Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
0e6524e
Fix rank-deficient matrix handling in OLS solver
igerber Jan 20, 2026
31371fa
Implement R-style rank deficiency handling instead of silent SVD trun…
igerber Jan 20, 2026
293e17f
Address code review feedback for R-style rank deficiency handling
igerber Jan 20, 2026
768ceea
Use dedicated simultaneous-treatment dataset for MultiPeriodDiD tutorial
igerber Jan 20, 2026
c94e8b2
Address third round of code review feedback
igerber Jan 20, 2026
afcc21f
Address fourth round of code review feedback
igerber Jan 20, 2026
3e480a9
Address fifth round of code review feedback
igerber Jan 20, 2026
cdcb765
Address P2 and P3 issues from code review
igerber Jan 20, 2026
8c18cac
Address P3 review feedback: skip_rank_check, TWFE warnings, tests, docs
igerber Jan 20, 2026
fdd580b
Fix MultiPeriodDiD rank-deficient vcov/df computation (P1)
igerber Jan 20, 2026
d6fc08b
Document P3 issues in TODO.md: check_finite bypass limitation
igerber Jan 20, 2026
c62c9a1
Align rank tolerance with R's lm() default (P1)
igerber Jan 20, 2026
9db9e6c
Fix test_near_collinear_covariates for new R-style tolerance
igerber Jan 20, 2026
4c62776
Fix average ATT inference for rank-deficient designs (P1) and docstri…
igerber Jan 20, 2026
c49edf6
R-style NA propagation for avg_att (P1) and scipy_lstsq tolerance (P1)
igerber Jan 20, 2026
f35e57c
Expose rank_deficient_action parameter at estimator level (P2)
igerber Jan 20, 2026
df5d62a
Propagate rank_deficient_action to all estimators (P2)
igerber Jan 20, 2026
e054cc3
Add REGISTRY.md docs and Rust NaN vcov fallback (P1, P2)
igerber Jan 20, 2026
41120ae
Fix get_params() and NaN vcov fallback (P2 round 2)
igerber Jan 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,20 @@ pytest tests/test_rust_backend.py -v
- Integrated with `TwoWayFixedEffects.decompose()` method

- **`diff_diff/linalg.py`** - Unified linear algebra backend (v1.4.0+):
- `solve_ols()` - OLS solver using scipy's gelsy LAPACK driver (QR-based, faster than SVD)
- `solve_ols()` - OLS solver with R-style rank deficiency handling
- `_detect_rank_deficiency()` - Detect linearly dependent columns via pivoted QR
- `compute_robust_vcov()` - Vectorized HC1 and cluster-robust variance-covariance estimation
- `compute_r_squared()` - R-squared and adjusted R-squared computation
- `LinearRegression` - High-level OLS helper class with unified coefficient extraction and inference
- `InferenceResult` - Dataclass container for coefficient-level inference (SE, t-stat, p-value, CI)
- Single optimization point for all estimators (reduces code duplication)
- Cluster-robust SEs use pandas groupby instead of O(n × clusters) loop
- **Rank deficiency handling** (R-style):
- Detects rank-deficient matrices using pivoted QR decomposition
- `rank_deficient_action` parameter: "warn" (default), "error", or "silent"
- Dropped columns have NaN coefficients (like R's `lm()`)
- VCoV matrix has NaN for rows/cols of dropped coefficients
- Warnings include column names when provided

- **`diff_diff/_backend.py`** - Backend detection and configuration (v2.0.0):
- Detects optional Rust backend availability
Expand Down Expand Up @@ -240,16 +247,25 @@ diff-diff achieved significant performance improvements in v1.4.0, now **faster

All estimators use a single optimized OLS/SE implementation:

- **scipy.linalg.lstsq with 'gelsy' driver**: QR-based solving, faster than NumPy's default SVD-based solver
- **R-style rank deficiency handling**: Uses pivoted QR to detect linearly dependent columns, drops them, sets NaN for their coefficients, and emits informative warnings (following R's `lm()` approach)
- **Vectorized cluster-robust SE**: Uses pandas groupby aggregation instead of O(n × clusters) Python loop
- **Single optimization point**: Changes to `linalg.py` benefit all estimators

```python
# All estimators import from linalg.py
from diff_diff.linalg import solve_ols, compute_robust_vcov

# Example usage
# Example usage (warns on rank deficiency, sets NaN for dropped coefficients)
coefficients, residuals, vcov = solve_ols(X, y, cluster_ids=cluster_ids)

# Suppress warning or raise error:
coefficients, residuals, vcov = solve_ols(X, y, rank_deficient_action="silent") # no warning
coefficients, residuals, vcov = solve_ols(X, y, rank_deficient_action="error") # raises ValueError

# At estimator level (DifferenceInDifferences, MultiPeriodDiD):
from diff_diff import DifferenceInDifferences
did = DifferenceInDifferences(rank_deficient_action="error") # raises on collinear data
did = DifferenceInDifferences(rank_deficient_action="silent") # no warning
```

#### CallawaySantAnna Optimizations (`staggered.py`)
Expand Down
49 changes: 49 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ Current limitations that may affect users:
| MultiPeriodDiD wild bootstrap not supported | `estimators.py:1068-1074` | Low | Edge case |
| `predict()` raises NotImplementedError | `estimators.py:532-554` | Low | Rarely needed |

### ~~NaN Standard Errors for Rank-Deficient Matrices~~ (RESOLVED)

**Status**: Resolved in v2.2.0 with R-style rank deficiency handling.

**Solution**: The OLS solver now detects rank-deficient design matrices using pivoted QR decomposition and handles them following R's `lm()` approach:
- Warns users about dropped columns
- Sets NaN for coefficients of linearly dependent columns
- Computes valid SEs for identified (non-dropped) coefficients only
- Expands vcov matrix with NaN for dropped rows/columns

This is controlled by the `rank_deficient_action` parameter in `solve_ols()`:
- `"warn"` (default): Emit warning, set NA for dropped coefficients
- `"error"`: Raise ValueError
- `"silent"`: No warning, but still set NA for dropped coefficients

---

## Code Quality
Expand Down Expand Up @@ -143,3 +158,37 @@ Potential future optimizations:
- [ ] JIT compilation for bootstrap loops (numba)
- [ ] Sparse matrix handling for large fixed effects

### QR+SVD Redundancy in Rank Detection

**Background**: The current `solve_ols()` implementation performs both QR (for rank detection) and SVD (for solving) decompositions on rank-deficient matrices. This is technically redundant since SVD can determine rank directly.

**Current approach** (R-style, chosen for robustness):
1. QR with pivoting for rank detection (`_detect_rank_deficiency()`)
2. scipy's `lstsq` with 'gelsd' driver (SVD-based) for solving

**Why we use QR for rank detection**:
- QR with pivoting provides the canonical ordering of linearly dependent columns
- R's `lm()` uses this approach for consistent dropped-column reporting
- Ensures consistent column dropping across runs (SVD column selection can vary)

**Potential optimization** (future work):
- Skip QR when `rank_deficient_action="silent"` since we don't need column names
- Use SVD rank directly in the Rust backend (already implemented)
- Add `skip_rank_check` parameter for hot paths where matrix is known to be full-rank (implemented in v2.2.0)

**Priority**: Low - the QR overhead is minimal compared to SVD solve, and correctness is more important than micro-optimization.

### Incomplete `check_finite` Bypass

**Background**: The `solve_ols()` function accepts a `check_finite=False` parameter intended to skip NaN/Inf validation for performance in hot paths where data is known to be clean.

**Current limitation**: When `check_finite=False`, our explicit validation is skipped, but scipy's internal QR decomposition in `_detect_rank_deficiency()` still validates finite values. This means callers cannot fully bypass all finite checks.

**Impact**: Minimal - the scipy check is fast and only affects edge cases where users explicitly pass `check_finite=False` with non-finite data (which would be a bug in their code anyway).

**Potential fix** (future work):
- Pass `check_finite=False` through to scipy's QR call (requires scipy >= 1.9.0)
- Or skip `_detect_rank_deficiency()` entirely when `check_finite=False` and `_skip_rank_check=True`

**Priority**: Low - this is an edge case optimization that doesn't affect correctness.

109 changes: 78 additions & 31 deletions diff_diff/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ class DifferenceInDifferences:
seed : int, optional
Random seed for reproducibility when using bootstrap inference.
If None (default), results will vary between runs.
rank_deficient_action : str, default "warn"
Action when design matrix is rank-deficient (linearly dependent columns):
- "warn": Issue warning and drop linearly dependent columns (default)
- "error": Raise ValueError
- "silent": Drop columns silently without warning

Attributes
----------
Expand Down Expand Up @@ -120,7 +125,8 @@ def __init__(
inference: str = "analytical",
n_bootstrap: int = 999,
bootstrap_weights: str = "rademacher",
seed: Optional[int] = None
seed: Optional[int] = None,
rank_deficient_action: str = "warn",
):
self.robust = robust
self.cluster = cluster
Expand All @@ -129,6 +135,7 @@ def __init__(
self.n_bootstrap = n_bootstrap
self.bootstrap_weights = bootstrap_weights
self.seed = seed
self.rank_deficient_action = rank_deficient_action

self.is_fitted_ = False
self.results_ = None
Expand Down Expand Up @@ -283,6 +290,7 @@ def fit(
robust=self.robust,
cluster_ids=cluster_ids if self.inference != "wild_bootstrap" else None,
alpha=self.alpha,
rank_deficient_action=self.rank_deficient_action,
).fit(X, y, df_adjustment=n_absorbed_effects)

coefficients = reg.coefficients_
Expand Down Expand Up @@ -596,6 +604,7 @@ def get_params(self) -> Dict[str, Any]:
"n_bootstrap": self.n_bootstrap,
"bootstrap_weights": self.bootstrap_weights,
"seed": self.seed,
"rank_deficient_action": self.rank_deficient_action,
}

def set_params(self, **params) -> "DifferenceInDifferences":
Expand Down Expand Up @@ -873,29 +882,43 @@ def fit( # type: ignore[override]
var_names.append(col)

# Fit OLS using unified backend
coefficients, residuals, fitted, _ = solve_ols(
X, y, return_fitted=True, return_vcov=False
# Pass cluster_ids to solve_ols for proper vcov computation
# This handles rank-deficient matrices by returning NaN for dropped columns
cluster_ids = data[self.cluster].values if self.cluster is not None else None

# Note: Wild bootstrap for multi-period effects is complex (multiple coefficients)
# For now, we use analytical inference even if inference="wild_bootstrap"
coefficients, residuals, fitted, vcov = solve_ols(
X, y,
return_fitted=True,
return_vcov=True,
cluster_ids=cluster_ids,
column_names=var_names,
rank_deficient_action=self.rank_deficient_action,
)
r_squared = compute_r_squared(y, residuals)

# Degrees of freedom
df = len(y) - X.shape[1] - n_absorbed_effects
# Degrees of freedom using effective rank (non-NaN coefficients)
k_effective = int(np.sum(~np.isnan(coefficients)))
df = len(y) - k_effective - n_absorbed_effects

# Compute standard errors
# Note: Wild bootstrap for multi-period effects is complex (multiple coefficients)
# For now, we use analytical inference even if inference="wild_bootstrap"
if self.cluster is not None:
cluster_ids = data[self.cluster].values
vcov = compute_robust_vcov(X, residuals, cluster_ids)
elif self.robust:
vcov = compute_robust_vcov(X, residuals)
else:
# For non-robust, non-clustered case, we need homoskedastic vcov
# solve_ols returns HC1 by default, so compute homoskedastic if needed
if not self.robust and self.cluster is None:
n = len(y)
k = X.shape[1]
mse = np.sum(residuals**2) / (n - k)
mse = np.sum(residuals**2) / (n - k_effective)
# Use solve() instead of inv() for numerical stability
# solve(A, B) computes X where AX=B, so this yields (X'X)^{-1} * mse
vcov = np.linalg.solve(X.T @ X, mse * np.eye(k))
# Only compute for identified columns (non-NaN coefficients)
identified_mask = ~np.isnan(coefficients)
if np.all(identified_mask):
vcov = np.linalg.solve(X.T @ X, mse * np.eye(X.shape[1]))
else:
# For rank-deficient case, compute vcov on reduced matrix then expand
X_reduced = X[:, identified_mask]
vcov_reduced = np.linalg.solve(X_reduced.T @ X_reduced, mse * np.eye(X_reduced.shape[1]))
# Expand to full size with NaN for dropped columns
vcov = np.full((X.shape[1], X.shape[1]), np.nan)
vcov[np.ix_(identified_mask, identified_mask)] = vcov_reduced

# Extract period-specific treatment effects
period_effects = {}
Expand All @@ -922,19 +945,43 @@ def fit( # type: ignore[override]
effect_indices.append(idx)

# Compute average treatment effect
# Average ATT = mean of period-specific effects
avg_att = np.mean(effect_values)

# Standard error of average: need to account for covariance
# Var(avg) = (1/n^2) * sum of all elements in the sub-covariance matrix
n_post = len(post_periods)
sub_vcov = vcov[np.ix_(effect_indices, effect_indices)]
avg_var = np.sum(sub_vcov) / (n_post ** 2)
avg_se = np.sqrt(avg_var)

avg_t_stat = avg_att / avg_se if avg_se > 0 else 0.0
avg_p_value = compute_p_value(avg_t_stat, df=df)
avg_conf_int = compute_confidence_interval(avg_att, avg_se, self.alpha, df=df)
# R-style NA propagation: if ANY period effect is NaN, average is undefined
effect_arr = np.array(effect_values)

if np.any(np.isnan(effect_arr)):
# Some period effects are NaN (unidentified) - cannot compute valid average
# This follows R's default behavior where mean(c(1, 2, NA)) returns NA
avg_att = np.nan
avg_se = np.nan
avg_t_stat = np.nan
avg_p_value = np.nan
avg_conf_int = (np.nan, np.nan)
else:
# All effects identified - compute average normally
avg_att = float(np.mean(effect_arr))

# Standard error of average: need to account for covariance
n_post = len(post_periods)
sub_vcov = vcov[np.ix_(effect_indices, effect_indices)]
avg_var = np.sum(sub_vcov) / (n_post ** 2)

if np.isnan(avg_var) or avg_var < 0:
# Vcov has NaN (dropped columns) - propagate NaN
avg_se = np.nan
avg_t_stat = np.nan
avg_p_value = np.nan
avg_conf_int = (np.nan, np.nan)
else:
avg_se = float(np.sqrt(avg_var))
if avg_se > 0:
avg_t_stat = avg_att / avg_se
avg_p_value = compute_p_value(avg_t_stat, df=df)
avg_conf_int = compute_confidence_interval(avg_att, avg_se, self.alpha, df=df)
else:
# Zero SE (degenerate case)
avg_t_stat = np.nan
avg_p_value = np.nan
avg_conf_int = (np.nan, np.nan)

# Count observations
n_treated = int(np.sum(d))
Expand Down
Loading