All FastPLMs sequence models share a common attention backend system controlled by config.attn_backend. This document covers how each backend works, when to use it, and how to configure it.
| Backend | Key | Numerical Equivalence | Speed | Availability |
|---|---|---|---|---|
| PyTorch SDPA | "sdpa" |
Exact | Fast | Any PyTorch >= 2.0 |
| Flash Attention | "kernels_flash" |
Approximate | Fastest | pip install kernels |
| Flex Attention | "flex" |
Near-exact | Very fast | PyTorch >= 2.5 |
| Auto | "auto" |
Varies | Best available | Always |
PyTorch's scaled_dot_product_attention dispatches to a fused CUDA kernel (cuDNN or memory-efficient attention) that is faster and more memory-efficient than naive attention while being mathematically identical.
When to use: Reproducibility, numerical sensitivity, general-purpose inference.
Implementation: Each attention layer calls F.scaled_dot_product_attention(query, key, value, attn_mask) with a 4D mask of shape (batch, 1, 1, seq_len).
Attention weights: SDPA does not natively return attention weights. When output_attentions=True is requested, all backends (including SDPA) compute attention weights via a separate naive matrix multiplication: scores = Q @ K^T, softmax, then context = scores @ V. This separate computation negates the memory savings of fused attention, so output_attentions=True should only be used for inspection or contact prediction, not during high-throughput inference.
Flash Attention 2/3 tiles the attention computation into blocks that fit in SRAM and applies an online softmax algorithm. This avoids materializing the full (seq_len, seq_len) attention matrix in HBM, achieving O(n) memory and typically 2-4x faster throughput than SDPA on Ampere (A100) and Hopper (H100) GPUs at long sequence lengths.
When to use: Maximum throughput on A100/H100, long sequences, large batch sizes.
Numerical properties: The online softmax and tiling introduce floating-point rounding differences compared to standard attention. These are typically small but not guaranteed to be inconsequential. They can compound across layers and interact with low-precision dtypes (bf16/fp16). If exact reproducibility matters, use "sdpa".
Installation: FastPLMs uses the HuggingFace kernels package for pre-built Flash Attention binaries:
pip install kernelsNo C++ compiler or CUDA toolkit version pinning required. The kernels package fetches a pre-compiled binary matched to your GPU architecture (SM80 for Ampere, SM90 for Hopper). If no compatible binary exists, the model gracefully falls back to "flex" or "sdpa".
Implementation details:
- Q, K, V are transposed from
(batch, heads, seq, dim)to(batch, seq, heads, dim)for the kernels layout - For variable-length batches, padding tokens are removed via
_unpad_input()which computes cumulative sequence lengths - The kernels flash function is called with the unpadded tensors
pad_input()reconstructs the full padded layout- Flash Attention 3 is tried first (Hopper GPUs), falling back to Flash Attention 2
PyTorch's flex_attention (PyTorch >= 2.5) generates a fused Triton kernel customized to the mask pattern. The primary advantage is block masks that skip padding tokens entirely at the CUDA block level, providing meaningful speedups on variable-length batches.
When to use: Variable-length batches with significant padding, best sustained throughput with torch.compile.
Numerical properties: Near-exact to SDPA. Differences are typically within floating-point rounding of naive computation.
First-call compilation: The first forward pass triggers JIT compilation via Triton, which takes 30-120 seconds. All subsequent calls with the same mask shape are fast. When combined with torch.compile, this yields the best sustained throughput.
Implementation:
- A block mask is created from the 2D attention mask via
create_block_mask(mask_mod, batch, 1, seq_len, seq_len) - The mask mod function returns True for positions that should attend to each other
flex_attention(query, key, value, block_mask=block_mask)generates and runs the fused kernel- E1 uses a block-causal variant where within-sequence attention is bidirectional but cross-sequence attention is causal
Selects the best available backend in priority order: kernels_flash -> flex -> sdpa. Useful when you want maximum speed without manual configuration. The resolved backend may differ across machines depending on installed packages and GPU architecture.
The backend must be set on the config before calling from_pretrained:
from transformers import AutoConfig, AutoModelForMaskedLM
config = AutoConfig.from_pretrained("Synthyra/ESM2-150M", trust_remote_code=True)
config.attn_backend = "flex"
model = AutoModelForMaskedLM.from_pretrained(
"Synthyra/ESM2-150M", config=config, trust_remote_code=True
)DPLM and DPLM2 expose a mutable property that propagates to all attention layers:
model = AutoModelForMaskedLM.from_pretrained("Synthyra/DPLM-150M", trust_remote_code=True)
model.attn_backend = "flex" # Updates every attention layer in-placeEach model has a resolve_attention_backend() function that:
- Validates the requested backend string
- For
"auto", probes available backends in order: kernels_flash -> flex -> sdpa - Prints the resolved backend once (globally, to avoid log spam)
- Returns an
AttentionBackendenum value
The resolved enum is stored on each attention layer as self.attn_backend and on the encoder as self.attention_backend.
Different backends require different mask formats. The get_attention_mask() function (or equivalent) in each model produces:
| Backend | Mask Format | Shape |
|---|---|---|
| SDPA | Float 4D mask (-inf for masked) |
(batch, 1, 1, seq_len) |
| Flash | Boolean 2D mask + cumulative seq lengths | (batch, seq_len) |
| Flex | BlockMask via create_block_mask |
Opaque block mask object |
- SDPA: Works well with
torch.compileout of the box - Flex: Best performance when the entire model is compiled; the Triton kernel generation integrates with the compiler
- Flash:
torch.compilewraps the kernels call; dynamic warmup detects when compilation has stabilized
The throughput benchmark (testing/throughput.py) applies torch.compile to all backends and uses dynamic warmup stabilization to ensure measurements reflect compiled performance.
When output_s_max=True is passed (ESM2, E1), each attention layer computes the per-head maximum attention score bound: max(||Q|| * ||K||) per head. This is useful for numerical stability analysis and debugging but adds overhead and should not be enabled during production inference.