diff --git a/CLAUDE.md b/CLAUDE.md index e4f71006..bf5db3d5 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -103,22 +103,35 @@ cross-platform compilation - no OpenBLAS or Intel MKL installation required. - **`diff_diff/imputation.py`** - Borusyak-Jaravel-Spiess imputation DiD estimator: - `ImputationDiD` - Borusyak et al. (2024) efficient imputation estimator for staggered DiD - - `ImputationDiDResults` - Results with overall ATT, event study, group effects, pre-trend test - - `ImputationBootstrapResults` - Multiplier bootstrap inference results - `imputation_did()` - Convenience function - Steps: (1) OLS on untreated obs for unit+time FE, (2) impute counterfactual Y(0), (3) aggregate - Conservative variance (Theorem 3) with `aux_partition` parameter for SE tightness - Pre-trend test (Equation 9) via `results.pretrend_test()` - Proposition 5: NaN for unidentified long-run horizons without never-treated units + - Re-exports result and bootstrap classes for backward compatibility + +- **`diff_diff/imputation_results.py`** - Result container classes: + - `ImputationBootstrapResults` - Multiplier bootstrap inference results + - `ImputationDiDResults` - Results with overall ATT, event study, group effects, pre-trend test + +- **`diff_diff/imputation_bootstrap.py`** - Bootstrap inference: + - `ImputationDiDBootstrapMixin` - Mixin with multiplier bootstrap methods + - Cluster-level influence function sums (Theorem 3) with Rademacher weights - **`diff_diff/two_stage.py`** - Gardner (2022) Two-Stage DiD estimator: - `TwoStageDiD` - Two-stage estimator: (1) estimate unit+time FE on untreated obs, (2) regress residualized outcomes on treatment indicators - - `TwoStageDiDResults` - Results with overall ATT, event study, group effects, per-observation treatment effects - - `TwoStageBootstrapResults` - Multiplier bootstrap inference on GMM influence function - `two_stage_did()` - Convenience function - Point estimates identical to ImputationDiD; different variance estimator (GMM sandwich vs. conservative) - Custom `_compute_gmm_variance()` — cannot reuse `compute_robust_vcov()` because correction term uses GLOBAL cross-moment - No finite-sample adjustments (raw asymptotic sandwich, matching R `did2s`) + - Re-exports result and bootstrap classes for backward compatibility + +- **`diff_diff/two_stage_results.py`** - Result container classes: + - `TwoStageBootstrapResults` - Multiplier bootstrap inference on GMM influence function + - `TwoStageDiDResults` - Results with overall ATT, event study, group effects, per-observation treatment effects + +- **`diff_diff/two_stage_bootstrap.py`** - Bootstrap inference: + - `TwoStageDiDBootstrapMixin` - Mixin with GMM influence function bootstrap methods - **`diff_diff/triple_diff.py`** - Triple Difference (DDD) estimator: - `TripleDifference` - Ortiz-Villavicencio & Sant'Anna (2025) estimator for DDD designs @@ -129,7 +142,6 @@ cross-platform compilation - no OpenBLAS or Intel MKL installation required. - **`diff_diff/trop.py`** - Triply Robust Panel (TROP) estimator (v2.1.0): - `TROP` - Athey, Imbens, Qu & Viviano (2025) estimator with factor model adjustment - - `TROPResults` - Results with ATT, factors, loadings, unit/time weights - `trop()` - Convenience function for quick estimation - Three robustness components: factor adjustment, unit weights, time weights - Two estimation methods via `method` parameter: @@ -137,6 +149,12 @@ cross-platform compilation - no OpenBLAS or Intel MKL installation required. - `"joint"`: Weighted least squares with homogeneous treatment effect (faster) - Automatic rank selection via cross-validation, information criterion, or elbow detection - Bootstrap and placebo-based variance estimation + - Re-exports result classes for backward compatibility + +- **`diff_diff/trop_results.py`** - Result container classes: + - `_LAMBDA_INF` - Sentinel value for disabled factor model (λ_nn=∞) + - `_PrecomputedStructures` - TypedDict for cached matrices + - `TROPResults` - Results with ATT, factors, loadings, unit/time weights - **`diff_diff/bacon.py`** - Goodman-Bacon decomposition for TWFE diagnostics: - `BaconDecomposition` - Decompose TWFE into weighted 2x2 comparisons (Goodman-Bacon 2021) diff --git a/TODO.md b/TODO.md index 1d436bef..37485884 100644 --- a/TODO.md +++ b/TODO.md @@ -25,10 +25,10 @@ Target: < 1000 lines per module for maintainability. |------|-------|--------| | ~~`staggered.py`~~ | ~~2301~~ 1120 | ✅ Split into staggered.py, staggered_bootstrap.py, staggered_aggregation.py, staggered_results.py | | ~~`prep.py`~~ | ~~1993~~ 1241 | ✅ Split: DGP functions moved to `prep_dgp.py` (777 lines) | -| `trop.py` | 2904 | **Needs split** -- nearly 3x the 1000-line target | -| `imputation.py` | 2480 | **Needs split** -- results, bootstrap, aggregation like staggered | -| `two_stage.py` | 2209 | **Needs split** -- same pattern as imputation | -| `utils.py` | 1879 | **Needs split** -- legacy placebo functions could move to diagnostics | +| ~~`trop.py`~~ | ~~2904~~ ~2560 | ✅ Partially split: results extracted to `trop_results.py` (~340 lines) | +| ~~`imputation.py`~~ | ~~2480~~ ~1740 | ✅ Split into imputation.py, imputation_results.py, imputation_bootstrap.py | +| ~~`two_stage.py`~~ | ~~2209~~ ~1490 | ✅ Split into two_stage.py, two_stage_results.py, two_stage_bootstrap.py | +| `utils.py` | 1879 | Monitor -- legacy placebo functions stay to avoid circular imports | | `visualization.py` | 1678 | Monitor -- growing but cohesive | | `linalg.py` | 1537 | Monitor -- unified backend, splitting would hurt cohesion | | `honest_did.py` | 1511 | Acceptable | @@ -42,49 +42,11 @@ Target: < 1000 lines per module for maintainability. All 7 t_stat locations fixed (diagnostics.py, sun_abraham.py, triple_diff.py) -- all now use `np.nan` or `np.isfinite()` guards. Fixed in PR #118 and follow-up PRs. -**Remaining nuance**: `diagnostics.py:785` still has `se = ... else 0.0` for the SE variable itself (not t_stat). The downstream t_stat line correctly returns `np.nan`, so inference is safe, but the SE value of 0.0 is technically incorrect for an undefined SE. +~~**Remaining nuance**: `diagnostics.py:785` SE = 0.0~~ — ✅ Fixed: SE now returns `np.nan` when undefined, and all downstream inference uses `safe_inference()`. -### Migrate Existing Inference Call Sites to `safe_inference()` +### ~~Migrate Existing Inference Call Sites to `safe_inference()`~~ -- DONE -`safe_inference()` was added to `diff_diff/utils.py` to compute t_stat, p_value, and CI together with a NaN gate at the top. It is now the prescribed pattern for all new code (see CLAUDE.md design pattern #7). However, ~26 existing inline inference computations across 12 files have **not** been migrated yet. - -**Files with inline inference to migrate:** - -| File | Approx. Locations | -|------|-------------------| -| `estimators.py` | 2 (lines 1038, 1089) | -| `sun_abraham.py` | 4 (lines 621, 644, 661, 905) | -| `staggered.py` | 6 (lines 696, 725, 763, 777, 792, 806) | -| `staggered_aggregation.py` | 2 (lines 407, 479) | -| `triple_diff.py` | 1 (line 601) | -| `imputation.py` | 2 (lines 1806, 1927) | -| `two_stage.py` | 2 (lines 1325, 1431) | -| `diagnostics.py` | 2 (lines 665, 786) | -| `synthetic_did.py` | 1 (line 426) | -| `trop.py` | 2 (lines 1474, 2054) | -| `utils.py` | 1 (line 641) | -| `linalg.py` | 1 (line 1310) | - -**How to find them:** -```bash -grep -En "(t_stat|overall_t)\s*=\s*[^#]*/|\[.t_stat.\]\s*=\s*[^#]*/" diff_diff/*.py | grep -v "def safe_inference" -``` - -**Note**: This command has one false positive (`utils.py:178`, inside the `safe_inference()` body) and misses multi-line expressions (e.g., `sun_abraham.py:660-661`). The table above is the authoritative list. - -**Migration pattern:** -```python -# Before (inline, error-prone) -t_stat = effect / se if se > 0 else 0.0 -p_value = compute_p_value(t_stat) -ci = compute_confidence_interval(effect, se, alpha) - -# After (NaN-safe, consistent) -from diff_diff.utils import safe_inference -t_stat, p_value, ci = safe_inference(effect, se, alpha=alpha, df=df) -``` - -**Priority**: Medium — the NaN-handling table above covers the worst cases (those using `0.0`). The remaining sites may use partial guards but should still be migrated for consistency and to prevent regressions. +✅ All ~32 inline inference call sites migrated to `safe_inference()` across 11 source files: `estimators.py`, `sun_abraham.py`, `staggered.py`, `staggered_aggregation.py`, `triple_diff.py`, `imputation.py`, `two_stage.py`, `diagnostics.py`, `synthetic_did.py`, `trop.py`, `utils.py`. Two sites left as-is with comments: `diagnostics.py:665` (permutation-based p_value) and `linalg.py:1310` (deliberately uses ±inf for zero-SE). --- @@ -96,17 +58,17 @@ Deferred items from PR reviews that were not addressed before merge. | Issue | Location | PR | Priority | |-------|----------|----|----------| -| TwoStageDiD & ImputationDiD bootstrap hardcodes Rademacher only; no `bootstrap_weights` parameter unlike CallawaySantAnna | `two_stage.py:1860`, `imputation.py:2363` | #156, #141 | Medium | -| TwoStageDiD GMM score logic duplicated between analytic/bootstrap with inconsistent NaN/overflow handling | `two_stage.py:1454-1784` | #156 | Medium | -| ImputationDiD weight construction duplicated between aggregation and bootstrap (drift risk) -- has explicit code comment acknowledging duplication | `imputation.py:1777-1786`, `imputation.py:2216-2221` | #141 | Medium | -| ImputationDiD dense `(A0'A0).toarray()` scales O((U+T+K)^2), OOM risk on large panels | `imputation.py:1564` | #141 | Medium | +| TwoStageDiD & ImputationDiD bootstrap hardcodes Rademacher only; no `bootstrap_weights` parameter unlike CallawaySantAnna | `two_stage_bootstrap.py`, `imputation_bootstrap.py` | #156, #141 | Medium | +| TwoStageDiD GMM score logic duplicated between analytic/bootstrap with inconsistent NaN/overflow handling | `two_stage.py`, `two_stage_bootstrap.py` | #156 | Medium | +| ImputationDiD weight construction duplicated between aggregation and bootstrap (drift risk) -- has explicit code comment acknowledging duplication | `imputation.py`, `imputation_bootstrap.py` | #141 | Medium | +| ImputationDiD dense `(A0'A0).toarray()` scales O((U+T+K)^2), OOM risk on large panels | `imputation.py` | #141 | Medium | #### Performance | Issue | Location | PR | Priority | |-------|----------|----|----------| -| TwoStageDiD per-column `.toarray()` in loop for cluster scores | `two_stage.py:1766-1767` | #156 | Medium | -| ImputationDiD event-study SEs recompute full conservative variance per horizon (should cache A0/A1 factorization) | `imputation.py:1772-1804` | #141 | Low | +| TwoStageDiD per-column `.toarray()` in loop for cluster scores | `two_stage_bootstrap.py` | #156 | Medium | +| ImputationDiD event-study SEs recompute full conservative variance per horizon (should cache A0/A1 factorization) | `imputation.py` | #141 | Low | | Legacy `compute_placebo_effects` uses deprecated projected-gradient weights (marked deprecated, users directed to `SyntheticDiD`) | `utils.py:1689-1691` | #145 | Low | | Rust faer SVD ndarray-to-faer conversion overhead (minimal vs SVD cost) | `rust/src/linalg.rs:67` | #115 | Low | @@ -115,7 +77,7 @@ Deferred items from PR reviews that were not addressed before merge. | Issue | Location | PR | Priority | |-------|----------|----|----------| | Tutorial notebooks not executed in CI | `docs/tutorials/*.ipynb` | #159 | Low | -| TwoStageDiD `test_nan_propagation` is a no-op (only `pass`) | `tests/test_two_stage.py:643-652` | #156 | Low | +| ~~TwoStageDiD `test_nan_propagation` is a no-op~~ | ~~`tests/test_two_stage.py:643-652`~~ | ~~#156~~ | ✅ Fixed | | ImputationDiD bootstrap + covariate path untested | `tests/test_imputation.py` | #141 | Low | | TROP `n_bootstrap >= 2` validation missing (can yield 0/NaN SE silently) | `trop.py:462` | #124 | Low | | SunAbraham deprecated `min_pre_periods`/`min_post_periods` still in `fit()` docstring | `sun_abraham.py:458-487` | #153 | Low | @@ -209,19 +171,9 @@ Spurious RuntimeWarnings ("divide by zero", "overflow", "invalid value") are emi - Occurs in IPW and DR estimation methods with covariates - Related to logistic regression overflow in edge cases (separate from BLAS bug) -### Fix Plan (follow-up PR) - -Replace `@` operator with `np.dot()` at affected call sites. `np.dot()` bypasses the ufunc FPE dispatch layer and produces identical results with zero spurious warnings on M4. - -**Affected files and lines:** -- `linalg.py`: 332, 690, 704, 829, 1463 -- `staggered.py`: 78, 85, 102, 860, 1024-1025 -- `triple_diff.py`: 301, 307, 323 -- `utils.py`: 515 -- `imputation.py`: 1253, 1301, 1602, 1662 -- `trop.py`: 1098 +### ~~Fix Plan (follow-up PR)~~ -- DONE -**Regression test:** Assert no RuntimeWarnings from `solve_ols()` with n ≥ 500 on all platforms. +✅ Replaced `@` operator with `np.dot()` at all 19 affected call sites across 6 files: `linalg.py` (5), `staggered.py` (5), `triple_diff.py` (3), `utils.py` (1), `imputation.py` (4), `trop.py` (1). Regression test added in `test_linalg.py::TestNoDotRuntimeWarnings`. **Long-term:** Revert to `@` operator when numpy ≥ 2.3 becomes the minimum supported version. diff --git a/diff_diff/diagnostics.py b/diff_diff/diagnostics.py index cd31f530..e3d79c9d 100644 --- a/diff_diff/diagnostics.py +++ b/diff_diff/diagnostics.py @@ -19,7 +19,7 @@ from diff_diff.estimators import DifferenceInDifferences from diff_diff.results import _get_significance_stars -from diff_diff.utils import compute_confidence_interval, compute_p_value +from diff_diff.utils import safe_inference @dataclass @@ -661,7 +661,7 @@ def permutation_test( ci_lower = np.percentile(valid_effects, alpha / 2 * 100) ci_upper = np.percentile(valid_effects, (1 - alpha / 2) * 100) - # T-stat from original estimate + # NOTE: Not using safe_inference — p_value is permutation-based, CI is percentile-based. t_stat = original_att / se if np.isfinite(se) and se > 0 else np.nan return PlaceboTestResults( @@ -782,15 +782,9 @@ def leave_one_out_test( # Statistics of LOO distribution mean_effect = np.mean(valid_effects) - se = np.std(valid_effects, ddof=1) if len(valid_effects) > 1 else 0.0 - t_stat = mean_effect / se if np.isfinite(se) and se > 0 else np.nan - - # Use t-distribution for p-value + se = np.std(valid_effects, ddof=1) if len(valid_effects) > 1 else np.nan df = len(valid_effects) - 1 if len(valid_effects) > 1 else 1 - p_value = compute_p_value(t_stat, df=df) - - # CI - conf_int = compute_confidence_interval(mean_effect, se, alpha, df=df) if np.isfinite(se) and se > 0 else (np.nan, np.nan) + t_stat, p_value, conf_int = safe_inference(mean_effect, se, alpha=alpha, df=df) return PlaceboTestResults( test_type="leave_one_out", diff --git a/diff_diff/estimators.py b/diff_diff/estimators.py index 367e6604..c725c274 100644 --- a/diff_diff/estimators.py +++ b/diff_diff/estimators.py @@ -27,9 +27,8 @@ from diff_diff.results import DiDResults, MultiPeriodDiDResults, PeriodEffect from diff_diff.utils import ( WildBootstrapResults, - compute_confidence_interval, - compute_p_value, demean_by_group, + safe_inference, validate_binary, wild_bootstrap_se, ) @@ -1034,14 +1033,7 @@ def fit( # type: ignore[override] idx = interaction_indices[period] effect = coefficients[idx] se = np.sqrt(vcov[idx, idx]) - if np.isfinite(se) and se > 0: - t_stat = effect / se - p_value = compute_p_value(t_stat, df=df) - conf_int = compute_confidence_interval(effect, se, self.alpha, df=df) - else: - t_stat = np.nan - p_value = np.nan - conf_int = (np.nan, np.nan) + t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha, df=df) period_effects[period] = PeriodEffect( period=period, @@ -1085,15 +1077,9 @@ def fit( # type: ignore[override] avg_conf_int = (np.nan, np.nan) else: avg_se = float(np.sqrt(avg_var)) - if np.isfinite(avg_se) and avg_se > 0: - avg_t_stat = avg_att / avg_se - avg_p_value = compute_p_value(avg_t_stat, df=df) - avg_conf_int = compute_confidence_interval(avg_att, avg_se, self.alpha, df=df) - else: - # Zero SE (degenerate case) - avg_t_stat = np.nan - avg_p_value = np.nan - avg_conf_int = (np.nan, np.nan) + avg_t_stat, avg_p_value, avg_conf_int = safe_inference( + avg_att, avg_se, alpha=self.alpha, df=df + ) # Count observations n_treated = int(np.sum(d)) diff --git a/diff_diff/imputation.py b/diff_diff/imputation.py index 28455abc..ae651efe 100644 --- a/diff_diff/imputation.py +++ b/diff_diff/imputation.py @@ -15,7 +15,6 @@ """ import warnings -from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Set, Tuple import numpy as np @@ -23,420 +22,11 @@ from scipy import sparse, stats from scipy.sparse.linalg import spsolve +from diff_diff.imputation_bootstrap import ImputationDiDBootstrapMixin +from diff_diff.imputation_results import ImputationBootstrapResults, ImputationDiDResults # noqa: F401 (re-export) from diff_diff.linalg import solve_ols -from diff_diff.results import _get_significance_stars -from diff_diff.utils import compute_confidence_interval, compute_p_value +from diff_diff.utils import safe_inference -# ============================================================================= -# Results Dataclasses -# ============================================================================= - - -@dataclass -class ImputationBootstrapResults: - """ - Results from ImputationDiD bootstrap inference. - - Bootstrap is a library extension beyond Borusyak et al. (2024), which - proposes only analytical inference via the conservative variance estimator. - Provided for consistency with CallawaySantAnna and SunAbraham. - - Attributes - ---------- - n_bootstrap : int - Number of bootstrap iterations. - weight_type : str - Type of bootstrap weights (currently "rademacher" only). - alpha : float - Significance level used for confidence intervals. - overall_att_se : float - Bootstrap standard error for overall ATT. - overall_att_ci : tuple - Bootstrap confidence interval for overall ATT. - overall_att_p_value : float - Bootstrap p-value for overall ATT. - event_study_ses : dict, optional - Bootstrap SEs for event study effects. - event_study_cis : dict, optional - Bootstrap CIs for event study effects. - event_study_p_values : dict, optional - Bootstrap p-values for event study effects. - group_ses : dict, optional - Bootstrap SEs for group effects. - group_cis : dict, optional - Bootstrap CIs for group effects. - group_p_values : dict, optional - Bootstrap p-values for group effects. - bootstrap_distribution : np.ndarray, optional - Full bootstrap distribution of overall ATT. - """ - - n_bootstrap: int - weight_type: str - alpha: float - overall_att_se: float - overall_att_ci: Tuple[float, float] - overall_att_p_value: float - event_study_ses: Optional[Dict[int, float]] = None - event_study_cis: Optional[Dict[int, Tuple[float, float]]] = None - event_study_p_values: Optional[Dict[int, float]] = None - group_ses: Optional[Dict[Any, float]] = None - group_cis: Optional[Dict[Any, Tuple[float, float]]] = None - group_p_values: Optional[Dict[Any, float]] = None - bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False) - - -@dataclass -class ImputationDiDResults: - """ - Results from Borusyak-Jaravel-Spiess (2024) imputation DiD estimation. - - Attributes - ---------- - treatment_effects : pd.DataFrame - Unit-level treatment effects with columns: unit, time, tau_hat, weight. - overall_att : float - Overall average treatment effect on the treated. - overall_se : float - Standard error of overall ATT. - overall_t_stat : float - T-statistic for overall ATT. - overall_p_value : float - P-value for overall ATT. - overall_conf_int : tuple - Confidence interval for overall ATT. - event_study_effects : dict, optional - Dictionary mapping relative time h to effect dict with keys: - 'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_obs'. - group_effects : dict, optional - Dictionary mapping cohort g to effect dict. - groups : list - List of treatment cohorts. - time_periods : list - List of all time periods. - n_obs : int - Total number of observations. - n_treated_obs : int - Number of treated observations (|Omega_1|). - n_untreated_obs : int - Number of untreated observations (|Omega_0|). - n_treated_units : int - Number of ever-treated units. - n_control_units : int - Number of units contributing to Omega_0. - alpha : float - Significance level used. - pretrend_results : dict, optional - Populated by pretrend_test(). - bootstrap_results : ImputationBootstrapResults, optional - Bootstrap inference results. - """ - - treatment_effects: pd.DataFrame - overall_att: float - overall_se: float - overall_t_stat: float - overall_p_value: float - overall_conf_int: Tuple[float, float] - event_study_effects: Optional[Dict[int, Dict[str, Any]]] - group_effects: Optional[Dict[Any, Dict[str, Any]]] - groups: List[Any] - time_periods: List[Any] - n_obs: int - n_treated_obs: int - n_untreated_obs: int - n_treated_units: int - n_control_units: int - alpha: float = 0.05 - pretrend_results: Optional[Dict[str, Any]] = field(default=None, repr=False) - bootstrap_results: Optional[ImputationBootstrapResults] = field(default=None, repr=False) - # Internal: stores data needed for pretrend_test() - _estimator_ref: Optional[Any] = field(default=None, repr=False) - - def __repr__(self) -> str: - """Concise string representation.""" - sig = _get_significance_stars(self.overall_p_value) - return ( - f"ImputationDiDResults(ATT={self.overall_att:.4f}{sig}, " - f"SE={self.overall_se:.4f}, " - f"n_groups={len(self.groups)}, " - f"n_treated_obs={self.n_treated_obs})" - ) - - def summary(self, alpha: Optional[float] = None) -> str: - """ - Generate formatted summary of estimation results. - - Parameters - ---------- - alpha : float, optional - Significance level. Defaults to alpha used in estimation. - - Returns - ------- - str - Formatted summary. - """ - alpha = alpha or self.alpha - conf_level = int((1 - alpha) * 100) - - lines = [ - "=" * 85, - "Imputation DiD Estimator Results (Borusyak et al. 2024)".center(85), - "=" * 85, - "", - f"{'Total observations:':<30} {self.n_obs:>10}", - f"{'Treated observations:':<30} {self.n_treated_obs:>10}", - f"{'Untreated observations:':<30} {self.n_untreated_obs:>10}", - f"{'Treated units:':<30} {self.n_treated_units:>10}", - f"{'Control units:':<30} {self.n_control_units:>10}", - f"{'Treatment cohorts:':<30} {len(self.groups):>10}", - f"{'Time periods:':<30} {len(self.time_periods):>10}", - "", - ] - - # Overall ATT - lines.extend( - [ - "-" * 85, - "Overall Average Treatment Effect on the Treated".center(85), - "-" * 85, - f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} " - f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", - "-" * 85, - ] - ) - - t_str = ( - f"{self.overall_t_stat:>10.3f}" if np.isfinite(self.overall_t_stat) else f"{'NaN':>10}" - ) - p_str = ( - f"{self.overall_p_value:>10.4f}" - if np.isfinite(self.overall_p_value) - else f"{'NaN':>10}" - ) - sig = _get_significance_stars(self.overall_p_value) - - lines.extend( - [ - f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} " - f"{t_str} {p_str} {sig:>6}", - "-" * 85, - "", - f"{conf_level}% Confidence Interval: " - f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]", - "", - ] - ) - - # Event study effects - if self.event_study_effects: - lines.extend( - [ - "-" * 85, - "Event Study (Dynamic) Effects".center(85), - "-" * 85, - f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} " - f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", - "-" * 85, - ] - ) - - for h in sorted(self.event_study_effects.keys()): - eff = self.event_study_effects[h] - if eff.get("n_obs", 1) == 0: - # Reference period marker - lines.append( - f"[ref: {h}]" f"{'0.0000':>17} {'---':>12} {'---':>10} {'---':>10} {'':>6}" - ) - elif np.isnan(eff["effect"]): - lines.append(f"{h:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}") - else: - e_sig = _get_significance_stars(eff["p_value"]) - e_t = ( - f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}" - ) - e_p = ( - f"{eff['p_value']:>10.4f}" - if np.isfinite(eff["p_value"]) - else f"{'NaN':>10}" - ) - lines.append( - f"{h:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " - f"{e_t} {e_p} {e_sig:>6}" - ) - - lines.extend(["-" * 85, ""]) - - # Group effects - if self.group_effects: - lines.extend( - [ - "-" * 85, - "Group (Cohort) Effects".center(85), - "-" * 85, - f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} " - f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", - "-" * 85, - ] - ) - - for g in sorted(self.group_effects.keys()): - eff = self.group_effects[g] - if np.isnan(eff["effect"]): - lines.append(f"{g:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}") - else: - g_sig = _get_significance_stars(eff["p_value"]) - g_t = ( - f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}" - ) - g_p = ( - f"{eff['p_value']:>10.4f}" - if np.isfinite(eff["p_value"]) - else f"{'NaN':>10}" - ) - lines.append( - f"{g:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " - f"{g_t} {g_p} {g_sig:>6}" - ) - - lines.extend(["-" * 85, ""]) - - # Pre-trend test - if self.pretrend_results is not None: - pt = self.pretrend_results - lines.extend( - [ - "-" * 85, - "Pre-Trend Test (Equation 9)".center(85), - "-" * 85, - f"{'F-statistic:':<30} {pt['f_stat']:>10.3f}", - f"{'P-value:':<30} {pt['p_value']:>10.4f}", - f"{'Degrees of freedom:':<30} {pt['df']:>10}", - f"{'Number of leads:':<30} {pt['n_leads']:>10}", - "-" * 85, - "", - ] - ) - - lines.extend( - [ - "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", - "=" * 85, - ] - ) - - return "\n".join(lines) - - def print_summary(self, alpha: Optional[float] = None) -> None: - """Print summary to stdout.""" - print(self.summary(alpha)) - - def to_dataframe(self, level: str = "observation") -> pd.DataFrame: - """ - Convert results to DataFrame. - - Parameters - ---------- - level : str, default="observation" - Level of aggregation: - - "observation": Unit-level treatment effects - - "event_study": Event study effects by relative time - - "group": Group (cohort) effects - - Returns - ------- - pd.DataFrame - Results as DataFrame. - """ - if level == "observation": - return self.treatment_effects.copy() - - elif level == "event_study": - if self.event_study_effects is None: - raise ValueError( - "Event study effects not computed. " - "Use aggregate='event_study' or aggregate='all'." - ) - rows = [] - for h, data in sorted(self.event_study_effects.items()): - rows.append( - { - "relative_period": h, - "effect": data["effect"], - "se": data["se"], - "t_stat": data["t_stat"], - "p_value": data["p_value"], - "conf_int_lower": data["conf_int"][0], - "conf_int_upper": data["conf_int"][1], - "n_obs": data.get("n_obs", np.nan), - } - ) - return pd.DataFrame(rows) - - elif level == "group": - if self.group_effects is None: - raise ValueError( - "Group effects not computed. " "Use aggregate='group' or aggregate='all'." - ) - rows = [] - for g, data in sorted(self.group_effects.items()): - rows.append( - { - "group": g, - "effect": data["effect"], - "se": data["se"], - "t_stat": data["t_stat"], - "p_value": data["p_value"], - "conf_int_lower": data["conf_int"][0], - "conf_int_upper": data["conf_int"][1], - "n_obs": data.get("n_obs", np.nan), - } - ) - return pd.DataFrame(rows) - - else: - raise ValueError( - f"Unknown level: {level}. Use 'observation', 'event_study', or 'group'." - ) - - def pretrend_test(self, n_leads: Optional[int] = None) -> Dict[str, Any]: - """ - Run a pre-trend test (Equation 9 of Borusyak et al. 2024). - - Adds pre-treatment lead indicators to the Step 1 OLS and tests - their joint significance via a cluster-robust Wald F-test. - - Parameters - ---------- - n_leads : int, optional - Number of pre-treatment leads to include. If None, uses all - available pre-treatment periods minus one (for the reference period). - - Returns - ------- - dict - Dictionary with keys: 'f_stat', 'p_value', 'df', 'n_leads', - 'lead_coefficients'. - """ - if self._estimator_ref is None: - raise RuntimeError( - "Pre-trend test requires internal estimator reference. " - "Re-fit the model to use this method." - ) - result = self._estimator_ref._pretrend_test(n_leads=n_leads) - self.pretrend_results = result - return result - - @property - def is_significant(self) -> bool: - """Check if overall ATT is significant.""" - return bool(self.overall_p_value < self.alpha) - - @property - def significance_stars(self) -> str: - """Significance stars for overall ATT.""" - return _get_significance_stars(self.overall_p_value) # ============================================================================= @@ -444,7 +34,7 @@ def significance_stars(self) -> str: # ============================================================================= -class ImputationDiD: +class ImputationDiD(ImputationDiDBootstrapMixin): """ Borusyak-Jaravel-Spiess (2024) imputation DiD estimator. @@ -818,14 +408,8 @@ def fit( kept_cov_mask=kept_cov_mask, ) - overall_t = ( - overall_att / overall_se if np.isfinite(overall_se) and overall_se > 0 else np.nan - ) - overall_p = compute_p_value(overall_t) - overall_ci = ( - compute_confidence_interval(overall_att, overall_se, self.alpha) - if np.isfinite(overall_se) and overall_se > 0 - else (np.nan, np.nan) + overall_t, overall_p, overall_ci = safe_inference( + overall_att, overall_se, alpha=self.alpha ) # Event study and group aggregation @@ -966,9 +550,9 @@ def fit( ] eff_val = event_study_effects[h]["effect"] se_val = event_study_effects[h]["se"] - event_study_effects[h]["t_stat"] = ( - eff_val / se_val if np.isfinite(se_val) and se_val > 0 else np.nan - ) + event_study_effects[h]["t_stat"] = safe_inference( + eff_val, se_val, alpha=self.alpha + )[0] # Update group effects if group_effects and bootstrap_results.group_ses: @@ -979,9 +563,9 @@ def fit( group_effects[g]["p_value"] = bootstrap_results.group_p_values[g] eff_val = group_effects[g]["effect"] se_val = group_effects[g]["se"] - group_effects[g]["t_stat"] = ( - eff_val / se_val if np.isfinite(se_val) and se_val > 0 else np.nan - ) + group_effects[g]["t_stat"] = safe_inference( + eff_val, se_val, alpha=self.alpha + )[0] # Construct results self.results_ = ImputationDiDResults( @@ -1250,7 +834,7 @@ def _fit_untreated_model( delta_hat_clean = np.where(np.isfinite(delta_hat), delta_hat, 0.0) # Step C: Recover FE from covariate-adjusted outcome using iterative FE - y_adj = y - X_raw @ delta_hat_clean + y_adj = y - np.dot(X_raw, delta_hat_clean) unit_fe, time_fe = self._iterative_fe(y_adj, units, times, df_0.index) # grand_mean = 0: iterative FE absorb the intercept @@ -1298,7 +882,7 @@ def _impute_treatment_effects( if delta_hat is not None and covariates: X_1 = df_1[covariates].values - y_hat_0 = y_hat_0 + X_1 @ delta_hat + y_hat_0 = y_hat_0 + np.dot(X_1, delta_hat) tau_hat = df_1[outcome].values - y_hat_0 @@ -1599,7 +1183,7 @@ def _compute_auxiliary_residuals_treated( y_hat_0 = grand_mean + alpha_i + beta_t if delta_hat is not None and covariates: - y_hat_0 = y_hat_0 + df_1[covariates].values @ delta_hat + y_hat_0 = y_hat_0 + np.dot(df_1[covariates].values, delta_hat) tau_hat = df_1[outcome].values - y_hat_0 @@ -1659,7 +1243,7 @@ def _compute_residuals_untreated( y_hat = grand_mean + alpha_i + beta_t if delta_hat is not None and covariates: - y_hat = y_hat + df_0[covariates].values @ delta_hat + y_hat = y_hat + np.dot(df_0[covariates].values, delta_hat) return df_0[outcome].values - y_hat @@ -1803,13 +1387,7 @@ def _aggregate_event_study( kept_cov_mask=kept_cov_mask, ) - t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan - p_value = compute_p_value(t_stat) - conf_int = ( - compute_confidence_interval(effect, se, self.alpha) - if np.isfinite(se) and se > 0 - else (np.nan, np.nan) - ) + t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha) event_study_effects[h] = { "effect": effect, @@ -1924,13 +1502,7 @@ def _aggregate_group( kept_cov_mask=kept_cov_mask, ) - t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan - p_value = compute_p_value(t_stat) - conf_int = ( - compute_confidence_interval(effect, se, self.alpha) - if np.isfinite(se) and se > 0 - else (np.nan, np.nan) - ) + t_stat, p_value, conf_int = safe_inference(effect, se, alpha=self.alpha) group_effects[g] = { "effect": effect, @@ -2081,299 +1653,6 @@ def _pretrend_test(self, n_leads: Optional[int] = None) -> Dict[str, Any]: "lead_coefficients": lead_coefficients, } - # ========================================================================= - # Bootstrap - # ========================================================================= - - def _compute_percentile_ci( - self, - boot_dist: np.ndarray, - alpha: float, - ) -> Tuple[float, float]: - """Compute percentile confidence interval from bootstrap distribution.""" - lower = float(np.percentile(boot_dist, alpha / 2 * 100)) - upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100)) - return (lower, upper) - - def _compute_bootstrap_pvalue( - self, - original_effect: float, - boot_dist: np.ndarray, - n_valid: Optional[int] = None, - ) -> float: - """ - Compute two-sided bootstrap p-value. - - Uses the percentile method: p-value is the proportion of bootstrap - estimates on the opposite side of zero from the original estimate, - doubled for two-sided test. - - Parameters - ---------- - original_effect : float - Original point estimate. - boot_dist : np.ndarray - Bootstrap distribution of the effect. - n_valid : int, optional - Number of valid bootstrap samples. If None, uses self.n_bootstrap. - """ - if original_effect >= 0: - p_one_sided = float(np.mean(boot_dist <= 0)) - else: - p_one_sided = float(np.mean(boot_dist >= 0)) - p_value = min(2 * p_one_sided, 1.0) - n_for_floor = n_valid if n_valid is not None else self.n_bootstrap - p_value = max(p_value, 1 / (n_for_floor + 1)) - return p_value - - def _precompute_bootstrap_psi( - self, - df: pd.DataFrame, - outcome: str, - unit: str, - time: str, - first_treat: str, - covariates: Optional[List[str]], - omega_0_mask: pd.Series, - omega_1_mask: pd.Series, - unit_fe: Dict[Any, float], - time_fe: Dict[Any, float], - grand_mean: float, - delta_hat: Optional[np.ndarray], - cluster_var: str, - kept_cov_mask: Optional[np.ndarray], - overall_weights: np.ndarray, - event_study_effects: Optional[Dict[int, Dict[str, Any]]], - group_effects: Optional[Dict[Any, Dict[str, Any]]], - treatment_groups: List[Any], - tau_hat: np.ndarray, - balance_e: Optional[int], - ) -> Dict[str, Any]: - """ - Pre-compute cluster-level influence function sums for each bootstrap target. - - For each aggregation target (overall, per-horizon, per-group), computes - psi_i = sum_t v_it * epsilon_tilde_it for each cluster. The multiplier - bootstrap then perturbs these psi sums with Rademacher weights. - - Computational cost scales with the number of aggregation targets, since - each target requires its own v_untreated computation (weight-dependent). - """ - result: Dict[str, Any] = {} - - common = dict( - df=df, - outcome=outcome, - unit=unit, - time=time, - first_treat=first_treat, - covariates=covariates, - omega_0_mask=omega_0_mask, - omega_1_mask=omega_1_mask, - unit_fe=unit_fe, - time_fe=time_fe, - grand_mean=grand_mean, - delta_hat=delta_hat, - cluster_var=cluster_var, - kept_cov_mask=kept_cov_mask, - ) - - # Overall ATT - overall_psi, cluster_ids = self._compute_cluster_psi_sums(**common, weights=overall_weights) - result["overall"] = (overall_psi, cluster_ids) - - # Event study: per-horizon weights - # NOTE: weight logic duplicated from _aggregate_event_study. - # If weight scheme changes there, update here too. - if event_study_effects: - result["event_study"] = {} - df_1 = df.loc[omega_1_mask] - rel_times = df_1["_rel_time"].values - n_omega_1 = int(omega_1_mask.sum()) - - # Balanced cohort mask (same logic as _aggregate_event_study) - balanced_mask = None - if balance_e is not None: - all_horizons = sorted(set(int(h) for h in rel_times if np.isfinite(h))) - if self.horizon_max is not None: - all_horizons = [h for h in all_horizons if abs(h) <= self.horizon_max] - cohort_rel_times = self._build_cohort_rel_times(df, first_treat) - balanced_mask = self._compute_balanced_cohort_mask( - df_1, first_treat, all_horizons, balance_e, cohort_rel_times - ) - - ref_period = -1 - self.anticipation - for h in event_study_effects: - if event_study_effects[h].get("n_obs", 0) == 0: - continue - if h == ref_period: - continue - if not np.isfinite(event_study_effects[h].get("effect", np.nan)): - continue - h_mask = rel_times == h - if balanced_mask is not None: - h_mask = h_mask & balanced_mask - weights_h = np.zeros(n_omega_1) - finite_h = np.isfinite(tau_hat) & h_mask - n_valid_h = int(finite_h.sum()) - if n_valid_h == 0: - continue - weights_h[np.where(finite_h)[0]] = 1.0 / n_valid_h - - psi_h, _ = self._compute_cluster_psi_sums(**common, weights=weights_h) - result["event_study"][h] = psi_h - - # Group effects: per-group weights - # NOTE: weight logic duplicated from _aggregate_group. - # If weight scheme changes there, update here too. - if group_effects: - result["group"] = {} - df_1 = df.loc[omega_1_mask] - cohorts = df_1[first_treat].values - n_omega_1 = int(omega_1_mask.sum()) - - for g in group_effects: - if group_effects[g].get("n_obs", 0) == 0: - continue - if not np.isfinite(group_effects[g].get("effect", np.nan)): - continue - g_mask = cohorts == g - weights_g = np.zeros(n_omega_1) - finite_g = np.isfinite(tau_hat) & g_mask - n_valid_g = int(finite_g.sum()) - if n_valid_g == 0: - continue - weights_g[np.where(finite_g)[0]] = 1.0 / n_valid_g - - psi_g, _ = self._compute_cluster_psi_sums(**common, weights=weights_g) - result["group"][g] = psi_g - - return result - - def _run_bootstrap( - self, - original_att: float, - original_event_study: Optional[Dict[int, Dict[str, Any]]], - original_group: Optional[Dict[Any, Dict[str, Any]]], - psi_data: Dict[str, Any], - ) -> ImputationBootstrapResults: - """ - Run multiplier bootstrap on pre-computed influence function sums. - - Uses T_b = sum_i w_b_i * psi_i where w_b_i are Rademacher weights - and psi_i are cluster-level influence function sums from Theorem 3. - SE = std(T_b, ddof=1). - """ - if self.n_bootstrap < 50: - warnings.warn( - f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 " - "for reliable inference.", - UserWarning, - stacklevel=3, - ) - - rng = np.random.default_rng(self.seed) - - from diff_diff.staggered_bootstrap import _generate_bootstrap_weights_batch - - overall_psi, cluster_ids = psi_data["overall"] - n_clusters = len(cluster_ids) - - # Generate ALL weights upfront: shape (n_bootstrap, n_clusters) - all_weights = _generate_bootstrap_weights_batch( - self.n_bootstrap, n_clusters, "rademacher", rng - ) - - # Overall ATT bootstrap draws - boot_overall = all_weights @ overall_psi # (n_bootstrap,) - - # Event study: loop over horizons - boot_event_study: Optional[Dict[int, np.ndarray]] = None - if original_event_study and "event_study" in psi_data: - boot_event_study = {} - for h, psi_h in psi_data["event_study"].items(): - boot_event_study[h] = all_weights @ psi_h - - # Group effects: loop over groups - boot_group: Optional[Dict[Any, np.ndarray]] = None - if original_group and "group" in psi_data: - boot_group = {} - for g, psi_g in psi_data["group"].items(): - boot_group[g] = all_weights @ psi_g - - # --- Inference (percentile bootstrap, matching CS/SA convention) --- - # Shift perturbation-centered draws to effect-centered draws. - # The multiplier bootstrap produces T_b = sum w_b_i * psi_i centered at 0. - # CS adds the original effect back (L411 of staggered_bootstrap.py). - # We do the same here so percentile CIs and empirical p-values work correctly. - boot_overall_shifted = boot_overall + original_att - - overall_se = float(np.std(boot_overall, ddof=1)) - overall_ci = ( - self._compute_percentile_ci(boot_overall_shifted, self.alpha) - if overall_se > 0 - else (np.nan, np.nan) - ) - overall_p = ( - self._compute_bootstrap_pvalue(original_att, boot_overall_shifted) - if overall_se > 0 - else np.nan - ) - - event_study_ses = None - event_study_cis = None - event_study_p_values = None - if boot_event_study and original_event_study: - event_study_ses = {} - event_study_cis = {} - event_study_p_values = {} - for h in boot_event_study: - se_h = float(np.std(boot_event_study[h], ddof=1)) - event_study_ses[h] = se_h - orig_eff = original_event_study[h]["effect"] - if se_h > 0 and np.isfinite(orig_eff): - shifted_h = boot_event_study[h] + orig_eff - event_study_p_values[h] = self._compute_bootstrap_pvalue(orig_eff, shifted_h) - event_study_cis[h] = self._compute_percentile_ci(shifted_h, self.alpha) - else: - event_study_p_values[h] = np.nan - event_study_cis[h] = (np.nan, np.nan) - - group_ses = None - group_cis = None - group_p_values = None - if boot_group and original_group: - group_ses = {} - group_cis = {} - group_p_values = {} - for g in boot_group: - se_g = float(np.std(boot_group[g], ddof=1)) - group_ses[g] = se_g - orig_eff = original_group[g]["effect"] - if se_g > 0 and np.isfinite(orig_eff): - shifted_g = boot_group[g] + orig_eff - group_p_values[g] = self._compute_bootstrap_pvalue(orig_eff, shifted_g) - group_cis[g] = self._compute_percentile_ci(shifted_g, self.alpha) - else: - group_p_values[g] = np.nan - group_cis[g] = (np.nan, np.nan) - - return ImputationBootstrapResults( - n_bootstrap=self.n_bootstrap, - weight_type="rademacher", - alpha=self.alpha, - overall_att_se=overall_se, - overall_att_ci=overall_ci, - overall_att_p_value=overall_p, - event_study_ses=event_study_ses, - event_study_cis=event_study_cis, - event_study_p_values=event_study_p_values, - group_ses=group_ses, - group_cis=group_cis, - group_p_values=group_p_values, - bootstrap_distribution=boot_overall_shifted, - ) - # ========================================================================= # sklearn-compatible interface # ========================================================================= diff --git a/diff_diff/imputation_bootstrap.py b/diff_diff/imputation_bootstrap.py new file mode 100644 index 00000000..7d9f1e3d --- /dev/null +++ b/diff_diff/imputation_bootstrap.py @@ -0,0 +1,310 @@ +""" +Bootstrap inference methods for the Imputation DiD estimator. + +This module contains ImputationDiDBootstrapMixin, which provides multiplier +bootstrap inference. Extracted from imputation.py for module size management. +""" + +import warnings +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from diff_diff.imputation_results import ImputationBootstrapResults +from diff_diff.staggered_bootstrap import _generate_bootstrap_weights_batch + +__all__ = [ + "ImputationDiDBootstrapMixin", +] + + +class ImputationDiDBootstrapMixin: + """Mixin providing bootstrap inference methods for ImputationDiD.""" + + def _compute_percentile_ci( + self, + boot_dist: np.ndarray, + alpha: float, + ) -> Tuple[float, float]: + """Compute percentile confidence interval from bootstrap distribution.""" + lower = float(np.percentile(boot_dist, alpha / 2 * 100)) + upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100)) + return (lower, upper) + + def _compute_bootstrap_pvalue( + self, + original_effect: float, + boot_dist: np.ndarray, + n_valid: Optional[int] = None, + ) -> float: + """ + Compute two-sided bootstrap p-value. + + Uses the percentile method: p-value is the proportion of bootstrap + estimates on the opposite side of zero from the original estimate, + doubled for two-sided test. + + Parameters + ---------- + original_effect : float + Original point estimate. + boot_dist : np.ndarray + Bootstrap distribution of the effect. + n_valid : int, optional + Number of valid bootstrap samples. If None, uses self.n_bootstrap. + """ + if original_effect >= 0: + p_one_sided = float(np.mean(boot_dist <= 0)) + else: + p_one_sided = float(np.mean(boot_dist >= 0)) + p_value = min(2 * p_one_sided, 1.0) + n_for_floor = n_valid if n_valid is not None else self.n_bootstrap + p_value = max(p_value, 1 / (n_for_floor + 1)) + return p_value + + def _precompute_bootstrap_psi( + self, + df: pd.DataFrame, + outcome: str, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + omega_1_mask: pd.Series, + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + grand_mean: float, + delta_hat: Optional[np.ndarray], + cluster_var: str, + kept_cov_mask: Optional[np.ndarray], + overall_weights: np.ndarray, + event_study_effects: Optional[Dict[int, Dict[str, Any]]], + group_effects: Optional[Dict[Any, Dict[str, Any]]], + treatment_groups: List[Any], + tau_hat: np.ndarray, + balance_e: Optional[int], + ) -> Dict[str, Any]: + """ + Pre-compute cluster-level influence function sums for each bootstrap target. + + For each aggregation target (overall, per-horizon, per-group), computes + psi_i = sum_t v_it * epsilon_tilde_it for each cluster. The multiplier + bootstrap then perturbs these psi sums with Rademacher weights. + + Computational cost scales with the number of aggregation targets, since + each target requires its own v_untreated computation (weight-dependent). + """ + result: Dict[str, Any] = {} + + common = dict( + df=df, + outcome=outcome, + unit=unit, + time=time, + first_treat=first_treat, + covariates=covariates, + omega_0_mask=omega_0_mask, + omega_1_mask=omega_1_mask, + unit_fe=unit_fe, + time_fe=time_fe, + grand_mean=grand_mean, + delta_hat=delta_hat, + cluster_var=cluster_var, + kept_cov_mask=kept_cov_mask, + ) + + # Overall ATT + overall_psi, cluster_ids = self._compute_cluster_psi_sums(**common, weights=overall_weights) + result["overall"] = (overall_psi, cluster_ids) + + # Event study: per-horizon weights + # NOTE: weight logic duplicated from _aggregate_event_study. + # If weight scheme changes there, update here too. + if event_study_effects: + result["event_study"] = {} + df_1 = df.loc[omega_1_mask] + rel_times = df_1["_rel_time"].values + n_omega_1 = int(omega_1_mask.sum()) + + # Balanced cohort mask (same logic as _aggregate_event_study) + balanced_mask = None + if balance_e is not None: + all_horizons = sorted(set(int(h) for h in rel_times if np.isfinite(h))) + if self.horizon_max is not None: + all_horizons = [h for h in all_horizons if abs(h) <= self.horizon_max] + cohort_rel_times = self._build_cohort_rel_times(df, first_treat) + balanced_mask = self._compute_balanced_cohort_mask( + df_1, first_treat, all_horizons, balance_e, cohort_rel_times + ) + + ref_period = -1 - self.anticipation + for h in event_study_effects: + if event_study_effects[h].get("n_obs", 0) == 0: + continue + if h == ref_period: + continue + if not np.isfinite(event_study_effects[h].get("effect", np.nan)): + continue + h_mask = rel_times == h + if balanced_mask is not None: + h_mask = h_mask & balanced_mask + weights_h = np.zeros(n_omega_1) + finite_h = np.isfinite(tau_hat) & h_mask + n_valid_h = int(finite_h.sum()) + if n_valid_h == 0: + continue + weights_h[np.where(finite_h)[0]] = 1.0 / n_valid_h + + psi_h, _ = self._compute_cluster_psi_sums(**common, weights=weights_h) + result["event_study"][h] = psi_h + + # Group effects: per-group weights + # NOTE: weight logic duplicated from _aggregate_group. + # If weight scheme changes there, update here too. + if group_effects: + result["group"] = {} + df_1 = df.loc[omega_1_mask] + cohorts = df_1[first_treat].values + n_omega_1 = int(omega_1_mask.sum()) + + for g in group_effects: + if group_effects[g].get("n_obs", 0) == 0: + continue + if not np.isfinite(group_effects[g].get("effect", np.nan)): + continue + g_mask = cohorts == g + weights_g = np.zeros(n_omega_1) + finite_g = np.isfinite(tau_hat) & g_mask + n_valid_g = int(finite_g.sum()) + if n_valid_g == 0: + continue + weights_g[np.where(finite_g)[0]] = 1.0 / n_valid_g + + psi_g, _ = self._compute_cluster_psi_sums(**common, weights=weights_g) + result["group"][g] = psi_g + + return result + + def _run_bootstrap( + self, + original_att: float, + original_event_study: Optional[Dict[int, Dict[str, Any]]], + original_group: Optional[Dict[Any, Dict[str, Any]]], + psi_data: Dict[str, Any], + ) -> ImputationBootstrapResults: + """ + Run multiplier bootstrap on pre-computed influence function sums. + + Uses T_b = sum_i w_b_i * psi_i where w_b_i are Rademacher weights + and psi_i are cluster-level influence function sums from Theorem 3. + SE = std(T_b, ddof=1). + """ + if self.n_bootstrap < 50: + warnings.warn( + f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 " + "for reliable inference.", + UserWarning, + stacklevel=3, + ) + + rng = np.random.default_rng(self.seed) + + overall_psi, cluster_ids = psi_data["overall"] + n_clusters = len(cluster_ids) + + # Generate ALL weights upfront: shape (n_bootstrap, n_clusters) + all_weights = _generate_bootstrap_weights_batch( + self.n_bootstrap, n_clusters, "rademacher", rng + ) + + # Overall ATT bootstrap draws + boot_overall = np.dot(all_weights, overall_psi) # (n_bootstrap,) + + # Event study: loop over horizons + boot_event_study: Optional[Dict[int, np.ndarray]] = None + if original_event_study and "event_study" in psi_data: + boot_event_study = {} + for h, psi_h in psi_data["event_study"].items(): + boot_event_study[h] = np.dot(all_weights, psi_h) + + # Group effects: loop over groups + boot_group: Optional[Dict[Any, np.ndarray]] = None + if original_group and "group" in psi_data: + boot_group = {} + for g, psi_g in psi_data["group"].items(): + boot_group[g] = np.dot(all_weights, psi_g) + + # --- Inference (percentile bootstrap, matching CS/SA convention) --- + # Shift perturbation-centered draws to effect-centered draws. + # The multiplier bootstrap produces T_b = sum w_b_i * psi_i centered at 0. + # CS adds the original effect back (L411 of staggered_bootstrap.py). + # We do the same here so percentile CIs and empirical p-values work correctly. + boot_overall_shifted = boot_overall + original_att + + overall_se = float(np.std(boot_overall, ddof=1)) + overall_ci = ( + self._compute_percentile_ci(boot_overall_shifted, self.alpha) + if overall_se > 0 + else (np.nan, np.nan) + ) + overall_p = ( + self._compute_bootstrap_pvalue(original_att, boot_overall_shifted) + if overall_se > 0 + else np.nan + ) + + event_study_ses = None + event_study_cis = None + event_study_p_values = None + if boot_event_study and original_event_study: + event_study_ses = {} + event_study_cis = {} + event_study_p_values = {} + for h in boot_event_study: + se_h = float(np.std(boot_event_study[h], ddof=1)) + event_study_ses[h] = se_h + orig_eff = original_event_study[h]["effect"] + if se_h > 0 and np.isfinite(orig_eff): + shifted_h = boot_event_study[h] + orig_eff + event_study_p_values[h] = self._compute_bootstrap_pvalue(orig_eff, shifted_h) + event_study_cis[h] = self._compute_percentile_ci(shifted_h, self.alpha) + else: + event_study_p_values[h] = np.nan + event_study_cis[h] = (np.nan, np.nan) + + group_ses = None + group_cis = None + group_p_values = None + if boot_group and original_group: + group_ses = {} + group_cis = {} + group_p_values = {} + for g in boot_group: + se_g = float(np.std(boot_group[g], ddof=1)) + group_ses[g] = se_g + orig_eff = original_group[g]["effect"] + if se_g > 0 and np.isfinite(orig_eff): + shifted_g = boot_group[g] + orig_eff + group_p_values[g] = self._compute_bootstrap_pvalue(orig_eff, shifted_g) + group_cis[g] = self._compute_percentile_ci(shifted_g, self.alpha) + else: + group_p_values[g] = np.nan + group_cis[g] = (np.nan, np.nan) + + return ImputationBootstrapResults( + n_bootstrap=self.n_bootstrap, + weight_type="rademacher", + alpha=self.alpha, + overall_att_se=overall_se, + overall_att_ci=overall_ci, + overall_att_p_value=overall_p, + event_study_ses=event_study_ses, + event_study_cis=event_study_cis, + event_study_p_values=event_study_p_values, + group_ses=group_ses, + group_cis=group_cis, + group_p_values=group_p_values, + bootstrap_distribution=boot_overall_shifted, + ) diff --git a/diff_diff/imputation_results.py b/diff_diff/imputation_results.py new file mode 100644 index 00000000..6a05a5d8 --- /dev/null +++ b/diff_diff/imputation_results.py @@ -0,0 +1,426 @@ +""" +Result containers for the Imputation DiD estimator. + +This module contains ImputationBootstrapResults and ImputationDiDResults +dataclasses. Extracted from imputation.py for module size management. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from diff_diff.results import _get_significance_stars + +__all__ = [ + "ImputationBootstrapResults", + "ImputationDiDResults", +] + + +@dataclass +class ImputationBootstrapResults: + """ + Results from ImputationDiD bootstrap inference. + + Bootstrap is a library extension beyond Borusyak et al. (2024), which + proposes only analytical inference via the conservative variance estimator. + Provided for consistency with CallawaySantAnna and SunAbraham. + + Attributes + ---------- + n_bootstrap : int + Number of bootstrap iterations. + weight_type : str + Type of bootstrap weights (currently "rademacher" only). + alpha : float + Significance level used for confidence intervals. + overall_att_se : float + Bootstrap standard error for overall ATT. + overall_att_ci : tuple + Bootstrap confidence interval for overall ATT. + overall_att_p_value : float + Bootstrap p-value for overall ATT. + event_study_ses : dict, optional + Bootstrap SEs for event study effects. + event_study_cis : dict, optional + Bootstrap CIs for event study effects. + event_study_p_values : dict, optional + Bootstrap p-values for event study effects. + group_ses : dict, optional + Bootstrap SEs for group effects. + group_cis : dict, optional + Bootstrap CIs for group effects. + group_p_values : dict, optional + Bootstrap p-values for group effects. + bootstrap_distribution : np.ndarray, optional + Full bootstrap distribution of overall ATT. + """ + + n_bootstrap: int + weight_type: str + alpha: float + overall_att_se: float + overall_att_ci: Tuple[float, float] + overall_att_p_value: float + event_study_ses: Optional[Dict[int, float]] = None + event_study_cis: Optional[Dict[int, Tuple[float, float]]] = None + event_study_p_values: Optional[Dict[int, float]] = None + group_ses: Optional[Dict[Any, float]] = None + group_cis: Optional[Dict[Any, Tuple[float, float]]] = None + group_p_values: Optional[Dict[Any, float]] = None + bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False) + + +@dataclass +class ImputationDiDResults: + """ + Results from Borusyak-Jaravel-Spiess (2024) imputation DiD estimation. + + Attributes + ---------- + treatment_effects : pd.DataFrame + Unit-level treatment effects with columns: unit, time, tau_hat, weight. + overall_att : float + Overall average treatment effect on the treated. + overall_se : float + Standard error of overall ATT. + overall_t_stat : float + T-statistic for overall ATT. + overall_p_value : float + P-value for overall ATT. + overall_conf_int : tuple + Confidence interval for overall ATT. + event_study_effects : dict, optional + Dictionary mapping relative time h to effect dict with keys: + 'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_obs'. + group_effects : dict, optional + Dictionary mapping cohort g to effect dict. + groups : list + List of treatment cohorts. + time_periods : list + List of all time periods. + n_obs : int + Total number of observations. + n_treated_obs : int + Number of treated observations (|Omega_1|). + n_untreated_obs : int + Number of untreated observations (|Omega_0|). + n_treated_units : int + Number of ever-treated units. + n_control_units : int + Number of units contributing to Omega_0. + alpha : float + Significance level used. + pretrend_results : dict, optional + Populated by pretrend_test(). + bootstrap_results : ImputationBootstrapResults, optional + Bootstrap inference results. + """ + + treatment_effects: pd.DataFrame + overall_att: float + overall_se: float + overall_t_stat: float + overall_p_value: float + overall_conf_int: Tuple[float, float] + event_study_effects: Optional[Dict[int, Dict[str, Any]]] + group_effects: Optional[Dict[Any, Dict[str, Any]]] + groups: List[Any] + time_periods: List[Any] + n_obs: int + n_treated_obs: int + n_untreated_obs: int + n_treated_units: int + n_control_units: int + alpha: float = 0.05 + pretrend_results: Optional[Dict[str, Any]] = field(default=None, repr=False) + bootstrap_results: Optional[ImputationBootstrapResults] = field(default=None, repr=False) + # Internal: stores data needed for pretrend_test() + _estimator_ref: Optional[Any] = field(default=None, repr=False) + + def __repr__(self) -> str: + """Concise string representation.""" + sig = _get_significance_stars(self.overall_p_value) + return ( + f"ImputationDiDResults(ATT={self.overall_att:.4f}{sig}, " + f"SE={self.overall_se:.4f}, " + f"n_groups={len(self.groups)}, " + f"n_treated_obs={self.n_treated_obs})" + ) + + def summary(self, alpha: Optional[float] = None) -> str: + """ + Generate formatted summary of estimation results. + + Parameters + ---------- + alpha : float, optional + Significance level. Defaults to alpha used in estimation. + + Returns + ------- + str + Formatted summary. + """ + alpha = alpha or self.alpha + conf_level = int((1 - alpha) * 100) + + lines = [ + "=" * 85, + "Imputation DiD Estimator Results (Borusyak et al. 2024)".center(85), + "=" * 85, + "", + f"{'Total observations:':<30} {self.n_obs:>10}", + f"{'Treated observations:':<30} {self.n_treated_obs:>10}", + f"{'Untreated observations:':<30} {self.n_untreated_obs:>10}", + f"{'Treated units:':<30} {self.n_treated_units:>10}", + f"{'Control units:':<30} {self.n_control_units:>10}", + f"{'Treatment cohorts:':<30} {len(self.groups):>10}", + f"{'Time periods:':<30} {len(self.time_periods):>10}", + "", + ] + + # Overall ATT + lines.extend( + [ + "-" * 85, + "Overall Average Treatment Effect on the Treated".center(85), + "-" * 85, + f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + t_str = ( + f"{self.overall_t_stat:>10.3f}" if np.isfinite(self.overall_t_stat) else f"{'NaN':>10}" + ) + p_str = ( + f"{self.overall_p_value:>10.4f}" + if np.isfinite(self.overall_p_value) + else f"{'NaN':>10}" + ) + sig = _get_significance_stars(self.overall_p_value) + + lines.extend( + [ + f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} " + f"{t_str} {p_str} {sig:>6}", + "-" * 85, + "", + f"{conf_level}% Confidence Interval: " + f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]", + "", + ] + ) + + # Event study effects + if self.event_study_effects: + lines.extend( + [ + "-" * 85, + "Event Study (Dynamic) Effects".center(85), + "-" * 85, + f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + for h in sorted(self.event_study_effects.keys()): + eff = self.event_study_effects[h] + if eff.get("n_obs", 1) == 0: + # Reference period marker + lines.append( + f"[ref: {h}]" f"{'0.0000':>17} {'---':>12} {'---':>10} {'---':>10} {'':>6}" + ) + elif np.isnan(eff["effect"]): + lines.append(f"{h:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}") + else: + e_sig = _get_significance_stars(eff["p_value"]) + e_t = ( + f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}" + ) + e_p = ( + f"{eff['p_value']:>10.4f}" + if np.isfinite(eff["p_value"]) + else f"{'NaN':>10}" + ) + lines.append( + f"{h:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " + f"{e_t} {e_p} {e_sig:>6}" + ) + + lines.extend(["-" * 85, ""]) + + # Group effects + if self.group_effects: + lines.extend( + [ + "-" * 85, + "Group (Cohort) Effects".center(85), + "-" * 85, + f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + for g in sorted(self.group_effects.keys()): + eff = self.group_effects[g] + if np.isnan(eff["effect"]): + lines.append(f"{g:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}") + else: + g_sig = _get_significance_stars(eff["p_value"]) + g_t = ( + f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}" + ) + g_p = ( + f"{eff['p_value']:>10.4f}" + if np.isfinite(eff["p_value"]) + else f"{'NaN':>10}" + ) + lines.append( + f"{g:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " + f"{g_t} {g_p} {g_sig:>6}" + ) + + lines.extend(["-" * 85, ""]) + + # Pre-trend test + if self.pretrend_results is not None: + pt = self.pretrend_results + lines.extend( + [ + "-" * 85, + "Pre-Trend Test (Equation 9)".center(85), + "-" * 85, + f"{'F-statistic:':<30} {pt['f_stat']:>10.3f}", + f"{'P-value:':<30} {pt['p_value']:>10.4f}", + f"{'Degrees of freedom:':<30} {pt['df']:>10}", + f"{'Number of leads:':<30} {pt['n_leads']:>10}", + "-" * 85, + "", + ] + ) + + lines.extend( + [ + "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", + "=" * 85, + ] + ) + + return "\n".join(lines) + + def print_summary(self, alpha: Optional[float] = None) -> None: + """Print summary to stdout.""" + print(self.summary(alpha)) + + def to_dataframe(self, level: str = "observation") -> pd.DataFrame: + """ + Convert results to DataFrame. + + Parameters + ---------- + level : str, default="observation" + Level of aggregation: + - "observation": Unit-level treatment effects + - "event_study": Event study effects by relative time + - "group": Group (cohort) effects + + Returns + ------- + pd.DataFrame + Results as DataFrame. + """ + if level == "observation": + return self.treatment_effects.copy() + + elif level == "event_study": + if self.event_study_effects is None: + raise ValueError( + "Event study effects not computed. " + "Use aggregate='event_study' or aggregate='all'." + ) + rows = [] + for h, data in sorted(self.event_study_effects.items()): + rows.append( + { + "relative_period": h, + "effect": data["effect"], + "se": data["se"], + "t_stat": data["t_stat"], + "p_value": data["p_value"], + "conf_int_lower": data["conf_int"][0], + "conf_int_upper": data["conf_int"][1], + "n_obs": data.get("n_obs", np.nan), + } + ) + return pd.DataFrame(rows) + + elif level == "group": + if self.group_effects is None: + raise ValueError( + "Group effects not computed. " "Use aggregate='group' or aggregate='all'." + ) + rows = [] + for g, data in sorted(self.group_effects.items()): + rows.append( + { + "group": g, + "effect": data["effect"], + "se": data["se"], + "t_stat": data["t_stat"], + "p_value": data["p_value"], + "conf_int_lower": data["conf_int"][0], + "conf_int_upper": data["conf_int"][1], + "n_obs": data.get("n_obs", np.nan), + } + ) + return pd.DataFrame(rows) + + else: + raise ValueError( + f"Unknown level: {level}. Use 'observation', 'event_study', or 'group'." + ) + + def pretrend_test(self, n_leads: Optional[int] = None) -> Dict[str, Any]: + """ + Run a pre-trend test (Equation 9 of Borusyak et al. 2024). + + Adds pre-treatment lead indicators to the Step 1 OLS and tests + their joint significance via a cluster-robust Wald F-test. + + Parameters + ---------- + n_leads : int, optional + Number of pre-treatment leads to include. If None, uses all + available pre-treatment periods minus one (for the reference period). + + Returns + ------- + dict + Dictionary with keys: 'f_stat', 'p_value', 'df', 'n_leads', + 'lead_coefficients'. + """ + if self._estimator_ref is None: + raise RuntimeError( + "Pre-trend test requires internal estimator reference. " + "Re-fit the model to use this method." + ) + result = self._estimator_ref._pretrend_test(n_leads=n_leads) + self.pretrend_results = result + return result + + @property + def is_significant(self) -> bool: + """Check if overall ATT is significant.""" + return bool(self.overall_p_value < self.alpha) + + @property + def significance_stars(self) -> str: + """Significance stars for overall ATT.""" + return _get_significance_stars(self.overall_p_value) diff --git a/diff_diff/linalg.py b/diff_diff/linalg.py index f5764d8c..a74de9f4 100644 --- a/diff_diff/linalg.py +++ b/diff_diff/linalg.py @@ -329,7 +329,7 @@ def _solve_ols_rust( # Return with optional fitted values if return_fitted: - fitted = X @ coefficients + fitted = np.dot(X, coefficients) return coefficients, residuals, fitted, vcov else: return coefficients, residuals, vcov @@ -687,7 +687,7 @@ def _solve_ols_numpy( # Compute residuals using only the identified coefficients # Note: Dropped coefficients are NaN, so we use the reduced form - fitted = X_reduced @ coefficients_reduced + fitted = np.dot(X_reduced, coefficients_reduced) residuals = y - fitted # Compute variance-covariance matrix for reduced system, then expand @@ -701,7 +701,7 @@ def _solve_ols_numpy( coefficients = scipy_lstsq(X, y, lapack_driver="gelsd", check_finite=False, cond=1e-07)[0] # Compute residuals and fitted values - fitted = X @ coefficients + fitted = np.dot(X, coefficients) residuals = y - fitted # Compute variance-covariance matrix if requested @@ -826,7 +826,7 @@ def _compute_robust_vcov_numpy( adjustment = n / (n - k) u_squared = residuals**2 # Vectorized meat computation: X' diag(u^2) X = (X * u^2)' X - meat = X.T @ (X * u_squared[:, np.newaxis]) + meat = np.dot(X.T, X * u_squared[:, np.newaxis]) else: # Cluster-robust standard errors (vectorized via groupby) cluster_ids = np.asarray(cluster_ids) @@ -1299,7 +1299,7 @@ def get_inference( "This may indicate perfect multicollinearity or numerical issues.", UserWarning, ) - # Use inf for t-stat when SE is zero (perfect fit scenario) + # NOTE: Deliberately uses ±inf (not NaN via safe_inference) for zero-SE coefficients. if coef > 0: t_stat = np.inf elif coef < 0: @@ -1460,7 +1460,7 @@ def predict(self, X: np.ndarray) -> np.ndarray: coef = self.coefficients_.copy() coef[np.isnan(coef)] = 0.0 - return X @ coef + return np.dot(X, coef) # ============================================================================= diff --git a/diff_diff/staggered.py b/diff_diff/staggered.py index f2eb0829..9bca072b 100644 --- a/diff_diff/staggered.py +++ b/diff_diff/staggered.py @@ -13,10 +13,7 @@ from scipy import optimize from diff_diff.linalg import solve_ols -from diff_diff.utils import ( - compute_confidence_interval, - compute_p_value, -) +from diff_diff.utils import safe_inference # Import from split modules from diff_diff.staggered_results import ( @@ -75,17 +72,17 @@ def _logistic_regression( X_with_intercept = np.column_stack([np.ones(n), X]) def neg_log_likelihood(beta: np.ndarray) -> float: - z = X_with_intercept @ beta + z = np.dot(X_with_intercept, beta) # Clip to prevent overflow z = np.clip(z, -500, 500) log_lik = np.sum(y * z - np.log(1 + np.exp(z))) return -log_lik def gradient(beta: np.ndarray) -> np.ndarray: - z = X_with_intercept @ beta + z = np.dot(X_with_intercept, beta) z = np.clip(z, -500, 500) probs = 1 / (1 + np.exp(-z)) - return -X_with_intercept.T @ (y - probs) + return -np.dot(X_with_intercept.T, y - probs) # Initialize with zeros beta_init = np.zeros(p + 1) @@ -99,7 +96,7 @@ def gradient(beta: np.ndarray) -> np.ndarray: ) beta = result.x - z = X_with_intercept @ beta + z = np.dot(X_with_intercept, beta) z = np.clip(z, -500, 500) probs = 1 / (1 + np.exp(-z)) @@ -693,9 +690,7 @@ def fit( ) if att_gt is not None: - t_stat = att_gt / se_gt if np.isfinite(se_gt) and se_gt > 0 else np.nan - p_val = compute_p_value(t_stat) - ci = compute_confidence_interval(att_gt, se_gt, self.alpha) + t_stat, p_val, ci = safe_inference(att_gt, se_gt, alpha=self.alpha) group_time_effects[(g, t)] = { 'effect': att_gt, @@ -720,14 +715,9 @@ def fit( overall_att, overall_se = self._aggregate_simple( group_time_effects, influence_func_info, df, unit, precomputed ) - # Use NaN for t-stat and p-value when SE is undefined (NaN or non-positive) - if np.isfinite(overall_se) and overall_se > 0: - overall_t = overall_att / overall_se - overall_p = compute_p_value(overall_t) - else: - overall_t = np.nan - overall_p = np.nan - overall_ci = compute_confidence_interval(overall_att, overall_se, self.alpha) + overall_t, overall_p, overall_ci = safe_inference( + overall_att, overall_se, alpha=self.alpha + ) # Compute additional aggregations if requested event_study_effects = None @@ -758,11 +748,7 @@ def fit( # Update estimates with bootstrap inference overall_se = bootstrap_results.overall_att_se - # Use NaN for t-stat when SE is undefined; p-value comes from bootstrap - if np.isfinite(overall_se) and overall_se > 0: - overall_t = overall_att / overall_se - else: - overall_t = np.nan + overall_t = safe_inference(overall_att, overall_se, alpha=self.alpha)[0] overall_p = bootstrap_results.overall_att_p_value overall_ci = bootstrap_results.overall_att_ci @@ -774,7 +760,7 @@ def fit( group_time_effects[gt]['p_value'] = bootstrap_results.group_time_p_values[gt] effect = float(group_time_effects[gt]['effect']) se = float(group_time_effects[gt]['se']) - group_time_effects[gt]['t_stat'] = effect / se if np.isfinite(se) and se > 0 else np.nan + group_time_effects[gt]['t_stat'] = safe_inference(effect, se, alpha=self.alpha)[0] # Update event study effects with bootstrap SEs if (event_study_effects is not None @@ -789,7 +775,7 @@ def fit( event_study_effects[e]['p_value'] = p_val effect = float(event_study_effects[e]['effect']) se = float(event_study_effects[e]['se']) - event_study_effects[e]['t_stat'] = effect / se if np.isfinite(se) and se > 0 else np.nan + event_study_effects[e]['t_stat'] = safe_inference(effect, se, alpha=self.alpha)[0] # Update group effects with bootstrap SEs if (group_effects is not None @@ -803,7 +789,7 @@ def fit( group_effects[g]['p_value'] = bootstrap_results.group_effect_p_values[g] effect = float(group_effects[g]['effect']) se = float(group_effects[g]['se']) - group_effects[g]['t_stat'] = effect / se if np.isfinite(se) and se > 0 else np.nan + group_effects[g]['t_stat'] = safe_inference(effect, se, alpha=self.alpha)[0] # Store results self.results_ = CallawaySantAnnaResults( @@ -860,7 +846,7 @@ def _outcome_regression( # Predict counterfactual for treated units X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated]) - predicted_control = X_treated_with_intercept @ beta + predicted_control = np.dot(X_treated_with_intercept, beta) # ATT = mean(observed treated change - predicted counterfactual) att = np.mean(treated_change - predicted_control) @@ -1024,8 +1010,8 @@ def _doubly_robust( # Predict counterfactual for both treated and control X_treated_with_intercept = np.column_stack([np.ones(n_t), X_treated]) X_control_with_intercept = np.column_stack([np.ones(n_c), X_control]) - m_treated = X_treated_with_intercept @ beta - m_control = X_control_with_intercept @ beta + m_treated = np.dot(X_treated_with_intercept, beta) + m_control = np.dot(X_control_with_intercept, beta) # Step 2: Propensity score estimation X_all = np.vstack([X_treated, X_control]) diff --git a/diff_diff/staggered_aggregation.py b/diff_diff/staggered_aggregation.py index 14194138..85df1fc6 100644 --- a/diff_diff/staggered_aggregation.py +++ b/diff_diff/staggered_aggregation.py @@ -10,10 +10,7 @@ import numpy as np import pandas as pd -from diff_diff.utils import ( - compute_confidence_interval, - compute_p_value, -) +from diff_diff.utils import safe_inference # Type alias for pre-computed structures (defined at module scope for runtime access) PrecomputedData = Dict[str, Any] @@ -404,9 +401,7 @@ def _aggregate_event_study( gt_pairs, weights, influence_func_info ) - t_stat = agg_effect / agg_se if np.isfinite(agg_se) and agg_se > 0 else np.nan - p_val = compute_p_value(t_stat) - ci = compute_confidence_interval(agg_effect, agg_se, self.alpha) + t_stat, p_val, ci = safe_inference(agg_effect, agg_se, alpha=self.alpha) event_study_effects[e] = { 'effect': agg_effect, @@ -476,9 +471,7 @@ def _aggregate_by_group( gt_pairs, weights, influence_func_info ) - t_stat = agg_effect / agg_se if np.isfinite(agg_se) and agg_se > 0 else np.nan - p_val = compute_p_value(t_stat) - ci = compute_confidence_interval(agg_effect, agg_se, self.alpha) + t_stat, p_val, ci = safe_inference(agg_effect, agg_se, alpha=self.alpha) group_effects[g] = { 'effect': agg_effect, diff --git a/diff_diff/sun_abraham.py b/diff_diff/sun_abraham.py index 79e3bad7..7f84fec7 100644 --- a/diff_diff/sun_abraham.py +++ b/diff_diff/sun_abraham.py @@ -19,8 +19,7 @@ from diff_diff.linalg import LinearRegression, compute_robust_vcov from diff_diff.results import _get_significance_stars from diff_diff.utils import ( - compute_confidence_interval, - compute_p_value, + safe_inference, within_transform as _within_transform_util, ) @@ -618,9 +617,7 @@ def fit( coef_index_map, ) - overall_t = overall_att / overall_se if np.isfinite(overall_se) and overall_se > 0 else np.nan - overall_p = compute_p_value(overall_t) - overall_ci = compute_confidence_interval(overall_att, overall_se, self.alpha) if np.isfinite(overall_se) and overall_se > 0 else (np.nan, np.nan) + overall_t, overall_p, overall_ci = safe_inference(overall_att, overall_se, alpha=self.alpha) # Run bootstrap if requested bootstrap_results = None @@ -641,7 +638,7 @@ def fit( # Update results with bootstrap inference overall_se = bootstrap_results.overall_att_se - overall_t = overall_att / overall_se if np.isfinite(overall_se) and overall_se > 0 else np.nan + overall_t = safe_inference(overall_att, overall_se, alpha=self.alpha)[0] overall_p = bootstrap_results.overall_att_p_value overall_ci = bootstrap_results.overall_att_ci @@ -657,9 +654,9 @@ def fit( ) eff_val = event_study_effects[e]["effect"] se_val = event_study_effects[e]["se"] - event_study_effects[e]["t_stat"] = ( - eff_val / se_val if np.isfinite(se_val) and se_val > 0 else np.nan - ) + event_study_effects[e]["t_stat"] = safe_inference( + eff_val, se_val, alpha=self.alpha + )[0] # Convert cohort effects to storage format cohort_effects_storage: Dict[Tuple[Any, int], Dict[str, Any]] = {} @@ -902,9 +899,7 @@ def _compute_iw_effects( agg_var = float(weight_vec @ vcov_subset @ weight_vec) agg_se = np.sqrt(max(agg_var, 0)) - t_stat = agg_effect / agg_se if np.isfinite(agg_se) and agg_se > 0 else np.nan - p_val = compute_p_value(t_stat) - ci = compute_confidence_interval(agg_effect, agg_se, self.alpha) if np.isfinite(agg_se) and agg_se > 0 else (np.nan, np.nan) + t_stat, p_val, ci = safe_inference(agg_effect, agg_se, alpha=self.alpha) event_study_effects[e] = { "effect": agg_effect, diff --git a/diff_diff/synthetic_did.py b/diff_diff/synthetic_did.py index f0591647..59258452 100644 --- a/diff_diff/synthetic_did.py +++ b/diff_diff/synthetic_did.py @@ -15,11 +15,10 @@ from diff_diff.utils import ( _compute_regularization, _sum_normalize, - compute_confidence_interval, - compute_p_value, compute_sdid_estimator, compute_sdid_unit_weights, compute_time_weights, + safe_inference, validate_binary, ) @@ -422,24 +421,14 @@ def fit( # type: ignore[override] inference_method = "placebo" # Compute test statistics - if np.isfinite(se) and se > 0: - t_stat = att / se - # Use placebo distribution for p-value if available - if len(placebo_effects) > 0: - # Two-sided p-value from placebo distribution - p_value = np.mean(np.abs(placebo_effects) >= np.abs(att)) - p_value = max(p_value, 1.0 / (len(placebo_effects) + 1)) - else: - p_value = compute_p_value(t_stat) - else: - t_stat = np.nan - p_value = np.nan - - # Confidence interval - if np.isfinite(se) and se > 0: - conf_int = compute_confidence_interval(att, se, self.alpha) + t_stat, p_value_analytical, conf_int = safe_inference(att, se, alpha=self.alpha) + if len(placebo_effects) > 0 and np.isfinite(t_stat): + p_value = max( + np.mean(np.abs(placebo_effects) >= np.abs(att)), + 1.0 / (len(placebo_effects) + 1), + ) else: - conf_int = (np.nan, np.nan) + p_value = p_value_analytical # Create weight dictionaries unit_weights_dict = { diff --git a/diff_diff/triple_diff.py b/diff_diff/triple_diff.py index fc9c98d7..fd4051db 100644 --- a/diff_diff/triple_diff.py +++ b/diff_diff/triple_diff.py @@ -39,10 +39,7 @@ from diff_diff.linalg import LinearRegression, compute_robust_vcov, solve_ols from diff_diff.results import _get_significance_stars -from diff_diff.utils import ( - compute_confidence_interval, - compute_p_value, -) +from diff_diff.utils import safe_inference # ============================================================================= # Results Classes @@ -298,16 +295,16 @@ def _logistic_regression( X_with_intercept = np.column_stack([np.ones(n), X]) def neg_log_likelihood(beta: np.ndarray) -> float: - z = X_with_intercept @ beta + z = np.dot(X_with_intercept, beta) z = np.clip(z, -500, 500) log_lik = np.sum(y * z - np.log(1 + np.exp(z))) return -log_lik def gradient(beta: np.ndarray) -> np.ndarray: - z = X_with_intercept @ beta + z = np.dot(X_with_intercept, beta) z = np.clip(z, -500, 500) probs = 1 / (1 + np.exp(-z)) - return -X_with_intercept.T @ (y - probs) + return -np.dot(X_with_intercept.T, y - probs) beta_init = np.zeros(p + 1) @@ -320,7 +317,7 @@ def gradient(beta: np.ndarray) -> np.ndarray: ) beta = result.x - z = X_with_intercept @ beta + z = np.dot(X_with_intercept, beta) z = np.clip(z, -500, 500) probs = 1 / (1 + np.exp(-z)) @@ -598,14 +595,12 @@ def fit( ) # Compute inference - t_stat = att / se if np.isfinite(se) and se > 0 else np.nan df = n_obs - 8 # Approximate df (8 cell means) if covariates: df -= len(covariates) df = max(df, 1) - p_value = compute_p_value(t_stat, df=df) - conf_int = compute_confidence_interval(att, se, self.alpha, df=df) if np.isfinite(se) and se > 0 else (np.nan, np.nan) + t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df) # Get number of clusters if clustering n_clusters = None diff --git a/diff_diff/trop.py b/diff_diff/trop.py index 179412cd..1059cca0 100644 --- a/diff_diff/trop.py +++ b/diff_diff/trop.py @@ -19,20 +19,13 @@ import logging import warnings -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple import numpy as np import pandas as pd -from scipy import stats logger = logging.getLogger(__name__) -try: - from typing import TypedDict -except ImportError: - from typing_extensions import TypedDict - from diff_diff._backend import ( HAS_RUST_BACKEND, _rust_unit_distance_matrix, @@ -41,305 +34,12 @@ _rust_loocv_grid_search_joint, _rust_bootstrap_trop_variance_joint, ) -from diff_diff.results import _get_significance_stars -from diff_diff.utils import compute_confidence_interval, compute_p_value - - -# Sentinel value for "disabled" λ_nn in LOOCV parameter search. -# Per paper's footnote 2: λ_nn=∞ disables the factor model (L=0). -# For λ_time and λ_unit, 0.0 means disabled (uniform weights) per Eq. 3: -# exp(-0 × dist) = 1 for all distances. -_LAMBDA_INF: float = float('inf') - - -class _PrecomputedStructures(TypedDict): - """Type definition for pre-computed structures used across LOOCV iterations. - - These structures are computed once in `_precompute_structures()` and reused - to avoid redundant computation during LOOCV and final estimation. - """ - - unit_dist_matrix: np.ndarray - """Pairwise unit distance matrix (n_units x n_units).""" - time_dist_matrix: np.ndarray - """Time distance matrix where [t, s] = |t - s| (n_periods x n_periods).""" - control_mask: np.ndarray - """Boolean mask for control observations (D == 0).""" - treated_mask: np.ndarray - """Boolean mask for treated observations (D == 1).""" - treated_observations: List[Tuple[int, int]] - """List of (t, i) tuples for treated observations.""" - control_obs: List[Tuple[int, int]] - """List of (t, i) tuples for valid control observations.""" - control_unit_idx: np.ndarray - """Array of never-treated unit indices (for backward compatibility).""" - D: np.ndarray - """Treatment indicator matrix (n_periods x n_units) for dynamic control sets.""" - Y: np.ndarray - """Outcome matrix (n_periods x n_units).""" - n_units: int - """Number of units.""" - n_periods: int - """Number of time periods.""" - - -@dataclass -class TROPResults: - """ - Results from a Triply Robust Panel (TROP) estimation. - - TROP combines nuclear norm regularized factor estimation with - exponential distance-based unit weights and time decay weights. - - Attributes - ---------- - att : float - Average Treatment effect on the Treated (ATT). - se : float - Standard error of the ATT estimate. - t_stat : float - T-statistic for the ATT estimate. - p_value : float - P-value for the null hypothesis that ATT = 0. - conf_int : tuple[float, float] - Confidence interval for the ATT. - n_obs : int - Number of observations used in estimation. - n_treated : int - Number of treated units. - n_control : int - Number of control units. - n_treated_obs : int - Number of treated unit-time observations. - unit_effects : dict - Estimated unit fixed effects (alpha_i). - time_effects : dict - Estimated time fixed effects (beta_t). - treatment_effects : dict - Individual treatment effects for each treated (unit, time) pair. - lambda_time : float - Selected time weight decay parameter from grid. 0.0 = uniform time - weights (disabled) per Eq. 3. - lambda_unit : float - Selected unit weight decay parameter from grid. 0.0 = uniform unit - weights (disabled) per Eq. 3. - lambda_nn : float - Selected nuclear norm regularization parameter from grid. inf = factor - model disabled (L=0); converted to 1e10 internally for computation. - factor_matrix : np.ndarray - Estimated low-rank factor matrix L (n_periods x n_units). - effective_rank : float - Effective rank of the factor matrix (sum of singular values / max). - loocv_score : float - Leave-one-out cross-validation score for selected parameters. - alpha : float - Significance level for confidence interval. - n_pre_periods : int - Number of pre-treatment periods. - n_post_periods : int - Number of post-treatment periods (periods with D=1 observations). - n_bootstrap : int, optional - Number of bootstrap replications (if bootstrap variance). - bootstrap_distribution : np.ndarray, optional - Bootstrap distribution of estimates. - """ - - att: float - se: float - t_stat: float - p_value: float - conf_int: Tuple[float, float] - n_obs: int - n_treated: int - n_control: int - n_treated_obs: int - unit_effects: Dict[Any, float] - time_effects: Dict[Any, float] - treatment_effects: Dict[Tuple[Any, Any], float] - lambda_time: float - lambda_unit: float - lambda_nn: float - factor_matrix: np.ndarray - effective_rank: float - loocv_score: float - alpha: float = 0.05 - n_pre_periods: int = 0 - n_post_periods: int = 0 - n_bootstrap: Optional[int] = field(default=None) - bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False) - - def __repr__(self) -> str: - """Concise string representation.""" - sig = _get_significance_stars(self.p_value) - return ( - f"TROPResults(ATT={self.att:.4f}{sig}, " - f"SE={self.se:.4f}, " - f"eff_rank={self.effective_rank:.1f}, " - f"p={self.p_value:.4f})" - ) - - def summary(self, alpha: Optional[float] = None) -> str: - """ - Generate a formatted summary of the estimation results. - - Parameters - ---------- - alpha : float, optional - Significance level for confidence intervals. Defaults to the - alpha used during estimation. - - Returns - ------- - str - Formatted summary table. - """ - alpha = alpha or self.alpha - conf_level = int((1 - alpha) * 100) - - lines = [ - "=" * 75, - "Triply Robust Panel (TROP) Estimation Results".center(75), - "Athey, Imbens, Qu & Viviano (2025)".center(75), - "=" * 75, - "", - f"{'Observations:':<25} {self.n_obs:>10}", - f"{'Treated units:':<25} {self.n_treated:>10}", - f"{'Control units:':<25} {self.n_control:>10}", - f"{'Treated observations:':<25} {self.n_treated_obs:>10}", - f"{'Pre-treatment periods:':<25} {self.n_pre_periods:>10}", - f"{'Post-treatment periods:':<25} {self.n_post_periods:>10}", - "", - "-" * 75, - "Tuning Parameters (selected via LOOCV)".center(75), - "-" * 75, - f"{'Lambda (time decay):':<25} {self.lambda_time:>10.4f}", - f"{'Lambda (unit distance):':<25} {self.lambda_unit:>10.4f}", - f"{'Lambda (nuclear norm):':<25} {self.lambda_nn:>10.4f}", - f"{'Effective rank:':<25} {self.effective_rank:>10.2f}", - f"{'LOOCV score:':<25} {self.loocv_score:>10.6f}", - ] - - # Variance info - if self.n_bootstrap is not None: - lines.append(f"{'Bootstrap replications:':<25} {self.n_bootstrap:>10}") - - lines.extend([ - "", - "-" * 75, - f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} " - f"{'t-stat':>10} {'P>|t|':>10} {'':>5}", - "-" * 75, - f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} " - f"{self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}", - "-" * 75, - "", - f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]", - ]) - - # Add significance codes - lines.extend([ - "", - "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", - "=" * 75, - ]) - - return "\n".join(lines) - - def print_summary(self, alpha: Optional[float] = None) -> None: - """Print the summary to stdout.""" - print(self.summary(alpha)) - - def to_dict(self) -> Dict[str, Any]: - """ - Convert results to a dictionary. - - Returns - ------- - Dict[str, Any] - Dictionary containing all estimation results. - """ - return { - "att": self.att, - "se": self.se, - "t_stat": self.t_stat, - "p_value": self.p_value, - "conf_int_lower": self.conf_int[0], - "conf_int_upper": self.conf_int[1], - "n_obs": self.n_obs, - "n_treated": self.n_treated, - "n_control": self.n_control, - "n_treated_obs": self.n_treated_obs, - "n_pre_periods": self.n_pre_periods, - "n_post_periods": self.n_post_periods, - "lambda_time": self.lambda_time, - "lambda_unit": self.lambda_unit, - "lambda_nn": self.lambda_nn, - "effective_rank": self.effective_rank, - "loocv_score": self.loocv_score, - } - - def to_dataframe(self) -> pd.DataFrame: - """ - Convert results to a pandas DataFrame. - - Returns - ------- - pd.DataFrame - DataFrame with estimation results. - """ - return pd.DataFrame([self.to_dict()]) - - def get_treatment_effects_df(self) -> pd.DataFrame: - """ - Get individual treatment effects as a DataFrame. - - Returns - ------- - pd.DataFrame - DataFrame with unit, time, and treatment effect columns. - """ - return pd.DataFrame([ - {"unit": unit, "time": time, "effect": effect} - for (unit, time), effect in self.treatment_effects.items() - ]) - - def get_unit_effects_df(self) -> pd.DataFrame: - """ - Get unit fixed effects as a DataFrame. - - Returns - ------- - pd.DataFrame - DataFrame with unit and effect columns. - """ - return pd.DataFrame([ - {"unit": unit, "effect": effect} - for unit, effect in self.unit_effects.items() - ]) - - def get_time_effects_df(self) -> pd.DataFrame: - """ - Get time fixed effects as a DataFrame. - - Returns - ------- - pd.DataFrame - DataFrame with time and effect columns. - """ - return pd.DataFrame([ - {"time": time, "effect": effect} - for time, effect in self.time_effects.items() - ]) - - @property - def is_significant(self) -> bool: - """Check if the ATT is statistically significant at the alpha level.""" - return bool(self.p_value < self.alpha) - - @property - def significance_stars(self) -> str: - """Return significance stars based on p-value.""" - return _get_significance_stars(self.p_value) +from diff_diff.trop_results import ( + _LAMBDA_INF, + _PrecomputedStructures, + TROPResults, +) +from diff_diff.utils import safe_inference class TROP: @@ -1095,7 +795,7 @@ def _solve_joint_no_lowrank( coeffs, _, _, _ = np.linalg.lstsq(X_weighted, y_weighted, rcond=None) except np.linalg.LinAlgError: # Fallback: use pseudo-inverse - coeffs = np.linalg.pinv(X_weighted) @ y_weighted + coeffs = np.dot(np.linalg.pinv(X_weighted), y_weighted) # Extract parameters mu = coeffs[0] @@ -1470,14 +1170,8 @@ def _fit_joint( ) # Compute test statistics - if se > 0: - t_stat = att / se - p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=max(1, n_treated_obs - 1))) - conf_int = compute_confidence_interval(att, se, self.alpha) - else: - t_stat = np.nan - p_value = np.nan - conf_int = (np.nan, np.nan) + df_trop = max(1, n_treated_obs - 1) + t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df_trop) # Create results dictionaries unit_effects_dict = {idx_to_unit[i]: alpha[i] for i in range(n_units)} @@ -2050,15 +1744,8 @@ def fit( ) # Compute test statistics - if se > 0: - t_stat = att / se - p_value = 2 * (1 - stats.t.cdf(abs(t_stat), df=max(1, n_treated_obs - 1))) - conf_int = compute_confidence_interval(att, se, self.alpha) - else: - # When SE is undefined/zero, ALL inference fields should be NaN - t_stat = np.nan - p_value = np.nan - conf_int = (np.nan, np.nan) + df_trop = max(1, n_treated_obs - 1) + t_stat, p_value, conf_int = safe_inference(att, se, alpha=self.alpha, df=df_trop) # Create results dictionaries unit_effects_dict = {idx_to_unit[i]: alpha_hat[i] for i in range(n_units)} diff --git a/diff_diff/trop_results.py b/diff_diff/trop_results.py new file mode 100644 index 00000000..a2189e81 --- /dev/null +++ b/diff_diff/trop_results.py @@ -0,0 +1,322 @@ +""" +Result containers for the Triply Robust Panel (TROP) estimator. + +This module contains the TROPResults dataclass, _PrecomputedStructures TypedDict, +and _LAMBDA_INF sentinel value. Extracted from trop.py for module size management. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +try: + from typing import TypedDict +except ImportError: + from typing_extensions import TypedDict + +from diff_diff.results import _get_significance_stars + +__all__ = [ + "_LAMBDA_INF", + "_PrecomputedStructures", + "TROPResults", +] + + +# Sentinel value for "disabled" λ_nn in LOOCV parameter search. +# Per paper's footnote 2: λ_nn=∞ disables the factor model (L=0). +# For λ_time and λ_unit, 0.0 means disabled (uniform weights) per Eq. 3: +# exp(-0 × dist) = 1 for all distances. +_LAMBDA_INF: float = float('inf') + + +class _PrecomputedStructures(TypedDict): + """Type definition for pre-computed structures used across LOOCV iterations. + + These structures are computed once in `_precompute_structures()` and reused + to avoid redundant computation during LOOCV and final estimation. + """ + + unit_dist_matrix: np.ndarray + """Pairwise unit distance matrix (n_units x n_units).""" + time_dist_matrix: np.ndarray + """Time distance matrix where [t, s] = |t - s| (n_periods x n_periods).""" + control_mask: np.ndarray + """Boolean mask for control observations (D == 0).""" + treated_mask: np.ndarray + """Boolean mask for treated observations (D == 1).""" + treated_observations: List[Tuple[int, int]] + """List of (t, i) tuples for treated observations.""" + control_obs: List[Tuple[int, int]] + """List of (t, i) tuples for valid control observations.""" + control_unit_idx: np.ndarray + """Array of never-treated unit indices (for backward compatibility).""" + D: np.ndarray + """Treatment indicator matrix (n_periods x n_units) for dynamic control sets.""" + Y: np.ndarray + """Outcome matrix (n_periods x n_units).""" + n_units: int + """Number of units.""" + n_periods: int + """Number of time periods.""" + + +@dataclass +class TROPResults: + """ + Results from a Triply Robust Panel (TROP) estimation. + + TROP combines nuclear norm regularized factor estimation with + exponential distance-based unit weights and time decay weights. + + Attributes + ---------- + att : float + Average Treatment effect on the Treated (ATT). + se : float + Standard error of the ATT estimate. + t_stat : float + T-statistic for the ATT estimate. + p_value : float + P-value for the null hypothesis that ATT = 0. + conf_int : tuple[float, float] + Confidence interval for the ATT. + n_obs : int + Number of observations used in estimation. + n_treated : int + Number of treated units. + n_control : int + Number of control units. + n_treated_obs : int + Number of treated unit-time observations. + unit_effects : dict + Estimated unit fixed effects (alpha_i). + time_effects : dict + Estimated time fixed effects (beta_t). + treatment_effects : dict + Individual treatment effects for each treated (unit, time) pair. + lambda_time : float + Selected time weight decay parameter from grid. 0.0 = uniform time + weights (disabled) per Eq. 3. + lambda_unit : float + Selected unit weight decay parameter from grid. 0.0 = uniform unit + weights (disabled) per Eq. 3. + lambda_nn : float + Selected nuclear norm regularization parameter from grid. inf = factor + model disabled (L=0); converted to 1e10 internally for computation. + factor_matrix : np.ndarray + Estimated low-rank factor matrix L (n_periods x n_units). + effective_rank : float + Effective rank of the factor matrix (sum of singular values / max). + loocv_score : float + Leave-one-out cross-validation score for selected parameters. + alpha : float + Significance level for confidence interval. + n_pre_periods : int + Number of pre-treatment periods. + n_post_periods : int + Number of post-treatment periods (periods with D=1 observations). + n_bootstrap : int, optional + Number of bootstrap replications (if bootstrap variance). + bootstrap_distribution : np.ndarray, optional + Bootstrap distribution of estimates. + """ + + att: float + se: float + t_stat: float + p_value: float + conf_int: Tuple[float, float] + n_obs: int + n_treated: int + n_control: int + n_treated_obs: int + unit_effects: Dict[Any, float] + time_effects: Dict[Any, float] + treatment_effects: Dict[Tuple[Any, Any], float] + lambda_time: float + lambda_unit: float + lambda_nn: float + factor_matrix: np.ndarray + effective_rank: float + loocv_score: float + alpha: float = 0.05 + n_pre_periods: int = 0 + n_post_periods: int = 0 + n_bootstrap: Optional[int] = field(default=None) + bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False) + + def __repr__(self) -> str: + """Concise string representation.""" + sig = _get_significance_stars(self.p_value) + return ( + f"TROPResults(ATT={self.att:.4f}{sig}, " + f"SE={self.se:.4f}, " + f"eff_rank={self.effective_rank:.1f}, " + f"p={self.p_value:.4f})" + ) + + def summary(self, alpha: Optional[float] = None) -> str: + """ + Generate a formatted summary of the estimation results. + + Parameters + ---------- + alpha : float, optional + Significance level for confidence intervals. Defaults to the + alpha used during estimation. + + Returns + ------- + str + Formatted summary table. + """ + alpha = alpha or self.alpha + conf_level = int((1 - alpha) * 100) + + lines = [ + "=" * 75, + "Triply Robust Panel (TROP) Estimation Results".center(75), + "Athey, Imbens, Qu & Viviano (2025)".center(75), + "=" * 75, + "", + f"{'Observations:':<25} {self.n_obs:>10}", + f"{'Treated units:':<25} {self.n_treated:>10}", + f"{'Control units:':<25} {self.n_control:>10}", + f"{'Treated observations:':<25} {self.n_treated_obs:>10}", + f"{'Pre-treatment periods:':<25} {self.n_pre_periods:>10}", + f"{'Post-treatment periods:':<25} {self.n_post_periods:>10}", + "", + "-" * 75, + "Tuning Parameters (selected via LOOCV)".center(75), + "-" * 75, + f"{'Lambda (time decay):':<25} {self.lambda_time:>10.4f}", + f"{'Lambda (unit distance):':<25} {self.lambda_unit:>10.4f}", + f"{'Lambda (nuclear norm):':<25} {self.lambda_nn:>10.4f}", + f"{'Effective rank:':<25} {self.effective_rank:>10.2f}", + f"{'LOOCV score:':<25} {self.loocv_score:>10.6f}", + ] + + # Variance info + if self.n_bootstrap is not None: + lines.append(f"{'Bootstrap replications:':<25} {self.n_bootstrap:>10}") + + lines.extend([ + "", + "-" * 75, + f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'':>5}", + "-" * 75, + f"{'ATT':<15} {self.att:>12.4f} {self.se:>12.4f} " + f"{self.t_stat:>10.3f} {self.p_value:>10.4f} {self.significance_stars:>5}", + "-" * 75, + "", + f"{conf_level}% Confidence Interval: [{self.conf_int[0]:.4f}, {self.conf_int[1]:.4f}]", + ]) + + # Add significance codes + lines.extend([ + "", + "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", + "=" * 75, + ]) + + return "\n".join(lines) + + def print_summary(self, alpha: Optional[float] = None) -> None: + """Print the summary to stdout.""" + print(self.summary(alpha)) + + def to_dict(self) -> Dict[str, Any]: + """ + Convert results to a dictionary. + + Returns + ------- + Dict[str, Any] + Dictionary containing all estimation results. + """ + return { + "att": self.att, + "se": self.se, + "t_stat": self.t_stat, + "p_value": self.p_value, + "conf_int_lower": self.conf_int[0], + "conf_int_upper": self.conf_int[1], + "n_obs": self.n_obs, + "n_treated": self.n_treated, + "n_control": self.n_control, + "n_treated_obs": self.n_treated_obs, + "n_pre_periods": self.n_pre_periods, + "n_post_periods": self.n_post_periods, + "lambda_time": self.lambda_time, + "lambda_unit": self.lambda_unit, + "lambda_nn": self.lambda_nn, + "effective_rank": self.effective_rank, + "loocv_score": self.loocv_score, + } + + def to_dataframe(self) -> pd.DataFrame: + """ + Convert results to a pandas DataFrame. + + Returns + ------- + pd.DataFrame + DataFrame with estimation results. + """ + return pd.DataFrame([self.to_dict()]) + + def get_treatment_effects_df(self) -> pd.DataFrame: + """ + Get individual treatment effects as a DataFrame. + + Returns + ------- + pd.DataFrame + DataFrame with unit, time, and treatment effect columns. + """ + return pd.DataFrame([ + {"unit": unit, "time": time, "effect": effect} + for (unit, time), effect in self.treatment_effects.items() + ]) + + def get_unit_effects_df(self) -> pd.DataFrame: + """ + Get unit fixed effects as a DataFrame. + + Returns + ------- + pd.DataFrame + DataFrame with unit and effect columns. + """ + return pd.DataFrame([ + {"unit": unit, "effect": effect} + for unit, effect in self.unit_effects.items() + ]) + + def get_time_effects_df(self) -> pd.DataFrame: + """ + Get time fixed effects as a DataFrame. + + Returns + ------- + pd.DataFrame + DataFrame with time and effect columns. + """ + return pd.DataFrame([ + {"time": time, "effect": effect} + for time, effect in self.time_effects.items() + ]) + + @property + def is_significant(self) -> bool: + """Check if the ATT is statistically significant at the alpha level.""" + return bool(self.p_value < self.alpha) + + @property + def significance_stars(self) -> str: + """Return significance stars based on p-value.""" + return _get_significance_stars(self.p_value) diff --git a/diff_diff/two_stage.py b/diff_diff/two_stage.py index 7aec9b71..5e03dfb6 100644 --- a/diff_diff/two_stage.py +++ b/diff_diff/two_stage.py @@ -22,8 +22,7 @@ """ import warnings -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np import pandas as pd @@ -31,372 +30,9 @@ from scipy.sparse.linalg import factorized as sparse_factorized from diff_diff.linalg import solve_ols -from diff_diff.results import _get_significance_stars -from diff_diff.utils import compute_confidence_interval, compute_p_value - -# ============================================================================= -# Results Dataclasses -# ============================================================================= - - -@dataclass -class TwoStageBootstrapResults: - """ - Results from TwoStageDiD bootstrap inference. - - Bootstrap uses multiplier bootstrap on the GMM influence function, - consistent with other library estimators. The R `did2s` package uses - block bootstrap by default; multiplier bootstrap is asymptotically - equivalent. - - Attributes - ---------- - n_bootstrap : int - Number of bootstrap iterations. - weight_type : str - Type of bootstrap weights (currently "rademacher" only). - alpha : float - Significance level used for confidence intervals. - overall_att_se : float - Bootstrap standard error for overall ATT. - overall_att_ci : tuple - Bootstrap confidence interval for overall ATT. - overall_att_p_value : float - Bootstrap p-value for overall ATT. - event_study_ses : dict, optional - Bootstrap SEs for event study effects. - event_study_cis : dict, optional - Bootstrap CIs for event study effects. - event_study_p_values : dict, optional - Bootstrap p-values for event study effects. - group_ses : dict, optional - Bootstrap SEs for group effects. - group_cis : dict, optional - Bootstrap CIs for group effects. - group_p_values : dict, optional - Bootstrap p-values for group effects. - bootstrap_distribution : np.ndarray, optional - Full bootstrap distribution of overall ATT. - """ - - n_bootstrap: int - weight_type: str - alpha: float - overall_att_se: float - overall_att_ci: Tuple[float, float] - overall_att_p_value: float - event_study_ses: Optional[Dict[int, float]] = None - event_study_cis: Optional[Dict[int, Tuple[float, float]]] = None - event_study_p_values: Optional[Dict[int, float]] = None - group_ses: Optional[Dict[Any, float]] = None - group_cis: Optional[Dict[Any, Tuple[float, float]]] = None - group_p_values: Optional[Dict[Any, float]] = None - bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False) - - -@dataclass -class TwoStageDiDResults: - """ - Results from Gardner (2022) two-stage DiD estimation. - - Attributes - ---------- - treatment_effects : pd.DataFrame - Per-observation treatment effects with columns: unit, time, - tau_hat, weight. tau_hat is the residualized outcome y_tilde - for treated observations; weight is 1/n_treated. - overall_att : float - Overall average treatment effect on the treated. - overall_se : float - Standard error of overall ATT (GMM sandwich). - overall_t_stat : float - T-statistic for overall ATT. - overall_p_value : float - P-value for overall ATT. - overall_conf_int : tuple - Confidence interval for overall ATT. - event_study_effects : dict, optional - Dictionary mapping relative time h to effect dict with keys: - 'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_obs'. - group_effects : dict, optional - Dictionary mapping cohort g to effect dict. - groups : list - List of treatment cohorts. - time_periods : list - List of all time periods. - n_obs : int - Total number of observations. - n_treated_obs : int - Number of treated observations. - n_untreated_obs : int - Number of untreated observations. - n_treated_units : int - Number of ever-treated units. - n_control_units : int - Number of units contributing to untreated observations. - alpha : float - Significance level used. - bootstrap_results : TwoStageBootstrapResults, optional - Bootstrap inference results. - """ - - treatment_effects: pd.DataFrame - overall_att: float - overall_se: float - overall_t_stat: float - overall_p_value: float - overall_conf_int: Tuple[float, float] - event_study_effects: Optional[Dict[int, Dict[str, Any]]] - group_effects: Optional[Dict[Any, Dict[str, Any]]] - groups: List[Any] - time_periods: List[Any] - n_obs: int - n_treated_obs: int - n_untreated_obs: int - n_treated_units: int - n_control_units: int - alpha: float = 0.05 - bootstrap_results: Optional[TwoStageBootstrapResults] = field(default=None, repr=False) - - def __repr__(self) -> str: - """Concise string representation.""" - sig = _get_significance_stars(self.overall_p_value) - return ( - f"TwoStageDiDResults(ATT={self.overall_att:.4f}{sig}, " - f"SE={self.overall_se:.4f}, " - f"n_groups={len(self.groups)}, " - f"n_treated_obs={self.n_treated_obs})" - ) - - def summary(self, alpha: Optional[float] = None) -> str: - """ - Generate formatted summary of estimation results. - - Parameters - ---------- - alpha : float, optional - Significance level. Defaults to alpha used in estimation. - - Returns - ------- - str - Formatted summary. - """ - alpha = alpha or self.alpha - conf_level = int((1 - alpha) * 100) - - lines = [ - "=" * 85, - "Two-Stage DiD Estimator Results (Gardner 2022)".center(85), - "=" * 85, - "", - f"{'Total observations:':<30} {self.n_obs:>10}", - f"{'Treated observations:':<30} {self.n_treated_obs:>10}", - f"{'Untreated observations:':<30} {self.n_untreated_obs:>10}", - f"{'Treated units:':<30} {self.n_treated_units:>10}", - f"{'Control units:':<30} {self.n_control_units:>10}", - f"{'Treatment cohorts:':<30} {len(self.groups):>10}", - f"{'Time periods:':<30} {len(self.time_periods):>10}", - "", - ] - - # Overall ATT - lines.extend( - [ - "-" * 85, - "Overall Average Treatment Effect on the Treated".center(85), - "-" * 85, - f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} " - f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", - "-" * 85, - ] - ) - - t_str = ( - f"{self.overall_t_stat:>10.3f}" if np.isfinite(self.overall_t_stat) else f"{'NaN':>10}" - ) - p_str = ( - f"{self.overall_p_value:>10.4f}" - if np.isfinite(self.overall_p_value) - else f"{'NaN':>10}" - ) - sig = _get_significance_stars(self.overall_p_value) - - lines.extend( - [ - f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} " - f"{t_str} {p_str} {sig:>6}", - "-" * 85, - "", - f"{conf_level}% Confidence Interval: " - f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]", - "", - ] - ) - - # Event study effects - if self.event_study_effects: - lines.extend( - [ - "-" * 85, - "Event Study (Dynamic) Effects".center(85), - "-" * 85, - f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} " - f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", - "-" * 85, - ] - ) - - for h in sorted(self.event_study_effects.keys()): - eff = self.event_study_effects[h] - if eff.get("n_obs", 1) == 0: - # Reference period marker - lines.append( - f"[ref: {h}]" f"{'0.0000':>17} {'---':>12} {'---':>10} {'---':>10} {'':>6}" - ) - elif np.isnan(eff["effect"]): - lines.append(f"{h:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}") - else: - e_sig = _get_significance_stars(eff["p_value"]) - e_t = ( - f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}" - ) - e_p = ( - f"{eff['p_value']:>10.4f}" - if np.isfinite(eff["p_value"]) - else f"{'NaN':>10}" - ) - lines.append( - f"{h:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " - f"{e_t} {e_p} {e_sig:>6}" - ) - - lines.extend(["-" * 85, ""]) - - # Group effects - if self.group_effects: - lines.extend( - [ - "-" * 85, - "Group (Cohort) Effects".center(85), - "-" * 85, - f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} " - f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", - "-" * 85, - ] - ) - - for g in sorted(self.group_effects.keys()): - eff = self.group_effects[g] - if np.isnan(eff["effect"]): - lines.append(f"{g:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}") - else: - g_sig = _get_significance_stars(eff["p_value"]) - g_t = ( - f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}" - ) - g_p = ( - f"{eff['p_value']:>10.4f}" - if np.isfinite(eff["p_value"]) - else f"{'NaN':>10}" - ) - lines.append( - f"{g:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " - f"{g_t} {g_p} {g_sig:>6}" - ) - - lines.extend(["-" * 85, ""]) - - lines.extend( - [ - "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", - "=" * 85, - ] - ) - - return "\n".join(lines) - - def print_summary(self, alpha: Optional[float] = None) -> None: - """Print summary to stdout.""" - print(self.summary(alpha)) - - def to_dataframe(self, level: str = "event_study") -> pd.DataFrame: - """ - Convert results to DataFrame. - - Parameters - ---------- - level : str, default="event_study" - Level of aggregation: - - "event_study": Event study effects by relative time - - "group": Group (cohort) effects - - "observation": Per-observation treatment effects - - Returns - ------- - pd.DataFrame - Results as DataFrame. - """ - if level == "observation": - return self.treatment_effects.copy() - - elif level == "event_study": - if self.event_study_effects is None: - raise ValueError( - "Event study effects not computed. " - "Use aggregate='event_study' or aggregate='all'." - ) - rows = [] - for h, data in sorted(self.event_study_effects.items()): - rows.append( - { - "relative_period": h, - "effect": data["effect"], - "se": data["se"], - "t_stat": data["t_stat"], - "p_value": data["p_value"], - "conf_int_lower": data["conf_int"][0], - "conf_int_upper": data["conf_int"][1], - "n_obs": data.get("n_obs", np.nan), - } - ) - return pd.DataFrame(rows) - - elif level == "group": - if self.group_effects is None: - raise ValueError( - "Group effects not computed. " "Use aggregate='group' or aggregate='all'." - ) - rows = [] - for g, data in sorted(self.group_effects.items()): - rows.append( - { - "group": g, - "effect": data["effect"], - "se": data["se"], - "t_stat": data["t_stat"], - "p_value": data["p_value"], - "conf_int_lower": data["conf_int"][0], - "conf_int_upper": data["conf_int"][1], - "n_obs": data.get("n_obs", np.nan), - } - ) - return pd.DataFrame(rows) - - else: - raise ValueError( - f"Unknown level: {level}. Use 'event_study', 'group', or 'observation'." - ) - - @property - def is_significant(self) -> bool: - """Check if overall ATT is significant.""" - return bool(self.overall_p_value < self.alpha) - - @property - def significance_stars(self) -> str: - """Significance stars for overall ATT.""" - return _get_significance_stars(self.overall_p_value) +from diff_diff.two_stage_bootstrap import TwoStageDiDBootstrapMixin +from diff_diff.two_stage_results import TwoStageBootstrapResults, TwoStageDiDResults # noqa: F401 (re-export) +from diff_diff.utils import safe_inference # ============================================================================= @@ -404,7 +40,7 @@ def significance_stars(self) -> str: # ============================================================================= -class TwoStageDiD: +class TwoStageDiD(TwoStageDiDBootstrapMixin): """ Gardner (2022) two-stage Difference-in-Differences estimator. @@ -723,14 +359,8 @@ def fit( kept_cov_mask=kept_cov_mask, ) - overall_t = ( - overall_att / overall_se if np.isfinite(overall_se) and overall_se > 0 else np.nan - ) - overall_p = compute_p_value(overall_t) - overall_ci = ( - compute_confidence_interval(overall_att, overall_se, self.alpha) - if np.isfinite(overall_se) and overall_se > 0 - else (np.nan, np.nan) + overall_t, overall_p, overall_ci = safe_inference( + overall_att, overall_se, alpha=self.alpha ) # Event study and group aggregation @@ -845,9 +475,9 @@ def fit( ) eff_val = event_study_effects[h]["effect"] se_val = event_study_effects[h]["se"] - event_study_effects[h]["t_stat"] = ( - eff_val / se_val if np.isfinite(se_val) and se_val > 0 else np.nan - ) + event_study_effects[h]["t_stat"] = safe_inference( + eff_val, se_val, alpha=self.alpha + )[0] # Update group effects if group_effects and bootstrap_results.group_ses: @@ -858,9 +488,9 @@ def fit( group_effects[g]["p_value"] = bootstrap_results.group_p_values[g] eff_val = group_effects[g]["effect"] se_val = group_effects[g]["se"] - group_effects[g]["t_stat"] = ( - eff_val / se_val if np.isfinite(se_val) and se_val > 0 else np.nan - ) + group_effects[g]["t_stat"] = safe_inference( + eff_val, se_val, alpha=self.alpha + )[0] # Construct results self.results_ = TwoStageDiDResults( @@ -1027,7 +657,7 @@ def _fit_untreated_model( kept_cov_mask = np.isfinite(delta_hat) delta_hat_clean = np.where(np.isfinite(delta_hat), delta_hat, 0.0) - y_adj = y - X_raw @ delta_hat_clean + y_adj = y - np.dot(X_raw, delta_hat_clean) unit_fe, time_fe = self._iterative_fe(y_adj, units, times, df_0.index) return unit_fe, time_fe, 0.0, delta_hat_clean, kept_cov_mask @@ -1063,7 +693,7 @@ def _residualize( y_hat = grand_mean + alpha_i + beta_t if delta_hat is not None and covariates: - y_hat = y_hat + df[covariates].values @ delta_hat + y_hat = y_hat + np.dot(df[covariates].values, delta_hat) y_tilde = df[outcome].values - y_hat return y_tilde @@ -1118,7 +748,7 @@ def _stage2_static( att = float(coef[0]) # GMM sandwich variance - eps_2 = y_tilde - X_2 @ coef # Stage 2 residuals + eps_2 = y_tilde - np.dot(X_2, coef) # Stage 2 residuals V = self._compute_gmm_variance( df=df, @@ -1273,7 +903,7 @@ def _stage2_event_study( # Stage 2 OLS coef, residuals, _ = solve_ols(X_2, y_tilde, return_vcov=False) - eps_2 = y_tilde - X_2 @ coef + eps_2 = y_tilde - np.dot(X_2, coef) # GMM variance for full coefficient vector V = self._compute_gmm_variance( @@ -1322,13 +952,7 @@ def _stage2_event_study( effect = float(coef[j]) se = float(np.sqrt(max(V[j, j], 0.0))) - t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan - p_val = compute_p_value(t_stat) - ci = ( - compute_confidence_interval(effect, se, self.alpha) - if np.isfinite(se) and se > 0 - else (np.nan, np.nan) - ) + t_stat, p_val, ci = safe_inference(effect, se, alpha=self.alpha) event_study_effects[h] = { "effect": effect, @@ -1391,7 +1015,7 @@ def _stage2_group( # Stage 2 OLS coef, residuals, _ = solve_ols(X_2, y_tilde, return_vcov=False) - eps_2 = y_tilde - X_2 @ coef + eps_2 = y_tilde - np.dot(X_2, coef) # GMM variance V = self._compute_gmm_variance( @@ -1428,13 +1052,7 @@ def _stage2_group( effect = float(coef[j]) se = float(np.sqrt(max(V[j, j], 0.0))) - t_stat = effect / se if np.isfinite(se) and se > 0 else np.nan - p_val = compute_p_value(t_stat) - ci = ( - compute_confidence_interval(effect, se, self.alpha) - if np.isfinite(se) and se > 0 - else (np.nan, np.nan) - ) + t_stat, p_val, ci = safe_inference(effect, se, alpha=self.alpha) group_effects[g] = { "effect": effect, @@ -1522,9 +1140,9 @@ def _compute_gmm_variance( fitted_1 = alpha_i + beta_t if delta_hat is not None and cov_list: if kept_cov_mask is not None and not np.all(kept_cov_mask): - fitted_1 = fitted_1 + df[cov_list].values @ delta_hat[kept_cov_mask] + fitted_1 = fitted_1 + np.dot(df[cov_list].values, delta_hat[kept_cov_mask]) else: - fitted_1 = fitted_1 + df[cov_list].values @ delta_hat + fitted_1 = fitted_1 + np.dot(df[cov_list].values, delta_hat) y_tilde = df["_y_tilde"].values y_vals = y_tilde + fitted_1 # reconstruct Y @@ -1575,7 +1193,7 @@ def _compute_gmm_variance( # 4. S_g = gamma_hat' c_g - X'_{2g} eps_{2g} with np.errstate(invalid="ignore", divide="ignore", over="ignore"): - correction = c_by_cluster @ gamma_hat # (G x p) @ (p x k) = (G x k) + correction = np.dot(c_by_cluster, gamma_hat) # (G x p) @ (p x k) = (G x k) # Replace NaN/inf from overflow (rank-deficient FE) with 0 np.nan_to_num(correction, copy=False, nan=0.0, posinf=0.0, neginf=0.0) S = correction - s2_by_cluster # (G x k) @@ -1675,435 +1293,6 @@ def _build_rows(mask=None): return X_1, X_10, unit_to_idx, time_to_idx - # ========================================================================= - # Bootstrap - # ========================================================================= - - def _compute_cluster_S_scores( - self, - df: pd.DataFrame, - unit: str, - time: str, - covariates: Optional[List[str]], - omega_0_mask: pd.Series, - unit_fe: Dict[Any, float], - time_fe: Dict[Any, float], - delta_hat: Optional[np.ndarray], - kept_cov_mask: Optional[np.ndarray], - X_2: np.ndarray, - eps_2: np.ndarray, - cluster_ids: np.ndarray, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Compute per-cluster S_g scores for bootstrap. - - Returns - ------- - S : np.ndarray, shape (G, k) - Per-cluster influence scores. - bread : np.ndarray, shape (k, k) - (X'_2 X_2)^{-1}. - unique_clusters : np.ndarray - Unique cluster identifiers. - """ - n = len(df) - k = X_2.shape[1] - - cov_list = covariates - if covariates and kept_cov_mask is not None and not np.all(kept_cov_mask): - cov_list = [c for c, k_ in zip(covariates, kept_cov_mask) if k_] - - X_1_sparse, X_10_sparse, _, _ = self._build_fe_design( - df, unit, time, cov_list, omega_0_mask - ) - p = X_1_sparse.shape[1] - - # Reconstruct Y and compute eps_10 - alpha_i = df[unit].map(unit_fe).values - beta_t = df[time].map(time_fe).values - alpha_i = np.where(pd.isna(alpha_i), 0.0, alpha_i).astype(float) - beta_t = np.where(pd.isna(beta_t), 0.0, beta_t).astype(float) - fitted_1 = alpha_i + beta_t - if delta_hat is not None and cov_list: - if kept_cov_mask is not None and not np.all(kept_cov_mask): - fitted_1 = fitted_1 + df[cov_list].values @ delta_hat[kept_cov_mask] - else: - fitted_1 = fitted_1 + df[cov_list].values @ delta_hat - - y_tilde = df["_y_tilde"].values - y_vals = y_tilde + fitted_1 - - eps_10 = np.empty(n) - omega_0 = omega_0_mask.values - eps_10[omega_0] = y_vals[omega_0] - fitted_1[omega_0] - eps_10[~omega_0] = y_vals[~omega_0] - - # gamma_hat - XtX_10 = X_10_sparse.T @ X_10_sparse - Xt1_X2 = X_1_sparse.T @ X_2 - - try: - solve_XtX = sparse_factorized(XtX_10.tocsc()) - if Xt1_X2.ndim == 1: - gamma_hat = solve_XtX(Xt1_X2).reshape(-1, 1) - else: - gamma_hat = np.column_stack( - [solve_XtX(Xt1_X2[:, j]) for j in range(Xt1_X2.shape[1])] - ) - except RuntimeError: - gamma_hat = np.linalg.lstsq(XtX_10.toarray(), Xt1_X2, rcond=None)[0] - if gamma_hat.ndim == 1: - gamma_hat = gamma_hat.reshape(-1, 1) - - # Per-cluster aggregation - weighted_X10 = X_10_sparse.multiply(eps_10[:, None]) - unique_clusters, cluster_indices = np.unique(cluster_ids, return_inverse=True) - G = len(unique_clusters) - - weighted_X10_csc = weighted_X10.tocsc() - c_by_cluster = np.zeros((G, p)) - for j_col in range(p): - col_data = weighted_X10_csc.getcol(j_col).toarray().ravel() - np.add.at(c_by_cluster[:, j_col], cluster_indices, col_data) - - weighted_X2 = X_2 * eps_2[:, None] - s2_by_cluster = np.zeros((G, k)) - for j_col in range(k): - np.add.at(s2_by_cluster[:, j_col], cluster_indices, weighted_X2[:, j_col]) - - correction = c_by_cluster @ gamma_hat - S = correction - s2_by_cluster - - # Bread - XtX_2 = X_2.T @ X_2 - try: - bread = np.linalg.solve(XtX_2, np.eye(k)) - except np.linalg.LinAlgError: - bread = np.linalg.lstsq(XtX_2, np.eye(k), rcond=None)[0] - - return S, bread, unique_clusters - - def _run_bootstrap( - self, - df: pd.DataFrame, - unit: str, - time: str, - first_treat: str, - covariates: Optional[List[str]], - omega_0_mask: pd.Series, - omega_1_mask: pd.Series, - unit_fe: Dict[Any, float], - time_fe: Dict[Any, float], - grand_mean: float, - delta_hat: Optional[np.ndarray], - cluster_var: str, - kept_cov_mask: Optional[np.ndarray], - treatment_groups: List[Any], - ref_period: int, - balance_e: Optional[int], - original_att: float, - original_event_study: Optional[Dict[int, Dict[str, Any]]], - original_group: Optional[Dict[Any, Dict[str, Any]]], - aggregate: Optional[str], - ) -> Optional[TwoStageBootstrapResults]: - """Run multiplier bootstrap on GMM influence function.""" - if self.n_bootstrap < 50: - warnings.warn( - f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 " - "for reliable inference.", - UserWarning, - stacklevel=3, - ) - - rng = np.random.default_rng(self.seed) - - from diff_diff.staggered_bootstrap import _generate_bootstrap_weights_batch - - y_tilde = df["_y_tilde"].values.copy() # .copy() to avoid mutating df column - n = len(df) - cluster_ids = df[cluster_var].values - - # Handle NaN y_tilde (from unidentified FEs) — matches _stage2_static logic - nan_mask = ~np.isfinite(y_tilde) - if nan_mask.any(): - y_tilde[nan_mask] = 0.0 - - # --- Static specification bootstrap --- - D = omega_1_mask.values.astype(float) # .astype() already creates a copy - D[nan_mask] = 0.0 # Exclude NaN y_tilde obs from bootstrap estimation - - # Degenerate case: all treated obs have NaN y_tilde - if D.sum() == 0: - return None - - X_2_static = D.reshape(-1, 1) - coef_static = solve_ols(X_2_static, y_tilde, return_vcov=False)[0] - eps_2_static = y_tilde - X_2_static @ coef_static - - S_static, bread_static, unique_clusters = self._compute_cluster_S_scores( - df=df, - unit=unit, - time=time, - covariates=covariates, - omega_0_mask=omega_0_mask, - unit_fe=unit_fe, - time_fe=time_fe, - delta_hat=delta_hat, - kept_cov_mask=kept_cov_mask, - X_2=X_2_static, - eps_2=eps_2_static, - cluster_ids=cluster_ids, - ) - - n_clusters = len(unique_clusters) - all_weights = _generate_bootstrap_weights_batch( - self.n_bootstrap, n_clusters, "rademacher", rng - ) - - # T_b = bread @ (sum_g w_bg * S_g) = bread @ (W @ S)' per boot - # IF_b = bread @ S_g for each cluster, then perturb - # boot_coef = all_weights @ S_static @ bread_static.T → (B, k) - # For static (k=1): boot_att = all_weights @ S_static @ bread_static.T - boot_att_vec = all_weights @ S_static # (B, 1) - boot_att_vec = boot_att_vec @ bread_static.T # (B, 1) - boot_overall = boot_att_vec[:, 0] - - boot_overall_shifted = boot_overall + original_att - overall_se = float(np.std(boot_overall, ddof=1)) - overall_ci = ( - self._compute_percentile_ci(boot_overall_shifted, self.alpha) - if overall_se > 0 - else (np.nan, np.nan) - ) - overall_p = ( - self._compute_bootstrap_pvalue(original_att, boot_overall_shifted) - if overall_se > 0 - else np.nan - ) - - # --- Event study bootstrap --- - event_study_ses = None - event_study_cis = None - event_study_p_values = None - - if original_event_study and aggregate in ("event_study", "all"): - # Recompute S scores for event study specification - rel_times = df["_rel_time"].values - treated_rel = rel_times[omega_1_mask.values] - all_horizons = sorted(set(int(h) for h in treated_rel if np.isfinite(h))) - if self.horizon_max is not None: - all_horizons = [h for h in all_horizons if abs(h) <= self.horizon_max] - - if balance_e is not None: - cohort_rel_times = self._build_cohort_rel_times(df, first_treat) - balanced_cohorts = set() - if all_horizons: - max_h = max(all_horizons) - required_range = set(range(-balance_e, max_h + 1)) - for g, horizons in cohort_rel_times.items(): - if required_range.issubset(horizons): - balanced_cohorts.add(g) - if not balanced_cohorts: - all_horizons = [] # No qualifying cohorts -> skip event study bootstrap - else: - balance_mask = df[first_treat].isin(balanced_cohorts).values - else: - balance_mask = np.ones(n, dtype=bool) - - est_horizons = [h for h in all_horizons if h != ref_period] - - # Filter out Prop 5 horizons (same logic as _stage2_event_study) - has_never_treated = df["_never_treated"].any() - h_bar_boot = np.inf - if not has_never_treated and len(treatment_groups) > 1: - h_bar_boot = max(treatment_groups) - min(treatment_groups) - if h_bar_boot < np.inf: - est_horizons = [h for h in est_horizons if h < h_bar_boot] - - if est_horizons: - horizon_to_col = {h: j for j, h in enumerate(est_horizons)} - k_es = len(est_horizons) - X_2_es = np.zeros((n, k_es)) - for i in range(n): - if not balance_mask[i]: - continue - if nan_mask[i]: - continue # NaN y_tilde -> exclude from bootstrap event study - h = rel_times[i] - if np.isfinite(h): - h_int = int(h) - if h_int in horizon_to_col: - X_2_es[i, horizon_to_col[h_int]] = 1.0 - - coef_es = solve_ols(X_2_es, y_tilde, return_vcov=False)[0] - eps_2_es = y_tilde - X_2_es @ coef_es - - S_es, bread_es, _ = self._compute_cluster_S_scores( - df=df, - unit=unit, - time=time, - covariates=covariates, - omega_0_mask=omega_0_mask, - unit_fe=unit_fe, - time_fe=time_fe, - delta_hat=delta_hat, - kept_cov_mask=kept_cov_mask, - X_2=X_2_es, - eps_2=eps_2_es, - cluster_ids=cluster_ids, - ) - - # boot_coef_es: (B, k_es) - boot_coef_es = (all_weights @ S_es) @ bread_es.T - - event_study_ses = {} - event_study_cis = {} - event_study_p_values = {} - for h in original_event_study: - if original_event_study[h].get("n_obs", 0) == 0: - continue - if np.isnan(original_event_study[h]["effect"]): - continue # Skip Prop 5 and other NaN-effect horizons - if h not in horizon_to_col: - continue - j = horizon_to_col[h] - orig_eff = original_event_study[h]["effect"] - boot_h = boot_coef_es[:, j] - se_h = float(np.std(boot_h, ddof=1)) - event_study_ses[h] = se_h - if se_h > 0 and np.isfinite(orig_eff): - shifted_h = boot_h + orig_eff - event_study_p_values[h] = self._compute_bootstrap_pvalue( - orig_eff, shifted_h - ) - event_study_cis[h] = self._compute_percentile_ci(shifted_h, self.alpha) - else: - event_study_p_values[h] = np.nan - event_study_cis[h] = (np.nan, np.nan) - - # --- Group bootstrap --- - group_ses = None - group_cis = None - group_p_values = None - - if original_group and aggregate in ("group", "all"): - group_to_col = {g: j for j, g in enumerate(treatment_groups)} - k_grp = len(treatment_groups) - X_2_grp = np.zeros((n, k_grp)) - ft_vals = df[first_treat].values - treated_mask = omega_1_mask.values - for i in range(n): - if treated_mask[i]: - if nan_mask[i]: - continue # NaN y_tilde -> exclude from group bootstrap - g = ft_vals[i] - if g in group_to_col: - X_2_grp[i, group_to_col[g]] = 1.0 - - coef_grp = solve_ols(X_2_grp, y_tilde, return_vcov=False)[0] - eps_2_grp = y_tilde - X_2_grp @ coef_grp - - S_grp, bread_grp, _ = self._compute_cluster_S_scores( - df=df, - unit=unit, - time=time, - covariates=covariates, - omega_0_mask=omega_0_mask, - unit_fe=unit_fe, - time_fe=time_fe, - delta_hat=delta_hat, - kept_cov_mask=kept_cov_mask, - X_2=X_2_grp, - eps_2=eps_2_grp, - cluster_ids=cluster_ids, - ) - - boot_coef_grp = (all_weights @ S_grp) @ bread_grp.T - - group_ses = {} - group_cis = {} - group_p_values = {} - for g in original_group: - if g not in group_to_col: - continue - j = group_to_col[g] - orig_eff = original_group[g]["effect"] - boot_g = boot_coef_grp[:, j] - se_g = float(np.std(boot_g, ddof=1)) - group_ses[g] = se_g - if se_g > 0 and np.isfinite(orig_eff): - shifted_g = boot_g + orig_eff - group_p_values[g] = self._compute_bootstrap_pvalue(orig_eff, shifted_g) - group_cis[g] = self._compute_percentile_ci(shifted_g, self.alpha) - else: - group_p_values[g] = np.nan - group_cis[g] = (np.nan, np.nan) - - return TwoStageBootstrapResults( - n_bootstrap=self.n_bootstrap, - weight_type="rademacher", - alpha=self.alpha, - overall_att_se=overall_se, - overall_att_ci=overall_ci, - overall_att_p_value=overall_p, - event_study_ses=event_study_ses, - event_study_cis=event_study_cis, - event_study_p_values=event_study_p_values, - group_ses=group_ses, - group_cis=group_cis, - group_p_values=group_p_values, - bootstrap_distribution=boot_overall_shifted, - ) - - # ========================================================================= - # Bootstrap helpers - # ========================================================================= - - @staticmethod - def _compute_percentile_ci( - boot_dist: np.ndarray, - alpha: float, - ) -> Tuple[float, float]: - """Compute percentile confidence interval from bootstrap distribution.""" - lower = float(np.percentile(boot_dist, alpha / 2 * 100)) - upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100)) - return (lower, upper) - - @staticmethod - def _compute_bootstrap_pvalue( - original_effect: float, - boot_dist: np.ndarray, - ) -> float: - """Compute two-sided bootstrap p-value.""" - if original_effect >= 0: - p_one_sided = float(np.mean(boot_dist <= 0)) - else: - p_one_sided = float(np.mean(boot_dist >= 0)) - p_value = min(2 * p_one_sided, 1.0) - p_value = max(p_value, 1 / (len(boot_dist) + 1)) - return p_value - - # ========================================================================= - # Utility - # ========================================================================= - - @staticmethod - def _build_cohort_rel_times( - df: pd.DataFrame, - first_treat: str, - ) -> Dict[Any, Set[int]]: - """Build mapping of cohort -> set of observed relative times.""" - treated_mask = ~df["_never_treated"] - treated_df = df.loc[treated_mask] - result: Dict[Any, Set[int]] = {} - ft_vals = treated_df[first_treat].values - rt_vals = treated_df["_rel_time"].values - for i in range(len(treated_df)): - h = rt_vals[i] - if np.isfinite(h): - result.setdefault(ft_vals[i], set()).add(int(h)) - return result - # ========================================================================= # sklearn-compatible interface # ========================================================================= diff --git a/diff_diff/two_stage_bootstrap.py b/diff_diff/two_stage_bootstrap.py new file mode 100644 index 00000000..f33a5501 --- /dev/null +++ b/diff_diff/two_stage_bootstrap.py @@ -0,0 +1,449 @@ +""" +Bootstrap inference methods for the Two-Stage DiD estimator. + +This module contains TwoStageDiDBootstrapMixin, which provides multiplier +bootstrap inference on the GMM influence function. Extracted from two_stage.py +for module size management. +""" + +import warnings +from typing import Any, Dict, List, Optional, Set, Tuple + +import numpy as np +import pandas as pd +from scipy.sparse.linalg import factorized as sparse_factorized + +from diff_diff.linalg import solve_ols +from diff_diff.staggered_bootstrap import _generate_bootstrap_weights_batch +from diff_diff.two_stage_results import TwoStageBootstrapResults + +__all__ = [ + "TwoStageDiDBootstrapMixin", +] + + +class TwoStageDiDBootstrapMixin: + """Mixin providing bootstrap inference methods for TwoStageDiD.""" + + def _compute_cluster_S_scores( + self, + df: pd.DataFrame, + unit: str, + time: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + delta_hat: Optional[np.ndarray], + kept_cov_mask: Optional[np.ndarray], + X_2: np.ndarray, + eps_2: np.ndarray, + cluster_ids: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Compute per-cluster S_g scores for bootstrap. + + Returns + ------- + S : np.ndarray, shape (G, k) + Per-cluster influence scores. + bread : np.ndarray, shape (k, k) + (X'_2 X_2)^{-1}. + unique_clusters : np.ndarray + Unique cluster identifiers. + """ + n = len(df) + k = X_2.shape[1] + + cov_list = covariates + if covariates and kept_cov_mask is not None and not np.all(kept_cov_mask): + cov_list = [c for c, k_ in zip(covariates, kept_cov_mask) if k_] + + X_1_sparse, X_10_sparse, _, _ = self._build_fe_design( + df, unit, time, cov_list, omega_0_mask + ) + p = X_1_sparse.shape[1] + + # Reconstruct Y and compute eps_10 + alpha_i = df[unit].map(unit_fe).values + beta_t = df[time].map(time_fe).values + alpha_i = np.where(pd.isna(alpha_i), 0.0, alpha_i).astype(float) + beta_t = np.where(pd.isna(beta_t), 0.0, beta_t).astype(float) + fitted_1 = alpha_i + beta_t + if delta_hat is not None and cov_list: + if kept_cov_mask is not None and not np.all(kept_cov_mask): + fitted_1 = fitted_1 + np.dot(df[cov_list].values, delta_hat[kept_cov_mask]) + else: + fitted_1 = fitted_1 + np.dot(df[cov_list].values, delta_hat) + + y_tilde = df["_y_tilde"].values + y_vals = y_tilde + fitted_1 + + eps_10 = np.empty(n) + omega_0 = omega_0_mask.values + eps_10[omega_0] = y_vals[omega_0] - fitted_1[omega_0] + eps_10[~omega_0] = y_vals[~omega_0] + + # gamma_hat + XtX_10 = X_10_sparse.T @ X_10_sparse + Xt1_X2 = X_1_sparse.T @ X_2 + + try: + solve_XtX = sparse_factorized(XtX_10.tocsc()) + if Xt1_X2.ndim == 1: + gamma_hat = solve_XtX(Xt1_X2).reshape(-1, 1) + else: + gamma_hat = np.column_stack( + [solve_XtX(Xt1_X2[:, j]) for j in range(Xt1_X2.shape[1])] + ) + except RuntimeError: + gamma_hat = np.linalg.lstsq(XtX_10.toarray(), Xt1_X2, rcond=None)[0] + if gamma_hat.ndim == 1: + gamma_hat = gamma_hat.reshape(-1, 1) + + # Per-cluster aggregation + weighted_X10 = X_10_sparse.multiply(eps_10[:, None]) + unique_clusters, cluster_indices = np.unique(cluster_ids, return_inverse=True) + G = len(unique_clusters) + + weighted_X10_csc = weighted_X10.tocsc() + c_by_cluster = np.zeros((G, p)) + for j_col in range(p): + col_data = weighted_X10_csc.getcol(j_col).toarray().ravel() + np.add.at(c_by_cluster[:, j_col], cluster_indices, col_data) + + weighted_X2 = X_2 * eps_2[:, None] + s2_by_cluster = np.zeros((G, k)) + for j_col in range(k): + np.add.at(s2_by_cluster[:, j_col], cluster_indices, weighted_X2[:, j_col]) + + correction = np.dot(c_by_cluster, gamma_hat) + S = correction - s2_by_cluster + + # Bread + XtX_2 = np.dot(X_2.T, X_2) + try: + bread = np.linalg.solve(XtX_2, np.eye(k)) + except np.linalg.LinAlgError: + bread = np.linalg.lstsq(XtX_2, np.eye(k), rcond=None)[0] + + return S, bread, unique_clusters + + def _run_bootstrap( + self, + df: pd.DataFrame, + unit: str, + time: str, + first_treat: str, + covariates: Optional[List[str]], + omega_0_mask: pd.Series, + omega_1_mask: pd.Series, + unit_fe: Dict[Any, float], + time_fe: Dict[Any, float], + grand_mean: float, + delta_hat: Optional[np.ndarray], + cluster_var: str, + kept_cov_mask: Optional[np.ndarray], + treatment_groups: List[Any], + ref_period: int, + balance_e: Optional[int], + original_att: float, + original_event_study: Optional[Dict[int, Dict[str, Any]]], + original_group: Optional[Dict[Any, Dict[str, Any]]], + aggregate: Optional[str], + ) -> Optional[TwoStageBootstrapResults]: + """Run multiplier bootstrap on GMM influence function.""" + if self.n_bootstrap < 50: + warnings.warn( + f"n_bootstrap={self.n_bootstrap} is low. Consider n_bootstrap >= 199 " + "for reliable inference.", + UserWarning, + stacklevel=3, + ) + + rng = np.random.default_rng(self.seed) + + y_tilde = df["_y_tilde"].values.copy() # .copy() to avoid mutating df column + n = len(df) + cluster_ids = df[cluster_var].values + + # Handle NaN y_tilde (from unidentified FEs) — matches _stage2_static logic + nan_mask = ~np.isfinite(y_tilde) + if nan_mask.any(): + y_tilde[nan_mask] = 0.0 + + # --- Static specification bootstrap --- + D = omega_1_mask.values.astype(float) # .astype() already creates a copy + D[nan_mask] = 0.0 # Exclude NaN y_tilde obs from bootstrap estimation + + # Degenerate case: all treated obs have NaN y_tilde + if D.sum() == 0: + return None + + X_2_static = D.reshape(-1, 1) + coef_static = solve_ols(X_2_static, y_tilde, return_vcov=False)[0] + eps_2_static = y_tilde - np.dot(X_2_static, coef_static) + + S_static, bread_static, unique_clusters = self._compute_cluster_S_scores( + df=df, + unit=unit, + time=time, + covariates=covariates, + omega_0_mask=omega_0_mask, + unit_fe=unit_fe, + time_fe=time_fe, + delta_hat=delta_hat, + kept_cov_mask=kept_cov_mask, + X_2=X_2_static, + eps_2=eps_2_static, + cluster_ids=cluster_ids, + ) + + n_clusters = len(unique_clusters) + all_weights = _generate_bootstrap_weights_batch( + self.n_bootstrap, n_clusters, "rademacher", rng + ) + + # T_b = bread @ (sum_g w_bg * S_g) = bread @ (W @ S)' per boot + # IF_b = bread @ S_g for each cluster, then perturb + # boot_coef = all_weights @ S_static @ bread_static.T -> (B, k) + # For static (k=1): boot_att = all_weights @ S_static @ bread_static.T + boot_att_vec = np.dot(all_weights, S_static) # (B, 1) + boot_att_vec = np.dot(boot_att_vec, bread_static.T) # (B, 1) + boot_overall = boot_att_vec[:, 0] + + boot_overall_shifted = boot_overall + original_att + overall_se = float(np.std(boot_overall, ddof=1)) + overall_ci = ( + self._compute_percentile_ci(boot_overall_shifted, self.alpha) + if overall_se > 0 + else (np.nan, np.nan) + ) + overall_p = ( + self._compute_bootstrap_pvalue(original_att, boot_overall_shifted) + if overall_se > 0 + else np.nan + ) + + # --- Event study bootstrap --- + event_study_ses = None + event_study_cis = None + event_study_p_values = None + + if original_event_study and aggregate in ("event_study", "all"): + # Recompute S scores for event study specification + rel_times = df["_rel_time"].values + treated_rel = rel_times[omega_1_mask.values] + all_horizons = sorted(set(int(h) for h in treated_rel if np.isfinite(h))) + if self.horizon_max is not None: + all_horizons = [h for h in all_horizons if abs(h) <= self.horizon_max] + + if balance_e is not None: + cohort_rel_times = self._build_cohort_rel_times(df, first_treat) + balanced_cohorts = set() + if all_horizons: + max_h = max(all_horizons) + required_range = set(range(-balance_e, max_h + 1)) + for g, horizons in cohort_rel_times.items(): + if required_range.issubset(horizons): + balanced_cohorts.add(g) + if not balanced_cohorts: + all_horizons = [] # No qualifying cohorts -> skip event study bootstrap + else: + balance_mask = df[first_treat].isin(balanced_cohorts).values + else: + balance_mask = np.ones(n, dtype=bool) + + est_horizons = [h for h in all_horizons if h != ref_period] + + # Filter out Prop 5 horizons (same logic as _stage2_event_study) + has_never_treated = df["_never_treated"].any() + h_bar_boot = np.inf + if not has_never_treated and len(treatment_groups) > 1: + h_bar_boot = max(treatment_groups) - min(treatment_groups) + if h_bar_boot < np.inf: + est_horizons = [h for h in est_horizons if h < h_bar_boot] + + if est_horizons: + horizon_to_col = {h: j for j, h in enumerate(est_horizons)} + k_es = len(est_horizons) + X_2_es = np.zeros((n, k_es)) + for i in range(n): + if not balance_mask[i]: + continue + if nan_mask[i]: + continue # NaN y_tilde -> exclude from bootstrap event study + h = rel_times[i] + if np.isfinite(h): + h_int = int(h) + if h_int in horizon_to_col: + X_2_es[i, horizon_to_col[h_int]] = 1.0 + + coef_es = solve_ols(X_2_es, y_tilde, return_vcov=False)[0] + eps_2_es = y_tilde - np.dot(X_2_es, coef_es) + + S_es, bread_es, _ = self._compute_cluster_S_scores( + df=df, + unit=unit, + time=time, + covariates=covariates, + omega_0_mask=omega_0_mask, + unit_fe=unit_fe, + time_fe=time_fe, + delta_hat=delta_hat, + kept_cov_mask=kept_cov_mask, + X_2=X_2_es, + eps_2=eps_2_es, + cluster_ids=cluster_ids, + ) + + # boot_coef_es: (B, k_es) + boot_coef_es = np.dot(np.dot(all_weights, S_es), bread_es.T) + + event_study_ses = {} + event_study_cis = {} + event_study_p_values = {} + for h in original_event_study: + if original_event_study[h].get("n_obs", 0) == 0: + continue + if np.isnan(original_event_study[h]["effect"]): + continue # Skip Prop 5 and other NaN-effect horizons + if h not in horizon_to_col: + continue + j = horizon_to_col[h] + orig_eff = original_event_study[h]["effect"] + boot_h = boot_coef_es[:, j] + se_h = float(np.std(boot_h, ddof=1)) + event_study_ses[h] = se_h + if se_h > 0 and np.isfinite(orig_eff): + shifted_h = boot_h + orig_eff + event_study_p_values[h] = self._compute_bootstrap_pvalue( + orig_eff, shifted_h + ) + event_study_cis[h] = self._compute_percentile_ci(shifted_h, self.alpha) + else: + event_study_p_values[h] = np.nan + event_study_cis[h] = (np.nan, np.nan) + + # --- Group bootstrap --- + group_ses = None + group_cis = None + group_p_values = None + + if original_group and aggregate in ("group", "all"): + group_to_col = {g: j for j, g in enumerate(treatment_groups)} + k_grp = len(treatment_groups) + X_2_grp = np.zeros((n, k_grp)) + ft_vals = df[first_treat].values + treated_mask = omega_1_mask.values + for i in range(n): + if treated_mask[i]: + if nan_mask[i]: + continue # NaN y_tilde -> exclude from group bootstrap + g = ft_vals[i] + if g in group_to_col: + X_2_grp[i, group_to_col[g]] = 1.0 + + coef_grp = solve_ols(X_2_grp, y_tilde, return_vcov=False)[0] + eps_2_grp = y_tilde - np.dot(X_2_grp, coef_grp) + + S_grp, bread_grp, _ = self._compute_cluster_S_scores( + df=df, + unit=unit, + time=time, + covariates=covariates, + omega_0_mask=omega_0_mask, + unit_fe=unit_fe, + time_fe=time_fe, + delta_hat=delta_hat, + kept_cov_mask=kept_cov_mask, + X_2=X_2_grp, + eps_2=eps_2_grp, + cluster_ids=cluster_ids, + ) + + boot_coef_grp = np.dot(np.dot(all_weights, S_grp), bread_grp.T) + + group_ses = {} + group_cis = {} + group_p_values = {} + for g in original_group: + if g not in group_to_col: + continue + j = group_to_col[g] + orig_eff = original_group[g]["effect"] + boot_g = boot_coef_grp[:, j] + se_g = float(np.std(boot_g, ddof=1)) + group_ses[g] = se_g + if se_g > 0 and np.isfinite(orig_eff): + shifted_g = boot_g + orig_eff + group_p_values[g] = self._compute_bootstrap_pvalue(orig_eff, shifted_g) + group_cis[g] = self._compute_percentile_ci(shifted_g, self.alpha) + else: + group_p_values[g] = np.nan + group_cis[g] = (np.nan, np.nan) + + return TwoStageBootstrapResults( + n_bootstrap=self.n_bootstrap, + weight_type="rademacher", + alpha=self.alpha, + overall_att_se=overall_se, + overall_att_ci=overall_ci, + overall_att_p_value=overall_p, + event_study_ses=event_study_ses, + event_study_cis=event_study_cis, + event_study_p_values=event_study_p_values, + group_ses=group_ses, + group_cis=group_cis, + group_p_values=group_p_values, + bootstrap_distribution=boot_overall_shifted, + ) + + # ========================================================================= + # Bootstrap helpers + # ========================================================================= + + @staticmethod + def _compute_percentile_ci( + boot_dist: np.ndarray, + alpha: float, + ) -> Tuple[float, float]: + """Compute percentile confidence interval from bootstrap distribution.""" + lower = float(np.percentile(boot_dist, alpha / 2 * 100)) + upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100)) + return (lower, upper) + + @staticmethod + def _compute_bootstrap_pvalue( + original_effect: float, + boot_dist: np.ndarray, + ) -> float: + """Compute two-sided bootstrap p-value.""" + if original_effect >= 0: + p_one_sided = float(np.mean(boot_dist <= 0)) + else: + p_one_sided = float(np.mean(boot_dist >= 0)) + p_value = min(2 * p_one_sided, 1.0) + p_value = max(p_value, 1 / (len(boot_dist) + 1)) + return p_value + + # ========================================================================= + # Utility + # ========================================================================= + + @staticmethod + def _build_cohort_rel_times( + df: pd.DataFrame, + first_treat: str, + ) -> Dict[Any, Set[int]]: + """Build mapping of cohort -> set of observed relative times.""" + treated_mask = ~df["_never_treated"] + treated_df = df.loc[treated_mask] + result: Dict[Any, Set[int]] = {} + ft_vals = treated_df[first_treat].values + rt_vals = treated_df["_rel_time"].values + for i in range(len(treated_df)): + h = rt_vals[i] + if np.isfinite(h): + result.setdefault(ft_vals[i], set()).add(int(h)) + return result diff --git a/diff_diff/two_stage_results.py b/diff_diff/two_stage_results.py new file mode 100644 index 00000000..a1277726 --- /dev/null +++ b/diff_diff/two_stage_results.py @@ -0,0 +1,379 @@ +""" +Result containers for the Two-Stage DiD estimator. + +This module contains TwoStageBootstrapResults and TwoStageDiDResults +dataclasses. Extracted from two_stage.py for module size management. +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd + +from diff_diff.results import _get_significance_stars + +__all__ = [ + "TwoStageBootstrapResults", + "TwoStageDiDResults", +] + + +@dataclass +class TwoStageBootstrapResults: + """ + Results from TwoStageDiD bootstrap inference. + + Bootstrap uses multiplier bootstrap on the GMM influence function, + consistent with other library estimators. The R `did2s` package uses + block bootstrap by default; multiplier bootstrap is asymptotically + equivalent. + + Attributes + ---------- + n_bootstrap : int + Number of bootstrap iterations. + weight_type : str + Type of bootstrap weights (currently "rademacher" only). + alpha : float + Significance level used for confidence intervals. + overall_att_se : float + Bootstrap standard error for overall ATT. + overall_att_ci : tuple + Bootstrap confidence interval for overall ATT. + overall_att_p_value : float + Bootstrap p-value for overall ATT. + event_study_ses : dict, optional + Bootstrap SEs for event study effects. + event_study_cis : dict, optional + Bootstrap CIs for event study effects. + event_study_p_values : dict, optional + Bootstrap p-values for event study effects. + group_ses : dict, optional + Bootstrap SEs for group effects. + group_cis : dict, optional + Bootstrap CIs for group effects. + group_p_values : dict, optional + Bootstrap p-values for group effects. + bootstrap_distribution : np.ndarray, optional + Full bootstrap distribution of overall ATT. + """ + + n_bootstrap: int + weight_type: str + alpha: float + overall_att_se: float + overall_att_ci: Tuple[float, float] + overall_att_p_value: float + event_study_ses: Optional[Dict[int, float]] = None + event_study_cis: Optional[Dict[int, Tuple[float, float]]] = None + event_study_p_values: Optional[Dict[int, float]] = None + group_ses: Optional[Dict[Any, float]] = None + group_cis: Optional[Dict[Any, Tuple[float, float]]] = None + group_p_values: Optional[Dict[Any, float]] = None + bootstrap_distribution: Optional[np.ndarray] = field(default=None, repr=False) + + +@dataclass +class TwoStageDiDResults: + """ + Results from Gardner (2022) two-stage DiD estimation. + + Attributes + ---------- + treatment_effects : pd.DataFrame + Per-observation treatment effects with columns: unit, time, + tau_hat, weight. tau_hat is the residualized outcome y_tilde + for treated observations; weight is 1/n_treated. + overall_att : float + Overall average treatment effect on the treated. + overall_se : float + Standard error of overall ATT (GMM sandwich). + overall_t_stat : float + T-statistic for overall ATT. + overall_p_value : float + P-value for overall ATT. + overall_conf_int : tuple + Confidence interval for overall ATT. + event_study_effects : dict, optional + Dictionary mapping relative time h to effect dict with keys: + 'effect', 'se', 't_stat', 'p_value', 'conf_int', 'n_obs'. + group_effects : dict, optional + Dictionary mapping cohort g to effect dict. + groups : list + List of treatment cohorts. + time_periods : list + List of all time periods. + n_obs : int + Total number of observations. + n_treated_obs : int + Number of treated observations. + n_untreated_obs : int + Number of untreated observations. + n_treated_units : int + Number of ever-treated units. + n_control_units : int + Number of units contributing to untreated observations. + alpha : float + Significance level used. + bootstrap_results : TwoStageBootstrapResults, optional + Bootstrap inference results. + """ + + treatment_effects: pd.DataFrame + overall_att: float + overall_se: float + overall_t_stat: float + overall_p_value: float + overall_conf_int: Tuple[float, float] + event_study_effects: Optional[Dict[int, Dict[str, Any]]] + group_effects: Optional[Dict[Any, Dict[str, Any]]] + groups: List[Any] + time_periods: List[Any] + n_obs: int + n_treated_obs: int + n_untreated_obs: int + n_treated_units: int + n_control_units: int + alpha: float = 0.05 + bootstrap_results: Optional[TwoStageBootstrapResults] = field(default=None, repr=False) + + def __repr__(self) -> str: + """Concise string representation.""" + sig = _get_significance_stars(self.overall_p_value) + return ( + f"TwoStageDiDResults(ATT={self.overall_att:.4f}{sig}, " + f"SE={self.overall_se:.4f}, " + f"n_groups={len(self.groups)}, " + f"n_treated_obs={self.n_treated_obs})" + ) + + def summary(self, alpha: Optional[float] = None) -> str: + """ + Generate formatted summary of estimation results. + + Parameters + ---------- + alpha : float, optional + Significance level. Defaults to alpha used in estimation. + + Returns + ------- + str + Formatted summary. + """ + alpha = alpha or self.alpha + conf_level = int((1 - alpha) * 100) + + lines = [ + "=" * 85, + "Two-Stage DiD Estimator Results (Gardner 2022)".center(85), + "=" * 85, + "", + f"{'Total observations:':<30} {self.n_obs:>10}", + f"{'Treated observations:':<30} {self.n_treated_obs:>10}", + f"{'Untreated observations:':<30} {self.n_untreated_obs:>10}", + f"{'Treated units:':<30} {self.n_treated_units:>10}", + f"{'Control units:':<30} {self.n_control_units:>10}", + f"{'Treatment cohorts:':<30} {len(self.groups):>10}", + f"{'Time periods:':<30} {len(self.time_periods):>10}", + "", + ] + + # Overall ATT + lines.extend( + [ + "-" * 85, + "Overall Average Treatment Effect on the Treated".center(85), + "-" * 85, + f"{'Parameter':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + t_str = ( + f"{self.overall_t_stat:>10.3f}" if np.isfinite(self.overall_t_stat) else f"{'NaN':>10}" + ) + p_str = ( + f"{self.overall_p_value:>10.4f}" + if np.isfinite(self.overall_p_value) + else f"{'NaN':>10}" + ) + sig = _get_significance_stars(self.overall_p_value) + + lines.extend( + [ + f"{'ATT':<15} {self.overall_att:>12.4f} {self.overall_se:>12.4f} " + f"{t_str} {p_str} {sig:>6}", + "-" * 85, + "", + f"{conf_level}% Confidence Interval: " + f"[{self.overall_conf_int[0]:.4f}, {self.overall_conf_int[1]:.4f}]", + "", + ] + ) + + # Event study effects + if self.event_study_effects: + lines.extend( + [ + "-" * 85, + "Event Study (Dynamic) Effects".center(85), + "-" * 85, + f"{'Rel. Period':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + for h in sorted(self.event_study_effects.keys()): + eff = self.event_study_effects[h] + if eff.get("n_obs", 1) == 0: + # Reference period marker + lines.append( + f"[ref: {h}]" f"{'0.0000':>17} {'---':>12} {'---':>10} {'---':>10} {'':>6}" + ) + elif np.isnan(eff["effect"]): + lines.append(f"{h:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}") + else: + e_sig = _get_significance_stars(eff["p_value"]) + e_t = ( + f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}" + ) + e_p = ( + f"{eff['p_value']:>10.4f}" + if np.isfinite(eff["p_value"]) + else f"{'NaN':>10}" + ) + lines.append( + f"{h:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " + f"{e_t} {e_p} {e_sig:>6}" + ) + + lines.extend(["-" * 85, ""]) + + # Group effects + if self.group_effects: + lines.extend( + [ + "-" * 85, + "Group (Cohort) Effects".center(85), + "-" * 85, + f"{'Cohort':<15} {'Estimate':>12} {'Std. Err.':>12} " + f"{'t-stat':>10} {'P>|t|':>10} {'Sig.':>6}", + "-" * 85, + ] + ) + + for g in sorted(self.group_effects.keys()): + eff = self.group_effects[g] + if np.isnan(eff["effect"]): + lines.append(f"{g:<15} {'NaN':>12} {'NaN':>12} {'NaN':>10} {'NaN':>10} {'':>6}") + else: + g_sig = _get_significance_stars(eff["p_value"]) + g_t = ( + f"{eff['t_stat']:>10.3f}" if np.isfinite(eff["t_stat"]) else f"{'NaN':>10}" + ) + g_p = ( + f"{eff['p_value']:>10.4f}" + if np.isfinite(eff["p_value"]) + else f"{'NaN':>10}" + ) + lines.append( + f"{g:<15} {eff['effect']:>12.4f} {eff['se']:>12.4f} " + f"{g_t} {g_p} {g_sig:>6}" + ) + + lines.extend(["-" * 85, ""]) + + lines.extend( + [ + "Signif. codes: '***' 0.001, '**' 0.01, '*' 0.05, '.' 0.1", + "=" * 85, + ] + ) + + return "\n".join(lines) + + def print_summary(self, alpha: Optional[float] = None) -> None: + """Print summary to stdout.""" + print(self.summary(alpha)) + + def to_dataframe(self, level: str = "event_study") -> pd.DataFrame: + """ + Convert results to DataFrame. + + Parameters + ---------- + level : str, default="event_study" + Level of aggregation: + - "event_study": Event study effects by relative time + - "group": Group (cohort) effects + - "observation": Per-observation treatment effects + + Returns + ------- + pd.DataFrame + Results as DataFrame. + """ + if level == "observation": + return self.treatment_effects.copy() + + elif level == "event_study": + if self.event_study_effects is None: + raise ValueError( + "Event study effects not computed. " + "Use aggregate='event_study' or aggregate='all'." + ) + rows = [] + for h, data in sorted(self.event_study_effects.items()): + rows.append( + { + "relative_period": h, + "effect": data["effect"], + "se": data["se"], + "t_stat": data["t_stat"], + "p_value": data["p_value"], + "conf_int_lower": data["conf_int"][0], + "conf_int_upper": data["conf_int"][1], + "n_obs": data.get("n_obs", np.nan), + } + ) + return pd.DataFrame(rows) + + elif level == "group": + if self.group_effects is None: + raise ValueError( + "Group effects not computed. " "Use aggregate='group' or aggregate='all'." + ) + rows = [] + for g, data in sorted(self.group_effects.items()): + rows.append( + { + "group": g, + "effect": data["effect"], + "se": data["se"], + "t_stat": data["t_stat"], + "p_value": data["p_value"], + "conf_int_lower": data["conf_int"][0], + "conf_int_upper": data["conf_int"][1], + "n_obs": data.get("n_obs", np.nan), + } + ) + return pd.DataFrame(rows) + + else: + raise ValueError( + f"Unknown level: {level}. Use 'event_study', 'group', or 'observation'." + ) + + @property + def is_significant(self) -> bool: + """Check if overall ATT is significant.""" + return bool(self.overall_p_value < self.alpha) + + @property + def significance_stars(self) -> str: + """Significance stars for overall ATT.""" + return _get_significance_stars(self.overall_p_value) diff --git a/diff_diff/utils.py b/diff_diff/utils.py index 4d551685..712af691 100644 --- a/diff_diff/utils.py +++ b/diff_diff/utils.py @@ -512,7 +512,7 @@ def wild_bootstrap_se( obs_weights[indices] = cluster_weights[g] # Construct bootstrap sample: y* = X @ beta_restricted + e_restricted * weights - y_star = X @ beta_restricted + residuals_restricted * obs_weights + y_star = np.dot(X, beta_restricted) + residuals_restricted * obs_weights # Estimate bootstrap coefficients with cluster-robust SE beta_star, residuals_star, vcov_star = _solve_ols_linalg( @@ -638,8 +638,7 @@ def compute_trend(group_data: pd.DataFrame) -> Tuple[float, float]: # Test for difference in trends slope_diff = treated_slope - control_slope se_diff = np.sqrt(treated_se ** 2 + control_se ** 2) - t_stat = slope_diff / se_diff if se_diff > 0 else np.nan - p_value = compute_p_value(t_stat) if not np.isnan(t_stat) else np.nan + t_stat, p_value, _ = safe_inference(slope_diff, se_diff) return { "treated_trend": treated_slope, diff --git a/docs/methodology/REGISTRY.md b/docs/methodology/REGISTRY.md index 49785f41..08116c96 100644 --- a/docs/methodology/REGISTRY.md +++ b/docs/methodology/REGISTRY.md @@ -942,6 +942,7 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]² - Wrong D specification: if user provides event-style D (only first treatment period), the absorbing-state validation will raise ValueError with helpful guidance - **LOOCV failure metadata**: When LOOCV fits fail in the Rust backend, the first failed observation coordinates (t, i) are returned to Python for informative warning messages +- **Inference CI distribution**: After `safe_inference()` migration, CI uses t-distribution (df = max(1, n_treated_obs - 1)), consistent with p_value. Previously CI used normal-distribution while p_value used t-distribution (inconsistent). This is a minor behavioral change; CIs may be slightly wider for small n_treated_obs. **Reference implementation(s):** - Authors' replication code (forthcoming) diff --git a/tests/test_diagnostics.py b/tests/test_diagnostics.py index 2bd34469..d019c1e6 100644 --- a/tests/test_diagnostics.py +++ b/tests/test_diagnostics.py @@ -549,6 +549,45 @@ def test_loo_summary_shows_stats(self, simple_panel_data): assert "Units analyzed" in summary assert "Mean effect" in summary + def test_loo_single_valid_effect_nan_inference(self): + """SE should be NaN when only 1 valid LOO effect (len(valid_effects) <= 1).""" + from tests.conftest import assert_nan_inference + + # Unit 0: treated, both periods (the only unit providing a valid LOO effect) + # Unit 1: treated, pre-period only → removing unit 0 leaves unit 1 with + # no post-treatment data, making treated*post unidentified → NaN ATT + # Controls: units 2-5 with both periods + data = [] + data.append({"unit": 0, "post": 0, "outcome": 1.0, "treated": 1}) + data.append({"unit": 0, "post": 1, "outcome": 4.0, "treated": 1}) + data.append({"unit": 1, "post": 0, "outcome": 1.5, "treated": 1}) + for u in [2, 3, 4, 5]: + data.append({"unit": u, "post": 0, "outcome": 1.0, "treated": 0}) + data.append({"unit": u, "post": 1, "outcome": 2.0, "treated": 0}) + + df = pd.DataFrame(data) + + import warnings + + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + result = leave_one_out_test( + df, + outcome="outcome", + treatment="treated", + time="post", + unit="unit", + ) + + # Only 1 valid LOO effect → SE should be NaN (not 0.0) + assert np.isnan(result.se), f"SE should be NaN with 1 valid effect, got {result.se}" + assert_nan_inference({ + "se": result.se, + "t_stat": result.t_stat, + "p_value": result.p_value, + "conf_int": result.conf_int, + }) + # ============================================================================= # run_placebo_test dispatcher diff --git a/tests/test_linalg.py b/tests/test_linalg.py index ac519d97..07efd46c 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -1508,3 +1508,29 @@ def test_sun_abraham_estimator_produces_valid_results(self): assert result.overall_se > 0 assert np.isfinite(result.overall_att) assert len(result.event_study_effects) > 0 + + +class TestNoDotRuntimeWarnings: + """Verify np.dot replacement avoids Apple M4 BLAS ufunc FPE bug.""" + + def test_solve_ols_no_runtime_warnings(self): + """No RuntimeWarnings from solve_ols with n >= 500.""" + import warnings + + rng = np.random.default_rng(42) + n = 500 + k = 5 + X = rng.standard_normal((n, k)) + beta_true = rng.standard_normal(k) + y = np.dot(X, beta_true) + rng.standard_normal(n) * 0.1 + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + coefficients, residuals, vcov = solve_ols(X, y, return_vcov=True) + + runtime_warnings = [x for x in w if issubclass(x.category, RuntimeWarning)] + assert len(runtime_warnings) == 0, ( + f"Got {len(runtime_warnings)} RuntimeWarning(s): " + f"{[str(x.message) for x in runtime_warnings]}" + ) + assert np.allclose(coefficients, beta_true, atol=0.1) diff --git a/tests/test_staggered.py b/tests/test_staggered.py index 79651454..1b8b8bb7 100644 --- a/tests/test_staggered.py +++ b/tests/test_staggered.py @@ -2986,3 +2986,41 @@ def test_event_study_universal_no_effects_raises_error(self): first_treat='first_treat', aggregate='event_study' ) + + +class TestCallawaySantAnnaCIBugFix: + """Regression test: safe_inference fixes CI computed with NaN SE.""" + + def test_nan_se_group_time_ci_is_nan(self): + """conf_int should be (NaN, NaN) when SE is NaN, not finite values.""" + from tests.conftest import assert_nan_inference + + # Generate data with very few units to produce NaN-SE group-time effects + # (small sample → degenerate groups) + data = generate_staggered_data( + n_units=20, n_periods=6, n_cohorts=3, never_treated_frac=0.1, seed=123 + ) + + import warnings + + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + cs = CallawaySantAnna() + results = cs.fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + ) + + # Check all group-time effects: if SE is NaN, CI must also be NaN + for (g, t), eff in results.group_time_effects.items(): + se = eff["se"] + if not (np.isfinite(se) and se > 0): + assert_nan_inference({ + "se": se, + "t_stat": eff["t_stat"], + "p_value": eff["p_value"], + "conf_int": eff["conf_int"], + }) diff --git a/tests/test_two_stage.py b/tests/test_two_stage.py index be0108a4..5450992e 100644 --- a/tests/test_two_stage.py +++ b/tests/test_two_stage.py @@ -642,14 +642,34 @@ def test_rank_deficiency_error(self): def test_nan_propagation(self): """NaN SE should propagate to t_stat, p_value, conf_int.""" - data = generate_test_data() - results = TwoStageDiD().fit( - data, outcome="outcome", unit="unit", time="time", first_treat="first_treat" - ) + from tests.conftest import assert_nan_inference + + # Use never_treated_frac=0.0 to trigger Proposition 5 NaN horizons + data = generate_test_data(never_treated_frac=0.0) + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + results = TwoStageDiD().fit( + data, + outcome="outcome", + unit="unit", + time="time", + first_treat="first_treat", + aggregate="event_study", + ) - # For a reference period, t_stat and p_value should be NaN - if results.event_study_effects: - pass # Only check if event study was computed + # Proposition 5 horizons should have NaN inference fields + assert results.event_study_effects, "Event study should be computed" + nan_horizons_found = 0 + for h, eff in results.event_study_effects.items(): + if np.isnan(eff["effect"]): + nan_horizons_found += 1 + assert_nan_inference({ + "se": eff["se"], + "t_stat": eff["t_stat"], + "p_value": eff["p_value"], + "conf_int": eff["conf_int"], + }) + assert nan_horizons_found > 0, "Should have at least one Prop 5 NaN horizon" # Normal results should have finite values assert np.isfinite(results.overall_t_stat)