Skip to content

Commit c62d3bc

Browse files
igerberclaude
andcommitted
Add input validation, pre-fit warning, and docs fixes from AI review round 2
Address 5 issues from PR #145 AI review: - P0: Validate treatment is constant within unit (reject staggered designs) - P1: Enforce balanced panel (all units must have all periods) - P1: Warn when pre-treatment fit RMSE exceeds treated outcome SD - P1: Fix Registry FW iteration count (1000 → 10000, matching R/code) - P2: Fix misleading placebo docstring (weights use fresh start, not warm start) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 7158bf0 commit c62d3bc

File tree

3 files changed

+218
-3
lines changed

3 files changed

+218
-3
lines changed

diff_diff/synthetic_did.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,24 @@ def fit( # type: ignore[override]
267267
# Identify treated and control units
268268
# Treatment indicator should be constant within unit
269269
unit_treatment = data.groupby(unit)[treatment].first()
270+
271+
# Validate treatment is constant within unit (SDID requires block treatment)
272+
treatment_nunique = data.groupby(unit)[treatment].nunique()
273+
varying_units = treatment_nunique[treatment_nunique > 1]
274+
if len(varying_units) > 0:
275+
example_unit = varying_units.index[0]
276+
example_vals = sorted(
277+
data.loc[data[unit] == example_unit, treatment].unique()
278+
)
279+
raise ValueError(
280+
f"Treatment indicator varies within {len(varying_units)} unit(s) "
281+
f"(e.g., unit '{example_unit}' has values {example_vals}). "
282+
f"SyntheticDiD requires 'block' treatment where treatment is "
283+
f"constant within each unit across all time periods. "
284+
f"For staggered adoption designs, use CallawaySantAnna or "
285+
f"ImputationDiD instead."
286+
)
287+
270288
treated_units = unit_treatment[unit_treatment == 1].index.tolist()
271289
control_units = unit_treatment[unit_treatment == 0].index.tolist()
272290

@@ -275,6 +293,21 @@ def fit( # type: ignore[override]
275293
if len(control_units) == 0:
276294
raise ValueError("No control units found")
277295

296+
# Validate balanced panel (SDID requires all units observed in all periods)
297+
periods_per_unit = data.groupby(unit)[time].nunique()
298+
expected_n_periods = len(all_periods)
299+
unbalanced_units = periods_per_unit[periods_per_unit != expected_n_periods]
300+
if len(unbalanced_units) > 0:
301+
example_unit = unbalanced_units.index[0]
302+
actual_count = unbalanced_units.iloc[0]
303+
raise ValueError(
304+
f"Panel is not balanced: {len(unbalanced_units)} unit(s) do not "
305+
f"have observations in all {expected_n_periods} periods "
306+
f"(e.g., unit '{example_unit}' has {actual_count} periods). "
307+
f"SyntheticDiD requires a balanced panel. Use "
308+
f"diff_diff.prep.balance_panel() to balance the panel first."
309+
)
310+
278311
# Residualize covariates if provided
279312
working_data = data.copy()
280313
if covariates:
@@ -338,6 +371,22 @@ def fit( # type: ignore[override]
338371
synthetic_pre = Y_pre_control @ unit_weights
339372
pre_fit_rmse = np.sqrt(np.mean((Y_pre_treated_mean - synthetic_pre) ** 2))
340373

374+
# Warn if pre-treatment fit is poor (Registry requirement).
375+
# Threshold: 1× SD of treated pre-treatment outcomes — a natural baseline
376+
# since RMSE exceeding natural variation indicates the synthetic control
377+
# fails to reproduce the treated series' level or trend.
378+
pre_treatment_sd = np.std(Y_pre_treated_mean, ddof=1) if len(Y_pre_treated_mean) > 1 else 0.0
379+
if pre_treatment_sd > 0 and pre_fit_rmse > pre_treatment_sd:
380+
warnings.warn(
381+
f"Pre-treatment fit is poor: RMSE ({pre_fit_rmse:.4f}) exceeds "
382+
f"the standard deviation of treated pre-treatment outcomes "
383+
f"({pre_treatment_sd:.4f}). The synthetic control may not "
384+
f"adequately reproduce treated unit trends. Consider adding "
385+
f"more control units or adjusting regularization.",
386+
UserWarning,
387+
stacklevel=2,
388+
)
389+
341390
# Compute standard errors based on variance_method
342391
if self.variance_method == "bootstrap":
343392
se, bootstrap_estimates = self._bootstrap_se(
@@ -814,7 +863,7 @@ def _placebo_variance_se(
814863

815864
# Re-estimate weights on permuted data (matching R's behavior)
816865
# R passes update.omega=TRUE, update.lambda=TRUE via opts,
817-
# using original weights as starting points for FW optimization.
866+
# re-estimating weights from uniform initialization (fresh start).
818867
# Unit weights: re-estimate on pseudo-control/pseudo-treated data
819868
pseudo_omega = compute_sdid_unit_weights(
820869
Y_pre_pseudo_control,

docs/methodology/REGISTRY.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ where A = Y_unit[:, :N_co], b = Y_unit[:, N_co], and centering is column-wise (i
587587
**Two-pass sparsification procedure** (matches R's `synthdid::sc.weight.fw` + `sparsify_function`):
588588
1. First pass: Run Frank-Wolfe for 100 iterations (max_iter_pre_sparsify) from uniform initialization
589589
2. Sparsify: `v[v <= max(v)/4] = 0; v = v / sum(v)` (zero out small weights, renormalize)
590-
3. Second pass: Run Frank-Wolfe for 1000 iterations (max_iter) starting from sparsified weights
590+
3. Second pass: Run Frank-Wolfe for 10000 iterations (max_iter) starting from sparsified weights
591591

592592
The sparsification step concentrates weights on the most important control units, improving interpretability and stability.
593593

@@ -659,13 +659,16 @@ Convergence criterion: stop when objective decrease < min_decrease² (default mi
659659
- **Noise level with < 2 pre-periods**: Returns 0.0, which makes both zeta_omega and zeta_lambda equal to 0.0 (no regularization).
660660
- **NaN inference for undefined statistics**: t_stat uses NaN when SE is zero or non-finite; p_value and CI also NaN. Matches CallawaySantAnna NaN convention.
661661
- **Placebo p-value floor**: `p_value = max(empirical_p, 1/(n_replications + 1))` to avoid reporting exactly zero.
662+
- **Varying treatment within unit**: Raises `ValueError`. SDID requires block treatment (constant within each unit). Suggests CallawaySantAnna or ImputationDiD for staggered adoption.
663+
- **Unbalanced panel**: Raises `ValueError`. SDID requires all units observed in all periods. Suggests `balance_panel()`.
664+
- **Poor pre-treatment fit**: Warns (`UserWarning`) when `pre_fit_rmse > std(treated_pre_outcomes, ddof=1)`. Diagnostic only; estimation proceeds.
662665

663666
**Reference implementation(s):**
664667
- R: `synthdid::synthdid_estimate()` (Arkhangelsky et al.'s official package)
665668
- Key R functions matched: `sc.weight.fw()` (Frank-Wolfe), `sparsify_function` (sparsification), `vcov.synthdid_estimate()` (variance)
666669

667670
**Requirements checklist:**
668-
- [x] Unit weights: Frank-Wolfe on collapsed form (T_pre, N_co+1), two-pass sparsification (100 iters -> sparsify -> 1000 iters)
671+
- [x] Unit weights: Frank-Wolfe on collapsed form (T_pre, N_co+1), two-pass sparsification (100 iters -> sparsify -> 10000 iters)
669672
- [x] Time weights: Frank-Wolfe on collapsed form (N_co, T_pre+1), last column = per-control post mean
670673
- [x] Unit and time weights: sum to 1, non-negative (simplex constraint)
671674
- [x] Auto-regularization: noise_level = sd(first_diffs), zeta_omega = (N1*T1)^0.25 * noise_level, zeta_lambda = 1e-6 * noise_level

tests/test_methodology_sdid.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import numpy as np
1212
import pytest
1313

14+
import pandas as pd
15+
1416
from diff_diff.synthetic_did import SyntheticDiD
1517
from diff_diff.utils import (
1618
_compute_noise_level,
@@ -730,3 +732,164 @@ def test_placebo_reestimates_weights_not_fixed(self):
730732
f"({fixed_se:.6f}), suggesting weights are NOT being "
731733
f"re-estimated as R's synthdid does."
732734
)
735+
736+
737+
# =============================================================================
738+
# Treatment Validation
739+
# =============================================================================
740+
741+
742+
class TestTreatmentValidation:
743+
"""Test that SDID rejects time-varying treatment (staggered designs)."""
744+
745+
def test_varying_treatment_within_unit_raises(self):
746+
"""Unit whose treatment switches over time should raise ValueError."""
747+
np.random.seed(42)
748+
data = pd.DataFrame({
749+
"unit": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
750+
"time": [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4],
751+
"outcome": np.random.randn(12),
752+
# Unit 1: treatment turns on at time 3 (staggered)
753+
"treated": [0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
754+
})
755+
sdid = SyntheticDiD()
756+
with pytest.raises(ValueError, match="Treatment indicator varies within"):
757+
sdid.fit(
758+
data, outcome="outcome", treatment="treated",
759+
unit="unit", time="time", post_periods=[3, 4],
760+
)
761+
762+
def test_constant_treatment_passes(self):
763+
"""Normal block-treatment data should pass validation."""
764+
np.random.seed(42)
765+
n_units, n_periods = 10, 8
766+
rows = []
767+
for u in range(n_units):
768+
is_treated = 1 if u < 3 else 0
769+
for t in range(n_periods):
770+
rows.append({
771+
"unit": u, "time": t,
772+
"outcome": np.random.randn() + (2.0 if is_treated and t >= 5 else 0),
773+
"treated": is_treated,
774+
})
775+
data = pd.DataFrame(rows)
776+
sdid = SyntheticDiD()
777+
result = sdid.fit(
778+
data, outcome="outcome", treatment="treated",
779+
unit="unit", time="time", post_periods=[5, 6, 7],
780+
)
781+
assert result is not None
782+
783+
784+
# =============================================================================
785+
# Balanced Panel Validation
786+
# =============================================================================
787+
788+
789+
class TestBalancedPanelValidation:
790+
"""Test that SDID rejects unbalanced panels."""
791+
792+
def test_unbalanced_panel_raises(self):
793+
"""Unit missing a period should raise ValueError."""
794+
np.random.seed(42)
795+
rows = []
796+
for u in range(6):
797+
is_treated = 1 if u < 2 else 0
798+
for t in range(5):
799+
rows.append({
800+
"unit": u, "time": t,
801+
"outcome": np.random.randn(),
802+
"treated": is_treated,
803+
})
804+
data = pd.DataFrame(rows)
805+
# Drop one observation to make panel unbalanced
806+
data = data[~((data["unit"] == 3) & (data["time"] == 2))].reset_index(drop=True)
807+
808+
sdid = SyntheticDiD()
809+
with pytest.raises(ValueError, match="Panel is not balanced"):
810+
sdid.fit(
811+
data, outcome="outcome", treatment="treated",
812+
unit="unit", time="time", post_periods=[3, 4],
813+
)
814+
815+
def test_balanced_panel_passes(self):
816+
"""Fully balanced panel should pass validation."""
817+
np.random.seed(42)
818+
rows = []
819+
for u in range(8):
820+
is_treated = 1 if u < 2 else 0
821+
for t in range(6):
822+
rows.append({
823+
"unit": u, "time": t,
824+
"outcome": np.random.randn() + (1.5 if is_treated and t >= 4 else 0),
825+
"treated": is_treated,
826+
})
827+
data = pd.DataFrame(rows)
828+
sdid = SyntheticDiD()
829+
result = sdid.fit(
830+
data, outcome="outcome", treatment="treated",
831+
unit="unit", time="time", post_periods=[4, 5],
832+
)
833+
assert result is not None
834+
835+
836+
# =============================================================================
837+
# Pre-treatment Fit Warning
838+
# =============================================================================
839+
840+
841+
class TestPreTreatmentFitWarning:
842+
"""Test that poor pre-treatment fit emits a warning."""
843+
844+
def test_poor_fit_emits_warning(self):
845+
"""Treated units at very different level from controls should warn."""
846+
np.random.seed(42)
847+
rows = []
848+
for u in range(10):
849+
is_treated = 1 if u < 2 else 0
850+
# Large level difference: treated ~100, control ~10
851+
level = 100.0 if is_treated else 10.0
852+
for t in range(8):
853+
rows.append({
854+
"unit": u, "time": t,
855+
"outcome": level + np.random.randn() * 0.5,
856+
"treated": is_treated,
857+
})
858+
data = pd.DataFrame(rows)
859+
sdid = SyntheticDiD()
860+
with warnings.catch_warnings(record=True) as w:
861+
warnings.simplefilter("always")
862+
sdid.fit(
863+
data, outcome="outcome", treatment="treated",
864+
unit="unit", time="time", post_periods=[6, 7],
865+
)
866+
fit_warnings = [x for x in w if "Pre-treatment fit is poor" in str(x.message)]
867+
assert len(fit_warnings) >= 1, (
868+
"Expected warning about poor pre-treatment fit but none was raised"
869+
)
870+
871+
def test_good_fit_no_warning(self):
872+
"""Parallel trends data with similar levels should not warn."""
873+
np.random.seed(42)
874+
rows = []
875+
for u in range(10):
876+
is_treated = 1 if u < 3 else 0
877+
for t in range(8):
878+
# Same level, parallel trends, treatment effect only in post
879+
rows.append({
880+
"unit": u, "time": t,
881+
"outcome": t + np.random.randn() * 0.3 + (2.0 if is_treated and t >= 5 else 0),
882+
"treated": is_treated,
883+
})
884+
data = pd.DataFrame(rows)
885+
sdid = SyntheticDiD()
886+
with warnings.catch_warnings(record=True) as w:
887+
warnings.simplefilter("always")
888+
sdid.fit(
889+
data, outcome="outcome", treatment="treated",
890+
unit="unit", time="time", post_periods=[5, 6, 7],
891+
)
892+
fit_warnings = [x for x in w if "Pre-treatment fit is poor" in str(x.message)]
893+
assert len(fit_warnings) == 0, (
894+
f"Unexpected pre-treatment fit warning: {fit_warnings[0].message}"
895+
)

0 commit comments

Comments
 (0)