diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index dbce874cfc0..4452debd480 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -87,8 +87,8 @@ jobs: export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH PYTHON_EXECUTABLE=python source .ci/scripts/test_model.sh "${{ matrix.model }}" cmake cuda - test-cuda-shims: - name: test-cuda-shims + unittest-cuda: + name: unittest-cuda uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main permissions: id-token: write @@ -103,17 +103,20 @@ jobs: ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | set -eux - # Install requirements - bash ./install_requirements.sh + # Install executorch in editable mode so custom op libs land in-tree + bash ./install_executorch.sh --editable # Build ExecuTorch with CUDA support cmake --workflow --preset llm-release-cuda - # Build and run CUDA shim tests + # Build and run CUDA shim tests (C++) pushd backends/cuda/runtime/shims/tests cmake --workflow --preset default popd + # Run CUDA backend Python tests, overrides addopts so that we don't run all tests in pytest.ini + python -m pytest backends/cuda/tests backends/cuda/passes/tests extension/llm/custom_ops -v -o "addopts=" + export-model-cuda-artifact: name: export-model-cuda-artifact # Skip this job if the pull request is from a fork (HuggingFace secrets are not available) diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py index c2c587da9fe..eb732df2a83 100644 --- a/backends/aoti/aoti_backend.py +++ b/backends/aoti/aoti_backend.py @@ -156,7 +156,10 @@ def preprocess( # Apply custom backend-specific passes custom_passes = cls.get_custom_passes(compile_specs) for custom_pass in custom_passes: - custom_pass(device_edge_program.graph_module) + if getattr(custom_pass, "requires_exported_program", False): + custom_pass(device_edge_program) + else: + custom_pass(device_edge_program.graph_module) # Run decompositions if any if decomposition_table: diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index dbbd79f4881..ed34591096e 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -12,6 +12,9 @@ import torch from executorch.backends.aoti.aoti_backend import AotiBackend +from executorch.backends.cuda.passes.move_cond_predicate_to_cpu import ( + MoveCondPredicateToCpuPass, +) from executorch.backends.cuda.triton.replacement_pass import ( ReplaceEdgeOpWithTritonOpPass, ) @@ -155,7 +158,10 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any] ) triton_kernel_mode = mode - return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else [] + passes = [MoveCondPredicateToCpuPass()] + if triton_kernel_mode == "ON": + passes.append(ReplaceEdgeOpWithTritonOpPass()) + return passes @classmethod def get_aoti_compile_options( diff --git a/backends/cuda/passes/__init__.py b/backends/cuda/passes/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backends/cuda/passes/move_cond_predicate_to_cpu.py b/backends/cuda/passes/move_cond_predicate_to_cpu.py new file mode 100644 index 00000000000..f876a684019 --- /dev/null +++ b/backends/cuda/passes/move_cond_predicate_to_cpu.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.export import ExportedProgram + + +class MoveCondPredicateToCpuPass: + """ + A pass that moves the predicate of torch.cond to CPU if the predicate is a constantbuffer. + This is useful for models that use the predicate as a constant buffer, such as an `initialized` flag for cross attention kv cache. + + This saves ~50us per torch.cond call on RTX 5080. + + Example: + ``` + class CrossAttentionWithCache(torch.nn.Module): + def __init__(self, hidden_size): + super().__init__() + self.k_proj = torch.nn.Linear(hidden_size, hidden_size) + self.v_proj = torch.nn.Linear(hidden_size, hidden_size) + self.q_proj = torch.nn.Linear(hidden_size, hidden_size) + self.out_proj = torch.nn.Linear(hidden_size, hidden_size) + # Buffer used as predicate for torch.cond + self.register_buffer("initialized", torch.tensor([False]), persistent=False) + self.register_buffer("k_cache", torch.zeros(1, 10, hidden_size), persistent=False) + self.register_buffer("v_cache", torch.zeros(1, 10, hidden_size), persistent=False) + + def compute_kv(self, encoder_hidden_states): + k = self.k_proj(encoder_hidden_states) + v = self.v_proj(encoder_hidden_states) + self.k_cache.copy_(k) + self.v_cache.copy_(v) + self.initialized.fill_(True) + return k, v + + def use_cached_kv(self, encoder_hidden_states): + return self.k_cache.clone(), self.v_cache.clone() + + def forward(self, hidden_states, encoder_hidden_states): + q = self.q_proj(hidden_states) + # Use torch.cond with initialized buffer as predicate + k, v = torch.cond( + self.initialized, + self.use_cached_kv, + self.compute_kv, + (encoder_hidden_states,), + ) + attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v) + return self.out_proj(attn_output) + ``` + In this example if we keep `self.initialized` on GPU, we will need to copy it to CPU for every forward pass. + We move the predicate to CPU to avoid device to host copies. + This pass is only applicable to models that use torch.cond and its predicate is a constant buffer. + """ + + requires_exported_program = True + + def __call__(self, exported_program: ExportedProgram): + graph_module = exported_program.graph_module + + # Map input names to buffer names + inputs_to_buffers = exported_program.graph_signature.inputs_to_buffers + + for node in graph_module.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.higher_order.cond + ): + pred_node = node.args[0] + if ( + pred_node.op == "placeholder" + and pred_node.name in inputs_to_buffers + ): + buffer_name = inputs_to_buffers[pred_node.name] + + if buffer_name in exported_program.constants: + tensor = exported_program._constants[buffer_name] + if tensor.device.type != "cpu": + exported_program._constants[buffer_name] = tensor.to("cpu") + + # Also update the placeholder metadata + if "val" in pred_node.meta: + fake_tensor = pred_node.meta["val"] + if isinstance(fake_tensor, torch.Tensor): + pred_node.meta["val"] = fake_tensor.to("cpu") + exported_program.validate() diff --git a/backends/cuda/passes/tests/__init__.py b/backends/cuda/passes/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backends/cuda/passes/tests/test_move_cond_predicate_to_cpu.py b/backends/cuda/passes/tests/test_move_cond_predicate_to_cpu.py new file mode 100644 index 00000000000..6b1cc1f63e1 --- /dev/null +++ b/backends/cuda/passes/tests/test_move_cond_predicate_to_cpu.py @@ -0,0 +1,553 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from backends.cuda.passes.move_cond_predicate_to_cpu import MoveCondPredicateToCpuPass +from torch.export import export + + +@unittest.skipUnless(torch.cuda.is_available(), "CUDA is not available") +class TestMoveCondPredicateToCpuPass(unittest.TestCase): + """Test the MoveCondPredicateToCpuPass transformation pass.""" + + def test_gpu_buffer_predicate_moved_to_cpu(self): + """Test that a GPU non-persistent buffer used as predicate is moved to CPU.""" + + class CondWithGpuBufferPredicate(torch.nn.Module): + def __init__(self): + super().__init__() + # Non-persistent buffer goes to constants + self.register_buffer( + "flag", torch.tensor([False], device="cuda"), persistent=False + ) + + def true_branch(self, x): + return x * 2 + + def false_branch(self, x): + return x + 1 + + def forward(self, x): + return torch.cond( + self.flag, + self.true_branch, + self.false_branch, + (x,), + ) + + module = CondWithGpuBufferPredicate() + module.eval() + inputs = (torch.randn(4, 4, device="cuda"),) + + # Export the model + exported_program = export(module, inputs, strict=True) + + # Verify the buffer is on GPU before the pass + buffer_name = None + for name in exported_program.constants: + if "flag" in name: + buffer_name = name + break + + self.assertIsNotNone(buffer_name, "Buffer 'flag' should exist in constants") + self.assertEqual( + exported_program._constants[buffer_name].device.type, + "cuda", + "Buffer should be on CUDA before the pass", + ) + + # Apply the pass + pass_instance = MoveCondPredicateToCpuPass() + pass_instance(exported_program) + + # Verify the buffer is now on CPU + self.assertEqual( + exported_program._constants[buffer_name].device.type, + "cpu", + "Buffer should be on CPU after the pass", + ) + + def test_cpu_buffer_predicate_unchanged(self): + """Test that a CPU non-persistent buffer used as predicate remains on CPU.""" + + class CondWithCpuBufferPredicate(torch.nn.Module): + def __init__(self): + super().__init__() + # Non-persistent buffer on CPU + self.register_buffer( + "flag", torch.tensor([True], device="cpu"), persistent=False + ) + + def true_branch(self, x): + return x * 2 + + def false_branch(self, x): + return x + 1 + + def forward(self, x): + return torch.cond( + self.flag, + self.true_branch, + self.false_branch, + (x,), + ) + + module = CondWithCpuBufferPredicate() + module.eval() + # Input still on cuda, but buffer on cpu + inputs = (torch.randn(4, 4, device="cuda"),) + + exported_program = export(module, inputs, strict=True) + + # Find the buffer + buffer_name = None + for name in exported_program.constants: + if "flag" in name: + buffer_name = name + break + + self.assertIsNotNone(buffer_name, "Buffer 'flag' should exist in constants") + self.assertEqual( + exported_program._constants[buffer_name].device.type, + "cpu", + "Buffer should be on CPU before the pass", + ) + + # Apply the pass + pass_instance = MoveCondPredicateToCpuPass() + pass_instance(exported_program) + + # Verify the buffer remains on CPU + self.assertEqual( + exported_program._constants[buffer_name].device.type, + "cpu", + "Buffer should remain on CPU after the pass", + ) + + def test_computed_predicate_no_change(self): + """Test that a computed predicate (not a buffer) is not affected.""" + + class CondWithComputedPredicate(torch.nn.Module): + def true_branch(self, x): + return x * 2 + + def false_branch(self, x): + return x + 1 + + def forward(self, x): + # Predicate is computed from input, not a buffer + pred = x.sum() > 0 + return torch.cond( + pred, + self.true_branch, + self.false_branch, + (x,), + ) + + module = CondWithComputedPredicate() + module.eval() + inputs = (torch.randn(4, 4, device="cuda"),) + + exported_program = export(module, inputs, strict=True) + + # Apply the pass - should not raise any errors + pass_instance = MoveCondPredicateToCpuPass() + pass_instance(exported_program) + + # Validate the program is still valid + exported_program.validate() + + def test_multiple_cond_with_buffer_predicates(self): + """Test that multiple torch.cond calls with non-persistent buffer predicates are handled.""" + + class MultipleCondWithBufferPredicates(torch.nn.Module): + def __init__(self): + super().__init__() + # Non-persistent buffers go to constants + self.register_buffer( + "flag1", torch.tensor([False], device="cuda"), persistent=False + ) + self.register_buffer( + "flag2", torch.tensor([True], device="cuda"), persistent=False + ) + + def true_branch1(self, x): + return x * 2 + + def false_branch1(self, x): + return x + 1 + + def true_branch2(self, x): + return x - 1 + + def false_branch2(self, x): + return x / 2 + + def forward(self, x): + y = torch.cond( + self.flag1, + self.true_branch1, + self.false_branch1, + (x,), + ) + z = torch.cond( + self.flag2, + self.true_branch2, + self.false_branch2, + (y,), + ) + return z + + module = MultipleCondWithBufferPredicates() + module.eval() + inputs = (torch.randn(4, 4, device="cuda"),) + + exported_program = export(module, inputs, strict=True) + + # Verify both buffers are on GPU before the pass + flag1_name = None + flag2_name = None + for name in exported_program.constants: + if "flag1" in name: + flag1_name = name + elif "flag2" in name: + flag2_name = name + + self.assertIsNotNone(flag1_name) + self.assertIsNotNone(flag2_name) + + self.assertEqual( + exported_program._constants[flag1_name].device.type, + "cuda", + "flag1 should be on CUDA before the pass", + ) + self.assertEqual( + exported_program._constants[flag2_name].device.type, + "cuda", + "flag2 should be on CUDA before the pass", + ) + + # Apply the pass + pass_instance = MoveCondPredicateToCpuPass() + pass_instance(exported_program) + + # Verify both buffers are now on CPU + self.assertEqual( + exported_program._constants[flag1_name].device.type, + "cpu", + "flag1 should be on CPU after the pass", + ) + self.assertEqual( + exported_program._constants[flag2_name].device.type, + "cpu", + "flag2 should be on CPU after the pass", + ) + + def test_cross_attention_cache_pattern(self): + """Test the cross-attention cache pattern from the docstring.""" + + class CrossAttentionWithCache(torch.nn.Module): + def __init__(self, hidden_size): + super().__init__() + self.k_proj = torch.nn.Linear(hidden_size, hidden_size) + self.v_proj = torch.nn.Linear(hidden_size, hidden_size) + self.q_proj = torch.nn.Linear(hidden_size, hidden_size) + self.out_proj = torch.nn.Linear(hidden_size, hidden_size) + # Non-persistent buffer used as predicate for torch.cond + self.register_buffer( + "initialized", + torch.tensor([False], device="cuda"), + persistent=False, + ) + # Non-persistent k_cache and v_cache (these should not be moved) + self.register_buffer( + "k_cache", + torch.zeros(1, 10, hidden_size, device="cuda"), + persistent=False, + ) + self.register_buffer( + "v_cache", + torch.zeros(1, 10, hidden_size, device="cuda"), + persistent=False, + ) + + def compute_kv(self, q, encoder_hidden_states): + k = self.k_proj(encoder_hidden_states) + v = self.v_proj(encoder_hidden_states) + return k, v + + def use_cached_kv(self, q, encoder_hidden_states): + return self.k_cache.clone(), self.v_cache.clone() + + def forward(self, hidden_states, encoder_hidden_states): + q = self.q_proj(hidden_states) + # Use torch.cond with initialized buffer as predicate + k, v = torch.cond( + self.initialized, + self.use_cached_kv, + self.compute_kv, + (q, encoder_hidden_states), + ) + attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v) + return self.out_proj(attn_output) + + hidden_size = 64 + module = CrossAttentionWithCache(hidden_size).cuda() + module.eval() + inputs = ( + torch.randn(1, 5, hidden_size, device="cuda"), # hidden_states + torch.randn(1, 10, hidden_size, device="cuda"), # encoder_hidden_states + ) + + exported_program = export(module, inputs, strict=True) + + # Find the initialized buffer + initialized_name = None + for name in exported_program.constants: + if "initialized" in name: + initialized_name = name + break + + self.assertIsNotNone( + initialized_name, "Buffer 'initialized' should exist in constants" + ) + self.assertEqual( + exported_program._constants[initialized_name].device.type, + "cuda", + "initialized buffer should be on CUDA before the pass", + ) + + # Apply the pass + pass_instance = MoveCondPredicateToCpuPass() + pass_instance(exported_program) + + # Verify the initialized buffer is now on CPU + self.assertEqual( + exported_program._constants[initialized_name].device.type, + "cpu", + "initialized buffer should be on CPU after the pass", + ) + + # Other buffers (k_cache, v_cache) should remain on GPU + for name in exported_program.constants: + if "k_cache" in name or "v_cache" in name: + self.assertEqual( + exported_program._constants[name].device.type, + "cuda", + f"{name} should remain on CUDA (not used as cond predicate)", + ) + + def test_placeholder_meta_updated(self): + """Test that placeholder metadata is updated when buffer is moved.""" + + class CondWithGpuBufferPredicate(torch.nn.Module): + def __init__(self): + super().__init__() + # Non-persistent buffer goes to constants + self.register_buffer( + "flag", torch.tensor([False], device="cuda"), persistent=False + ) + + def true_branch(self, x): + return x * 2 + + def false_branch(self, x): + return x + 1 + + def forward(self, x): + return torch.cond( + self.flag, + self.true_branch, + self.false_branch, + (x,), + ) + + module = CondWithGpuBufferPredicate() + module.eval() + inputs = (torch.randn(4, 4, device="cuda"),) + + exported_program = export(module, inputs, strict=True) + + # Find the predicate placeholder node and verify its metadata + pred_node = None + for node in exported_program.graph_module.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.higher_order.cond + ): + pred_node = node.args[0] + break + + self.assertIsNotNone(pred_node, "Should find a cond node with predicate") + + # Check metadata before pass + if "val" in pred_node.meta: + fake_tensor = pred_node.meta["val"] + if isinstance(fake_tensor, torch.Tensor): + self.assertEqual( + fake_tensor.device.type, + "cuda", + "Placeholder metadata should show CUDA before the pass", + ) + + # Apply the pass + pass_instance = MoveCondPredicateToCpuPass() + pass_instance(exported_program) + + # Check metadata after pass + if "val" in pred_node.meta: + fake_tensor = pred_node.meta["val"] + if isinstance(fake_tensor, torch.Tensor): + self.assertEqual( + fake_tensor.device.type, + "cpu", + "Placeholder metadata should show CPU after the pass", + ) + + def test_requires_exported_program_attribute(self): + """Test that the pass has requires_exported_program attribute set to True.""" + pass_instance = MoveCondPredicateToCpuPass() + self.assertTrue( + pass_instance.requires_exported_program, + "Pass should require an ExportedProgram", + ) + + def test_program_validates_after_pass(self): + """Test that exported program is valid after applying the pass.""" + + class CondWithGpuBufferPredicate(torch.nn.Module): + def __init__(self): + super().__init__() + # Non-persistent buffer goes to constants + self.register_buffer( + "flag", torch.tensor([False], device="cuda"), persistent=False + ) + + def true_branch(self, x): + return x * 2 + + def false_branch(self, x): + return x + 1 + + def forward(self, x): + return torch.cond( + self.flag, + self.true_branch, + self.false_branch, + (x,), + ) + + module = CondWithGpuBufferPredicate() + module.eval() + inputs = (torch.randn(4, 4, device="cuda"),) + + exported_program = export(module, inputs, strict=True) + + # Apply the pass - should not raise + pass_instance = MoveCondPredicateToCpuPass() + pass_instance(exported_program) + + # validate() is called inside the pass, but we call it again to be sure + exported_program.validate() + + def test_no_cond_in_graph(self): + """Test that pass works correctly when there is no torch.cond in the graph.""" + + class SimpleModule(torch.nn.Module): + def __init__(self): + super().__init__() + # Non-persistent buffer to test that it stays on GPU + self.register_buffer( + "buffer", torch.tensor([1.0], device="cuda"), persistent=False + ) + + def forward(self, x): + return x + self.buffer + + module = SimpleModule() + module.eval() + inputs = (torch.randn(4, 4, device="cuda"),) + + exported_program = export(module, inputs, strict=True) + + # Apply the pass - should not raise + pass_instance = MoveCondPredicateToCpuPass() + pass_instance(exported_program) + + # Buffer should remain on GPU since it's not a cond predicate + buffer_name = None + for name in exported_program.constants: + if "buffer" in name: + buffer_name = name + break + + if buffer_name: + self.assertEqual( + exported_program._constants[buffer_name].device.type, + "cuda", + "Buffer should remain on CUDA since it's not a cond predicate", + ) + + def test_persistent_buffer_predicate_not_moved(self): + """Test that a persistent buffer (in state_dict) used as predicate is NOT moved. + + The pass only handles non-persistent buffers stored in `constants`. + Persistent buffers are stored in `state_dict` and should remain unchanged. + """ + + class CondWithPersistentBufferPredicate(torch.nn.Module): + def __init__(self): + super().__init__() + # Persistent buffer (default) goes to state_dict, not constants + self.register_buffer("flag", torch.tensor([False], device="cuda")) + + def true_branch(self, x): + return x * 2 + + def false_branch(self, x): + return x + 1 + + def forward(self, x): + return torch.cond( + self.flag, + self.true_branch, + self.false_branch, + (x,), + ) + + module = CondWithPersistentBufferPredicate() + module.eval() + inputs = (torch.randn(4, 4, device="cuda"),) + + exported_program = export(module, inputs, strict=True) + + # Verify the buffer is in state_dict, NOT in constants + self.assertIn("flag", exported_program.state_dict) + self.assertNotIn("flag", exported_program.constants) + + # Verify the buffer is on GPU before the pass + self.assertEqual( + exported_program.state_dict["flag"].device.type, + "cuda", + "Persistent buffer should be on CUDA before the pass", + ) + + # Apply the pass + pass_instance = MoveCondPredicateToCpuPass() + pass_instance(exported_program) + + # Verify the buffer remains on GPU (pass should not affect state_dict buffers) + self.assertEqual( + exported_program.state_dict["flag"].device.type, + "cuda", + "Persistent buffer should remain on CUDA after the pass (not in constants)", + ) + + +if __name__ == "__main__": + unittest.main()