1818from azimuth .startup import startup_tasks
1919from azimuth .task_manager import TaskManager
2020from azimuth .types import DatasetSplitName , ModuleOptions
21- from azimuth .utils .cluster import default_cluster
2221from azimuth .utils .conversion import JSONResponseIgnoreNan
2322from azimuth .utils .logs import set_logger_config
2423from azimuth .utils .project import load_dataset_split_managers_from_config , save_config
@@ -100,9 +99,7 @@ def start_app(config_path, debug=False) -> FastAPI:
10099 if azimuth_config .dataset is None :
101100 raise ValueError ("No dataset has been specified in the config." )
102101
103- local_cluster = default_cluster (large = azimuth_config .large_dask_cluster )
104-
105- run_startup_tasks (azimuth_config , local_cluster )
102+ run_startup_tasks (azimuth_config )
106103 assert_not_none (_task_manager ).client .run (set_logger_config , level )
107104
108105 app = create_app ()
@@ -228,25 +225,23 @@ def create_app() -> FastAPI:
228225 return app
229226
230227
231- def initialize_managers (azimuth_config : AzimuthConfig , cluster : SpecCluster ):
232- """Initialize DatasetSplitManagers and TaskManagers.
233-
228+ def initialize_managers_and_config (
229+ azimuth_config : AzimuthConfig , cluster : Optional [SpecCluster ] = None
230+ ):
231+ """Initialize DatasetSplitManagers and Config.
234232
235233 Args:
236- azimuth_config: Configuration
237- cluster: Dask cluster to use.
234+ azimuth_config: Config
235+ cluster: Dask cluster to use, if different than default .
238236 """
239237 global _task_manager , _dataset_split_managers , _azimuth_config
240- _azimuth_config = azimuth_config
241- if _task_manager is not None :
242- task_history = _task_manager .current_tasks
238+ if _task_manager :
239+ _task_manager . clear_worker_cache ()
240+ _task_manager .restart ()
243241 else :
244- task_history = {}
245-
246- _task_manager = TaskManager (azimuth_config , cluster = cluster )
247-
248- _task_manager .current_tasks = task_history
242+ _task_manager = TaskManager (cluster , azimuth_config .large_dask_cluster )
249243
244+ _azimuth_config = azimuth_config
250245 _dataset_split_managers = load_dataset_split_managers_from_config (azimuth_config )
251246
252247
@@ -283,15 +278,14 @@ def run_validation_module(pipeline_index=None):
283278 task_manager .restart ()
284279
285280
286- def run_startup_tasks (azimuth_config : AzimuthConfig , cluster : SpecCluster ):
281+ def run_startup_tasks (azimuth_config : AzimuthConfig , cluster : Optional [ SpecCluster ] = None ):
287282 """Initialize managers, run validation and startup tasks.
288283
289284 Args:
290285 azimuth_config: Config
291- cluster: Cluster
292-
286+ cluster: Dask cluster to use, if different than default.
293287 """
294- initialize_managers (azimuth_config , cluster )
288+ initialize_managers_and_config (azimuth_config , cluster )
295289
296290 task_manager = assert_not_none (get_task_manager ())
297291 # Validate that everything is in order **before** the startup tasks.
@@ -303,5 +297,5 @@ def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster):
303297 save_config (azimuth_config ) # Save only after the validation modules ran successfully
304298
305299 global _startup_tasks , _ready_flag
306- _startup_tasks = startup_tasks (_dataset_split_managers , task_manager )
300+ _startup_tasks = startup_tasks (_dataset_split_managers , task_manager , azimuth_config )
307301 _ready_flag = Event ()
0 commit comments