22from functools import cache , partial
33
44from aoc import main
5- from aoc .collections import Bitmask
65from aoc .graph_dyn import Edges
76
87
@@ -14,27 +13,23 @@ def parse(input: str) -> Edges[str]:
1413 return edges
1514
1615
17- def count_paths (input : str , src : str , dst : str , via : list [str ] | None = None ) -> int :
16+ def count_paths (input : str , * , src : str , dst : str , via : set [str ] | None = None ) -> int :
1817 edges = parse (input )
19-
20- if via is None :
21- via = []
22- via_mask = {k : i for i , k in enumerate (via )}
23- goal = Bitmask .from_list (via_mask .values ())
18+ via = via or set ()
2419
2520 @cache
26- def count (cur : str , seen : Bitmask ) -> int :
21+ def count (cur : str , seen : frozenset [ str ] ) -> int :
2722 if cur == dst :
28- return 1 if seen == goal else 0
29- if cur in via_mask :
30- seen = seen . on ( via_mask [ cur ])
23+ return 1 if seen == via else 0
24+ if cur in via :
25+ seen = seen | { cur }
3126 return sum (count (n , seen ) for n in edges [cur ])
3227
33- return count (src , Bitmask ())
28+ return count (src , frozenset ())
3429
3530
3631if __name__ == "__main__" :
3732 main (
3833 partial (count_paths , src = "you" , dst = "out" ),
39- partial (count_paths , src = "svr" , dst = "out" , via = [ "dac" , "fft" ] ),
34+ partial (count_paths , src = "svr" , dst = "out" , via = { "dac" , "fft" } ),
4035 )
0 commit comments