Skip to content

Commit 44124ca

Browse files
igerberclaude
andcommitted
Guard W_max==0 division in twostep nuclear norm solver + update docstrings
- Add conditional threshold when W_max==0 to prevent ZeroDivisionError, matching Rust backend behavior (trop.rs:665) - Update Python and Rust docstrings to reflect correct FISTA/Nesterov acceleration formulas (L_f = 2·max(W), η = 1/(2·max(W))) - Add regression test for all-zero weights edge case Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6d0a8be commit 44124ca

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

diff_diff/trop.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1990,10 +1990,11 @@ def _weighted_nuclear_norm_solve(
19901990
paper's Equation 2 (page 7). The full objective is:
19911991
min_L Σ W_{ti}(R_{ti} - L_{ti})² + λ_nn||L||_*
19921992
1993-
This uses a proximal gradient / soft-impute approach (Mazumder et al. 2010):
1994-
L_{k+1} = prox_{λ||·||_*}(L_k + W ⊙ (R - L_k))
1995-
1996-
where W ⊙ denotes element-wise multiplication with normalized weights.
1993+
This uses proximal gradient descent (Mazumder et al. 2010) with
1994+
FISTA/Nesterov acceleration. Lipschitz constant L_f = 2·max(W),
1995+
step size η = 1/(2·max(W)), proximal threshold η·λ_nn:
1996+
G_k = L_k + (W/max(W)) ⊙ (R - L_k)
1997+
L_{k+1} = prox_{η·λ_nn·||·||_*}(G_k)
19971998
19981999
IMPORTANT: For observations with W=0 (treated observations), we keep
19992000
L values from the previous iteration rather than setting L = R, which
@@ -2068,7 +2069,8 @@ def _weighted_nuclear_norm_solve(
20682069

20692070
# Proximal step: soft-threshold singular values
20702071
L_prev = L.copy()
2071-
L = self._soft_threshold_svd(gradient_step, lambda_nn / (2.0 * W_max))
2072+
threshold = lambda_nn / (2.0 * W_max) if W_max > 0 else lambda_nn / 2.0
2073+
L = self._soft_threshold_svd(gradient_step, threshold)
20722074
t_fista = t_fista_new
20732075

20742076
# Check convergence

rust/src/trop.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -620,9 +620,10 @@ fn compute_weight_matrix(
620620
///
621621
/// Minimizes: Σ W_{ti}(Y_{ti} - α_i - β_t - L_{ti})² + λ_nn||L||_*
622622
///
623-
/// Paper alignment: Uses weighted proximal gradient for L update:
624-
/// L ← prox_{η·λ_nn·||·||_*}(L + η·(W ⊙ (R - L)))
625-
/// where η ≤ 1/max(W) for convergence.
623+
/// Paper alignment: Uses weighted proximal gradient for L update with
624+
/// Lipschitz constant L_f = 2·max(W), step size η = 1/(2·max(W)):
625+
/// G = L + (W/max(W)) ⊙ (R - L)
626+
/// L ← prox_{η·λ_nn·||·||_*}(G)
626627
///
627628
/// Returns None if estimation fails due to numerical issues.
628629
#[allow(clippy::too_many_arguments)]

tests/test_trop.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2693,6 +2693,26 @@ def test_twostep_nonuniform_weights_objective(self):
26932693
f"Nuclear norm not reduced: {nuclear_norm_L} >= {nuclear_norm_R}"
26942694
)
26952695

2696+
def test_zero_weights_no_division_error(self):
2697+
"""Verify solver handles all-zero weights without ZeroDivisionError."""
2698+
rng = np.random.default_rng(99)
2699+
Y = rng.normal(0, 1, (6, 4))
2700+
W = np.zeros((6, 4))
2701+
L_init = rng.normal(0, 1, (6, 4))
2702+
2703+
trop_est = TROP(method="twostep", n_bootstrap=2)
2704+
result = trop_est._weighted_nuclear_norm_solve(
2705+
Y=Y,
2706+
W=W,
2707+
L_init=L_init,
2708+
alpha=np.zeros(4),
2709+
beta=np.zeros(6),
2710+
lambda_nn=0.3,
2711+
)
2712+
2713+
assert np.isfinite(result).all(), "Result contains NaN or Inf"
2714+
assert result.shape == (6, 4), f"Expected (6, 4), got {result.shape}"
2715+
26962716

26972717
class TestTROPJointMethod:
26982718
"""Tests for TROP method='joint'.

0 commit comments

Comments
 (0)