Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,12 @@ pytest tests/test_rust_backend.py -v
- **`rust/src/bootstrap.rs`** - Parallel bootstrap weight generation (Rademacher, Mammen, Webb)
- **`rust/src/linalg.rs`** - OLS solver and cluster-robust variance estimation
- **`rust/src/weights.rs`** - Synthetic control weights and simplex projection
- **`rust/src/trop.rs`** - TROP estimator acceleration:
- `compute_unit_distance_matrix()` - Parallel pairwise RMSE distance computation (4-8x speedup)
- `loocv_grid_search()` - Parallel LOOCV across tuning parameters (10-50x speedup)
- `bootstrap_trop_variance()` - Parallel bootstrap variance estimation (5-15x speedup)
- Uses ndarray-linalg with OpenBLAS (Linux/macOS) or Intel MKL (Windows)
- Provides 4-8x speedup for SyntheticDiD, minimal benefit for other estimators
- Provides 4-8x speedup for SyntheticDiD, 5-20x speedup for TROP

- **`diff_diff/results.py`** - Dataclass containers for estimation results:
- `DiDResults`, `MultiPeriodDiDResults`, `SyntheticDiDResults`, `PeriodEffect`
Expand Down
16 changes: 16 additions & 0 deletions diff_diff/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
project_simplex as _rust_project_simplex,
solve_ols as _rust_solve_ols,
compute_robust_vcov as _rust_compute_robust_vcov,
# TROP estimator acceleration
compute_unit_distance_matrix as _rust_unit_distance_matrix,
loocv_grid_search as _rust_loocv_grid_search,
bootstrap_trop_variance as _rust_bootstrap_trop_variance,
)
_rust_available = True
except ImportError:
Expand All @@ -32,6 +36,10 @@
_rust_project_simplex = None
_rust_solve_ols = None
_rust_compute_robust_vcov = None
# TROP estimator acceleration
_rust_unit_distance_matrix = None
_rust_loocv_grid_search = None
_rust_bootstrap_trop_variance = None

# Determine final backend based on environment variable and availability
if _backend_env == 'python':
Expand All @@ -42,6 +50,10 @@
_rust_project_simplex = None
_rust_solve_ols = None
_rust_compute_robust_vcov = None
# TROP estimator acceleration
_rust_unit_distance_matrix = None
_rust_loocv_grid_search = None
_rust_bootstrap_trop_variance = None
elif _backend_env == 'rust':
# Force Rust mode - fail if not available
if not _rust_available:
Expand All @@ -61,4 +73,8 @@
'_rust_project_simplex',
'_rust_solve_ols',
'_rust_compute_robust_vcov',
# TROP estimator acceleration
'_rust_unit_distance_matrix',
'_rust_loocv_grid_search',
'_rust_bootstrap_trop_variance',
]
176 changes: 143 additions & 33 deletions diff_diff/trop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Estimators. *Working Paper*. https://arxiv.org/abs/2508.21536
"""

import logging
import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Union
Expand All @@ -25,11 +26,19 @@
import pandas as pd
from scipy import stats

logger = logging.getLogger(__name__)

try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict

from diff_diff._backend import (
HAS_RUST_BACKEND,
_rust_unit_distance_matrix,
_rust_loocv_grid_search,
_rust_bootstrap_trop_variance,
)
from diff_diff.results import _get_significance_stars
from diff_diff.utils import compute_confidence_interval, compute_p_value

Expand Down Expand Up @@ -489,7 +498,11 @@ def _precompute_structures(
"""
# Compute pairwise unit distances (for all observation-specific weights)
# Following Equation 3 (page 7): RMSE between units over pre-treatment
unit_dist_matrix = self._compute_all_unit_distances(Y, D, n_units, n_periods)
if HAS_RUST_BACKEND and _rust_unit_distance_matrix is not None:
# Use Rust backend for parallel distance computation (4-8x speedup)
unit_dist_matrix = _rust_unit_distance_matrix(Y, D.astype(np.float64))
else:
unit_dist_matrix = self._compute_all_unit_distances(Y, D, n_units, n_periods)

# Pre-compute time distance vectors for each target period
# Time distance: |t - s| for all s and each target t
Expand Down Expand Up @@ -759,20 +772,51 @@ def fit(
Y, D, control_unit_idx, n_units, n_periods
)

for lambda_time in self.lambda_time_grid:
for lambda_unit in self.lambda_unit_grid:
for lambda_nn in self.lambda_nn_grid:
try:
score = self._loocv_score_obs_specific(
Y, D, control_mask, control_unit_idx,
lambda_time, lambda_unit, lambda_nn,
n_units, n_periods
)
if score < best_score:
best_score = score
best_lambda = (lambda_time, lambda_unit, lambda_nn)
except (np.linalg.LinAlgError, ValueError):
continue
# Use Rust backend for parallel LOOCV grid search (10-50x speedup)
if HAS_RUST_BACKEND and _rust_loocv_grid_search is not None:
try:
# Prepare inputs for Rust function
control_mask_u8 = control_mask.astype(np.uint8)
time_dist_matrix = self._precomputed["time_dist_matrix"].astype(np.int64)
unit_dist_matrix = self._precomputed["unit_dist_matrix"]
control_unit_idx_i64 = control_unit_idx.astype(np.int64)

lambda_time_arr = np.array(self.lambda_time_grid, dtype=np.float64)
lambda_unit_arr = np.array(self.lambda_unit_grid, dtype=np.float64)
lambda_nn_arr = np.array(self.lambda_nn_grid, dtype=np.float64)

best_lt, best_lu, best_ln, best_score = _rust_loocv_grid_search(
Y, D.astype(np.float64), control_mask_u8, control_unit_idx_i64,
unit_dist_matrix, time_dist_matrix,
lambda_time_arr, lambda_unit_arr, lambda_nn_arr,
self.max_loocv_samples, self.max_iter, self.tol,
self.seed if self.seed is not None else 0
)
best_lambda = (best_lt, best_lu, best_ln)
except Exception as e:
# Fall back to Python implementation on error
logger.debug(
"Rust LOOCV grid search failed, falling back to Python: %s", e
)
best_lambda = None
best_score = np.inf

# Fall back to Python implementation if Rust unavailable or failed
if best_lambda is None:
for lambda_time in self.lambda_time_grid:
for lambda_unit in self.lambda_unit_grid:
for lambda_nn in self.lambda_nn_grid:
try:
score = self._loocv_score_obs_specific(
Y, D, control_mask, control_unit_idx,
lambda_time, lambda_unit, lambda_nn,
n_units, n_periods
)
if score < best_score:
best_score = score
best_lambda = (lambda_time, lambda_unit, lambda_nn)
except (np.linalg.LinAlgError, ValueError):
continue

if best_lambda is None:
warnings.warn(
Expand Down Expand Up @@ -841,7 +885,7 @@ def fit(
if self.variance_method == "bootstrap":
se, bootstrap_dist = self._bootstrap_variance(
data, outcome, treatment, unit, time, post_periods_list,
best_lambda
best_lambda, Y=Y, D=D, control_unit_idx=control_unit_idx
)
else:
se, bootstrap_dist = self._jackknife_variance(
Expand Down Expand Up @@ -1285,41 +1329,107 @@ def _bootstrap_variance(
time: str,
post_periods: List[Any],
optimal_lambda: Tuple[float, float, float],
Y: Optional[np.ndarray] = None,
D: Optional[np.ndarray] = None,
control_unit_idx: Optional[np.ndarray] = None,
) -> Tuple[float, np.ndarray]:
"""
Compute bootstrap standard error using unit-level block bootstrap.

When the optional Rust backend is available and the matrix parameters
(Y, D, control_unit_idx) are provided, uses parallelized Rust
implementation for 5-15x speedup. Falls back to Python implementation
if Rust is unavailable or if matrix parameters are not provided.

Parameters
----------
data : pd.DataFrame
Original data.
Original data in long format with unit, time, outcome, and treatment.
outcome : str
Outcome column name.
Name of the outcome column in data.
treatment : str
Treatment column name.
Name of the treatment indicator column in data.
unit : str
Unit column name.
Name of the unit identifier column in data.
time : str
Time column name.
Name of the time period column in data.
post_periods : list
Post-treatment periods.
optimal_lambda : tuple
Optimal (lambda_time, lambda_unit, lambda_nn).
List of post-treatment time periods.
optimal_lambda : tuple of float
Optimal tuning parameters (lambda_time, lambda_unit, lambda_nn)
from cross-validation. Used for model estimation in each bootstrap.
Y : np.ndarray, optional
Outcome matrix of shape (n_periods, n_units). Required for Rust
backend acceleration. If None, falls back to Python implementation.
D : np.ndarray, optional
Treatment indicator matrix of shape (n_periods, n_units) where
D[t,i]=1 indicates unit i is treated at time t. Required for Rust
backend acceleration.
control_unit_idx : np.ndarray, optional
Array of indices for control units (never-treated). Required for
Rust backend acceleration.

Returns
-------
tuple
(se, bootstrap_estimates).
se : float
Bootstrap standard error of the ATT estimate.
bootstrap_estimates : np.ndarray
Array of ATT estimates from each bootstrap iteration. Length may
be less than n_bootstrap if some iterations failed.

Notes
-----
Uses unit-level block bootstrap where entire unit time series are
resampled with replacement. This preserves within-unit correlation
structure and is appropriate for panel data.
"""
lambda_time, lambda_unit, lambda_nn = optimal_lambda

# Try Rust backend for parallel bootstrap (5-15x speedup)
if (HAS_RUST_BACKEND and _rust_bootstrap_trop_variance is not None
and self._precomputed is not None and Y is not None
and D is not None and control_unit_idx is not None):
try:
# Prepare inputs
treated_observations = self._precomputed["treated_observations"]
treated_t = np.array([t for t, i in treated_observations], dtype=np.int64)
treated_i = np.array([i for t, i in treated_observations], dtype=np.int64)
control_mask = self._precomputed["control_mask"]

bootstrap_estimates, se = _rust_bootstrap_trop_variance(
Y, D.astype(np.float64),
control_mask.astype(np.uint8),
control_unit_idx.astype(np.int64),
treated_t, treated_i,
self._precomputed["unit_dist_matrix"],
self._precomputed["time_dist_matrix"].astype(np.int64),
lambda_time, lambda_unit, lambda_nn,
self.n_bootstrap, self.max_iter, self.tol,
self.seed if self.seed is not None else 0
)

if len(bootstrap_estimates) >= 10:
return float(se), bootstrap_estimates
# Fall through to Python if too few bootstrap samples
logger.debug(
"Rust bootstrap returned only %d samples, falling back to Python",
len(bootstrap_estimates)
)
except Exception as e:
logger.debug(
"Rust bootstrap variance failed, falling back to Python: %s", e
)

# Python implementation (fallback)
rng = np.random.default_rng(self.seed)
all_units = data[unit].unique()
n_units = len(all_units)
n_units_data = len(all_units)

bootstrap_estimates = []
bootstrap_estimates_list = []

for b in range(self.n_bootstrap):
for _ in range(self.n_bootstrap):
# Sample units with replacement
sampled_units = rng.choice(all_units, size=n_units, replace=True)
sampled_units = rng.choice(all_units, size=n_units_data, replace=True)

# Create bootstrap sample with unique unit IDs
boot_data = pd.concat([
Expand All @@ -1333,11 +1443,11 @@ def _bootstrap_variance(
boot_data, outcome, treatment, unit, time,
post_periods, optimal_lambda
)
bootstrap_estimates.append(att)
bootstrap_estimates_list.append(att)
except (ValueError, np.linalg.LinAlgError, KeyError):
continue

bootstrap_estimates = np.array(bootstrap_estimates)
bootstrap_estimates = np.array(bootstrap_estimates_list)

if len(bootstrap_estimates) < 10:
warnings.warn(
Expand All @@ -1349,7 +1459,7 @@ def _bootstrap_variance(
return 0.0, np.array([])

se = np.std(bootstrap_estimates, ddof=1)
return se, bootstrap_estimates
return float(se), bootstrap_estimates

def _jackknife_variance(
self,
Expand Down
6 changes: 6 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use pyo3::prelude::*;

mod bootstrap;
mod linalg;
mod trop;
mod weights;

/// A Python module implemented in Rust for diff-diff acceleration.
Expand All @@ -26,6 +27,11 @@ fn _rust_backend(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(linalg::solve_ols, m)?)?;
m.add_function(wrap_pyfunction!(linalg::compute_robust_vcov, m)?)?;

// TROP estimator acceleration
m.add_function(wrap_pyfunction!(trop::compute_unit_distance_matrix, m)?)?;
m.add_function(wrap_pyfunction!(trop::loocv_grid_search, m)?)?;
m.add_function(wrap_pyfunction!(trop::bootstrap_trop_variance, m)?)?;

// Version info
m.add("__version__", env!("CARGO_PKG_VERSION"))?;

Expand Down
Loading