-
Notifications
You must be signed in to change notification settings - Fork 4
Description
Summary
PyTorch has officially released a native FlashAttention-4 backend for FlexAttention, accessible via kernel_options={"BACKEND": "FLASH"}. This provides automatic CuTeDSL score/mask function generation and JIT instantiation of FA4 kernels — delivering 1.2×–3.2× speedups over the Triton backend on Hopper and Blackwell GPUs.
TorchSpec currently has two separate code paths:
flex_attentionbackend — uses the Triton-based FlexAttention withcompile_friendly_flex_attentionfa_experimentalbackend — manually importsflash_attn.cute.flash_attn_funcand wires up Eagle3 mask_mod through CuTeDSL directly (LlamaFlashAttentionMasked)
The new native pathway could unify these two backends into a single FlexAttention path with an optional BACKEND="FLASH" flag, reducing code complexity while getting FA4-level performance.
Motivation
-
Simpler code — The current
fa_experimentalpath requires manual CuTeDSL integration, custom forward/backward wiring, compilation patching (_patch_cutlass_compilation), and pre-compilation warmup (precompile_flash_attn_masked). The native pathway handles all of this automatically viatorch.compile. -
Better Blackwell performance — The blog reports 2.2×–3.2× speedups on GB200 vs Triton FlexAttention. Since TorchSpec already targets SM100, this is directly relevant.
-
Maintainability — The manual CuTeDSL integration is tightly coupled to
flash_attn.cuteinternals (e.g.,_flash_attn_fwd,_flash_attn_bwd,BlockSparseTensorsTorch). The native pathway is a stable PyTorch API.
Proposed Changes
- Add a new attention backend option (e.g.,
flex_flash) that usesflex_attentionwithkernel_options={"BACKEND": "FLASH"} - Reuse the existing
score_mod/mask_modfunctions fromgenerate_eagle3_mask— they should work as-is - Benchmark against the current
fa_experimentalbackend to validate performance parity - If performance is equivalent or better, consider deprecating
fa_experimentalin favor of the unified path
Example
from functools import partial
from torch.nn.attention.flex_attention import flex_attention
flex_flash = torch.compile(
partial(flex_attention, kernel_options={"BACKEND": "FLASH"}),
dynamic=False,
)
# Existing Eagle3 mask_mod works directly
mask_mod = generate_eagle3_mask(seq_lengths, Q_LEN, KV_LEN)
block_mask = create_block_mask(mask_mod, B, H, Q_LEN, KV_LEN, device="cuda")
out = flex_flash(query, key, value, block_mask=block_mask)References
- FlexAttention + FlashAttention-4: Fast and Flexible (PyTorch Blog, March 2026)
- Reverse Engineering FlashAttention-4 (Modal Blog)