Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 62 additions & 63 deletions diff_diff/trop.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@
from diff_diff.utils import compute_confidence_interval, compute_p_value


# Sentinel value for "disabled" mode in LOOCV parameter search
# Following paper's footnote 2: λ=∞ disables the corresponding component
# Sentinel value for "disabled" λ_nn in LOOCV parameter search.
# Per paper's footnote 2: λ_nn=∞ disables the factor model (L=0).
# For λ_time and λ_unit, 0.0 means disabled (uniform weights) per Eq. 3:
# exp(-0 × dist) = 1 for all distances.
_LAMBDA_INF: float = float('inf')


Expand Down Expand Up @@ -116,15 +118,14 @@ class TROPResults:
treatment_effects : dict
Individual treatment effects for each treated (unit, time) pair.
lambda_time : float
Selected time weight decay parameter from grid. Note: infinity values
are converted internally (∞ → 0.0 for uniform weights) for computation.
Selected time weight decay parameter from grid. 0.0 = uniform time
weights (disabled) per Eq. 3.
lambda_unit : float
Selected unit weight decay parameter from grid. Note: infinity values
are converted internally (∞ → 0.0 for uniform weights) for computation.
Selected unit weight decay parameter from grid. 0.0 = uniform unit
weights (disabled) per Eq. 3.
lambda_nn : float
Selected nuclear norm regularization parameter from grid. Note: infinity
values are converted internally (∞ → 1e10, factor model disabled) for
computation.
Selected nuclear norm regularization parameter from grid. inf = factor
model disabled (L=0); converted to 1e10 internally for computation.
factor_matrix : np.ndarray
Estimated low-rank factor matrix L (n_periods x n_units).
effective_rank : float
Expand Down Expand Up @@ -382,11 +383,14 @@ class TROP:
penalty is finite.

lambda_time_grid : list, optional
Grid of time weight decay parameters. Default: [0, 0.1, 0.5, 1, 2, 5].
Grid of time weight decay parameters. 0.0 = uniform weights (disabled).
Must not contain inf. Default: [0, 0.1, 0.5, 1, 2, 5].
lambda_unit_grid : list, optional
Grid of unit weight decay parameters. Default: [0, 0.1, 0.5, 1, 2, 5].
Grid of unit weight decay parameters. 0.0 = uniform weights (disabled).
Must not contain inf. Default: [0, 0.1, 0.5, 1, 2, 5].
lambda_nn_grid : list, optional
Grid of nuclear norm regularization parameters. Default: [0, 0.01, 0.1, 1].
Grid of nuclear norm regularization parameters. inf = factor model
disabled (L=0). Default: [0, 0.01, 0.1, 1].
max_iter : int, default=100
Maximum iterations for nuclear norm optimization.
tol : float, default=1e-6
Expand Down Expand Up @@ -491,6 +495,21 @@ def __init__(
f"got '{variance_method}'"
)

# Validate that time/unit grids do not contain inf.
# Per Athey et al. (2025) Eq. 3, λ_time=0 and λ_unit=0 give uniform
# weights (exp(-0 × dist) = 1). Using inf is a misunderstanding of
# the paper's convention. Only λ_nn=∞ is valid (disables factor model).
for grid_name, grid_vals in [
("lambda_time_grid", self.lambda_time_grid),
("lambda_unit_grid", self.lambda_unit_grid),
]:
if any(np.isinf(v) for v in grid_vals):
raise ValueError(
f"{grid_name} must not contain inf. Use 0.0 for uniform "
f"weights (disabled) per Athey et al. (2025) Eq. 3: "
f"exp(-0 × dist) = 1 for all distances."
)

# Internal state
self.results_: Optional[TROPResults] = None
self.is_fitted_: bool = False
Expand Down Expand Up @@ -708,10 +727,11 @@ def _univariate_loocv_search(

Following paper's footnote 2, this performs a univariate grid search
for one tuning parameter while holding others fixed. The fixed_params
can include _LAMBDA_INF values to disable specific components:
use 0.0 for disabled time/unit weights and _LAMBDA_INF for disabled
factor model:
- lambda_nn = inf: Skip nuclear norm regularization (L=0)
- lambda_time = inf: Uniform time weights (treated as 0)
- lambda_unit = inf: Uniform unit weights (treated as 0)
- lambda_time = 0.0: Uniform time weights (exp(-0×dist)=1)
- lambda_unit = 0.0: Uniform unit weights (exp(-0×dist)=1)

Parameters
----------
Expand All @@ -732,7 +752,7 @@ def _univariate_loocv_search(
grid : List[float]
Grid of values to search over.
fixed_params : Dict[str, float]
Fixed values for other parameters. May include _LAMBDA_INF.
Fixed values for other parameters. May include _LAMBDA_INF for lambda_nn.

Returns
-------
Expand All @@ -745,22 +765,14 @@ def _univariate_loocv_search(
for value in grid:
params = {**fixed_params, param_name: value}

# Convert inf values to 0 for computation (inf means "disabled" = uniform weights)
lambda_time = params.get('lambda_time', 0.0)
lambda_unit = params.get('lambda_unit', 0.0)
lambda_nn = params.get('lambda_nn', 0.0)

# Handle infinity as "disabled" mode
# Per paper Equations 2-3:
# - λ_time/λ_unit=∞ → exp(-∞×dist)→0 for dist>0, uniform weights → use 0.0
# - λ_nn=∞ → infinite penalty → L≈0 (factor model disabled) → use 1e10
# Note: λ_nn=0 means NO regularization (full-rank L), opposite of "disabled"
if np.isinf(lambda_time):
lambda_time = 0.0 # Uniform time weights
if np.isinf(lambda_unit):
lambda_unit = 0.0 # Uniform unit weights
# Convert λ_nn=∞ → large finite value (factor model disabled, L≈0)
# λ_time and λ_unit use 0.0 for uniform weights per Eq. 3 (no inf conversion needed)
if np.isinf(lambda_nn):
lambda_nn = 1e10 # Very large → L≈0 (factor model disabled)
lambda_nn = 1e10

try:
score = self._loocv_score_obs_specific(
Expand Down Expand Up @@ -1423,9 +1435,9 @@ def _fit_joint(
for lambda_time_val in self.lambda_time_grid:
for lambda_unit_val in self.lambda_unit_grid:
for lambda_nn_val in self.lambda_nn_grid:
# Convert infinity values
lt = 0.0 if np.isinf(lambda_time_val) else lambda_time_val
lu = 0.0 if np.isinf(lambda_unit_val) else lambda_unit_val
# Convert λ_nn=∞ → large finite value (factor model disabled)
lt = lambda_time_val
lu = lambda_unit_val
ln = 1e10 if np.isinf(lambda_nn_val) else lambda_nn_val

try:
Expand All @@ -1451,13 +1463,10 @@ def _fit_joint(

# Final estimation with best parameters
lambda_time, lambda_unit, lambda_nn = best_lambda
original_lambda_time, original_lambda_unit, original_lambda_nn = best_lambda
original_lambda_nn = lambda_nn

# Convert infinity values for computation
if np.isinf(lambda_time):
lambda_time = 0.0
if np.isinf(lambda_unit):
lambda_unit = 0.0
# Convert λ_nn=∞ → large finite value (factor model disabled, L≈0)
# λ_time and λ_unit use 0.0 for uniform weights directly (no conversion needed)
if np.isinf(lambda_nn):
lambda_nn = 1e10

Expand Down Expand Up @@ -1535,8 +1544,8 @@ def _fit_joint(
unit_effects=unit_effects_dict,
time_effects=time_effects_dict,
treatment_effects=treatment_effects,
lambda_time=original_lambda_time,
lambda_unit=original_lambda_unit,
lambda_time=lambda_time,
lambda_unit=lambda_unit,
lambda_nn=original_lambda_nn,
factor_matrix=L,
effective_rank=effective_rank,
Expand Down Expand Up @@ -1866,9 +1875,9 @@ def fit(
TROPResults
Object containing the ATT estimate, standard error,
factor estimates, and tuning parameters. The lambda_*
attributes show the selected grid values. Infinity values
(∞) are converted internally: λ_time/λ_unit=∞ → 0.0 (uniform
weights), λ_nn=∞ → 1e10 (factor model disabled).
attributes show the selected grid values. For λ_time and
λ_unit, 0.0 means uniform weights; inf is not accepted.
For λ_nn, ∞ is converted to 1e10 (factor model disabled).
"""
# Validate inputs
required_cols = [outcome, treatment, unit, time]
Expand Down Expand Up @@ -2053,11 +2062,11 @@ def fit(
{'lambda_unit': 0.0, 'lambda_nn': _LAMBDA_INF}
)

# λ_nn search: fix λ_time= (uniform time weights), λ_unit=0
# λ_nn search: fix λ_time=0 (uniform time weights), λ_unit=0
lambda_nn_init, _ = self._univariate_loocv_search(
Y, D, control_mask, control_unit_idx, n_units, n_periods,
'lambda_nn', self.lambda_nn_grid,
{'lambda_time': _LAMBDA_INF, 'lambda_unit': 0.0}
{'lambda_time': 0.0, 'lambda_unit': 0.0}
)

# λ_unit search: fix λ_nn=∞, λ_time=0
Expand Down Expand Up @@ -2099,24 +2108,16 @@ def fit(
self._optimal_lambda = best_lambda
lambda_time, lambda_unit, lambda_nn = best_lambda

# Convert infinity values for final estimation (matching LOOCV conversion)
# This ensures final estimation uses the same effective parameters that LOOCV evaluated.
# See REGISTRY.md "λ=∞ implementation" for rationale.
#
# IMPORTANT: Store original grid values for results, use converted for computation.
# This lets users see what was selected from their grid, while ensuring consistent
# behavior between point estimation and variance estimation.
original_lambda_time, original_lambda_unit, original_lambda_nn = best_lambda

if np.isinf(lambda_time):
lambda_time = 0.0 # Uniform time weights
if np.isinf(lambda_unit):
lambda_unit = 0.0 # Uniform unit weights
# Store original λ_nn for results (only λ_nn needs original→effective conversion).
# λ_time and λ_unit use 0.0 for uniform weights directly per Eq. 3.
original_lambda_nn = lambda_nn

# Convert λ_nn=∞ → large finite value (factor model disabled, L≈0)
if np.isinf(lambda_nn):
lambda_nn = 1e10 # Very large → L≈0 (factor model disabled)
lambda_nn = 1e10

# Create effective_lambda with converted values for ALL downstream computation
# This ensures variance estimation uses the same parameters as point estimation
# effective_lambda with converted λ_nn for ALL downstream computation
# (variance estimation uses the same parameters as point estimation)
effective_lambda = (lambda_time, lambda_unit, lambda_nn)

# Step 2: Final estimation - per-observation model fitting following Algorithm 2
Expand Down Expand Up @@ -2214,10 +2215,8 @@ def fit(
unit_effects=unit_effects_dict,
time_effects=time_effects_dict,
treatment_effects=treatment_effects,
# Store ORIGINAL grid values (possibly inf) so users see what was selected.
# Internally, infinity values are converted for computation (see effective_lambda).
lambda_time=original_lambda_time,
lambda_unit=original_lambda_unit,
lambda_time=lambda_time,
lambda_unit=lambda_unit,
lambda_nn=original_lambda_nn,
factor_matrix=L_hat,
effective_rank=effective_rank,
Expand Down
53 changes: 32 additions & 21 deletions docs/methodology/REGISTRY.md
Original file line number Diff line number Diff line change
Expand Up @@ -488,30 +488,40 @@ has D=1), ATT will be incorrect - document this clearly.
- No separate "post_periods" concept - D matrix is the sole input for treatment timing
- Supports general assignment patterns including staggered adoption

*Estimator equation (as implemented):*
*Estimator equation (as implemented, Section 2.2):*

Factor model:
Working model (separating unit/time FE from regularized factor component):
```
Y_it = L_it + τ D_it + ε_it
Y_it(0) = α_i + β_t + L_it + ε_it, E[ε_it | L] = 0
```
where L = UΣV' is low-rank factor structure.
where α_i are unit fixed effects, β_t are time fixed effects, and L = UΣV' is a low-rank
factor structure. The FE are estimated separately from L because L is regularized but
the fixed effects are not.

Factor estimation via nuclear norm regularization:
Optimization (Equation 2):
```
= argmin_L ||Y_control - L||_F² + λ_nn ||L||_*
(α̂, β̂, L̂) = argmin_{α,β,L} Σ_j Σ_s θ_s^{i,t} ω_j^{i,t} (1-W_js)(Y_js - α_j - β_s - L_js)² + λ_nn ||L||_*
```
Solved via soft-thresholding of singular values:
Solved via alternating minimization with soft-thresholding of singular values for L:
```
L̂ = U × soft_threshold(Σ, λ_nn) × V'
```

Unit weights:
Per-observation weights (Equation 3):
```
ω_j = exp(-λ_unit × d(j, treated)) / Σ_k exp(-λ_unit × d(k, treated))
θ_s^{i,t}(λ) = exp(-λ_time × |t - s|)

ω_j^{i,t}(λ) = exp(-λ_unit × dist^unit_{-t}(j, i))

dist^unit_{-t}(j, i) = (Σ_u 1{u≠t}(1-W_iu)(1-W_ju)(Y_iu - Y_ju)² / Σ_u 1{u≠t}(1-W_iu)(1-W_ju))^{1/2}
```
where d(j, treated) is RMSE distance to treated units in pre-period.
Note: weights are per-(i,t) observation-specific. The distance formula excludes the
target period t and uses only periods where both units are untreated (W=0).

Time weights: analogous construction for periods.
*Special cases (Section 2.2):*
- λ_nn=∞, ω_j=θ_s=1 (uniform weights) → recovers DID/TWFE
- ω_j=θ_s=1, λ_nn<∞ → recovers Matrix Completion (Athey et al. 2021)
- λ_nn=∞ with specific ω_j, θ_s → recovers SC/SDID

*LOOCV tuning parameter selection (Equation 5, Footnote 2):*
```
Expand All @@ -521,13 +531,15 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
- **Two-stage procedure** (per paper's footnote 2):
- Stage 1: Univariate grid searches with extreme fixed values
- λ_time search: fix λ_unit=0, λ_nn=∞ (disabled)
- λ_nn search: fix λ_time= (uniform time weights), λ_unit=0
- λ_nn search: fix λ_time=0 (uniform time weights), λ_unit=0
- λ_unit search: fix λ_nn=∞, λ_time=0
- Stage 2: Cycling (coordinate descent) until convergence
- **"Disabled" parameter semantics** (per paper Equations 2-3):
- `λ_time=∞` or `λ_unit=∞`: Converts to `0.0` internally → exp(-0×dist)=1 → uniform weights
- `λ_nn=∞`: Converts to `1e10` internally → very large penalty → L≈0 (factor model off, recovers DID/TWFE)
- **"Disabled" parameter semantics** (per paper Section 4.3, Table 5, Footnote 2):
- `λ_time=0`: Uniform time weights (disabled), because exp(-0 × dist) = 1
- `λ_unit=0`: Uniform unit weights (disabled), because exp(-0 × dist) = 1
- `λ_nn=∞`: Factor model disabled (L=0), because infinite penalty; converted to `1e10` internally
- **Note**: `λ_nn=0` means NO regularization (full-rank L), which is the OPPOSITE of "disabled"
- **Validation**: `lambda_time_grid` and `lambda_unit_grid` must not contain inf. A `ValueError` is raised if they do, guiding users to use 0.0 for uniform weights per Eq. 3.
- **Subsampling**: max_loocv_samples (default 100) for computational tractability
- This subsamples control observations, NOT parameter combinations
- Increases precision at cost of computation; increase for more precise tuning
Expand All @@ -538,21 +550,20 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²
- This ensures λ selection only considers fully estimable combinations

*Standard errors:*
- Default: Block bootstrap preserving panel structure
- Alternative: Jackknife (leave-one-unit-out)
- Default: Block bootstrap preserving panel structure (Algorithm 3)
- Alternative: Jackknife (leave-one-unit-out) — **implementation addition** not described in the paper

*Edge cases:*
- Rank selection: automatic via cross-validation, information criterion, or elbow
- Zero singular values: handled by soft-thresholding
- Extreme distances: weights regularized to prevent degeneracy
- LOOCV fit failures: returns Q(λ) = ∞ on first failure (per Equation 5 requirement that Q sums over ALL D==0 cells); if all parameter combinations fail, falls back to defaults (1.0, 1.0, 0.1)
- **λ=∞ implementation**: Infinity values are converted in both LOOCV search and final estimation:
- λ_time=∞ or λ_unit=∞ → 0.0 (uniform weights via exp(-0×d)=1)
- **λ_nn=∞ implementation**: Only λ_nn uses infinity (converted to 1e10 for computation):
- λ_nn=∞ → 1e10 (large penalty → L≈0, factor model disabled)
- Conversion applied to grid values during LOOCV (including Rust backend)
- Conversion applied to selected values for point estimation
- Conversion applied to selected values for variance estimation (ensures SE matches ATT)
- **Results storage**: `TROPResults` stores *original* grid values (e.g., inf), while computations use converted values. This lets users see what was selected from their grid.
- **Results storage**: `TROPResults` stores *original* λ_nn value (inf), while computations use 1e10. λ_time and λ_unit store their selected values directly (0.0 = uniform).
- **Empty control observations**: If LOOCV control observations become empty (edge case during subsampling), returns Q(λ) = ∞ with warning. A score of 0.0 would incorrectly "win" over legitimate parameters.
- **Infinite LOOCV score handling**: If best LOOCV score is infinite, `best_lambda` is set to None, triggering defaults fallback
- Validation: requires at least 2 periods before first treatment
Expand All @@ -572,7 +583,7 @@ Q(λ) = Σ_{j,s: D_js=0} [τ̂_js^loocv(λ)]²

**Requirements checklist:**
- [x] Factor matrix estimated via soft-threshold SVD
- [x] Unit weights: `exp(-λ_unit × distance)` with normalization
- [x] Unit weights: `exp(-λ_unit × distance)` (unnormalized, matching Eq. 2)
- [x] LOOCV implemented for tuning parameter selection
- [x] LOOCV uses SUM of squared errors per Equation 5
- [x] Multiple rank selection methods: cv, ic, elbow
Expand Down
Loading