44import time
55import math
66import multiprocessing
7- from typing import List , Optional , Union , Generator , Sequence , Iterator , Deque
7+ from typing import List , Optional , Union , Generator , Sequence , Iterator , Deque , Tuple
88from collections import deque
99
1010from . import llama_cpp
@@ -15,15 +15,34 @@ class LlamaCache:
1515 """Cache for a llama.cpp model."""
1616
1717 def __init__ (self ):
18- self .cache_state : Dict [Sequence [llama_cpp .llama_token ], "LlamaState" ] = dict ()
18+ self .cache_state : Dict [Tuple [llama_cpp .llama_token , ...], "LlamaState" ] = dict ()
19+
20+ def _sorted_keys (self ) -> List [Tuple [llama_cpp .llama_token , ...]]:
21+ return [
22+ key
23+ for _ , key in sorted (
24+ ((len (key ), key ) for key in self .cache_state .keys ()), reverse = True
25+ )
26+ ]
27+
28+ def _find_key (
29+ self , key : Tuple [llama_cpp .llama_token , ...]
30+ ) -> Optional [Tuple [llama_cpp .llama_token , ...]]:
31+ for k in self ._sorted_keys ():
32+ if key [: len (k )] == k :
33+ return k
34+ return None
1935
2036 def __getitem__ (
2137 self , key : Sequence [llama_cpp .llama_token ]
2238 ) -> Optional ["LlamaState" ]:
23- return self .cache_state .get (tuple (key ), None )
39+ _key = self ._find_key (tuple (key ))
40+ if _key is None :
41+ return None
42+ return self .cache_state [_key ]
2443
2544 def __contains__ (self , key : Sequence [llama_cpp .llama_token ]) -> bool :
26- return tuple (key ) in self . cache_state
45+ return self . _find_key ( tuple (key )) is not None
2746
2847 def __setitem__ (self , key : Sequence [llama_cpp .llama_token ], value : "LlamaState" ):
2948 self .cache_state = dict () # NOTE: Currently limit to one cache entry.
@@ -295,7 +314,7 @@ def generate(
295314 if (
296315 reset
297316 and len (self .eval_tokens ) > 0
298- and self .eval_tokens == tokens [: len (self .eval_tokens )]
317+ and tuple ( self .eval_tokens ) == tuple ( tokens [: len (self .eval_tokens )])
299318 ):
300319 if self .verbose :
301320 print ("generate cache hit" , file = sys .stderr )
@@ -438,6 +457,8 @@ def _create_completion(
438457
439458 if self .cache and len (completion_tokens ) == 0 :
440459 if prompt_tokens not in self .cache :
460+ if self .verbose :
461+ print ("cache miss" , file = sys .stderr )
441462 self .cache [prompt_tokens ] = self .save_state ()
442463
443464 completion_tokens .append (token )
0 commit comments