|
11 | 11 | import numpy as np |
12 | 12 | import pytest |
13 | 13 |
|
| 14 | +import pandas as pd |
| 15 | + |
14 | 16 | from diff_diff.synthetic_did import SyntheticDiD |
15 | 17 | from diff_diff.utils import ( |
16 | 18 | _compute_noise_level, |
@@ -730,3 +732,164 @@ def test_placebo_reestimates_weights_not_fixed(self): |
730 | 732 | f"({fixed_se:.6f}), suggesting weights are NOT being " |
731 | 733 | f"re-estimated as R's synthdid does." |
732 | 734 | ) |
| 735 | + |
| 736 | + |
| 737 | +# ============================================================================= |
| 738 | +# Treatment Validation |
| 739 | +# ============================================================================= |
| 740 | + |
| 741 | + |
| 742 | +class TestTreatmentValidation: |
| 743 | + """Test that SDID rejects time-varying treatment (staggered designs).""" |
| 744 | + |
| 745 | + def test_varying_treatment_within_unit_raises(self): |
| 746 | + """Unit whose treatment switches over time should raise ValueError.""" |
| 747 | + np.random.seed(42) |
| 748 | + data = pd.DataFrame({ |
| 749 | + "unit": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], |
| 750 | + "time": [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4], |
| 751 | + "outcome": np.random.randn(12), |
| 752 | + # Unit 1: treatment turns on at time 3 (staggered) |
| 753 | + "treated": [0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], |
| 754 | + }) |
| 755 | + sdid = SyntheticDiD() |
| 756 | + with pytest.raises(ValueError, match="Treatment indicator varies within"): |
| 757 | + sdid.fit( |
| 758 | + data, outcome="outcome", treatment="treated", |
| 759 | + unit="unit", time="time", post_periods=[3, 4], |
| 760 | + ) |
| 761 | + |
| 762 | + def test_constant_treatment_passes(self): |
| 763 | + """Normal block-treatment data should pass validation.""" |
| 764 | + np.random.seed(42) |
| 765 | + n_units, n_periods = 10, 8 |
| 766 | + rows = [] |
| 767 | + for u in range(n_units): |
| 768 | + is_treated = 1 if u < 3 else 0 |
| 769 | + for t in range(n_periods): |
| 770 | + rows.append({ |
| 771 | + "unit": u, "time": t, |
| 772 | + "outcome": np.random.randn() + (2.0 if is_treated and t >= 5 else 0), |
| 773 | + "treated": is_treated, |
| 774 | + }) |
| 775 | + data = pd.DataFrame(rows) |
| 776 | + sdid = SyntheticDiD() |
| 777 | + result = sdid.fit( |
| 778 | + data, outcome="outcome", treatment="treated", |
| 779 | + unit="unit", time="time", post_periods=[5, 6, 7], |
| 780 | + ) |
| 781 | + assert result is not None |
| 782 | + |
| 783 | + |
| 784 | +# ============================================================================= |
| 785 | +# Balanced Panel Validation |
| 786 | +# ============================================================================= |
| 787 | + |
| 788 | + |
| 789 | +class TestBalancedPanelValidation: |
| 790 | + """Test that SDID rejects unbalanced panels.""" |
| 791 | + |
| 792 | + def test_unbalanced_panel_raises(self): |
| 793 | + """Unit missing a period should raise ValueError.""" |
| 794 | + np.random.seed(42) |
| 795 | + rows = [] |
| 796 | + for u in range(6): |
| 797 | + is_treated = 1 if u < 2 else 0 |
| 798 | + for t in range(5): |
| 799 | + rows.append({ |
| 800 | + "unit": u, "time": t, |
| 801 | + "outcome": np.random.randn(), |
| 802 | + "treated": is_treated, |
| 803 | + }) |
| 804 | + data = pd.DataFrame(rows) |
| 805 | + # Drop one observation to make panel unbalanced |
| 806 | + data = data[~((data["unit"] == 3) & (data["time"] == 2))].reset_index(drop=True) |
| 807 | + |
| 808 | + sdid = SyntheticDiD() |
| 809 | + with pytest.raises(ValueError, match="Panel is not balanced"): |
| 810 | + sdid.fit( |
| 811 | + data, outcome="outcome", treatment="treated", |
| 812 | + unit="unit", time="time", post_periods=[3, 4], |
| 813 | + ) |
| 814 | + |
| 815 | + def test_balanced_panel_passes(self): |
| 816 | + """Fully balanced panel should pass validation.""" |
| 817 | + np.random.seed(42) |
| 818 | + rows = [] |
| 819 | + for u in range(8): |
| 820 | + is_treated = 1 if u < 2 else 0 |
| 821 | + for t in range(6): |
| 822 | + rows.append({ |
| 823 | + "unit": u, "time": t, |
| 824 | + "outcome": np.random.randn() + (1.5 if is_treated and t >= 4 else 0), |
| 825 | + "treated": is_treated, |
| 826 | + }) |
| 827 | + data = pd.DataFrame(rows) |
| 828 | + sdid = SyntheticDiD() |
| 829 | + result = sdid.fit( |
| 830 | + data, outcome="outcome", treatment="treated", |
| 831 | + unit="unit", time="time", post_periods=[4, 5], |
| 832 | + ) |
| 833 | + assert result is not None |
| 834 | + |
| 835 | + |
| 836 | +# ============================================================================= |
| 837 | +# Pre-treatment Fit Warning |
| 838 | +# ============================================================================= |
| 839 | + |
| 840 | + |
| 841 | +class TestPreTreatmentFitWarning: |
| 842 | + """Test that poor pre-treatment fit emits a warning.""" |
| 843 | + |
| 844 | + def test_poor_fit_emits_warning(self): |
| 845 | + """Treated units at very different level from controls should warn.""" |
| 846 | + np.random.seed(42) |
| 847 | + rows = [] |
| 848 | + for u in range(10): |
| 849 | + is_treated = 1 if u < 2 else 0 |
| 850 | + # Large level difference: treated ~100, control ~10 |
| 851 | + level = 100.0 if is_treated else 10.0 |
| 852 | + for t in range(8): |
| 853 | + rows.append({ |
| 854 | + "unit": u, "time": t, |
| 855 | + "outcome": level + np.random.randn() * 0.5, |
| 856 | + "treated": is_treated, |
| 857 | + }) |
| 858 | + data = pd.DataFrame(rows) |
| 859 | + sdid = SyntheticDiD() |
| 860 | + with warnings.catch_warnings(record=True) as w: |
| 861 | + warnings.simplefilter("always") |
| 862 | + sdid.fit( |
| 863 | + data, outcome="outcome", treatment="treated", |
| 864 | + unit="unit", time="time", post_periods=[6, 7], |
| 865 | + ) |
| 866 | + fit_warnings = [x for x in w if "Pre-treatment fit is poor" in str(x.message)] |
| 867 | + assert len(fit_warnings) >= 1, ( |
| 868 | + "Expected warning about poor pre-treatment fit but none was raised" |
| 869 | + ) |
| 870 | + |
| 871 | + def test_good_fit_no_warning(self): |
| 872 | + """Parallel trends data with similar levels should not warn.""" |
| 873 | + np.random.seed(42) |
| 874 | + rows = [] |
| 875 | + for u in range(10): |
| 876 | + is_treated = 1 if u < 3 else 0 |
| 877 | + for t in range(8): |
| 878 | + # Same level, parallel trends, treatment effect only in post |
| 879 | + rows.append({ |
| 880 | + "unit": u, "time": t, |
| 881 | + "outcome": t + np.random.randn() * 0.3 + (2.0 if is_treated and t >= 5 else 0), |
| 882 | + "treated": is_treated, |
| 883 | + }) |
| 884 | + data = pd.DataFrame(rows) |
| 885 | + sdid = SyntheticDiD() |
| 886 | + with warnings.catch_warnings(record=True) as w: |
| 887 | + warnings.simplefilter("always") |
| 888 | + sdid.fit( |
| 889 | + data, outcome="outcome", treatment="treated", |
| 890 | + unit="unit", time="time", post_periods=[5, 6, 7], |
| 891 | + ) |
| 892 | + fit_warnings = [x for x in w if "Pre-treatment fit is poor" in str(x.message)] |
| 893 | + assert len(fit_warnings) == 0, ( |
| 894 | + f"Unexpected pre-treatment fit warning: {fit_warnings[0].message}" |
| 895 | + ) |
0 commit comments