-
Notifications
You must be signed in to change notification settings - Fork 255
fix: fix wrong param in SentenceChunker #370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
0f04984
4809d32
aa11489
656d653
de789b2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,9 +2,6 @@ | |
| import pickle | ||
|
|
||
| from datetime import datetime | ||
| from importlib.metadata import version | ||
|
|
||
| from packaging.version import Version | ||
| from transformers import DynamicCache | ||
|
|
||
| from memos.configs.memory import KVCacheMemoryConfig | ||
|
|
@@ -210,29 +207,26 @@ def _concat_caches(self, caches: list[DynamicCache]) -> DynamicCache: | |
| if len(caches) == 1: | ||
| return caches[0] | ||
|
|
||
| merged = DynamicCache() | ||
| num_layers = len(caches[0].key_cache) | ||
|
|
||
| if Version(version("transformers")) >= Version("4.54.0"): | ||
| merged.append_new_layers(num_layers - 1) | ||
| # Newer transformers expose `layers` with `.keys`/`.values` | ||
| if hasattr(caches[0], "layers") and caches[0].layers is not None: | ||
| num_layers = len(caches[0].layers) | ||
| base = caches[0] | ||
| for layer in range(num_layers): | ||
| # gather all K and V for this layer | ||
| keys = [c.layers[layer].keys for c in caches] | ||
| vals = [c.layers[layer].values for c in caches] | ||
| # single concat per layer | ||
| merged.layers[layer].keys = torch.cat(keys, dim=-2) | ||
| merged.layers[layer].values = torch.cat(vals, dim=-2) | ||
|
|
||
| base.layers[layer].keys = torch.cat(keys, dim=-2) | ||
| base.layers[layer].values = torch.cat(vals, dim=-2) | ||
| return base | ||
|
Comment on lines
+213
to
+219
|
||
| else: | ||
| # Legacy API: key_cache/value_cache lists | ||
| merged = DynamicCache() | ||
| num_layers = len(caches[0].key_cache) | ||
| for layer in range(num_layers): | ||
| # gather all K and V for this layer | ||
| keys = [c.key_cache[layer] for c in caches] | ||
| vals = [c.value_cache[layer] for c in caches] | ||
| # single concat per layer | ||
| merged.key_cache.append(torch.cat(keys, dim=-2)) | ||
| merged.value_cache.append(torch.cat(vals, dim=-2)) | ||
|
|
||
| return merged | ||
| return merged | ||
|
|
||
|
|
||
| def move_dynamic_cache_htod(dynamic_cache: DynamicCache, device: str) -> DynamicCache: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The variable name
tis unclear and could be improved for readability. Consider renaming it totensororsample_tensorto better convey its purpose as a tensor used to determine device and dtype information.