Skip to content

Commit 2f1d5df

Browse files
igerberclaude
andcommitted
Address tech debt from code reviews: bootstrap_weights, dedup, removals
- Add bootstrap_weights parameter to TwoStageDiD and ImputationDiD (rademacher/mammen/webb, matching CallawaySantAnna API) - Unify TwoStageDiD GMM score computation via _compute_gmm_scores() static method with consistent NaN/overflow handling - Extract _compute_target_weights() helper for ImputationDiD weight construction, eliminating aggregation/bootstrap duplication - Optimize TwoStageDiD cluster score loop: single .toarray() call replaces per-column .getcol(j).toarray() - Add TROP n_bootstrap >= 2 validation (ValueError) - Remove SunAbraham deprecated min_pre_periods/min_post_periods params - Remove legacy compute_placebo_effects from utils.py - Add ImputationDiD bootstrap + covariate test - Update TODO.md marking completed items Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent edbb5ca commit 2f1d5df

File tree

13 files changed

+209
-312
lines changed

13 files changed

+209
-312
lines changed

TODO.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Target: < 1000 lines per module for maintainability.
2828
| ~~`trop.py`~~ | ~~2904~~ ~2560 | ✅ Partially split: results extracted to `trop_results.py` (~340 lines) |
2929
| ~~`imputation.py`~~ | ~~2480~~ ~1740 | ✅ Split into imputation.py, imputation_results.py, imputation_bootstrap.py |
3030
| ~~`two_stage.py`~~ | ~~2209~~ ~1490 | ✅ Split into two_stage.py, two_stage_results.py, two_stage_bootstrap.py |
31-
| `utils.py` | 1879 | Monitor -- legacy placebo functions stay to avoid circular imports |
31+
| `utils.py` | 1780 | Monitor -- legacy placebo function removed |
3232
| `visualization.py` | 1678 | Monitor -- growing but cohesive |
3333
| `linalg.py` | 1537 | Monitor -- unified backend, splitting would hurt cohesion |
3434
| `honest_did.py` | 1511 | Acceptable |
@@ -58,18 +58,18 @@ Deferred items from PR reviews that were not addressed before merge.
5858

5959
| Issue | Location | PR | Priority |
6060
|-------|----------|----|----------|
61-
| TwoStageDiD & ImputationDiD bootstrap hardcodes Rademacher only; no `bootstrap_weights` parameter unlike CallawaySantAnna | `two_stage_bootstrap.py`, `imputation_bootstrap.py` | #156, #141 | Medium |
62-
| TwoStageDiD GMM score logic duplicated between analytic/bootstrap with inconsistent NaN/overflow handling | `two_stage.py`, `two_stage_bootstrap.py` | #156 | Medium |
63-
| ImputationDiD weight construction duplicated between aggregation and bootstrap (drift risk) -- has explicit code comment acknowledging duplication | `imputation.py`, `imputation_bootstrap.py` | #141 | Medium |
64-
| ImputationDiD dense `(A0'A0).toarray()` scales O((U+T+K)^2), OOM risk on large panels | `imputation.py` | #141 | Medium |
61+
| ~~TwoStageDiD & ImputationDiD bootstrap hardcodes Rademacher only; no `bootstrap_weights` parameter unlike CallawaySantAnna~~ | ~~`two_stage_bootstrap.py`, `imputation_bootstrap.py`~~ | ~~#156, #141~~ | ✅ Fixed: Added `bootstrap_weights` parameter to both estimators |
62+
| ~~TwoStageDiD GMM score logic duplicated between analytic/bootstrap with inconsistent NaN/overflow handling~~ | ~~`two_stage.py`, `two_stage_bootstrap.py`~~ | ~~#156~~ | ✅ Fixed: Unified via `_compute_gmm_scores()` static method |
63+
| ~~ImputationDiD weight construction duplicated between aggregation and bootstrap (drift risk)~~ | ~~`imputation.py`, `imputation_bootstrap.py`~~ | ~~#141~~ | ✅ Fixed: Extracted `_compute_target_weights()` helper in `imputation_bootstrap.py` |
64+
| ImputationDiD dense `(A0'A0).toarray()` scales O((U+T+K)^2), OOM risk on large panels | `imputation.py` | #141 | Medium (deferred — only triggers when sparse solver fails; fixing requires sparse least-squares alternatives) |
6565

6666
#### Performance
6767

6868
| Issue | Location | PR | Priority |
6969
|-------|----------|----|----------|
70-
| TwoStageDiD per-column `.toarray()` in loop for cluster scores | `two_stage_bootstrap.py` | #156 | Medium |
70+
| ~~TwoStageDiD per-column `.toarray()` in loop for cluster scores~~ | ~~`two_stage_bootstrap.py`~~ | ~~#156~~ | ✅ Fixed: Single `.toarray()` call replaces per-column loop |
7171
| ImputationDiD event-study SEs recompute full conservative variance per horizon (should cache A0/A1 factorization) | `imputation.py` | #141 | Low |
72-
| Legacy `compute_placebo_effects` uses deprecated projected-gradient weights (marked deprecated, users directed to `SyntheticDiD`) | `utils.py:1689-1691` | #145 | Low |
72+
| ~~Legacy `compute_placebo_effects` uses deprecated projected-gradient weights~~ | ~~`utils.py:1689-1691`~~ | ~~#145~~ | ✅ Fixed: Removed function entirely |
7373
| Rust faer SVD ndarray-to-faer conversion overhead (minimal vs SVD cost) | `rust/src/linalg.rs:67` | #115 | Low |
7474

7575
#### Testing/Docs
@@ -78,11 +78,11 @@ Deferred items from PR reviews that were not addressed before merge.
7878
|-------|----------|----|----------|
7979
| Tutorial notebooks not executed in CI | `docs/tutorials/*.ipynb` | #159 | Low |
8080
| ~~TwoStageDiD `test_nan_propagation` is a no-op~~ | ~~`tests/test_two_stage.py:643-652`~~ | ~~#156~~ | ✅ Fixed |
81-
| ImputationDiD bootstrap + covariate path untested | `tests/test_imputation.py` | #141 | Low |
82-
| TROP `n_bootstrap >= 2` validation missing (can yield 0/NaN SE silently) | `trop.py:462` | #124 | Low |
83-
| SunAbraham deprecated `min_pre_periods`/`min_post_periods` still in `fit()` docstring | `sun_abraham.py:458-487` | #153 | Low |
81+
| ~~ImputationDiD bootstrap + covariate path untested~~ | ~~`tests/test_imputation.py`~~ | ~~#141~~ | ✅ Fixed: Added `test_bootstrap_with_covariates` |
82+
| ~~TROP `n_bootstrap >= 2` validation missing (can yield 0/NaN SE silently)~~ | ~~`trop.py:462`~~ | ~~#124~~ | ✅ Fixed: Added `ValueError` for `n_bootstrap < 2` |
83+
| ~~SunAbraham deprecated `min_pre_periods`/`min_post_periods` still in `fit()` docstring~~ | ~~`sun_abraham.py:458-487`~~ | ~~#153~~ | ✅ Fixed: Removed deprecated params from `fit()` |
8484
| R comparison tests spawn separate `Rscript` per test (slow CI) | `tests/test_methodology_twfe.py:294` | #139 | Low |
85-
| Rust TROP bootstrap SE returns 0.0 instead of NaN for <2 samples | `rust/src/trop.rs:1038-1054` | #115 | Low |
85+
| ~~Rust TROP bootstrap SE returns 0.0 instead of NaN for <2 samples~~ | ~~`rust/src/trop.rs:1038-1054`~~ | ~~#115~~ | ✅ Already fixed: Returns `f64::NAN` at `rust/src/trop.rs:1034` |
8686

8787
---
8888

diff_diff/imputation.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from scipy import sparse, stats
2323
from scipy.sparse.linalg import spsolve
2424

25-
from diff_diff.imputation_bootstrap import ImputationDiDBootstrapMixin
25+
from diff_diff.imputation_bootstrap import ImputationDiDBootstrapMixin, _compute_target_weights
2626
from diff_diff.imputation_results import ImputationBootstrapResults, ImputationDiDResults # noqa: F401 (re-export)
2727
from diff_diff.linalg import solve_ols
2828
from diff_diff.utils import safe_inference
@@ -63,6 +63,8 @@ class ImputationDiD(ImputationDiDBootstrapMixin):
6363
n_bootstrap : int, default=0
6464
Number of bootstrap iterations. If 0, uses analytical inference
6565
(conservative variance from Theorem 3).
66+
bootstrap_weights : str, default="rademacher"
67+
Type of bootstrap weights: "rademacher", "mammen", or "webb".
6668
seed : int, optional
6769
Random seed for reproducibility.
6870
rank_deficient_action : str, default="warn"
@@ -126,6 +128,7 @@ def __init__(
126128
alpha: float = 0.05,
127129
cluster: Optional[str] = None,
128130
n_bootstrap: int = 0,
131+
bootstrap_weights: str = "rademacher",
129132
seed: Optional[int] = None,
130133
rank_deficient_action: str = "warn",
131134
horizon_max: Optional[int] = None,
@@ -136,6 +139,11 @@ def __init__(
136139
f"rank_deficient_action must be 'warn', 'error', or 'silent', "
137140
f"got '{rank_deficient_action}'"
138141
)
142+
if bootstrap_weights not in ("rademacher", "mammen", "webb"):
143+
raise ValueError(
144+
f"bootstrap_weights must be 'rademacher', 'mammen', or 'webb', "
145+
f"got '{bootstrap_weights}'"
146+
)
139147
if aux_partition not in ("cohort_horizon", "cohort", "horizon"):
140148
raise ValueError(
141149
f"aux_partition must be 'cohort_horizon', 'cohort', or 'horizon', "
@@ -146,6 +154,7 @@ def __init__(
146154
self.alpha = alpha
147155
self.cluster = cluster
148156
self.n_bootstrap = n_bootstrap
157+
self.bootstrap_weights = bootstrap_weights
149158
self.seed = seed
150159
self.rank_deficient_action = rank_deficient_action
151160
self.horizon_max = horizon_max
@@ -1359,15 +1368,7 @@ def _aggregate_event_study(
13591368
effect = float(np.mean(valid_tau))
13601369

13611370
# Compute SE via conservative variance with horizon-specific weights
1362-
weights_h = np.zeros(int(omega_1_mask.sum()))
1363-
# Map h_mask (relative to df_1) to weights array
1364-
h_indices_in_omega1 = np.where(h_mask)[0]
1365-
n_valid = len(valid_tau)
1366-
# Only weight valid (finite) observations
1367-
finite_mask = np.isfinite(tau_hat[h_mask])
1368-
valid_h_indices = h_indices_in_omega1[finite_mask]
1369-
for idx in valid_h_indices:
1370-
weights_h[idx] = 1.0 / n_valid
1371+
weights_h, n_valid = _compute_target_weights(tau_hat, h_mask)
13711372

13721373
se = self._compute_conservative_variance(
13731374
df=df,
@@ -1477,12 +1478,7 @@ def _aggregate_group(
14771478
effect = float(np.mean(valid_tau))
14781479

14791480
# Compute SE with group-specific weights
1480-
weights_g = np.zeros(int(omega_1_mask.sum()))
1481-
finite_mask = np.isfinite(tau_hat) & g_mask
1482-
g_indices = np.where(finite_mask)[0]
1483-
n_valid = len(valid_tau)
1484-
for idx in g_indices:
1485-
weights_g[idx] = 1.0 / n_valid
1481+
weights_g, _ = _compute_target_weights(tau_hat, g_mask)
14861482

14871483
se = self._compute_conservative_variance(
14881484
df=df,
@@ -1664,6 +1660,7 @@ def get_params(self) -> Dict[str, Any]:
16641660
"alpha": self.alpha,
16651661
"cluster": self.cluster,
16661662
"n_bootstrap": self.n_bootstrap,
1663+
"bootstrap_weights": self.bootstrap_weights,
16671664
"seed": self.seed,
16681665
"rank_deficient_action": self.rank_deficient_action,
16691666
"horizon_max": self.horizon_max,

diff_diff/imputation_bootstrap.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,39 @@
1919
]
2020

2121

22+
def _compute_target_weights(
23+
tau_hat: np.ndarray,
24+
target_mask: np.ndarray,
25+
) -> "tuple[np.ndarray, int]":
26+
"""
27+
Equal weights for finite tau_hat observations within target_mask.
28+
29+
Used by both aggregation and bootstrap paths to avoid weight logic
30+
duplication.
31+
32+
Parameters
33+
----------
34+
tau_hat : np.ndarray
35+
Per-observation treatment effects (may contain NaN).
36+
target_mask : np.ndarray
37+
Boolean mask selecting the target subset within tau_hat.
38+
39+
Returns
40+
-------
41+
weights : np.ndarray
42+
Weight array (same length as tau_hat). 1/n_valid for finite
43+
observations in target_mask, 0 elsewhere.
44+
n_valid : int
45+
Number of finite observations in the target subset.
46+
"""
47+
finite_target = np.isfinite(tau_hat) & target_mask
48+
n_valid = int(finite_target.sum())
49+
weights = np.zeros(len(tau_hat))
50+
if n_valid > 0:
51+
weights[np.where(finite_target)[0]] = 1.0 / n_valid
52+
return weights, n_valid
53+
54+
2255
class ImputationDiDBootstrapMixin:
2356
"""Mixin providing bootstrap inference methods for ImputationDiD."""
2457

@@ -120,13 +153,10 @@ def _precompute_bootstrap_psi(
120153
result["overall"] = (overall_psi, cluster_ids)
121154

122155
# Event study: per-horizon weights
123-
# NOTE: weight logic duplicated from _aggregate_event_study.
124-
# If weight scheme changes there, update here too.
125156
if event_study_effects:
126157
result["event_study"] = {}
127158
df_1 = df.loc[omega_1_mask]
128159
rel_times = df_1["_rel_time"].values
129-
n_omega_1 = int(omega_1_mask.sum())
130160

131161
# Balanced cohort mask (same logic as _aggregate_event_study)
132162
balanced_mask = None
@@ -150,37 +180,28 @@ def _precompute_bootstrap_psi(
150180
h_mask = rel_times == h
151181
if balanced_mask is not None:
152182
h_mask = h_mask & balanced_mask
153-
weights_h = np.zeros(n_omega_1)
154-
finite_h = np.isfinite(tau_hat) & h_mask
155-
n_valid_h = int(finite_h.sum())
183+
weights_h, n_valid_h = _compute_target_weights(tau_hat, h_mask)
156184
if n_valid_h == 0:
157185
continue
158-
weights_h[np.where(finite_h)[0]] = 1.0 / n_valid_h
159186

160187
psi_h, _ = self._compute_cluster_psi_sums(**common, weights=weights_h)
161188
result["event_study"][h] = psi_h
162189

163190
# Group effects: per-group weights
164-
# NOTE: weight logic duplicated from _aggregate_group.
165-
# If weight scheme changes there, update here too.
166191
if group_effects:
167192
result["group"] = {}
168193
df_1 = df.loc[omega_1_mask]
169194
cohorts = df_1[first_treat].values
170-
n_omega_1 = int(omega_1_mask.sum())
171195

172196
for g in group_effects:
173197
if group_effects[g].get("n_obs", 0) == 0:
174198
continue
175199
if not np.isfinite(group_effects[g].get("effect", np.nan)):
176200
continue
177201
g_mask = cohorts == g
178-
weights_g = np.zeros(n_omega_1)
179-
finite_g = np.isfinite(tau_hat) & g_mask
180-
n_valid_g = int(finite_g.sum())
202+
weights_g, n_valid_g = _compute_target_weights(tau_hat, g_mask)
181203
if n_valid_g == 0:
182204
continue
183-
weights_g[np.where(finite_g)[0]] = 1.0 / n_valid_g
184205

185206
psi_g, _ = self._compute_cluster_psi_sums(**common, weights=weights_g)
186207
result["group"][g] = psi_g
@@ -216,7 +237,7 @@ def _run_bootstrap(
216237

217238
# Generate ALL weights upfront: shape (n_bootstrap, n_clusters)
218239
all_weights = _generate_bootstrap_weights_batch(
219-
self.n_bootstrap, n_clusters, "rademacher", rng
240+
self.n_bootstrap, n_clusters, self.bootstrap_weights, rng
220241
)
221242

222243
# Overall ATT bootstrap draws
@@ -295,7 +316,7 @@ def _run_bootstrap(
295316

296317
return ImputationBootstrapResults(
297318
n_bootstrap=self.n_bootstrap,
298-
weight_type="rademacher",
319+
weight_type=self.bootstrap_weights,
299320
alpha=self.alpha,
300321
overall_att_se=overall_se,
301322
overall_att_ci=overall_ci,

diff_diff/sun_abraham.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,6 @@ def fit(
433433
time: str,
434434
first_treat: str,
435435
covariates: Optional[List[str]] = None,
436-
min_pre_periods: int = 1,
437-
min_post_periods: int = 1,
438436
) -> SunAbrahamResults:
439437
"""
440438
Fit the Sun-Abraham estimator using saturated regression.
@@ -454,10 +452,6 @@ def fit(
454452
Use 0 (or np.inf) for never-treated units.
455453
covariates : list, optional
456454
List of covariate column names to include in regression.
457-
min_pre_periods : int, default=1
458-
**Deprecated**: Accepted but ignored. Will be removed in a future version.
459-
min_post_periods : int, default=1
460-
**Deprecated**: Accepted but ignored. Will be removed in a future version.
461455
462456
Returns
463457
-------
@@ -469,22 +463,6 @@ def fit(
469463
ValueError
470464
If required columns are missing or data validation fails.
471465
"""
472-
# Deprecation warnings for unimplemented parameters
473-
if min_pre_periods != 1:
474-
warnings.warn(
475-
"min_pre_periods is not yet implemented and will be ignored. "
476-
"This parameter will be removed in a future version.",
477-
FutureWarning,
478-
stacklevel=2,
479-
)
480-
if min_post_periods != 1:
481-
warnings.warn(
482-
"min_post_periods is not yet implemented and will be ignored. "
483-
"This parameter will be removed in a future version.",
484-
FutureWarning,
485-
stacklevel=2,
486-
)
487-
488466
# Validate inputs
489467
required_cols = [outcome, unit, time, first_treat]
490468
if covariates:

diff_diff/trop.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,12 @@ def __init__(
156156
self.lambda_unit_grid = lambda_unit_grid or [0.0, 0.1, 0.5, 1.0, 2.0, 5.0]
157157
self.lambda_nn_grid = lambda_nn_grid or [0.0, 0.01, 0.1, 1.0, 10.0]
158158

159+
if n_bootstrap < 2:
160+
raise ValueError(
161+
"n_bootstrap must be >= 2 for TROP (bootstrap variance "
162+
"estimation is always used)"
163+
)
164+
159165
self.max_iter = max_iter
160166
self.tol = tol
161167
self.alpha = alpha

0 commit comments

Comments
 (0)