From bc0025dad4113615409ec9412bb40801dd579712 Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 16 Feb 2026 12:28:51 -0500 Subject: [PATCH 1/4] Add Two-Stage DiD estimator (Gardner 2022) Implement the Two-Stage DiD estimator from Gardner (2022), matching the R `did2s` package. Stage 1 estimates unit+time fixed effects on untreated observations; Stage 2 regresses residualized outcomes on treatment indicators. Point estimates are identical to ImputationDiD; the key contribution is a GMM sandwich variance estimator (Newey & McFadden 1994) that accounts for first-stage estimation error. - TwoStageDiD class with static, event study, and group aggregation - Custom GMM sandwich variance (cannot reuse compute_robust_vcov) - Multiplier bootstrap on GMM influence function - 51 tests including equivalence tests with ImputationDiD - Full documentation: README, REGISTRY, CLAUDE.md, ROADMAP Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 11 + README.md | 105 +- ROADMAP.md | 4 +- diff_diff/__init__.py | 10 + diff_diff/two_stage.py | 2103 ++++++++++++++++++++++++++++++++++ diff_diff/visualization.py | 2 + docs/methodology/REGISTRY.md | 73 ++ tests/test_two_stage.py | 955 +++++++++++++++ 8 files changed, 3260 insertions(+), 3 deletions(-) create mode 100644 diff_diff/two_stage.py create mode 100644 tests/test_two_stage.py diff --git a/CLAUDE.md b/CLAUDE.md index 9eba2be7..cc24bd14 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -111,6 +111,15 @@ cross-platform compilation - no OpenBLAS or Intel MKL installation required. - Pre-trend test (Equation 9) via `results.pretrend_test()` - Proposition 5: NaN for unidentified long-run horizons without never-treated units +- **`diff_diff/two_stage.py`** - Gardner (2022) Two-Stage DiD estimator: + - `TwoStageDiD` - Two-stage estimator: (1) estimate unit+time FE on untreated obs, (2) regress residualized outcomes on treatment indicators + - `TwoStageDiDResults` - Results with overall ATT, event study, group effects, per-observation treatment effects + - `TwoStageBootstrapResults` - Multiplier bootstrap inference on GMM influence function + - `two_stage_did()` - Convenience function + - Point estimates identical to ImputationDiD; different variance estimator (GMM sandwich vs. conservative) + - Custom `_compute_gmm_variance()` — cannot reuse `compute_robust_vcov()` because correction term uses GLOBAL cross-moment + - No finite-sample adjustments (raw asymptotic sandwich, matching R `did2s`) + - **`diff_diff/triple_diff.py`** - Triple Difference (DDD) estimator: - `TripleDifference` - Ortiz-Villavicencio & Sant'Anna (2025) estimator for DDD designs - `TripleDifferenceResults` - Results with ATT, SEs, cell means, diagnostics @@ -270,6 +279,7 @@ cross-platform compilation - no OpenBLAS or Intel MKL installation required. ├── CallawaySantAnna ├── SunAbraham ├── ImputationDiD + ├── TwoStageDiD ├── TripleDifference ├── TROP ├── SyntheticDiD @@ -381,6 +391,7 @@ Tests mirror the source modules: - `tests/test_staggered.py` - Tests for CallawaySantAnna - `tests/test_sun_abraham.py` - Tests for SunAbraham interaction-weighted estimator - `tests/test_imputation.py` - Tests for ImputationDiD (Borusyak et al. 2024) estimator +- `tests/test_two_stage.py` - Tests for TwoStageDiD (Gardner 2022) estimator, including equivalence tests with ImputationDiD - `tests/test_triple_diff.py` - Tests for Triple Difference (DDD) estimator - `tests/test_trop.py` - Tests for Triply Robust Panel (TROP) estimator - `tests/test_bacon.py` - Tests for Goodman-Bacon decomposition diff --git a/README.md b/README.md index cdf21c42..f6d103b1 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1 - **Wild cluster bootstrap**: Valid inference with few clusters (<50) using Rademacher, Webb, or Mammen weights - **Panel data support**: Two-way fixed effects estimator for panel designs - **Multi-period analysis**: Event-study style DiD with period-specific treatment effects -- **Staggered adoption**: Callaway-Sant'Anna (2021), Sun-Abraham (2021), and Borusyak-Jaravel-Spiess (2024) imputation estimators for heterogeneous treatment timing +- **Staggered adoption**: Callaway-Sant'Anna (2021), Sun-Abraham (2021), Borusyak-Jaravel-Spiess (2024) imputation, and Two-Stage DiD (Gardner 2022) estimators for heterogeneous treatment timing - **Triple Difference (DDD)**: Ortiz-Villavicencio & Sant'Anna (2025) estimators with proper covariate handling - **Synthetic DiD**: Combined DiD with synthetic control for improved robustness - **Triply Robust Panel (TROP)**: Factor-adjusted DiD with synthetic weights (Athey et al. 2025) @@ -927,6 +927,53 @@ ImputationDiD( | Inference | Conservative variance (Theorem 3) | Multiplier bootstrap | | Pre-trends | Built-in F-test (Equation 9) | Separate testing | +### Two-Stage DiD (Gardner 2022) + +Two-Stage DiD addresses TWFE bias in staggered adoption designs by estimating unit and time fixed effects on untreated observations only, then regressing the residualized outcomes on treatment indicators. Point estimates match the Imputation DiD estimator (Borusyak et al. 2024); the key difference is that Two-Stage DiD uses a GMM sandwich variance estimator that accounts for first-stage estimation error, while Imputation DiD uses a conservative variance (Theorem 3). + +```python +from diff_diff import TwoStageDiD + +# Basic usage +est = TwoStageDiD() +results = est.fit(data, outcome='outcome', unit='unit', time='period', first_treat='first_treat') +results.print_summary() +``` + +**Event study:** + +```python +# Event study aggregation with visualization +results = est.fit(data, outcome='outcome', unit='unit', time='period', + first_treat='first_treat', aggregate='event_study') +plot_event_study(results) +``` + +**Parameters:** + +```python +TwoStageDiD( + anticipation=0, # Periods of anticipation effects + alpha=0.05, # Significance level for CIs + cluster=None, # Column for cluster-robust SEs (defaults to unit) + n_bootstrap=0, # Bootstrap iterations (0 = analytical GMM SEs) + seed=None, # Random seed + rank_deficient_action='warn', # 'warn', 'error', or 'silent' + horizon_max=None, # Max event-study horizon +) +``` + +**When to use Two-Stage DiD vs Imputation DiD:** + +| Aspect | Two-Stage DiD | Imputation DiD | +|--------|--------------|---------------| +| Point estimates | Identical | Identical | +| Variance | GMM sandwich (accounts for first-stage error) | Conservative (Theorem 3, may overcover) | +| Intuition | Residualize then regress | Impute counterfactuals then aggregate | +| Reference impl. | R `did2s` package | R `didimputation` package | + +Both estimators are the efficient estimator under homogeneous treatment effects, producing shorter confidence intervals than Callaway-Sant'Anna or Sun-Abraham. + ### Triple Difference (DDD) Triple Difference (DDD) is used when treatment requires satisfying two criteria: belonging to a treated **group** AND being in an eligible **partition**. The `TripleDifference` class implements the methodology from Ortiz-Villavicencio & Sant'Anna (2025), which correctly handles covariate adjustment (unlike naive implementations). @@ -2104,6 +2151,58 @@ ImputationDiD( | `to_dataframe(level)` | Convert to DataFrame ('observation', 'event_study', 'group') | | `pretrend_test(n_leads)` | Run pre-trend F-test (Equation 9) | +### TwoStageDiD + +```python +TwoStageDiD( + anticipation=0, # Periods of anticipation effects + alpha=0.05, # Significance level for CIs + cluster=None, # Column for cluster-robust SEs (defaults to unit) + n_bootstrap=0, # Bootstrap iterations (0 = analytical GMM SEs) + seed=None, # Random seed + rank_deficient_action='warn', # 'warn', 'error', or 'silent' + horizon_max=None, # Max event-study horizon +) +``` + +**fit() Parameters:** + +| Parameter | Type | Description | +|-----------|------|-------------| +| `data` | DataFrame | Panel data | +| `outcome` | str | Outcome variable column name | +| `unit` | str | Unit identifier column | +| `time` | str | Time period column | +| `first_treat` | str | First treatment period column (0 for never-treated) | +| `covariates` | list | Covariate column names | +| `aggregate` | str | Aggregation: None, "event_study", "group", "all" | +| `balance_e` | int | Balance event study to this many pre-treatment periods | + +### TwoStageDiDResults + +**Attributes:** + +| Attribute | Description | +|-----------|-------------| +| `overall_att` | Overall average treatment effect on the treated | +| `overall_se` | Standard error (GMM sandwich variance) | +| `overall_t_stat` | T-statistic | +| `overall_p_value` | P-value for H0: ATT = 0 | +| `overall_conf_int` | Confidence interval | +| `event_study_effects` | Dict of relative time -> effect dict (if `aggregate='event_study'` or `'all'`) | +| `group_effects` | Dict of cohort -> effect dict (if `aggregate='group'` or `'all'`) | +| `treatment_effects` | DataFrame of unit-level treatment effects | +| `n_treated_obs` | Number of treated observations | +| `n_untreated_obs` | Number of untreated observations | + +**Methods:** + +| Method | Description | +|--------|-------------| +| `summary(alpha)` | Get formatted summary string | +| `print_summary(alpha)` | Print summary to stdout | +| `to_dataframe(level)` | Convert to DataFrame ('observation', 'event_study', 'group') | + ### TripleDifference ```python @@ -2582,6 +2681,10 @@ The `HonestDiD` module implements sensitivity analysis methods for relaxing the - **Sun, L., & Abraham, S. (2021).** "Estimating Dynamic Treatment Effects in Event Studies with Heterogeneous Treatment Effects." *Journal of Econometrics*, 225(2), 175-199. [https://doi.org/10.1016/j.jeconom.2020.09.006](https://doi.org/10.1016/j.jeconom.2020.09.006) +- **Gardner, J. (2022).** "Two-stage differences in differences." *arXiv preprint arXiv:2207.05943*. [https://arxiv.org/abs/2207.05943](https://arxiv.org/abs/2207.05943) + +- **Butts, K., & Gardner, J. (2022).** "did2s: Two-Stage Difference-in-Differences." *The R Journal*, 14(1), 162-173. [https://doi.org/10.32614/RJ-2022-048](https://doi.org/10.32614/RJ-2022-048) + - **de Chaisemartin, C., & D'Haultfœuille, X. (2020).** "Two-Way Fixed Effects Estimators with Heterogeneous Treatment Effects." *American Economic Review*, 110(9), 2964-2996. [https://doi.org/10.1257/aer.20181169](https://doi.org/10.1257/aer.20181169) - **Goodman-Bacon, A. (2021).** "Difference-in-Differences with Variation in Treatment Timing." *Journal of Econometrics*, 225(2), 254-277. [https://doi.org/10.1016/j.jeconom.2021.03.014](https://doi.org/10.1016/j.jeconom.2021.03.014) diff --git a/ROADMAP.md b/ROADMAP.md index ddd35b59..47e1d2f5 100644 --- a/ROADMAP.md +++ b/ROADMAP.md @@ -10,7 +10,7 @@ For past changes and release history, see [CHANGELOG.md](CHANGELOG.md). diff-diff v2.3.0 is a **production-ready** DiD library with feature parity with R's `did` + `HonestDiD` + `synthdid` ecosystem for core DiD analysis: -- **Core estimators**: Basic DiD, TWFE, MultiPeriod, Callaway-Sant'Anna, Sun-Abraham, Borusyak-Jaravel-Spiess Imputation, Synthetic DiD, Triple Difference (DDD), TROP +- **Core estimators**: Basic DiD, TWFE, MultiPeriod, Callaway-Sant'Anna, Sun-Abraham, Borusyak-Jaravel-Spiess Imputation, Synthetic DiD, Triple Difference (DDD), TROP, Two-Stage DiD (Gardner 2022) - **Valid inference**: Robust SEs, cluster SEs, wild bootstrap, multiplier bootstrap, placebo-based variance - **Assumption diagnostics**: Parallel trends tests, placebo tests, Goodman-Bacon decomposition - **Sensitivity analysis**: Honest DiD (Rambachan-Roth), Pre-trends power analysis (Roth 2022) @@ -24,7 +24,7 @@ diff-diff v2.3.0 is a **production-ready** DiD library with feature parity with High-value additions building on our existing foundation. -### Gardner's Two-Stage DiD (did2s) +### Gardner's Two-Stage DiD (did2s) -- IMPLEMENTED (v2.4) Two-stage approach gaining traction in applied work. First residualizes outcomes, then estimates effects. diff --git a/diff_diff/__init__.py b/diff_diff/__init__.py index b2a305fb..01c50e75 100644 --- a/diff_diff/__init__.py +++ b/diff_diff/__init__.py @@ -101,6 +101,12 @@ ImputationDiDResults, imputation_did, ) +from diff_diff.two_stage import ( + TwoStageBootstrapResults, + TwoStageDiD, + TwoStageDiDResults, + two_stage_did, +) from diff_diff.sun_abraham import ( SABootstrapResults, SunAbraham, @@ -152,6 +158,7 @@ "CallawaySantAnna", "SunAbraham", "ImputationDiD", + "TwoStageDiD", "TripleDifference", "TROP", # Bacon Decomposition @@ -173,6 +180,9 @@ "ImputationDiDResults", "ImputationBootstrapResults", "imputation_did", + "TwoStageDiDResults", + "TwoStageBootstrapResults", + "two_stage_did", "TripleDifferenceResults", "triple_difference", "TROPResults", diff --git a/diff_diff/two_stage.py b/diff_diff/two_stage.py new file mode 100644 index 00000000..de2c3dbe --- /dev/null +++ b/diff_diff/two_stage.py @@ -0,0 +1,2103 @@ +""" +Gardner (2022) Two-Stage Difference-in-Differences Estimator. + +Implements the two-stage DiD estimator from Gardner (2022), "Two-stage +differences in differences". The method: +1. Estimates unit + time fixed effects on untreated observations only +2. Residualizes ALL outcomes using estimated FEs +3. Regresses residualized outcomes on treatment indicators (Stage 2) + +Inference uses the GMM sandwich variance estimator from Butts & Gardner +(2022) that correctly accounts for first-stage estimation uncertainty. + +Point estimates are identical to ImputationDiD (Borusyak et al. 2024); +the key difference is the variance estimator (GMM sandwich vs. conservative). + +References +---------- +Gardner, J. (2022). Two-stage differences in differences. + arXiv:2207.05943. +Butts, K. & Gardner, J. (2022). did2s: Two-Stage + Difference-in-Differences. R Journal, 14(1), 162-173. +""" + +import warnings +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set, Tuple + +import numpy as np +import pandas as pd +from scipy import sparse +from scipy.sparse.linalg import factorized as sparse_factorized + +from diff_diff.linalg import solve_ols +from diff_diff.results import _get_significance_stars +from diff_diff.utils import compute_confidence_interval, compute_p_value + +# ============================================================================= +# Results Dataclasses +# ============================================================================= + + +@dataclass +class TwoStageBootstrapResults: + """ + Results from TwoStageDiD bootstrap inference. + + Bootstrap uses multiplier bootstrap on the GMM influence function, + consistent with other library estimators. The R `did2s` package uses + block bootstrap by default; multiplier bootstrap is asymptotically + equivalent. + + Attributes + ---------- + n_bootstrap : int + Number of bootstrap iterations. + weight_type : str + Type of bootstrap weights (currently "rademacher" only). + alpha : float + Significance level used for confidence intervals. + overall_att_se : float + Bootstrap standard error for overall ATT. + overall_att_ci : tuple + Bootstrap confidence interval for overall ATT. + overall_att_p_value : float + Bootstrap p-value for overall ATT. + event_study_ses : dict, optional + Bootstrap SEs for event study effects. + event_study_cis : dict, optional + Bootstrap CIs for event study effects. + event_study_p_values : dict, optional + Bootstrap p-values for event study effects. + group_ses : dict, optional + Bootstrap SEs for group effects. + group_cis : dict, optional + Bootstrap CIs for group effects. + group_p_values : dict, optional + Bootstrap p-values for group effects. + bootstrap_distribution : np.ndarray, optional + Full bootstrap distribution of overall ATT. + """ + + n_bootstrap: int + weight_type: str + alpha: float + overall_att_se: float + overall_att_ci: Tuple[float, float] + overall_att_p_value: float + event_study_ses: Optional[Dict[int, float]] = None + event_study_cis: Optional[Dict[int, Tuple[float, float]]] = None + event_study_p_values: Optional[Dict[int, float]] = None + group_ses: Optional[Dict[Any, float]] = None + group_cis: Optional[Dict[Any, Tuple[float, float]]] = None + group_p_values: Optional[Dict[Any, float]] = None + bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False) + + +@dataclass +class TwoStageDiDResults: + """ + Results from Gardner (2022) two-stage DiD estimation. + + Attributes + ---------- + treatment_effects : pd.DataFrame + Per-observation treatment effects with columns: unit, time, + tau_hat, weight. tau_hat is the residualized outcome y_tilde + for treated observations; weight is 1/n_treated. + overall_att : float + Overall average treatment effect on the treated. + overall_se : float + Standard error of overall ATT (GMM sandwich). + overall_t_stat : float + T-statistic for overall ATT. + overall_p_value : float + P-value for overall ATT. + overall_conf_int : tuple + Confidence interval for overall ATT. + event_study_effects : dict, optional + Dictionary mapping relative time h to effect dict with keys: + 'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_obs'. + group_effects : dict, optional + Dictionary mapping cohort g to effect dict. + groups : list + List of treatment cohorts. + time_periods : list + List of all time periods. + n_obs : int + Total number of observations. + n_treated_obs : int + Number of treated observations. + n_untreated_obs : int + Number of untreated observations. + n_treated_units : int + Number of ever-treated units. + n_control_units : int + Number of units contributing to untreated observations. + alpha : float + Significance level used. + bootstrap_results : TwoStageBootstrapResults, optional + Bootstrap inference results. + """ + + treatment_effects: pd.DataFrame + overall_att: float + overall_se: float + overall_t_stat: float + overall_p_value: float + overall_conf_int: Tuple[float, float] + event_study_effects: Optional[Dict[int, Dict[str, Any]]] + group_effects: Optional[Dict[Any, Dict[str, Any]]] + groups: List[Any] + time_periods: List[Any] + n_obs: int + n_treated_obs: int + n_untreated_obs: int + n_treated_units: int + n_control_units: int + alpha: float = 0.05 + bootstrap_results: Optional[TwoStageBootstrapResults] = field(default=None, repr=False) + + def __repr__(self) -> str: + """Concise string representation.""" + sig = _get_significance_stars(self.overall_p_value) + return ( + f"TwoStageDiDResults(ATT={self.overall_att:.4f}{sig}, " + f"SE={self.overall_se:.4f}, " + f"n_groups={len(self.groups)}, " + f"n_treated_obs={self.n_treated_obs})" + ) + + def summary(self, alpha: Optional[float] = None) -> str: + """ + Generate formatted summary of estimation results. + + Parameters + ---------- + alpha : float, optional + Significance level. Defaults to alpha used in estimation. + + Returns + ------- + str + Formatted summary. + """ + alpha = alpha or self.alpha + conf_level = int((1 - alpha) * 100) + + lines = [ + "=" * 85, + "Two-Stage DiD Estimator Results (Gardner 2022)".center(85), + "=" * 85, + "", + f"{'Total observations:':<30} {self.n_obs:>10}", + f"{'Treated observations:':<30} {self.n_treated_obs:>10}", + f"{'Untreated observations:':<30} {self.n_untreated_obs:>10}", + f"{'Treated units:':<30} {self.n_treated_units:>10}", + f"{'Control units:':<30} {self.n_control_units:>10}", + f"{'Treatment cohorts:':<30} {len(self.groups):>10}", + f"{'Time periods:':<30} {len(self.time_periods):>10}", + "", + ] + + # Overall ATT + lines.extend( + [ + "-" * 85, + "Overall Average Treatment Effect on the Treated".center(85), + "-" * 85, + f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + t_str = ( + f"{self.overall_t_stat:>10.3f}" if np.isfinite(self.overall_t_stat) else f"{'NaN':>10}" + ) + p_str = ( + f"{self.overall_p_value:>10.4f}" + if np.isfinite(self.overall_p_value) + else f"{'NaN':>10}" + ) + sig = _get_significance_stars(self.overall_p_value) + + lines.extend( + [ + f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} " + f"{t_str} {p_str} {sig:>6}", + "-" * 85, + "", + f"{conf_level}% Confidence Interval: " + f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]", + "", + ] + ) + + # Event study effects + if self.event_study_effects: + lines.extend( + [ + "-" * 85, + "Event Study (Dynamic) Effects".center(85), + "-" * 85, + f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + for h in sorted(self.event_study_effects.keys()): + eff = self.event_study_effects[h] + if eff.get("n_obs", 1) == 0: + # Reference period marker + lines.append( + f"[ref: {h}]" f"{'0.0000':>17} {'---':>12} {'---':>10} {'---':>10} {'':>6}" + ) + elif np.isnan(eff["effect"]): + lines.append(f"{h:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}") + else: + e_sig = _get_significance_stars(eff["p_value"]) + e_t = ( + f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}" + ) + e_p = ( + f"{eff['p_value']:>10.4f}" + if np.isfinite(eff["p_value"]) + else f"{'NaN':>10}" + ) + lines.append( + f"{h:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " + f"{e_t} {e_p} {e_sig:>6}" + ) + + lines.extend(["-" * 85, ""]) + + # Group effects + if self.group_effects: + lines.extend( + [ + "-" * 85, + "Group (Cohort) Effects".center(85), + "-" * 85, + f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + for g in sorted(self.group_effects.keys()): + eff = self.group_effects[g] + if np.isnan(eff["effect"]): + lines.append(f"{g:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}") + else: + g_sig = _get_significance_stars(eff["p_value"]) + g_t = ( + f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}" + ) + g_p = ( + f"{eff['p_value']:>10.4f}" + if np.isfinite(eff["p_value"]) + else f"{'NaN':>10}" + ) + lines.append( + f"{g:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " + f"{g_t} {g_p} {g_sig:>6}" + ) + + lines.extend(["-" * 85, ""]) + + lines.extend( + [ + "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", + "=" * 85, + ] + ) + + return "\n".join(lines) + + def print_summary(self, alpha: Optional[float] = None) -> None: + """Print summary to stdout.""" + print(self.summary(alpha)) + + def to_dataframe(self, level: str = "event_study") -> pd.DataFrame: + """ + Convert results to DataFrame. + + Parameters + ---------- + level : str, default="event_study" + Level of aggregation: + - "event_study": Event study effects by relative time + - "group": Group (cohort) effects + - "observation": Per-observation treatment effects + + Returns + ------- + pd.DataFrame + Results as DataFrame. + """ + if level == "observation": + return self.treatment_effects.copy() + + elif level == "event_study": + if self.event_study_effects is None: + raise ValueError( + "Event study effects not computed. " + "Use aggregate='event_study' or aggregate='all'." + ) + rows = [] + for h, data in sorted(self.event_study_effects.items()): + rows.append( + { + "relative_period": h, + "effect": data["effect"], + "se": data["se"], + "t_stat": data["t_stat"], + "p_value": data["p_value"], + "conf_int_lower": data["conf_int"][0], + "conf_int_upper": data["conf_int"][1], + "n_obs": data.get("n_obs", np.nan), + } + ) + return pd.DataFrame(rows) + + elif level == "group": + if self.group_effects is None: + raise ValueError( + "Group effects not computed. " "Use aggregate='group' or aggregate='all'." + ) + rows = [] + for g, data in sorted(self.group_effects.items()): + rows.append( + { + "group": g, + "effect": data["effect"], + "se": data["se"], + "t_stat": data["t_stat"], + "p_value": data["p_value"], + "conf_int_lower": data["conf_int"][0], + "conf_int_upper": data["conf_int"][1], + "n_obs": data.get("n_obs", np.nan), + } + ) + return pd.DataFrame(rows) + + else: + raise ValueError( + f"Unknown level: {level}. Use 'event_study', 'group', or 'observation'." + ) + + @property + def is_significant(self) -> bool: + """Check if overall ATT is significant.""" + return bool(self.overall_p_value < self.alpha) + + @property + def significance_stars(self) -> str: + """Significance stars for overall ATT.""" + return _get_significance_stars(self.overall_p_value) + + +# ============================================================================= +# Main Estimator +# ============================================================================= + + +class TwoStageDiD: + """ + Gardner (2022) two-stage Difference-in-Differences estimator. + + This estimator addresses TWFE bias under heterogeneous treatment + effects by: + 1. Estimating unit + time FEs on untreated observations only + 2. Residualizing ALL outcomes using estimated FEs + 3. Regressing residualized outcomes on treatment indicators + + Point estimates are identical to ImputationDiD (Borusyak et al. 2024). + The key difference is the variance estimator: TwoStageDiD uses a GMM + sandwich variance that accounts for first-stage estimation uncertainty, + while ImputationDiD uses the conservative variance from Theorem 3. + + Parameters + ---------- + anticipation : int, default=0 + Number of periods before treatment where effects may occur. + alpha : float, default=0.05 + Significance level for confidence intervals. + cluster : str, optional + Column name for cluster-robust standard errors. + If None, clusters at the unit level by default. + n_bootstrap : int, default=0 + Number of bootstrap iterations. If 0, uses analytical GMM + sandwich inference. + seed : int, optional + Random seed for reproducibility. + rank_deficient_action : str, default="warn" + Action when design matrix is rank-deficient: + - "warn": Issue warning and drop linearly dependent columns + - "error": Raise ValueError + - "silent": Drop columns silently + horizon_max : int, optional + Maximum event-study horizon. If set, event study effects are only + computed for |h| <= horizon_max. + + Attributes + ---------- + results_ : TwoStageDiDResults + Estimation results after calling fit(). + is_fitted_ : bool + Whether the model has been fitted. + + Examples + -------- + Basic usage: + + >>> from diff_diff import TwoStageDiD, generate_staggered_data + >>> data = generate_staggered_data(n_units=200, seed=42) + >>> est = TwoStageDiD() + >>> results = est.fit(data, outcome='outcome', unit='unit', + ... time='period', first_treat='first_treat') + >>> results.print_summary() + + With event study: + + >>> est = TwoStageDiD() + >>> results = est.fit(data, outcome='outcome', unit='unit', + ... time='period', first_treat='first_treat', + ... aggregate='event_study') + >>> from diff_diff import plot_event_study + >>> plot_event_study(results) + + Notes + ----- + The two-stage estimator uses ALL untreated observations (never-treated + + not-yet-treated periods of eventually-treated units) to estimate the + counterfactual model. + + References + ---------- + Gardner, J. (2022). Two-stage differences in differences. + arXiv:2207.05943. + Butts, K. & Gardner, J. (2022). did2s: Two-Stage + Difference-in-Differences. R Journal, 14(1), 162-173. + """ + + def __init__( + self, + anticipation: int = 0, + alpha: float = 0.05, + cluster: Optional[str] = None, + n_bootstrap: int = 0, + seed: Optional[int] = None, + rank_deficient_action: str = "warn", + horizon_max: Optional[int] = None, + ): + if rank_deficient_action not in ("warn", "error", "silent"): + raise ValueError( + f"rank_deficient_action must be 'warn', 'error', or 'silent', " + f"got '{rank_deficient_action}'" + ) + + self.anticipation = anticipation + self.alpha = alpha + self.cluster = cluster + self.n_bootstrap = n_bootstrap + self.seed = seed + self.rank_deficient_action = rank_deficient_action + self.horizon_max = horizon_max + + self.is_fitted_ = False + self.results_: Optional[TwoStageDiDResults] = None + + def fit( + self, + data: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]] = None, + aggregate: Optional[str] = None, + balance_e: Optional[int] = None, + ) -> TwoStageDiDResults: + """ + Fit the two-stage DiD estimator. + + Parameters + ---------- + data : pd.DataFrame + Panel data with unit and time identifiers. + outcome : str + Name of outcome variable column. + unit : str + Name of unit identifier column. + time : str + Name of time period column. + first_treat : str + Name of column indicating when unit was first treated. + Use 0 (or np.inf) for never-treated units. + covariates : list of str, optional + List of covariate column names. + aggregate : str, optional + Aggregation mode: None/"simple" (overall ATT only), + "event_study", "group", or "all". + balance_e : int, optional + When computing event study, restrict to cohorts observed at all + relative times in [-balance_e, max_h]. + + Returns + ------- + TwoStageDiDResults + Object containing all estimation results. + + Raises + ------ + ValueError + If required columns are missing or data validation fails. + """ + # ---- Data validation ---- + required_cols = [outcome, unit, time, first_treat] + if covariates: + required_cols.extend(covariates) + + missing = [c for c in required_cols if c not in data.columns] + if missing: + raise ValueError(f"Missing columns: {missing}") + + df = data.copy() + df[time] = pd.to_numeric(df[time]) + df[first_treat] = pd.to_numeric(df[first_treat]) + + # Validate absorbing treatment + ft_nunique = df.groupby(unit)[first_treat].nunique() + non_constant = ft_nunique[ft_nunique > 1] + if len(non_constant) > 0: + example_unit = non_constant.index[0] + example_vals = sorted(df.loc[df[unit] == example_unit, first_treat].unique()) + warnings.warn( + f"{len(non_constant)} unit(s) have non-constant '{first_treat}' " + f"values (e.g., unit '{example_unit}' has values {example_vals}). " + f"TwoStageDiD assumes treatment is an absorbing state " + f"(once treated, always treated) with a single treatment onset " + f"time per unit. Non-constant first_treat violates this assumption " + f"and may produce unreliable estimates.", + UserWarning, + stacklevel=2, + ) + df[first_treat] = df.groupby(unit)[first_treat].transform("first") + + # Identify treatment status + df["_never_treated"] = (df[first_treat] == 0) | (df[first_treat] == np.inf) + + # Check for always-treated units + min_time = df[time].min() + always_treated_mask = (~df["_never_treated"]) & (df[first_treat] <= min_time) + n_always_treated = df.loc[always_treated_mask, unit].nunique() + if n_always_treated > 0: + warnings.warn( + f"{n_always_treated} unit(s) are treated in all observed periods " + f"(first_treat <= {min_time}). These units have no untreated " + "observations and cannot contribute to the counterfactual model. " + "Excluding from estimation.", + UserWarning, + stacklevel=2, + ) + # Exclude always-treated units + always_treated_units = df.loc[always_treated_mask, unit].unique() + df = df[~df[unit].isin(always_treated_units)].copy() + + # Treatment indicator with anticipation + effective_treat = df[first_treat] - self.anticipation + df["_treated"] = (~df["_never_treated"]) & (df[time] >= effective_treat) + + # Partition into Omega_0 (untreated) and Omega_1 (treated) + omega_0_mask = ~df["_treated"] + omega_1_mask = df["_treated"] + + n_omega_0 = int(omega_0_mask.sum()) + n_omega_1 = int(omega_1_mask.sum()) + + if n_omega_0 == 0: + raise ValueError( + "No untreated observations found. Cannot estimate counterfactual model." + ) + if n_omega_1 == 0: + raise ValueError("No treated observations found. Nothing to estimate.") + + # Groups and time periods + time_periods = sorted(df[time].unique()) + treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0 and g != np.inf]) + + if len(treatment_groups) == 0: + raise ValueError("No treated units found. Check 'first_treat' column.") + + # Unit info + unit_info = ( + df.groupby(unit).agg({first_treat: "first", "_never_treated": "first"}).reset_index() + ) + n_treated_units = int((~unit_info["_never_treated"]).sum()) + units_in_omega_0 = df.loc[omega_0_mask, unit].unique() + n_control_units = len(units_in_omega_0) + + # Cluster variable + cluster_var = self.cluster if self.cluster is not None else unit + if self.cluster is not None and self.cluster not in df.columns: + raise ValueError( + f"Cluster column '{self.cluster}' not found in data. " + f"Available columns: {list(df.columns)}" + ) + + # Relative time + df["_rel_time"] = np.where( + ~df["_never_treated"], + df[time] - df[first_treat], + np.nan, + ) + + # ---- Stage 1: OLS on untreated observations ---- + unit_fe, time_fe, grand_mean, delta_hat, kept_cov_mask = self._fit_untreated_model( + df, outcome, unit, time, covariates, omega_0_mask + ) + + # ---- Rank condition checks ---- + treated_unit_ids = df.loc[omega_1_mask, unit].unique() + units_with_fe = set(unit_fe.keys()) + units_missing_fe = set(treated_unit_ids) - units_with_fe + + post_period_ids = df.loc[omega_1_mask, time].unique() + periods_with_fe = set(time_fe.keys()) + periods_missing_fe = set(post_period_ids) - periods_with_fe + + if units_missing_fe or periods_missing_fe: + parts = [] + if units_missing_fe: + sorted_missing = sorted(units_missing_fe) + parts.append( + f"{len(units_missing_fe)} treated unit(s) have no untreated " + f"periods (units: {sorted_missing[:5]}" + f"{'...' if len(units_missing_fe) > 5 else ''})" + ) + if periods_missing_fe: + sorted_missing = sorted(periods_missing_fe) + parts.append( + f"{len(periods_missing_fe)} post-treatment period(s) have no " + f"untreated units (periods: {sorted_missing[:5]}" + f"{'...' if len(periods_missing_fe) > 5 else ''})" + ) + msg = ( + "Rank condition violated: " + + "; ".join(parts) + + ". Affected treatment effects will be NaN." + ) + if self.rank_deficient_action == "error": + raise ValueError(msg) + elif self.rank_deficient_action == "warn": + warnings.warn(msg, UserWarning, stacklevel=2) + + # ---- Residualize ALL observations ---- + y_tilde = self._residualize( + df, outcome, unit, time, covariates, unit_fe, time_fe, grand_mean, delta_hat + ) + df["_y_tilde"] = y_tilde + + # ---- Stage 2: OLS of y_tilde on treatment indicators ---- + # Build design matrices and compute effects + GMM variance + ref_period = -1 - self.anticipation + + # Always compute overall ATT (static specification) + overall_att, overall_se = self._stage2_static( + df=df, + unit=unit, + time=time, + first_treat=first_treat, + covariates=covariates, + omega_0_mask=omega_0_mask, + omega_1_mask=omega_1_mask, + unit_fe=unit_fe, + time_fe=time_fe, + grand_mean=grand_mean, + delta_hat=delta_hat, + cluster_var=cluster_var, + kept_cov_mask=kept_cov_mask, + ) + + overall_t = ( + overall_att / overall_se if np.isfinite(overall_se) and overall_se > 0 else np.nan + ) + overall_p = compute_p_value(overall_t) + overall_ci = ( + compute_confidence_interval(overall_att, overall_se, self.alpha) + if np.isfinite(overall_se) and overall_se > 0 + else (np.nan, np.nan) + ) + + # Event study and group aggregation + event_study_effects = None + group_effects = None + + if aggregate in ("event_study", "all"): + event_study_effects = self._stage2_event_study( + df=df, + unit=unit, + time=time, + first_treat=first_treat, + covariates=covariates, + omega_0_mask=omega_0_mask, + omega_1_mask=omega_1_mask, + unit_fe=unit_fe, + time_fe=time_fe, + grand_mean=grand_mean, + delta_hat=delta_hat, + cluster_var=cluster_var, + treatment_groups=treatment_groups, + ref_period=ref_period, + balance_e=balance_e, + kept_cov_mask=kept_cov_mask, + ) + + if aggregate in ("group", "all"): + group_effects = self._stage2_group( + df=df, + unit=unit, + time=time, + first_treat=first_treat, + covariates=covariates, + omega_0_mask=omega_0_mask, + omega_1_mask=omega_1_mask, + unit_fe=unit_fe, + time_fe=time_fe, + grand_mean=grand_mean, + delta_hat=delta_hat, + cluster_var=cluster_var, + treatment_groups=treatment_groups, + kept_cov_mask=kept_cov_mask, + ) + + # Build treatment effects DataFrame + treated_df = df.loc[omega_1_mask, [unit, time, "_y_tilde", "_rel_time"]].copy() + treated_df = treated_df.rename(columns={"_y_tilde": "tau_hat", "_rel_time": "rel_time"}) + tau_finite = treated_df["tau_hat"].notna() & np.isfinite(treated_df["tau_hat"].values) + n_valid_te = int(tau_finite.sum()) + if n_valid_te > 0: + treated_df["weight"] = np.where(tau_finite, 1.0 / n_valid_te, 0.0) + else: + treated_df["weight"] = 0.0 + + # ---- Bootstrap ---- + bootstrap_results = None + if self.n_bootstrap > 0: + try: + bootstrap_results = self._run_bootstrap( + df=df, + unit=unit, + time=time, + first_treat=first_treat, + covariates=covariates, + omega_0_mask=omega_0_mask, + omega_1_mask=omega_1_mask, + unit_fe=unit_fe, + time_fe=time_fe, + grand_mean=grand_mean, + delta_hat=delta_hat, + cluster_var=cluster_var, + kept_cov_mask=kept_cov_mask, + treatment_groups=treatment_groups, + ref_period=ref_period, + balance_e=balance_e, + original_att=overall_att, + original_event_study=event_study_effects, + original_group=group_effects, + aggregate=aggregate, + ) + except Exception as e: + warnings.warn( + f"Bootstrap failed: {e}. Skipping bootstrap inference.", + UserWarning, + stacklevel=2, + ) + + if bootstrap_results is not None: + # Update inference with bootstrap results + overall_se = bootstrap_results.overall_att_se + overall_t = ( + overall_att / overall_se + if np.isfinite(overall_se) and overall_se > 0 + else np.nan + ) + overall_p = bootstrap_results.overall_att_p_value + overall_ci = bootstrap_results.overall_att_ci + + # Update event study + if event_study_effects and bootstrap_results.event_study_ses: + for h in event_study_effects: + if ( + h in bootstrap_results.event_study_ses + and event_study_effects[h].get("n_obs", 1) > 0 + ): + event_study_effects[h]["se"] = bootstrap_results.event_study_ses[h] + event_study_effects[h]["conf_int"] = bootstrap_results.event_study_cis[ + h + ] + event_study_effects[h]["p_value"] = ( + bootstrap_results.event_study_p_values[h] + ) + eff_val = event_study_effects[h]["effect"] + se_val = event_study_effects[h]["se"] + event_study_effects[h]["t_stat"] = ( + eff_val / se_val if np.isfinite(se_val) and se_val > 0 else np.nan + ) + + # Update group effects + if group_effects and bootstrap_results.group_ses: + for g in group_effects: + if g in bootstrap_results.group_ses: + group_effects[g]["se"] = bootstrap_results.group_ses[g] + group_effects[g]["conf_int"] = bootstrap_results.group_cis[g] + group_effects[g]["p_value"] = bootstrap_results.group_p_values[g] + eff_val = group_effects[g]["effect"] + se_val = group_effects[g]["se"] + group_effects[g]["t_stat"] = ( + eff_val / se_val if np.isfinite(se_val) and se_val > 0 else np.nan + ) + + # Construct results + self.results_ = TwoStageDiDResults( + treatment_effects=treated_df, + overall_att=overall_att, + overall_se=overall_se, + overall_t_stat=overall_t, + overall_p_value=overall_p, + overall_conf_int=overall_ci, + event_study_effects=event_study_effects, + group_effects=group_effects, + groups=treatment_groups, + time_periods=time_periods, + n_obs=len(df), + n_treated_obs=n_omega_1, + n_untreated_obs=n_omega_0, + n_treated_units=n_treated_units, + n_control_units=n_control_units, + alpha=self.alpha, + bootstrap_results=bootstrap_results, + ) + + self.is_fitted_ = True + return self.results_ + + # ========================================================================= + # Stage 1: OLS on untreated observations + # ========================================================================= + + def _iterative_fe( + self, + y: np.ndarray, + unit_vals: np.ndarray, + time_vals: np.ndarray, + idx: pd.Index, + max_iter: int = 100, + tol: float = 1e-10, + ) -> Tuple[Dict[Any, float], Dict[Any, float]]: + """ + Estimate unit and time FE via iterative alternating projection. + + Returns + ------- + unit_fe : dict + Mapping from unit -> unit fixed effect. + time_fe : dict + Mapping from time -> time fixed effect. + """ + n = len(y) + alpha = np.zeros(n) + beta = np.zeros(n) + + with np.errstate(invalid="ignore", divide="ignore"): + for iteration in range(max_iter): + resid_after_alpha = y - alpha + beta_new = ( + pd.Series(resid_after_alpha, index=idx) + .groupby(time_vals) + .transform("mean") + .values + ) + + resid_after_beta = y - beta_new + alpha_new = ( + pd.Series(resid_after_beta, index=idx) + .groupby(unit_vals) + .transform("mean") + .values + ) + + max_change = max( + np.max(np.abs(alpha_new - alpha)), + np.max(np.abs(beta_new - beta)), + ) + alpha = alpha_new + beta = beta_new + if max_change < tol: + break + + unit_fe = pd.Series(alpha, index=idx).groupby(unit_vals).first().to_dict() + time_fe = pd.Series(beta, index=idx).groupby(time_vals).first().to_dict() + return unit_fe, time_fe + + @staticmethod + def _iterative_demean( + vals: np.ndarray, + unit_vals: np.ndarray, + time_vals: np.ndarray, + idx: pd.Index, + max_iter: int = 100, + tol: float = 1e-10, + ) -> np.ndarray: + """Demean a vector by iterative alternating projection.""" + result = vals.copy() + with np.errstate(invalid="ignore", divide="ignore"): + for _ in range(max_iter): + time_means = ( + pd.Series(result, index=idx).groupby(time_vals).transform("mean").values + ) + result_after_time = result - time_means + unit_means = ( + pd.Series(result_after_time, index=idx) + .groupby(unit_vals) + .transform("mean") + .values + ) + result_new = result_after_time - unit_means + if np.max(np.abs(result_new - result)) < tol: + result = result_new + break + result = result_new + return result + + def _fit_untreated_model( + self, + df: pd.DataFrame, + outcome: str, + unit: str, + time: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + ) -> Tuple[ + Dict[Any, float], Dict[Any, float], float, Optional[np.ndarray], Optional[np.ndarray] + ]: + """ + Stage 1: Estimate unit + time FE on untreated observations. + + Returns + ------- + unit_fe, time_fe, grand_mean, delta_hat, kept_cov_mask + """ + df_0 = df.loc[omega_0_mask] + + if covariates is None or len(covariates) == 0: + y = df_0[outcome].values.copy() + unit_fe, time_fe = self._iterative_fe( + y, df_0[unit].values, df_0[time].values, df_0.index + ) + return unit_fe, time_fe, 0.0, None, None + + else: + y = df_0[outcome].values.copy() + X_raw = df_0[covariates].values.copy() + units = df_0[unit].values + times = df_0[time].values + n_cov = len(covariates) + + y_dm = self._iterative_demean(y, units, times, df_0.index) + X_dm = np.column_stack( + [ + self._iterative_demean(X_raw[:, j], units, times, df_0.index) + for j in range(n_cov) + ] + ) + + result = solve_ols( + X_dm, + y_dm, + return_vcov=False, + rank_deficient_action=self.rank_deficient_action, + column_names=covariates, + ) + delta_hat = result[0] + kept_cov_mask = np.isfinite(delta_hat) + delta_hat_clean = np.where(np.isfinite(delta_hat), delta_hat, 0.0) + + y_adj = y - X_raw @ delta_hat_clean + unit_fe, time_fe = self._iterative_fe(y_adj, units, times, df_0.index) + + return unit_fe, time_fe, 0.0, delta_hat_clean, kept_cov_mask + + # ========================================================================= + # Residualization + # ========================================================================= + + def _residualize( + self, + df: pd.DataFrame, + outcome: str, + unit: str, + time: str, + covariates: Optional[List[str]], + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + grand_mean: float, + delta_hat: Optional[np.ndarray], + ) -> np.ndarray: + """ + Compute residualized outcome y_tilde for ALL observations. + + y_tilde_i = y_i - mu_hat_i - eta_hat_t [- X_i @ delta_hat] + """ + alpha_i = df[unit].map(unit_fe).values + beta_t = df[time].map(time_fe).values + + # Handle missing FE (NaN for units/periods not in untreated sample) + alpha_i = np.where(pd.isna(alpha_i), np.nan, alpha_i).astype(float) + beta_t = np.where(pd.isna(beta_t), np.nan, beta_t).astype(float) + + y_hat = grand_mean + alpha_i + beta_t + + if delta_hat is not None and covariates: + y_hat = y_hat + df[covariates].values @ delta_hat + + y_tilde = df[outcome].values - y_hat + return y_tilde + + # ========================================================================= + # Stage 2 specifications + # ========================================================================= + + def _stage2_static( + self, + df: pd.DataFrame, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + omega_1_mask: pd.Series, + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + grand_mean: float, + delta_hat: Optional[np.ndarray], + cluster_var: str, + kept_cov_mask: Optional[np.ndarray], + ) -> Tuple[float, float]: + """ + Static (simple ATT) Stage 2: OLS of y_tilde on D_it. + + Returns (att, se). + """ + y_tilde = df["_y_tilde"].values.copy() + + # Handle NaN y_tilde (from unidentified FEs — e.g., rank condition violations) + # Set to 0 so solve_ols doesn't reject; these obs have X_2=0 (untreated) + # or contribute NaN treatment effects (excluded from point estimate). + nan_mask = ~np.isfinite(y_tilde) + if nan_mask.any(): + y_tilde[nan_mask] = 0.0 + + D = omega_1_mask.values.astype(float) + # Zero out treatment indicator for NaN y_tilde obs (don't count in ATT) + D[nan_mask] = 0.0 + + # X_2: treatment indicator (no intercept) + X_2 = D.reshape(-1, 1) + + # Avoid degenerate case where all treated obs have NaN y_tilde + if D.sum() == 0: + return np.nan, np.nan + + # Stage 2 OLS for point estimate (discard naive SE) + coef, residuals, _ = solve_ols(X_2, y_tilde, return_vcov=False) + att = float(coef[0]) + + # GMM sandwich variance + eps_2 = y_tilde - X_2 @ coef # Stage 2 residuals + + V = self._compute_gmm_variance( + df=df, + unit=unit, + time=time, + covariates=covariates, + omega_0_mask=omega_0_mask, + unit_fe=unit_fe, + time_fe=time_fe, + delta_hat=delta_hat, + kept_cov_mask=kept_cov_mask, + X_2=X_2, + eps_2=eps_2, + cluster_ids=df[cluster_var].values, + ) + + se = float(np.sqrt(max(V[0, 0], 0.0))) + return att, se + + def _stage2_event_study( + self, + df: pd.DataFrame, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + omega_1_mask: pd.Series, + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + grand_mean: float, + delta_hat: Optional[np.ndarray], + cluster_var: str, + treatment_groups: List[Any], + ref_period: int, + balance_e: Optional[int], + kept_cov_mask: Optional[np.ndarray], + ) -> Dict[int, Dict[str, Any]]: + """Event study Stage 2: OLS of y_tilde on relative-time dummies.""" + y_tilde = df["_y_tilde"].values.copy() + # Handle NaN y_tilde (unidentified FEs) + nan_mask = ~np.isfinite(y_tilde) + if nan_mask.any(): + y_tilde[nan_mask] = 0.0 + rel_times = df["_rel_time"].values + n = len(df) + + # Get all horizons from treated observations + treated_rel = rel_times[omega_1_mask.values] + all_horizons = sorted(set(int(h) for h in treated_rel if np.isfinite(h))) + + # Apply horizon_max filter + if self.horizon_max is not None: + all_horizons = [h for h in all_horizons if abs(h) <= self.horizon_max] + + # Apply balance_e filter + if balance_e is not None: + cohort_rel_times = self._build_cohort_rel_times(df, first_treat) + balanced_cohorts = set() + if all_horizons: + max_h = max(all_horizons) + required_range = set(range(-balance_e, max_h + 1)) + for g, horizons in cohort_rel_times.items(): + if required_range.issubset(horizons): + balanced_cohorts.add(g) + balance_mask = ( + df[first_treat].isin(balanced_cohorts).values + if balanced_cohorts + else np.ones(n, dtype=bool) + ) + else: + balance_mask = np.ones(n, dtype=bool) + + # Remove reference period from estimation horizons + est_horizons = [h for h in all_horizons if h != ref_period] + + if len(est_horizons) == 0: + # No horizons to estimate — return just reference period + return { + ref_period: { + "effect": 0.0, + "se": 0.0, + "t_stat": np.nan, + "p_value": np.nan, + "conf_int": (0.0, 0.0), + "n_obs": 0, + } + } + + # Build Stage 2 design: one column per horizon (no intercept) + # Never-treated obs get all-zero rows (undefined relative time -> NaN) + # With no intercept, they contribute zero to X'_2 X_2 and X'_2 y_tilde + horizon_to_col = {h: j for j, h in enumerate(est_horizons)} + k = len(est_horizons) + X_2 = np.zeros((n, k)) + + for i in range(n): + if not balance_mask[i]: + continue + if nan_mask[i]: + continue # NaN y_tilde -> don't include in event study + h = rel_times[i] + if np.isfinite(h): + h_int = int(h) + if h_int in horizon_to_col: + X_2[i, horizon_to_col[h_int]] = 1.0 + + # Stage 2 OLS + coef, residuals, _ = solve_ols(X_2, y_tilde, return_vcov=False) + eps_2 = y_tilde - X_2 @ coef + + # GMM variance for full coefficient vector + V = self._compute_gmm_variance( + df=df, + unit=unit, + time=time, + covariates=covariates, + omega_0_mask=omega_0_mask, + unit_fe=unit_fe, + time_fe=time_fe, + delta_hat=delta_hat, + kept_cov_mask=kept_cov_mask, + X_2=X_2, + eps_2=eps_2, + cluster_ids=df[cluster_var].values, + ) + + # Build results dict + event_study_effects: Dict[int, Dict[str, Any]] = {} + + # Reference period marker + event_study_effects[ref_period] = { + "effect": 0.0, + "se": 0.0, + "t_stat": np.nan, + "p_value": np.nan, + "conf_int": (0.0, 0.0), + "n_obs": 0, + } + + for h in est_horizons: + j = horizon_to_col[h] + effect = float(coef[j]) + se = float(np.sqrt(max(V[j, j], 0.0))) + n_obs = int(np.sum(X_2[:, j])) + + t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan + p_val = compute_p_value(t_stat) + ci = ( + compute_confidence_interval(effect, se, self.alpha) + if np.isfinite(se) and se > 0 + else (np.nan, np.nan) + ) + + event_study_effects[h] = { + "effect": effect, + "se": se, + "t_stat": t_stat, + "p_value": p_val, + "conf_int": ci, + "n_obs": n_obs, + } + + return event_study_effects + + def _stage2_group( + self, + df: pd.DataFrame, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + omega_1_mask: pd.Series, + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + grand_mean: float, + delta_hat: Optional[np.ndarray], + cluster_var: str, + treatment_groups: List[Any], + kept_cov_mask: Optional[np.ndarray], + ) -> Dict[Any, Dict[str, Any]]: + """Group (cohort) Stage 2: OLS of y_tilde on cohort dummies.""" + y_tilde = df["_y_tilde"].values.copy() + nan_mask = ~np.isfinite(y_tilde) + if nan_mask.any(): + y_tilde[nan_mask] = 0.0 + n = len(df) + + # Build Stage 2 design: one column per cohort (no intercept) + group_to_col = {g: j for j, g in enumerate(treatment_groups)} + k = len(treatment_groups) + X_2 = np.zeros((n, k)) + + ft_vals = df[first_treat].values + treated_mask = omega_1_mask.values + for i in range(n): + if treated_mask[i] and not nan_mask[i]: + g = ft_vals[i] + if g in group_to_col: + X_2[i, group_to_col[g]] = 1.0 + + # Stage 2 OLS + coef, residuals, _ = solve_ols(X_2, y_tilde, return_vcov=False) + eps_2 = y_tilde - X_2 @ coef + + # GMM variance + V = self._compute_gmm_variance( + df=df, + unit=unit, + time=time, + covariates=covariates, + omega_0_mask=omega_0_mask, + unit_fe=unit_fe, + time_fe=time_fe, + delta_hat=delta_hat, + kept_cov_mask=kept_cov_mask, + X_2=X_2, + eps_2=eps_2, + cluster_ids=df[cluster_var].values, + ) + + group_effects: Dict[Any, Dict[str, Any]] = {} + for g in treatment_groups: + j = group_to_col[g] + effect = float(coef[j]) + se = float(np.sqrt(max(V[j, j], 0.0))) + n_obs = int(np.sum(X_2[:, j])) + + t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan + p_val = compute_p_value(t_stat) + ci = ( + compute_confidence_interval(effect, se, self.alpha) + if np.isfinite(se) and se > 0 + else (np.nan, np.nan) + ) + + group_effects[g] = { + "effect": effect, + "se": se, + "t_stat": t_stat, + "p_value": p_val, + "conf_int": ci, + "n_obs": n_obs, + } + + return group_effects + + # ========================================================================= + # GMM Sandwich Variance (Butts & Gardner 2022) + # ========================================================================= + + def _compute_gmm_variance( + self, + df: pd.DataFrame, + unit: str, + time: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + delta_hat: Optional[np.ndarray], + kept_cov_mask: Optional[np.ndarray], + X_2: np.ndarray, + eps_2: np.ndarray, + cluster_ids: np.ndarray, + ) -> np.ndarray: + """ + Compute GMM sandwich variance (Butts & Gardner 2022). + + Matches the R `did2s` source code implementation: uses the GLOBAL + Hessian inverse (not per-cluster) and NO finite-sample adjustments. + + The per-observation influence function is: + IF_i = (X'_2 X_2)^{-1} [gamma_hat' x_{10i} eps_{10i} - x_{2i} eps_{2i}] + + where gamma_hat = (X'_{10} X_{10})^{-1} (X'_1 X_2) uses the GLOBAL + cross-moment. + + The cluster-robust variance is: + V = (X'_2 X_2)^{-1} (sum_g S_g S'_g) (X'_2 X_2)^{-1} + S_g = gamma_hat' c_g - X'_{2g} eps_{2g} + c_g = X'_{10g} eps_{10g} + + Parameters + ---------- + X_2 : np.ndarray, shape (n, k) + Stage 2 design matrix (treatment indicators). + eps_2 : np.ndarray, shape (n,) + Stage 2 residuals. + cluster_ids : np.ndarray, shape (n,) + Cluster identifiers. + + Returns + ------- + np.ndarray, shape (k, k) + Variance-covariance matrix. + """ + n = len(df) + k = X_2.shape[1] + + # Exclude rank-deficient covariates + cov_list = covariates + if covariates and kept_cov_mask is not None and not np.all(kept_cov_mask): + cov_list = [c for c, k_ in zip(covariates, kept_cov_mask) if k_] + + # Build sparse FE design matrices X_1 (all obs) and X_10 (untreated only) + X_1_sparse, X_10_sparse, unit_to_idx, time_to_idx = self._build_fe_design( + df, unit, time, cov_list, omega_0_mask + ) + + p = X_1_sparse.shape[1] + + # eps_10 = Y - X_10 @ gamma_hat + # Untreated: stage 1 residual (Y - fitted). Treated: Y (X_10 rows = 0). + # Reconstruct Y from y_tilde: Y = y_tilde + fitted_stage1 + alpha_i = df[unit].map(unit_fe).values + beta_t = df[time].map(time_fe).values + alpha_i = np.where(pd.isna(alpha_i), 0.0, alpha_i).astype(float) + beta_t = np.where(pd.isna(beta_t), 0.0, beta_t).astype(float) + fitted_1 = alpha_i + beta_t + if delta_hat is not None and cov_list: + if kept_cov_mask is not None and not np.all(kept_cov_mask): + fitted_1 = fitted_1 + df[cov_list].values @ delta_hat[kept_cov_mask] + else: + fitted_1 = fitted_1 + df[cov_list].values @ delta_hat + + y_tilde = df["_y_tilde"].values + y_vals = y_tilde + fitted_1 # reconstruct Y + + # eps_10: for untreated, stage 1 residual; for treated, Y_i (since X_10 rows = 0) + eps_10 = np.empty(n) + omega_0 = omega_0_mask.values + eps_10[omega_0] = y_vals[omega_0] - fitted_1[omega_0] # Stage 1 residual + eps_10[~omega_0] = y_vals[~omega_0] # x_{10i} = 0, so eps_10 = Y + + # 1. gamma_hat = (X'_{10} X_{10})^{-1} (X'_1 X_2) [p x k] + XtX_10 = X_10_sparse.T @ X_10_sparse # (p x p) sparse + Xt1_X2 = X_1_sparse.T @ X_2 # (p x k) dense + + try: + solve_XtX = sparse_factorized(XtX_10.tocsc()) + if Xt1_X2.ndim == 1: + gamma_hat = solve_XtX(Xt1_X2).reshape(-1, 1) + else: + gamma_hat = np.column_stack( + [solve_XtX(Xt1_X2[:, j]) for j in range(Xt1_X2.shape[1])] + ) + except RuntimeError: + # Singular matrix — fall back to dense least-squares + gamma_hat = np.linalg.lstsq(XtX_10.toarray(), Xt1_X2, rcond=None)[0] + if gamma_hat.ndim == 1: + gamma_hat = gamma_hat.reshape(-1, 1) + + # 2. Per-cluster Stage 1 scores: c_g = X'_{10g} eps_{10g} + # Only untreated obs have non-zero X_10 rows + weighted_X10 = X_10_sparse.multiply(eps_10[:, None]) # sparse element-wise + + unique_clusters, cluster_indices = np.unique(cluster_ids, return_inverse=True) + G = len(unique_clusters) + + # Aggregate sparse rows by cluster using column-wise np.add.at + weighted_X10_csc = weighted_X10.tocsc() + c_by_cluster = np.zeros((G, p)) + for j_col in range(p): + col_data = weighted_X10_csc.getcol(j_col).toarray().ravel() + np.add.at(c_by_cluster[:, j_col], cluster_indices, col_data) + + # 3. Per-cluster Stage 2 scores: X'_{2g} eps_{2g} + weighted_X2 = X_2 * eps_2[:, None] # (n x k) dense + s2_by_cluster = np.zeros((G, k)) + for j_col in range(k): + np.add.at(s2_by_cluster[:, j_col], cluster_indices, weighted_X2[:, j_col]) + + # 4. S_g = gamma_hat' c_g - X'_{2g} eps_{2g} + with np.errstate(invalid="ignore", divide="ignore", over="ignore"): + correction = c_by_cluster @ gamma_hat # (G x p) @ (p x k) = (G x k) + # Replace NaN/inf from overflow (rank-deficient FE) with 0 + np.nan_to_num(correction, copy=False, nan=0.0, posinf=0.0, neginf=0.0) + S = correction - s2_by_cluster # (G x k) + + # 5. Meat: sum_g S_g S'_g = S' S + with np.errstate(invalid="ignore", over="ignore"): + meat = S.T @ S # (k x k) + + # 6. Bread: (X'_2 X_2)^{-1} + with np.errstate(invalid="ignore", over="ignore", divide="ignore"): + XtX_2 = X_2.T @ X_2 + try: + bread = np.linalg.solve(XtX_2, np.eye(k)) + except np.linalg.LinAlgError: + bread = np.linalg.lstsq(XtX_2, np.eye(k), rcond=None)[0] + + # 7. V = bread @ meat @ bread + V = bread @ meat @ bread + return V + + def _build_fe_design( + self, + df: pd.DataFrame, + unit: str, + time: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + ) -> Tuple[sparse.csr_matrix, sparse.csr_matrix, Dict[Any, int], Dict[Any, int]]: + """ + Build sparse FE design matrices X_1 (all obs) and X_10 (untreated rows only). + + Column layout: [unit_0, ..., unit_{U-2}, time_0, ..., time_{T-2}, cov_1, ..., cov_C] + (Drop first unit and first time for identification.) + + X_10 is identical to X_1 except that rows for treated observations are zeroed out. + + Returns + ------- + X_1_sparse : sparse.csr_matrix, shape (n, p) + X_10_sparse : sparse.csr_matrix, shape (n, p) + unit_to_idx : dict + time_to_idx : dict + """ + n = len(df) + unit_vals = df[unit].values + time_vals = df[time].values + omega_0 = omega_0_mask.values + + all_units = np.unique(unit_vals) + all_times = np.unique(time_vals) + unit_to_idx = {u: i for i, u in enumerate(all_units)} + time_to_idx = {t: i for i, t in enumerate(all_times)} + n_units = len(all_units) + n_times = len(all_times) + n_cov = len(covariates) if covariates else 0 + n_fe_cols = (n_units - 1) + (n_times - 1) + + def _build_rows(mask=None): + """Build sparse matrix for given observation mask.""" + # Unit dummies (drop first) + u_indices = np.array([unit_to_idx[u] for u in unit_vals]) + u_mask = u_indices > 0 + if mask is not None: + u_mask = u_mask & mask + + u_rows = np.arange(n)[u_mask] + u_cols = u_indices[u_mask] - 1 + + # Time dummies (drop first) + t_indices = np.array([time_to_idx[t] for t in time_vals]) + t_mask = t_indices > 0 + if mask is not None: + t_mask = t_mask & mask + + t_rows = np.arange(n)[t_mask] + t_cols = (n_units - 1) + t_indices[t_mask] - 1 + + rows = np.concatenate([u_rows, t_rows]) + cols = np.concatenate([u_cols, t_cols]) + data = np.ones(len(rows)) + + A_fe = sparse.csr_matrix((data, (rows, cols)), shape=(n, n_fe_cols)) + + if n_cov > 0: + cov_data = df[covariates].values.copy() + if mask is not None: + cov_data[~mask] = 0.0 + A_cov = sparse.csr_matrix(cov_data) + A = sparse.hstack([A_fe, A_cov], format="csr") + else: + A = A_fe + + return A + + X_1 = _build_rows(mask=None) + X_10 = _build_rows(mask=omega_0) + + return X_1, X_10, unit_to_idx, time_to_idx + + # ========================================================================= + # Bootstrap + # ========================================================================= + + def _compute_cluster_S_scores( + self, + df: pd.DataFrame, + unit: str, + time: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + delta_hat: Optional[np.ndarray], + kept_cov_mask: Optional[np.ndarray], + X_2: np.ndarray, + eps_2: np.ndarray, + cluster_ids: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Compute per-cluster S_g scores for bootstrap. + + Returns + ------- + S : np.ndarray, shape (G, k) + Per-cluster influence scores. + bread : np.ndarray, shape (k, k) + (X'_2 X_2)^{-1}. + unique_clusters : np.ndarray + Unique cluster identifiers. + """ + n = len(df) + k = X_2.shape[1] + + cov_list = covariates + if covariates and kept_cov_mask is not None and not np.all(kept_cov_mask): + cov_list = [c for c, k_ in zip(covariates, kept_cov_mask) if k_] + + X_1_sparse, X_10_sparse, _, _ = self._build_fe_design( + df, unit, time, cov_list, omega_0_mask + ) + p = X_1_sparse.shape[1] + + # Reconstruct Y and compute eps_10 + alpha_i = df[unit].map(unit_fe).values + beta_t = df[time].map(time_fe).values + alpha_i = np.where(pd.isna(alpha_i), 0.0, alpha_i).astype(float) + beta_t = np.where(pd.isna(beta_t), 0.0, beta_t).astype(float) + fitted_1 = alpha_i + beta_t + if delta_hat is not None and cov_list: + if kept_cov_mask is not None and not np.all(kept_cov_mask): + fitted_1 = fitted_1 + df[cov_list].values @ delta_hat[kept_cov_mask] + else: + fitted_1 = fitted_1 + df[cov_list].values @ delta_hat + + y_tilde = df["_y_tilde"].values + y_vals = y_tilde + fitted_1 + + eps_10 = np.empty(n) + omega_0 = omega_0_mask.values + eps_10[omega_0] = y_vals[omega_0] - fitted_1[omega_0] + eps_10[~omega_0] = y_vals[~omega_0] + + # gamma_hat + XtX_10 = X_10_sparse.T @ X_10_sparse + Xt1_X2 = X_1_sparse.T @ X_2 + + try: + solve_XtX = sparse_factorized(XtX_10.tocsc()) + if Xt1_X2.ndim == 1: + gamma_hat = solve_XtX(Xt1_X2).reshape(-1, 1) + else: + gamma_hat = np.column_stack( + [solve_XtX(Xt1_X2[:, j]) for j in range(Xt1_X2.shape[1])] + ) + except RuntimeError: + gamma_hat = np.linalg.lstsq(XtX_10.toarray(), Xt1_X2, rcond=None)[0] + if gamma_hat.ndim == 1: + gamma_hat = gamma_hat.reshape(-1, 1) + + # Per-cluster aggregation + weighted_X10 = X_10_sparse.multiply(eps_10[:, None]) + unique_clusters, cluster_indices = np.unique(cluster_ids, return_inverse=True) + G = len(unique_clusters) + + weighted_X10_csc = weighted_X10.tocsc() + c_by_cluster = np.zeros((G, p)) + for j_col in range(p): + col_data = weighted_X10_csc.getcol(j_col).toarray().ravel() + np.add.at(c_by_cluster[:, j_col], cluster_indices, col_data) + + weighted_X2 = X_2 * eps_2[:, None] + s2_by_cluster = np.zeros((G, k)) + for j_col in range(k): + np.add.at(s2_by_cluster[:, j_col], cluster_indices, weighted_X2[:, j_col]) + + correction = c_by_cluster @ gamma_hat + S = correction - s2_by_cluster + + # Bread + XtX_2 = X_2.T @ X_2 + try: + bread = np.linalg.solve(XtX_2, np.eye(k)) + except np.linalg.LinAlgError: + bread = np.linalg.lstsq(XtX_2, np.eye(k), rcond=None)[0] + + return S, bread, unique_clusters + + def _run_bootstrap( + self, + df: pd.DataFrame, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + omega_1_mask: pd.Series, + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + grand_mean: float, + delta_hat: Optional[np.ndarray], + cluster_var: str, + kept_cov_mask: Optional[np.ndarray], + treatment_groups: List[Any], + ref_period: int, + balance_e: Optional[int], + original_att: float, + original_event_study: Optional[Dict[int, Dict[str, Any]]], + original_group: Optional[Dict[Any, Dict[str, Any]]], + aggregate: Optional[str], + ) -> TwoStageBootstrapResults: + """Run multiplier bootstrap on GMM influence function.""" + if self.n_bootstrap < 50: + warnings.warn( + f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 " + "for reliable inference.", + UserWarning, + stacklevel=3, + ) + + rng = np.random.default_rng(self.seed) + + from diff_diff.staggered_bootstrap import _generate_bootstrap_weights_batch + + y_tilde = df["_y_tilde"].values + n = len(df) + cluster_ids = df[cluster_var].values + + # --- Static specification bootstrap --- + D = omega_1_mask.values.astype(float) + X_2_static = D.reshape(-1, 1) + coef_static = solve_ols(X_2_static, y_tilde, return_vcov=False)[0] + eps_2_static = y_tilde - X_2_static @ coef_static + + S_static, bread_static, unique_clusters = self._compute_cluster_S_scores( + df=df, + unit=unit, + time=time, + covariates=covariates, + omega_0_mask=omega_0_mask, + unit_fe=unit_fe, + time_fe=time_fe, + delta_hat=delta_hat, + kept_cov_mask=kept_cov_mask, + X_2=X_2_static, + eps_2=eps_2_static, + cluster_ids=cluster_ids, + ) + + n_clusters = len(unique_clusters) + all_weights = _generate_bootstrap_weights_batch( + self.n_bootstrap, n_clusters, "rademacher", rng + ) + + # T_b = bread @ (sum_g w_bg * S_g) = bread @ (W @ S)' per boot + # IF_b = bread @ S_g for each cluster, then perturb + # boot_coef = all_weights @ S_static @ bread_static.T → (B, k) + # For static (k=1): boot_att = all_weights @ S_static @ bread_static.T + boot_att_vec = all_weights @ S_static # (B, 1) + boot_att_vec = boot_att_vec @ bread_static.T # (B, 1) + boot_overall = boot_att_vec[:, 0] + + boot_overall_shifted = boot_overall + original_att + overall_se = float(np.std(boot_overall, ddof=1)) + overall_ci = ( + self._compute_percentile_ci(boot_overall_shifted, self.alpha) + if overall_se > 0 + else (np.nan, np.nan) + ) + overall_p = ( + self._compute_bootstrap_pvalue(original_att, boot_overall_shifted) + if overall_se > 0 + else np.nan + ) + + # --- Event study bootstrap --- + event_study_ses = None + event_study_cis = None + event_study_p_values = None + + if original_event_study and aggregate in ("event_study", "all"): + # Recompute S scores for event study specification + rel_times = df["_rel_time"].values + treated_rel = rel_times[omega_1_mask.values] + all_horizons = sorted(set(int(h) for h in treated_rel if np.isfinite(h))) + if self.horizon_max is not None: + all_horizons = [h for h in all_horizons if abs(h) <= self.horizon_max] + + if balance_e is not None: + cohort_rel_times = self._build_cohort_rel_times(df, first_treat) + balanced_cohorts = set() + if all_horizons: + max_h = max(all_horizons) + required_range = set(range(-balance_e, max_h + 1)) + for g, horizons in cohort_rel_times.items(): + if required_range.issubset(horizons): + balanced_cohorts.add(g) + balance_mask = ( + df[first_treat].isin(balanced_cohorts).values + if balanced_cohorts + else np.ones(n, dtype=bool) + ) + else: + balance_mask = np.ones(n, dtype=bool) + + est_horizons = [h for h in all_horizons if h != ref_period] + if est_horizons: + horizon_to_col = {h: j for j, h in enumerate(est_horizons)} + k_es = len(est_horizons) + X_2_es = np.zeros((n, k_es)) + for i in range(n): + if not balance_mask[i]: + continue + h = rel_times[i] + if np.isfinite(h): + h_int = int(h) + if h_int in horizon_to_col: + X_2_es[i, horizon_to_col[h_int]] = 1.0 + + coef_es = solve_ols(X_2_es, y_tilde, return_vcov=False)[0] + eps_2_es = y_tilde - X_2_es @ coef_es + + S_es, bread_es, _ = self._compute_cluster_S_scores( + df=df, + unit=unit, + time=time, + covariates=covariates, + omega_0_mask=omega_0_mask, + unit_fe=unit_fe, + time_fe=time_fe, + delta_hat=delta_hat, + kept_cov_mask=kept_cov_mask, + X_2=X_2_es, + eps_2=eps_2_es, + cluster_ids=cluster_ids, + ) + + # boot_coef_es: (B, k_es) + boot_coef_es = (all_weights @ S_es) @ bread_es.T + + event_study_ses = {} + event_study_cis = {} + event_study_p_values = {} + for h in original_event_study: + if original_event_study[h].get("n_obs", 0) == 0: + continue + if h not in horizon_to_col: + continue + j = horizon_to_col[h] + orig_eff = original_event_study[h]["effect"] + boot_h = boot_coef_es[:, j] + se_h = float(np.std(boot_h, ddof=1)) + event_study_ses[h] = se_h + if se_h > 0 and np.isfinite(orig_eff): + shifted_h = boot_h + orig_eff + event_study_p_values[h] = self._compute_bootstrap_pvalue( + orig_eff, shifted_h + ) + event_study_cis[h] = self._compute_percentile_ci(shifted_h, self.alpha) + else: + event_study_p_values[h] = np.nan + event_study_cis[h] = (np.nan, np.nan) + + # --- Group bootstrap --- + group_ses = None + group_cis = None + group_p_values = None + + if original_group and aggregate in ("group", "all"): + group_to_col = {g: j for j, g in enumerate(treatment_groups)} + k_grp = len(treatment_groups) + X_2_grp = np.zeros((n, k_grp)) + ft_vals = df[first_treat].values + treated_mask = omega_1_mask.values + for i in range(n): + if treated_mask[i]: + g = ft_vals[i] + if g in group_to_col: + X_2_grp[i, group_to_col[g]] = 1.0 + + coef_grp = solve_ols(X_2_grp, y_tilde, return_vcov=False)[0] + eps_2_grp = y_tilde - X_2_grp @ coef_grp + + S_grp, bread_grp, _ = self._compute_cluster_S_scores( + df=df, + unit=unit, + time=time, + covariates=covariates, + omega_0_mask=omega_0_mask, + unit_fe=unit_fe, + time_fe=time_fe, + delta_hat=delta_hat, + kept_cov_mask=kept_cov_mask, + X_2=X_2_grp, + eps_2=eps_2_grp, + cluster_ids=cluster_ids, + ) + + boot_coef_grp = (all_weights @ S_grp) @ bread_grp.T + + group_ses = {} + group_cis = {} + group_p_values = {} + for g in original_group: + if g not in group_to_col: + continue + j = group_to_col[g] + orig_eff = original_group[g]["effect"] + boot_g = boot_coef_grp[:, j] + se_g = float(np.std(boot_g, ddof=1)) + group_ses[g] = se_g + if se_g > 0 and np.isfinite(orig_eff): + shifted_g = boot_g + orig_eff + group_p_values[g] = self._compute_bootstrap_pvalue(orig_eff, shifted_g) + group_cis[g] = self._compute_percentile_ci(shifted_g, self.alpha) + else: + group_p_values[g] = np.nan + group_cis[g] = (np.nan, np.nan) + + return TwoStageBootstrapResults( + n_bootstrap=self.n_bootstrap, + weight_type="rademacher", + alpha=self.alpha, + overall_att_se=overall_se, + overall_att_ci=overall_ci, + overall_att_p_value=overall_p, + event_study_ses=event_study_ses, + event_study_cis=event_study_cis, + event_study_p_values=event_study_p_values, + group_ses=group_ses, + group_cis=group_cis, + group_p_values=group_p_values, + bootstrap_distribution=boot_overall_shifted, + ) + + # ========================================================================= + # Bootstrap helpers + # ========================================================================= + + @staticmethod + def _compute_percentile_ci( + boot_dist: np.ndarray, + alpha: float, + ) -> Tuple[float, float]: + """Compute percentile confidence interval from bootstrap distribution.""" + lower = float(np.percentile(boot_dist, alpha / 2 * 100)) + upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100)) + return (lower, upper) + + @staticmethod + def _compute_bootstrap_pvalue( + original_effect: float, + boot_dist: np.ndarray, + ) -> float: + """Compute two-sided bootstrap p-value.""" + if original_effect >= 0: + p_one_sided = float(np.mean(boot_dist <= 0)) + else: + p_one_sided = float(np.mean(boot_dist >= 0)) + p_value = min(2 * p_one_sided, 1.0) + p_value = max(p_value, 1 / (len(boot_dist) + 1)) + return p_value + + # ========================================================================= + # Utility + # ========================================================================= + + @staticmethod + def _build_cohort_rel_times( + df: pd.DataFrame, + first_treat: str, + ) -> Dict[Any, Set[int]]: + """Build mapping of cohort -> set of observed relative times.""" + treated_mask = ~df["_never_treated"] + treated_df = df.loc[treated_mask] + result: Dict[Any, Set[int]] = {} + ft_vals = treated_df[first_treat].values + rt_vals = treated_df["_rel_time"].values + for i in range(len(treated_df)): + h = rt_vals[i] + if np.isfinite(h): + result.setdefault(ft_vals[i], set()).add(int(h)) + return result + + # ========================================================================= + # sklearn-compatible interface + # ========================================================================= + + def get_params(self) -> Dict[str, Any]: + """Get estimator parameters (sklearn-compatible).""" + return { + "anticipation": self.anticipation, + "alpha": self.alpha, + "cluster": self.cluster, + "n_bootstrap": self.n_bootstrap, + "seed": self.seed, + "rank_deficient_action": self.rank_deficient_action, + "horizon_max": self.horizon_max, + } + + def set_params(self, **params) -> "TwoStageDiD": + """Set estimator parameters (sklearn-compatible).""" + for key, value in params.items(): + if hasattr(self, key): + setattr(self, key, value) + else: + raise ValueError(f"Unknown parameter: {key}") + return self + + def summary(self) -> str: + """Get summary of estimation results.""" + if not self.is_fitted_: + raise RuntimeError("Model must be fitted before calling summary()") + assert self.results_ is not None + return self.results_.summary() + + def print_summary(self) -> None: + """Print summary to stdout.""" + print(self.summary()) + + +# ============================================================================= +# Convenience function +# ============================================================================= + + +def two_stage_did( + data: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]] = None, + aggregate: Optional[str] = None, + balance_e: Optional[int] = None, + **kwargs, +) -> TwoStageDiDResults: + """ + Convenience function for two-stage DiD estimation. + + This is a shortcut for creating a TwoStageDiD estimator and calling fit(). + + Parameters + ---------- + data : pd.DataFrame + Panel data. + outcome : str + Outcome variable column name. + unit : str + Unit identifier column name. + time : str + Time period column name. + first_treat : str + Column indicating first treatment period (0 for never-treated). + covariates : list of str, optional + Covariate column names. + aggregate : str, optional + Aggregation mode: None, "simple", "event_study", "group", "all". + balance_e : int, optional + Balance event study to cohorts observed at all relative times. + **kwargs + Additional keyword arguments passed to TwoStageDiD constructor. + + Returns + ------- + TwoStageDiDResults + Estimation results. + + Examples + -------- + >>> from diff_diff import two_stage_did, generate_staggered_data + >>> data = generate_staggered_data(seed=42) + >>> results = two_stage_did(data, 'outcome', 'unit', 'period', + ... 'first_treat', aggregate='event_study') + >>> results.print_summary() + """ + est = TwoStageDiD(**kwargs) + return est.fit( + data, + outcome=outcome, + unit=unit, + time=time, + first_treat=first_treat, + covariates=covariates, + aggregate=aggregate, + balance_e=balance_e, + ) diff --git a/diff_diff/visualization.py b/diff_diff/visualization.py index 6903d415..71fb243b 100644 --- a/diff_diff/visualization.py +++ b/diff_diff/visualization.py @@ -18,6 +18,7 @@ from diff_diff.staggered import CallawaySantAnnaResults from diff_diff.imputation import ImputationDiDResults from diff_diff.sun_abraham import SunAbrahamResults + from diff_diff.two_stage import TwoStageDiDResults # Type alias for results that can be plotted PlottableResults = Union[ @@ -25,6 +26,7 @@ "CallawaySantAnnaResults", "SunAbrahamResults", "ImputationDiDResults", + "TwoStageDiDResults", pd.DataFrame, ] diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 4f81ebc0..b8e7fc5a 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -12,6 +12,7 @@ This document provides the academic foundations and key implementation requireme - [CallawaySantAnna](#callawaysantanna) - [SunAbraham](#sunabraham) - [ImputationDiD](#imputationdid) + - [TwoStageDiD](#twostagedid) 3. [Advanced Estimators](#advanced-estimators) - [SyntheticDiD](#syntheticdid) - [TripleDifference](#tripledifference) @@ -565,6 +566,76 @@ Y_it = alpha_i + beta_t [+ X'_it * delta] + W'_it * gamma + epsilon_it --- +## TwoStageDiD + +**Primary source:** [Gardner, J. (2022). Two-stage differences in differences. arXiv:2207.05943.](https://arxiv.org/abs/2207.05943) + +**Key implementation requirements:** + +*Assumption checks / warnings:* +- **Parallel trends (same as ImputationDiD):** `E[Y_it(0)] = alpha_i + beta_t` for all observations. +- **No-anticipation effects:** `Y_it = Y_it(0)` for all untreated observations. +- Treatment must be absorbing: `D_it` switches from 0 to 1 and stays at 1. +- Always-treated units (treated in all periods) are excluded with a warning, since they have no untreated observations for Stage 1 FE estimation. + +*Estimator equation (two-stage procedure, as implemented):* + +``` +Stage 1. Estimate unit + time fixed effects on untreated observations only (it in Omega_0): + Y_it = alpha_i + beta_t + epsilon_it + Compute residuals: y_tilde_it = Y_it - alpha_hat_i - beta_hat_t (for ALL observations) + +Stage 2. Regress residualized outcomes on treatment indicators (on treated observations): + y_tilde_it = tau * D_it + eta_it + (or event-study specification with horizon indicators) +``` + +Point estimates are identical to ImputationDiD (Borusyak et al. 2024). The two-stage procedure is algebraically equivalent to the imputation approach: both estimate unit+time FE on untreated observations and recover treatment effects from the difference between observed and counterfactual outcomes. + +*Variance: GMM sandwich (Newey & McFadden 1994 Theorem 6.1):* + +The variance accounts for first-stage estimation error propagating into Stage 2, following the GMM framework: + +``` +V(tau_hat) = (D'D)^{-1} * Bread * (D'D)^{-1} + +Bread = sum_c ( sum_{i in c} psi_i )( sum_{i in c} psi_i )' +``` + +where `psi_i` is the stacked influence function for unit i across all its observations, combining the Stage 2 score and the Stage 1 correction term. + +**Note on Equation 6 discrepancy:** The paper's Equation 6 uses a per-cluster inverse `(D_c'D_c)^{-1}` when forming the influence function contribution. The R `did2s` implementation and our code use the GLOBAL inverse `(D'D)^{-1}` following standard GMM theory (Newey & McFadden 1994). We follow the R implementation, which is consistent with standard GMM sandwich variance estimation. + +**No finite-sample adjustments:** The variance estimator uses the raw asymptotic sandwich without degrees-of-freedom corrections (no HC1-style `n/(n-k)` adjustment). This matches the R `did2s` implementation. + +*Bootstrap:* + +Our implementation uses multiplier bootstrap on the GMM influence function: cluster-level `psi` sums are pre-computed, then perturbed with Rademacher/Mammen/Webb weights. The R `did2s` package defaults to block bootstrap (resampling clusters with replacement). Both approaches are asymptotically valid; the multiplier bootstrap is computationally cheaper and consistent with the CallawaySantAnna/ImputationDiD bootstrap patterns in this library. + +*Edge cases:* +- **Always-treated units:** Units treated in all observed periods have no untreated observations for Stage 1 FE estimation. These are excluded with a warning listing the affected unit IDs. Their treated observations do NOT contribute to Stage 2. +- **Rank condition violations:** If the Stage 1 design matrix (unit+time dummies on untreated obs) is rank-deficient, or if certain unit/time FE are unidentified (e.g., a unit with no untreated periods after excluding always-treated), the affected FE produce NaN. Behavior controlled by `rank_deficient_action`: "warn" (default), "error", or "silent". +- **NaN y_tilde handling:** When Stage 1 FE are unidentified for some observations, the residualized outcome `y_tilde` is NaN. These observations are zeroed out (excluded) from the Stage 2 regression and variance computation, matching the treatment of unimputable observations in ImputationDiD. +- **NaN inference for undefined statistics:** t_stat uses NaN when SE is non-finite or zero; p_value and CI also NaN. Matches CallawaySantAnna/ImputationDiD NaN convention. +- **Event study aggregation:** Horizon-specific effects use the same two-stage procedure with horizon indicator dummies in Stage 2. Unidentified horizons (e.g., long-run effects without never-treated units, per Proposition 5 of Borusyak et al. 2024) produce NaN. +- **No never-treated units:** Long-run effects may be unidentified (same limitation as ImputationDiD). Warning emitted for affected horizons. + +**Reference implementation(s):** +- R: `did2s::did2s()` (Kyle Butts & John Gardner) + +**Requirements checklist:** +- [x] Stage 1: OLS on untreated observations only for unit+time FE +- [x] Stage 2: Regress residualized outcomes on treatment indicators +- [x] Point estimates match ImputationDiD +- [x] GMM sandwich variance (Newey & McFadden 1994 Theorem 6.1) +- [x] Global `(D'D)^{-1}` in variance (matches R `did2s`, not paper Eq. 6) +- [x] No finite-sample adjustment (raw asymptotic sandwich) +- [x] Always-treated units excluded with warning +- [x] Multiplier bootstrap on GMM influence function +- [x] Event study and overall ATT aggregation + +--- + # Advanced Estimators ## SyntheticDiD @@ -1264,6 +1335,7 @@ should be a deliberate user choice. | CallawaySantAnna | Analytical (influence fn) | Multiplier bootstrap | | SunAbraham | Cluster-robust + delta method | Pairs bootstrap | | ImputationDiD | Conservative clustered (Thm 3) | Multiplier bootstrap (library extension; percentile CIs and empirical p-values, consistent with CS/SA) | +| TwoStageDiD | GMM sandwich (Newey & McFadden 1994) | Multiplier bootstrap on GMM influence function | | SyntheticDiD | Placebo variance (Alg 4) | Unit-level bootstrap (fixed weights) | | TripleDifference | HC1 / cluster-robust | Influence function for IPW/DR | | TROP | Block bootstrap | — | @@ -1284,6 +1356,7 @@ should be a deliberate user choice. | CallawaySantAnna | did | `att_gt()` | | SunAbraham | fixest | `sunab()` | | ImputationDiD | didimputation | `did_imputation()` | +| TwoStageDiD | did2s | `did2s()` | | SyntheticDiD | synthdid | `synthdid_estimate()` | | TripleDifference | - | (forthcoming) | | TROP | - | (forthcoming) | diff --git a/tests/test_two_stage.py b/tests/test_two_stage.py new file mode 100644 index 00000000..db570ed9 --- /dev/null +++ b/tests/test_two_stage.py @@ -0,0 +1,955 @@ +""" +Tests for Gardner (2022) Two-Stage DiD estimator. +""" + +import warnings + +import numpy as np +import pandas as pd +import pytest + +from diff_diff.two_stage import ( + TwoStageBootstrapResults, + TwoStageDiD, + TwoStageDiDResults, + two_stage_did, +) + +# ============================================================================= +# Shared test data generation +# ============================================================================= + + +def generate_test_data( + n_units: int = 100, + n_periods: int = 10, + treatment_effect: float = 2.0, + never_treated_frac: float = 0.3, + dynamic_effects: bool = True, + seed: int = 42, +) -> pd.DataFrame: + """Generate synthetic staggered adoption data for testing.""" + rng = np.random.default_rng(seed) + + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + + n_never = int(n_units * never_treated_frac) + n_treated = n_units - n_never + + cohort_periods = np.array([3, 5, 7]) + first_treat = np.zeros(n_units, dtype=int) + if n_treated > 0: + cohort_assignments = rng.choice(len(cohort_periods), size=n_treated) + first_treat[n_never:] = cohort_periods[cohort_assignments] + + first_treat_expanded = np.repeat(first_treat, n_periods) + + unit_fe = rng.standard_normal(n_units) * 2.0 + time_fe = np.linspace(0, 1, n_periods) + + unit_fe_expanded = np.repeat(unit_fe, n_periods) + time_fe_expanded = np.tile(time_fe, n_units) + + post = (times >= first_treat_expanded) & (first_treat_expanded > 0) + relative_time = times - first_treat_expanded + + if dynamic_effects: + dynamic_mult = 1 + 0.1 * np.maximum(relative_time, 0) + else: + dynamic_mult = np.ones_like(relative_time, dtype=float) + + effect = treatment_effect * dynamic_mult + + outcomes = ( + unit_fe_expanded + time_fe_expanded + effect * post + rng.standard_normal(len(units)) * 0.5 + ) + + return pd.DataFrame( + { + "unit": units, + "time": times, + "outcome": outcomes, + "first_treat": first_treat_expanded, + } + ) + + +# ============================================================================= +# TestTwoStageDiDBasic +# ============================================================================= + + +class TestTwoStageDiDBasic: + """Tests for basic TwoStageDiD functionality.""" + + def test_basic_fit(self): + """Test basic model fitting.""" + data = generate_test_data() + + est = TwoStageDiD() + results = est.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + assert est.is_fitted_ + assert isinstance(results, TwoStageDiDResults) + + def test_att_accuracy(self): + """Test that ATT recovers true treatment effect.""" + data = generate_test_data(treatment_effect=2.0, dynamic_effects=False, seed=123) + + results = TwoStageDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + # Should recover ~2.0 with reasonable tolerance + assert abs(results.overall_att - 2.0) < 0.3 + + def test_se_positive_finite(self): + """Test that SEs are positive and finite.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + assert results.overall_se > 0 + assert np.isfinite(results.overall_se) + + def test_ci_contains_point_estimate(self): + """Test that confidence interval contains the point estimate.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + assert results.overall_conf_int[0] <= results.overall_att + assert results.overall_att <= results.overall_conf_int[1] + + def test_t_stat_and_p_value(self): + """Test that t-stat and p-value are consistent.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + assert np.isfinite(results.overall_t_stat) + assert 0 <= results.overall_p_value <= 1 + + # t-stat should equal ATT / SE + expected_t = results.overall_att / results.overall_se + assert abs(results.overall_t_stat - expected_t) < 1e-10 + + def test_event_study(self): + """Test event study specification.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + assert results.event_study_effects is not None + assert len(results.event_study_effects) > 0 + + # Check reference period is present + ref_period = -1 + assert ref_period in results.event_study_effects + assert results.event_study_effects[ref_period]["effect"] == 0.0 + + # Post-treatment effects should be positive (treatment_effect=2.0) + post_effects = {h: e for h, e in results.event_study_effects.items() if h >= 0} + assert len(post_effects) > 0 + for h, eff in post_effects.items(): + assert eff["effect"] > 0, f"Post-treatment effect at h={h} should be positive" + assert eff["se"] > 0, f"SE at h={h} should be positive" + assert np.isfinite(eff["t_stat"]) + assert 0 <= eff["p_value"] <= 1 + + def test_group_effects(self): + """Test group (cohort) effects.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="group", + ) + + assert results.group_effects is not None + # Should have 3 groups (cohorts 3, 5, 7) + assert len(results.group_effects) == 3 + for g, eff in results.group_effects.items(): + assert eff["effect"] > 0 + assert eff["se"] > 0 + assert np.isfinite(eff["t_stat"]) + + def test_all_aggregation(self): + """Test aggregate='all' produces both event study and group effects.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="all", + ) + + assert results.event_study_effects is not None + assert results.group_effects is not None + + def test_summary_text(self): + """Test that summary produces expected header text.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + text = results.summary() + assert "Two-Stage DiD Estimator Results (Gardner 2022)" in text + assert "ATT" in text + assert "Overall Average Treatment Effect" in text + + def test_to_dataframe_event_study(self): + """Test to_dataframe with event_study level.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + df = results.to_dataframe("event_study") + assert isinstance(df, pd.DataFrame) + assert "relative_period" in df.columns + assert "effect" in df.columns + assert "se" in df.columns + assert len(df) > 0 + + def test_to_dataframe_group(self): + """Test to_dataframe with group level.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="group", + ) + + df = results.to_dataframe("group") + assert isinstance(df, pd.DataFrame) + assert "group" in df.columns + assert len(df) == 3 + + def test_to_dataframe_observation(self): + """Test to_dataframe with observation level.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + df = results.to_dataframe("observation") + assert isinstance(df, pd.DataFrame) + assert "tau_hat" in df.columns + assert "weight" in df.columns + assert "unit" in df.columns + assert "time" in df.columns + assert len(df) == results.n_treated_obs + + def test_to_dataframe_invalid_level(self): + """Test to_dataframe with invalid level raises.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + with pytest.raises(ValueError, match="Unknown level"): + results.to_dataframe("invalid") + + def test_to_dataframe_no_event_study(self): + """Test to_dataframe raises when event study not computed.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + with pytest.raises(ValueError, match="Event study effects not computed"): + results.to_dataframe("event_study") + + def test_repr(self): + """Test __repr__ contains expected elements.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + repr_str = repr(results) + assert "TwoStageDiDResults" in repr_str + assert "ATT=" in repr_str + assert "SE=" in repr_str + + def test_is_significant_property(self): + """Test is_significant property.""" + data = generate_test_data(treatment_effect=2.0) + results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + assert isinstance(results.is_significant, bool) + # Strong treatment effect should be significant + assert results.is_significant + + def test_significance_stars_property(self): + """Test significance_stars property.""" + data = generate_test_data(treatment_effect=2.0) + results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + stars = results.significance_stars + assert isinstance(stars, str) + # Strong effect should have at least one star + assert len(stars.strip()) > 0 + + def test_metadata_fields(self): + """Test that metadata fields are populated correctly.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + assert results.n_obs == len(data) + assert results.n_treated_obs > 0 + assert results.n_untreated_obs > 0 + assert results.n_treated_obs + results.n_untreated_obs == results.n_obs + assert results.n_treated_units > 0 + assert results.n_control_units > 0 + assert len(results.groups) == 3 + assert len(results.time_periods) == 10 + + +# ============================================================================= +# TestTwoStageDiDEquivalence +# ============================================================================= + + +class TestTwoStageDiDEquivalence: + """Test that TwoStageDiD point estimates match ImputationDiD.""" + + def test_overall_att_matches_imputation(self): + """Overall ATT should match ImputationDiD to machine precision.""" + from diff_diff.imputation import ImputationDiD + + data = generate_test_data() + + ts_results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + imp_results = ImputationDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + assert abs(ts_results.overall_att - imp_results.overall_att) < 1e-10 + + def test_event_study_effects_match_imputation(self): + """Event study point estimates should match ImputationDiD.""" + from diff_diff.imputation import ImputationDiD + + data = generate_test_data() + + ts_results = TwoStageDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + imp_results = ImputationDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + # Both should have the same horizons + ts_horizons = set(ts_results.event_study_effects.keys()) + imp_horizons = set(imp_results.event_study_effects.keys()) + assert ts_horizons == imp_horizons + + # Point estimates should match + for h in ts_horizons: + ts_eff = ts_results.event_study_effects[h]["effect"] + imp_eff = imp_results.event_study_effects[h]["effect"] + if np.isfinite(ts_eff) and np.isfinite(imp_eff): + assert ( + abs(ts_eff - imp_eff) < 1e-8 + ), f"Effect mismatch at h={h}: TS={ts_eff:.10f}, Imp={imp_eff:.10f}" + + def test_group_effects_match_imputation(self): + """Group point estimates should match ImputationDiD.""" + from diff_diff.imputation import ImputationDiD + + data = generate_test_data() + + ts_results = TwoStageDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="group", + ) + imp_results = ImputationDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="group", + ) + + assert set(ts_results.group_effects.keys()) == set(imp_results.group_effects.keys()) + + for g in ts_results.group_effects: + ts_eff = ts_results.group_effects[g]["effect"] + imp_eff = imp_results.group_effects[g]["effect"] + assert abs(ts_eff - imp_eff) < 1e-8 + + def test_ses_differ_from_imputation(self): + """GMM SEs should differ from conservative (Theorem 3) SEs.""" + from diff_diff.imputation import ImputationDiD + + data = generate_test_data() + + ts_results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + imp_results = ImputationDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + # SEs should differ (different variance estimators) + assert abs(ts_results.overall_se - imp_results.overall_se) > 1e-6 + + +# ============================================================================= +# TestTwoStageDiDVariance +# ============================================================================= + + +class TestTwoStageDiDVariance: + """Tests for GMM sandwich variance estimator.""" + + def test_gmm_se_differs_from_naive(self): + """GMM SE should differ from naive Stage 2 OLS SE.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + # The GMM SE accounts for first-stage estimation uncertainty + assert results.overall_se > 0 + assert np.isfinite(results.overall_se) + + def test_event_study_se_positive(self): + """Event study SEs should all be positive.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + for h, eff in results.event_study_effects.items(): + if eff.get("n_obs", 0) > 0: + assert eff["se"] > 0, f"SE at h={h} should be positive" + assert np.isfinite(eff["se"]) + + +# ============================================================================= +# TestTwoStageDiDEdgeCases +# ============================================================================= + + +class TestTwoStageDiDEdgeCases: + """Tests for edge cases and error handling.""" + + def test_always_treated_excluded_with_warning(self): + """Always-treated units should be excluded with a warning.""" + data = generate_test_data() + + # Add an always-treated unit (first_treat = 0 means treated at time 0) + always_treated = pd.DataFrame( + { + "unit": np.repeat(999, 10), + "time": np.arange(10), + "outcome": np.random.default_rng(42).standard_normal(10), + "first_treat": np.repeat(-1, 10), # treated before sample starts + } + ) + data_with_always = pd.concat([data, always_treated], ignore_index=True) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = TwoStageDiD().fit( + data_with_always, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + always_treated_warns = [ + x for x in w if "treated in all observed periods" in str(x.message) + ] + assert len(always_treated_warns) > 0 + + # Verify unit was excluded (total obs should be less) + assert results.n_obs == len(data) + + def test_no_never_treated_works(self): + """Estimation should work without never-treated units.""" + data = generate_test_data(never_treated_frac=0.0) + + results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + assert results.overall_att > 0 + assert results.overall_se > 0 + + def test_single_cohort(self): + """Should work with a single treatment cohort.""" + rng = np.random.default_rng(42) + n_units, n_periods = 50, 8 + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + + first_treat = np.zeros(n_units, dtype=int) + first_treat[15:] = 4 # single cohort at period 4 + + ft_exp = np.repeat(first_treat, n_periods) + post = (times >= ft_exp) & (ft_exp > 0) + outcomes = ( + rng.standard_normal(n_units)[np.repeat(np.arange(n_units), n_periods)] + + 2.0 * post + + rng.standard_normal(len(units)) * 0.5 + ) + + data = pd.DataFrame( + {"unit": units, "time": times, "outcome": outcomes, "first_treat": ft_exp} + ) + + results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + assert abs(results.overall_att - 2.0) < 0.5 + assert len(results.groups) == 1 + + def test_anticipation_shifts_timing(self): + """Anticipation parameter should shift effective treatment timing.""" + data = generate_test_data(seed=123) + + results_no_ant = TwoStageDiD(anticipation=0).fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + results_with_ant = TwoStageDiD(anticipation=1).fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + # With anticipation, more obs are treated -> different ATT + assert results_with_ant.n_treated_obs > results_no_ant.n_treated_obs + assert abs(results_no_ant.overall_att - results_with_ant.overall_att) > 0.01 + + def test_rank_deficiency_warning(self): + """Rank deficiency should emit warning in 'warn' mode.""" + # Create data where some treated units have no untreated periods + rng = np.random.default_rng(42) + n_units, n_periods = 20, 5 + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + + # All units treated at period 0 (except never-treated) + first_treat = np.zeros(n_units, dtype=int) + first_treat[5:] = 0 # never treated (first_treat=0) + first_treat[:5] = 1 # treated at period 1 + + ft_exp = np.repeat(first_treat, n_periods) + outcomes = rng.standard_normal(len(units)) + + data = pd.DataFrame( + {"unit": units, "time": times, "outcome": outcomes, "first_treat": ft_exp} + ) + + # Should work without error in warn mode + results = TwoStageDiD(rank_deficient_action="warn").fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + assert isinstance(results, TwoStageDiDResults) + + def test_rank_deficiency_error(self): + """Rank deficiency should raise in 'error' mode when violated.""" + # Create data where a treated unit has NO untreated periods at all + rng = np.random.default_rng(42) + n_units, n_periods = 20, 5 + units = np.repeat(np.arange(n_units), n_periods) + times = np.tile(np.arange(n_periods), n_units) + + # Some units treated at period 0 (no pre-treatment) + first_treat = np.zeros(n_units, dtype=int) + first_treat[10:] = 0 # first_treat at the first time period + first_treat[:5] = 0 # never treated + first_treat[5:10] = 0 # Make all units at period 0 as treated + # Actually let's have some treated at period 0 so they fail rank check + first_treat[5:10] = 0 # All these are coded as never-treated (first_treat=0) + + ft_exp = np.repeat(first_treat, n_periods) + outcomes = rng.standard_normal(len(units)) + data = pd.DataFrame( + {"unit": units, "time": times, "outcome": outcomes, "first_treat": ft_exp} + ) + + # All units are never-treated, so no treated obs -> ValueError + with pytest.raises(ValueError, match="No treated observations"): + TwoStageDiD(rank_deficient_action="error").fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + def test_nan_propagation(self): + """NaN SE should propagate to t_stat, p_value, conf_int.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + # For a reference period, t_stat and p_value should be NaN + if results.event_study_effects: + pass # Only check if event study was computed + + # Normal results should have finite values + assert np.isfinite(results.overall_t_stat) + assert np.isfinite(results.overall_p_value) + + def test_covariates(self): + """Estimation with covariates should work.""" + data = generate_test_data() + rng = np.random.default_rng(99) + data["x1"] = rng.standard_normal(len(data)) + data["x2"] = rng.standard_normal(len(data)) + + results = TwoStageDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + covariates=["x1", "x2"], + ) + + assert results.overall_att > 0 + assert results.overall_se > 0 + assert np.isfinite(results.overall_se) + + def test_missing_column_error(self): + """Missing required columns should raise ValueError.""" + data = generate_test_data() + + with pytest.raises(ValueError, match="Missing columns"): + TwoStageDiD().fit( + data, + outcome="nonexistent", + unit="unit", + time="time", + first_treat="first_treat", + ) + + def test_no_treated_obs_error(self): + """Should raise when no treated observations exist.""" + rng = np.random.default_rng(42) + n = 100 + data = pd.DataFrame( + { + "unit": np.repeat(np.arange(10), 10), + "time": np.tile(np.arange(10), 10), + "outcome": rng.standard_normal(n), + "first_treat": 0, # all never-treated + } + ) + + with pytest.raises(ValueError, match="No treated"): + TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + def test_horizon_max(self): + """horizon_max should limit event study horizons.""" + data = generate_test_data() + results = TwoStageDiD(horizon_max=2).fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + # All horizons should have |h| <= 2 + for h in results.event_study_effects: + if results.event_study_effects[h].get("n_obs", 0) > 0: + assert abs(h) <= 2 + + +# ============================================================================= +# TestTwoStageDiDParameters +# ============================================================================= + + +class TestTwoStageDiDParameters: + """Tests for parameter handling.""" + + def test_get_params(self): + """get_params should include all __init__ params.""" + est = TwoStageDiD(anticipation=1, alpha=0.1, n_bootstrap=100, seed=42, horizon_max=5) + params = est.get_params() + + assert params["anticipation"] == 1 + assert params["alpha"] == 0.1 + assert params["n_bootstrap"] == 100 + assert params["seed"] == 42 + assert params["horizon_max"] == 5 + assert params["rank_deficient_action"] == "warn" + assert params["cluster"] is None + + def test_set_params(self): + """set_params should modify attributes.""" + est = TwoStageDiD() + est.set_params(anticipation=2, alpha=0.1) + + assert est.anticipation == 2 + assert est.alpha == 0.1 + + def test_set_params_returns_self(self): + """set_params should return self for chaining.""" + est = TwoStageDiD() + result = est.set_params(anticipation=1) + assert result is est + + def test_set_params_unknown_raises(self): + """set_params with unknown param should raise.""" + est = TwoStageDiD() + with pytest.raises(ValueError, match="Unknown parameter"): + est.set_params(nonexistent_param=42) + + def test_rank_deficient_action_validation(self): + """Invalid rank_deficient_action should raise.""" + with pytest.raises(ValueError, match="rank_deficient_action"): + TwoStageDiD(rank_deficient_action="invalid") + + def test_cluster_changes_ses(self): + """Different cluster variable should change SEs.""" + data = generate_test_data() + # Add a cluster variable with fewer clusters than units + data["cluster"] = data["unit"] % 10 + + results_unit = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + results_cluster = TwoStageDiD(cluster="cluster").fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + # Point estimates should be the same + assert abs(results_unit.overall_att - results_cluster.overall_att) < 1e-10 + # SEs should differ + assert abs(results_unit.overall_se - results_cluster.overall_se) > 1e-6 + + def test_horizon_max_limits_horizons(self): + """horizon_max should limit event study horizons.""" + data = generate_test_data() + + results_full = TwoStageDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + results_limited = TwoStageDiD(horizon_max=2).fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + full_horizons = set(results_full.event_study_effects.keys()) + limited_horizons = set(results_limited.event_study_effects.keys()) + + assert len(limited_horizons) <= len(full_horizons) + + +# ============================================================================= +# TestTwoStageDiDBootstrap +# ============================================================================= + + +class TestTwoStageDiDBootstrap: + """Tests for bootstrap inference.""" + + def test_bootstrap_runs(self, ci_params): + """Bootstrap should complete and produce results.""" + data = generate_test_data() + n_boot = ci_params.bootstrap(50) + results = TwoStageDiD(n_bootstrap=n_boot, seed=42).fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + assert results.bootstrap_results is not None + assert isinstance(results.bootstrap_results, TwoStageBootstrapResults) + + def test_bootstrap_structure(self, ci_params): + """Bootstrap results should have correct structure.""" + data = generate_test_data() + n_boot = ci_params.bootstrap(50) + results = TwoStageDiD(n_bootstrap=n_boot, seed=42).fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + br = results.bootstrap_results + assert br.n_bootstrap == n_boot + assert br.weight_type == "rademacher" + assert br.overall_att_se > 0 + assert br.overall_att_ci[0] < br.overall_att_ci[1] + assert 0 < br.overall_att_p_value <= 1 + + def test_bootstrap_updates_inference(self, ci_params): + """Bootstrap should update the main results inference.""" + data = generate_test_data() + n_boot = ci_params.bootstrap(50) + + results_analytical = TwoStageDiD(seed=42).fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + results_bootstrap = TwoStageDiD(n_bootstrap=n_boot, seed=42).fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + # Point estimates should match + assert abs(results_analytical.overall_att - results_bootstrap.overall_att) < 1e-10 + # SEs should differ (analytical GMM vs bootstrap) + assert abs(results_analytical.overall_se - results_bootstrap.overall_se) > 1e-6 + + def test_bootstrap_event_study(self, ci_params): + """Bootstrap should work with event study specification.""" + data = generate_test_data() + n_boot = ci_params.bootstrap(50) + results = TwoStageDiD(n_bootstrap=n_boot, seed=42).fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + assert results.bootstrap_results is not None + assert results.bootstrap_results.event_study_ses is not None + for h, se in results.bootstrap_results.event_study_ses.items(): + assert se > 0 + + +# ============================================================================= +# TestTwoStageDiDConvenience +# ============================================================================= + + +class TestTwoStageDiDConvenience: + """Tests for convenience function.""" + + def test_convenience_function_returns_results(self): + """Convenience function should return TwoStageDiDResults.""" + data = generate_test_data() + results = two_stage_did( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + assert isinstance(results, TwoStageDiDResults) + assert results.overall_att > 0 + + def test_convenience_function_kwargs(self): + """Constructor kwargs should be forwarded.""" + data = generate_test_data() + results = two_stage_did( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + anticipation=1, + alpha=0.1, + ) + + assert isinstance(results, TwoStageDiDResults) + assert results.alpha == 0.1 + + def test_convenience_function_aggregate(self): + """Convenience function should support aggregate parameter.""" + data = generate_test_data() + results = two_stage_did( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + assert results.event_study_effects is not None + + def test_estimator_summary_before_fit_raises(self): + """Calling summary() before fit() should raise.""" + est = TwoStageDiD() + with pytest.raises(RuntimeError, match="fitted"): + est.summary() + + def test_print_summary(self, capsys): + """print_summary should print to stdout.""" + data = generate_test_data() + results = TwoStageDiD().fit( + data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" + ) + + results.print_summary() + captured = capsys.readouterr() + assert "Two-Stage DiD" in captured.out From bc2eb77fb41838772feda30b8352ff2092779e42 Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 16 Feb 2026 13:20:38 -0500 Subject: [PATCH 2/4] Address PR review: edge case fixes for TwoStageDiD - Always-treated warning now lists affected unit IDs (truncated at 10) - Bootstrap handles NaN y_tilde: masks NaN obs in static, event study, and group bootstrap paths; returns None when all treated obs are NaN - balance_e warns when no cohorts qualify instead of silently falling back - Add 3 edge case tests and REGISTRY.md update Co-Authored-By: Claude Opus 4.6 --- diff_diff/two_stage.py | 67 ++++++++++++++++++++++++---------- docs/methodology/REGISTRY.md | 1 + tests/test_two_stage.py | 71 ++++++++++++++++++++++++++++++++++++ 3 files changed, 120 insertions(+), 19 deletions(-) diff --git a/diff_diff/two_stage.py b/diff_diff/two_stage.py index de2c3dbe..fc399c99 100644 --- a/diff_diff/two_stage.py +++ b/diff_diff/two_stage.py @@ -593,18 +593,19 @@ def fit( # Check for always-treated units min_time = df[time].min() always_treated_mask = (~df["_never_treated"]) & (df[first_treat] <= min_time) - n_always_treated = df.loc[always_treated_mask, unit].nunique() + always_treated_units = df.loc[always_treated_mask, unit].unique() + n_always_treated = len(always_treated_units) if n_always_treated > 0: + unit_list = ", ".join(str(u) for u in always_treated_units[:10]) + suffix = f" (and {n_always_treated - 10} more)" if n_always_treated > 10 else "" warnings.warn( f"{n_always_treated} unit(s) are treated in all observed periods " - f"(first_treat <= {min_time}). These units have no untreated " - "observations and cannot contribute to the counterfactual model. " - "Excluding from estimation.", + f"(first_treat <= {min_time}): [{unit_list}{suffix}]. " + "These units have no untreated observations and cannot contribute " + "to the counterfactual model. Excluding from estimation.", UserWarning, stacklevel=2, ) - # Exclude always-treated units - always_treated_units = df.loc[always_treated_mask, unit].unique() df = df[~df[unit].isin(always_treated_units)].copy() # Treatment indicator with anticipation @@ -1183,11 +1184,25 @@ def _stage2_event_study( for g, horizons in cohort_rel_times.items(): if required_range.issubset(horizons): balanced_cohorts.add(g) - balance_mask = ( - df[first_treat].isin(balanced_cohorts).values - if balanced_cohorts - else np.ones(n, dtype=bool) - ) + if not balanced_cohorts: + warnings.warn( + f"No cohorts satisfy balance_e={balance_e} requirement. " + "Event study results will contain only the reference period. " + "Consider reducing balance_e.", + UserWarning, + stacklevel=2, + ) + return { + ref_period: { + "effect": 0.0, + "se": 0.0, + "t_stat": np.nan, + "p_value": np.nan, + "conf_int": (0.0, 0.0), + "n_obs": 0, + } + } + balance_mask = df[first_treat].isin(balanced_cohorts).values else: balance_mask = np.ones(n, dtype=bool) @@ -1724,7 +1739,7 @@ def _run_bootstrap( original_event_study: Optional[Dict[int, Dict[str, Any]]], original_group: Optional[Dict[Any, Dict[str, Any]]], aggregate: Optional[str], - ) -> TwoStageBootstrapResults: + ) -> Optional[TwoStageBootstrapResults]: """Run multiplier bootstrap on GMM influence function.""" if self.n_bootstrap < 50: warnings.warn( @@ -1738,12 +1753,23 @@ def _run_bootstrap( from diff_diff.staggered_bootstrap import _generate_bootstrap_weights_batch - y_tilde = df["_y_tilde"].values + y_tilde = df["_y_tilde"].values.copy() # .copy() to avoid mutating df column n = len(df) cluster_ids = df[cluster_var].values + # Handle NaN y_tilde (from unidentified FEs) — matches _stage2_static logic + nan_mask = ~np.isfinite(y_tilde) + if nan_mask.any(): + y_tilde[nan_mask] = 0.0 + # --- Static specification bootstrap --- - D = omega_1_mask.values.astype(float) + D = omega_1_mask.values.astype(float) # .astype() already creates a copy + D[nan_mask] = 0.0 # Exclude NaN y_tilde obs from bootstrap estimation + + # Degenerate case: all treated obs have NaN y_tilde + if D.sum() == 0: + return None + X_2_static = D.reshape(-1, 1) coef_static = solve_ols(X_2_static, y_tilde, return_vcov=False)[0] eps_2_static = y_tilde - X_2_static @ coef_static @@ -1811,11 +1837,10 @@ def _run_bootstrap( for g, horizons in cohort_rel_times.items(): if required_range.issubset(horizons): balanced_cohorts.add(g) - balance_mask = ( - df[first_treat].isin(balanced_cohorts).values - if balanced_cohorts - else np.ones(n, dtype=bool) - ) + if not balanced_cohorts: + all_horizons = [] # No qualifying cohorts -> skip event study bootstrap + else: + balance_mask = df[first_treat].isin(balanced_cohorts).values else: balance_mask = np.ones(n, dtype=bool) @@ -1827,6 +1852,8 @@ def _run_bootstrap( for i in range(n): if not balance_mask[i]: continue + if nan_mask[i]: + continue # NaN y_tilde -> exclude from bootstrap event study h = rel_times[i] if np.isfinite(h): h_int = int(h) @@ -1890,6 +1917,8 @@ def _run_bootstrap( treated_mask = omega_1_mask.values for i in range(n): if treated_mask[i]: + if nan_mask[i]: + continue # NaN y_tilde -> exclude from group bootstrap g = ft_vals[i] if g in group_to_col: X_2_grp[i, group_to_col[g]] = 1.0 diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index b8e7fc5a..e0f74d54 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -618,6 +618,7 @@ Our implementation uses multiplier bootstrap on the GMM influence function: clus - **NaN y_tilde handling:** When Stage 1 FE are unidentified for some observations, the residualized outcome `y_tilde` is NaN. These observations are zeroed out (excluded) from the Stage 2 regression and variance computation, matching the treatment of unimputable observations in ImputationDiD. - **NaN inference for undefined statistics:** t_stat uses NaN when SE is non-finite or zero; p_value and CI also NaN. Matches CallawaySantAnna/ImputationDiD NaN convention. - **Event study aggregation:** Horizon-specific effects use the same two-stage procedure with horizon indicator dummies in Stage 2. Unidentified horizons (e.g., long-run effects without never-treated units, per Proposition 5 of Borusyak et al. 2024) produce NaN. +- **balance_e with no qualifying cohorts:** If no cohorts have sufficient pre/post coverage for the requested `balance_e`, a warning is emitted and event study results contain only the reference period. - **No never-treated units:** Long-run effects may be unidentified (same limitation as ImputationDiD). Warning emitted for affected horizons. **Reference implementation(s):** diff --git a/tests/test_two_stage.py b/tests/test_two_stage.py index db570ed9..17db8851 100644 --- a/tests/test_two_stage.py +++ b/tests/test_two_stage.py @@ -723,6 +723,77 @@ def test_horizon_max(self): if results.event_study_effects[h].get("n_obs", 0) > 0: assert abs(h) <= 2 + def test_always_treated_warning_lists_unit_ids(self): + """Always-treated warning should include affected unit IDs.""" + data = generate_test_data() + + # Add two always-treated units (first_treat before min_time=0) + always_treated = pd.DataFrame( + { + "unit": np.repeat([997, 998], 10), + "time": np.tile(np.arange(10), 2), + "outcome": np.random.default_rng(42).standard_normal(20), + "first_treat": np.repeat([-1, -2], 10), + } + ) + data_with_always = pd.concat([data, always_treated], ignore_index=True) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + TwoStageDiD().fit( + data_with_always, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + always_warns = [x for x in w if "treated in all observed periods" in str(x.message)] + assert len(always_warns) == 1 + msg = str(always_warns[0].message) + assert "997" in msg + assert "998" in msg + + def test_bootstrap_with_nan_y_tilde(self, ci_params): + """Bootstrap should handle NaN y_tilde from unidentified FEs.""" + # No never-treated units: cohorts 3, 5, 7 on periods 0-9 means + # periods 7-9 have zero untreated obs -> NaN y_tilde + data = generate_test_data(never_treated_frac=0.0) + n_boot = ci_params.bootstrap(20) + + results = TwoStageDiD(n_bootstrap=n_boot).fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + assert np.isfinite(results.overall_att) + assert results.overall_se > 0 + + def test_balance_e_empty_cohorts_warns(self): + """Unreasonably large balance_e should warn when no cohorts qualify.""" + data = generate_test_data() + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = TwoStageDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + balance_e=100, # No cohort can satisfy this + ) + balance_warns = [x for x in w if "No cohorts satisfy" in str(x.message)] + assert len(balance_warns) > 0 + + # Event study should contain only the reference period + assert len(results.event_study_effects) == 1 + ref_key = list(results.event_study_effects.keys())[0] + assert results.event_study_effects[ref_key]["n_obs"] == 0 + # ============================================================================= # TestTwoStageDiDParameters From 81878683398675a39ec174d6d951e14f06474da6 Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 16 Feb 2026 13:53:28 -0500 Subject: [PATCH 3/4] Fix zero-observation horizons/cohorts producing se=0 instead of NaN Address PR review round 2: - [P1] Add n_obs==0 early check in _stage2_event_study and _stage2_group to produce NaN inference for zero-observation horizons/cohorts - [P2] Fix REGISTRY.md bootstrap weight type (Rademacher only, not Rademacher/Mammen/Webb) - [P2] Add tests for zero-observation event study horizons and group effects from NaN y_tilde filtering Co-Authored-By: Claude Opus 4.6 --- diff_diff/two_stage.py | 28 ++++++++++++++-- docs/methodology/REGISTRY.md | 4 ++- tests/test_two_stage.py | 65 ++++++++++++++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 3 deletions(-) diff --git a/diff_diff/two_stage.py b/diff_diff/two_stage.py index fc399c99..deba9e93 100644 --- a/diff_diff/two_stage.py +++ b/diff_diff/two_stage.py @@ -1275,9 +1275,21 @@ def _stage2_event_study( for h in est_horizons: j = horizon_to_col[h] + n_obs = int(np.sum(X_2[:, j])) + + if n_obs == 0: + event_study_effects[h] = { + "effect": np.nan, + "se": np.nan, + "t_stat": np.nan, + "p_value": np.nan, + "conf_int": (np.nan, np.nan), + "n_obs": 0, + } + continue + effect = float(coef[j]) se = float(np.sqrt(max(V[j, j], 0.0))) - n_obs = int(np.sum(X_2[:, j])) t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan p_val = compute_p_value(t_stat) @@ -1358,9 +1370,21 @@ def _stage2_group( group_effects: Dict[Any, Dict[str, Any]] = {} for g in treatment_groups: j = group_to_col[g] + n_obs = int(np.sum(X_2[:, j])) + + if n_obs == 0: + group_effects[g] = { + "effect": np.nan, + "se": np.nan, + "t_stat": np.nan, + "p_value": np.nan, + "conf_int": (np.nan, np.nan), + "n_obs": 0, + } + continue + effect = float(coef[j]) se = float(np.sqrt(max(V[j, j], 0.0))) - n_obs = int(np.sum(X_2[:, j])) t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan p_val = compute_p_value(t_stat) diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index e0f74d54..400d5293 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -610,7 +610,7 @@ where `psi_i` is the stacked influence function for unit i across all its observ *Bootstrap:* -Our implementation uses multiplier bootstrap on the GMM influence function: cluster-level `psi` sums are pre-computed, then perturbed with Rademacher/Mammen/Webb weights. The R `did2s` package defaults to block bootstrap (resampling clusters with replacement). Both approaches are asymptotically valid; the multiplier bootstrap is computationally cheaper and consistent with the CallawaySantAnna/ImputationDiD bootstrap patterns in this library. +Our implementation uses multiplier bootstrap on the GMM influence function: cluster-level `psi` sums are pre-computed, then perturbed with Rademacher weights. The R `did2s` package defaults to block bootstrap (resampling clusters with replacement). Both approaches are asymptotically valid; the multiplier bootstrap is computationally cheaper and consistent with the CallawaySantAnna/ImputationDiD bootstrap patterns in this library. *Edge cases:* - **Always-treated units:** Units treated in all observed periods have no untreated observations for Stage 1 FE estimation. These are excluded with a warning listing the affected unit IDs. Their treated observations do NOT contribute to Stage 2. @@ -620,6 +620,8 @@ Our implementation uses multiplier bootstrap on the GMM influence function: clus - **Event study aggregation:** Horizon-specific effects use the same two-stage procedure with horizon indicator dummies in Stage 2. Unidentified horizons (e.g., long-run effects without never-treated units, per Proposition 5 of Borusyak et al. 2024) produce NaN. - **balance_e with no qualifying cohorts:** If no cohorts have sufficient pre/post coverage for the requested `balance_e`, a warning is emitted and event study results contain only the reference period. - **No never-treated units:** Long-run effects may be unidentified (same limitation as ImputationDiD). Warning emitted for affected horizons. +- **Zero-observation horizons after filtering:** When `balance_e` or NaN `y_tilde` filtering results in zero observations for some event study horizons, those horizons produce NaN for all inference fields (effect, SE, t-stat, p-value, CI) with n_obs=0. This differs from the Proposition 5 case (unidentified long-run effects) which has observations but unidentified counterfactual. +- **Zero-observation cohorts in group effects:** If all treated observations for a cohort have NaN `y_tilde` (excluded from estimation), that cohort's group effect is NaN with n_obs=0. **Reference implementation(s):** - R: `did2s::did2s()` (Kyle Butts & John Gardner) diff --git a/tests/test_two_stage.py b/tests/test_two_stage.py index 17db8851..6d467f68 100644 --- a/tests/test_two_stage.py +++ b/tests/test_two_stage.py @@ -794,6 +794,71 @@ def test_balance_e_empty_cohorts_warns(self): ref_key = list(results.event_study_effects.keys())[0] assert results.event_study_effects[ref_key]["n_obs"] == 0 + def test_event_study_nan_for_zero_obs_horizons(self): + """Zero-observation horizons from NaN y_tilde produce NaN inference.""" + # No never-treated: cohorts 3, 5, 7; periods 0-9. + # Periods 7-9 have zero untreated obs → NaN y_tilde. + # Horizon 4 = cohort 3 at period 7 (NaN) + cohort 5 at period 9 (NaN) → 0 obs. + # Horizons 5, 6 = cohort 3 at periods 8, 9 (NaN) → 0 obs. + # Horizons 0-3 have valid observations from multiple cohorts. + data = generate_test_data(never_treated_frac=0.0) + results = TwoStageDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) + + assert results.event_study_effects is not None + + # Horizons 0-3 should have observations and finite effects + for h in range(0, 4): + eff = results.event_study_effects[h] + assert eff["n_obs"] > 0, f"Horizon {h} should have observations" + assert np.isfinite(eff["effect"]), f"Horizon {h} effect should be finite" + + # Horizons 4, 5, 6 should have zero obs and NaN inference + for h in [4, 5, 6]: + eff = results.event_study_effects[h] + assert eff["n_obs"] == 0, f"Horizon {h} should have 0 observations" + assert np.isnan(eff["effect"]), f"Horizon {h} effect should be NaN" + assert np.isnan(eff["se"]), f"Horizon {h} SE should be NaN" + assert np.isnan(eff["t_stat"]), f"Horizon {h} t_stat should be NaN" + assert np.isnan(eff["p_value"]), f"Horizon {h} p_value should be NaN" + assert np.isnan(eff["conf_int"][0]), f"Horizon {h} CI lower should be NaN" + + def test_group_effects_nan_for_all_nan_cohort(self): + """Cohort with all NaN y_tilde produces NaN group effect.""" + # No never-treated units: cohorts 3, 5, 7; periods 0-9. + # Periods 7, 8, 9 have zero untreated obs (all 3 cohorts treated by t=7). + # Cohort 7: treated at periods 7-9, all have NaN y_tilde -> n_obs=0. + # Cohorts 3, 5: have some valid treated periods -> n_obs > 0. + data = generate_test_data(never_treated_frac=0.0) + results = TwoStageDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="group", + ) + + assert results.group_effects is not None + + # Cohorts 3 and 5 should have valid effects + for g in [3, 5]: + eff = results.group_effects[g] + assert eff["n_obs"] > 0, f"Cohort {g} should have observations" + assert np.isfinite(eff["effect"]), f"Cohort {g} effect should be finite" + + # Cohort 7: all treated obs have NaN y_tilde -> zero obs -> NaN + eff_7 = results.group_effects[7] + assert eff_7["n_obs"] == 0, "Cohort 7 should have 0 observations" + assert np.isnan(eff_7["effect"]), "Cohort 7 effect should be NaN" + assert np.isnan(eff_7["se"]), "Cohort 7 SE should be NaN" + # ============================================================================= # TestTwoStageDiDParameters From e0286d2149b7280594ba3b085f04689dcfbfd592 Mon Sep 17 00:00:00 2001 From: igerber Date: Mon, 16 Feb 2026 14:26:48 -0500 Subject: [PATCH 4/4] Add Proposition 5 detection for unidentified long-run horizons Distinguish Prop 5 horizons (treated obs exist, counterfactual unidentified) from zero-observation horizons. Prop 5 horizons get n_obs > 0 with NaN inference and a warning, matching ImputationDiD behavior. Co-Authored-By: Claude Opus 4.6 --- diff_diff/two_stage.py | 57 ++++++++++++++++++++++++++++++++++-- docs/methodology/REGISTRY.md | 4 +-- tests/test_two_stage.py | 39 ++++++++++++++---------- 3 files changed, 80 insertions(+), 20 deletions(-) diff --git a/diff_diff/two_stage.py b/diff_diff/two_stage.py index deba9e93..7aec9b71 100644 --- a/diff_diff/two_stage.py +++ b/diff_diff/two_stage.py @@ -1206,8 +1206,39 @@ def _stage2_event_study( else: balance_mask = np.ones(n, dtype=bool) - # Remove reference period from estimation horizons - est_horizons = [h for h in all_horizons if h != ref_period] + # Check Proposition 5: no never-treated units + has_never_treated = df["_never_treated"].any() + h_bar = np.inf + if not has_never_treated and len(treatment_groups) > 1: + h_bar = max(treatment_groups) - min(treatment_groups) + + # Identify Prop 5 horizons and compute their actual treated obs counts. + # Treated obs have NaN y_tilde at these horizons (counterfactual + # unidentified), but actual_n counts them to distinguish from truly + # empty horizons. rel_times is NaN for untreated/never-treated obs + # (line ~653), so (rel_times == h) is False for them. + prop5_horizons = [] + prop5_effects: Dict[int, Dict[str, Any]] = {} + if h_bar < np.inf: + for h in all_horizons: + if h == ref_period: + continue + if h >= h_bar: + actual_n = int(np.sum((rel_times == h) & omega_1_mask.values & balance_mask)) + if actual_n > 0: + prop5_horizons.append(h) + prop5_effects[h] = { + "effect": np.nan, + "se": np.nan, + "t_stat": np.nan, + "p_value": np.nan, + "conf_int": (np.nan, np.nan), + "n_obs": actual_n, + } + + # Remove reference period AND Prop 5 horizons from estimation + prop5_set = set(prop5_horizons) + est_horizons = [h for h in all_horizons if h != ref_period and h not in prop5_set] if len(est_horizons) == 0: # No horizons to estimate — return just reference period @@ -1308,6 +1339,17 @@ def _stage2_event_study( "n_obs": n_obs, } + # Add Proposition 5 entries (unidentified horizons with n_obs > 0) + event_study_effects.update(prop5_effects) + + if prop5_horizons: + warnings.warn( + f"Horizons {prop5_horizons} are not identified without " + f"never-treated units (Proposition 5). Set to NaN.", + UserWarning, + stacklevel=2, + ) + return event_study_effects def _stage2_group( @@ -1869,6 +1911,15 @@ def _run_bootstrap( balance_mask = np.ones(n, dtype=bool) est_horizons = [h for h in all_horizons if h != ref_period] + + # Filter out Prop 5 horizons (same logic as _stage2_event_study) + has_never_treated = df["_never_treated"].any() + h_bar_boot = np.inf + if not has_never_treated and len(treatment_groups) > 1: + h_bar_boot = max(treatment_groups) - min(treatment_groups) + if h_bar_boot < np.inf: + est_horizons = [h for h in est_horizons if h < h_bar_boot] + if est_horizons: horizon_to_col = {h: j for j, h in enumerate(est_horizons)} k_es = len(est_horizons) @@ -1911,6 +1962,8 @@ def _run_bootstrap( for h in original_event_study: if original_event_study[h].get("n_obs", 0) == 0: continue + if np.isnan(original_event_study[h]["effect"]): + continue # Skip Prop 5 and other NaN-effect horizons if h not in horizon_to_col: continue j = horizon_to_col[h] diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 400d5293..49785f41 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -619,8 +619,8 @@ Our implementation uses multiplier bootstrap on the GMM influence function: clus - **NaN inference for undefined statistics:** t_stat uses NaN when SE is non-finite or zero; p_value and CI also NaN. Matches CallawaySantAnna/ImputationDiD NaN convention. - **Event study aggregation:** Horizon-specific effects use the same two-stage procedure with horizon indicator dummies in Stage 2. Unidentified horizons (e.g., long-run effects without never-treated units, per Proposition 5 of Borusyak et al. 2024) produce NaN. - **balance_e with no qualifying cohorts:** If no cohorts have sufficient pre/post coverage for the requested `balance_e`, a warning is emitted and event study results contain only the reference period. -- **No never-treated units:** Long-run effects may be unidentified (same limitation as ImputationDiD). Warning emitted for affected horizons. -- **Zero-observation horizons after filtering:** When `balance_e` or NaN `y_tilde` filtering results in zero observations for some event study horizons, those horizons produce NaN for all inference fields (effect, SE, t-stat, p-value, CI) with n_obs=0. This differs from the Proposition 5 case (unidentified long-run effects) which has observations but unidentified counterfactual. +- **No never-treated units (Proposition 5):** When there are no never-treated units and multiple treatment cohorts, horizons h >= h_bar (where h_bar = max(groups) - min(groups)) are unidentified per Proposition 5 of Borusyak et al. (2024). These produce NaN inference with n_obs > 0 (treated observations exist but counterfactual is unidentified) and a warning listing affected horizons. Matches ImputationDiD behavior. Proposition 5 applies to event study horizons only, not cohort aggregation — a cohort whose treated obs all fall at Prop 5 horizons naturally gets n_obs=0 in group effects because all its y_tilde values are NaN. +- **Zero-observation horizons after filtering:** When `balance_e` or NaN `y_tilde` filtering results in zero observations for some non-Prop-5 event study horizons, those horizons produce NaN for all inference fields (effect, SE, t-stat, p-value, CI) with n_obs=0. - **Zero-observation cohorts in group effects:** If all treated observations for a cohort have NaN `y_tilde` (excluded from estimation), that cohort's group effect is NaN with n_obs=0. **Reference implementation(s):** diff --git a/tests/test_two_stage.py b/tests/test_two_stage.py index 6d467f68..be0108a4 100644 --- a/tests/test_two_stage.py +++ b/tests/test_two_stage.py @@ -794,35 +794,42 @@ def test_balance_e_empty_cohorts_warns(self): ref_key = list(results.event_study_effects.keys())[0] assert results.event_study_effects[ref_key]["n_obs"] == 0 - def test_event_study_nan_for_zero_obs_horizons(self): - """Zero-observation horizons from NaN y_tilde produce NaN inference.""" + def test_proposition_5_nan_for_long_run_horizons(self): + """Prop 5 horizons have n_obs > 0 but NaN inference (unidentified).""" # No never-treated: cohorts 3, 5, 7; periods 0-9. - # Periods 7-9 have zero untreated obs → NaN y_tilde. - # Horizon 4 = cohort 3 at period 7 (NaN) + cohort 5 at period 9 (NaN) → 0 obs. - # Horizons 5, 6 = cohort 3 at periods 8, 9 (NaN) → 0 obs. - # Horizons 0-3 have valid observations from multiple cohorts. + # h_bar = max(groups) - min(groups) = 7 - 3 = 4. + # Horizons 0-3: identified, valid effects. + # Horizons 4, 5, 6: Prop 5 unidentified — treated obs exist but + # counterfactual is unidentified without never-treated units. data = generate_test_data(never_treated_frac=0.0) - results = TwoStageDiD().fit( - data, - outcome="outcome", - unit="unit", - time="time", - first_treat="first_treat", - aggregate="event_study", - ) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + results = TwoStageDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) assert results.event_study_effects is not None + # Check Prop 5 warning was emitted + prop5_warnings = [x for x in w if "not identified without never-treated" in str(x.message)] + assert len(prop5_warnings) > 0, "Proposition 5 warning should be emitted" + # Horizons 0-3 should have observations and finite effects for h in range(0, 4): eff = results.event_study_effects[h] assert eff["n_obs"] > 0, f"Horizon {h} should have observations" assert np.isfinite(eff["effect"]), f"Horizon {h} effect should be finite" - # Horizons 4, 5, 6 should have zero obs and NaN inference + # Horizons 4, 5, 6: Prop 5 — n_obs > 0 but NaN inference for h in [4, 5, 6]: eff = results.event_study_effects[h] - assert eff["n_obs"] == 0, f"Horizon {h} should have 0 observations" + assert eff["n_obs"] > 0, f"Horizon {h} should have n_obs > 0 (Prop 5)" assert np.isnan(eff["effect"]), f"Horizon {h} effect should be NaN" assert np.isnan(eff["se"]), f"Horizon {h} SE should be NaN" assert np.isnan(eff["t_stat"]), f"Horizon {h} t_stat should be NaN"