Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions .github/workflows/cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ jobs:
export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH
PYTHON_EXECUTABLE=python source .ci/scripts/test_model.sh "${{ matrix.model }}" cmake cuda

test-cuda-shims:
name: test-cuda-shims
unittest-cuda:
name: unittest-cuda
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
permissions:
id-token: write
Expand All @@ -103,17 +103,20 @@ jobs:
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
script: |
set -eux
# Install requirements
bash ./install_requirements.sh
# Install executorch in editable mode so custom op libs land in-tree
bash ./install_executorch.sh --editable

# Build ExecuTorch with CUDA support
cmake --workflow --preset llm-release-cuda

# Build and run CUDA shim tests
# Build and run CUDA shim tests (C++)
pushd backends/cuda/runtime/shims/tests
cmake --workflow --preset default
popd

# Run CUDA backend Python tests, overrides addopts so that we don't run all tests in pytest.ini
python -m pytest backends/cuda/tests backends/cuda/passes/tests extension/llm/custom_ops -v -o "addopts="

export-model-cuda-artifact:
name: export-model-cuda-artifact
# Skip this job if the pull request is from a fork (HuggingFace secrets are not available)
Expand Down
5 changes: 4 additions & 1 deletion backends/aoti/aoti_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def preprocess(
# Apply custom backend-specific passes
custom_passes = cls.get_custom_passes(compile_specs)
for custom_pass in custom_passes:
custom_pass(device_edge_program.graph_module)
if getattr(custom_pass, "requires_exported_program", False):
custom_pass(device_edge_program)
else:
custom_pass(device_edge_program.graph_module)

# Run decompositions if any
if decomposition_table:
Expand Down
8 changes: 7 additions & 1 deletion backends/cuda/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

import torch
from executorch.backends.aoti.aoti_backend import AotiBackend
from executorch.backends.cuda.passes.move_cond_predicate_to_cpu import (
MoveCondPredicateToCpuPass,
)
from executorch.backends.cuda.triton.replacement_pass import (
ReplaceEdgeOpWithTritonOpPass,
)
Expand Down Expand Up @@ -155,7 +158,10 @@ def get_custom_passes(cls, compile_specs: List[CompileSpec]) -> List[typing.Any]
)
triton_kernel_mode = mode

return [ReplaceEdgeOpWithTritonOpPass()] if triton_kernel_mode == "ON" else []
passes = [MoveCondPredicateToCpuPass()]
if triton_kernel_mode == "ON":
passes.append(ReplaceEdgeOpWithTritonOpPass())
return passes

@classmethod
def get_aoti_compile_options(
Expand Down
Empty file.
90 changes: 90 additions & 0 deletions backends/cuda/passes/move_cond_predicate_to_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch.export import ExportedProgram


class MoveCondPredicateToCpuPass:
"""
A pass that moves the predicate of torch.cond to CPU if the predicate is a constantbuffer.
This is useful for models that use the predicate as a constant buffer, such as an `initialized` flag for cross attention kv cache.

This saves ~50us per torch.cond call on RTX 5080.

Example:
```
class CrossAttentionWithCache(torch.nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.k_proj = torch.nn.Linear(hidden_size, hidden_size)
self.v_proj = torch.nn.Linear(hidden_size, hidden_size)
self.q_proj = torch.nn.Linear(hidden_size, hidden_size)
self.out_proj = torch.nn.Linear(hidden_size, hidden_size)
# Buffer used as predicate for torch.cond
self.register_buffer("initialized", torch.tensor([False]), persistent=False)
self.register_buffer("k_cache", torch.zeros(1, 10, hidden_size), persistent=False)
self.register_buffer("v_cache", torch.zeros(1, 10, hidden_size), persistent=False)

def compute_kv(self, encoder_hidden_states):
k = self.k_proj(encoder_hidden_states)
v = self.v_proj(encoder_hidden_states)
self.k_cache.copy_(k)
self.v_cache.copy_(v)
self.initialized.fill_(True)
return k, v

def use_cached_kv(self, encoder_hidden_states):
return self.k_cache.clone(), self.v_cache.clone()

def forward(self, hidden_states, encoder_hidden_states):
q = self.q_proj(hidden_states)
# Use torch.cond with initialized buffer as predicate
k, v = torch.cond(
self.initialized,
self.use_cached_kv,
self.compute_kv,
(encoder_hidden_states,),
)
attn_output = torch.nn.functional.scaled_dot_product_attention(q, k, v)
return self.out_proj(attn_output)
```
In this example if we keep `self.initialized` on GPU, we will need to copy it to CPU for every forward pass.
We move the predicate to CPU to avoid device to host copies.
This pass is only applicable to models that use torch.cond and its predicate is a constant buffer.
"""

requires_exported_program = True

def __call__(self, exported_program: ExportedProgram):
graph_module = exported_program.graph_module

# Map input names to buffer names
inputs_to_buffers = exported_program.graph_signature.inputs_to_buffers

for node in graph_module.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.higher_order.cond
):
pred_node = node.args[0]
if (
pred_node.op == "placeholder"
and pred_node.name in inputs_to_buffers
):
buffer_name = inputs_to_buffers[pred_node.name]

if buffer_name in exported_program.constants:
tensor = exported_program._constants[buffer_name]
if tensor.device.type != "cpu":
exported_program._constants[buffer_name] = tensor.to("cpu")

# Also update the placeholder metadata
if "val" in pred_node.meta:
fake_tensor = pred_node.meta["val"]
if isinstance(fake_tensor, torch.Tensor):
pred_node.meta["val"] = fake_tensor.to("cpu")
exported_program.validate()
Empty file.
Loading
Loading