Skip to content

Commit 8e72034

Browse files
committed
Remove config from task manager and stop killing it
1 parent e8849f3 commit 8e72034

22 files changed

+174
-127
lines changed

azimuth/app.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from azimuth.startup import startup_tasks
1919
from azimuth.task_manager import TaskManager
2020
from azimuth.types import DatasetSplitName, ModuleOptions
21-
from azimuth.utils.cluster import default_cluster
2221
from azimuth.utils.conversion import JSONResponseIgnoreNan
2322
from azimuth.utils.logs import set_logger_config
2423
from azimuth.utils.project import load_dataset_split_managers_from_config, save_config
@@ -100,9 +99,7 @@ def start_app(config_path, debug=False) -> FastAPI:
10099
if azimuth_config.dataset is None:
101100
raise ValueError("No dataset has been specified in the config.")
102101

103-
local_cluster = default_cluster(large=azimuth_config.large_dask_cluster)
104-
105-
run_startup_tasks(azimuth_config, local_cluster)
102+
run_startup_tasks(azimuth_config)
106103
assert_not_none(_task_manager).client.run(set_logger_config, level)
107104

108105
app = create_app()
@@ -228,25 +225,23 @@ def create_app() -> FastAPI:
228225
return app
229226

230227

231-
def initialize_managers(azimuth_config: AzimuthConfig, cluster: SpecCluster):
232-
"""Initialize DatasetSplitManagers and TaskManagers.
233-
228+
def initialize_managers_and_config(
229+
azimuth_config: AzimuthConfig, cluster: Optional[SpecCluster] = None
230+
):
231+
"""Initialize DatasetSplitManagers and Config.
234232
235233
Args:
236-
azimuth_config: Configuration
237-
cluster: Dask cluster to use.
234+
azimuth_config: Config
235+
cluster: Dask cluster to use, if different than default.
238236
"""
239237
global _task_manager, _dataset_split_managers, _azimuth_config
240-
_azimuth_config = azimuth_config
241-
if _task_manager is not None:
242-
task_history = _task_manager.current_tasks
238+
if _task_manager:
239+
_task_manager.clear_worker_cache()
240+
_task_manager.restart()
243241
else:
244-
task_history = {}
245-
246-
_task_manager = TaskManager(azimuth_config, cluster=cluster)
247-
248-
_task_manager.current_tasks = task_history
242+
_task_manager = TaskManager(cluster, azimuth_config.large_dask_cluster)
249243

244+
_azimuth_config = azimuth_config
250245
_dataset_split_managers = load_dataset_split_managers_from_config(azimuth_config)
251246

252247

@@ -283,15 +278,14 @@ def run_validation_module(pipeline_index=None):
283278
task_manager.restart()
284279

285280

286-
def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster):
281+
def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: Optional[SpecCluster] = None):
287282
"""Initialize managers, run validation and startup tasks.
288283
289284
Args:
290285
azimuth_config: Config
291-
cluster: Cluster
292-
286+
cluster: Dask cluster to use, if different than default.
293287
"""
294-
initialize_managers(azimuth_config, cluster)
288+
initialize_managers_and_config(azimuth_config, cluster)
295289

296290
task_manager = assert_not_none(get_task_manager())
297291
# Validate that everything is in order **before** the startup tasks.
@@ -303,5 +297,5 @@ def run_startup_tasks(azimuth_config: AzimuthConfig, cluster: SpecCluster):
303297
save_config(azimuth_config) # Save only after the validation modules ran successfully
304298

305299
global _startup_tasks, _ready_flag
306-
_startup_tasks = startup_tasks(_dataset_split_managers, task_manager)
300+
_startup_tasks = startup_tasks(_dataset_split_managers, task_manager, azimuth_config)
307301
_ready_flag = Event()

azimuth/routers/v1/admin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from azimuth.app import (
1212
get_config,
1313
get_task_manager,
14-
initialize_managers,
14+
initialize_managers_and_config,
1515
require_editable_config,
1616
run_startup_tasks,
1717
)
@@ -77,11 +77,11 @@ def patch_config(
7777
) -> AzimuthConfig:
7878
try:
7979
new_config = update_config(old_config=config, partial_config=partial_config)
80-
run_startup_tasks(new_config, task_manager.cluster)
80+
run_startup_tasks(new_config)
8181
except Exception as e:
8282
log.error("Rollback config update due to error", exc_info=e)
8383
new_config = config
84-
initialize_managers(new_config, task_manager.cluster)
84+
initialize_managers_and_config(new_config)
8585
if isinstance(e, AzimuthValidationError):
8686
raise HTTPException(HTTP_400_BAD_REQUEST, detail=str(e))
8787
elif isinstance(e, ValidationError):

azimuth/routers/v1/app.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,12 @@ def get_dataset_info(
8888
get_dataset_split_manager_mapping
8989
),
9090
startup_tasks: Dict[str, Module] = Depends(get_startup_tasks),
91-
task_manager: TaskManager = Depends(get_task_manager),
9291
config: AzimuthConfig = Depends(get_config),
9392
):
9493
eval_dm = dataset_split_managers.get(DatasetSplitName.eval)
9594
training_dm = dataset_split_managers.get(DatasetSplitName.train)
9695
dm = assert_not_none(eval_dm or training_dm)
9796

98-
model_contract = task_manager.config.model_contract
99-
10097
return DatasetInfoResponse(
10198
project_name=config.name,
10299
class_names=dm.get_class_names(),
@@ -109,19 +106,16 @@ def get_dataset_info(
109106
if training_dm is not None
110107
else [],
111108
startup_tasks={k: v.status() for k, v in startup_tasks.items()},
112-
model_contract=model_contract,
113-
prediction_available=predictions_available(task_manager.config),
114-
perturbation_testing_available=perturbation_testing_available(task_manager.config),
109+
model_contract=config.model_contract,
110+
prediction_available=predictions_available(config),
111+
perturbation_testing_available=perturbation_testing_available(config),
115112
available_dataset_splits=AvailableDatasetSplits(
116113
eval=eval_dm is not None, train=training_dm is not None
117114
),
118-
similarity_available=similarity_available(task_manager.config),
115+
similarity_available=similarity_available(config),
119116
postprocessing_editable=None
120117
if config.pipelines is None
121-
else [
122-
postprocessing_editable(task_manager.config, idx)
123-
for idx in range(len(config.pipelines))
124-
],
118+
else [postprocessing_editable(config, idx) for idx in range(len(config.pipelines))],
125119
)
126120

127121

@@ -177,6 +171,7 @@ def get_perturbation_testing_summary(
177171
SupportedModule.PerturbationTestingMerged,
178172
dataset_split_name=DatasetSplitName.all,
179173
task_manager=task_manager,
174+
config=config,
180175
last_update=last_update,
181176
mod_options=ModuleOptions(pipeline_index=pipeline_index),
182177
)[0]
@@ -192,6 +187,7 @@ def get_perturbation_testing_summary(
192187
SupportedModule.PerturbationTestingSummary,
193188
dataset_split_name=DatasetSplitName.all,
194189
task_manager=task_manager,
190+
config=config,
195191
mod_options=ModuleOptions(pipeline_index=pipeline_index),
196192
)[0]
197193
return PerturbationTestingSummary(

azimuth/routers/v1/class_overlap.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def get_class_overlap_plot(
7474
SupportedModule.ClassOverlap,
7575
dataset_split_name=dataset_split_name,
7676
task_manager=task_manager,
77+
config=config,
7778
last_update=-1,
7879
)[0]
7980
class_overlap_plot_response: ClassOverlapPlotResponse = make_sankey_plot(
@@ -97,6 +98,7 @@ def get_class_overlap_plot(
9798
def get_class_overlap(
9899
dataset_split_name: DatasetSplitName,
99100
task_manager: TaskManager = Depends(get_task_manager),
101+
config: AzimuthConfig = Depends(get_config),
100102
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
101103
dataset_split_managers: Dict[DatasetSplitName, DatasetSplitManager] = Depends(
102104
get_dataset_split_manager_mapping
@@ -109,6 +111,7 @@ def get_class_overlap(
109111
SupportedModule.ClassOverlap,
110112
dataset_split_name=dataset_split_name,
111113
task_manager=task_manager,
114+
config=config,
112115
last_update=-1,
113116
)[0]
114117
dataset_class_count = class_overlap_result.s_matrix.shape[0]
@@ -124,6 +127,7 @@ def get_class_overlap(
124127
SupportedModule.ConfusionMatrix,
125128
DatasetSplitName.eval,
126129
task_manager=task_manager,
130+
config=config,
127131
mod_options=ModuleOptions(
128132
pipeline_index=pipeline_index, cf_normalize=False, cf_reorder_classes=False
129133
),

azimuth/routers/v1/custom_utterances.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,13 @@ def get_saliency(
9797
utterances: List[str] = Query([], title="Utterances"),
9898
pipeline_index: int = Depends(require_pipeline_index),
9999
task_manager: TaskManager = Depends(get_task_manager),
100+
config: AzimuthConfig = Depends(get_config),
100101
) -> List[SaliencyResponse]:
101102
task_result: List[SaliencyResponse] = get_custom_task_result(
102103
SupportedMethod.Saliency,
103104
task_manager=task_manager,
104-
custom_query={task_manager.config.columns.text_input: utterances},
105+
config=config,
106+
custom_query={config.columns.text_input: utterances},
105107
mod_options=ModuleOptions(pipeline_index=pipeline_index),
106108
)
107109

azimuth/routers/v1/dataset_warnings.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from fastapi import APIRouter, Depends
88

9-
from azimuth.app import get_all_dataset_split_managers, get_task_manager
9+
from azimuth.app import get_all_dataset_split_managers, get_config, get_task_manager
10+
from azimuth.config import AzimuthConfig
1011
from azimuth.dataset_split_manager import DatasetSplitManager
1112
from azimuth.task_manager import TaskManager
1213
from azimuth.types import DatasetSplitName, SupportedModule
@@ -28,6 +29,7 @@
2829
)
2930
def get_dataset_warnings(
3031
task_manager: TaskManager = Depends(get_task_manager),
32+
config: AzimuthConfig = Depends(get_config),
3133
dataset_split_managers: Dict[DatasetSplitName, DatasetSplitManager] = Depends(
3234
get_all_dataset_split_managers
3335
),
@@ -42,6 +44,7 @@ def get_dataset_warnings(
4244
SupportedModule.DatasetWarnings,
4345
dataset_split_name=DatasetSplitName.all,
4446
task_manager=task_manager,
47+
config=config,
4548
last_update=last_update,
4649
)[0]
4750

azimuth/routers/v1/export.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -99,19 +99,17 @@ def get_export_perturbation_testing_summary(
9999
SupportedModule.PerturbationTestingSummary,
100100
DatasetSplitName.all,
101101
task_manager=task_manager,
102+
config=config,
102103
last_update=last_update,
103104
mod_options=ModuleOptions(pipeline_index=pipeline_index),
104105
)[0].all_tests_summary
105106

106-
cfg = task_manager.config
107107
df = pd.DataFrame.from_records([t.dict() for t in task_result])
108108
df["example"] = df["example"].apply(lambda i: i["perturbedUtterance"])
109109
file_label = time.strftime("%Y%m%d_%H%M%S", time.localtime())
110110

111-
filename = f"azimuth_export_behavioral_testing_summary_{cfg.name}_{file_label}.csv"
112-
113-
pt = pjoin(cfg.get_artifact_path(), filename)
114-
111+
filename = f"azimuth_export_behavioral_testing_summary_{config.name}_{file_label}.csv"
112+
pt = pjoin(config.get_artifact_path(), filename)
115113
df.to_csv(pt, index=False)
116114

117115
return FileResponse(path=pt, filename=filename)
@@ -135,15 +133,14 @@ def get_export_perturbed_set(
135133
) -> FileResponse:
136134
pipeline_index_not_null = assert_not_none(pipeline_index)
137135
file_label = time.strftime("%Y%m%d_%H%M%S", time.localtime())
138-
cfg = task_manager.config
139-
140-
filename = f"azimuth_export_modified_set_{cfg.name}_{dataset_split_name}_{file_label}.json"
141-
pt = pjoin(cfg.get_artifact_path(), filename)
136+
filename = f"azimuth_export_modified_set_{config.name}_{dataset_split_name}_{file_label}.json"
137+
pt = pjoin(config.get_artifact_path(), filename)
142138

143139
task_result: List[List[PerturbedUtteranceResult]] = get_standard_task_result(
144140
SupportedModule.PerturbationTesting,
145141
dataset_split_name,
146142
task_manager,
143+
config=config,
147144
mod_options=ModuleOptions(pipeline_index=pipeline_index_not_null),
148145
)
149146

azimuth/routers/v1/model_performance/confidence_histogram.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from fastapi import APIRouter, Depends, Query
66

7-
from azimuth.app import get_dataset_split_manager, get_task_manager
7+
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
8+
from azimuth.config import AzimuthConfig
89
from azimuth.dataset_split_manager import DatasetSplitManager
910
from azimuth.task_manager import TaskManager
1011
from azimuth.types import (
@@ -36,6 +37,7 @@ def get_confidence_histogram(
3637
dataset_split_name: DatasetSplitName,
3738
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
3839
task_manager: TaskManager = Depends(get_task_manager),
40+
config: AzimuthConfig = Depends(get_config),
3941
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
4042
pipeline_index: int = Depends(require_pipeline_index),
4143
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
@@ -50,6 +52,7 @@ def get_confidence_histogram(
5052
task_name=SupportedModule.ConfidenceHistogram,
5153
dataset_split_name=dataset_split_name,
5254
task_manager=task_manager,
55+
config=config,
5356
mod_options=mod_options,
5457
last_update=dataset_split_manager.last_update,
5558
)[0]

azimuth/routers/v1/model_performance/confusion_matrix.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from fastapi import APIRouter, Depends, Query
66

7-
from azimuth.app import get_dataset_split_manager, get_task_manager
7+
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
8+
from azimuth.config import AzimuthConfig
89
from azimuth.dataset_split_manager import DatasetSplitManager
910
from azimuth.task_manager import TaskManager
1011
from azimuth.types import (
@@ -36,6 +37,7 @@ def get_confusion_matrix(
3637
dataset_split_name: DatasetSplitName,
3738
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
3839
task_manager: TaskManager = Depends(get_task_manager),
40+
config: AzimuthConfig = Depends(get_config),
3941
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
4042
pipeline_index: int = Depends(require_pipeline_index),
4143
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
@@ -54,6 +56,7 @@ def get_confusion_matrix(
5456
SupportedModule.ConfusionMatrix,
5557
dataset_split_name,
5658
task_manager=task_manager,
59+
config=config,
5760
mod_options=mod_options,
5861
last_update=dataset_split_manager.last_update,
5962
)[0]

azimuth/routers/v1/model_performance/metrics.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
from fastapi import APIRouter, Depends, Query
77

8-
from azimuth.app import get_dataset_split_manager, get_task_manager
8+
from azimuth.app import get_config, get_dataset_split_manager, get_task_manager
9+
from azimuth.config import AzimuthConfig
910
from azimuth.dataset_split_manager import DatasetSplitManager
1011
from azimuth.modules.model_performance.metrics import MetricsModule
1112
from azimuth.task_manager import TaskManager
@@ -44,6 +45,7 @@ def get_metrics(
4445
dataset_split_name: DatasetSplitName,
4546
named_filters: NamedDatasetFilters = Depends(build_named_dataset_filters),
4647
task_manager: TaskManager = Depends(get_task_manager),
48+
config: AzimuthConfig = Depends(get_config),
4749
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
4850
pipeline_index: int = Depends(require_pipeline_index),
4951
without_postprocessing: bool = Query(False, title="Without Postprocessing"),
@@ -58,6 +60,7 @@ def get_metrics(
5860
SupportedModule.Metrics,
5961
dataset_split_name,
6062
task_manager,
63+
config=config,
6164
mod_options=mod_options,
6265
last_update=dataset_split_manager.last_update,
6366
)
@@ -77,6 +80,7 @@ def get_metrics(
7780
def get_metrics_per_filter(
7881
dataset_split_name: DatasetSplitName,
7982
task_manager: TaskManager = Depends(get_task_manager),
83+
config: AzimuthConfig = Depends(get_config),
8084
dataset_split_manager: DatasetSplitManager = Depends(get_dataset_split_manager),
8185
pipeline_index: int = Depends(require_pipeline_index),
8286
) -> MetricsPerFilterAPIResponse:
@@ -85,6 +89,7 @@ def get_metrics_per_filter(
8589
SupportedModule.MetricsPerFilter,
8690
dataset_split_name,
8791
task_manager,
92+
config=config,
8893
mod_options=mod_options,
8994
last_update=dataset_split_manager.last_update,
9095
)[0]
@@ -93,6 +98,7 @@ def get_metrics_per_filter(
9398
SupportedModule.Metrics,
9499
dataset_split_name,
95100
task_manager,
101+
config=config,
96102
mod_options=mod_options,
97103
last_update=dataset_split_manager.last_update,
98104
)[0]

0 commit comments

Comments
 (0)