-
Notifications
You must be signed in to change notification settings - Fork 219
[Opt] Implement caching for NgramHashMapping creation #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Added caching mechanism for NgramHashMapping to optimize performance.
dino65-dev
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall: Good optimization — caching NgramHashMapping avoids redundant tokenizer loading and hash computation across Engram layers. A few robustness issues to address:
Mutable reference bug (L317-338): layer_ids is passed by reference. Pass tuple(layer_ids) to avoid silent corruption if caller mutates the list.
Unbounded cache (L305-340): _HASH_MAPPING_CACHE can grow indefinitely, holding large tokenizers. Consider @lru_cache(maxsize=N) or add a clear_cache() helper.
No thread safety (L328-339): Race condition in check-then-set. Add threading.Lock or use @lru_cache (handles locking internally).
Suggested refactor (addresses all 3):
from functools import lru_cache
@lru_cache(maxsize=8)
def get_or_create_hash_mapping(
engram_vocab_size, # pass as tuple
max_ngram_size,
n_embed_per_ngram,
n_head_per_ngram,
layer_ids, # pass as tuple
tokenizer_name_or_path,
pad_id,
seed,
):
return NgramHashMapping(
engram_vocab_size=engram_vocab_size,
max_ngram_size=max_ngram_size,
n_embed_per_ngram=n_embed_per_ngram,
n_head_per_ngram=n_head_per_ngram,
layer_ids=layer_ids,
tokenizer_name_or_path=tokenizer_name_or_path,
pad_id=pad_id,
seed=seed,
)Callers need to pass tuples: tuple(engram_vocab_size), tuple(layer_ids).
| cache_key = ( | ||
| tuple(engram_vocab_size), | ||
| max_ngram_size, | ||
| n_embed_per_ngram, | ||
| n_head_per_ngram, | ||
| tuple(layer_ids), | ||
| tokenizer_name_or_path, | ||
| pad_id, | ||
| seed, | ||
| ) | ||
|
|
||
| if cache_key not in _HASH_MAPPING_CACHE: | ||
| _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping( | ||
| engram_vocab_size=engram_vocab_size, | ||
| max_ngram_size=max_ngram_size, | ||
| n_embed_per_ngram=n_embed_per_ngram, | ||
| n_head_per_ngram=n_head_per_ngram, | ||
| layer_ids=layer_ids, | ||
| tokenizer_name_or_path=tokenizer_name_or_path, | ||
| pad_id=pad_id, | ||
| seed=seed, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issue: layer_ids is passed by reference to NgramHashMapping, but the cache key uses tuple(layer_ids). If the caller mutates the list later, the cached instance silently uses stale data.
Fix: Pass an immutable copy:
| cache_key = ( | |
| tuple(engram_vocab_size), | |
| max_ngram_size, | |
| n_embed_per_ngram, | |
| n_head_per_ngram, | |
| tuple(layer_ids), | |
| tokenizer_name_or_path, | |
| pad_id, | |
| seed, | |
| ) | |
| if cache_key not in _HASH_MAPPING_CACHE: | |
| _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping( | |
| engram_vocab_size=engram_vocab_size, | |
| max_ngram_size=max_ngram_size, | |
| n_embed_per_ngram=n_embed_per_ngram, | |
| n_head_per_ngram=n_head_per_ngram, | |
| layer_ids=layer_ids, | |
| tokenizer_name_or_path=tokenizer_name_or_path, | |
| pad_id=pad_id, | |
| seed=seed, | |
| ) | |
| _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping( | |
| engram_vocab_size=engram_vocab_size, | |
| max_ngram_size=max_ngram_size, | |
| n_embed_per_ngram=n_embed_per_ngram, | |
| n_head_per_ngram=n_head_per_ngram, | |
| layer_ids=tuple(layer_ids), # <- immutable copy | |
| tokenizer_name_or_path=tokenizer_name_or_path, | |
| pad_id=pad_id, | |
| seed=seed, | |
| ) |
| _HASH_MAPPING_CACHE = {} | ||
| # Ensures that an NgramHashMapping with identical configuration is created only once. | ||
| def get_or_create_hash_mapping( | ||
| engram_vocab_size, | ||
| max_ngram_size, | ||
| n_embed_per_ngram, | ||
| n_head_per_ngram, | ||
| layer_ids, | ||
| tokenizer_name_or_path, | ||
| pad_id, | ||
| seed, | ||
| ): | ||
| cache_key = ( | ||
| tuple(engram_vocab_size), | ||
| max_ngram_size, | ||
| n_embed_per_ngram, | ||
| n_head_per_ngram, | ||
| tuple(layer_ids), | ||
| tokenizer_name_or_path, | ||
| pad_id, | ||
| seed, | ||
| ) | ||
|
|
||
| if cache_key not in _HASH_MAPPING_CACHE: | ||
| _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping( | ||
| engram_vocab_size=engram_vocab_size, | ||
| max_ngram_size=max_ngram_size, | ||
| n_embed_per_ngram=n_embed_per_ngram, | ||
| n_head_per_ngram=n_head_per_ngram, | ||
| layer_ids=layer_ids, | ||
| tokenizer_name_or_path=tokenizer_name_or_path, | ||
| pad_id=pad_id, | ||
| seed=seed, | ||
| ) | ||
|
|
||
| return _HASH_MAPPING_CACHE[cache_key] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issue: _HASH_MAPPING_CACHE is unbounded. Each entry holds a HuggingFace tokenizer + lookup tables. In long-running processes or hyperparameter sweeps, this can grow indefinitely and OOM.
Fix: Use lru_cache with a size limit (also handles thread safety):
| _HASH_MAPPING_CACHE = {} | |
| # Ensures that an NgramHashMapping with identical configuration is created only once. | |
| def get_or_create_hash_mapping( | |
| engram_vocab_size, | |
| max_ngram_size, | |
| n_embed_per_ngram, | |
| n_head_per_ngram, | |
| layer_ids, | |
| tokenizer_name_or_path, | |
| pad_id, | |
| seed, | |
| ): | |
| cache_key = ( | |
| tuple(engram_vocab_size), | |
| max_ngram_size, | |
| n_embed_per_ngram, | |
| n_head_per_ngram, | |
| tuple(layer_ids), | |
| tokenizer_name_or_path, | |
| pad_id, | |
| seed, | |
| ) | |
| if cache_key not in _HASH_MAPPING_CACHE: | |
| _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping( | |
| engram_vocab_size=engram_vocab_size, | |
| max_ngram_size=max_ngram_size, | |
| n_embed_per_ngram=n_embed_per_ngram, | |
| n_head_per_ngram=n_head_per_ngram, | |
| layer_ids=layer_ids, | |
| tokenizer_name_or_path=tokenizer_name_or_path, | |
| pad_id=pad_id, | |
| seed=seed, | |
| ) | |
| return _HASH_MAPPING_CACHE[cache_key] | |
| from functools import lru_cache | |
| @lru_cache(maxsize=8) | |
| def get_or_create_hash_mapping( | |
| engram_vocab_size, # must be tuple, not list | |
| max_ngram_size, | |
| n_embed_per_ngram, | |
| n_head_per_ngram, | |
| layer_ids, # must be tuple, not list | |
| tokenizer_name_or_path, | |
| pad_id, | |
| seed, | |
| ): | |
| return NgramHashMapping( | |
| engram_vocab_size=engram_vocab_size, | |
| max_ngram_size=max_ngram_size, | |
| n_embed_per_ngram=n_embed_per_ngram, | |
| n_head_per_ngram=n_head_per_ngram, | |
| layer_ids=layer_ids, | |
| tokenizer_name_or_path=tokenizer_name_or_path, | |
| pad_id=pad_id, | |
| seed=seed, | |
| ) |
Callers must pass tuples instead of lists. You can also, keep the manual cache but add a clear_hash_mapping_cache() helper.
| if cache_key not in _HASH_MAPPING_CACHE: | ||
| _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping( | ||
| engram_vocab_size=engram_vocab_size, | ||
| max_ngram_size=max_ngram_size, | ||
| n_embed_per_ngram=n_embed_per_ngram, | ||
| n_head_per_ngram=n_head_per_ngram, | ||
| layer_ids=layer_ids, | ||
| tokenizer_name_or_path=tokenizer_name_or_path, | ||
| pad_id=pad_id, | ||
| seed=seed, | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Issue: The check-then-set pattern isn't thread-safe. Multiple threads can race past if cache_key not in _HASH_MAPPING_CACHE and redundantly create expensive NgramHashMapping instances (tokenizer loading, prime computation).
Fix: Add a lock:
| if cache_key not in _HASH_MAPPING_CACHE: | |
| _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping( | |
| engram_vocab_size=engram_vocab_size, | |
| max_ngram_size=max_ngram_size, | |
| n_embed_per_ngram=n_embed_per_ngram, | |
| n_head_per_ngram=n_head_per_ngram, | |
| layer_ids=layer_ids, | |
| tokenizer_name_or_path=tokenizer_name_or_path, | |
| pad_id=pad_id, | |
| seed=seed, | |
| ) | |
| import threading | |
| _HASH_MAPPING_CACHE = {} | |
| _HASH_MAPPING_LOCK = threading.Lock() | |
| def get_or_create_hash_mapping(...): | |
| cache_key = (...) | |
| with _HASH_MAPPING_LOCK: | |
| if cache_key not in _HASH_MAPPING_CACHE: | |
| _HASH_MAPPING_CACHE[cache_key] = NgramHashMapping(...) | |
| return _HASH_MAPPING_CACHE[cache_key] |
Or use @lru_cache which handles locking internally (also fixes the unbounded cache issue)
Added caching mechanism for NgramHashMapping to optimize performance.