Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion diff_diff/staggered.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ class CallawaySantAnna(
- "universal": Always use g-1-anticipation as base period.
Both produce identical post-treatment effects. Matches R's
did::att_gt() base_period parameter.
cband : bool, default=True
Whether to compute simultaneous confidence bands (sup-t) for
event study aggregation. Requires ``n_bootstrap > 0``.
When True, results include ``cband_crit_value`` and per-event-time
``cband_conf_int`` entries controlling family-wise error rate.

Attributes
----------
Expand Down Expand Up @@ -302,6 +307,7 @@ def __init__(
seed: Optional[int] = None,
rank_deficient_action: str = "warn",
base_period: str = "varying",
cband: bool = True,
):
import warnings

Expand Down Expand Up @@ -362,6 +368,8 @@ def __init__(
self.rank_deficient_action = rank_deficient_action
self.base_period = base_period

self.cband = cband

self.is_fitted_ = False
self.results_: Optional[CallawaySantAnnaResults] = None

Expand Down Expand Up @@ -728,7 +736,8 @@ def fit(
if aggregate in ["event_study", "all"]:
event_study_effects = self._aggregate_event_study(
group_time_effects, influence_func_info,
treatment_groups, time_periods, balance_e
treatment_groups, time_periods, balance_e,
df, unit, precomputed,
)

if aggregate in ["group", "all"]:
Expand All @@ -746,6 +755,10 @@ def fit(
balance_e=balance_e,
treatment_groups=treatment_groups,
time_periods=time_periods,
df=df,
unit=unit,
precomputed=precomputed,
cband=self.cband,
)

# Update estimates with bootstrap inference
Expand Down Expand Up @@ -793,6 +806,20 @@ def fit(
se = float(group_effects[g]['se'])
group_effects[g]['t_stat'] = safe_inference(effect, se, alpha=self.alpha)[0]

# Compute simultaneous confidence band CIs if cband is available
cband_crit_value = None
if bootstrap_results is not None:
cband_crit_value = bootstrap_results.cband_crit_value

if cband_crit_value is not None and event_study_effects is not None:
for e, eff_data in event_study_effects.items():
se_val = eff_data['se']
if np.isfinite(se_val) and se_val > 0:
eff_data['cband_conf_int'] = (
eff_data['effect'] - cband_crit_value * se_val,
eff_data['effect'] + cband_crit_value * se_val,
)

# Store results
self.results_ = CallawaySantAnnaResults(
group_time_effects=group_time_effects,
Expand All @@ -812,6 +839,7 @@ def fit(
event_study_effects=event_study_effects,
group_effects=group_effects,
bootstrap_results=bootstrap_results,
cband_crit_value=cband_crit_value,
)

self.is_fitted_ = True
Expand Down Expand Up @@ -1085,6 +1113,7 @@ def get_params(self) -> Dict[str, Any]:
"seed": self.seed,
"rank_deficient_action": self.rank_deficient_action,
"base_period": self.base_period,
"cband": self.cband,
}

def set_params(self, **params) -> "CallawaySantAnna":
Expand Down
191 changes: 114 additions & 77 deletions diff_diff/staggered_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def _compute_aggregated_se(
variance = np.sum(psi_overall ** 2)
return np.sqrt(variance)

def _compute_aggregated_se_with_wif(
def _compute_combined_influence_function(
self,
gt_pairs: List[Tuple[Any, Any]],
weights: np.ndarray,
Expand All @@ -172,42 +172,48 @@ def _compute_aggregated_se_with_wif(
df: pd.DataFrame,
unit: str,
precomputed: Optional["PrecomputedData"] = None,
) -> float:
global_unit_to_idx: Optional[Dict[Any, int]] = None,
n_global_units: Optional[int] = None,
) -> Tuple[np.ndarray, Optional[List]]:
"""
Compute SE with weight influence function (wif) adjustment.

This matches R's `did` package approach for "simple" aggregation,
which accounts for uncertainty in estimating group-size weights.

The wif adjustment adds variance due to the fact that aggregation
weights w_g = n_g / N depend on estimated group sizes.

Formula (matching R's did::aggte):
agg_inf_i = Σ_k w_k × inf_i_k + wif_i × ATT_k
se = sqrt(mean(agg_inf^2) / n)

where:
- k indexes "keepers" (post-treatment (g,t) pairs)
- w_k = pg[k] / sum(pg[keepers]) where pg = n_g / n_all
- wif captures how unit i influences the weight estimation
Compute the combined (standard IF + WIF) influence function vector.

If global_unit_to_idx / n_global_units are provided, the returned vector
is zero-padded to the global unit set for bootstrap alignment.
Otherwise, the returned vector is indexed by the local unit set
(all units appearing in the (g,t) pairs).

Returns
-------
combined_if : np.ndarray
Per-unit combined influence function (standard IF + WIF).
all_units : list or None
Ordered list of units (only when using local indexing).
"""
if not influence_func_info:
return 0.0

# Build unit index mapping
all_units_set: Set[Any] = set()
for (g, t) in gt_pairs:
if (g, t) in influence_func_info:
info = influence_func_info[(g, t)]
all_units_set.update(info['treated_units'])
all_units_set.update(info['control_units'])
if n_global_units is not None:
return np.zeros(n_global_units), None
return np.zeros(0), None

# Build unit index mapping (local or global)
if global_unit_to_idx is not None and n_global_units is not None:
unit_to_idx = global_unit_to_idx
n_units = n_global_units
all_units = None # caller already has the unit list
else:
all_units_set: Set[Any] = set()
for (g, t) in gt_pairs:
if (g, t) in influence_func_info:
info = influence_func_info[(g, t)]
all_units_set.update(info['treated_units'])
all_units_set.update(info['control_units'])

if not all_units_set:
return 0.0
if not all_units_set:
return np.zeros(0), []

all_units = sorted(all_units_set)
n_units = len(all_units)
unit_to_idx = {u: i for i, u in enumerate(all_units)}
all_units = sorted(all_units_set)
n_units = len(all_units)
unit_to_idx = {u: i for i, u in enumerate(all_units)}

# Get unique groups and their information
unique_groups = sorted(set(groups_for_gt))
Expand All @@ -216,7 +222,6 @@ def _compute_aggregated_se_with_wif(

# Compute group-level probabilities matching R's formula:
# pg[g] = n_g / n_all (fraction of ALL units in group g)
# This differs from our old formula which used n_g / total_treated
group_sizes = {}
for g in unique_groups:
treated_in_g = df[df['first_treat'] == g][unit].nunique()
Expand All @@ -226,13 +231,12 @@ def _compute_aggregated_se_with_wif(
pg_by_group = np.array([group_sizes[g] / n_units for g in unique_groups])

# pg indexed by keeper (each (g,t) pair gets its group's pg)
# This matches R's: pg <- pgg[match(group, originalglist)]
pg_keepers = np.array([pg_by_group[group_to_idx[g]] for g in groups_for_gt])
sum_pg_keepers = np.sum(pg_keepers)

# Guard against zero weights (no keepers = no variance)
if sum_pg_keepers == 0:
return 0.0
return np.zeros(n_units), all_units

# Standard aggregated influence (without wif)
psi_standard = np.zeros(n_units)
Expand All @@ -254,58 +258,37 @@ def _compute_aggregated_se_with_wif(
if len(control_indices) > 0:
np.add.at(psi_standard, control_indices, w * info['control_inf'])

# Build unit-group array using precomputed data if available
# This is O(n_units) instead of O(n_units × n_obs) DataFrame lookups
# Build unit-group array: normalize iterator to (idx, uid) pairs
unit_groups_array = np.full(n_units, -1, dtype=np.float64)
idx_uid_pairs = (
[(idx, uid) for uid, idx in global_unit_to_idx.items()]
if global_unit_to_idx is not None
else list(enumerate(all_units))
)

if precomputed is not None:
# Use precomputed cohort mapping
precomputed_units = precomputed['all_units']
precomputed_cohorts = precomputed['unit_cohorts']
precomputed_unit_to_idx = precomputed['unit_to_idx']

# Build unit_groups_array for the units in this SE computation
# A value of -1 indicates never-treated or other (not in unique_groups)
unit_groups_array = np.full(n_units, -1, dtype=np.float64)
for i, uid in enumerate(all_units):
for idx, uid in idx_uid_pairs:
if uid in precomputed_unit_to_idx:
cohort = precomputed_cohorts[precomputed_unit_to_idx[uid]]
if cohort in unique_groups_set:
unit_groups_array[i] = cohort
unit_groups_array[idx] = cohort
else:
# Fallback: build from DataFrame (slow path for backward compatibility)
unit_groups_array = np.full(n_units, -1, dtype=np.float64)
for i, uid in enumerate(all_units):
for idx, uid in idx_uid_pairs:
unit_first_treat = df[df[unit] == uid]['first_treat'].iloc[0]
if unit_first_treat in unique_groups_set:
unit_groups_array[i] = unit_first_treat
unit_groups_array[idx] = unit_first_treat

# Vectorized WIF computation
# R's wif formula:
# if1[i,k] = (indicator(G_i == group_k) - pg[k]) / sum(pg[keepers])
# if2[i,k] = indicator_sum[i] * pg[k] / sum(pg[keepers])^2
# wif[i,k] = if1[i,k] - if2[i,k]
# wif_contrib[i] = sum_k(wif[i,k] * att[k])

# Build indicator matrix: (n_units, n_keepers)
# indicator_matrix[i, k] = 1.0 if unit i belongs to group for keeper k
groups_for_gt_array = np.array(groups_for_gt)
indicator_matrix = (unit_groups_array[:, np.newaxis] == groups_for_gt_array[np.newaxis, :]).astype(np.float64)

# Vectorized indicator_sum: sum over keepers
# indicator_sum[i] = sum_k(indicator(G_i == group_k) - pg[k])
indicator_sum = np.sum(indicator_matrix - pg_keepers, axis=1)

# Vectorized wif matrix computation
# Suppress RuntimeWarnings for edge cases (small samples, extreme weights)
# in division operations and matrix multiplication
with np.errstate(divide='ignore', invalid='ignore', over='ignore'):
# if1_matrix[i,k] = (indicator[i,k] - pg[k]) / sum_pg
if1_matrix = (indicator_matrix - pg_keepers) / sum_pg_keepers
# if2_matrix[i,k] = indicator_sum[i] * pg[k] / sum_pg^2
if2_matrix = np.outer(indicator_sum, pg_keepers) / (sum_pg_keepers ** 2)
wif_matrix = if1_matrix - if2_matrix

# Single matrix-vector multiply for all contributions
# wif_contrib[i] = sum_k(wif[i,k] * att[k])
wif_contrib = wif_matrix @ effects

# Check for non-finite values from edge cases
Expand All @@ -319,16 +302,64 @@ def _compute_aggregated_se_with_wif(
RuntimeWarning,
stacklevel=2
)
return np.nan # Signal invalid inference instead of biased SE
nan_result = np.full(n_units, np.nan)
return nan_result, all_units

# Scale by 1/n_units to match R's getSE formula: sqrt(mean(IF^2)/n)
# Scale by 1/n_units to match R's getSE formula
psi_wif = wif_contrib / n_units

# Combine standard and wif terms
psi_total = psi_standard + psi_wif

# Compute variance and SE
# R's formula: sqrt(mean(IF^2) / n) = sqrt(sum(IF^2) / n^2)
return psi_total, all_units

def _compute_aggregated_se_with_wif(
self,
gt_pairs: List[Tuple[Any, Any]],
weights: np.ndarray,
effects: np.ndarray,
groups_for_gt: np.ndarray,
influence_func_info: Dict,
df: pd.DataFrame,
unit: str,
precomputed: Optional["PrecomputedData"] = None,
) -> float:
"""
Compute SE with weight influence function (wif) adjustment.

This matches R's `did` package approach for aggregation,
which accounts for uncertainty in estimating group-size weights.

Formula (matching R's did::aggte):
agg_inf_i = Σ_k w_k × inf_i_k + wif_i × ATT_k
se = sqrt(mean(agg_inf^2) / n)
"""
# Extract global unit info for correct pg = n_g / N_total scaling.
# Without this, the local path builds the unit set from only units in
# the selected (g,t) pairs, causing pg overestimation at extreme event
# times where only early-adopter groups have data.
global_unit_to_idx = None
n_global_units = None
if precomputed is not None:
global_unit_to_idx = precomputed['unit_to_idx']
n_global_units = len(precomputed['all_units'])
elif df is not None and unit is not None:
n_global_units = df[unit].nunique()

psi_total, _ = self._compute_combined_influence_function(
gt_pairs, weights, effects, groups_for_gt,
influence_func_info, df, unit, precomputed,
global_unit_to_idx=global_unit_to_idx,
n_global_units=n_global_units,
)

if len(psi_total) == 0:
return 0.0

# Check for NaN propagation from non-finite WIF
if not np.all(np.isfinite(psi_total)):
return np.nan

variance = np.sum(psi_total ** 2)
return np.sqrt(variance)

Expand All @@ -339,14 +370,18 @@ def _aggregate_event_study(
groups: List[Any],
time_periods: List[Any],
balance_e: Optional[int] = None,
df: Optional[pd.DataFrame] = None,
unit: Optional[str] = None,
precomputed: Optional["PrecomputedData"] = None,
) -> Dict[int, Dict[str, Any]]:
"""
Aggregate effects by relative time (event study).

Computes average effect at each event time e = t - g.

Standard errors use influence function aggregation to account for
covariances across (g,t) pairs.
Standard errors include the weight influence function (WIF)
adjustment that accounts for uncertainty in group-size weights,
matching R's did::aggte(..., type="dynamic").
"""
# Organize effects by relative time, keeping track of (g,t) pairs
effects_by_e: Dict[int, List[Tuple[Tuple[Any, Any], float, int]]] = {}
Expand Down Expand Up @@ -396,9 +431,11 @@ def _aggregate_event_study(

agg_effect = np.sum(weights * effs)

# Compute SE using influence function aggregation
agg_se = self._compute_aggregated_se(
gt_pairs, weights, influence_func_info
# Compute SE with WIF adjustment (matching R's did::aggte)
groups_for_gt = np.array([g for (g, t) in gt_pairs])
agg_se = self._compute_aggregated_se_with_wif(
gt_pairs, weights, effs, groups_for_gt,
influence_func_info, df, unit, precomputed
)

t_stat, p_val, ci = safe_inference(agg_effect, agg_se, alpha=self.alpha)
Expand Down
Loading
Loading