diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index f1bcda19..e5a13efa 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -15,7 +15,7 @@ jobs: - name: Install dependencies run: | - dnf install -y openssl-devel perl-IPC-Cmd + dnf install -y openssl-devel perl-IPC-Cmd openblas-devel curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --no-modify-path /opt/python/cp312-cp312/bin/pip install maturin @@ -25,7 +25,7 @@ jobs: for pyver in 39 310 311 312 313; do pybin="/opt/python/cp${pyver}-cp${pyver}/bin/python" if [ -f "$pybin" ]; then - /opt/python/cp312-cp312/bin/maturin build --release --out dist -i "$pybin" --features extension-module + /opt/python/cp312-cp312/bin/maturin build --release --out dist -i "$pybin" --features extension-module,openblas fi done @@ -58,7 +58,7 @@ jobs: run: pip install maturin - name: Build wheel - run: maturin build --release --out dist --features extension-module + run: maturin build --release --out dist --features extension-module,accelerate - name: Upload wheels uses: actions/upload-artifact@v4 diff --git a/.github/workflows/rust-test.yml b/.github/workflows/rust-test.yml index 293f346d..5f07d76e 100644 --- a/.github/workflows/rust-test.yml +++ b/.github/workflows/rust-test.yml @@ -39,9 +39,21 @@ jobs: - name: Install Rust toolchain uses: dtolnay/rust-toolchain@stable + - name: Install OpenBLAS (Linux) + if: runner.os == 'Linux' + run: sudo apt-get update && sudo apt-get install -y libopenblas-dev + - name: Run Rust tests working-directory: rust - run: cargo test --verbose + run: | + if [ "${{ runner.os }}" == "macOS" ]; then + cargo test --verbose --features accelerate + elif [ "${{ runner.os }}" == "Linux" ]; then + cargo test --verbose --features openblas + else + cargo test --verbose + fi + shell: bash # Build and test with Python on multiple platforms python-tests: @@ -68,10 +80,20 @@ jobs: # Keep in sync with pyproject.toml [project.dependencies] and [project.optional-dependencies.dev] run: pip install pytest pytest-xdist numpy pandas scipy + - name: Install OpenBLAS (Linux) + if: runner.os == 'Linux' + run: sudo apt-get update && sudo apt-get install -y libopenblas-dev + - name: Build and install with maturin run: | pip install maturin - maturin build --release -o dist + if [ "${{ runner.os }}" == "macOS" ]; then + maturin build --release -o dist --features extension-module,accelerate + elif [ "${{ runner.os }}" == "Linux" ]; then + maturin build --release -o dist --features extension-module,openblas + else + maturin build --release -o dist + fi echo "=== Built wheels ===" ls -la dist/ || dir dist shell: bash diff --git a/CHANGELOG.md b/CHANGELOG.md index 72a26832..f3ba89b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,23 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added +- **Conditional BLAS linking for Rust backend** — Apple Accelerate on macOS, OpenBLAS on Linux. + Pre-built wheels now use platform-optimized BLAS for matrix-vector and matrix-matrix + operations across all Rust-accelerated code paths (weights, OLS, TROP). Windows continues + using pure Rust (no external dependencies). Improves Rust backend performance at larger scales. +- `rust_backend_info()` diagnostic function in `diff_diff._backend` — reports compile-time + BLAS feature status (blas, accelerate, openblas) + +### Fixed +- **Rust SDID backend performance regression at scale** — Frank-Wolfe solver was 3-10x slower than pure Python at 1k+ scale + - Gram-accelerated FW loop for time weights: precomputes A^T@A, reducing per-iteration cost from O(N×T0) to O(T0) (~100x speedup per iteration at 5k scale) + - Allocation-free FW loop for unit weights: 1 GEMV per iteration (was 3), zero heap allocations (was ~8) + - Dispatch based on problem dimensions: Gram path when T0 < N, standard path when T0 >= N + - Rust backend now faster than pure Python at all scales + ## [2.4.1] - 2026-02-17 ### Added diff --git a/CLAUDE.md b/CLAUDE.md index bf5db3d5..ad24d4ff 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -40,6 +40,15 @@ maturin develop # Build with release optimizations maturin develop --release +# Build with platform BLAS (macOS — links Apple Accelerate) +maturin develop --release --features accelerate + +# Build with platform BLAS (Linux — requires libopenblas-dev) +maturin develop --release --features openblas + +# Build without BLAS (Windows, or explicit pure Rust) +maturin develop --release + # Force pure Python mode (disable Rust backend) DIFF_DIFF_BACKEND=python pytest @@ -50,9 +59,11 @@ DIFF_DIFF_BACKEND=rust pytest pytest tests/test_rust_backend.py -v ``` -**Note**: As of v2.2.0, the Rust backend uses the pure-Rust `faer` library for linear algebra, -eliminating external BLAS/LAPACK dependencies. This enables Windows wheel builds and simplifies -cross-platform compilation - no OpenBLAS or Intel MKL installation required. +**Note**: As of v2.2.0, the Rust backend uses `faer` (pure Rust) for SVD and matrix inversion. +BLAS is optionally linked via Cargo features (`accelerate` on macOS, `openblas` on Linux) +for matrix-vector/matrix-matrix products. Windows builds remain fully pure Rust with no +external dependencies. Pre-built PyPI wheels include platform BLAS; source builds use +pure Rust by default. ## Architecture @@ -183,6 +194,7 @@ cross-platform compilation - no OpenBLAS or Intel MKL installation required. - Detects optional Rust backend availability - Handles `DIFF_DIFF_BACKEND` environment variable ('auto', 'python', 'rust') - Exports `HAS_RUST_BACKEND` flag and Rust function references + - `rust_backend_info()` — returns compile-time BLAS feature status dict - Other modules import from here to avoid circular imports with `__init__.py` - **`rust/`** - Optional Rust backend for accelerated computation (v2.0.0+): @@ -194,8 +206,10 @@ cross-platform compilation - no OpenBLAS or Intel MKL installation required. - `compute_unit_distance_matrix()` - Parallel pairwise RMSE distance computation (4-8x speedup) - `loocv_grid_search()` - Parallel LOOCV across tuning parameters (10-50x speedup) - `bootstrap_trop_variance()` - Parallel bootstrap variance estimation (5-15x speedup) - - Uses pure-Rust `faer` library for linear algebra (no external BLAS/LAPACK dependencies) - - Cross-platform: builds on Linux, macOS, and Windows without additional setup + - Uses pure-Rust `faer` library for SVD/matrix inversion (no external deps) + - Optional BLAS linking via Cargo features: `accelerate` (macOS), `openblas` (Linux) + - When BLAS is enabled, ndarray `.dot()` calls dispatch to platform-optimized dgemv/dgemm + - Cross-platform: Windows builds use pure Rust with no additional setup - Provides 4-8x speedup for SyntheticDiD, 5-20x speedup for TROP - **`diff_diff/results.py`** - Dataclass containers for estimation results: diff --git a/diff_diff/_backend.py b/diff_diff/_backend.py index 954f94c4..0c6e99f6 100644 --- a/diff_diff/_backend.py +++ b/diff_diff/_backend.py @@ -35,6 +35,8 @@ compute_time_weights as _rust_compute_time_weights, compute_noise_level as _rust_compute_noise_level, sc_weight_fw as _rust_sc_weight_fw, + # Diagnostics + rust_backend_info as _rust_backend_info, ) _rust_available = True except ImportError: @@ -56,6 +58,7 @@ _rust_compute_time_weights = None _rust_compute_noise_level = None _rust_sc_weight_fw = None + _rust_backend_info = None # Determine final backend based on environment variable and availability if _backend_env == 'python': @@ -78,6 +81,7 @@ _rust_compute_time_weights = None _rust_compute_noise_level = None _rust_sc_weight_fw = None + _rust_backend_info = None elif _backend_env == 'rust': # Force Rust mode - fail if not available if not _rust_available: @@ -90,8 +94,25 @@ # Auto mode - use Rust if available HAS_RUST_BACKEND = _rust_available + +def rust_backend_info(): + """Return compile-time BLAS feature information for the Rust backend. + + Returns a dict with keys: + - 'blas': True if any BLAS backend is linked + - 'accelerate': True if Apple Accelerate is linked (macOS) + - 'openblas': True if OpenBLAS is linked (Linux) + + If the Rust backend is not available, all values are False. + """ + if _rust_backend_info is not None: + return _rust_backend_info() + return {"blas": False, "accelerate": False, "openblas": False} + + __all__ = [ 'HAS_RUST_BACKEND', + 'rust_backend_info', '_rust_bootstrap_weights', '_rust_synthetic_weights', '_rust_project_simplex', diff --git a/docs/benchmarks.rst b/docs/benchmarks.rst index 3c4366a0..425f1b50 100644 --- a/docs/benchmarks.rst +++ b/docs/benchmarks.rst @@ -267,6 +267,11 @@ implementations: additional speedup since these estimators primarily use OLS and variance computations that are already highly optimized in NumPy/SciPy via BLAS/LAPACK. + As of v2.5.0, pre-built wheels on macOS and Linux link platform-optimized + BLAS libraries (Apple Accelerate and OpenBLAS respectively) for matrix-vector + and matrix-matrix products across all Rust-accelerated code paths. Windows + wheels continue to use pure Rust with no external dependencies. + Three-Way Performance Summary ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -394,10 +399,11 @@ Three-Way Performance Summary Frank-Wolfe optimization algorithm. At 5k scale, R takes ~9 minutes while pure Python completes in 32 seconds. ATT estimates are numerically identical (< 1e-10 difference) since both implementations use the same Frank-Wolfe - optimizer with two-pass sparsification. The Rust backend provides a speedup - at small scale (2.1x over pure Python) but is slower at larger scales due to - overhead in the placebo variance estimation loop; this is a known area for - future optimization. + optimizer with two-pass sparsification. The Rust backend uses a + Gram-accelerated Frank-Wolfe solver for time weights (reducing per-iteration + cost from O(N×T0) to O(T0)) and an allocation-free solver for unit weights + (1 GEMV per iteration instead of 3, zero heap allocations). These + optimizations make the Rust backend faster than pure Python at all scales. Dataset Sizes ~~~~~~~~~~~~~ diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 00e75504..3cdf3f7e 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -14,6 +14,10 @@ crate-type = ["cdylib", "rlib"] default = [] # extension-module is only needed for cdylib builds, not for cargo test extension-module = ["pyo3/extension-module"] +# Platform BLAS backends (optional, activated for pre-built wheels) +# When enabled, ndarray's .dot() and general_mat_vec_mul dispatch to BLAS dgemv/dgemm +accelerate = ["ndarray/blas", "dep:blas-src", "blas-src/accelerate"] +openblas = ["ndarray/blas"] [dependencies] # PyO3 0.22 supports Python 3.8-3.13 @@ -24,10 +28,14 @@ rand = "0.8" rand_xoshiro = "0.6" rayon = "1.8" -# Pure Rust linear algebra library - no external BLAS/LAPACK dependencies -# This enables Windows builds without Intel MKL complexity +# Pure Rust linear algebra for SVD/matrix inversion (no external deps). +# BLAS for matrix-vector products is optional via accelerate/openblas features. faer = "0.24" +# BLAS backend (optional, activated by accelerate/openblas features) +# blas-src 0.10 is ndarray's tested version (see ndarray/crates/blas-tests/Cargo.toml) +blas-src = { version = "0.10", optional = true } + [profile.release] lto = true codegen-units = 1 diff --git a/rust/build.rs b/rust/build.rs new file mode 100644 index 00000000..b9fff1ce --- /dev/null +++ b/rust/build.rs @@ -0,0 +1,12 @@ +/// Build script for diff_diff_rust. +/// +/// When the `openblas` feature is enabled, links against the system OpenBLAS +/// library directly. This avoids the `openblas-src` -> `openblas-build` -> +/// `ureq` -> `native-tls` dependency chain, which has Rust compiler +/// compatibility issues. Requires `libopenblas-dev` (Ubuntu) or +/// `openblas-devel` (CentOS/manylinux) to be installed. +fn main() { + if std::env::var("CARGO_FEATURE_OPENBLAS").is_ok() { + println!("cargo:rustc-link-lib=openblas"); + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 15606281..9507e1c7 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -3,7 +3,16 @@ //! This module provides optimized implementations of computationally //! intensive operations used in difference-in-differences analysis. +// Pull in BLAS linker flags for macOS Accelerate. +// blas-src is a linker-only crate — extern crate is required to ensure +// the Accelerate framework is actually linked. +// For OpenBLAS (Linux), linking is handled by build.rs instead of blas-src +// to avoid the openblas-src -> ureq -> native-tls dependency chain. +#[cfg(feature = "accelerate")] +extern crate blas_src; + use pyo3::prelude::*; +use std::collections::HashMap; mod bootstrap; mod linalg; @@ -42,8 +51,24 @@ fn _rust_backend(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(trop::loocv_grid_search_joint, m)?)?; m.add_function(wrap_pyfunction!(trop::bootstrap_trop_variance_joint, m)?)?; + // Diagnostics + m.add_function(wrap_pyfunction!(rust_backend_info, m)?)?; + // Version info m.add("__version__", env!("CARGO_PKG_VERSION"))?; Ok(()) } + +/// Return compile-time BLAS feature information for diagnostics. +#[pyfunction] +fn rust_backend_info() -> PyResult> { + let mut info = HashMap::new(); + info.insert( + "blas".to_string(), + cfg!(feature = "accelerate") || cfg!(feature = "openblas"), + ); + info.insert("accelerate".to_string(), cfg!(feature = "accelerate")); + info.insert("openblas".to_string(), cfg!(feature = "openblas")); + Ok(info) +} diff --git a/rust/src/weights.rs b/rust/src/weights.rs index 9b5cfb5d..001b9b23 100644 --- a/rust/src/weights.rs +++ b/rust/src/weights.rs @@ -7,6 +7,7 @@ //! - SDID unit and time weight computation use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis}; +use ndarray::linalg::general_mat_vec_mul; use numpy::{PyArray1, PyReadonlyArray1, PyReadonlyArray2, ToPyArray}; use pyo3::prelude::*; @@ -201,46 +202,254 @@ fn sparsify_internal(v: &Array1) -> Array1 { sum_normalize_internal(&result) } -/// Single Frank-Wolfe step on the simplex. -/// Matches R's fw.step() in synthdid's sc.weight.fw(). -fn fw_step_internal( - a: &ArrayView2, - x: &Array1, - b: &ArrayView1, - eta: f64, -) -> Array1 { - let ax = a.dot(x); - let diff = &ax - b; - let half_grad = a.t().dot(&diff) + eta * x; +/// Interval (in iterations) at which the Gram path refreshes `ata_x = ATA @ lam` +/// from scratch to prevent floating-point drift from incremental updates. +const GRAM_REFRESH_INTERVAL: usize = 100; - // Find vertex with smallest gradient component - let i = half_grad - .iter() +/// Find the index of the minimum element in a vector. +#[inline] +fn argmin_f64(v: &Array1) -> usize { + debug_assert!(!v.is_empty(), "argmin called on empty array"); + v.iter() .enumerate() .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) .map(|(idx, _)| idx) - .unwrap_or(0); + .unwrap_or(0) +} + +/// Gram-accelerated Frank-Wolfe loop for the T0 < N case (time weights). +/// +/// By precomputing `ATA = A^T @ A` (T0×T0) and `ATb = A^T @ b` (T0,), +/// the entire FW iteration reduces from O(N×T0) per step to O(T0) per step. +/// This is a ~N/T0 speedup (e.g., ~100x at 5k scale where N=4000, T0=40). +/// +/// Zero heap allocations per iteration — all workspace is pre-allocated. +fn sc_weight_fw_gram( + a: &ArrayView2, + b: &ArrayView1, + lam: &mut Array1, + eta: f64, + zeta: f64, + n: usize, + min_decrease_sq: f64, + max_iter: usize, +) { + let t0 = lam.len(); + + // Precompute Gram matrix and related quantities — O(N×T0²) once + let ata = a.t().dot(a); // (T0, T0) + let atb = a.t().dot(b); // (T0,) + let b_norm_sq = b.dot(b); // scalar + + // Diagonal of ATA for d_err_sq computation + let mut ata_diag = Array1::zeros(t0); + for j in 0..t0 { + ata_diag[j] = ata[[j, j]]; + } + + // Maintained incrementally: ata_x = ATA @ lam + let mut ata_x = ata.dot(lam); + + // Pre-allocate workspace + let mut half_grad = Array1::zeros(t0); + + let mut prev_val = f64::INFINITY; + + for t in 0..max_iter { + // Step 1: half_grad[j] = ata_x[j] - atb[j] + eta * lam[j] + for j in 0..t0 { + half_grad[j] = ata_x[j] - atb[j] + eta * lam[j]; + } + + // Step 2: Find vertex with smallest gradient component + let i = argmin_f64(&half_grad); + + // Step 3: Guard — check if direction is essentially zero + // d_x_norm_sq = ||lam||² + 1 - 2*lam[i] (since d_x = e_i - lam) + let lam_norm_sq: f64 = lam.iter().map(|&v| v * v).sum(); + let d_x_norm_sq = lam_norm_sq + 1.0 - 2.0 * lam[i]; + if d_x_norm_sq < 1e-24 { + // Already at optimal vertex — compute objective for convergence check + let xt_ata_x: f64 = ata_x.iter().zip(lam.iter()).map(|(&a, &b)| a * b).sum(); + let atb_dot_lam: f64 = atb.iter().zip(lam.iter()).map(|(&a, &b)| a * b).sum(); + let val = zeta * zeta * lam_norm_sq + (xt_ata_x - 2.0 * atb_dot_lam + b_norm_sq) / n as f64; + if t >= 1 && prev_val - val < min_decrease_sq { + break; + } + prev_val = val; + continue; + } + + // Step 4: Guard — compute denominator for step size + // d_err_sq = ||A[:,i] - A@lam||² = ata_diag[i] - 2*ata_x[i] + lam^T ATA lam + let xt_ata_x: f64 = ata_x.iter().zip(lam.iter()).map(|(&a, &b)| a * b).sum(); + let d_err_sq = ata_diag[i] - 2.0 * ata_x[i] + xt_ata_x; + let denom = d_err_sq + eta * d_x_norm_sq; + if denom <= 0.0 { + let atb_dot_lam: f64 = atb.iter().zip(lam.iter()).map(|(&a, &b)| a * b).sum(); + let val = zeta * zeta * lam_norm_sq + (xt_ata_x - 2.0 * atb_dot_lam + b_norm_sq) / n as f64; + if t >= 1 && prev_val - val < min_decrease_sq { + break; + } + prev_val = val; + continue; + } + + // Step 5: hg_dot_dx = half_grad[i] - half_grad @ lam + let hg_dot_lam: f64 = half_grad.iter().zip(lam.iter()).map(|(&a, &b)| a * b).sum(); + let hg_dot_dx = half_grad[i] - hg_dot_lam; + + // Step 6: step = clamp(-hg_dot_dx / denom, 0, 1) + let step = (-hg_dot_dx / denom).max(0.0).min(1.0); - // Direction: d_x = e_i - x - let mut d_x = -x.clone(); - d_x[i] += 1.0; + // Step 7: lam = (1-step)*lam + step*e_i (in-place) + let one_minus_step = 1.0 - step; + for j in 0..t0 { + lam[j] *= one_minus_step; + } + lam[i] += step; + + // Step 8: ata_x incremental update — O(T0) + // ata_x = (1-step)*ata_x + step*ATA[:,i] + let ata_col_i = ata.column(i); + for j in 0..t0 { + ata_x[j] = one_minus_step * ata_x[j] + step * ata_col_i[j]; + } + // Periodic refresh to prevent drift + if t > 0 && t % GRAM_REFRESH_INTERVAL == 0 { + ata_x = ata.dot(lam as &Array1); + } - // Check if direction is essentially zero - let d_x_norm_sq: f64 = d_x.iter().map(|&v| v * v).sum(); - if d_x_norm_sq < 1e-24 { - return x.clone(); + // Step 9: Compute objective + let lam_norm_sq: f64 = lam.iter().map(|&v| v * v).sum(); + let xt_ata_x: f64 = ata_x.iter().zip(lam.iter()).map(|(&a, &b)| a * b).sum(); + let atb_dot_lam: f64 = atb.iter().zip(lam.iter()).map(|(&a, &b)| a * b).sum(); + let val = zeta * zeta * lam_norm_sq + (xt_ata_x - 2.0 * atb_dot_lam + b_norm_sq) / n as f64; + + // Step 10: Convergence check + if t >= 1 && prev_val - val < min_decrease_sq { + break; + } + prev_val = val; } +} - // Compute step size via exact line search - let d_err = a.column(i).to_owned() - &ax; - let denom = d_err.dot(&d_err) + eta * d_x.dot(&d_x); - if denom <= 0.0 { - return x.clone(); +/// Allocation-free standard Frank-Wolfe loop for the T0 >= N case (unit weights). +/// +/// Same algorithm as the original `fw_step_internal` but with: +/// - 1 GEMV per iteration (down from 3) +/// - 0 heap allocations per iteration (down from ~8) +/// - Incremental `ax` maintenance instead of recomputing A @ lam each step +fn sc_weight_fw_standard( + a: &ArrayView2, + b: &ArrayView1, + lam: &mut Array1, + eta: f64, + zeta: f64, + n: usize, + min_decrease_sq: f64, + max_iter: usize, +) { + let t0 = lam.len(); + + // Precompute column norms: col_norms_sq[j] = ||A[:,j]||² + let mut col_norms_sq = Array1::zeros(t0); + for j in 0..t0 { + let col = a.column(j); + col_norms_sq[j] = col.dot(&col); } - let step = -(half_grad.dot(&d_x)) / denom; - let step = step.max(0.0).min(1.0); - x + &(step * &d_x) + // Pre-allocate workspace + let mut ax = a.dot(lam as &Array1); // (N,), maintained incrementally + let mut half_grad = Array1::zeros(t0); + let mut diff = Array1::zeros(n); // Reusable buffer for ax - b + + let mut prev_val = f64::INFINITY; + + for t in 0..max_iter { + // Step 1-2: Compute half_grad = A^T @ (ax - b) + eta * lam + // Uses general_mat_vec_mul which dispatches to BLAS dgemv when enabled, + // otherwise falls back to ndarray's optimized matrixmultiply kernel. + diff.assign(&ax); + diff -= &*b; + general_mat_vec_mul(1.0, &a.t(), &diff, 0.0, &mut half_grad); + half_grad.scaled_add(eta, &*lam); + + // Step 3: Find vertex with smallest gradient component + let i = argmin_f64(&half_grad); + + // Step 4: Guard — d_x_norm_sq = ||lam||² + 1 - 2*lam[i] + let lam_norm_sq: f64 = lam.iter().map(|&v| v * v).sum(); + let d_x_norm_sq = lam_norm_sq + 1.0 - 2.0 * lam[i]; + if d_x_norm_sq < 1e-24 { + // Compute objective for convergence check + let mut err_sq = 0.0; + for k in 0..n { + let e = ax[k] - b[k]; + err_sq += e * e; + } + let val = zeta * zeta * lam_norm_sq + err_sq / n as f64; + if t >= 1 && prev_val - val < min_decrease_sq { + break; + } + prev_val = val; + continue; + } + + // Step 5: d_err_sq = col_norms_sq[i] - 2*A[:,i].dot(&ax) + ax.dot(&ax) + let col_i = a.column(i); + let col_dot_ax: f64 = col_i.iter().zip(ax.iter()).map(|(&a, &b)| a * b).sum(); + let ax_dot_ax: f64 = ax.iter().map(|&v| v * v).sum(); + let d_err_sq = col_norms_sq[i] - 2.0 * col_dot_ax + ax_dot_ax; + + // Step 6: Guard — denom + let denom = d_err_sq + eta * d_x_norm_sq; + if denom <= 0.0 { + let mut err_sq = 0.0; + for k in 0..n { + let e = ax[k] - b[k]; + err_sq += e * e; + } + let val = zeta * zeta * lam_norm_sq + err_sq / n as f64; + if t >= 1 && prev_val - val < min_decrease_sq { + break; + } + prev_val = val; + continue; + } + + // Step 7: hg_dot_dx and step + let hg_dot_lam: f64 = half_grad.iter().zip(lam.iter()).map(|(&a, &b)| a * b).sum(); + let hg_dot_dx = half_grad[i] - hg_dot_lam; + let step = (-hg_dot_dx / denom).max(0.0).min(1.0); + + // Step 8: Update lam in-place: lam = (1-step)*lam + step*e_i + let one_minus_step = 1.0 - step; + for j in 0..t0 { + lam[j] *= one_minus_step; + } + lam[i] += step; + + // Step 9: Update ax incrementally — O(N) + // ax = (1-step)*ax + step*A[:,i] + for k in 0..n { + ax[k] = one_minus_step * ax[k] + step * col_i[k]; + } + + // Step 10: Compute objective + let mut err_sq = 0.0; + for k in 0..n { + let e = ax[k] - b[k]; + err_sq += e * e; + } + let lam_norm_sq: f64 = lam.iter().map(|&v| v * v).sum(); + let val = zeta * zeta * lam_norm_sq + err_sq / n as f64; + + if t >= 1 && prev_val - val < min_decrease_sq { + break; + } + prev_val = val; + } } /// Compute synthetic control weights via Frank-Wolfe optimization. @@ -249,6 +458,10 @@ fn fw_step_internal( /// min_{lambda on simplex} zeta^2 * ||lambda||^2 /// + (1/N) * ||A_centered @ lambda - b_centered||^2 /// +/// Dispatches to one of two optimized loop implementations: +/// - **Gram path** (T0 < N): Precomputes A^T@A, reducing per-iteration cost from O(N×T0) to O(T0) +/// - **Standard path** (T0 >= N): Allocation-free loop with 1 GEMV/iter (was 3) and 0 allocs (was ~8) +/// /// # Arguments /// * `y` - Matrix of shape (N, T0+1). Last column is the target. /// * `zeta` - Regularization strength. @@ -292,22 +505,14 @@ fn sc_weight_fw_internal( }; let min_decrease_sq = min_decrease * min_decrease; - let mut prev_val = f64::INFINITY; - - for t in 0..max_iter { - lam = fw_step_internal(&a, &lam, &b, eta); - - // Compute objective: zeta^2 * ||lam||^2 + (1/N) * ||Y @ [lam, -1]||^2 - let mut lam_ext = Array1::zeros(t0 + 1); - lam_ext.slice_mut(s![..t0]).assign(&lam); - lam_ext[t0] = -1.0; - let err = y_owned.dot(&lam_ext); - let val = zeta * zeta * lam.dot(&lam) + err.dot(&err) / n as f64; - if t >= 1 && prev_val - val < min_decrease_sq { - break; - } - prev_val = val; + // Dispatch to optimized loop based on problem dimensions + if t0 < n { + // Gram path: precompute A^T@A for O(T0) per iteration + sc_weight_fw_gram(&a, &b, &mut lam, eta, zeta, n, min_decrease_sq, max_iter); + } else { + // Standard path: allocation-free with 1 GEMV per iteration + sc_weight_fw_standard(&a, &b, &mut lam, eta, zeta, n, min_decrease_sq, max_iter); } lam @@ -710,4 +915,164 @@ mod tests { assert_eq!(result.len(), 1); assert!((result[0] - 1.0).abs() < 1e-10); } + + #[test] + fn test_fw_gram_matches_standard() { + // Create a T0 < N problem and verify both paths produce identical weights. + // N=20, T0=5 -> Gram path. Force standard path via a wrapper for comparison. + let vals: Vec = (0..120).map(|i| ((i * 7 + 3) % 97) as f64 / 97.0).collect(); + let y_gram = Array2::from_shape_vec((20, 6), vals).unwrap(); + + // Directly call both internal functions on the same centered data. + let y_owned: Array2 = { + let col_means = y_gram.mean_axis(Axis(0)).unwrap(); + &y_gram - &col_means + }; + let t0 = 5; + let n = 20; + let a = y_owned.slice(s![.., ..t0]); + let b = y_owned.column(t0); + let eta = n as f64 * 0.3 * 0.3; + let min_decrease_sq = 1e-5 * 1e-5; + + // Run Gram path + let mut lam_gram = Array1::from_elem(t0, 1.0 / t0 as f64); + sc_weight_fw_gram(&a, &b, &mut lam_gram, eta, 0.3, n, min_decrease_sq, 10000); + + // Run standard path on same data + let mut lam_std = Array1::from_elem(t0, 1.0 / t0 as f64); + sc_weight_fw_standard(&a, &b, &mut lam_std, eta, 0.3, n, min_decrease_sq, 10000); + + // Both should produce nearly identical weights + for j in 0..t0 { + assert!( + (lam_gram[j] - lam_std[j]).abs() < 1e-10, + "Gram and standard paths diverge at index {}: gram={}, std={}", + j, lam_gram[j], lam_std[j] + ); + } + + // Verify both are valid simplex weights + let sum_gram: f64 = lam_gram.sum(); + let sum_std: f64 = lam_std.sum(); + assert!((sum_gram - 1.0).abs() < 1e-6, "Gram weights should sum to 1, got {}", sum_gram); + assert!((sum_std - 1.0).abs() < 1e-6, "Standard weights should sum to 1, got {}", sum_std); + } + + #[test] + fn test_fw_standard_no_regression() { + // Create a T0 >= N problem (unit weights case) and verify against known output. + // N=5, T0=8 -> Standard path. + let vals: Vec = (0..45).map(|i| ((i * 13 + 5) % 53) as f64 / 53.0).collect(); + let y = Array2::from_shape_vec((5, 9), vals).unwrap(); + + let result = sc_weight_fw_internal(&y.view(), 0.5, true, None, 1e-5, 10000); + + // Verify valid simplex weights + let sum: f64 = result.sum(); + assert!((sum - 1.0).abs() < 1e-6, "Weights should sum to 1, got {}", sum); + assert!(result.iter().all(|&w| w >= -1e-6), "Weights should be non-negative"); + assert_eq!(result.len(), 8); + } + + #[test] + fn test_incremental_ata_x_accuracy() { + // Run 500+ iterations on a T0 < N problem and verify incremental ata_x + // doesn't drift significantly from fresh computation. + let vals: Vec = (0..200).map(|i| { + let x = (i as f64) * 0.1; + x.sin() + ((i * 7) % 31) as f64 / 31.0 + }).collect(); + let y = Array2::from_shape_vec((20, 10), vals).unwrap(); + + // Run with enough iterations to exercise the refresh mechanism + let result = sc_weight_fw_internal(&y.view(), 0.1, true, None, 1e-8, 1000); + + // Verify valid result (convergence with correct weights) + let sum: f64 = result.sum(); + assert!((sum - 1.0).abs() < 1e-6, "Weights should sum to 1, got {}", sum); + assert!(result.iter().all(|&w| w >= -1e-6), "Weights should be non-negative"); + + // Verify Gram path was used (T0=9 < N=20) + assert_eq!(result.len(), 9); + } + + #[test] + fn test_gram_boundary_t0_equals_n_minus_1() { + // T0 = N-1 triggers Gram path (T0 < N), works correctly + // N=6, T0=5 -> just barely Gram path + let vals: Vec = (0..36).map(|i| ((i * 11 + 7) % 41) as f64 / 41.0).collect(); + let y = Array2::from_shape_vec((6, 6), vals).unwrap(); + + let result = sc_weight_fw_internal(&y.view(), 0.2, true, None, 1e-5, 10000); + + let sum: f64 = result.sum(); + assert!((sum - 1.0).abs() < 1e-6, "Weights should sum to 1, got {}", sum); + assert!(result.iter().all(|&w| w >= -1e-6), "Weights should be non-negative"); + assert_eq!(result.len(), 5); // T0 = 5 + } + + #[test] + fn test_gemv_produces_correct_half_grad() { + // Verify that general_mat_vec_mul produces the same half_grad as manual loop. + let n = 10; + let t0 = 4; + let eta = 0.5; + + // Deterministic test data + let a_vals: Vec = (0..(n * t0)).map(|i| ((i * 7 + 3) % 41) as f64 / 41.0).collect(); + let a = Array2::from_shape_vec((n, t0), a_vals).unwrap(); + + let ax: Array1 = (0..n).map(|i| ((i * 11 + 5) % 37) as f64 / 37.0).collect(); + let b: Array1 = (0..n).map(|i| ((i * 13 + 2) % 29) as f64 / 29.0).collect(); + let lam: Array1 = (0..t0).map(|j| ((j * 17 + 1) % 19) as f64 / 19.0).collect(); + + // Reference: manual loop + let mut ref_grad = Array1::zeros(t0); + for j in 0..t0 { + let col = a.column(j); + let mut dot = 0.0; + for k in 0..n { + dot += col[k] * (ax[k] - b[k]); + } + ref_grad[j] = dot + eta * lam[j]; + } + + // New code path: general_mat_vec_mul + scaled_add + let mut new_grad = Array1::zeros(t0); + let mut diff = Array1::zeros(n); + diff.assign(&ax); + diff -= &b; + general_mat_vec_mul(1.0, &a.t(), &diff, 0.0, &mut new_grad); + new_grad.scaled_add(eta, &lam); + + // Verify match to high precision + for j in 0..t0 { + assert!( + (ref_grad[j] - new_grad[j]).abs() < 1e-12, + "half_grad mismatch at index {}: manual={}, gemv={}", + j, ref_grad[j], new_grad[j] + ); + } + } + + #[test] + fn test_intercept_false_both_paths() { + // Verify both Gram and standard paths work with intercept=false + // Gram path: N=15, T0=4 (T0 < N) + let vals_gram: Vec = (0..75).map(|i| ((i * 3 + 1) % 37) as f64 / 37.0).collect(); + let y_gram = Array2::from_shape_vec((15, 5), vals_gram).unwrap(); + let result_gram = sc_weight_fw_internal(&y_gram.view(), 0.3, false, None, 1e-5, 10000); + let sum_gram: f64 = result_gram.sum(); + assert!((sum_gram - 1.0).abs() < 1e-6, "Gram intercept=false: weights should sum to 1, got {}", sum_gram); + assert!(result_gram.iter().all(|&w| w >= -1e-6), "Gram intercept=false: weights should be non-negative"); + + // Standard path: N=4, T0=10 (T0 >= N) + let vals_std: Vec = (0..44).map(|i| ((i * 5 + 2) % 29) as f64 / 29.0).collect(); + let y_std = Array2::from_shape_vec((4, 11), vals_std).unwrap(); + let result_std = sc_weight_fw_internal(&y_std.view(), 0.3, false, None, 1e-5, 10000); + let sum_std: f64 = result_std.sum(); + assert!((sum_std - 1.0).abs() < 1e-6, "Standard intercept=false: weights should sum to 1, got {}", sum_std); + assert!(result_std.iter().all(|&w| w >= -1e-6), "Standard intercept=false: weights should be non-negative"); + } } diff --git a/tests/test_rust_backend.py b/tests/test_rust_backend.py index 376585b6..0a86aeea 100644 --- a/tests/test_rust_backend.py +++ b/tests/test_rust_backend.py @@ -23,6 +23,22 @@ def test_rust_backend_available(self): """Verify Rust backend is available when this test runs.""" assert HAS_RUST_BACKEND + def test_rust_backend_info(self): + """Test rust_backend_info returns valid diagnostics dict.""" + from diff_diff._backend import rust_backend_info + + info = rust_backend_info() + assert isinstance(info, dict) + assert "blas" in info + assert "accelerate" in info + assert "openblas" in info + assert isinstance(info["blas"], bool) + assert isinstance(info["accelerate"], bool) + assert isinstance(info["openblas"], bool) + # If either platform BLAS is enabled, blas should be True + if info["accelerate"] or info["openblas"]: + assert info["blas"] is True + # ========================================================================= # Bootstrap Weight Tests # ========================================================================= @@ -1828,6 +1844,85 @@ def test_unit_weights_single_control(self): assert weights.shape == (1,) assert abs(weights[0] - 1.0) < 1e-10 + def test_fw_gram_vs_standard_equivalence(self): + """Test Gram path (T0 < N) and standard path produce equivalent results. + + Creates a problem where T0 < N (triggers Gram path in Rust), then + verifies the Rust result matches pure Python exactly. This validates + that the Gram precomputation optimization produces identical weights. + """ + from diff_diff._rust_backend import sc_weight_fw as rust_fn + from diff_diff.utils import _sc_weight_fw_numpy as numpy_fn + + np.random.seed(42) + # N=50 rows, T0=8 columns + 1 target = 9 cols total + # This triggers Gram path (T0=8 < N=50) + Y = np.random.randn(50, 9) + + rust_w = rust_fn(Y, 0.3, True, None, 1e-5, 10000) + numpy_w = numpy_fn(Y, 0.3, True, None, 1e-5, 10000) + + # Weights must match to high precision + np.testing.assert_array_almost_equal( + rust_w, numpy_w, decimal=6, + err_msg="Gram path weights should match Python" + ) + assert abs(rust_w.sum() - 1.0) < 1e-6 + assert np.all(rust_w >= -1e-6) + + def test_fw_standard_path_equivalence(self): + """Test standard path (T0 >= N) produces results matching Python. + + Creates a problem where T0 >= N (triggers standard path in Rust), + then verifies the Rust result matches pure Python exactly. + """ + from diff_diff._rust_backend import sc_weight_fw as rust_fn + from diff_diff.utils import _sc_weight_fw_numpy as numpy_fn + + np.random.seed(42) + # N=5 rows, T0=12 columns + 1 target = 13 cols total + # This triggers standard path (T0=12 >= N=5) + Y = np.random.randn(5, 13) + + rust_w = rust_fn(Y, 0.5, True, None, 1e-5, 10000) + numpy_w = numpy_fn(Y, 0.5, True, None, 1e-5, 10000) + + np.testing.assert_array_almost_equal( + rust_w, numpy_w, decimal=6, + err_msg="Standard path weights should match Python" + ) + assert abs(rust_w.sum() - 1.0) < 1e-6 + assert np.all(rust_w >= -1e-6) + + def test_sdid_intercept_false_rust_vs_python(self): + """Test intercept=false produces matching weights in both backends. + + Verifies both Gram and standard paths handle intercept=false correctly + (no column centering applied). + """ + from diff_diff._rust_backend import sc_weight_fw as rust_fn + from diff_diff.utils import _sc_weight_fw_numpy as numpy_fn + + np.random.seed(42) + + # Gram path: T0 < N + Y_gram = np.random.randn(30, 6) + rust_w_gram = rust_fn(Y_gram, 0.2, False, None, 1e-5, 10000) + numpy_w_gram = numpy_fn(Y_gram, 0.2, False, None, 1e-5, 10000) + np.testing.assert_array_almost_equal( + rust_w_gram, numpy_w_gram, decimal=6, + err_msg="Gram path intercept=false weights should match Python" + ) + + # Standard path: T0 >= N + Y_std = np.random.randn(4, 10) + rust_w_std = rust_fn(Y_std, 0.2, False, None, 1e-5, 10000) + numpy_w_std = numpy_fn(Y_std, 0.2, False, None, 1e-5, 10000) + np.testing.assert_array_almost_equal( + rust_w_std, numpy_w_std, decimal=6, + err_msg="Standard path intercept=false weights should match Python" + ) + def test_full_sdid_rust_vs_python(self): """Test full SDID estimation produces same results with Rust and Python.""" import pandas as pd