Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/pytorch.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
7a064ed3eafa43f17412d434b395240c727b3000
7a79b41e29a790ebb4b530eb98a89381e2d7de29
13 changes: 13 additions & 0 deletions extension/llm/custom_ops/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,16 @@ runtime.python_test(
"//executorch/extension/pybindings:portable_lib",
],
)

runtime.python_test(
name = "test_update_cross_attn_cache",
srcs = [
"test_update_cross_attn_cache.py",
],
preload_deps = [
":custom_ops_aot_py",
],
deps = [
"//caffe2:torch",
],
)
106 changes: 106 additions & 0 deletions extension/llm/custom_ops/custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,16 @@

import logging

from typing import Tuple

import torch

from torch._inductor.lowering import lowerings as L, register_lowering

from torch.library import impl

aten = torch.ops.aten

try:
op = torch.ops.llama.sdpa_with_kv_cache.default
assert op is not None
Expand Down Expand Up @@ -387,3 +393,103 @@ def custom_quantized_sdpa_meta(
)

return torch.empty(query.size(), dtype=torch.float32, device="meta")


# 1) Define the custom op in the "executorch" namespace with name "alias"
@torch.library.custom_op("executorch::alias", mutates_args=())
def custom_alias(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# no copies, just pass-through
return x, y


# 2) FakeTensor kernel: describes output metadata for compile-time
@custom_alias.register_fake
def _(x, y):
# For this op, outputs have exactly the same shape/dtype/device as inputs.
# We just need *dummy* tensors with that metadata.
out_x = torch.empty_like(x)
out_y = torch.empty_like(y)
return out_x, out_y


@register_lowering(torch.ops.executorch.alias.default)
def lowering_custom_alias(x, y):
# x, y here are IR values (Inductor's internal representation).
# Alias is logically a no-op – just pass them through.
return x, y


# Expecting cache shape: (B, H, S_max, D), value shape (B, H, S, D) where S <= S_max
def _validate_cross_attn_cache_params(value: torch.Tensor, cache: torch.Tensor):
torch._assert(value.dim() == 4, "value must be 4D")
torch._assert(cache.dim() == 4, "cache must be 4D")
# Cache shape: (B, H, S_max, D)
# Value shape: (B, H, S, D)
torch._assert(
value.size(2) <= cache.size(2),
f"value sequence length {value.size(2)} exceeds cache size {cache.size(2)}",
)
torch._assert(value.size(0) == cache.size(0), "batch size mismatch")
torch._assert(value.size(1) == cache.size(1), "num heads mismatch")
torch._assert(value.size(3) == cache.size(3), "head dim mismatch")
torch._assert(value.dtype == cache.dtype, "dtype mismatch")


# Intentionally declaring no mutations to enable use inside torch.cond branches,
# which require pure functions. torch.cond requires branch functions to be mutation-free.
# We omit `cache` from `mutates_args` to satisfy this constraint, accepting the
# mutation for inference use.
@torch.library.custom_op("executorch::update_cross_attn_cache", mutates_args=[])
def _update_cross_attn_cache(value: torch.Tensor, cache: torch.Tensor) -> torch.Tensor:
"""
Update cross-attention KV cache with new values.

Copies the value tensor into the beginning of the cache tensor along the
sequence dimension. This is used for cross-attention caching where the
encoder outputs are computed once and reused across decoding steps.

Args:
value: New values to store in cache. Shape: [B, H, S, D] where
B = batch size, H = num heads, S = sequence length, D = head dim.
cache: Pre-allocated cache tensor to update. Shape: [B, H, S_max, D]
where S_max >= S.

Returns:
A clone of the updated cache tensor. Note that this is different from
inductor lowering which returns the cache tensor itself. The reason is
that if we return input buffer directly, we will fail torch check in
higher order ops.

Note:
The cache is mutated in-place, but we return a clone to avoid aliasing
issues with the exported program.
"""
_validate_cross_attn_cache_params(value, cache)
cache[:, :, : value.size(2), :].copy_(value)
return cache.clone()


# Register the fake (meta) kernel
@_update_cross_attn_cache.register_fake
def _update_cross_attn_cache_fake(
value: torch.Tensor, cache: torch.Tensor
) -> torch.Tensor:
_validate_cross_attn_cache_params(value, cache)
return torch.empty_like(cache)


# Register Inductor lowering
@register_lowering(torch.ops.executorch.update_cross_attn_cache)
def _update_cross_attn_cache_lowering(value, cache):
# cache shape: [B, H, S_max, D]
# value shape: [B, H, S, D]

# We need to slice the cache along dim 2 (sequence length)
# slice(self, dim, start, end, step=1)
seq_len = value.get_size()[2]
cache_slice = L[aten.slice.Tensor](cache, 2, 0, seq_len, 1)

# Copy value into the slice
L[aten.copy_.default](cache_slice, value)

return cache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is returning the original cache whereas eager mode returns a copy. Do you want this discrepancy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah. eager has to return a copy to make exported_program runnable. For inductor lowering we need it to be as efficient as possible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay then document the discrepancy

Loading
Loading