@@ -14,36 +14,27 @@ def parse(input: str) -> Edges[str]:
1414 return edges
1515
1616
17- def you_to_out (input : str ) :
17+ def count_paths (input : str , src : str , dst : str , via : list [ str ] | None = None ) -> int :
1818 edges = parse (input )
1919
20- @cache
21- def count_paths (cur : str ) -> int :
22- if cur == "out" :
23- return 1
24- return sum (count_paths (n ) for n in edges [cur ])
25-
26- return count_paths ("you" )
27-
28-
29- def svr_to_out_via_dac_fft (input : str ):
30- edges = parse (input )
31- via = {"dac" : 0 , "fft" : 1 }
32- goal = Bitmask .from_list (via .values ())
20+ if via is None :
21+ via = []
22+ via_mask = {k : i for i , k in enumerate (via )}
23+ goal = Bitmask .from_list (via_mask .values ())
3324
3425 @cache
35- def count_paths (cur : str , seen : Bitmask ) -> int :
36- if cur == "out" :
26+ def count (cur : str , seen : Bitmask ) -> int :
27+ if cur == dst :
3728 return 1 if seen == goal else 0
38- if cur in via :
39- seen = seen .on (via [cur ])
40- return sum (count_paths (n , seen ) for n in edges [cur ])
29+ if cur in via_mask :
30+ seen = seen .on (via_mask [cur ])
31+ return sum (count (n , seen ) for n in edges [cur ])
4132
42- return count_paths ( "svr" , Bitmask ())
33+ return count ( src , Bitmask ())
4334
4435
4536if __name__ == "__main__" :
4637 main (
47- you_to_out ,
48- svr_to_out_via_dac_fft ,
38+ partial ( count_paths , src = "you" , dst = "out" ) ,
39+ partial ( count_paths , src = "svr" , dst = "out" , via = [ "dac" , "fft" ]) ,
4940 )
0 commit comments