Skip to content

Commit fc7888b

Browse files
authored
Merge pull request #153 from igerber/sa-method-review
Sun-Abraham methodology review: fix 5 issues, add R benchmarks
2 parents f966182 + 9026464 commit fc7888b

File tree

9 files changed

+1012
-38
lines changed

9 files changed

+1012
-38
lines changed

METHODOLOGY_REVIEW.md

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Each estimator in diff-diff should be periodically reviewed to ensure:
2424
| MultiPeriodDiD | `estimators.py` | `fixest::feols()` | **Complete** | 2026-02-02 |
2525
| TwoWayFixedEffects | `twfe.py` | `fixest::feols()` | **Complete** | 2026-02-08 |
2626
| CallawaySantAnna | `staggered.py` | `did::att_gt()` | **Complete** | 2026-01-24 |
27-
| SunAbraham | `sun_abraham.py` | `fixest::sunab()` | Not Started | - |
27+
| SunAbraham | `sun_abraham.py` | `fixest::sunab()` | **Complete** | 2026-02-15 |
2828
| SyntheticDiD | `synthetic_did.py` | `synthdid::synthdid_estimate()` | **Complete** | 2026-02-10 |
2929
| TripleDifference | `triple_diff.py` | (forthcoming) | Not Started | - |
3030
| TROP | `trop.py` | (forthcoming) | Not Started | - |
@@ -294,14 +294,88 @@ variables appear to the left of the `|` separator.
294294
| Module | `sun_abraham.py` |
295295
| Primary Reference | Sun & Abraham (2021) |
296296
| R Reference | `fixest::sunab()` |
297-
| Status | Not Started |
298-
| Last Review | - |
297+
| Status | **Complete** |
298+
| Last Review | 2026-02-15 |
299+
300+
**Verified Components:**
301+
- [x] Saturated TWFE regression with cohort × relative-time interactions
302+
- [x] Within-transformation for unit and time fixed effects
303+
- [x] Interaction-weighted event study effects (δ̂_e = Σ_g ŵ_{g,e} × δ̂_{g,e})
304+
- [x] IW weights match event-time sample shares (n_{g,e} / Σ_g n_{g,e})
305+
- [x] Overall ATT as weighted average of post-treatment effects
306+
- [x] Delta method SE for aggregated effects (Var = w' Σ w)
307+
- [x] Cluster-robust SEs at unit level
308+
- [x] Reference period normalized to zero (e=-1 excluded from design matrix)
309+
- [x] R comparison: ATT matches `fixest::sunab()` within machine precision (<1e-11)
310+
- [x] R comparison: SE matches within 0.3% (small scale) / 0.1% (1k scale)
311+
- [x] R comparison: Event study effects correlation = 1.000000
312+
- [x] R comparison: Event study max diff < 1e-11
313+
- [x] Bootstrap inference (pairs bootstrap)
314+
- [x] Rank deficiency handling (warn/error/silent)
315+
- [x] All REGISTRY.md edge cases tested
316+
317+
**Test Coverage:**
318+
- 43 tests in `tests/test_sun_abraham.py` (36 existing + 7 methodology verification)
319+
- R benchmark tests via `benchmarks/run_benchmarks.py --estimator sunab`
320+
321+
**R Comparison Results:**
322+
- Overall ATT matches within machine precision (diff < 1e-11 at both scales)
323+
- Cluster-robust SE matches within 0.3% (well within 1% threshold)
324+
- Event study effects match perfectly (correlation 1.0, max diff < 1e-11)
325+
- Validated at small (200 units) and 1k (1000 units) scales
299326

300327
**Corrections Made:**
301-
- (None yet)
328+
1. **DF adjustment for absorbed FE** (`sun_abraham.py`, `_fit_saturated_regression()`):
329+
Added `df_adjustment = n_units + n_times - 1` to `LinearRegression.fit()` to account
330+
for absorbed unit and time fixed effects in degrees of freedom. Unlike TWFE (which uses
331+
`-2` plus an explicit intercept column), SunAbraham's saturated regression has no
332+
intercept, so all absorbed df must come from the adjustment. Affects t-distribution DoF
333+
for cohort-level p-values/CIs (slightly larger p-values, slightly wider CIs) but does
334+
NOT change VCV or SE values.
335+
336+
2. **NaN return for no post-treatment effects** (`sun_abraham.py`, `_compute_overall_att()`):
337+
Changed return from `(0.0, 0.0)` to `(np.nan, np.nan)` when no post-treatment effects
338+
exist. All downstream inference fields (t_stat, p_value, conf_int) correctly propagate
339+
NaN via existing guards in `fit()`.
340+
341+
3. **Deprecation warnings for unused parameters** (`sun_abraham.py`, `fit()`):
342+
Added `FutureWarning` for `min_pre_periods` and `min_post_periods` parameters that
343+
are accepted but never used (no-op). These will be removed in a future version.
344+
345+
4. **Removed event-time truncation at [-20, 20]** (`sun_abraham.py`):
346+
Removed the hardcoded cap `max(min(...), -20)` / `min(max(...), 20)` to match
347+
R's `fixest::sunab()` which has no such limit. All available relative times are
348+
now estimated.
349+
350+
5. **Warning for variance fallback path** (`sun_abraham.py`, `_compute_overall_att()`):
351+
Added `UserWarning` when the full weight vector cannot be constructed and a
352+
simplified variance (ignoring covariances between periods) is used as fallback.
353+
354+
6. **IW weights use event-time sample shares** (`sun_abraham.py`, `_compute_iw_effects()`):
355+
Changed IW weights from `n_g / Σ_g n_g` (cohort sizes) to `n_{g,e} / Σ_g n_{g,e}`
356+
(per-event-time observation counts) to match the REGISTRY.md formula. For balanced
357+
panels these are identical; for unbalanced panels the new formula correctly reflects
358+
actual sample composition at each event-time. Added unbalanced panel test.
359+
360+
7. **Normalize `np.inf` never-treated encoding** (`sun_abraham.py`, `fit()`):
361+
`first_treat=np.inf` (documented as valid for never-treated) was included in
362+
`treatment_groups` and `_rel_time` via `> 0` checks, producing `-inf` event times.
363+
Fixed by normalizing `np.inf` to `0` immediately after computing `_never_treated`.
364+
Same fix applied to `staggered.py` (`CallawaySantAnna`).
302365

303366
**Outstanding Concerns:**
304-
- (None yet)
367+
- **Inference distribution**: Cohort-level p-values use t-distribution (via
368+
`LinearRegression.get_inference()`), while aggregated event study and overall ATT
369+
p-values use normal distribution (via `compute_p_value()`). This is asymptotically
370+
equivalent and standard for delta-method-aggregated quantities. R's fixest uses
371+
t-distribution at all levels, so aggregated p-values may differ slightly for small
372+
samples — this is a documented deviation.
373+
374+
**Deviations from R's fixest::sunab():**
375+
1. **NaN for no post-treatment effects**: Python returns `(NaN, NaN)` for overall ATT/SE
376+
when no post-treatment effects exist. R would error.
377+
2. **Normal distribution for aggregated inference**: Aggregated p-values use normal
378+
distribution (asymptotically equivalent). R uses t-distribution.
305379

306380
---
307381

benchmarks/R/benchmark_sunab.R

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#!/usr/bin/env Rscript
2+
# Benchmark: Sun-Abraham interaction-weighted estimator (R `fixest::sunab()`)
3+
#
4+
# This uses fixest::sunab() with unit+time FE and unit-level clustering,
5+
# matching the Python SunAbraham estimator's approach.
6+
#
7+
# Usage:
8+
# Rscript benchmark_sunab.R --data path/to/data.csv --output path/to/results.json
9+
10+
library(fixest)
11+
library(jsonlite)
12+
library(data.table)
13+
14+
# Parse command line arguments
15+
args <- commandArgs(trailingOnly = TRUE)
16+
17+
parse_args <- function(args) {
18+
result <- list(
19+
data = NULL,
20+
output = NULL
21+
)
22+
23+
i <- 1
24+
while (i <= length(args)) {
25+
if (args[i] == "--data") {
26+
result$data <- args[i + 1]
27+
i <- i + 2
28+
} else if (args[i] == "--output") {
29+
result$output <- args[i + 1]
30+
i <- i + 2
31+
} else {
32+
i <- i + 1
33+
}
34+
}
35+
36+
if (is.null(result$data) || is.null(result$output)) {
37+
stop("Usage: Rscript benchmark_sunab.R --data <path> --output <path>")
38+
}
39+
40+
return(result)
41+
}
42+
43+
config <- parse_args(args)
44+
45+
# Load data
46+
message(sprintf("Loading data from: %s", config$data))
47+
data <- fread(config$data)
48+
49+
# Convert first_treat to double before assigning Inf (integer column can't hold Inf)
50+
data[, first_treat := as.double(first_treat)]
51+
# Convert never-treated coding: first_treat=0 -> Inf (R's convention for never-treated)
52+
data[first_treat == 0, first_treat := Inf]
53+
54+
# Run benchmark
55+
message("Running Sun-Abraham estimation with fixest::sunab()...")
56+
start_time <- Sys.time()
57+
58+
# Sun-Abraham with unit+time FE, clustered at unit level
59+
# sunab(cohort, period) creates the interaction-weighted estimator
60+
model <- feols(
61+
outcome ~ sunab(first_treat, time) | unit + time,
62+
data = data,
63+
cluster = ~unit
64+
)
65+
66+
estimation_time <- as.numeric(difftime(Sys.time(), start_time, units = "secs"))
67+
68+
# Extract event study effects (per-relative-period IW coefficients)
69+
es_coefs <- coef(model)
70+
es_ses <- se(model)
71+
72+
# Build event study list
73+
event_study <- list()
74+
coef_names <- names(es_coefs)
75+
for (i in seq_along(es_coefs)) {
76+
name <- coef_names[i]
77+
# fixest sunab names coefficients like "time::-4" or "time::2"
78+
event_time <- as.numeric(gsub("^time::(-?[0-9]+)$", "\\1", name))
79+
80+
event_study[[length(event_study) + 1]] <- list(
81+
event_time = event_time,
82+
att = unname(es_coefs[i]),
83+
se = unname(es_ses[i])
84+
)
85+
}
86+
87+
# Aggregate to get overall ATT (weighted by observation count per cell)
88+
# aggregate() returns a matrix with columns: Estimate, Std. Error, t value, Pr(>|t|)
89+
agg_result <- aggregate(model, agg = "ATT")
90+
91+
overall_att <- agg_result[1, "Estimate"]
92+
overall_se <- agg_result[1, "Std. Error"]
93+
overall_pvalue <- agg_result[1, "Pr(>|t|)"]
94+
95+
message(sprintf("Overall ATT: %.6f (SE: %.6f)", overall_att, overall_se))
96+
97+
# Format output
98+
results <- list(
99+
estimator = "fixest::sunab()",
100+
cluster = "unit",
101+
102+
# Overall ATT (aggregated)
103+
overall_att = overall_att,
104+
overall_se = overall_se,
105+
overall_pvalue = overall_pvalue,
106+
107+
# Event study effects
108+
event_study = event_study,
109+
110+
# Timing
111+
timing = list(
112+
estimation_seconds = estimation_time,
113+
total_seconds = estimation_time
114+
),
115+
116+
# Metadata
117+
metadata = list(
118+
r_version = R.version.string,
119+
fixest_version = as.character(packageVersion("fixest")),
120+
n_units = length(unique(data$unit)),
121+
n_periods = length(unique(data$time)),
122+
n_obs = nrow(data),
123+
n_event_study_coefs = length(es_coefs)
124+
)
125+
)
126+
127+
# Write output
128+
message(sprintf("Writing results to: %s", config$output))
129+
write_json(results, config$output, auto_unbox = TRUE, pretty = TRUE, digits = 15)
130+
131+
message(sprintf("Completed in %.3f seconds", estimation_time))
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Benchmark: SunAbraham interaction-weighted estimator (diff-diff SunAbraham class).
4+
5+
This benchmarks the SunAbraham estimator with cluster-robust SEs,
6+
matching R's fixest::sunab() approach.
7+
8+
Usage:
9+
python benchmark_sun_abraham.py --data path/to/data.csv --output path/to/results.json
10+
"""
11+
12+
import argparse
13+
import json
14+
import os
15+
import sys
16+
from pathlib import Path
17+
18+
# IMPORTANT: Parse --backend and set environment variable BEFORE importing diff_diff
19+
def _get_backend_from_args():
20+
"""Parse --backend argument without importing diff_diff."""
21+
parser = argparse.ArgumentParser(add_help=False)
22+
parser.add_argument("--backend", default="auto", choices=["auto", "python", "rust"])
23+
args, _ = parser.parse_known_args()
24+
return args.backend
25+
26+
_requested_backend = _get_backend_from_args()
27+
if _requested_backend in ("python", "rust"):
28+
os.environ["DIFF_DIFF_BACKEND"] = _requested_backend
29+
30+
# NOW import diff_diff and other dependencies (will see the env var)
31+
import numpy as np
32+
import pandas as pd
33+
34+
# Add parent to path for imports
35+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
36+
37+
from diff_diff import SunAbraham, HAS_RUST_BACKEND
38+
from benchmarks.python.utils import Timer
39+
40+
41+
def parse_args():
42+
parser = argparse.ArgumentParser(description="Benchmark SunAbraham estimator")
43+
parser.add_argument("--data", required=True, help="Path to input CSV data")
44+
parser.add_argument("--output", required=True, help="Path to output JSON results")
45+
parser.add_argument(
46+
"--backend", default="auto", choices=["auto", "python", "rust"],
47+
help="Backend to use: auto (default), python (pure Python), rust (Rust backend)"
48+
)
49+
return parser.parse_args()
50+
51+
52+
def get_actual_backend() -> str:
53+
"""Return the actual backend being used based on HAS_RUST_BACKEND."""
54+
return "rust" if HAS_RUST_BACKEND else "python"
55+
56+
57+
def main():
58+
args = parse_args()
59+
60+
actual_backend = get_actual_backend()
61+
print(f"Using backend: {actual_backend}")
62+
63+
# Load data
64+
print(f"Loading data from: {args.data}")
65+
data = pd.read_csv(args.data)
66+
67+
# Run benchmark using SunAbraham (analytical SEs, no bootstrap)
68+
print("Running Sun-Abraham estimation...")
69+
70+
sa = SunAbraham(control_group="never_treated", n_bootstrap=0)
71+
72+
with Timer() as timer:
73+
results = sa.fit(
74+
data,
75+
outcome="outcome",
76+
unit="unit",
77+
time="time",
78+
first_treat="first_treat",
79+
)
80+
81+
overall_att = results.overall_att
82+
overall_se = results.overall_se
83+
overall_pvalue = results.overall_p_value
84+
85+
# Extract event study effects
86+
event_study = []
87+
for e in sorted(results.event_study_effects.keys()):
88+
eff = results.event_study_effects[e]
89+
event_study.append({
90+
"event_time": int(e),
91+
"att": float(eff["effect"]),
92+
"se": float(eff["se"]),
93+
})
94+
95+
total_time = timer.elapsed
96+
97+
# Build output
98+
output = {
99+
"estimator": "diff_diff.SunAbraham",
100+
"backend": actual_backend,
101+
"cluster": "unit",
102+
# Overall ATT
103+
"overall_att": float(overall_att),
104+
"overall_se": float(overall_se),
105+
"overall_pvalue": float(overall_pvalue),
106+
# Event study effects
107+
"event_study": event_study,
108+
# Timing
109+
"timing": {
110+
"estimation_seconds": total_time,
111+
"total_seconds": total_time,
112+
},
113+
# Metadata
114+
"metadata": {
115+
"n_units": len(data["unit"].unique()),
116+
"n_periods": len(data["time"].unique()),
117+
"n_obs": len(data),
118+
"n_groups": len(results.groups),
119+
"n_event_study_effects": len(event_study),
120+
},
121+
}
122+
123+
# Write output
124+
print(f"Writing results to: {args.output}")
125+
output_path = Path(args.output)
126+
output_path.parent.mkdir(parents=True, exist_ok=True)
127+
with open(output_path, "w") as f:
128+
json.dump(output, f, indent=2)
129+
130+
print(f"Overall ATT: {overall_att:.6f} (SE: {overall_se:.6f})")
131+
print(f"Completed in {total_time:.3f} seconds")
132+
return output
133+
134+
135+
if __name__ == "__main__":
136+
main()

0 commit comments

Comments
 (0)