Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/rust-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -135,24 +135,24 @@ jobs:
- name: Run Rust backend tests (Unix)
if: runner.os != 'Windows'
working-directory: /tmp
run: pytest tests/test_rust_backend.py -v
run: pytest tests/test_rust_backend.py -v -m ''

- name: Run Rust backend tests (Windows)
if: runner.os == 'Windows'
working-directory: ${{ runner.temp }}
run: pytest tests/test_rust_backend.py -v
run: pytest tests/test_rust_backend.py -v -m ''

- name: Run tests with Rust backend (Unix)
if: runner.os != 'Windows'
working-directory: /tmp
run: DIFF_DIFF_BACKEND=rust pytest tests/ -q -n auto --dist worksteal
run: DIFF_DIFF_BACKEND=rust pytest tests/ -q -n auto --dist worksteal -m ''

- name: Run tests with Rust backend (Windows)
if: runner.os == 'Windows'
working-directory: ${{ runner.temp }}
run: |
$env:DIFF_DIFF_BACKEND="rust"
pytest tests/ -q -n auto --dist worksteal
pytest tests/ -q -n auto --dist worksteal -m ''
shell: pwsh

# Test pure Python fallback (without Rust extension)
Expand All @@ -177,4 +177,4 @@ jobs:
PYTHONPATH=. python -c "from diff_diff import HAS_RUST_BACKEND; print(f'HAS_RUST_BACKEND: {HAS_RUST_BACKEND}'); assert not HAS_RUST_BACKEND"

- name: Run tests in pure Python mode
run: PYTHONPATH=. DIFF_DIFF_BACKEND=python pytest tests/ -q --ignore=tests/test_rust_backend.py -n auto --dist worksteal
run: PYTHONPATH=. DIFF_DIFF_BACKEND=python pytest tests/ -q --ignore=tests/test_rust_backend.py -n auto --dist worksteal -m ''
5 changes: 3 additions & 2 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ category (`Methodology/Correctness`, `Performance`, or `Testing/Docs`):
`threshold = 0.40 if n_boot < 100 else 0.15`.
- **`assert_nan_inference()`** from conftest.py: Use to validate ALL inference fields are
NaN-consistent. Don't check individual fields separately.
- **Slow test suites**: `tests/test_trop.py` is very time-consuming. Skip with
`pytest --ignore=tests/test_trop.py` for unrelated changes.
- **Slow tests**: TROP, Sun-Abraham bootstrap, and TROP-parity tests are marked
`@pytest.mark.slow` and excluded by default via `addopts`. Run `pytest -m ''`
to include them, or `pytest -m slow` to run only slow tests.
- **Behavioral assertions**: Always assert expected outcomes, not just no-exception.
Bad: `result = func(bad_input)`. Good: `result = func(bad_input); assert np.isnan(result.coef)`.

Expand Down
24 changes: 9 additions & 15 deletions diff_diff/sun_abraham.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,10 @@ def _run_bootstrap(
all_units = df[unit].unique()
n_units = len(all_units)

# Pre-compute unit -> row indices mapping (avoids repeated boolean scans)
unit_row_indices = {u: df.index[df[unit] == u].values for u in all_units}
unit_row_counts = {u: len(idx) for u, idx in unit_row_indices.items()}

# Store bootstrap samples
rel_periods = sorted(original_event_study.keys())
bootstrap_effects = {e: np.zeros(self.n_bootstrap) for e in rel_periods}
Expand All @@ -1009,23 +1013,13 @@ def _run_bootstrap(
# Resample units with replacement (pairs bootstrap)
boot_units = rng.choice(all_units, size=n_units, replace=True)

# Create bootstrap sample efficiently
# Build index array for all selected units
boot_indices = np.concatenate([
df.index[df[unit] == u].values for u in boot_units
])
# Create bootstrap sample using pre-computed index mapping
boot_indices = np.concatenate([unit_row_indices[u] for u in boot_units])
df_b = df.iloc[boot_indices].copy()

# Reassign unique unit IDs for bootstrap sample
# Each resampled unit gets a unique ID
new_unit_ids = []
current_id = 0
for u in boot_units:
unit_rows = df[df[unit] == u]
for _ in range(len(unit_rows)):
new_unit_ids.append(current_id)
current_id += 1
df_b[unit] = new_unit_ids[:len(df_b)]
# Reassign unique unit IDs (vectorized)
rows_per_unit = np.array([unit_row_counts[u] for u in boot_units])
df_b[unit] = np.repeat(np.arange(n_units), rows_per_unit)

# Recompute relative time (vectorized)
df_b["_rel_time"] = np.where(
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ python-packages = ["diff_diff"]
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = "test_*.py"
# Run all tests including slow ones by default; use `pytest -m 'not slow'` for faster local runs
addopts = "-v --tb=short"
# Exclude slow tests by default; use `pytest -m ''` to run all tests
addopts = "-v --tb=short -m 'not slow'"
markers = [
"slow: marks tests as slow (run `pytest -m 'not slow'` to exclude, or `pytest -m slow` to run only slow tests)",
]
Expand Down
2 changes: 2 additions & 0 deletions tests/test_rust_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,7 @@ def test_trop_produces_valid_results(self):
assert results.lambda_nn in [0.0, 0.1]


@pytest.mark.slow
@pytest.mark.skipif(not HAS_RUST_BACKEND, reason="Rust backend not available")
class TestTROPJointRustBackend:
"""Test suite for TROP joint method Rust backend functions."""
Expand Down Expand Up @@ -1269,6 +1270,7 @@ def test_bootstrap_trop_variance_joint_reproducible(self):
np.testing.assert_almost_equal(se1, se2)


@pytest.mark.slow
@pytest.mark.skipif(not HAS_RUST_BACKEND, reason="Rust backend not available")
class TestTROPJointRustVsNumpy:
"""Tests comparing TROP joint Rust and NumPy implementations."""
Expand Down
1 change: 1 addition & 0 deletions tests/test_sun_abraham.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def test_invalid_level_error(self):
results.to_dataframe(level="invalid")


@pytest.mark.slow
class TestSunAbrahamBootstrap:
"""Tests for Sun-Abraham bootstrap inference."""

Expand Down
2 changes: 2 additions & 0 deletions tests/test_trop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pandas as pd
import pytest

pytestmark = pytest.mark.slow

from diff_diff import SyntheticDiD
from diff_diff.trop import TROP, TROPResults, trop
from diff_diff.prep import generate_factor_data
Expand Down
Loading