Skip to content

Commit f4913d5

Browse files
igerberclaude
andcommitted
Spring cleanup: bootstrap NaN-gating, mypy fixes, doc snippet hardening
- Migrate imputation_bootstrap.py and two_stage_bootstrap.py to shared compute_effect_bootstrap_stats() for NaN filtering and SE<=0 guards - Add @overload to solve_ols/_solve_ols_numpy to resolve 15 mypy unpacking errors; add assert guards for Optional indexing (81→9 errors) - Replace blanket NameError catch in test_doc_snippets.py with explicit allow-list of 12 context-dependent snippets - Update TODO.md to reflect resolved tech debt items Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2116ff5 commit f4913d5

15 files changed

+484
-372
lines changed

TODO.md

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ Current limitations that may affect users:
1212

1313
| Issue | Location | Priority | Notes |
1414
|-------|----------|----------|-------|
15-
| MultiPeriodDiD wild bootstrap not supported | `estimators.py:779-785` | Low | Edge case |
16-
| `predict()` raises NotImplementedError | `estimators.py:568-587` | Low | Rarely needed |
15+
| MultiPeriodDiD wild bootstrap not supported | `estimators.py:778-784` | Low | Edge case |
16+
| `predict()` raises NotImplementedError | `estimators.py:567-588` | Low | Rarely needed |
1717

1818
## Code Quality
1919

@@ -23,14 +23,20 @@ Target: < 1000 lines per module for maintainability.
2323

2424
| File | Lines | Action |
2525
|------|-------|--------|
26-
| `utils.py` | 1780 | Monitor -- legacy placebo function removed |
27-
| `visualization.py` | 1678 | Monitor -- growing but cohesive |
28-
| `linalg.py` | 1537 | Monitor -- unified backend, splitting would hurt cohesion |
26+
| `trop.py` | 2738 | Consider splitting — 2.7× target |
27+
| `utils.py` | 1838 | Monitor |
28+
| `staggered.py` | 1785 | Monitor |
29+
| `imputation.py` | 1756 | Monitor |
30+
| `visualization.py` | 1727 | Monitor — growing but cohesive |
31+
| `linalg.py` | 1727 | Monitor — unified backend, splitting would hurt cohesion |
32+
| `triple_diff.py` | 1581 | Monitor |
2933
| `honest_did.py` | 1511 | Acceptable |
34+
| `two_stage.py` | 1451 | Acceptable |
3035
| `power.py` | 1350 | Acceptable |
31-
| `triple_diff.py` | 1322 | Acceptable |
32-
| `sun_abraham.py` | 1227 | Acceptable |
33-
| `estimators.py` | 1161 | Acceptable |
36+
| `prep.py` | 1242 | Acceptable |
37+
| `sun_abraham.py` | 1162 | Acceptable |
38+
| `continuous_did.py` | 1155 | Acceptable |
39+
| `estimators.py` | 1147 | Acceptable |
3440
| `pretrends.py` | 1104 | Acceptable |
3541

3642
---
@@ -44,7 +50,6 @@ Deferred items from PR reviews that were not addressed before merge.
4450
| Issue | Location | PR | Priority |
4551
|-------|----------|----|----------|
4652
| ImputationDiD dense `(A0'A0).toarray()` scales O((U+T+K)^2), OOM risk on large panels | `imputation.py` | #141 | Medium (deferred — only triggers when sparse solver fails; fixing requires sparse least-squares alternatives) |
47-
| Bootstrap NaN-gating gap: manual SE/CI/p-value without non-finite filtering or SE<=0 guard | `imputation_bootstrap.py`, `two_stage_bootstrap.py` | #177 | Medium — migrate to `compute_effect_bootstrap_stats` from `bootstrap_utils.py` |
4853
| EfficientDiD: warn when cohort share is very small (< 2 units or < 1% of sample) — inverted in Omega*/EIF | `efficient_did_weights.py` | #192 | Low |
4954
| EfficientDiD: API docs / tutorial page for new public estimator | `docs/` | #192 | Medium |
5055

@@ -62,7 +67,7 @@ Deferred items from PR reviews that were not addressed before merge.
6267
| Tutorial notebooks not executed in CI | `docs/tutorials/*.ipynb` | #159 | Low |
6368
| R comparison tests spawn separate `Rscript` per test (slow CI) | `tests/test_methodology_twfe.py:294` | #139 | Low |
6469
| CS R helpers hard-code `xformla = ~ 1`; no covariate-adjusted R benchmark for IRLS path | `tests/test_methodology_callaway.py` | #202 | Low |
65-
| Context-dependent doc snippets pass via blanket NameError; no standalone validation | `tests/test_doc_snippets.py`, `docs/api/visualization.rst`, `docs/python_comparison.rst`, `docs/r_comparison.rst` | #206 | Low |
70+
| ~~Context-dependent doc snippets pass via blanket NameError~~ | `tests/test_doc_snippets.py` | #206 | ~~Low~~ — resolved: allow-list replaces blanket catch |
6671
| ~1,460 `duplicate object description` Sphinx warnings — each class attribute is documented in both module API pages and autosummary stubs; fix by adding `:no-index:` to one location or restructuring API docs to avoid overlap | `docs/api/*.rst`, `docs/api/_autosummary/` || Low |
6772

6873
---
@@ -82,22 +87,20 @@ Different estimators compute SEs differently. Consider unified interface.
8287

8388
### Type Annotations
8489

85-
Pyright reports 282 type errors. Most are false positives from numpy/pandas type stubs.
90+
Mypy reports 9 errors (down from 81 before spring cleanup). All remaining are
91+
mixin `attr-defined` errors — methods accessed via `self` that live on the
92+
concrete class, not the mixin. Fixing these requires Protocol classes, which is
93+
low priority.
8694

8795
| Category | Count | Notes |
8896
|----------|-------|-------|
89-
| reportArgumentType | 94 | numpy/pandas stub mismatches |
90-
| reportAttributeAccessIssue | 89 | Union types (results classes) |
91-
| reportReturnType | 21 | Return type mismatches |
92-
| reportOperatorIssue | 16 | Operators on incompatible types |
93-
| Others | 62 | Various minor issues |
94-
95-
**Genuine issues to fix (low priority):**
96-
- [ ] Optional handling in `estimators.py:291,297,308` - None checks needed
97-
- [ ] Union type narrowing in `visualization.py:325-345` - results classes
98-
- [ ] numpy floating conversion in `diagnostics.py:669-673`
99-
100-
**Note:** Most errors are false positives from imprecise type stubs. Mypy config in pyproject.toml already handles these via `disable_error_code`.
97+
| attr-defined (mixin methods) | 9 | Structural — requires Protocol refactor |
98+
99+
**Resolved in spring cleanup:**
100+
- [x] `@overload` on `solve_ols` / `_solve_ols_numpy` — eliminated all unpacking mismatches
101+
- [x] `assert X is not None` guards — eliminated all Optional indexing errors
102+
- [x] Mixin scalar attribute stubs — eliminated 26 mixin attr-defined errors
103+
- [x] Matplotlib `tab10` lookup fix
101104

102105
## Deprecated Code
103106

diff_diff/estimators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def fit(
296296
coefficients = reg.coefficients_
297297
residuals = reg.residuals_
298298
fitted = reg.fitted_values_
299+
assert coefficients is not None
299300
att = coefficients[att_idx]
300301

301302
# Get inference - either from bootstrap or analytical
@@ -1029,6 +1030,7 @@ def fit( # type: ignore[override]
10291030
post_effect_values = []
10301031
post_effect_indices = []
10311032

1033+
assert vcov is not None
10321034
for period in non_ref_periods:
10331035
idx = interaction_indices[period]
10341036
effect = coefficients[idx]

diff_diff/imputation.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@
2323
from scipy.sparse.linalg import spsolve
2424

2525
from diff_diff.imputation_bootstrap import ImputationDiDBootstrapMixin, _compute_target_weights
26-
from diff_diff.imputation_results import ImputationBootstrapResults, ImputationDiDResults # noqa: F401 (re-export)
26+
from diff_diff.imputation_results import ( # noqa: F401 (re-export)
27+
ImputationBootstrapResults,
28+
ImputationDiDResults,
29+
)
2730
from diff_diff.linalg import solve_ols
2831
from diff_diff.utils import safe_inference
2932

30-
31-
3233
# =============================================================================
3334
# Main Estimator
3435
# =============================================================================
@@ -417,9 +418,7 @@ def fit(
417418
kept_cov_mask=kept_cov_mask,
418419
)
419420

420-
overall_t, overall_p, overall_ci = safe_inference(
421-
overall_att, overall_se, alpha=self.alpha
422-
)
421+
overall_t, overall_p, overall_ci = safe_inference(overall_att, overall_se, alpha=self.alpha)
423422

424423
# Event study and group aggregation
425424
event_study_effects = None
@@ -553,7 +552,9 @@ def fit(
553552
and event_study_effects[h].get("n_obs", 1) > 0
554553
):
555554
event_study_effects[h]["se"] = bootstrap_results.event_study_ses[h]
555+
assert bootstrap_results.event_study_cis is not None
556556
event_study_effects[h]["conf_int"] = bootstrap_results.event_study_cis[h]
557+
assert bootstrap_results.event_study_p_values is not None
557558
event_study_effects[h]["p_value"] = bootstrap_results.event_study_p_values[
558559
h
559560
]
@@ -568,7 +569,9 @@ def fit(
568569
for g in group_effects:
569570
if g in bootstrap_results.group_ses:
570571
group_effects[g]["se"] = bootstrap_results.group_ses[g]
572+
assert bootstrap_results.group_cis is not None
571573
group_effects[g]["conf_int"] = bootstrap_results.group_cis[g]
574+
assert bootstrap_results.group_p_values is not None
572575
group_effects[g]["p_value"] = bootstrap_results.group_p_values[g]
573576
eff_val = group_effects[g]["effect"]
574577
se_val = group_effects[g]["se"]
@@ -1614,6 +1617,7 @@ def _pretrend_test(self, n_leads: Optional[int] = None) -> Dict[str, Any]:
16141617
)
16151618
coefficients = result[0]
16161619
vcov = result[2]
1620+
assert vcov is not None
16171621

16181622
# Extract lead coefficients and their sub-VCV
16191623
n_leads_actual = len(lead_cols)

diff_diff/imputation_bootstrap.py

Lines changed: 39 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,18 @@
66
"""
77

88
import warnings
9-
from typing import Any, Dict, List, Optional, Tuple
9+
from typing import Any, Dict, List, Optional
1010

1111
import numpy as np
1212
import pandas as pd
1313

14+
from diff_diff.bootstrap_utils import (
15+
compute_effect_bootstrap_stats as _compute_effect_bootstrap_stats,
16+
)
17+
from diff_diff.bootstrap_utils import (
18+
generate_bootstrap_weights_batch as _generate_bootstrap_weights_batch,
19+
)
1420
from diff_diff.imputation_results import ImputationBootstrapResults
15-
from diff_diff.staggered_bootstrap import _generate_bootstrap_weights_batch
1621

1722
__all__ = [
1823
"ImputationDiDBootstrapMixin",
@@ -55,46 +60,13 @@ def _compute_target_weights(
5560
class ImputationDiDBootstrapMixin:
5661
"""Mixin providing bootstrap inference methods for ImputationDiD."""
5762

58-
def _compute_percentile_ci(
59-
self,
60-
boot_dist: np.ndarray,
61-
alpha: float,
62-
) -> Tuple[float, float]:
63-
"""Compute percentile confidence interval from bootstrap distribution."""
64-
lower = float(np.percentile(boot_dist, alpha / 2 * 100))
65-
upper = float(np.percentile(boot_dist, (1 - alpha / 2) * 100))
66-
return (lower, upper)
67-
68-
def _compute_bootstrap_pvalue(
69-
self,
70-
original_effect: float,
71-
boot_dist: np.ndarray,
72-
n_valid: Optional[int] = None,
73-
) -> float:
74-
"""
75-
Compute two-sided bootstrap p-value.
76-
77-
Uses the percentile method: p-value is the proportion of bootstrap
78-
estimates on the opposite side of zero from the original estimate,
79-
doubled for two-sided test.
80-
81-
Parameters
82-
----------
83-
original_effect : float
84-
Original point estimate.
85-
boot_dist : np.ndarray
86-
Bootstrap distribution of the effect.
87-
n_valid : int, optional
88-
Number of valid bootstrap samples. If None, uses self.n_bootstrap.
89-
"""
90-
if original_effect >= 0:
91-
p_one_sided = float(np.mean(boot_dist <= 0))
92-
else:
93-
p_one_sided = float(np.mean(boot_dist >= 0))
94-
p_value = min(2 * p_one_sided, 1.0)
95-
n_for_floor = n_valid if n_valid is not None else self.n_bootstrap
96-
p_value = max(p_value, 1 / (n_for_floor + 1))
97-
return p_value
63+
# Type hints for attributes accessed from the main class
64+
n_bootstrap: int
65+
bootstrap_weights: str
66+
alpha: float
67+
seed: Optional[int]
68+
anticipation: int
69+
horizon_max: Optional[int]
9870

9971
def _precompute_bootstrap_psi(
10072
self,
@@ -266,16 +238,11 @@ def _run_bootstrap(
266238
# We do the same here so percentile CIs and empirical p-values work correctly.
267239
boot_overall_shifted = boot_overall + original_att
268240

269-
overall_se = float(np.std(boot_overall, ddof=1))
270-
overall_ci = (
271-
self._compute_percentile_ci(boot_overall_shifted, self.alpha)
272-
if overall_se > 0
273-
else (np.nan, np.nan)
274-
)
275-
overall_p = (
276-
self._compute_bootstrap_pvalue(original_att, boot_overall_shifted)
277-
if overall_se > 0
278-
else np.nan
241+
overall_se, overall_ci, overall_p = _compute_effect_bootstrap_stats(
242+
original_att,
243+
boot_overall_shifted,
244+
alpha=self.alpha,
245+
context="ImputationDiD overall ATT",
279246
)
280247

281248
event_study_ses = None
@@ -286,16 +253,17 @@ def _run_bootstrap(
286253
event_study_cis = {}
287254
event_study_p_values = {}
288255
for h in boot_event_study:
289-
se_h = float(np.std(boot_event_study[h], ddof=1))
290-
event_study_ses[h] = se_h
291256
orig_eff = original_event_study[h]["effect"]
292-
if se_h > 0 and np.isfinite(orig_eff):
293-
shifted_h = boot_event_study[h] + orig_eff
294-
event_study_p_values[h] = self._compute_bootstrap_pvalue(orig_eff, shifted_h)
295-
event_study_cis[h] = self._compute_percentile_ci(shifted_h, self.alpha)
296-
else:
297-
event_study_p_values[h] = np.nan
298-
event_study_cis[h] = (np.nan, np.nan)
257+
shifted_h = boot_event_study[h] + orig_eff
258+
se_h, ci_h, p_h = _compute_effect_bootstrap_stats(
259+
orig_eff,
260+
shifted_h,
261+
alpha=self.alpha,
262+
context=f"ImputationDiD event study (h={h})",
263+
)
264+
event_study_ses[h] = se_h
265+
event_study_cis[h] = ci_h
266+
event_study_p_values[h] = p_h
299267

300268
group_ses = None
301269
group_cis = None
@@ -305,16 +273,17 @@ def _run_bootstrap(
305273
group_cis = {}
306274
group_p_values = {}
307275
for g in boot_group:
308-
se_g = float(np.std(boot_group[g], ddof=1))
309-
group_ses[g] = se_g
310276
orig_eff = original_group[g]["effect"]
311-
if se_g > 0 and np.isfinite(orig_eff):
312-
shifted_g = boot_group[g] + orig_eff
313-
group_p_values[g] = self._compute_bootstrap_pvalue(orig_eff, shifted_g)
314-
group_cis[g] = self._compute_percentile_ci(shifted_g, self.alpha)
315-
else:
316-
group_p_values[g] = np.nan
317-
group_cis[g] = (np.nan, np.nan)
277+
shifted_g = boot_group[g] + orig_eff
278+
se_g, ci_g, p_g = _compute_effect_bootstrap_stats(
279+
orig_eff,
280+
shifted_g,
281+
alpha=self.alpha,
282+
context=f"ImputationDiD group effect (g={g})",
283+
)
284+
group_ses[g] = se_g
285+
group_cis[g] = ci_g
286+
group_p_values[g] = p_g
318287

319288
return ImputationBootstrapResults(
320289
n_bootstrap=self.n_bootstrap,

0 commit comments

Comments
 (0)