Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,6 @@ Cargo.lock

# Maturin build artifacts
target/

# Claude Code - local settings (user-specific permissions)
.claude/settings.local.json
4 changes: 4 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ Enhancements for `honest_did.py`:
## CallawaySantAnna Bootstrap Improvements

- [ ] Consider aligning p-value computation with R `did` package (symmetric percentile method)
- [ ] Investigate RuntimeWarnings in influence function aggregation (`staggered.py:1722`, `staggered.py:1999-2018`)
- Warnings: "divide by zero", "overflow", "invalid value" in matmul operations
- Occurs during bootstrap SE computation with small sample sizes or edge cases
- Does not affect correctness (results are still valid), but should be suppressed or handled gracefully

---

Expand Down
115 changes: 104 additions & 11 deletions diff_diff/pretrends.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,63 @@ def to_dataframe(self) -> pd.DataFrame:
"""Convert results to DataFrame."""
return pd.DataFrame([self.to_dict()])

def power_at(self, M: float) -> float:
"""
Compute power to detect a specific violation magnitude.

This method allows computing power at different M values without
re-fitting the model, using the stored variance-covariance matrix.

Parameters
----------
M : float
Violation magnitude to evaluate.

Returns
-------
float
Power to detect violation of magnitude M.
"""
from scipy import stats

n_pre = self.n_pre_periods

# Reconstruct violation weights based on violation type
# Must match PreTrendsPower._get_violation_weights() exactly
if self.violation_type == "linear":
# Linear trend: weights decrease toward treatment
# [n-1, n-2, ..., 1, 0] for n pre-periods
weights = np.arange(-n_pre + 1, 1, dtype=float)
weights = -weights # Now [n-1, n-2, ..., 1, 0]
elif self.violation_type == "constant":
weights = np.ones(n_pre)
elif self.violation_type == "last_period":
weights = np.zeros(n_pre)
weights[-1] = 1.0
else:
# For custom, we can't reconstruct - use equal weights as fallback
weights = np.ones(n_pre)

# Normalize weights to unit L2 norm
norm = np.linalg.norm(weights)
if norm > 0:
weights = weights / norm

# Compute non-centrality parameter
try:
vcov_inv = np.linalg.inv(self.vcov)
except np.linalg.LinAlgError:
vcov_inv = np.linalg.pinv(self.vcov)

# delta = M * weights
# nc = delta' * V^{-1} * delta
noncentrality = M**2 * (weights @ vcov_inv @ weights)

# Compute power using non-central chi-squared
power = 1 - stats.ncx2.cdf(self.critical_value, df=n_pre, nc=noncentrality)

return float(power)


@dataclass
class PreTrendsPowerCurve:
Expand Down Expand Up @@ -471,10 +528,18 @@ def _get_violation_weights(self, n_pre: int) -> np.ndarray:
def _extract_pre_period_params(
self,
results: Union[MultiPeriodDiDResults, Any],
pre_periods: Optional[List[int]] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int]:
"""
Extract pre-period parameters from results.

Parameters
----------
results : MultiPeriodDiDResults or similar
Results object from event study estimation.
pre_periods : list of int, optional
Explicit list of pre-treatment periods. If None, uses results.pre_periods.

Returns
-------
effects : np.ndarray
Expand All @@ -487,13 +552,18 @@ def _extract_pre_period_params(
Number of pre-periods.
"""
if isinstance(results, MultiPeriodDiDResults):
# Get pre-period information
all_pre_periods = results.pre_periods
# Get pre-period information - use explicit pre_periods if provided
if pre_periods is not None:
all_pre_periods = list(pre_periods)
else:
all_pre_periods = results.pre_periods

if len(all_pre_periods) == 0:
raise ValueError(
"No pre-treatment periods found in results. "
"Pre-trends power analysis requires pre-period coefficients."
"Pre-trends power analysis requires pre-period coefficients. "
"If you estimated all periods as post_periods, use the pre_periods "
"parameter to specify which are actually pre-treatment."
)

# Only include periods with actual estimated coefficients
Expand Down Expand Up @@ -775,6 +845,7 @@ def fit(
self,
results: Union[MultiPeriodDiDResults, Any],
M: Optional[float] = None,
pre_periods: Optional[List[int]] = None,
) -> PreTrendsPowerResults:
"""
Compute pre-trends power analysis.
Expand All @@ -786,14 +857,19 @@ def fit(
M : float, optional
Specific violation magnitude to evaluate. If None, evaluates at
a default magnitude based on the data.
pre_periods : list of int, optional
Explicit list of pre-treatment periods to use for power analysis.
If None, attempts to infer from results.pre_periods. Use this when
you've estimated an event study with all periods in post_periods
and need to specify which are actually pre-treatment.

Returns
-------
PreTrendsPowerResults
Power analysis results including power and MDV.
"""
# Extract pre-period parameters
effects, ses, vcov, n_pre = self._extract_pre_period_params(results)
effects, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods)

# Get violation weights
weights = self._get_violation_weights(n_pre)
Expand Down Expand Up @@ -831,6 +907,7 @@ def power_at(
self,
results: Union[MultiPeriodDiDResults, Any],
M: float,
pre_periods: Optional[List[int]] = None,
) -> float:
"""
Compute power to detect a specific violation magnitude.
Expand All @@ -841,20 +918,23 @@ def power_at(
Event study results.
M : float
Violation magnitude.
pre_periods : list of int, optional
Explicit list of pre-treatment periods. See fit() for details.

Returns
-------
float
Power to detect violation of magnitude M.
"""
result = self.fit(results, M=M)
result = self.fit(results, M=M, pre_periods=pre_periods)
return result.power

def power_curve(
self,
results: Union[MultiPeriodDiDResults, Any],
M_grid: Optional[List[float]] = None,
n_points: int = 50,
pre_periods: Optional[List[int]] = None,
) -> PreTrendsPowerCurve:
"""
Compute power across a range of violation magnitudes.
Expand All @@ -868,14 +948,16 @@ def power_curve(
automatic grid from 0 to 2.5 * MDV.
n_points : int, default=50
Number of points in automatic grid.
pre_periods : list of int, optional
Explicit list of pre-treatment periods. See fit() for details.

Returns
-------
PreTrendsPowerCurve
Power curve data with plot method.
"""
# Extract parameters
effects, ses, vcov, n_pre = self._extract_pre_period_params(results)
_, ses, vcov, n_pre = self._extract_pre_period_params(results, pre_periods)
weights = self._get_violation_weights(n_pre)

# Compute MDV
Expand Down Expand Up @@ -906,6 +988,7 @@ def power_curve(
def sensitivity_to_honest_did(
self,
results: Union[MultiPeriodDiDResults, Any],
pre_periods: Optional[List[int]] = None,
) -> Dict[str, Any]:
"""
Compare pre-trends power analysis with HonestDiD sensitivity.
Expand All @@ -917,6 +1000,8 @@ def sensitivity_to_honest_did(
----------
results : results object
Event study results.
pre_periods : list of int, optional
Explicit list of pre-treatment periods. See fit() for details.

Returns
-------
Expand All @@ -926,7 +1011,7 @@ def sensitivity_to_honest_did(
- honest_M_at_mdv: Corresponding M value for HonestDiD
- interpretation: Text explaining the relationship
"""
pt_results = self.fit(results)
pt_results = self.fit(results, pre_periods=pre_periods)
mdv = pt_results.mdv

# The MDV represents the size of violation the test could detect
Expand Down Expand Up @@ -993,6 +1078,7 @@ def compute_pretrends_power(
alpha: float = 0.05,
target_power: float = 0.80,
violation_type: str = "linear",
pre_periods: Optional[List[int]] = None,
) -> PreTrendsPowerResults:
"""
Convenience function for pre-trends power analysis.
Expand All @@ -1009,6 +1095,9 @@ def compute_pretrends_power(
Target power for MDV calculation.
violation_type : str, default='linear'
Type of violation pattern.
pre_periods : list of int, optional
Explicit list of pre-treatment periods. If None, attempts to infer
from results. Use when you've estimated all periods as post_periods.

Returns
-------
Expand All @@ -1021,7 +1110,7 @@ def compute_pretrends_power(
>>> from diff_diff.pretrends import compute_pretrends_power
>>>
>>> results = MultiPeriodDiD().fit(data, ...)
>>> power_results = compute_pretrends_power(results)
>>> power_results = compute_pretrends_power(results, pre_periods=[0, 1, 2, 3])
>>> print(f"MDV: {power_results.mdv:.3f}")
>>> print(f"Power: {power_results.power:.1%}")
"""
Expand All @@ -1030,14 +1119,15 @@ def compute_pretrends_power(
power=target_power,
violation_type=violation_type,
)
return pt.fit(results, M=M)
return pt.fit(results, M=M, pre_periods=pre_periods)


def compute_mdv(
results: Union[MultiPeriodDiDResults, Any],
alpha: float = 0.05,
target_power: float = 0.80,
violation_type: str = "linear",
pre_periods: Optional[List[int]] = None,
) -> float:
"""
Compute minimum detectable violation.
Expand All @@ -1049,9 +1139,12 @@ def compute_mdv(
alpha : float, default=0.05
Significance level.
target_power : float, default=0.80
Target power.
Target power for MDV calculation.
violation_type : str, default='linear'
Type of violation pattern.
pre_periods : list of int, optional
Explicit list of pre-treatment periods. If None, attempts to infer
from results. Use when you've estimated all periods as post_periods.

Returns
-------
Expand All @@ -1063,5 +1156,5 @@ def compute_mdv(
power=target_power,
violation_type=violation_type,
)
result = pt.fit(results)
result = pt.fit(results, pre_periods=pre_periods)
return result.mdv
4 changes: 4 additions & 0 deletions diff_diff/staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,10 @@ def fit(
df[time] = pd.to_numeric(df[time])
df[first_treat] = pd.to_numeric(df[first_treat])

# Standardize the first_treat column name for internal use
# This avoids hardcoding column names in internal methods
df['first_treat'] = df[first_treat]

# Identify groups and time periods
time_periods = sorted(df[time].unique())
treatment_groups = sorted([g for g in df[first_treat].unique() if g > 0])
Expand Down
Loading