|
44 | 44 | _staggered_fit_kwargs, |
45 | 45 | _trop_fit_kwargs, |
46 | 46 | ) |
| 47 | +from diff_diff.prep import generate_did_data |
47 | 48 |
|
48 | 49 |
|
49 | 50 | class TestPowerAnalysis: |
@@ -800,17 +801,10 @@ def test_continuous_did_not_in_registry(self): |
800 | 801 | progress=False, |
801 | 802 | ) |
802 | 803 |
|
803 | | - def test_twfe_not_in_registry(self): |
804 | | - """TwoWayFixedEffects is not in registry and raises without custom data_generator.""" |
| 804 | + def test_twfe_in_registry(self): |
| 805 | + """TwoWayFixedEffects is in the registry.""" |
805 | 806 | registry = _get_registry() |
806 | | - assert "TwoWayFixedEffects" not in registry |
807 | | - |
808 | | - with pytest.raises(ValueError, match="not in registry"): |
809 | | - simulate_power( |
810 | | - TwoWayFixedEffects(), |
811 | | - n_simulations=5, |
812 | | - progress=False, |
813 | | - ) |
| 807 | + assert "TwoWayFixedEffects" in registry |
814 | 808 |
|
815 | 809 | def test_unknown_estimator_raises_without_data_generator(self): |
816 | 810 | """Unknown estimator without data_generator raises ValueError.""" |
@@ -1066,6 +1060,83 @@ def test_sdid_sample_size(self): |
1066 | 1060 | assert isinstance(result, SimulationSampleSizeResults) |
1067 | 1061 | assert result.required_n > 0 |
1068 | 1062 |
|
| 1063 | + @pytest.mark.slow |
| 1064 | + def test_twfe(self): |
| 1065 | + result = simulate_power( |
| 1066 | + TwoWayFixedEffects(), |
| 1067 | + n_simulations=5, |
| 1068 | + seed=42, |
| 1069 | + progress=False, |
| 1070 | + ) |
| 1071 | + self._assert_valid_result(result, "TwoWayFixedEffects") |
| 1072 | + |
| 1073 | + @pytest.mark.slow |
| 1074 | + def test_twfe_mde(self): |
| 1075 | + result = simulate_mde( |
| 1076 | + TwoWayFixedEffects(), |
| 1077 | + n_simulations=5, |
| 1078 | + effect_range=(0.5, 5.0), |
| 1079 | + seed=42, |
| 1080 | + progress=False, |
| 1081 | + ) |
| 1082 | + assert isinstance(result, SimulationMDEResults) |
| 1083 | + assert result.mde > 0 |
| 1084 | + |
| 1085 | + @pytest.mark.slow |
| 1086 | + def test_twfe_sample_size(self): |
| 1087 | + result = simulate_sample_size( |
| 1088 | + TwoWayFixedEffects(), |
| 1089 | + n_simulations=5, |
| 1090 | + n_range=(20, 100), |
| 1091 | + seed=42, |
| 1092 | + progress=False, |
| 1093 | + ) |
| 1094 | + assert isinstance(result, SimulationSampleSizeResults) |
| 1095 | + assert result.required_n > 0 |
| 1096 | + |
| 1097 | + @pytest.mark.slow |
| 1098 | + def test_custom_fallback_unregistered_estimator(self): |
| 1099 | + """Unregistered estimator works with custom data_generator and estimator_kwargs.""" |
| 1100 | + |
| 1101 | + class _UnregisteredEstimator: |
| 1102 | + """Unregistered wrapper for testing custom fallback.""" |
| 1103 | + |
| 1104 | + def __init__(self): |
| 1105 | + self._inner = DifferenceInDifferences() |
| 1106 | + |
| 1107 | + def fit(self, data, **kwargs): |
| 1108 | + return self._inner.fit(data, **kwargs) |
| 1109 | + |
| 1110 | + result = simulate_power( |
| 1111 | + _UnregisteredEstimator(), |
| 1112 | + data_generator=generate_did_data, |
| 1113 | + estimator_kwargs=dict(outcome="outcome", treatment="treated", time="post"), |
| 1114 | + n_simulations=5, |
| 1115 | + seed=42, |
| 1116 | + progress=False, |
| 1117 | + ) |
| 1118 | + assert 0 <= result.power <= 1 |
| 1119 | + assert result.n_simulations > 0 |
| 1120 | + |
| 1121 | + def test_custom_fallback_missing_kwargs_raises(self): |
| 1122 | + """Unregistered estimator with no estimator_kwargs fails on fit.""" |
| 1123 | + |
| 1124 | + class _UnregisteredEstimator: |
| 1125 | + def __init__(self): |
| 1126 | + self._inner = DifferenceInDifferences() |
| 1127 | + |
| 1128 | + def fit(self, data, **kwargs): |
| 1129 | + return self._inner.fit(data, **kwargs) |
| 1130 | + |
| 1131 | + with pytest.raises((ValueError, TypeError, RuntimeError)): |
| 1132 | + simulate_power( |
| 1133 | + _UnregisteredEstimator(), |
| 1134 | + data_generator=generate_did_data, |
| 1135 | + n_simulations=5, |
| 1136 | + seed=42, |
| 1137 | + progress=False, |
| 1138 | + ) |
| 1139 | + |
1069 | 1140 |
|
1070 | 1141 | # --------------------------------------------------------------------------- |
1071 | 1142 | # simulate_mde tests |
|
0 commit comments