2727
2828from causalpy .custom_exceptions import BadIndexException
2929from causalpy .plot_utils import get_hdi_to_df , plot_xY
30- from causalpy .pymc_models import (
31- BayesianBasisExpansionTimeSeries ,
32- PyMCModel ,
33- StateSpaceTimeSeries ,
34- )
30+ from causalpy .pymc_models import PyMCModel
3531from causalpy .utils import round_num
3632
3733from .base import BaseExperiment
@@ -153,27 +149,15 @@ def __init__(
153149 )
154150
155151 # fit the model to the observed (pre-intervention) data
152+ # All PyMC models now accept xr.DataArray with consistent API
156153 if isinstance (self .model , PyMCModel ):
157- is_bsts_like = isinstance (
158- self .model , (BayesianBasisExpansionTimeSeries , StateSpaceTimeSeries )
159- )
160-
161- if is_bsts_like :
162- # BSTS/StateSpace models expect numpy arrays and datetime coords
163- X_fit = self .pre_X .values if self .pre_X .shape [1 ] > 0 else None # type: ignore[attr-defined]
164- y_fit = self .pre_y .isel (treated_units = 0 ).values # type: ignore[attr-defined]
165- pre_coords : dict [str , Any ] = {"datetime_index" : self .datapre .index }
166- if X_fit is not None :
167- pre_coords ["coeffs" ] = list (self .labels )
168- self .model .fit (X = X_fit , y = y_fit , coords = pre_coords )
169- else :
170- # General PyMC models expect xarray with treated_units
171- COORDS = {
172- "coeffs" : self .labels ,
173- "obs_ind" : np .arange (self .pre_X .shape [0 ]),
174- "treated_units" : ["unit_0" ],
175- }
176- self .model .fit (X = self .pre_X , y = self .pre_y , coords = COORDS )
154+ COORDS : dict [str , Any ] = {
155+ "coeffs" : self .labels ,
156+ "obs_ind" : np .arange (self .pre_X .shape [0 ]),
157+ "treated_units" : ["unit_0" ],
158+ "datetime_index" : self .datapre .index , # For time series models
159+ }
160+ self .model .fit (X = self .pre_X , y = self .pre_y , coords = COORDS )
177161 elif isinstance (self .model , RegressorMixin ):
178162 # For OLS models, use 1D y data
179163 self .model .fit (X = self .pre_X , y = self .pre_y .isel (treated_units = 0 ))
@@ -182,85 +166,28 @@ def __init__(
182166
183167 # score the goodness of fit to the pre-intervention data
184168 if isinstance (self .model , PyMCModel ):
185- is_bsts_like = isinstance (
186- self .model , (BayesianBasisExpansionTimeSeries , StateSpaceTimeSeries )
187- )
188- if is_bsts_like :
189- X_score = self .pre_X .values if self .pre_X .shape [1 ] > 0 else None # type: ignore[attr-defined]
190- y_score = self .pre_y .isel (treated_units = 0 ).values # type: ignore[attr-defined]
191- score_coords : dict [str , Any ] = {"datetime_index" : self .datapre .index }
192- if X_score is not None :
193- score_coords ["coeffs" ] = list (self .labels )
194- self .score = self .model .score (X = X_score , y = y_score , coords = score_coords )
195- else :
196- self .score = self .model .score (X = self .pre_X , y = self .pre_y )
169+ self .score = self .model .score (X = self .pre_X , y = self .pre_y )
197170 elif isinstance (self .model , RegressorMixin ):
198171 self .score = self .model .score (
199172 X = self .pre_X , y = self .pre_y .isel (treated_units = 0 )
200173 )
201174
202175 # get the model predictions of the observed (pre-intervention) data
203176 if isinstance (self .model , PyMCModel ):
204- is_bsts_like = isinstance (
205- self .model , (BayesianBasisExpansionTimeSeries , StateSpaceTimeSeries )
206- )
207- if is_bsts_like :
208- X_pre_predict = self .pre_X .values if self .pre_X .shape [1 ] > 0 else None # type: ignore[attr-defined]
209- pre_pred_coords : dict [str , Any ] = {"datetime_index" : self .datapre .index }
210- self .pre_pred = self .model .predict (
211- X = X_pre_predict , coords = pre_pred_coords
212- )
213- if not isinstance (self .pre_pred , az .InferenceData ):
214- self .pre_pred = az .InferenceData (posterior_predictive = self .pre_pred )
215- else :
216- self .pre_pred = self .model .predict (X = self .pre_X )
177+ self .pre_pred = self .model .predict (X = self .pre_X )
217178 elif isinstance (self .model , RegressorMixin ):
218179 self .pre_pred = self .model .predict (X = self .pre_X )
219180
220181 # calculate the counterfactual (post period)
221182 if isinstance (self .model , PyMCModel ):
222- is_bsts_like = isinstance (
223- self .model , (BayesianBasisExpansionTimeSeries , StateSpaceTimeSeries )
224- )
225- if is_bsts_like :
226- X_post_predict = (
227- self .post_X .values if self .post_X .shape [1 ] > 0 else None # type: ignore[attr-defined]
228- )
229- post_pred_coords : dict [str , Any ] = {
230- "datetime_index" : self .datapost .index
231- }
232- self .post_pred = self .model .predict (
233- X = X_post_predict , coords = post_pred_coords , out_of_sample = True
234- )
235- if not isinstance (self .post_pred , az .InferenceData ):
236- self .post_pred = az .InferenceData (
237- posterior_predictive = self .post_pred
238- )
239- else :
240- self .post_pred = self .model .predict (X = self .post_X )
183+ self .post_pred = self .model .predict (X = self .post_X , out_of_sample = True )
241184 elif isinstance (self .model , RegressorMixin ):
242185 self .post_pred = self .model .predict (X = self .post_X )
243186
244- # calculate impact - use appropriate y data format for each model type
187+ # calculate impact - all PyMC models now use 2D data with treated_units
245188 if isinstance (self .model , PyMCModel ):
246- is_bsts_like = isinstance (
247- self .model , (BayesianBasisExpansionTimeSeries , StateSpaceTimeSeries )
248- )
249- if is_bsts_like :
250- pre_y_for_impact = self .pre_y .isel (treated_units = 0 )
251- post_y_for_impact = self .post_y .isel (treated_units = 0 )
252- self .pre_impact = self .model .calculate_impact (
253- pre_y_for_impact , self .pre_pred
254- )
255- self .post_impact = self .model .calculate_impact (
256- post_y_for_impact , self .post_pred
257- )
258- else :
259- # PyMC models with treated_units use 2D data
260- self .pre_impact = self .model .calculate_impact (self .pre_y , self .pre_pred )
261- self .post_impact = self .model .calculate_impact (
262- self .post_y , self .post_pred
263- )
189+ self .pre_impact = self .model .calculate_impact (self .pre_y , self .pre_pred )
190+ self .post_impact = self .model .calculate_impact (self .post_y , self .post_pred )
264191 elif isinstance (self .model , RegressorMixin ):
265192 # SKL models work with 1D data
266193 self .pre_impact = self .model .calculate_impact (
0 commit comments