From 0f04984c62b2e8ba8a8185b9d62b929e77225954 Mon Sep 17 00:00:00 2001 From: Linorman <3033797357@qq.com> Date: Sat, 18 Oct 2025 11:43:07 +0800 Subject: [PATCH 1/3] fix: fix wrong param in SentenceChunker --- src/memos/chunkers/sentence_chunker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/chunkers/sentence_chunker.py b/src/memos/chunkers/sentence_chunker.py index 4de0cf32b..36ebee8e0 100644 --- a/src/memos/chunkers/sentence_chunker.py +++ b/src/memos/chunkers/sentence_chunker.py @@ -21,7 +21,7 @@ def __init__(self, config: SentenceChunkerConfig): self.config = config self.chunker = ChonkieSentenceChunker( - tokenizer_or_token_counter=config.tokenizer_or_token_counter, + tokenizer=config.tokenizer_or_token_counter, chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap, min_sentences_per_chunk=config.min_sentences_per_chunk, From 4809d326e21361f608a58db82f2e99b2134e175d Mon Sep 17 00:00:00 2001 From: Linorman <3033797357@qq.com> Date: Sat, 18 Oct 2025 15:21:28 +0800 Subject: [PATCH 2/3] fix: fix wrong param in hf.py --- src/memos/llms/hf.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/memos/llms/hf.py b/src/memos/llms/hf.py index 00081b581..b8e2adbad 100644 --- a/src/memos/llms/hf.py +++ b/src/memos/llms/hf.py @@ -382,7 +382,23 @@ def build_kv_cache(self, messages) -> DynamicCache: kv = DynamicCache() with torch.no_grad(): self.model(**inputs, use_cache=True, past_key_values=kv) - for i, (k, v) in enumerate(zip(kv.key_cache, kv.value_cache, strict=False)): - kv.key_cache[i] = k[:, :, :seq_len, :] - kv.value_cache[i] = v[:, :, :seq_len, :] + try: + if hasattr(kv, "key_cache") and hasattr(kv, "value_cache"): + for i, (k, v) in enumerate(zip(kv.key_cache, kv.value_cache)): + if isinstance(k, torch.Tensor): + kv.key_cache[i] = k[..., :seq_len, :] + if isinstance(v, torch.Tensor): + kv.value_cache[i] = v[..., :seq_len, :] + elif hasattr(kv, "layers"): + for layer in kv.layers: + if hasattr(layer, "keys") and isinstance(layer.keys, torch.Tensor): + layer.keys = layer.keys[..., :seq_len, :] + if hasattr(layer, "values") and isinstance(layer.values, torch.Tensor): + layer.values = layer.values[..., :seq_len, :] + else: + logger.warning( + "DynamicCache object has no key_cache/value_cache or layers attributes; returning unmodified cache" + ) + except Exception as e: + logger.exception("Failed while trimming KV cache to seq_len: %s", e) return kv From aa1148984ca624106edb4c7ba2839e9a8f9ba276 Mon Sep 17 00:00:00 2001 From: Linorman <3033797357@qq.com> Date: Sat, 18 Oct 2025 16:39:00 +0800 Subject: [PATCH 3/3] fix: fix wrong param in the whole repo --- src/memos/llms/hf.py | 17 ++--- src/memos/mem_os/utils/format_utils.py | 90 +++++++++++++++++++------- src/memos/memories/activation/kv.py | 58 ++++++++++------- tests/memories/activation/test_kv.py | 27 ++++++-- 4 files changed, 129 insertions(+), 63 deletions(-) diff --git a/src/memos/llms/hf.py b/src/memos/llms/hf.py index b8e2adbad..53f3c383e 100644 --- a/src/memos/llms/hf.py +++ b/src/memos/llms/hf.py @@ -383,21 +383,22 @@ def build_kv_cache(self, messages) -> DynamicCache: with torch.no_grad(): self.model(**inputs, use_cache=True, past_key_values=kv) try: - if hasattr(kv, "key_cache") and hasattr(kv, "value_cache"): - for i, (k, v) in enumerate(zip(kv.key_cache, kv.value_cache)): - if isinstance(k, torch.Tensor): - kv.key_cache[i] = k[..., :seq_len, :] - if isinstance(v, torch.Tensor): - kv.value_cache[i] = v[..., :seq_len, :] - elif hasattr(kv, "layers"): + # Prefer new API first + if hasattr(kv, "layers") and kv.layers is not None: for layer in kv.layers: if hasattr(layer, "keys") and isinstance(layer.keys, torch.Tensor): layer.keys = layer.keys[..., :seq_len, :] if hasattr(layer, "values") and isinstance(layer.values, torch.Tensor): layer.values = layer.values[..., :seq_len, :] + elif hasattr(kv, "key_cache") and hasattr(kv, "value_cache"): + for i, (k, v) in enumerate(zip(kv.key_cache, kv.value_cache)): + if isinstance(k, torch.Tensor): + kv.key_cache[i] = k[..., :seq_len, :] + if isinstance(v, torch.Tensor): + kv.value_cache[i] = v[..., :seq_len, :] else: logger.warning( - "DynamicCache object has no key_cache/value_cache or layers attributes; returning unmodified cache" + "DynamicCache object has no layers or key_cache/value_cache attributes; returning unmodified cache" ) except Exception as e: logger.exception("Failed while trimming KV cache to seq_len: %s", e) diff --git a/src/memos/mem_os/utils/format_utils.py b/src/memos/mem_os/utils/format_utils.py index 5fdb59058..153a5978c 100644 --- a/src/memos/mem_os/utils/format_utils.py +++ b/src/memos/mem_os/utils/format_utils.py @@ -1088,35 +1088,68 @@ def convert_activation_memory_to_serializable( for item in act_mem_items: # Extract basic information that can be serialized + # Infer counts/device/dtype compatibly for new/old DynamicCache APIs + mem = item.memory + key_layers = 0 + val_layers = 0 + device_str = "unknown" + dtype_str = "unknown" + + if mem: + if hasattr(mem, "layers") and mem.layers is not None: + key_layers = len(mem.layers) + val_layers = len(mem.layers) + # find first available tensor to report device/dtype + for lyr in mem.layers: + t = getattr(lyr, "keys", None) + if t is None: + t = getattr(lyr, "values", None) + if t is not None: + device_str = str(t.device) + dtype_str = str(t.dtype) + break + else: + key_layers = len(getattr(mem, "key_cache", []) or []) + val_layers = len(getattr(mem, "value_cache", []) or []) + if getattr(mem, "key_cache", None): + first = next((t for t in mem.key_cache if t is not None), None) + if first is not None: + device_str = str(first.device) + dtype_str = str(first.dtype) + serializable_item = { "id": item.id, "metadata": item.metadata, "memory_info": { "type": "DynamicCache", - "key_cache_layers": len(item.memory.key_cache) if item.memory else 0, - "value_cache_layers": len(item.memory.value_cache) if item.memory else 0, - "device": str(item.memory.key_cache[0].device) - if item.memory and item.memory.key_cache - else "unknown", - "dtype": str(item.memory.key_cache[0].dtype) - if item.memory and item.memory.key_cache - else "unknown", + "key_cache_layers": key_layers, + "value_cache_layers": val_layers, + "device": device_str, + "dtype": dtype_str, }, } # Add tensor shape information if available - if item.memory and item.memory.key_cache: + if item.memory: key_shapes = [] value_shapes = [] - - for i, key_tensor in enumerate(item.memory.key_cache): - if key_tensor is not None: - key_shapes.append({"layer": i, "shape": list(key_tensor.shape)}) - - if i < len(item.memory.value_cache) and item.memory.value_cache[i] is not None: - value_shapes.append( - {"layer": i, "shape": list(item.memory.value_cache[i].shape)} - ) + mem = item.memory + if hasattr(mem, "layers") and mem.layers is not None: + for i, layer in enumerate(mem.layers): + if getattr(layer, "keys", None) is not None: + key_shapes.append({"layer": i, "shape": list(layer.keys.shape)}) + if getattr(layer, "values", None) is not None: + value_shapes.append({"layer": i, "shape": list(layer.values.shape)}) + elif getattr(mem, "key_cache", None): + for i, key_tensor in enumerate(mem.key_cache): + if key_tensor is not None: + key_shapes.append({"layer": i, "shape": list(key_tensor.shape)}) + if ( + hasattr(mem, "value_cache") + and i < len(mem.value_cache) + and mem.value_cache[i] is not None + ): + value_shapes.append({"layer": i, "shape": list(mem.value_cache[i].shape)}) serializable_item["memory_info"]["key_shapes"] = key_shapes serializable_item["memory_info"]["value_shapes"] = value_shapes @@ -1144,15 +1177,22 @@ def convert_activation_memory_summary(act_mem_items: list[KVCacheItem]) -> dict[ total_parameters = 0 for item in act_mem_items: - if item.memory and item.memory.key_cache: - total_layers += len(item.memory.key_cache) - - # Calculate approximate parameter count - for key_tensor in item.memory.key_cache: + mem = item.memory + if not mem: + continue + if hasattr(mem, "layers") and mem.layers is not None: + total_layers += len(mem.layers) + for layer in mem.layers: + if getattr(layer, "keys", None) is not None: + total_parameters += layer.keys.numel() + if getattr(layer, "values", None) is not None: + total_parameters += layer.values.numel() + elif getattr(mem, "key_cache", None): + total_layers += len(mem.key_cache) + for key_tensor in mem.key_cache: if key_tensor is not None: total_parameters += key_tensor.numel() - - for value_tensor in item.memory.value_cache: + for value_tensor in getattr(mem, "value_cache", []) or []: if value_tensor is not None: total_parameters += value_tensor.numel() diff --git a/src/memos/memories/activation/kv.py b/src/memos/memories/activation/kv.py index 2fa08590f..d07c7abb9 100644 --- a/src/memos/memories/activation/kv.py +++ b/src/memos/memories/activation/kv.py @@ -2,9 +2,6 @@ import pickle from datetime import datetime -from importlib.metadata import version - -from packaging.version import Version from transformers import DynamicCache from memos.configs.memory import KVCacheMemoryConfig @@ -210,29 +207,26 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache: if len(caches) == 1: return caches[0] - merged = DynamicCache() - num_layers = len(caches[0].key_cache) - - if Version(version("transformers")) >= Version("4.54.0"): - merged.append_new_layers(num_layers - 1) + # Newer transformers expose `layers` with `.keys`/`.values` + if hasattr(caches[0], "layers") and caches[0].layers is not None: + num_layers = len(caches[0].layers) + base = caches[0] for layer in range(num_layers): - # gather all K and V for this layer keys = [c.layers[layer].keys for c in caches] vals = [c.layers[layer].values for c in caches] - # single concat per layer - merged.layers[layer].keys = torch.cat(keys, dim=-2) - merged.layers[layer].values = torch.cat(vals, dim=-2) - + base.layers[layer].keys = torch.cat(keys, dim=-2) + base.layers[layer].values = torch.cat(vals, dim=-2) + return base else: + # Legacy API: key_cache/value_cache lists + merged = DynamicCache() + num_layers = len(caches[0].key_cache) for layer in range(num_layers): - # gather all K and V for this layer keys = [c.key_cache[layer] for c in caches] vals = [c.value_cache[layer] for c in caches] - # single concat per layer merged.key_cache.append(torch.cat(keys, dim=-2)) merged.value_cache.append(torch.cat(vals, dim=-2)) - - return merged + return merged def move_dynamic_cache_htod(dynamic_cache: DynamicCache, device: str) -> DynamicCache: @@ -242,11 +236,27 @@ def move_dynamic_cache_htod(dynamic_cache: DynamicCache, device: str) -> Dynamic So before inferring with DynamicCache, we should move it to GPU in-place first. """ # Currently, we put this function outside [class KVCacheMemory] - for i in range(len(dynamic_cache.key_cache)): - if dynamic_cache.key_cache[i] is not None: - dynamic_cache.key_cache[i] = dynamic_cache.key_cache[i].to(device, non_blocking=True) - if dynamic_cache.value_cache[i] is not None: - dynamic_cache.value_cache[i] = dynamic_cache.value_cache[i].to( - device, non_blocking=True - ) + # Support both old API (key_cache/value_cache) and new API (layers with keys/values) + if hasattr(dynamic_cache, "layers") and dynamic_cache.layers is not None: + for i, layer in enumerate(dynamic_cache.layers): + # Each layer is expected to have `.keys` and `.values` tensors + if hasattr(layer, "keys") and layer.keys is not None: + layer.keys = layer.keys.to(device, non_blocking=True) + if hasattr(layer, "values") and layer.values is not None: + layer.values = layer.values.to(device, non_blocking=True) + else: + # Fallback to legacy attributes + for i in range(len(getattr(dynamic_cache, "key_cache", []))): + if dynamic_cache.key_cache[i] is not None: + dynamic_cache.key_cache[i] = dynamic_cache.key_cache[i].to( + device, non_blocking=True + ) + if ( + hasattr(dynamic_cache, "value_cache") + and i < len(dynamic_cache.value_cache) + and dynamic_cache.value_cache[i] is not None + ): + dynamic_cache.value_cache[i] = dynamic_cache.value_cache[i].to( + device, non_blocking=True + ) return dynamic_cache diff --git a/tests/memories/activation/test_kv.py b/tests/memories/activation/test_kv.py index 6490d687f..20bcd2435 100644 --- a/tests/memories/activation/test_kv.py +++ b/tests/memories/activation/test_kv.py @@ -34,10 +34,20 @@ def kv_memory(dummy_config): def make_filled_cache(): - # Create a DynamicCache with at least one dummy tensor layer + # Create a DynamicCache with at least one dummy tensor layer, supporting new/old APIs cache = DynamicCache() - cache.key_cache.append(torch.zeros(1, 2, 3)) - cache.value_cache.append(torch.zeros(1, 2, 3)) + if hasattr(cache, "layers") and cache.layers is not None: + # For new API, append a layer-like object with keys/values tensors + class _Layer: + def __init__(self): + self.keys = torch.zeros(1, 2, 3) + self.values = torch.zeros(1, 2, 3) + + cache.layers.append(_Layer()) + else: + # Legacy API + cache.key_cache.append(torch.zeros(1, 2, 3)) + cache.value_cache.append(torch.zeros(1, 2, 3)) return cache @@ -58,9 +68,14 @@ def test_get_cache_merge(kv_memory): kv_memory.add([item1, item2]) merged = kv_memory.get_cache([item1.id, item2.id]) assert isinstance(merged, DynamicCache) - # Check the number of layers in merged key/value cache - assert len(merged.key_cache) == 1 - assert len(merged.value_cache) == 1 + # Check the number of layers in merged cache (new or old API) + if hasattr(merged, "layers") and merged.layers is not None: + assert len(merged.layers) == 1 + assert getattr(merged.layers[0], "keys", None) is not None + assert getattr(merged.layers[0], "values", None) is not None + else: + assert len(merged.key_cache) == 1 + assert len(merged.value_cache) == 1 def test_delete_and_get_all(kv_memory):