Skip to content

Commit 8a4dfbb

Browse files
Imrpove caching to speed up taxbenefitsystem load (#414)
1 parent ad4a406 commit 8a4dfbb

5 files changed

Lines changed: 87 additions & 30 deletions

File tree

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: patch
2+
changes:
3+
fixed:
4+
- Optimisation improvements for loading tax-benefit systems (caching).

policyengine_core/parameters/at_instant_like.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from policyengine_core import periods
55
from policyengine_core.periods import Instant
66

7+
# Cache for instant -> string conversions used in get_at_instant
8+
_instant_str_cache: dict = {}
9+
710

811
class AtInstantLike(abc.ABC):
912
"""
@@ -14,8 +17,34 @@ def __call__(self, instant: Instant) -> Any:
1417
return self.get_at_instant(instant)
1518

1619
def get_at_instant(self, instant: Instant) -> Any:
17-
instant = str(periods.instant(instant))
18-
return self._get_at_instant(instant)
20+
# Fast path for Instant objects - use their __str__ which is cached
21+
if isinstance(instant, Instant):
22+
return self._get_at_instant(str(instant))
23+
24+
# For other types, use a cache to avoid repeated conversions
25+
# Create a hashable cache key
26+
cache_key = None
27+
if isinstance(instant, str):
28+
cache_key = instant
29+
elif isinstance(instant, tuple):
30+
cache_key = instant
31+
elif isinstance(instant, int):
32+
cache_key = (instant,)
33+
elif hasattr(instant, "year"): # datetime.date
34+
cache_key = (instant.year, instant.month, instant.day)
35+
36+
if cache_key is not None:
37+
cached_str = _instant_str_cache.get(cache_key)
38+
if cached_str is not None:
39+
return self._get_at_instant(cached_str)
40+
instant_obj = periods.instant(instant)
41+
instant_str = str(instant_obj)
42+
_instant_str_cache[cache_key] = instant_str
43+
return self._get_at_instant(instant_str)
44+
45+
# Fallback for other types (Period, list, etc.)
46+
instant_str = str(periods.instant(instant))
47+
return self._get_at_instant(instant_str)
1948

2049
@abc.abstractmethod
2150
def _get_at_instant(self, instant): ...

policyengine_core/parameters/operations/uprate_parameters.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,26 +124,33 @@ def uprate_parameters(root: ParameterNode) -> ParameterNode:
124124
parameter.values_list[0].instant_str
125125
)
126126

127+
# Pre-compute values that don't change in the loop
128+
last_instant_str = str(last_instant)
129+
value_at_start = parameter(last_instant)
130+
uprater_at_start = uprating_parameter(last_instant)
131+
132+
if uprater_at_start is None:
133+
raise ValueError(
134+
f"Failed to uprate using {uprating_parameter.name} at {last_instant} for {parameter.name} because the uprating parameter is not defined at {last_instant}."
135+
)
136+
137+
# Pre-compute uprater values for all entries to avoid repeated lookups
138+
has_rounding = "rounding" in meta
139+
127140
# For each defined instant in the uprating parameter
128141
for entry in uprating_parameter.values_list[::-1]:
129142
entry_instant = instant(entry.instant_str)
130143
# If the uprater instant is defined after the last parameter instant
131144
if entry_instant > last_instant:
132145
# Apply the uprater and add to the parameter
133-
value_at_start = parameter(last_instant)
134-
uprater_at_start = uprating_parameter(last_instant)
135-
if uprater_at_start is None:
136-
raise ValueError(
137-
f"Failed to uprate using {uprating_parameter.name} at {last_instant} for {parameter.name} at {entry_instant} because the uprating parameter is not defined at {last_instant}."
138-
)
139146
uprater_at_entry = uprating_parameter(
140147
entry_instant
141148
)
142149
uprater_change = (
143150
uprater_at_entry / uprater_at_start
144151
)
145152
uprated_value = value_at_start * uprater_change
146-
if "rounding" in meta:
153+
if has_rounding:
147154
uprated_value = round_uprated_value(
148155
meta, uprated_value
149156
)

policyengine_core/periods/helpers.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
from policyengine_core import periods
66
from policyengine_core.periods import config
77

8+
# Global cache for instant objects to avoid repeated tuple creation
9+
_instant_cache: dict = {}
810

9-
@lru_cache(maxsize=1024)
11+
12+
@lru_cache(maxsize=10000)
1013
def _instant_from_string(instant_str: str) -> "periods.Instant":
1114
"""Cached parsing of instant strings."""
1215
if not config.INSTANT_PATTERN.match(instant_str):
@@ -48,18 +51,35 @@ def instant(instant):
4851
return instant
4952
if isinstance(instant, str):
5053
return _instant_from_string(instant)
54+
55+
# For other types, create a cache key and check the cache
56+
cache_key = None
57+
# Check Period before tuple since Period is a subclass of tuple
58+
if isinstance(instant, periods.Period):
59+
return instant.start
5160
elif isinstance(instant, datetime.date):
52-
instant = periods.Instant((instant.year, instant.month, instant.day))
61+
cache_key = (instant.year, instant.month, instant.day)
5362
elif isinstance(instant, int):
54-
instant = (instant,)
55-
elif isinstance(instant, list):
56-
assert 1 <= len(instant) <= 3
57-
instant = tuple(instant)
58-
elif isinstance(instant, periods.Period):
59-
instant = instant.start
60-
else:
61-
assert isinstance(instant, tuple), instant
62-
assert 1 <= len(instant) <= 3
63+
cache_key = (instant, 1, 1)
64+
elif isinstance(instant, (tuple, list)):
65+
if len(instant) == 1:
66+
cache_key = (instant[0], 1, 1)
67+
elif len(instant) == 2:
68+
cache_key = (instant[0], instant[1], 1)
69+
elif len(instant) == 3:
70+
cache_key = tuple(instant)
71+
72+
if cache_key is not None:
73+
cached = _instant_cache.get(cache_key)
74+
if cached is not None:
75+
return cached
76+
result = periods.Instant(cache_key)
77+
_instant_cache[cache_key] = result
78+
return result
79+
80+
# Fallback for unexpected types
81+
assert isinstance(instant, tuple), instant
82+
assert 1 <= len(instant) <= 3
6383
if len(instant) == 1:
6484
return periods.Instant((instant[0], 1, 1))
6585
if len(instant) == 2:

policyengine_core/periods/period_.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import calendar
4-
from datetime import datetime
4+
from datetime import datetime, date, timedelta
55
from typing import List
66

77
from policyengine_core import periods
@@ -463,15 +463,12 @@ def stop(self) -> periods.Instant:
463463
return periods.Instant((float("inf"), float("inf"), float("inf")))
464464
if unit == "day":
465465
if size > 1:
466-
day += size - 1
467-
month_last_day = calendar.monthrange(year, month)[1]
468-
while day > month_last_day:
469-
month += 1
470-
if month == 13:
471-
year += 1
472-
month = 1
473-
day -= month_last_day
474-
month_last_day = calendar.monthrange(year, month)[1]
466+
# Use datetime arithmetic for efficient day calculation
467+
start_date = date(year, month, day)
468+
end_date = start_date + timedelta(days=size - 1)
469+
return periods.Instant(
470+
(end_date.year, end_date.month, end_date.day)
471+
)
475472
else:
476473
if unit == "month":
477474
month += size

0 commit comments

Comments
 (0)