diff --git a/.ci/docker/ci_commit_pins/pytorch.txt b/.ci/docker/ci_commit_pins/pytorch.txt index f4ec226f512..64dbd40e9a6 100644 --- a/.ci/docker/ci_commit_pins/pytorch.txt +++ b/.ci/docker/ci_commit_pins/pytorch.txt @@ -1 +1 @@ -7a064ed3eafa43f17412d434b395240c727b3000 +7a79b41e29a790ebb4b530eb98a89381e2d7de29 diff --git a/extension/llm/custom_ops/TARGETS b/extension/llm/custom_ops/TARGETS index 5dda2318f3f..3b7c41e248c 100644 --- a/extension/llm/custom_ops/TARGETS +++ b/extension/llm/custom_ops/TARGETS @@ -63,3 +63,16 @@ runtime.python_test( "//executorch/extension/pybindings:portable_lib", ], ) + +runtime.python_test( + name = "test_update_cross_attn_cache", + srcs = [ + "test_update_cross_attn_cache.py", + ], + preload_deps = [ + ":custom_ops_aot_py", + ], + deps = [ + "//caffe2:torch", + ], +) diff --git a/extension/llm/custom_ops/custom_ops.py b/extension/llm/custom_ops/custom_ops.py index dfa357fe356..9aacded4b4c 100644 --- a/extension/llm/custom_ops/custom_ops.py +++ b/extension/llm/custom_ops/custom_ops.py @@ -12,10 +12,16 @@ import logging +from typing import Tuple + import torch +from torch._inductor.lowering import lowerings as L, register_lowering + from torch.library import impl +aten = torch.ops.aten + try: op = torch.ops.llama.sdpa_with_kv_cache.default assert op is not None @@ -387,3 +393,103 @@ def custom_quantized_sdpa_meta( ) return torch.empty(query.size(), dtype=torch.float32, device="meta") + + +# 1) Define the custom op in the "executorch" namespace with name "alias" +@torch.library.custom_op("executorch::alias", mutates_args=()) +def custom_alias(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # no copies, just pass-through + return x, y + + +# 2) FakeTensor kernel: describes output metadata for compile-time +@custom_alias.register_fake +def _(x, y): + # For this op, outputs have exactly the same shape/dtype/device as inputs. + # We just need *dummy* tensors with that metadata. + out_x = torch.empty_like(x) + out_y = torch.empty_like(y) + return out_x, out_y + + +@register_lowering(torch.ops.executorch.alias.default) +def lowering_custom_alias(x, y): + # x, y here are IR values (Inductor's internal representation). + # Alias is logically a no-op – just pass them through. + return x, y + + +# Expecting cache shape: (B, H, S_max, D), value shape (B, H, S, D) where S <= S_max +def _validate_cross_attn_cache_params(value: torch.Tensor, cache: torch.Tensor): + torch._assert(value.dim() == 4, "value must be 4D") + torch._assert(cache.dim() == 4, "cache must be 4D") + # Cache shape: (B, H, S_max, D) + # Value shape: (B, H, S, D) + torch._assert( + value.size(2) <= cache.size(2), + f"value sequence length {value.size(2)} exceeds cache size {cache.size(2)}", + ) + torch._assert(value.size(0) == cache.size(0), "batch size mismatch") + torch._assert(value.size(1) == cache.size(1), "num heads mismatch") + torch._assert(value.size(3) == cache.size(3), "head dim mismatch") + torch._assert(value.dtype == cache.dtype, "dtype mismatch") + + +# Intentionally declaring no mutations to enable use inside torch.cond branches, +# which require pure functions. torch.cond requires branch functions to be mutation-free. +# We omit `cache` from `mutates_args` to satisfy this constraint, accepting the +# mutation for inference use. +@torch.library.custom_op("executorch::update_cross_attn_cache", mutates_args=[]) +def _update_cross_attn_cache(value: torch.Tensor, cache: torch.Tensor) -> torch.Tensor: + """ + Update cross-attention KV cache with new values. + + Copies the value tensor into the beginning of the cache tensor along the + sequence dimension. This is used for cross-attention caching where the + encoder outputs are computed once and reused across decoding steps. + + Args: + value: New values to store in cache. Shape: [B, H, S, D] where + B = batch size, H = num heads, S = sequence length, D = head dim. + cache: Pre-allocated cache tensor to update. Shape: [B, H, S_max, D] + where S_max >= S. + + Returns: + A clone of the updated cache tensor. Note that this is different from + inductor lowering which returns the cache tensor itself. The reason is + that if we return input buffer directly, we will fail torch check in + higher order ops. + + Note: + The cache is mutated in-place, but we return a clone to avoid aliasing + issues with the exported program. + """ + _validate_cross_attn_cache_params(value, cache) + cache[:, :, : value.size(2), :].copy_(value) + return cache.clone() + + +# Register the fake (meta) kernel +@_update_cross_attn_cache.register_fake +def _update_cross_attn_cache_fake( + value: torch.Tensor, cache: torch.Tensor +) -> torch.Tensor: + _validate_cross_attn_cache_params(value, cache) + return torch.empty_like(cache) + + +# Register Inductor lowering +@register_lowering(torch.ops.executorch.update_cross_attn_cache) +def _update_cross_attn_cache_lowering(value, cache): + # cache shape: [B, H, S_max, D] + # value shape: [B, H, S, D] + + # We need to slice the cache along dim 2 (sequence length) + # slice(self, dim, start, end, step=1) + seq_len = value.get_size()[2] + cache_slice = L[aten.slice.Tensor](cache, 2, 0, seq_len, 1) + + # Copy value into the slice + L[aten.copy_.default](cache_slice, value) + + return cache diff --git a/extension/llm/custom_ops/test_update_cross_attn_cache.py b/extension/llm/custom_ops/test_update_cross_attn_cache.py new file mode 100644 index 00000000000..dde2da68f51 --- /dev/null +++ b/extension/llm/custom_ops/test_update_cross_attn_cache.py @@ -0,0 +1,284 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch + +# Import the custom ops to ensure they are registered +from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 + +# Check CUDA availability once at module level +CUDA_AVAILABLE = torch.cuda.is_available() + + +class TestUpdateCrossAttnCache(unittest.TestCase): + def test_update_cross_attn_cache(self): + + # Create tensors + # Cache: [B=2, H=1, S_max=4, D=4] + cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32) + # Value: [B=2, H=1, S=2, D=4] (S < S_max) + value = torch.randn(2, 1, 2, 4, dtype=torch.float32) + + # Compile a function that uses the op + @torch.compile + def fn(v, c): + return torch.ops.executorch.update_cross_attn_cache(v, c) + + # Run it + out = fn(value, cache) + + # Check correctness + # The first 2 elements in dim 2 (sequence dim) should match value + torch.testing.assert_close( + cache[:, :, :2, :], value, msg="Cache slice not updated correctly" + ) + + # Make sure out and cache are close. In eager they are the same objects. + torch.testing.assert_close( + out, cache, msg="Output and cache are different objects" + ) + + # The rest should be zeros + torch.testing.assert_close( + cache[:, :, 2:, :], + torch.zeros_like(cache[:, :, 2:, :]), + msg="Rest of cache was modified", + ) + + def test_update_cross_attn_cache_in_cond(self): + # Create tensors + + # Value: [B=2, H=1, S=2, D=4] + value = torch.randn(2, 1, 2, 4, dtype=torch.float32) + # Alternative value for false branch + value_alt = torch.randn(2, 1, 2, 4, dtype=torch.float32) + + # Define a function that uses the op inside torch.cond + def fn_with_cond(pred, v1, v2, c): + def true_fn(v1, v2, cache): + return torch.ops.executorch.update_cross_attn_cache(v1, cache) + + def false_fn(v1, v2, cache): + return torch.ops.executorch.update_cross_attn_cache(v2, cache) + + return torch.cond(pred, true_fn, false_fn, (v1, v2, c)) + + # Test with true condition + pred_true = torch.tensor(True) + cache_true = torch.zeros(2, 1, 4, 4, dtype=torch.float32) + + # Compile the function + @torch.compile + def compiled_fn(pred, v1, v2, c): + return fn_with_cond(pred, v1, v2, c) + + # Run with true condition + compiled_fn(pred_true, value, value_alt, cache_true) + + # Check that the true branch was executed (value was used) + torch.testing.assert_close( + cache_true[:, :, :2, :], + value, + msg="Cache not updated correctly in true branch", + ) + + # Test with false condition + pred_false = torch.tensor(False) + cache_false = torch.zeros(2, 1, 4, 4, dtype=torch.float32) + + compiled_fn(pred_false, value, value_alt, cache_false) + + # Check that the false branch was executed (value_alt was used) + torch.testing.assert_close( + cache_false[:, :, :2, :], + value_alt, + msg="Cache not updated correctly in false branch", + ) + + def test_update_cross_attn_cache_export(self): + + # Create tensors + # Cache: [B=2, H=1, S_max=4, D=4] + cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32) + # Value: [B=2, H=1, S=2, D=4] + value = torch.randn(2, 1, 2, 4, dtype=torch.float32) + # Alternative value for false branch + value_alt = torch.randn(2, 1, 2, 4, dtype=torch.float32) + + # Define a module that uses torch.cond with the op + class UpdateCacheCondModule(torch.nn.Module): + def forward(self, pred, v1, v2, c): + def true_fn(v1, v2, cache): + return torch.ops.executorch.update_cross_attn_cache(v1, cache) + + def false_fn(v1, v2, cache): + return torch.ops.executorch.update_cross_attn_cache(v2, cache) + + return torch.cond(pred, true_fn, false_fn, (v1, v2, c)) + + module = UpdateCacheCondModule() + + # Export the module with true condition + pred_true = torch.tensor(True) + exported_program = torch.export.export( + module, + (pred_true, value, value_alt, cache), + ) + + # Run the exported program with true condition + cache_true = torch.zeros(2, 1, 4, 4, dtype=torch.float32) + exported_program.module()(pred_true, value, value_alt, cache_true) + + # Check that the true branch was executed (value was used) + torch.testing.assert_close( + cache_true[:, :, :2, :], + value, + msg="Cache not updated correctly in true branch after export", + ) + + # Run the exported program with false condition + pred_false = torch.tensor(False) + cache_false = torch.zeros(2, 1, 4, 4, dtype=torch.float32) + exported_program.module()(pred_false, value, value_alt, cache_false) + + # Check that the false branch was executed (value_alt was used) + torch.testing.assert_close( + cache_false[:, :, :2, :], + value_alt, + msg="Cache not updated correctly in false branch after export", + ) + + def test_update_cross_attn_cache_different_shapes(self): + + # Test with different batch sizes and sequence lengths + test_cases = [ + # (B, H, S_max, S, D) + (1, 2, 10, 5, 8), + (4, 4, 8, 3, 16), + (2, 1, 16, 10, 32), + ] + + @torch.compile + def fn(v, c): + return torch.ops.executorch.update_cross_attn_cache(v, c) + + for B, H, S_max, S, D in test_cases: + # Cache: [B, H, S_max, D], Value: [B, H, S, D] + cache = torch.zeros(B, H, S_max, D, dtype=torch.float32) + value = torch.randn(B, H, S, D, dtype=torch.float32) + + fn(value, cache) + + # Check that the first S positions in dim 2 are updated + torch.testing.assert_close( + cache[:, :, :S, :], + value, + msg=f"Failed for shape B={B}, H={H}, S_max={S_max}, S={S}, D={D}", + ) + + # Check that the rest remain zeros + if S < S_max: + torch.testing.assert_close( + cache[:, :, S:, :], + torch.zeros_like(cache[:, :, S:, :]), + msg=f"Remaining cache modified for shape B={B}, H={H}, S_max={S_max}, S={S}, D={D}", + ) + + def test_update_cross_attn_cache_full_sequence(self): + + # Cache: [B=2, H=1, S_max=4, D=4] + cache = torch.zeros(2, 1, 4, 4, dtype=torch.float32) + # Value: [B=2, H=1, S=4, D=4] (S == S_max) + value = torch.randn(2, 1, 4, 4, dtype=torch.float32) + + @torch.compile + def fn(v, c): + return torch.ops.executorch.update_cross_attn_cache(v, c) + + fn(value, cache) + + # The entire cache should match value + torch.testing.assert_close( + cache, value, msg="Cache not fully updated when S == S_max" + ) + + @unittest.skipUnless(CUDA_AVAILABLE, "CUDA not available") + def test_alias_and_update_cross_attn_cache_with_cond_triton(self): + """Test combining alias and update_cross_attn_cache ops with torch.cond, + lowered to Triton on CUDA. True branch uses alias, false branch uses + update_cross_attn_cache.""" + + # Create CUDA tensors + # Value: [B=2, H=1, S=2, D=4] + value = torch.randn(2, 1, 2, 4, dtype=torch.float32, device="cuda") + # Extra tensor for alias op + extra = torch.randn(2, 1, 4, 4, dtype=torch.float32, device="cuda") + + # Define a function that uses different ops in each branch + def fn_with_cond(pred, v, extra_tensor, c): + def true_fn(v, extra_tensor, cache): + # True branch: use alias op only + aliased_cache, aliased_extra = torch.ops.executorch.alias( + cache, extra_tensor + ) + # Return sum of aliased tensors (no cache mutation) + return aliased_cache + aliased_extra + + def false_fn(v, extra_tensor, cache): + # False branch: use update_cross_attn_cache op only + updated = torch.ops.executorch.update_cross_attn_cache(v, cache) + return updated + + return torch.cond(pred, true_fn, false_fn, (v, extra_tensor, c)) + + # Compile the function with Triton backend + @torch.compile(backend="inductor") + def compiled_fn(pred, v, extra_tensor, c): + return fn_with_cond(pred, v, extra_tensor, c) + + # Test with true condition (alias branch) + pred_true = torch.tensor(True, device="cuda") + cache_true = torch.zeros(2, 1, 4, 4, dtype=torch.float32, device="cuda") + + result_true = compiled_fn(pred_true, value, extra, cache_true) + + # Check that the true branch was executed (alias: cache + extra) + expected_true = cache_true + extra + torch.testing.assert_close( + result_true, + expected_true, + msg="Result incorrect in true branch (alias) with CUDA/Triton", + ) + + # Cache should remain unchanged in true branch (alias doesn't mutate) + torch.testing.assert_close( + cache_true, + torch.zeros(2, 1, 4, 4, dtype=torch.float32, device="cuda"), + msg="Cache should not be mutated in true branch (alias)", + ) + + # Test with false condition (update_cross_attn_cache branch) + pred_false = torch.tensor(False, device="cuda") + cache_false = torch.zeros(2, 1, 4, 4, dtype=torch.float32, device="cuda") + + compiled_fn(pred_false, value, extra, cache_false) + + # Check that the false branch was executed (update_cross_attn_cache) + # The cache should be updated with value in the first S positions + torch.testing.assert_close( + cache_false[:, :, :2, :], + value, + msg="Cache not updated correctly in false branch with CUDA/Triton", + ) + + # The rest of the cache should remain zeros + torch.testing.assert_close( + cache_false[:, :, 2:, :], + torch.zeros(2, 1, 2, 4, dtype=torch.float32, device="cuda"), + msg="Rest of cache was modified in false branch", + ) diff --git a/runtime/core/portable_type/c10/c10/util/safe_numerics.h b/runtime/core/portable_type/c10/c10/util/safe_numerics.h index 32ffca52e48..bfdb968ff96 100644 --- a/runtime/core/portable_type/c10/c10/util/safe_numerics.h +++ b/runtime/core/portable_type/c10/c10/util/safe_numerics.h @@ -3,6 +3,7 @@ #include #include +#include // GCC has __builtin_mul_overflow from before it supported __has_builtin #ifdef _MSC_VER @@ -15,31 +16,45 @@ namespace c10 { -C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) { +template , int> = 0> +C10_ALWAYS_INLINE bool add_overflows(T a, T b, T* out) { #if C10_HAS_BUILTIN_OVERFLOW() return __builtin_add_overflow(a, b, out); #else - unsigned long long tmp; -#if defined(_M_IX86) || defined(_M_X64) - auto carry = _addcarry_u64(0, a, b, &tmp); -#else - tmp = a + b; - unsigned long long vector = (a & b) ^ ((a ^ b) & ~tmp); - auto carry = vector >> 63; -#endif - *out = tmp; - return carry; + if constexpr (std::is_signed_v) { + // For signed types, detect overflow by checking sign changes + volatile T tmp = a + b; + *out = tmp; + + // If both operands have the same sign, check if result changed sign + // unexpectedly. + if ((a > 0) == (b > 0)) { + if ((a > 0) && (tmp <= 0)) { + return true; // Positive overflow + } + if ((a < 0) && (tmp >= 0)) { + return true; // Negative overflow + } + } + return false; + } else { + // For unsigned types, overflow causes wrap-around + volatile T tmp = a + b; + *out = tmp; + return (tmp < a || tmp < b); + } #endif } -template +C10_ALWAYS_INLINE bool add_overflows(uint64_t a, uint64_t b, uint64_t* out) { + return add_overflows(a, b, out); +} + +template , int> = 0> C10_ALWAYS_INLINE bool mul_overflows(T a, T b, T* out) { #if C10_HAS_BUILTIN_OVERFLOW() return __builtin_mul_overflow(a, b, out); #else - static_assert( - std::is_integral_v, "mul_overflows only supports integral types"); - if constexpr (std::is_signed_v) { // For signed types, use the division-based check volatile T tmp = a * b; diff --git a/torch_pin.py b/torch_pin.py index e934463cb70..62a2572fd78 100644 --- a/torch_pin.py +++ b/torch_pin.py @@ -1,2 +1,2 @@ -TORCH_VERSION = "2.10.0" -NIGHTLY_VERSION = "dev20251120" +TORCH_VERSION = "2.11.0" +NIGHTLY_VERSION = "dev20251222"