-
Notifications
You must be signed in to change notification settings - Fork 780
Custom op to update cache for torch.cond #16366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
63a2766
f02dbe1
9a7aa91
bc07a7b
a97933b
99ca698
e1bb6c2
395ab4f
2a7a9f0
a86ab6e
ca3ac6d
8b94087
5f755f9
690546b
73efe12
d96dec8
eb6a7e6
d5c53ec
8b8580d
ba6fdff
a8b20f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| 7a064ed3eafa43f17412d434b395240c727b3000 | ||
| 7a79b41e29a790ebb4b530eb98a89381e2d7de29 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,10 +12,16 @@ | |
|
|
||
| import logging | ||
|
|
||
| from typing import Tuple | ||
|
|
||
| import torch | ||
|
|
||
| from torch._inductor.lowering import lowerings as L, register_lowering | ||
|
|
||
| from torch.library import impl | ||
|
|
||
| aten = torch.ops.aten | ||
|
|
||
| try: | ||
| op = torch.ops.llama.sdpa_with_kv_cache.default | ||
| assert op is not None | ||
|
|
@@ -387,3 +393,103 @@ def custom_quantized_sdpa_meta( | |
| ) | ||
|
|
||
| return torch.empty(query.size(), dtype=torch.float32, device="meta") | ||
|
|
||
|
|
||
| # 1) Define the custom op in the "executorch" namespace with name "alias" | ||
| @torch.library.custom_op("executorch::alias", mutates_args=()) | ||
| def custom_alias(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| # no copies, just pass-through | ||
| return x, y | ||
|
|
||
|
|
||
| # 2) FakeTensor kernel: describes output metadata for compile-time | ||
| @custom_alias.register_fake | ||
| def _(x, y): | ||
| # For this op, outputs have exactly the same shape/dtype/device as inputs. | ||
| # We just need *dummy* tensors with that metadata. | ||
| out_x = torch.empty_like(x) | ||
| out_y = torch.empty_like(y) | ||
| return out_x, out_y | ||
|
|
||
|
|
||
| @register_lowering(torch.ops.executorch.alias.default) | ||
| def lowering_custom_alias(x, y): | ||
| # x, y here are IR values (Inductor's internal representation). | ||
| # Alias is logically a no-op – just pass them through. | ||
| return x, y | ||
|
|
||
|
|
||
| # Expecting cache shape: (B, H, S_max, D), value shape (B, H, S, D) where S <= S_max | ||
| def _validate_cross_attn_cache_params(value: torch.Tensor, cache: torch.Tensor): | ||
| torch._assert(value.dim() == 4, "value must be 4D") | ||
| torch._assert(cache.dim() == 4, "cache must be 4D") | ||
| # Cache shape: (B, H, S_max, D) | ||
| # Value shape: (B, H, S, D) | ||
| torch._assert( | ||
| value.size(2) <= cache.size(2), | ||
| f"value sequence length {value.size(2)} exceeds cache size {cache.size(2)}", | ||
| ) | ||
| torch._assert(value.size(0) == cache.size(0), "batch size mismatch") | ||
| torch._assert(value.size(1) == cache.size(1), "num heads mismatch") | ||
| torch._assert(value.size(3) == cache.size(3), "head dim mismatch") | ||
| torch._assert(value.dtype == cache.dtype, "dtype mismatch") | ||
|
|
||
|
|
||
| # Intentionally declaring no mutations to enable use inside torch.cond branches, | ||
| # which require pure functions. torch.cond requires branch functions to be mutation-free. | ||
| # We omit `cache` from `mutates_args` to satisfy this constraint, accepting the | ||
| # mutation for inference use. | ||
| @torch.library.custom_op("executorch::update_cross_attn_cache", mutates_args=[]) | ||
| def _update_cross_attn_cache(value: torch.Tensor, cache: torch.Tensor) -> torch.Tensor: | ||
| """ | ||
| Update cross-attention KV cache with new values. | ||
|
|
||
| Copies the value tensor into the beginning of the cache tensor along the | ||
| sequence dimension. This is used for cross-attention caching where the | ||
| encoder outputs are computed once and reused across decoding steps. | ||
|
|
||
| Args: | ||
| value: New values to store in cache. Shape: [B, H, S, D] where | ||
| B = batch size, H = num heads, S = sequence length, D = head dim. | ||
| cache: Pre-allocated cache tensor to update. Shape: [B, H, S_max, D] | ||
| where S_max >= S. | ||
|
|
||
| Returns: | ||
| A clone of the updated cache tensor. Note that this is different from | ||
| inductor lowering which returns the cache tensor itself. The reason is | ||
| that if we return input buffer directly, we will fail torch check in | ||
| higher order ops. | ||
|
|
||
| Note: | ||
| The cache is mutated in-place, but we return a clone to avoid aliasing | ||
| issues with the exported program. | ||
| """ | ||
| _validate_cross_attn_cache_params(value, cache) | ||
| cache[:, :, : value.size(2), :].copy_(value) | ||
| return cache.clone() | ||
|
|
||
|
|
||
| # Register the fake (meta) kernel | ||
| @_update_cross_attn_cache.register_fake | ||
| def _update_cross_attn_cache_fake( | ||
| value: torch.Tensor, cache: torch.Tensor | ||
| ) -> torch.Tensor: | ||
| _validate_cross_attn_cache_params(value, cache) | ||
| return torch.empty_like(cache) | ||
|
|
||
|
|
||
| # Register Inductor lowering | ||
| @register_lowering(torch.ops.executorch.update_cross_attn_cache) | ||
| def _update_cross_attn_cache_lowering(value, cache): | ||
| # cache shape: [B, H, S_max, D] | ||
| # value shape: [B, H, S, D] | ||
|
|
||
| # We need to slice the cache along dim 2 (sequence length) | ||
| # slice(self, dim, start, end, step=1) | ||
| seq_len = value.get_size()[2] | ||
| cache_slice = L[aten.slice.Tensor](cache, 2, 0, seq_len, 1) | ||
|
|
||
| # Copy value into the slice | ||
| L[aten.copy_.default](cache_slice, value) | ||
|
|
||
| return cache | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is returning the original cache whereas eager mode returns a copy. Do you want this discrepancy?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah. eager has to return a copy to make exported_program runnable. For inductor lowering we need it to be as efficient as possible.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay then document the discrepancy |
||
Uh oh!
There was an error while loading. Please reload this page.