Skip to content

Commit 5d2fdaf

Browse files
igerberclaude
andcommitted
Address PR #208 review: add TWFE to registry, fix unregistered-estimator fallback
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c6465b7 commit 5d2fdaf

File tree

2 files changed

+101
-13
lines changed

2 files changed

+101
-13
lines changed

diff_diff/power.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,15 @@ def _basic_fit_kwargs(
136136
return dict(outcome="outcome", treatment="treated", time="post")
137137

138138

139+
def _twfe_fit_kwargs(
140+
data: pd.DataFrame,
141+
n_units: int,
142+
n_periods: int,
143+
treatment_period: int,
144+
) -> Dict[str, Any]:
145+
return dict(outcome="outcome", treatment="treated", time="post", unit="unit")
146+
147+
139148
def _multiperiod_fit_kwargs(
140149
data: pd.DataFrame,
141150
n_units: int,
@@ -255,6 +264,13 @@ def _get_registry() -> Dict[str, _EstimatorProfile]:
255264
result_extractor=_extract_simple,
256265
min_n=20,
257266
),
267+
"TwoWayFixedEffects": _EstimatorProfile(
268+
default_dgp=generate_did_data,
269+
dgp_kwargs_builder=_basic_dgp_kwargs,
270+
fit_kwargs_builder=_twfe_fit_kwargs,
271+
result_extractor=_extract_simple,
272+
min_n=20,
273+
),
258274
"MultiPeriodDiD": _EstimatorProfile(
259275
default_dgp=generate_did_data,
260276
dgp_kwargs_builder=_basic_dgp_kwargs,
@@ -1303,7 +1319,9 @@ def simulate_power(
13031319
if profile is None and data_generator is None:
13041320
raise ValueError(
13051321
f"Estimator '{estimator_name}' not in registry. "
1306-
f"Provide a custom data_generator and estimator_kwargs."
1322+
f"Provide a custom data_generator and estimator_kwargs "
1323+
f"(the full dict of keyword arguments for estimator.fit(), "
1324+
f"e.g. dict(outcome='y', treatment='treat', time='period'))."
13071325
)
13081326

13091327
# When a custom data_generator is provided, bypass registry DGP
@@ -1414,8 +1432,7 @@ def simulate_power(
14141432
)
14151433
fit_kwargs.update(est_kwargs)
14161434
else:
1417-
fit_kwargs = dict(outcome="outcome", treatment="treated", time="post")
1418-
fit_kwargs.update(est_kwargs)
1435+
fit_kwargs = dict(est_kwargs)
14191436

14201437
result = estimator.fit(data, **fit_kwargs)
14211438

tests/test_power.py

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
_staggered_fit_kwargs,
4545
_trop_fit_kwargs,
4646
)
47+
from diff_diff.prep import generate_did_data
4748

4849

4950
class TestPowerAnalysis:
@@ -800,17 +801,10 @@ def test_continuous_did_not_in_registry(self):
800801
progress=False,
801802
)
802803

803-
def test_twfe_not_in_registry(self):
804-
"""TwoWayFixedEffects is not in registry and raises without custom data_generator."""
804+
def test_twfe_in_registry(self):
805+
"""TwoWayFixedEffects is in the registry."""
805806
registry = _get_registry()
806-
assert "TwoWayFixedEffects" not in registry
807-
808-
with pytest.raises(ValueError, match="not in registry"):
809-
simulate_power(
810-
TwoWayFixedEffects(),
811-
n_simulations=5,
812-
progress=False,
813-
)
807+
assert "TwoWayFixedEffects" in registry
814808

815809
def test_unknown_estimator_raises_without_data_generator(self):
816810
"""Unknown estimator without data_generator raises ValueError."""
@@ -1066,6 +1060,83 @@ def test_sdid_sample_size(self):
10661060
assert isinstance(result, SimulationSampleSizeResults)
10671061
assert result.required_n > 0
10681062

1063+
@pytest.mark.slow
1064+
def test_twfe(self):
1065+
result = simulate_power(
1066+
TwoWayFixedEffects(),
1067+
n_simulations=5,
1068+
seed=42,
1069+
progress=False,
1070+
)
1071+
self._assert_valid_result(result, "TwoWayFixedEffects")
1072+
1073+
@pytest.mark.slow
1074+
def test_twfe_mde(self):
1075+
result = simulate_mde(
1076+
TwoWayFixedEffects(),
1077+
n_simulations=5,
1078+
effect_range=(0.5, 5.0),
1079+
seed=42,
1080+
progress=False,
1081+
)
1082+
assert isinstance(result, SimulationMDEResults)
1083+
assert result.mde > 0
1084+
1085+
@pytest.mark.slow
1086+
def test_twfe_sample_size(self):
1087+
result = simulate_sample_size(
1088+
TwoWayFixedEffects(),
1089+
n_simulations=5,
1090+
n_range=(20, 100),
1091+
seed=42,
1092+
progress=False,
1093+
)
1094+
assert isinstance(result, SimulationSampleSizeResults)
1095+
assert result.required_n > 0
1096+
1097+
@pytest.mark.slow
1098+
def test_custom_fallback_unregistered_estimator(self):
1099+
"""Unregistered estimator works with custom data_generator and estimator_kwargs."""
1100+
1101+
class _UnregisteredEstimator:
1102+
"""Unregistered wrapper for testing custom fallback."""
1103+
1104+
def __init__(self):
1105+
self._inner = DifferenceInDifferences()
1106+
1107+
def fit(self, data, **kwargs):
1108+
return self._inner.fit(data, **kwargs)
1109+
1110+
result = simulate_power(
1111+
_UnregisteredEstimator(),
1112+
data_generator=generate_did_data,
1113+
estimator_kwargs=dict(outcome="outcome", treatment="treated", time="post"),
1114+
n_simulations=5,
1115+
seed=42,
1116+
progress=False,
1117+
)
1118+
assert 0 <= result.power <= 1
1119+
assert result.n_simulations > 0
1120+
1121+
def test_custom_fallback_missing_kwargs_raises(self):
1122+
"""Unregistered estimator with no estimator_kwargs fails on fit."""
1123+
1124+
class _UnregisteredEstimator:
1125+
def __init__(self):
1126+
self._inner = DifferenceInDifferences()
1127+
1128+
def fit(self, data, **kwargs):
1129+
return self._inner.fit(data, **kwargs)
1130+
1131+
with pytest.raises((ValueError, TypeError, RuntimeError)):
1132+
simulate_power(
1133+
_UnregisteredEstimator(),
1134+
data_generator=generate_did_data,
1135+
n_simulations=5,
1136+
seed=42,
1137+
progress=False,
1138+
)
1139+
10691140

10701141
# ---------------------------------------------------------------------------
10711142
# simulate_mde tests

0 commit comments

Comments
 (0)