Skip to content

Commit 9026464

Browse files
igerberclaude
andcommitted
Address AI review round 5: fix bootstrap NaN propagation, fix CI pandas compat
- Guard overall bootstrap stats in _run_bootstrap() when ATT is NaN, preventing _compute_bootstrap_pvalue from returning 1/(B+1) instead of NaN - Add test_no_post_effects_bootstrap_returns_nan for the bootstrap NaN path - Cast first_treat to float before assigning np.inf in tests for pandas compat Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 049a6ef commit 9026464

File tree

3 files changed

+71
-8
lines changed

3 files changed

+71
-8
lines changed

diff_diff/sun_abraham.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,11 +1144,16 @@ def _run_bootstrap(
11441144
event_study_p_values[e] = p_value
11451145

11461146
# Overall ATT statistics
1147-
overall_se = float(np.std(bootstrap_overall, ddof=1))
1148-
overall_ci = self._compute_percentile_ci(bootstrap_overall, self.alpha)
1149-
overall_p = self._compute_bootstrap_pvalue(
1150-
original_overall_att, bootstrap_overall
1151-
)
1147+
if not np.isfinite(original_overall_att):
1148+
overall_se = np.nan
1149+
overall_ci = (np.nan, np.nan)
1150+
overall_p = np.nan
1151+
else:
1152+
overall_se = float(np.std(bootstrap_overall, ddof=1))
1153+
overall_ci = self._compute_percentile_ci(bootstrap_overall, self.alpha)
1154+
overall_p = self._compute_bootstrap_pvalue(
1155+
original_overall_att, bootstrap_overall
1156+
)
11521157

11531158
return SABootstrapResults(
11541159
n_bootstrap=self.n_bootstrap,

tests/test_staggered.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,9 @@ def test_never_treated_inf_encoding(self):
111111
data.copy(), outcome="outcome", unit="unit", time="time", first_treat="first_treat"
112112
)
113113

114-
# Re-encode never-treated from 0 to np.inf
114+
# Re-encode never-treated from 0 to np.inf (cast to float first for pandas compat)
115115
data_inf = data.copy()
116+
data_inf["first_treat"] = data_inf["first_treat"].astype(float)
116117
data_inf.loc[data_inf["first_treat"] == 0, "first_treat"] = np.inf
117118

118119
results_inf = cs.fit(

tests/test_sun_abraham.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,6 +1123,61 @@ def test_no_post_effects_returns_nan(self):
11231123
f"Expected (NaN, NaN) overall_conf_int, got {results.overall_conf_int}"
11241124
)
11251125

1126+
def test_no_post_effects_bootstrap_returns_nan(self, ci_params):
1127+
"""Test that no post-treatment effects returns NaN even with bootstrap.
1128+
1129+
When there are no post-treatment periods, overall_att/se/t_stat/p_value/ci
1130+
should all be NaN. The bootstrap path must not overwrite NaN with non-NaN
1131+
values (regression test for P0 bug where _compute_bootstrap_pvalue returned
1132+
1/(B+1) instead of NaN when original_effect was NaN).
1133+
"""
1134+
# Create data where all periods are pre-treatment
1135+
np.random.seed(42)
1136+
n_units = 40
1137+
n_periods = 6
1138+
1139+
units = np.repeat(np.arange(n_units), n_periods)
1140+
times = np.tile(np.arange(n_periods), n_units)
1141+
1142+
# All treated units have first_treat at period 100 (well beyond data range)
1143+
first_treat = np.zeros(n_units)
1144+
first_treat[12:] = 100 # treated at period 100, but data only goes to period 5
1145+
first_treat_expanded = np.repeat(first_treat, n_periods)
1146+
1147+
unit_fe = np.repeat(np.random.randn(n_units), n_periods)
1148+
time_fe = np.tile(np.arange(n_periods) * 0.1, n_units)
1149+
outcomes = unit_fe + time_fe + np.random.randn(len(units)) * 0.3
1150+
1151+
data = pd.DataFrame({
1152+
"unit": units,
1153+
"time": times,
1154+
"outcome": outcomes,
1155+
"first_treat": first_treat_expanded.astype(int),
1156+
})
1157+
1158+
n_boot = ci_params.bootstrap(50)
1159+
sa = SunAbraham(n_bootstrap=n_boot, seed=42)
1160+
results = sa.fit(
1161+
data, outcome="outcome", unit="unit", time="time", first_treat="first_treat"
1162+
)
1163+
1164+
# All overall inference fields should be NaN
1165+
assert np.isnan(results.overall_att), (
1166+
f"Expected NaN overall_att, got {results.overall_att}"
1167+
)
1168+
assert np.isnan(results.overall_se), (
1169+
f"Expected NaN overall_se, got {results.overall_se}"
1170+
)
1171+
assert np.isnan(results.overall_t_stat), (
1172+
f"Expected NaN overall_t_stat, got {results.overall_t_stat}"
1173+
)
1174+
assert np.isnan(results.overall_p_value), (
1175+
f"Expected NaN overall_p_value with bootstrap, got {results.overall_p_value}"
1176+
)
1177+
assert np.isnan(results.overall_conf_int[0]) and np.isnan(results.overall_conf_int[1]), (
1178+
f"Expected (NaN, NaN) overall_conf_int, got {results.overall_conf_int}"
1179+
)
1180+
11261181
def test_deprecated_min_pre_periods_warning(self):
11271182
"""Test that min_pre_periods emits FutureWarning (Step 5c)."""
11281183
data = generate_staggered_data(seed=42)
@@ -1377,8 +1432,9 @@ def test_never_treated_inf_encoding(self):
13771432
data.copy(), outcome="outcome", unit="unit", time="time", first_treat="first_treat"
13781433
)
13791434

1380-
# Re-encode never-treated from 0 to np.inf
1435+
# Re-encode never-treated from 0 to np.inf (cast to float first for pandas compat)
13811436
data_inf = data.copy()
1437+
data_inf["first_treat"] = data_inf["first_treat"].astype(float)
13821438
data_inf.loc[data_inf["first_treat"] == 0, "first_treat"] = np.inf
13831439

13841440
results_inf = sa.fit(
@@ -1411,7 +1467,8 @@ def test_never_treated_inf_encoding(self):
14111467
def test_all_never_treated_inf_raises(self):
14121468
"""Test that all-never-treated data with np.inf encoding raises ValueError."""
14131469
data = generate_staggered_data(n_units=100, n_periods=10, n_cohorts=3, seed=42)
1414-
# Set ALL units to never-treated via np.inf
1470+
# Set ALL units to never-treated via np.inf (cast to float first for pandas compat)
1471+
data["first_treat"] = data["first_treat"].astype(float)
14151472
data["first_treat"] = np.inf
14161473

14171474
sa = SunAbraham(n_bootstrap=0)

0 commit comments

Comments
 (0)