Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/memos/chunkers/sentence_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, config: SentenceChunkerConfig):

self.config = config
self.chunker = ChonkieSentenceChunker(
tokenizer_or_token_counter=config.tokenizer_or_token_counter,
tokenizer=config.tokenizer_or_token_counter,
chunk_size=config.chunk_size,
chunk_overlap=config.chunk_overlap,
min_sentences_per_chunk=config.min_sentences_per_chunk,
Expand Down
90 changes: 65 additions & 25 deletions src/memos/mem_os/utils/format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,35 +1088,68 @@ def convert_activation_memory_to_serializable(

for item in act_mem_items:
# Extract basic information that can be serialized
# Infer counts/device/dtype compatibly for new/old DynamicCache APIs
mem = item.memory
key_layers = 0
val_layers = 0
device_str = "unknown"
dtype_str = "unknown"

if mem:
if hasattr(mem, "layers") and mem.layers is not None:
key_layers = len(mem.layers)
val_layers = len(mem.layers)
# find first available tensor to report device/dtype
for lyr in mem.layers:
t = getattr(lyr, "keys", None)
if t is None:
t = getattr(lyr, "values", None)
if t is not None:
device_str = str(t.device)
dtype_str = str(t.dtype)
Comment on lines +1104 to +1109
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The variable name t is unclear and could be improved for readability. Consider renaming it to tensor or sample_tensor to better convey its purpose as a tensor used to determine device and dtype information.

Suggested change
t = getattr(lyr, "keys", None)
if t is None:
t = getattr(lyr, "values", None)
if t is not None:
device_str = str(t.device)
dtype_str = str(t.dtype)
tensor = getattr(lyr, "keys", None)
if tensor is None:
tensor = getattr(lyr, "values", None)
if tensor is not None:
device_str = str(tensor.device)
dtype_str = str(tensor.dtype)

Copilot uses AI. Check for mistakes.
break
else:
key_layers = len(getattr(mem, "key_cache", []) or [])
val_layers = len(getattr(mem, "value_cache", []) or [])
if getattr(mem, "key_cache", None):
first = next((t for t in mem.key_cache if t is not None), None)
if first is not None:
device_str = str(first.device)
dtype_str = str(first.dtype)

serializable_item = {
"id": item.id,
"metadata": item.metadata,
"memory_info": {
"type": "DynamicCache",
"key_cache_layers": len(item.memory.key_cache) if item.memory else 0,
"value_cache_layers": len(item.memory.value_cache) if item.memory else 0,
"device": str(item.memory.key_cache[0].device)
if item.memory and item.memory.key_cache
else "unknown",
"dtype": str(item.memory.key_cache[0].dtype)
if item.memory and item.memory.key_cache
else "unknown",
"key_cache_layers": key_layers,
"value_cache_layers": val_layers,
"device": device_str,
"dtype": dtype_str,
},
}

# Add tensor shape information if available
if item.memory and item.memory.key_cache:
if item.memory:
key_shapes = []
value_shapes = []

for i, key_tensor in enumerate(item.memory.key_cache):
if key_tensor is not None:
key_shapes.append({"layer": i, "shape": list(key_tensor.shape)})

if i < len(item.memory.value_cache) and item.memory.value_cache[i] is not None:
value_shapes.append(
{"layer": i, "shape": list(item.memory.value_cache[i].shape)}
)
mem = item.memory
if hasattr(mem, "layers") and mem.layers is not None:
for i, layer in enumerate(mem.layers):
if getattr(layer, "keys", None) is not None:
key_shapes.append({"layer": i, "shape": list(layer.keys.shape)})
if getattr(layer, "values", None) is not None:
value_shapes.append({"layer": i, "shape": list(layer.values.shape)})
elif getattr(mem, "key_cache", None):
for i, key_tensor in enumerate(mem.key_cache):
if key_tensor is not None:
key_shapes.append({"layer": i, "shape": list(key_tensor.shape)})
if (
hasattr(mem, "value_cache")
and i < len(mem.value_cache)
and mem.value_cache[i] is not None
):
value_shapes.append({"layer": i, "shape": list(mem.value_cache[i].shape)})

serializable_item["memory_info"]["key_shapes"] = key_shapes
serializable_item["memory_info"]["value_shapes"] = value_shapes
Expand Down Expand Up @@ -1144,15 +1177,22 @@ def convert_activation_memory_summary(act_mem_items: list[KVCacheItem]) -> dict[
total_parameters = 0

for item in act_mem_items:
if item.memory and item.memory.key_cache:
total_layers += len(item.memory.key_cache)

# Calculate approximate parameter count
for key_tensor in item.memory.key_cache:
mem = item.memory
if not mem:
continue
if hasattr(mem, "layers") and mem.layers is not None:
total_layers += len(mem.layers)
for layer in mem.layers:
if getattr(layer, "keys", None) is not None:
total_parameters += layer.keys.numel()
if getattr(layer, "values", None) is not None:
total_parameters += layer.values.numel()
elif getattr(mem, "key_cache", None):
total_layers += len(mem.key_cache)
for key_tensor in mem.key_cache:
if key_tensor is not None:
total_parameters += key_tensor.numel()

for value_tensor in item.memory.value_cache:
for value_tensor in getattr(mem, "value_cache", []) or []:
if value_tensor is not None:
total_parameters += value_tensor.numel()

Expand Down
28 changes: 11 additions & 17 deletions src/memos/memories/activation/kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
import pickle

from datetime import datetime
from importlib.metadata import version

from packaging.version import Version
from transformers import DynamicCache

from memos.configs.memory import KVCacheMemoryConfig
Expand Down Expand Up @@ -210,29 +207,26 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache:
if len(caches) == 1:
return caches[0]

merged = DynamicCache()
num_layers = len(caches[0].key_cache)

if Version(version("transformers")) >= Version("4.54.0"):
merged.append_new_layers(num_layers - 1)
# Newer transformers expose `layers` with `.keys`/`.values`
if hasattr(caches[0], "layers") and caches[0].layers is not None:
num_layers = len(caches[0].layers)
base = caches[0]
for layer in range(num_layers):
# gather all K and V for this layer
keys = [c.layers[layer].keys for c in caches]
vals = [c.layers[layer].values for c in caches]
# single concat per layer
merged.layers[layer].keys = torch.cat(keys, dim=-2)
merged.layers[layer].values = torch.cat(vals, dim=-2)

base.layers[layer].keys = torch.cat(keys, dim=-2)
base.layers[layer].values = torch.cat(vals, dim=-2)
return base
Comment on lines +213 to +219
Copy link

Copilot AI Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mutating the first cache object in-place (caches[0]) is problematic because it modifies the original cache that may still be referenced elsewhere. This could lead to unexpected side effects if the caller expects the original caches to remain unchanged. Consider creating a new DynamicCache() object and populating its layers similar to the legacy API path.

Copilot uses AI. Check for mistakes.
else:
# Legacy API: key_cache/value_cache lists
merged = DynamicCache()
num_layers = len(caches[0].key_cache)
for layer in range(num_layers):
# gather all K and V for this layer
keys = [c.key_cache[layer] for c in caches]
vals = [c.value_cache[layer] for c in caches]
# single concat per layer
merged.key_cache.append(torch.cat(keys, dim=-2))
merged.value_cache.append(torch.cat(vals, dim=-2))

return merged
return merged


def move_dynamic_cache_htod(dynamic_cache: DynamicCache, device: str) -> DynamicCache:
Expand Down
27 changes: 21 additions & 6 deletions tests/memories/activation/test_kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,20 @@ def kv_memory(dummy_config):


def make_filled_cache():
# Create a DynamicCache with at least one dummy tensor layer
# Create a DynamicCache with at least one dummy tensor layer, supporting new/old APIs
cache = DynamicCache()
cache.key_cache.append(torch.zeros(1, 2, 3))
cache.value_cache.append(torch.zeros(1, 2, 3))
if hasattr(cache, "layers") and cache.layers is not None:
# For new API, append a layer-like object with keys/values tensors
class _Layer:
def __init__(self):
self.keys = torch.zeros(1, 2, 3)
self.values = torch.zeros(1, 2, 3)

cache.layers.append(_Layer())
else:
# Legacy API
cache.key_cache.append(torch.zeros(1, 2, 3))
cache.value_cache.append(torch.zeros(1, 2, 3))
return cache


Expand All @@ -58,9 +68,14 @@ def test_get_cache_merge(kv_memory):
kv_memory.add([item1, item2])
merged = kv_memory.get_cache([item1.id, item2.id])
assert isinstance(merged, DynamicCache)
# Check the number of layers in merged key/value cache
assert len(merged.key_cache) == 1
assert len(merged.value_cache) == 1
# Check the number of layers in merged cache (new or old API)
if hasattr(merged, "layers") and merged.layers is not None:
assert len(merged.layers) == 1
assert getattr(merged.layers[0], "keys", None) is not None
assert getattr(merged.layers[0], "values", None) is not None
else:
assert len(merged.key_cache) == 1
assert len(merged.value_cache) == 1


def test_delete_and_get_all(kv_memory):
Expand Down