@@ -81,12 +81,12 @@ def logger(self) -> loguru.Logger:
8181 return loguru .logger
8282
8383
84- class MLflowService (Service ):
85- """Service for MLflow tracking and registry.
84+ class MlflowService (Service ):
85+ """Service for Mlflow tracking and registry.
8686
8787 Parameters:
88- tracking_uri (str): the URI for the MLflow tracking server.
89- registry_uri (str): the URI for the MLflow model registry.
88+ tracking_uri (str): the URI for the Mlflow tracking server.
89+ registry_uri (str): the URI for the Mlflow model registry.
9090 experiment_name (str): the name of tracking experiment.
9191 registry_name (str): the name of model registry.
9292 autolog_disable (bool): disable autologging.
@@ -96,9 +96,24 @@ class MLflowService(Service):
9696 autolog_log_model_signatures (bool): If True, logs model signatures during autologging.
9797 autolog_log_models (bool): If True, enables logging of models during autologging.
9898 autolog_log_datasets (bool): If True, logs datasets used during autologging.
99- autolog_silent (bool): If True, suppresses all MLflow warnings during autologging.
99+ autolog_silent (bool): If True, suppresses all Mlflow warnings during autologging.
100100 """
101101
102+ class RunConfig (pdt .BaseModel , strict = True , frozen = True , extra = "forbid" ):
103+ """Run configuration for Mlflow tracking.
104+
105+ Parameters:
106+ name (str): name of the run.
107+ description (str | None): description of the run.
108+ tags (dict[str, T.Any] | None): tags for the run.
109+ log_system_metrics (bool | None): enable system metrics logging.
110+ """
111+
112+ name : str
113+ description : str | None = None
114+ tags : dict [str , T .Any ] | None = None
115+ log_system_metrics : bool | None = None
116+
102117 # server uri
103118 tracking_uri : str = "./mlruns"
104119 registry_uri : str = "./mlruns"
@@ -135,31 +150,25 @@ def start(self) -> None:
135150 )
136151
137152 @ctx .contextmanager
138- def run (
139- self ,
140- name : str ,
141- description : str | None = None ,
142- tags : dict [str , T .Any ] | None = None ,
143- log_system_metrics : bool | None = None ,
144- ) -> T .Generator [mlflow .ActiveRun , None , None ]:
145- """Yield an active MLflow run and exit it afterwards.
153+ def run_context (self , run_config : RunConfig ) -> T .Generator [mlflow .ActiveRun , None , None ]:
154+ """Yield an active Mlflow run and exit it afterwards.
146155
147156 Args:
148- name (str): name of the run.
149- description (str | None, optional): description of the run. Defaults to None.
150- tags (dict[str, T.Any] | None, optional): dict of tags of the run. Defaults to None.
151- log_system_metrics (bool | None, optional): enable system metrics logging. Defaults to None.
157+ run (str): run parameters.
152158
153159 Yields:
154160 T.Generator[mlflow.ActiveRun, None, None]: active run context. Will be closed as the end of context.
155161 """
156162 with mlflow .start_run (
157- run_name = name , description = description , tags = tags , log_system_metrics = log_system_metrics
163+ run_name = run_config .name ,
164+ tags = run_config .tags ,
165+ description = run_config .description ,
166+ log_system_metrics = run_config .log_system_metrics ,
158167 ) as run :
159168 yield run
160169
161170 def client (self ) -> mt .MlflowClient :
162- """Return a new MLflow client.
171+ """Return a new Mlflow client.
163172
164173 Returns:
165174 MlflowClient: the mlflow client.
0 commit comments