Skip to content

Commit 0961b30

Browse files
committed
Refactor PyMC time series models to use xarray API
Unified the API for BayesianBasisExpansionTimeSeries and StateSpaceTimeSeries to accept xarray.DataArray inputs for X and y, with coordinates for datetime and treated units. Removed legacy numpy/datetime handling and updated internal conversion logic. Adjusted InterruptedTimeSeries and tests to use the new API, ensuring consistent handling of exogenous regressors and time indices. Improved error handling and warnings for coordinate mismatches and prediction inputs.
1 parent 5309a1f commit 0961b30

File tree

7 files changed

+565
-416
lines changed

7 files changed

+565
-416
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@ dist/
1414
docs/build/
1515
docs/jupyter_execute/
1616
docs/source/api/generated/
17+
18+
.cursor/

causalpy/experiments/interrupted_time_series.py

Lines changed: 15 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@
2727

2828
from causalpy.custom_exceptions import BadIndexException
2929
from causalpy.plot_utils import get_hdi_to_df, plot_xY
30-
from causalpy.pymc_models import (
31-
BayesianBasisExpansionTimeSeries,
32-
PyMCModel,
33-
StateSpaceTimeSeries,
34-
)
30+
from causalpy.pymc_models import PyMCModel
3531
from causalpy.utils import round_num
3632

3733
from .base import BaseExperiment
@@ -153,27 +149,15 @@ def __init__(
153149
)
154150

155151
# fit the model to the observed (pre-intervention) data
152+
# All PyMC models now accept xr.DataArray with consistent API
156153
if isinstance(self.model, PyMCModel):
157-
is_bsts_like = isinstance(
158-
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
159-
)
160-
161-
if is_bsts_like:
162-
# BSTS/StateSpace models expect numpy arrays and datetime coords
163-
X_fit = self.pre_X.values if self.pre_X.shape[1] > 0 else None # type: ignore[attr-defined]
164-
y_fit = self.pre_y.isel(treated_units=0).values # type: ignore[attr-defined]
165-
pre_coords: dict[str, Any] = {"datetime_index": self.datapre.index}
166-
if X_fit is not None:
167-
pre_coords["coeffs"] = list(self.labels)
168-
self.model.fit(X=X_fit, y=y_fit, coords=pre_coords)
169-
else:
170-
# General PyMC models expect xarray with treated_units
171-
COORDS = {
172-
"coeffs": self.labels,
173-
"obs_ind": np.arange(self.pre_X.shape[0]),
174-
"treated_units": ["unit_0"],
175-
}
176-
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
154+
COORDS: dict[str, Any] = {
155+
"coeffs": self.labels,
156+
"obs_ind": np.arange(self.pre_X.shape[0]),
157+
"treated_units": ["unit_0"],
158+
"datetime_index": self.datapre.index, # For time series models
159+
}
160+
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
177161
elif isinstance(self.model, RegressorMixin):
178162
# For OLS models, use 1D y data
179163
self.model.fit(X=self.pre_X, y=self.pre_y.isel(treated_units=0))
@@ -182,85 +166,28 @@ def __init__(
182166

183167
# score the goodness of fit to the pre-intervention data
184168
if isinstance(self.model, PyMCModel):
185-
is_bsts_like = isinstance(
186-
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
187-
)
188-
if is_bsts_like:
189-
X_score = self.pre_X.values if self.pre_X.shape[1] > 0 else None # type: ignore[attr-defined]
190-
y_score = self.pre_y.isel(treated_units=0).values # type: ignore[attr-defined]
191-
score_coords: dict[str, Any] = {"datetime_index": self.datapre.index}
192-
if X_score is not None:
193-
score_coords["coeffs"] = list(self.labels)
194-
self.score = self.model.score(X=X_score, y=y_score, coords=score_coords)
195-
else:
196-
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
169+
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
197170
elif isinstance(self.model, RegressorMixin):
198171
self.score = self.model.score(
199172
X=self.pre_X, y=self.pre_y.isel(treated_units=0)
200173
)
201174

202175
# get the model predictions of the observed (pre-intervention) data
203176
if isinstance(self.model, PyMCModel):
204-
is_bsts_like = isinstance(
205-
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
206-
)
207-
if is_bsts_like:
208-
X_pre_predict = self.pre_X.values if self.pre_X.shape[1] > 0 else None # type: ignore[attr-defined]
209-
pre_pred_coords: dict[str, Any] = {"datetime_index": self.datapre.index}
210-
self.pre_pred = self.model.predict(
211-
X=X_pre_predict, coords=pre_pred_coords
212-
)
213-
if not isinstance(self.pre_pred, az.InferenceData):
214-
self.pre_pred = az.InferenceData(posterior_predictive=self.pre_pred)
215-
else:
216-
self.pre_pred = self.model.predict(X=self.pre_X)
177+
self.pre_pred = self.model.predict(X=self.pre_X)
217178
elif isinstance(self.model, RegressorMixin):
218179
self.pre_pred = self.model.predict(X=self.pre_X)
219180

220181
# calculate the counterfactual (post period)
221182
if isinstance(self.model, PyMCModel):
222-
is_bsts_like = isinstance(
223-
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
224-
)
225-
if is_bsts_like:
226-
X_post_predict = (
227-
self.post_X.values if self.post_X.shape[1] > 0 else None # type: ignore[attr-defined]
228-
)
229-
post_pred_coords: dict[str, Any] = {
230-
"datetime_index": self.datapost.index
231-
}
232-
self.post_pred = self.model.predict(
233-
X=X_post_predict, coords=post_pred_coords, out_of_sample=True
234-
)
235-
if not isinstance(self.post_pred, az.InferenceData):
236-
self.post_pred = az.InferenceData(
237-
posterior_predictive=self.post_pred
238-
)
239-
else:
240-
self.post_pred = self.model.predict(X=self.post_X)
183+
self.post_pred = self.model.predict(X=self.post_X, out_of_sample=True)
241184
elif isinstance(self.model, RegressorMixin):
242185
self.post_pred = self.model.predict(X=self.post_X)
243186

244-
# calculate impact - use appropriate y data format for each model type
187+
# calculate impact - all PyMC models now use 2D data with treated_units
245188
if isinstance(self.model, PyMCModel):
246-
is_bsts_like = isinstance(
247-
self.model, (BayesianBasisExpansionTimeSeries, StateSpaceTimeSeries)
248-
)
249-
if is_bsts_like:
250-
pre_y_for_impact = self.pre_y.isel(treated_units=0)
251-
post_y_for_impact = self.post_y.isel(treated_units=0)
252-
self.pre_impact = self.model.calculate_impact(
253-
pre_y_for_impact, self.pre_pred
254-
)
255-
self.post_impact = self.model.calculate_impact(
256-
post_y_for_impact, self.post_pred
257-
)
258-
else:
259-
# PyMC models with treated_units use 2D data
260-
self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred)
261-
self.post_impact = self.model.calculate_impact(
262-
self.post_y, self.post_pred
263-
)
189+
self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred)
190+
self.post_impact = self.model.calculate_impact(self.post_y, self.post_pred)
264191
elif isinstance(self.model, RegressorMixin):
265192
# SKL models work with 1D data
266193
self.pre_impact = self.model.calculate_impact(

0 commit comments

Comments
 (0)