Skip to content

Commit 71e2184

Browse files
authored
Retrain Forecast Models (#1282)
2 parents 0fb5724 + 839229f commit 71e2184

File tree

4 files changed

+43
-23
lines changed

4 files changed

+43
-23
lines changed

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,17 +218,12 @@ def _write_file(local_filename, remote_filename, storage_options, **kwargs):
218218

219219

220220
def load_pkl(filepath):
221-
return _safe_write(fn=_load_pkl, filepath=filepath)
222-
223-
224-
def _load_pkl(filepath):
225221
storage_options = {}
226222
if ObjectStorageDetails.is_oci_path(filepath):
227223
storage_options = default_signer()
228224

229225
with fsspec.open(filepath, "rb", **storage_options) as f:
230226
return cloudpickle.load(f)
231-
return None
232227

233228

234229
def write_pkl(obj, filename, output_dir, storage_options):

ads/opctl/operator/lowcode/forecast/model/arima.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ def _train_model(self, i, s_id, df, model_kwargs):
8585
X_pred = self.get_horizon(data).drop(target, axis=1)
8686

8787
if self.loaded_models is not None and s_id in self.loaded_models:
88-
model = self.loaded_models[s_id]
88+
model = self.loaded_models[s_id]["model"]
89+
order = model.order
90+
seasonal_order = model.seasonal_order
91+
model = pm.ARIMA(order=order, seasonal_order=seasonal_order)
92+
model.fit(y=y, X=X_in)
8993
else:
9094
# Build and fit model
9195
model = pm.auto_arima(y=y, X=X_in, **model_kwargs)

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -142,17 +142,20 @@ def _build_model(self) -> pd.DataFrame:
142142
)
143143

144144
if self.loaded_models is not None and s_id in self.loaded_models:
145-
model = self.loaded_models[s_id]
146-
else:
147-
model = Pipeline(
148-
task="forecasting",
149-
**model_kwargs,
150-
)
151-
model.fit(
152-
X=data_i.drop(target, axis=1),
153-
y=data_i[[target]],
154-
time_budget=time_budget,
155-
)
145+
model = self.loaded_models[s_id]["model"]
146+
model_kwargs["model_list"] = [model.selected_model_]
147+
model_kwargs["search_space"]={}
148+
model_kwargs["search_space"][model.selected_model_] = model.selected_model_params_
149+
150+
model = Pipeline(
151+
task="forecasting",
152+
**model_kwargs,
153+
)
154+
model.fit(
155+
X=data_i.drop(target, axis=1),
156+
y=data_i[[target]],
157+
time_budget=time_budget,
158+
)
156159
logger.debug(f"Selected model: {model.selected_model_}")
157160
logger.debug(f"Selected model params: {model.selected_model_params_}")
158161
summary_frame = model.forecast(

ads/opctl/operator/lowcode/forecast/model/prophet.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import matplotlib as mpl
1010
import numpy as np
1111
import optuna
12+
import inspect
1213
import pandas as pd
1314
from joblib import Parallel, delayed
1415

@@ -39,6 +40,22 @@ def _add_unit(num, unit):
3940
return f"{num} {unit}"
4041

4142

43+
def _extract_parameter(model):
44+
"""
45+
extract Prophet initialization parameters
46+
"""
47+
from prophet import Prophet
48+
sig = inspect.signature(Prophet.__init__)
49+
param_names = list(sig.parameters.keys())
50+
params = {}
51+
for name in param_names:
52+
if hasattr(model, name):
53+
value = getattr(model, name)
54+
if isinstance(value, (int, float, str, bool, type(None), dict, list)):
55+
params[name] = value
56+
return params
57+
58+
4259
def _fit_model(data, params, additional_regressors):
4360
from prophet import Prophet
4461

@@ -96,16 +113,17 @@ def _train_model(self, i, series_id, df, model_kwargs):
96113
data = self.preprocess(df, series_id)
97114
data_i = self.drop_horizon(data)
98115
if self.loaded_models is not None and series_id in self.loaded_models:
99-
model = self.loaded_models[series_id]
116+
previous_model = self.loaded_models[series_id]["model"]
117+
model_kwargs.update(_extract_parameter(previous_model))
100118
else:
101119
if self.perform_tuning:
102120
model_kwargs = self.run_tuning(data_i, model_kwargs)
103121

104-
model = _fit_model(
105-
data=data,
106-
params=model_kwargs,
107-
additional_regressors=self.additional_regressors,
108-
)
122+
model = _fit_model(
123+
data=data,
124+
params=model_kwargs,
125+
additional_regressors=self.additional_regressors,
126+
)
109127

110128
# Get future df for prediction
111129
future = data.drop("y", axis=1)

0 commit comments

Comments
 (0)