11import re
2- from functools import cache , partial
2+ from functools import cache
33
44from aoc import main
5- from aoc .collections import Bitmask
65from aoc .graph_dyn import Edges
76
87
@@ -14,27 +13,24 @@ def parse(input: str) -> Edges[str]:
1413 return edges
1514
1615
17- def count_paths (input : str , src : str , dst : str , via : list [str ] | None = None ) -> int :
18- edges = parse (input )
19-
20- if via is None :
21- via = []
22- via_mask = {k : i for i , k in enumerate (via )}
23- goal = Bitmask .from_list (via_mask .values ())
16+ def count_paths (
17+ edges : Edges , src : str , dst : str , * , via : set [str ] | None = None
18+ ) -> int :
19+ via = via or set ()
2420
2521 @cache
26- def count (cur : str , seen : Bitmask ) -> int :
22+ def count (cur : str , seen : frozenset [ str ] ) -> int :
2723 if cur == dst :
28- return 1 if seen == goal else 0
29- if cur in via_mask :
30- seen = seen . on ( via_mask [ cur ])
24+ return 1 if seen == via else 0
25+ if cur in via :
26+ seen = seen | { cur }
3127 return sum (count (n , seen ) for n in edges [cur ])
3228
33- return count (src , Bitmask ())
29+ return count (src , frozenset ())
3430
3531
3632if __name__ == "__main__" :
3733 main (
38- partial ( count_paths , src = "you" , dst = "out" ),
39- partial ( count_paths , src = "svr" , dst = "out" , via = [ "dac" , "fft" ] ),
34+ lambda s : count_paths ( parse ( s ), "you" , "out" ),
35+ lambda s : count_paths ( parse ( s ), "svr" , "out" , via = { "dac" , "fft" } ),
4036 )
0 commit comments