Skip to content

Commit 7db4ecc

Browse files
committed
Custom op to update cache for torch.cond
ghstack-source-id: e3a8c16 ghstack-comment-id: 3683802199 Pull-Request: #16366
1 parent 1854f94 commit 7db4ecc

File tree

3 files changed

+290
-2
lines changed

3 files changed

+290
-2
lines changed

extension/llm/custom_ops/custom_ops.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,16 @@
1212

1313
import logging
1414

15+
from typing import Tuple
16+
1517
import torch
1618

19+
from torch._inductor.lowering import lowerings as L, register_lowering
20+
1721
from torch.library import impl
1822

23+
aten = torch.ops.aten
24+
1925
try:
2026
op = torch.ops.llama.sdpa_with_kv_cache.default
2127
assert op is not None
@@ -387,3 +393,85 @@ def custom_quantized_sdpa_meta(
387393
)
388394

389395
return torch.empty(query.size(), dtype=torch.float32, device="meta")
396+
397+
398+
# 1) Define the custom op in the "executorch" namespace with name "alias"
399+
@torch.library.custom_op("executorch::alias", mutates_args=())
400+
def custom_alias(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
401+
# no copies, just pass-through
402+
return x, y
403+
404+
405+
# 2) FakeTensor kernel: describes output metadata for compile-time
406+
@custom_alias.register_fake
407+
def _(x, y):
408+
# For this op, outputs have exactly the same shape/dtype/device as inputs.
409+
# We just need *dummy* tensors with that metadata.
410+
out_x = torch.empty_like(x)
411+
out_y = torch.empty_like(y)
412+
return out_x, out_y
413+
414+
415+
@register_lowering(torch.ops.executorch.alias.default)
416+
def lowering_custom_alias(x, y):
417+
# x, y here are IR values (Inductor's internal representation).
418+
# Alias is logically a no-op – just pass them through.
419+
return x, y
420+
421+
422+
# Expecting cache shape: (B, H, S_max, D), value shape (B, H, S, D) where S <= S_max
423+
def _validate_cross_attn_cache_params(value: torch.Tensor, cache: torch.Tensor):
424+
torch._assert(value.dim() == 4, "value must be 4D")
425+
torch._assert(cache.dim() == 4, "cache must be 4D")
426+
# Cache shape: (B, H, S_max, D)
427+
# Value shape: (B, H, S, D)
428+
torch._assert(
429+
value.size(2) <= cache.size(2),
430+
f"value sequence length {value.size(2)} exceeds cache size {cache.size(2)}",
431+
)
432+
torch._assert(value.size(0) == cache.size(0), "batch size mismatch")
433+
torch._assert(value.size(1) == cache.size(1), "num heads mismatch")
434+
torch._assert(value.size(3) == cache.size(3), "head dim mismatch")
435+
torch._assert(value.dtype == cache.dtype, "dtype mismatch")
436+
437+
438+
# This is cheating: we delibrately NOT mark `cache` to be mutating so that this
439+
# custom op can be used in HOP such as `torch.cond`, where `torch.compile` requires
440+
# no aliasing or mutation in the branches. This is fine because we only care about inference.
441+
@torch.library.custom_op("executorch::update_cross_attn_cache", mutates_args=[])
442+
def _update_cross_attn_cache(value: torch.Tensor, cache: torch.Tensor) -> torch.Tensor:
443+
# Eager implementation
444+
_validate_cross_attn_cache_params(value, cache)
445+
446+
# Slice the cache to match value's sequence length and copy
447+
# cache shape: [B, H, S_max, D]
448+
# value shape: [B, H, S, D]
449+
cache[:, :, : value.size(2), :].copy_(value)
450+
# Return a clone of the cache to avoid aliasing with the input cache, so that we can still run exported program.
451+
return cache.clone()
452+
453+
454+
# Register the fake (meta) kernel
455+
@_update_cross_attn_cache.register_fake
456+
def _update_cross_attn_cache_fake(
457+
value: torch.Tensor, cache: torch.Tensor
458+
) -> torch.Tensor:
459+
_validate_cross_attn_cache_params(value, cache)
460+
return torch.empty_like(cache)
461+
462+
463+
# Register Inductor lowering
464+
@register_lowering(torch.ops.executorch.update_cross_attn_cache)
465+
def _update_cross_attn_cache_lowering(value, cache):
466+
# cache shape: [B, H, S_max, D]
467+
# value shape: [B, H, S, D]
468+
469+
# We need to slice the cache along dim 2 (sequence length)
470+
# slice(self, dim, start, end, step=1)
471+
seq_len = value.get_size()[2]
472+
cache_slice = L[aten.slice.Tensor](cache, 2, 0, seq_len, 1)
473+
474+
# Copy value into the slice
475+
L[aten.copy_.default](cache_slice, value)
476+
477+
return cache
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import unittest
2+
3+
import torch
4+
5+
# Import the custom ops to ensure they are registered
6+
from executorch.extension.llm.custom_ops import custom_ops # noqa: F401
7+
8+
9+
class TestUpdateCrossAttnCache(unittest.TestCase):
10+
def test_update_cross_attn_cache(self):
11+
12+
# Create tensors
13+
# Cache: [B=2, H=1, S_max=4, D=4]
14+
cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
15+
# Value: [B=2, H=1, S=2, D=4] (S < S_max)
16+
value = torch.randn(2, 1, 2, 4, dtype=torch.float32)
17+
18+
# Compile a function that uses the op
19+
@torch.compile
20+
def fn(v, c):
21+
return torch.ops.executorch.update_cross_attn_cache(v, c)
22+
23+
# Run it
24+
out = fn(value, cache)
25+
26+
# Check correctness
27+
# The first 2 elements in dim 2 (sequence dim) should match value
28+
torch.testing.assert_close(
29+
cache[:, :, :2, :], value, msg="Cache slice not updated correctly"
30+
)
31+
32+
# Make sure out and cache are close. In eager they are the same objects.
33+
torch.testing.assert_close(
34+
out, cache, msg="Output and cache are different objects"
35+
)
36+
37+
# The rest should be zeros
38+
torch.testing.assert_close(
39+
cache[:, :, 2:, :],
40+
torch.zeros_like(cache[:, :, 2:, :]),
41+
msg="Rest of cache was modified",
42+
)
43+
44+
def test_update_cross_attn_cache_in_cond(self):
45+
# Create tensors
46+
47+
# Value: [B=2, H=1, S=2, D=4]
48+
value = torch.randn(2, 1, 2, 4, dtype=torch.float32)
49+
# Alternative value for false branch
50+
value_alt = torch.randn(2, 1, 2, 4, dtype=torch.float32)
51+
52+
# Define a function that uses the op inside torch.cond
53+
def fn_with_cond(pred, v1, v2, c):
54+
def true_fn(v1, v2, cache):
55+
return torch.ops.executorch.update_cross_attn_cache(v1, cache)
56+
57+
def false_fn(v1, v2, cache):
58+
return torch.ops.executorch.update_cross_attn_cache(v2, cache)
59+
60+
return torch.cond(pred, true_fn, false_fn, (v1, v2, c))
61+
62+
# Test with true condition
63+
pred_true = torch.tensor(True)
64+
cache_true = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
65+
66+
# Compile the function
67+
@torch.compile
68+
def compiled_fn(pred, v1, v2, c):
69+
return fn_with_cond(pred, v1, v2, c)
70+
71+
# Run with true condition
72+
compiled_fn(pred_true, value, value_alt, cache_true)
73+
74+
# Check that the true branch was executed (value was used)
75+
torch.testing.assert_close(
76+
cache_true[:, :, :2, :],
77+
value,
78+
msg="Cache not updated correctly in true branch",
79+
)
80+
81+
# Test with false condition
82+
pred_false = torch.tensor(False)
83+
cache_false = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
84+
85+
compiled_fn(pred_false, value, value_alt, cache_false)
86+
87+
# Check that the false branch was executed (value_alt was used)
88+
torch.testing.assert_close(
89+
cache_false[:, :, :2, :],
90+
value_alt,
91+
msg="Cache not updated correctly in false branch",
92+
)
93+
94+
def test_update_cross_attn_cache_export(self):
95+
96+
# Create tensors
97+
# Cache: [B=2, H=1, S_max=4, D=4]
98+
cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
99+
# Value: [B=2, H=1, S=2, D=4]
100+
value = torch.randn(2, 1, 2, 4, dtype=torch.float32)
101+
# Alternative value for false branch
102+
value_alt = torch.randn(2, 1, 2, 4, dtype=torch.float32)
103+
104+
# Define a module that uses torch.cond with the op
105+
class UpdateCacheCondModule(torch.nn.Module):
106+
def forward(self, pred, v1, v2, c):
107+
def true_fn(v1, v2, cache):
108+
return torch.ops.executorch.update_cross_attn_cache(v1, cache)
109+
110+
def false_fn(v1, v2, cache):
111+
return torch.ops.executorch.update_cross_attn_cache(v2, cache)
112+
113+
return torch.cond(pred, true_fn, false_fn, (v1, v2, c))
114+
115+
module = UpdateCacheCondModule()
116+
117+
# Export the module with true condition
118+
pred_true = torch.tensor(True)
119+
exported_program = torch.export.export(
120+
module,
121+
(pred_true, value, value_alt, cache),
122+
)
123+
124+
# Run the exported program with true condition
125+
cache_true = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
126+
exported_program.module()(pred_true, value, value_alt, cache_true)
127+
128+
# Check that the true branch was executed (value was used)
129+
torch.testing.assert_close(
130+
cache_true[:, :, :2, :],
131+
value,
132+
msg="Cache not updated correctly in true branch after export",
133+
)
134+
135+
# Run the exported program with false condition
136+
pred_false = torch.tensor(False)
137+
cache_false = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
138+
exported_program.module()(pred_false, value, value_alt, cache_false)
139+
140+
# Check that the false branch was executed (value_alt was used)
141+
torch.testing.assert_close(
142+
cache_false[:, :, :2, :],
143+
value_alt,
144+
msg="Cache not updated correctly in false branch after export",
145+
)
146+
147+
def test_update_cross_attn_cache_different_shapes(self):
148+
print("Testing executorch::update_cross_attn_cache with different shapes...")
149+
150+
# Test with different batch sizes and sequence lengths
151+
test_cases = [
152+
# (B, H, S_max, S, D)
153+
(1, 2, 10, 5, 8),
154+
(4, 4, 8, 3, 16),
155+
(2, 1, 16, 10, 32),
156+
]
157+
158+
for B, H, S_max, S, D in test_cases:
159+
# Cache: [B, H, S_max, D], Value: [B, H, S, D]
160+
cache = torch.zeros(B, H, S_max, D, dtype=torch.float32)
161+
value = torch.randn(B, H, S, D, dtype=torch.float32)
162+
163+
@torch.compile
164+
def fn(v, c):
165+
return torch.ops.executorch.update_cross_attn_cache(v, c)
166+
167+
fn(value, cache)
168+
169+
# Check that the first S positions in dim 2 are updated
170+
torch.testing.assert_close(
171+
cache[:, :, :S, :],
172+
value,
173+
msg=f"Failed for shape B={B}, H={H}, S_max={S_max}, S={S}, D={D}",
174+
)
175+
176+
# Check that the rest remain zeros
177+
if S < S_max:
178+
torch.testing.assert_close(
179+
cache[:, :, S:, :],
180+
torch.zeros_like(cache[:, :, S:, :]),
181+
msg=f"Remaining cache modified for shape B={B}, H={H}, S_max={S_max}, S={S}, D={D}",
182+
)
183+
184+
def test_update_cross_attn_cache_full_sequence(self):
185+
186+
# Cache: [B=2, H=1, S_max=4, D=4]
187+
cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32)
188+
# Value: [B=2, H=1, S=4, D=4] (S == S_max)
189+
value = torch.randn(2, 1, 4, 4, dtype=torch.float32)
190+
191+
@torch.compile
192+
def fn(v, c):
193+
return torch.ops.executorch.update_cross_attn_cache(v, c)
194+
195+
fn(value, cache)
196+
197+
# The entire cache should match value
198+
torch.testing.assert_close(
199+
cache, value, msg="Cache not fully updated when S == S_max"
200+
)

torch_pin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
TORCH_VERSION = "2.10.0"
2-
NIGHTLY_VERSION = "dev20251120"
1+
TORCH_VERSION = "2.11.0"
2+
NIGHTLY_VERSION = "dev20251222"

0 commit comments

Comments
 (0)