Skip to content

Conversation

@yunkchen
Copy link

Added caching mechanism for NgramHashMapping to optimize performance.

Added caching mechanism for NgramHashMapping to optimize performance.
@yunkchen yunkchen changed the title Implement caching for NgramHashMapping creation [Opt] Implement caching for NgramHashMapping creation Jan 22, 2026
Copy link

@dino65-dev dino65-dev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall: Good optimization — caching NgramHashMapping avoids redundant tokenizer loading and hash computation across Engram layers. A few robustness issues to address:

Mutable reference bug (L317-338): layer_ids is passed by reference. Pass tuple(layer_ids) to avoid silent corruption if caller mutates the list.
Unbounded cache (L305-340): _HASH_MAPPING_CACHE can grow indefinitely, holding large tokenizers. Consider @lru_cache(maxsize=N) or add a clear_cache() helper.

No thread safety (L328-339): Race condition in check-then-set. Add threading.Lock or use @lru_cache (handles locking internally).

Suggested refactor (addresses all 3):

from functools import lru_cache
@lru_cache(maxsize=8)
def get_or_create_hash_mapping(
    engram_vocab_size,      # pass as tuple
    max_ngram_size,
    n_embed_per_ngram,
    n_head_per_ngram,
    layer_ids,              # pass as tuple
    tokenizer_name_or_path,
    pad_id,
    seed,
):
    return NgramHashMapping(
        engram_vocab_size=engram_vocab_size,
        max_ngram_size=max_ngram_size,
        n_embed_per_ngram=n_embed_per_ngram,
        n_head_per_ngram=n_head_per_ngram,
        layer_ids=layer_ids,
        tokenizer_name_or_path=tokenizer_name_or_path,
        pad_id=pad_id,
        seed=seed,
    )

Callers need to pass tuples: tuple(engram_vocab_size), tuple(layer_ids).

Comment on lines +317 to +338
cache_key = (
tuple(engram_vocab_size),
max_ngram_size,
n_embed_per_ngram,
n_head_per_ngram,
tuple(layer_ids),
tokenizer_name_or_path,
pad_id,
seed,
)

if cache_key not in _HASH_MAPPING_CACHE:
_HASH_MAPPING_CACHE[cache_key] = NgramHashMapping(
engram_vocab_size=engram_vocab_size,
max_ngram_size=max_ngram_size,
n_embed_per_ngram=n_embed_per_ngram,
n_head_per_ngram=n_head_per_ngram,
layer_ids=layer_ids,
tokenizer_name_or_path=tokenizer_name_or_path,
pad_id=pad_id,
seed=seed,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: layer_ids is passed by reference to NgramHashMapping, but the cache key uses tuple(layer_ids). If the caller mutates the list later, the cached instance silently uses stale data.
Fix: Pass an immutable copy:

Suggested change
cache_key = (
tuple(engram_vocab_size),
max_ngram_size,
n_embed_per_ngram,
n_head_per_ngram,
tuple(layer_ids),
tokenizer_name_or_path,
pad_id,
seed,
)
if cache_key not in _HASH_MAPPING_CACHE:
_HASH_MAPPING_CACHE[cache_key] = NgramHashMapping(
engram_vocab_size=engram_vocab_size,
max_ngram_size=max_ngram_size,
n_embed_per_ngram=n_embed_per_ngram,
n_head_per_ngram=n_head_per_ngram,
layer_ids=layer_ids,
tokenizer_name_or_path=tokenizer_name_or_path,
pad_id=pad_id,
seed=seed,
)
_HASH_MAPPING_CACHE[cache_key] = NgramHashMapping(
engram_vocab_size=engram_vocab_size,
max_ngram_size=max_ngram_size,
n_embed_per_ngram=n_embed_per_ngram,
n_head_per_ngram=n_head_per_ngram,
layer_ids=tuple(layer_ids), # <- immutable copy
tokenizer_name_or_path=tokenizer_name_or_path,
pad_id=pad_id,
seed=seed,
)

Comment on lines +305 to +340
_HASH_MAPPING_CACHE = {}
# Ensures that an NgramHashMapping with identical configuration is created only once.
def get_or_create_hash_mapping(
engram_vocab_size,
max_ngram_size,
n_embed_per_ngram,
n_head_per_ngram,
layer_ids,
tokenizer_name_or_path,
pad_id,
seed,
):
cache_key = (
tuple(engram_vocab_size),
max_ngram_size,
n_embed_per_ngram,
n_head_per_ngram,
tuple(layer_ids),
tokenizer_name_or_path,
pad_id,
seed,
)

if cache_key not in _HASH_MAPPING_CACHE:
_HASH_MAPPING_CACHE[cache_key] = NgramHashMapping(
engram_vocab_size=engram_vocab_size,
max_ngram_size=max_ngram_size,
n_embed_per_ngram=n_embed_per_ngram,
n_head_per_ngram=n_head_per_ngram,
layer_ids=layer_ids,
tokenizer_name_or_path=tokenizer_name_or_path,
pad_id=pad_id,
seed=seed,
)

return _HASH_MAPPING_CACHE[cache_key]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: _HASH_MAPPING_CACHE is unbounded. Each entry holds a HuggingFace tokenizer + lookup tables. In long-running processes or hyperparameter sweeps, this can grow indefinitely and OOM.

Fix: Use lru_cache with a size limit (also handles thread safety):

Suggested change
_HASH_MAPPING_CACHE = {}
# Ensures that an NgramHashMapping with identical configuration is created only once.
def get_or_create_hash_mapping(
engram_vocab_size,
max_ngram_size,
n_embed_per_ngram,
n_head_per_ngram,
layer_ids,
tokenizer_name_or_path,
pad_id,
seed,
):
cache_key = (
tuple(engram_vocab_size),
max_ngram_size,
n_embed_per_ngram,
n_head_per_ngram,
tuple(layer_ids),
tokenizer_name_or_path,
pad_id,
seed,
)
if cache_key not in _HASH_MAPPING_CACHE:
_HASH_MAPPING_CACHE[cache_key] = NgramHashMapping(
engram_vocab_size=engram_vocab_size,
max_ngram_size=max_ngram_size,
n_embed_per_ngram=n_embed_per_ngram,
n_head_per_ngram=n_head_per_ngram,
layer_ids=layer_ids,
tokenizer_name_or_path=tokenizer_name_or_path,
pad_id=pad_id,
seed=seed,
)
return _HASH_MAPPING_CACHE[cache_key]
from functools import lru_cache
@lru_cache(maxsize=8)
def get_or_create_hash_mapping(
engram_vocab_size, # must be tuple, not list
max_ngram_size,
n_embed_per_ngram,
n_head_per_ngram,
layer_ids, # must be tuple, not list
tokenizer_name_or_path,
pad_id,
seed,
):
return NgramHashMapping(
engram_vocab_size=engram_vocab_size,
max_ngram_size=max_ngram_size,
n_embed_per_ngram=n_embed_per_ngram,
n_head_per_ngram=n_head_per_ngram,
layer_ids=layer_ids,
tokenizer_name_or_path=tokenizer_name_or_path,
pad_id=pad_id,
seed=seed,
)

Callers must pass tuples instead of lists. You can also, keep the manual cache but add a clear_hash_mapping_cache() helper.

Comment on lines +328 to +339
if cache_key not in _HASH_MAPPING_CACHE:
_HASH_MAPPING_CACHE[cache_key] = NgramHashMapping(
engram_vocab_size=engram_vocab_size,
max_ngram_size=max_ngram_size,
n_embed_per_ngram=n_embed_per_ngram,
n_head_per_ngram=n_head_per_ngram,
layer_ids=layer_ids,
tokenizer_name_or_path=tokenizer_name_or_path,
pad_id=pad_id,
seed=seed,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: The check-then-set pattern isn't thread-safe. Multiple threads can race past if cache_key not in _HASH_MAPPING_CACHE and redundantly create expensive NgramHashMapping instances (tokenizer loading, prime computation).

Fix: Add a lock:

Suggested change
if cache_key not in _HASH_MAPPING_CACHE:
_HASH_MAPPING_CACHE[cache_key] = NgramHashMapping(
engram_vocab_size=engram_vocab_size,
max_ngram_size=max_ngram_size,
n_embed_per_ngram=n_embed_per_ngram,
n_head_per_ngram=n_head_per_ngram,
layer_ids=layer_ids,
tokenizer_name_or_path=tokenizer_name_or_path,
pad_id=pad_id,
seed=seed,
)
import threading
_HASH_MAPPING_CACHE = {}
_HASH_MAPPING_LOCK = threading.Lock()
def get_or_create_hash_mapping(...):
cache_key = (...)
with _HASH_MAPPING_LOCK:
if cache_key not in _HASH_MAPPING_CACHE:
_HASH_MAPPING_CACHE[cache_key] = NgramHashMapping(...)
return _HASH_MAPPING_CACHE[cache_key]

Or use @lru_cache which handles locking internally (also fixes the unbounded cache issue)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants