Skip to content

Commit ff93ef4

Browse files
MaxGhenisclaude
andauthored
Fix _fast_cache invalidation bug and improve cache key performance (#448)
set_input() was not clearing _fast_cache entries, causing stale values to be returned on subsequent calculate() calls. Also switches cache keys from str(period) to Period objects (already hashable tuples) to avoid unnecessary string conversions, deduplicates enum LUT and structured dtype unification logic in vectorial parameter lookups, and uses the public trace property instead of the private _trace attr. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 1a8eeb5 commit ff93ef4

4 files changed

Lines changed: 191 additions & 66 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix _fast_cache invalidation bug in set_input and add cache tests.

policyengine_core/parameters/vectorial_parameter_node_at_instant.py

Lines changed: 50 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,45 @@
1212
from policyengine_core.parameters.parameter_node import ParameterNode
1313

1414

15+
def _build_enum_lut(enum, name_to_child_idx, sentinel, stringify_names=False):
16+
"""Build a lookup table mapping enum int codes to child indices."""
17+
enum_items = list(enum)
18+
max_code = max(item.index for item in enum_items) + 1
19+
lut = numpy.full(max_code, sentinel, dtype=numpy.intp)
20+
for item in enum_items:
21+
name = str(item.name) if stringify_names else item.name
22+
child_idx = name_to_child_idx.get(name)
23+
if child_idx is not None:
24+
lut[item.index] = child_idx
25+
return lut
26+
27+
28+
def _unify_structured_dtypes(values):
29+
"""Compute a unified dtype across structured arrays with potentially
30+
different fields, and cast all values to that dtype.
31+
32+
Returns (unified_dtype, all_fields, casted_values).
33+
"""
34+
all_fields = []
35+
seen = set()
36+
for val in values:
37+
for field in val.dtype.names:
38+
if field not in seen:
39+
all_fields.append(field)
40+
seen.add(field)
41+
42+
unified_dtype = numpy.dtype([(f, "<f8") for f in all_fields])
43+
44+
casted_values = []
45+
for val in values:
46+
casted = numpy.zeros(len(val), dtype=unified_dtype)
47+
for field in val.dtype.names:
48+
casted[field] = val[field]
49+
casted_values.append(casted)
50+
51+
return unified_dtype, all_fields, casted_values
52+
53+
1554
class VectorialParameterNodeAtInstant:
1655
"""
1756
Parameter node of the legislation at a given instant which has been vectorized.
@@ -205,13 +244,7 @@ def __getitem__(self, key: str) -> Any:
205244
self._enum_lut_cache = {}
206245
lut = self._enum_lut_cache.get(cache_key)
207246
if lut is None:
208-
enum_items = list(enum)
209-
max_code = max(item.index for item in enum_items) + 1
210-
lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp)
211-
for item in enum_items:
212-
child_idx = name_to_child_idx.get(item.name)
213-
if child_idx is not None:
214-
lut[item.index] = child_idx
247+
lut = _build_enum_lut(enum, name_to_child_idx, SENTINEL)
215248
self._enum_lut_cache[cache_key] = lut
216249
idx = lut[numpy.asarray(key)]
217250
elif (
@@ -224,13 +257,9 @@ def __getitem__(self, key: str) -> Any:
224257
self._enum_lut_cache = {}
225258
lut = self._enum_lut_cache.get(cache_key)
226259
if lut is None:
227-
enum_items = list(enum)
228-
max_code = max(item.index for item in enum_items) + 1
229-
lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp)
230-
for item in enum_items:
231-
child_idx = name_to_child_idx.get(str(item.name))
232-
if child_idx is not None:
233-
lut[item.index] = child_idx
260+
lut = _build_enum_lut(
261+
enum, name_to_child_idx, SENTINEL, stringify_names=True
262+
)
234263
self._enum_lut_cache[cache_key] = lut
235264
codes = numpy.array([v.index for v in key], dtype=numpy.intp)
236265
idx = lut[codes]
@@ -262,23 +291,9 @@ def __getitem__(self, key: str) -> Any:
262291
if v0_len <= 1:
263292
# 1-element structured arrays: simple concat + index
264293
if not dtypes_match:
265-
all_fields = []
266-
seen = set()
267-
for val in values:
268-
for field in val.dtype.names:
269-
if field not in seen:
270-
all_fields.append(field)
271-
seen.add(field)
272-
273-
unified_dtype = numpy.dtype([(f, "<f8") for f in all_fields])
274-
275-
values_cast = []
276-
for val in values:
277-
casted = numpy.zeros(len(val), dtype=unified_dtype)
278-
for field in val.dtype.names:
279-
casted[field] = val[field]
280-
values_cast.append(casted)
281-
294+
unified_dtype, all_fields, values_cast = (
295+
_unify_structured_dtypes(values)
296+
)
282297
default = numpy.zeros(1, dtype=unified_dtype)
283298
for field in unified_dtype.names:
284299
default[field] = numpy.nan
@@ -302,22 +317,9 @@ def __getitem__(self, key: str) -> Any:
302317
# Nested structured: fall back to numpy.select
303318
conditions = [idx == i for i in range(len(values))]
304319
if not dtypes_match:
305-
all_fields = []
306-
seen = set()
307-
for val in values:
308-
for field in val.dtype.names:
309-
if field not in seen:
310-
all_fields.append(field)
311-
seen.add(field)
312-
unified_dtype = numpy.dtype(
313-
[(f, "<f8") for f in all_fields]
320+
unified_dtype, all_fields, values_cast = (
321+
_unify_structured_dtypes(values)
314322
)
315-
values_cast = []
316-
for val in values:
317-
casted = numpy.zeros(len(val), dtype=unified_dtype)
318-
for field in val.dtype.names:
319-
casted[field] = val[field]
320-
values_cast.append(casted)
321323
default = numpy.zeros(v0_len, dtype=unified_dtype)
322324
for field in unified_dtype.names:
323325
default[field] = numpy.nan
@@ -328,22 +330,9 @@ def __getitem__(self, key: str) -> Any:
328330
else:
329331
# Flat structured: fast per-field indexing
330332
if not dtypes_match:
331-
all_fields = []
332-
seen = set()
333-
for val in values:
334-
for field in val.dtype.names:
335-
if field not in seen:
336-
all_fields.append(field)
337-
seen.add(field)
338-
unified_dtype = numpy.dtype(
339-
[(f, "<f8") for f in all_fields]
333+
unified_dtype, all_fields, values_unified = (
334+
_unify_structured_dtypes(values)
340335
)
341-
values_unified = []
342-
for val in values:
343-
casted = numpy.zeros(len(val), dtype=unified_dtype)
344-
for field in val.dtype.names:
345-
casted[field] = val[field]
346-
values_unified.append(casted)
347336
field_names = all_fields
348337
result_dtype = unified_dtype
349338
else:

policyengine_core/simulations/simulation.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -452,8 +452,8 @@ def calculate(
452452
# Fast path: skip tracer, random seed and all _calculate() machinery for
453453
# already-computed values. map_to and decode_enums are NOT cached here —
454454
# they are post-processing steps that vary per call site.
455-
if map_to is None and not decode_enums and not getattr(self, "_trace", False):
456-
_fast_key = (variable_name, str(period))
455+
if map_to is None and not decode_enums and not getattr(self, "trace", False):
456+
_fast_key = (variable_name, period)
457457
_fast_cache = getattr(self, "_fast_cache", None)
458458
if _fast_cache is not None:
459459
_cached = _fast_cache.get(_fast_key)
@@ -765,7 +765,7 @@ def _calculate(self, variable_name: str, period: Period = None) -> ArrayLike:
765765
smc.set_cache_value(cache_path, array)
766766

767767
if hasattr(self, "_fast_cache"):
768-
self._fast_cache[(variable_name, str(period))] = array
768+
self._fast_cache[(variable_name, period)] = array
769769

770770
return array
771771

@@ -776,7 +776,7 @@ def purge_cache_of_invalid_values(self) -> None:
776776
for _name, _period in self.invalidated_caches:
777777
holder = self.get_holder(_name)
778778
holder.delete_arrays(_period)
779-
self._fast_cache.pop((_name, str(_period)), None)
779+
self._fast_cache.pop((_name, _period), None)
780780
self.invalidated_caches = set()
781781

782782
def calculate_add(
@@ -1142,7 +1142,9 @@ def delete_arrays(self, variable: str, period: Period = None) -> None:
11421142
k: v for k, v in self._fast_cache.items() if k[0] != variable
11431143
}
11441144
else:
1145-
self._fast_cache.pop((variable, str(period)), None)
1145+
if not isinstance(period, Period):
1146+
period = periods.period(period)
1147+
self._fast_cache.pop((variable, period), None)
11461148

11471149
def get_known_periods(self, variable: str) -> List[Period]:
11481150
"""
@@ -1187,6 +1189,7 @@ def set_input(self, variable_name: str, period: Period, value: ArrayLike) -> Non
11871189
if (variable.end is not None) and (period.start.date > variable.end):
11881190
return
11891191
self.get_holder(variable_name).set_input(period, value, self.branch_name)
1192+
self._fast_cache.pop((variable_name, period), None)
11901193

11911194
def get_variable_population(self, variable_name: str) -> Population:
11921195
variable = self.tax_benefit_system.get_variable(

tests/core/test_fast_cache.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""Tests for the _fast_cache mechanism in Simulation."""
2+
3+
import numpy as np
4+
from policyengine_core.simulations import SimulationBuilder
5+
6+
7+
def _make_simulation(tax_benefit_system, salary=3000):
8+
"""Build a simple simulation with one person and a salary."""
9+
return SimulationBuilder().build_from_entities(
10+
tax_benefit_system,
11+
{
12+
"persons": {
13+
"bill": {"salary": {"2017-01": salary}},
14+
},
15+
"households": {"household": {"parents": ["bill"]}},
16+
},
17+
)
18+
19+
20+
def test_fast_cache_returns_cached_value(tax_benefit_system):
21+
"""Second calculate() for a computed variable should return
22+
the cached value from _fast_cache without recomputing."""
23+
sim = _make_simulation(tax_benefit_system)
24+
25+
# income_tax is computed via formula, so it enters _fast_cache
26+
result1 = sim.calculate("income_tax", "2017-01")
27+
assert len(sim._fast_cache) > 0
28+
29+
result2 = sim.calculate("income_tax", "2017-01")
30+
# Must be the exact same object (identity check proves cache hit)
31+
assert result1 is result2
32+
33+
34+
def test_fast_cache_invalidated_after_set_input(tax_benefit_system):
35+
"""set_input() must evict the stale _fast_cache entry so the next
36+
calculate() returns the new value."""
37+
sim = _make_simulation(tax_benefit_system)
38+
39+
# Populate the cache with a computed variable
40+
result1 = sim.calculate("income_tax", "2017-01")
41+
assert len(sim._fast_cache) > 0
42+
old_val = result1[0]
43+
44+
# Overwrite income_tax with a direct value
45+
sim.set_input("income_tax", "2017-01", np.array([9999.0]))
46+
47+
# The cache entry for income_tax must be gone
48+
result2 = sim.calculate("income_tax", "2017-01")
49+
assert np.isclose(result2[0], 9999.0), (
50+
f"Expected 9999.0 after set_input, got {result2[0]} (stale cache bug)"
51+
)
52+
53+
54+
def test_fast_cache_invalidated_after_delete_arrays_with_period(
55+
tax_benefit_system,
56+
):
57+
"""delete_arrays(variable, period) must evict that specific
58+
_fast_cache entry."""
59+
sim = _make_simulation(tax_benefit_system)
60+
61+
sim.calculate("income_tax", "2017-01")
62+
assert len(sim._fast_cache) > 0
63+
64+
sim.delete_arrays("income_tax", "2017-01")
65+
66+
matching = [k for k in sim._fast_cache if k[0] == "income_tax"]
67+
assert len(matching) == 0
68+
69+
70+
def test_fast_cache_invalidated_after_delete_arrays_all_periods(
71+
tax_benefit_system,
72+
):
73+
"""delete_arrays(variable) with no period must evict ALL
74+
_fast_cache entries for that variable."""
75+
sim = _make_simulation(tax_benefit_system)
76+
77+
sim.calculate("income_tax", "2017-01")
78+
assert len(sim._fast_cache) > 0
79+
80+
sim.delete_arrays("income_tax")
81+
82+
matching = [k for k in sim._fast_cache if k[0] == "income_tax"]
83+
assert len(matching) == 0
84+
85+
86+
def test_fast_cache_empty_after_clone(tax_benefit_system):
87+
"""clone() must produce a simulation with an empty _fast_cache."""
88+
sim = _make_simulation(tax_benefit_system)
89+
90+
sim.calculate("income_tax", "2017-01")
91+
assert len(sim._fast_cache) > 0
92+
93+
cloned = sim.clone()
94+
assert len(cloned._fast_cache) == 0
95+
96+
97+
def test_fast_cache_invalidated_after_purge_cache(tax_benefit_system):
98+
"""purge_cache_of_invalid_values() must remove entries listed in
99+
invalidated_caches from _fast_cache."""
100+
sim = _make_simulation(tax_benefit_system)
101+
102+
sim.calculate("income_tax", "2017-01")
103+
assert len(sim._fast_cache) > 0
104+
105+
# Manually mark the entry as invalidated (simulating what the
106+
# framework does during dependency tracking)
107+
from policyengine_core.periods import period as make_period
108+
109+
sim.invalidated_caches.add(("income_tax", make_period("2017-01")))
110+
# The stack must be empty for purge to fire
111+
sim.tracer._stack.clear()
112+
sim.purge_cache_of_invalid_values()
113+
114+
matching = [k for k in sim._fast_cache if k[0] == "income_tax"]
115+
assert len(matching) == 0
116+
117+
118+
def test_fast_cache_uses_period_not_str_as_key(tax_benefit_system):
119+
"""_fast_cache keys should use Period objects, not str(period),
120+
to avoid unnecessary string conversions."""
121+
sim = _make_simulation(tax_benefit_system)
122+
sim.calculate("income_tax", "2017-01")
123+
124+
# All keys should have Period as second element, not str
125+
for key in sim._fast_cache:
126+
variable_name, period_key = key
127+
assert not isinstance(period_key, str), (
128+
f"Expected Period as cache key, got str: {period_key!r}"
129+
)
130+
assert isinstance(period_key, tuple), (
131+
f"Period should be a tuple subclass, got {type(period_key)}"
132+
)

0 commit comments

Comments
 (0)