Skip to content

Commit 20671c7

Browse files
committed
feat(structure): improve code structure (second review)
1 parent f568fbc commit 20671c7

24 files changed

+139
-134
lines changed

src/bikes/core/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def fit(self, inputs: schemas.Inputs, targets: schemas.Targets) -> "BaselineSkle
155155
@T.override
156156
def predict(self, inputs: schemas.Inputs) -> schemas.Outputs:
157157
model = self.get_internal_model()
158-
prediction = model.predict(inputs) # np.ndarray
158+
prediction = model.predict(inputs)
159159
outputs = schemas.Outputs(
160160
{schemas.OutputsSchema.prediction: prediction}, index=inputs.index
161161
)

src/bikes/io/registries.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ class Saver(abc.ABC, pdt.BaseModel, strict=True, frozen=True, extra="forbid"):
5757
e.g., to switch between serialization flavors.
5858
5959
Parameters:
60-
path (str): model path inside the MLflow store.
60+
path (str): model path inside the Mlflow store.
6161
"""
6262

6363
KIND: str
@@ -81,15 +81,15 @@ def save(
8181

8282

8383
class CustomSaver(Saver):
84-
"""Saver for project models using the MLflow PyFunc module.
84+
"""Saver for project models using the Mlflow PyFunc module.
8585
8686
https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html
8787
"""
8888

8989
KIND: T.Literal["CustomSaver"] = "CustomSaver"
9090

9191
class Adapter(mlflow.pyfunc.PythonModel): # type: ignore[misc]
92-
"""Adapt a custom model to the MLflow PyFunc flavor for saving operations.
92+
"""Adapt a custom model to the Mlflow PyFunc flavor for saving operations.
9393
9494
https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html?#mlflow.pyfunc.PythonModel
9595
"""
@@ -134,12 +134,12 @@ def save(
134134

135135

136136
class BuiltinSaver(Saver):
137-
"""Saver for built-in models using an MLflow flavor module.
137+
"""Saver for built-in models using an Mlflow flavor module.
138138
139139
https://mlflow.org/docs/latest/models.html#built-in-model-flavors
140140
141141
Parameters:
142-
flavor (str): MLflow flavor module to use for the serialization.
142+
flavor (str): Mlflow flavor module to use for the serialization.
143143
"""
144144

145145
KIND: T.Literal["BuiltinSaver"] = "BuiltinSaver"
@@ -201,7 +201,7 @@ def load(self, uri: str) -> "Loader.Adapter":
201201

202202

203203
class CustomLoader(Loader):
204-
"""Loader for custom models using the MLflow PyFunc module.
204+
"""Loader for custom models using the Mlflow PyFunc module.
205205
206206
https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html
207207
"""
@@ -233,9 +233,9 @@ def load(self, uri: str) -> "CustomLoader.Adapter":
233233

234234

235235
class BuiltinLoader(Loader):
236-
"""Loader for built-in models using the MLflow PyFunc module.
236+
"""Loader for built-in models using the Mlflow PyFunc module.
237237
238-
Note: use MLflow PyFunc instead of flavors to use standard API.
238+
Note: use Mlflow PyFunc instead of flavors to use standard API.
239239
240240
https://mlflow.org/docs/latest/models.html#built-in-model-flavors
241241
"""
@@ -298,17 +298,17 @@ def register(self, name: str, model_uri: str) -> Version:
298298
"""
299299

300300

301-
class MLflowRegister(Register):
302-
"""Register for models in the MLflow Model Registry.
301+
class MlflowRegister(Register):
302+
"""Register for models in the Mlflow Model Registry.
303303
304304
https://mlflow.org/docs/latest/model-registry.html
305305
"""
306306

307-
KIND: T.Literal["MLflowRegister"] = "MLflowRegister"
307+
KIND: T.Literal["MlflowRegister"] = "MlflowRegister"
308308

309309
@T.override
310310
def register(self, name: str, model_uri: str) -> Version:
311311
return mlflow.register_model(name=name, model_uri=model_uri, tags=self.tags)
312312

313313

314-
RegisterKind = MLflowRegister
314+
RegisterKind = MlflowRegister

src/bikes/io/services.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,12 @@ def logger(self) -> loguru.Logger:
8181
return loguru.logger
8282

8383

84-
class MLflowService(Service):
85-
"""Service for MLflow tracking and registry.
84+
class MlflowService(Service):
85+
"""Service for Mlflow tracking and registry.
8686
8787
Parameters:
88-
tracking_uri (str): the URI for the MLflow tracking server.
89-
registry_uri (str): the URI for the MLflow model registry.
88+
tracking_uri (str): the URI for the Mlflow tracking server.
89+
registry_uri (str): the URI for the Mlflow model registry.
9090
experiment_name (str): the name of tracking experiment.
9191
registry_name (str): the name of model registry.
9292
autolog_disable (bool): disable autologging.
@@ -96,9 +96,24 @@ class MLflowService(Service):
9696
autolog_log_model_signatures (bool): If True, logs model signatures during autologging.
9797
autolog_log_models (bool): If True, enables logging of models during autologging.
9898
autolog_log_datasets (bool): If True, logs datasets used during autologging.
99-
autolog_silent (bool): If True, suppresses all MLflow warnings during autologging.
99+
autolog_silent (bool): If True, suppresses all Mlflow warnings during autologging.
100100
"""
101101

102+
class RunConfig(pdt.BaseModel, strict=True, frozen=True, extra="forbid"):
103+
"""Run configuration for Mlflow tracking.
104+
105+
Parameters:
106+
name (str): name of the run.
107+
description (str | None): description of the run.
108+
tags (dict[str, T.Any] | None): tags for the run.
109+
log_system_metrics (bool | None): enable system metrics logging.
110+
"""
111+
112+
name: str
113+
description: str | None = None
114+
tags: dict[str, T.Any] | None = None
115+
log_system_metrics: bool | None = None
116+
102117
# server uri
103118
tracking_uri: str = "./mlruns"
104119
registry_uri: str = "./mlruns"
@@ -135,31 +150,25 @@ def start(self) -> None:
135150
)
136151

137152
@ctx.contextmanager
138-
def run(
139-
self,
140-
name: str,
141-
description: str | None = None,
142-
tags: dict[str, T.Any] | None = None,
143-
log_system_metrics: bool | None = None,
144-
) -> T.Generator[mlflow.ActiveRun, None, None]:
145-
"""Yield an active MLflow run and exit it afterwards.
153+
def run_context(self, run_config: RunConfig) -> T.Generator[mlflow.ActiveRun, None, None]:
154+
"""Yield an active Mlflow run and exit it afterwards.
146155
147156
Args:
148-
name (str): name of the run.
149-
description (str | None, optional): description of the run. Defaults to None.
150-
tags (dict[str, T.Any] | None, optional): dict of tags of the run. Defaults to None.
151-
log_system_metrics (bool | None, optional): enable system metrics logging. Defaults to None.
157+
run (str): run parameters.
152158
153159
Yields:
154160
T.Generator[mlflow.ActiveRun, None, None]: active run context. Will be closed as the end of context.
155161
"""
156162
with mlflow.start_run(
157-
run_name=name, description=description, tags=tags, log_system_metrics=log_system_metrics
163+
run_name=run_config.name,
164+
tags=run_config.tags,
165+
description=run_config.description,
166+
log_system_metrics=run_config.log_system_metrics,
158167
) as run:
159168
yield run
160169

161170
def client(self) -> mt.MlflowClient:
162-
"""Return a new MLflow client.
171+
"""Return a new Mlflow client.
163172
164173
Returns:
165174
MlflowClient: the mlflow client.

src/bikes/jobs/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ class Job(abc.ABC, pdt.BaseModel, strict=True, frozen=True, extra="forbid"):
2626
2727
Parameters:
2828
logger_service (services.LoggerService): manage the logging system.
29-
mlflow_service (services.MLflowService): manage the mlflow system.
29+
mlflow_service (services.MlflowService): manage the mlflow system.
3030
"""
3131

3232
KIND: str
3333

3434
logger_service: services.LoggerService = services.LoggerService()
35-
mlflow_service: services.MLflowService = services.MLflowService()
35+
mlflow_service: services.MlflowService = services.MlflowService()
3636

3737
def __enter__(self) -> T.Self:
3838
"""Enter the job context.
@@ -43,7 +43,7 @@ def __enter__(self) -> T.Self:
4343
self.logger_service.start()
4444
logger = self.logger_service.logger()
4545
logger.debug("[START] Logger service: {}", self.logger_service)
46-
logger.debug("[START] MLflow service: {}", self.mlflow_service)
46+
logger.debug("[START] Mlflow service: {}", self.mlflow_service)
4747
self.mlflow_service.start()
4848
return self
4949

@@ -64,7 +64,7 @@ def __exit__(
6464
T.Literal[False]: always propagate exceptions.
6565
"""
6666
logger = self.logger_service.logger()
67-
logger.debug("[STOP] MLflow service: {}", self.mlflow_service)
67+
logger.debug("[STOP] Mlflow service: {}", self.mlflow_service)
6868
self.mlflow_service.stop()
6969
logger.debug("[STOP] Logger service: {}", self.logger_service)
7070
self.logger_service.stop()

src/bikes/jobs/promotion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ class PromotionJob(base.Job):
2121

2222
KIND: T.Literal["PromotionJob"] = "PromotionJob"
2323

24-
version: int | None = None
2524
alias: str = "Champion"
25+
version: int | None = None
2626

2727
@T.override
2828
def run(self) -> base.Locals:

src/bikes/jobs/training.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from bikes.core import metrics as metrics_
1010
from bikes.core import models, schemas
11-
from bikes.io import datasets, registries
11+
from bikes.io import datasets, registries, services
1212
from bikes.jobs import base
1313
from bikes.utils import signers, splitters
1414

@@ -19,9 +19,7 @@ class TrainingJob(base.Job):
1919
"""Train and register a single AI/ML model.
2020
2121
Parameters:
22-
run_name (str): name of the run.
23-
run_description (str, optional): description of the run.
24-
run_tags: (dict[str, T.Any], optional): tags for the run.
22+
run_config (services.MlflowService.RunConfig): mlflow run config.
2523
inputs (datasets.ReaderKind): reader for the inputs data.
2624
targets (datasets.ReaderKind): reader for the targets data.
2725
model (models.ModelKind): machine learning model to train.
@@ -35,9 +33,7 @@ class TrainingJob(base.Job):
3533
KIND: T.Literal["TrainingJob"] = "TrainingJob"
3634

3735
# Run
38-
run_name: str = "Tuning"
39-
run_description: str | None = None
40-
run_tags: dict[str, T.Any] | None = None
36+
run_config: services.MlflowService.RunConfig = services.MlflowService.RunConfig(name="Training")
4137
# Data
4238
inputs: datasets.ReaderKind = pdt.Field(..., discriminator="KIND")
4339
targets: datasets.ReaderKind = pdt.Field(..., discriminator="KIND")
@@ -55,7 +51,7 @@ class TrainingJob(base.Job):
5551
signer: signers.SignerKind = pdt.Field(signers.InferSigner(), discriminator="KIND")
5652
# Registrer
5753
# - avoid shadowing pydantic `register` pydantic function
58-
registry: registries.RegisterKind = pdt.Field(registries.MLflowRegister(), discriminator="KIND")
54+
registry: registries.RegisterKind = pdt.Field(registries.MlflowRegister(), discriminator="KIND")
5955

6056
@T.override
6157
def run(self) -> base.Locals:
@@ -65,9 +61,7 @@ def run(self) -> base.Locals:
6561
logger.info("With logger: {}", logger)
6662
# - mlflow
6763
client = self.mlflow_service.client()
68-
with self.mlflow_service.run(
69-
name=self.run_name, description=self.run_description, tags=self.run_tags
70-
) as run:
64+
with self.mlflow_service.run_context(run_config=self.run_config) as run:
7165
logger.info("With mlflow run id: {}", run.info.run_id)
7266
# data
7367
# - inputs

src/bikes/jobs/tuning.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pydantic as pdt
88

99
from bikes.core import metrics, models, schemas
10-
from bikes.io import datasets
10+
from bikes.io import datasets, services
1111
from bikes.jobs import base
1212
from bikes.utils import searchers, splitters
1313

@@ -18,9 +18,7 @@ class TuningJob(base.Job):
1818
"""Find the best hyperparameters for a model.
1919
2020
Parameters:
21-
run_name (str): name of the run.
22-
run_description (str, optional): description of the run.
23-
run_tags: (dict[str, T.Any], optional): tags for the run.
21+
run_config (services.MlflowService.RunConfig): mlflow run config.
2422
inputs (datasets.ReaderKind): reader for the inputs data.
2523
targets (datasets.ReaderKind): reader for the targets data.
2624
model (models.ModelKind): machine learning model to tune.
@@ -32,9 +30,7 @@ class TuningJob(base.Job):
3230
KIND: T.Literal["TuningJob"] = "TuningJob"
3331

3432
# Run
35-
run_name: str = "Tuning"
36-
run_description: str | None = None
37-
run_tags: dict[str, T.Any] | None = None
33+
run_config: services.MlflowService.RunConfig = services.MlflowService.RunConfig(name="Tuning")
3834
# Data
3935
inputs: datasets.ReaderKind = pdt.Field(..., discriminator="KIND")
4036
targets: datasets.ReaderKind = pdt.Field(..., discriminator="KIND")
@@ -64,9 +60,7 @@ def run(self) -> base.Locals:
6460
logger = self.logger_service.logger()
6561
logger.info("With logger: {}", logger)
6662
# - mlflow
67-
with self.mlflow_service.run(
68-
name=self.run_name, description=self.run_description, tags=self.run_tags
69-
) as run:
63+
with self.mlflow_service.run_context(run_config=self.run_config) as run:
7064
logger.info("With mlflow run id: {}", run.info.run_id)
7165
# data
7266
# - inputs

src/bikes/utils/searchers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,17 @@ def search(
6060
metric (metrics.Metric): main metric to optimize.
6161
inputs (schemas.Inputs): model inputs for tuning.
6262
targets (schemas.Targets): model targets for tuning.
63-
cv (CrossValidation): structure for cross-folds strategy.
63+
cv (CrossValidation): choice for cross-fold validation.
6464
6565
Returns:
66-
Results: all the results of the searcher process.
66+
Results: all the results of the searcher execution process.
6767
"""
6868

6969

7070
class GridCVSearcher(Searcher):
7171
"""Grid searcher with cross-fold validation.
7272
73-
Metric should return higher values for better models.
73+
Convention: metric returns higher values for better models.
7474
7575
Parameters:
7676
n_jobs (int, optional): number of jobs to run in parallel.

src/bikes/utils/signers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class Signer(abc.ABC, pdt.BaseModel, strict=True, frozen=True, extra="forbid"):
2222
"""Base class for generating model signatures.
2323
2424
Allow to switch between model signing strategies.
25-
e.g., automatic inference, manual signatures, ...
25+
e.g., automatic inference, manual model signature, ...
2626
2727
https://mlflow.org/docs/latest/models.html#model-signature-and-input-example
2828
"""

src/bikes/utils/splitters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class TrainTestSplitter(Splitter):
6767
"""Split a dataframe into a train and test set.
6868
6969
Parameters:
70-
shuffle (bool): shuffle dataset before splitting it.
70+
shuffle (bool): shuffle the dataset. Default is False.
7171
test_size (int | float): number/ratio for the test set.
7272
random_state (int): random state for the splitter object.
7373
"""

0 commit comments

Comments
 (0)