Skip to content

[Feature] Adopt FlexAttention native FA4 backend (BACKEND="FLASH") to replace manual CuTeDSL integration #30

@cicirori

Description

@cicirori

Summary

PyTorch has officially released a native FlashAttention-4 backend for FlexAttention, accessible via kernel_options={"BACKEND": "FLASH"}. This provides automatic CuTeDSL score/mask function generation and JIT instantiation of FA4 kernels — delivering 1.2×–3.2× speedups over the Triton backend on Hopper and Blackwell GPUs.

TorchSpec currently has two separate code paths:

  • flex_attention backend — uses the Triton-based FlexAttention with compile_friendly_flex_attention
  • fa_experimental backend — manually imports flash_attn.cute.flash_attn_func and wires up Eagle3 mask_mod through CuTeDSL directly (LlamaFlashAttentionMasked)

The new native pathway could unify these two backends into a single FlexAttention path with an optional BACKEND="FLASH" flag, reducing code complexity while getting FA4-level performance.

Motivation

  1. Simpler code — The current fa_experimental path requires manual CuTeDSL integration, custom forward/backward wiring, compilation patching (_patch_cutlass_compilation), and pre-compilation warmup (precompile_flash_attn_masked). The native pathway handles all of this automatically via torch.compile.

  2. Better Blackwell performance — The blog reports 2.2×–3.2× speedups on GB200 vs Triton FlexAttention. Since TorchSpec already targets SM100, this is directly relevant.

  3. Maintainability — The manual CuTeDSL integration is tightly coupled to flash_attn.cute internals (e.g., _flash_attn_fwd, _flash_attn_bwd, BlockSparseTensorsTorch). The native pathway is a stable PyTorch API.

Proposed Changes

  • Add a new attention backend option (e.g., flex_flash) that uses flex_attention with kernel_options={"BACKEND": "FLASH"}
  • Reuse the existing score_mod / mask_mod functions from generate_eagle3_mask — they should work as-is
  • Benchmark against the current fa_experimental backend to validate performance parity
  • If performance is equivalent or better, consider deprecating fa_experimental in favor of the unified path

Example

from functools import partial
from torch.nn.attention.flex_attention import flex_attention

flex_flash = torch.compile(
    partial(flex_attention, kernel_options={"BACKEND": "FLASH"}),
    dynamic=False,
)

# Existing Eagle3 mask_mod works directly
mask_mod = generate_eagle3_mask(seq_lengths, Q_LEN, KV_LEN)
block_mask = create_block_mask(mask_mod, B, H, Q_LEN, KV_LEN, device="cuda")
out = flex_flash(query, key, value, block_mask=block_mask)

References

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions