Skip to content

Commit fcf9e28

Browse files
authored
Merge pull request #135 from igerber/multi-period-r-comparison
Add MultiPeriodDiD vs R (fixest) benchmark
2 parents 4bb60bf + 3eee33e commit fcf9e28

File tree

7 files changed

+665
-4
lines changed

7 files changed

+665
-4
lines changed

CLAUDE.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ python benchmarks/run_benchmarks.py --all
353353

354354
# Run specific estimator
355355
python benchmarks/run_benchmarks.py --estimator callaway
356+
python benchmarks/run_benchmarks.py --estimator multiperiod
356357
```
357358

358359
See `docs/benchmarks.rst` for full methodology and validation results.

METHODOLOGY_REVIEW.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,10 @@ Each estimator in diff-diff should be periodically reviewed to ensure:
134134
fixed to use interaction sub-VCV instead of full regression VCV.
135135

136136
**Outstanding Concerns:**
137-
- No R comparison benchmarks yet (unlike DifferenceInDifferences and CallawaySantAnna which
138-
have formal R benchmark tests). Consider adding `benchmarks/R/multiperiod_benchmark.R`.
137+
- ~~No R comparison benchmarks yet~~**Resolved**: R comparison benchmark added via
138+
`benchmarks/R/benchmark_multiperiod.R` using `fixest::feols(outcome ~ treated * time_f | unit)`.
139+
Results match R exactly: ATT diff < 1e-11, SE diff 0.0%, period effects correlation 1.0.
140+
Validated at small (200 units) and 1k scales.
139141
- Default SE is HC1 (not cluster-robust at unit level as fixest uses). Cluster-robust
140142
available via `cluster` parameter but not the default.
141143
- Endpoint binning for distant event times not yet implemented.
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
#!/usr/bin/env Rscript
2+
# Benchmark: MultiPeriodDiD event study (R `fixest` package)
3+
#
4+
# Usage:
5+
# Rscript benchmark_multiperiod.R --data path/to/data.csv --output path/to/results.json \
6+
# --n-pre 4 --n-post 4
7+
8+
library(fixest)
9+
library(jsonlite)
10+
library(data.table)
11+
12+
# Parse command line arguments
13+
args <- commandArgs(trailingOnly = TRUE)
14+
15+
parse_args <- function(args) {
16+
result <- list(
17+
data = NULL,
18+
output = NULL,
19+
cluster = "unit",
20+
n_pre = NULL,
21+
n_post = NULL,
22+
reference_period = NULL
23+
)
24+
25+
i <- 1
26+
while (i <= length(args)) {
27+
if (args[i] == "--data") {
28+
result$data <- args[i + 1]
29+
i <- i + 2
30+
} else if (args[i] == "--output") {
31+
result$output <- args[i + 1]
32+
i <- i + 2
33+
} else if (args[i] == "--cluster") {
34+
result$cluster <- args[i + 1]
35+
i <- i + 2
36+
} else if (args[i] == "--n-pre") {
37+
result$n_pre <- as.integer(args[i + 1])
38+
i <- i + 2
39+
} else if (args[i] == "--n-post") {
40+
result$n_post <- as.integer(args[i + 1])
41+
i <- i + 2
42+
} else if (args[i] == "--reference-period") {
43+
result$reference_period <- as.integer(args[i + 1])
44+
i <- i + 2
45+
} else {
46+
i <- i + 1
47+
}
48+
}
49+
50+
if (is.null(result$data) || is.null(result$output)) {
51+
stop("Usage: Rscript benchmark_multiperiod.R --data <path> --output <path> --n-pre <int> --n-post <int>")
52+
}
53+
if (is.null(result$n_pre) || is.null(result$n_post)) {
54+
stop("--n-pre and --n-post are required")
55+
}
56+
57+
# Default reference period: last pre-period
58+
if (is.null(result$reference_period)) {
59+
result$reference_period <- result$n_pre
60+
}
61+
62+
return(result)
63+
}
64+
65+
config <- parse_args(args)
66+
67+
# Load data
68+
message(sprintf("Loading data from: %s", config$data))
69+
data <- fread(config$data)
70+
71+
ref_period <- config$reference_period
72+
message(sprintf("Reference period: %d", ref_period))
73+
message(sprintf("n_pre: %d, n_post: %d", config$n_pre, config$n_post))
74+
75+
# Create factor for time with reference level
76+
data[, time_f := relevel(factor(time), ref = as.character(ref_period))]
77+
78+
# Run benchmark
79+
message("Running MultiPeriodDiD estimation (fixest::feols)...")
80+
start_time <- Sys.time()
81+
82+
# Regression: outcome ~ treated * time_f | unit, clustered SEs
83+
# With | unit, fixest absorbs unit fixed effects. The unit-invariant 'treated'
84+
# main effect is collinear with unit FE and is absorbed automatically.
85+
# Interaction coefficients treated:time_fK remain identified.
86+
cluster_formula <- as.formula(paste0("~", config$cluster))
87+
model <- feols(outcome ~ treated * time_f | unit, data = data, cluster = cluster_formula)
88+
89+
estimation_time <- as.numeric(difftime(Sys.time(), start_time, units = "secs"))
90+
91+
# Extract all coefficients and SEs
92+
coefs <- coef(model)
93+
ses <- se(model)
94+
vcov_mat <- vcov(model)
95+
96+
# Extract interaction coefficients (treated:time_fK for each non-reference K)
97+
interaction_mask <- grepl("^treated:time_f", names(coefs))
98+
interaction_names <- names(coefs)[interaction_mask]
99+
interaction_coefs <- coefs[interaction_mask]
100+
interaction_ses <- ses[interaction_mask]
101+
102+
message(sprintf("Found %d interaction coefficients", length(interaction_names)))
103+
104+
# Build period effects list
105+
all_periods <- sort(unique(data$time))
106+
period_effects <- list()
107+
108+
for (i in seq_along(interaction_names)) {
109+
coef_name <- interaction_names[i]
110+
# Extract period value from coefficient name "treated:time_fK"
111+
period_val <- as.integer(sub("treated:time_f", "", coef_name))
112+
event_time <- period_val - ref_period
113+
114+
period_effects[[i]] <- list(
115+
period = period_val,
116+
event_time = event_time,
117+
att = unname(interaction_coefs[i]),
118+
se = unname(interaction_ses[i])
119+
)
120+
}
121+
122+
# Compute average ATT across post-periods (covariance-aware SE)
123+
post_period_names <- c()
124+
for (coef_name in interaction_names) {
125+
period_val <- as.integer(sub("treated:time_f", "", coef_name))
126+
if (period_val > config$n_pre) {
127+
post_period_names <- c(post_period_names, coef_name)
128+
}
129+
}
130+
131+
n_post_periods <- length(post_period_names)
132+
message(sprintf("Post-period interaction coefficients: %d", n_post_periods))
133+
134+
if (n_post_periods > 0) {
135+
avg_att <- mean(coefs[post_period_names])
136+
vcov_sub <- vcov_mat[post_period_names, post_period_names, drop = FALSE]
137+
avg_se <- sqrt(sum(vcov_sub) / n_post_periods^2)
138+
# NaN guard: match registry convention (REGISTRY.md lines 179-183)
139+
if (is.finite(avg_se) && avg_se > 0) {
140+
avg_t <- avg_att / avg_se
141+
avg_pval <- 2 * pt(abs(avg_t), df = model$nobs - length(coefs), lower.tail = FALSE)
142+
avg_ci_lower <- avg_att - qt(0.975, df = model$nobs - length(coefs)) * avg_se
143+
avg_ci_upper <- avg_att + qt(0.975, df = model$nobs - length(coefs)) * avg_se
144+
} else {
145+
avg_t <- NA
146+
avg_pval <- NA
147+
avg_ci_lower <- NA
148+
avg_ci_upper <- NA
149+
}
150+
} else {
151+
avg_att <- NA
152+
avg_se <- NA
153+
avg_pval <- NA
154+
avg_ci_lower <- NA
155+
avg_ci_upper <- NA
156+
}
157+
158+
message(sprintf("Average ATT: %.6f", avg_att))
159+
message(sprintf("Average SE: %.6f", avg_se))
160+
161+
# Format output
162+
results <- list(
163+
estimator = "fixest::feols (multiperiod)",
164+
cluster = config$cluster,
165+
166+
# Average treatment effect
167+
att = avg_att,
168+
se = avg_se,
169+
pvalue = avg_pval,
170+
ci_lower = avg_ci_lower,
171+
ci_upper = avg_ci_upper,
172+
173+
# Reference period
174+
reference_period = ref_period,
175+
176+
# Period-level effects
177+
period_effects = period_effects,
178+
179+
# Timing
180+
timing = list(
181+
estimation_seconds = estimation_time,
182+
total_seconds = estimation_time
183+
),
184+
185+
# Metadata
186+
metadata = list(
187+
r_version = R.version.string,
188+
fixest_version = as.character(packageVersion("fixest")),
189+
n_units = length(unique(data$unit)),
190+
n_periods = length(unique(data$time)),
191+
n_obs = nrow(data),
192+
n_pre = config$n_pre,
193+
n_post = config$n_post
194+
)
195+
)
196+
197+
# Write output
198+
message(sprintf("Writing results to: %s", config$output))
199+
write_json(results, config$output, auto_unbox = TRUE, pretty = TRUE, digits = 10)
200+
201+
message(sprintf("Completed in %.3f seconds", estimation_time))
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Benchmark: MultiPeriodDiD event study (diff-diff MultiPeriodDiD).
4+
5+
Usage:
6+
python benchmark_multiperiod.py --data path/to/data.csv --output path/to/results.json \
7+
--n-pre 4 --n-post 4
8+
"""
9+
10+
import argparse
11+
import json
12+
import os
13+
import sys
14+
from pathlib import Path
15+
16+
# IMPORTANT: Parse --backend and set environment variable BEFORE importing diff_diff
17+
# This ensures the backend configuration is respected by all modules
18+
def _get_backend_from_args():
19+
"""Parse --backend argument without importing diff_diff."""
20+
parser = argparse.ArgumentParser(add_help=False)
21+
parser.add_argument("--backend", default="auto", choices=["auto", "python", "rust"])
22+
args, _ = parser.parse_known_args()
23+
return args.backend
24+
25+
_requested_backend = _get_backend_from_args()
26+
if _requested_backend in ("python", "rust"):
27+
os.environ["DIFF_DIFF_BACKEND"] = _requested_backend
28+
29+
# NOW import diff_diff and other dependencies (will see the env var)
30+
import pandas as pd
31+
32+
# Add parent to path for imports
33+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
34+
35+
from diff_diff import MultiPeriodDiD, HAS_RUST_BACKEND
36+
from benchmarks.python.utils import Timer
37+
38+
39+
def parse_args():
40+
parser = argparse.ArgumentParser(description="Benchmark MultiPeriodDiD estimator")
41+
parser.add_argument("--data", required=True, help="Path to input CSV data")
42+
parser.add_argument("--output", required=True, help="Path to output JSON results")
43+
parser.add_argument(
44+
"--cluster", default="unit", help="Column to cluster standard errors on"
45+
)
46+
parser.add_argument(
47+
"--n-pre", type=int, required=True, help="Number of pre-treatment periods"
48+
)
49+
parser.add_argument(
50+
"--n-post", type=int, required=True, help="Number of post-treatment periods"
51+
)
52+
parser.add_argument(
53+
"--reference-period", type=int, default=None,
54+
help="Reference period (default: last pre-period = n_pre)"
55+
)
56+
parser.add_argument(
57+
"--backend", default="auto", choices=["auto", "python", "rust"],
58+
help="Backend to use: auto (default), python (pure Python), rust (Rust backend)"
59+
)
60+
return parser.parse_args()
61+
62+
63+
def get_actual_backend() -> str:
64+
"""Return the actual backend being used based on HAS_RUST_BACKEND."""
65+
return "rust" if HAS_RUST_BACKEND else "python"
66+
67+
68+
def main():
69+
args = parse_args()
70+
71+
# Get actual backend (already configured via env var before imports)
72+
actual_backend = get_actual_backend()
73+
print(f"Using backend: {actual_backend}")
74+
75+
# Load data
76+
print(f"Loading data from: {args.data}")
77+
data = pd.read_csv(args.data)
78+
79+
# Compute post_periods and reference_period from args
80+
all_periods = sorted(data["time"].unique())
81+
n_pre = args.n_pre
82+
post_periods = [p for p in all_periods if p > n_pre]
83+
ref_period = args.reference_period if args.reference_period is not None else n_pre
84+
85+
print(f"All periods: {all_periods}")
86+
print(f"Post periods: {post_periods}")
87+
print(f"Reference period: {ref_period}")
88+
89+
# Run benchmark
90+
print("Running MultiPeriodDiD estimation...")
91+
92+
did = MultiPeriodDiD(robust=True, cluster=args.cluster)
93+
94+
with Timer() as timer:
95+
results = did.fit(
96+
data,
97+
outcome="outcome",
98+
treatment="treated",
99+
time="time",
100+
post_periods=post_periods,
101+
reference_period=ref_period,
102+
absorb=["unit"],
103+
)
104+
105+
total_time = timer.elapsed
106+
107+
# Extract period effects (excluding reference period)
108+
period_effects = []
109+
for period, pe in sorted(results.period_effects.items()):
110+
event_time = period - ref_period
111+
period_effects.append({
112+
"period": int(period),
113+
"event_time": int(event_time),
114+
"att": float(pe.effect),
115+
"se": float(pe.se),
116+
})
117+
118+
# Build output
119+
output = {
120+
"estimator": "diff_diff.MultiPeriodDiD",
121+
"backend": actual_backend,
122+
"cluster": args.cluster,
123+
# Average treatment effect (across post-periods)
124+
"att": float(results.avg_att),
125+
"se": float(results.avg_se),
126+
"pvalue": float(results.avg_p_value),
127+
"ci_lower": float(results.avg_conf_int[0]),
128+
"ci_upper": float(results.avg_conf_int[1]),
129+
# Reference period
130+
"reference_period": int(ref_period),
131+
# Period-level effects
132+
"period_effects": period_effects,
133+
# Timing
134+
"timing": {
135+
"estimation_seconds": total_time,
136+
"total_seconds": total_time,
137+
},
138+
# Metadata
139+
"metadata": {
140+
"n_units": int(data["unit"].nunique()),
141+
"n_periods": int(data["time"].nunique()),
142+
"n_obs": len(data),
143+
"n_pre": n_pre,
144+
"n_post": len(post_periods),
145+
},
146+
}
147+
148+
# Write output
149+
print(f"Writing results to: {args.output}")
150+
output_path = Path(args.output)
151+
output_path.parent.mkdir(parents=True, exist_ok=True)
152+
with open(output_path, "w") as f:
153+
json.dump(output, f, indent=2)
154+
155+
print(f"ATT: {results.avg_att:.6f}")
156+
print(f"SE: {results.avg_se:.6f}")
157+
print(f"Completed in {total_time:.3f} seconds")
158+
return output
159+
160+
161+
if __name__ == "__main__":
162+
main()

0 commit comments

Comments
 (0)