Skip to content

Still... Inference only. Will take time for full end-end...! Working source for FlashMLA that works on Windows, more specifically blackwell workstation cards..... I want to bmbo, I hate Nvidia I really understand Linus now

License

Notifications You must be signed in to change notification settings

IISuperluminaLII/FlashMLA_Windows_Linux_sm120

Repository files navigation

Status:

external\FlashMLA\flash_mla\cuda_sm120.cp312-win_amd64.pyd'> Functions: ['dense_prefill_bwd', 'dense_prefill_fwd', 'fwd', 'fwd_kvcache_mla', 'get_mla_decoding_metadata', 'sparse_prefill_fwd']

Suace: SM120: TileShape = (64, 16, 128), ThreadShape = (1, 1, 1) TileShapeQK = (64/1, 16/1, 128/1) = (64, 16, 128) TileShapeQK.M = 64 rows per stage 16dp atoms handle 16 × 4 = 64 rows

SM100: TileShape = (128, 128, 128), ThreadShape = (2, 1, 1) TileShapeQK = (64, 128, 128) TileShapeQK.M = 64 rows per stage 16dp atoms handle 64 rows

thats pretty much it, just extra passes when needed to satisfy the pre-defs blackwell atoms because lets just keep driving instead of pitting in SC because flexiblity over logical?????????

FlashMLA

Introduction

FlashMLA is DeepSeek's library of optimized attention kernels, powering the DeepSeek-V3 and DeepSeek-V3.2-Exp models. This repository contains the following implementations:

Sparse Attention Kernels

These kernels power DeepSeek Sparse Attention (DSA), as introduced in this paper.

  • Token-level sparse attention for the prefill stage
  • Token-level sparse attention for the decoding stage, with FP8 KV cache

Dense Attention Kernels

  • Dense attention for the prefill stage
  • Dense attention for the decoding stage

News

  • 2025.09.29 Release of Sparse Attention Kernels: With the launch of DeepSeek-V3.2, we are releasing the corresponding token-level sparse attention kernels. These kernels power the model's DeepSeek Sparse Attention (DSA) and achieve up to 640 TFlops during prefilling and 410 TFlops during decoding. We also release a deep-dive blog for our new FP8 sparse decoding kernel. Check it out here.
  • 2025.08.01 Kernels for MHA on SM100: Thanks to NVIDIA's PR for MHA forward / backward kernels on SM100!
  • 2025.04.22 Deep-Dive Blog: We'd love to share the technical details behind the new FlashMLA kernel! Check out our deep-dive write-up here.
  • 2025.04.22 Performance Update: We're excited to announce the new release of Flash MLA, which delivers 5% ~ 15% performance improvement for compute-bound workloads, achieving up to 660 TFlops on NVIDIA H800 SXM5 GPUs. The interface of the new version is fully compatible with the old one. Simply upgrade to the new version for an immediate performance boost! 🚀🚀🚀

Performance

Test & benchmark MLA decoding (Sparse & Dense):

python tests/test_flash_mla_decoding.py

The dense MLA decoding kernel achieves up to 3000 GB/s in memory-bound configuration and 660 TFLOPS in computation-bound configuration on H800 SXM5 with CUDA 12.8. The token-level sparse MLA decoding kernel (which uses an FP8 KV cache while performing the matrix multiplication in bfloat16) achieves 410 TFLOPS in compute-bound configuration on H800 SXM5 with CUDA 12.8, and achieves up to 350 TFlops on B200 (which is not really optimized yet).

Test & benchmark MHA prefill (Dense):

python tests/test_fmha_sm100.py

It achieves up to 1460 TFlops in forward and 1000 TFlops in backward computation on B200, as reported by NVIDIA.

Test & benchmark MLA prefill (Sparse):

python tests/test_flash_mla_prefill.py

It achieves up to 640 TFlops in forward computation on H800 SXM5 with CUDA 12.8, and achieves up to 1450 TFlops on B200, CUDA 12.9.

Requirements

  • SM90 / SM100 / SM120 (See the support matrix below)
  • CUDA 12.8 and above (CUDA 12.9+ is required for SM100/SM120 kernels)
  • PyTorch 2.0 and above

Support matrix:

Kernel GPU Architecture MLA Mode [2] KVCache Format
Dense Decoding SM90 & SM120 MQA BF16
Sparse Decoding SM90 & SM100 MQA FP8 [1]
Dense Prefill SM100 & SM120 MHA
Sparse Prefill SM90 & SM100 & SM120 MQA

[1]: For more details on using FP8 KV cache, see documents below.

[2]: Here "MLA Mode" refers to the mode used for MLA calculation. MQA stands for Multi-Query Attention mode (i.e. head_dim_k = 576 with head_dim_v = 512), while MHA stands for Multi-Head Attention mode (i.e. head_dim_k = 192 / 128 with head_dim_v = 128). For a detailed explanation of these modes, please refer to the appendix of DeepSeek V3.2's Paper.

Installation

git clone https://github.com/deepseek-ai/FlashMLA.git flash-mla
cd flash-mla
git submodule update --init --recursive
pip install -v .

Usage

MLA Decoding

To use the MLA decoding kernels, call get_mla_metadata once before the decoding loop to get the tile scheduler metadata. Then, call flash_mla_with_kvcache in each decoding step. For example:

from flash_mla import get_mla_metadata, flash_mla_with_kvcache

tile_scheduler_metadata, num_splits = get_mla_metadata(
    cache_seqlens,
    s_q * h_q // h_kv,
    h_kv,
    h_q,
    is_fp8,
    topk,
)

for i in range(num_layers):
    ...
    o_i, lse_i = flash_mla_with_kvcache(
        q_i, kvcache_i, block_table, cache_seqlens, dv,
        tile_scheduler_metadata, num_splits,
        is_causal, is_fp8_kvcache, indices,
    )
    ...

Where

  • s_q is the number of q tokens per q sequence. If MTP (speculative decoding) is disabled, it should be 1.
  • h_kv is the number of key-value heads.
  • h_q is the number of query heads.

FP8 KV Cache: If is_fp8_kvcache is set to True, the kernel reads the KV cache in the "FP8 with scale" format (described below). It dequantizes the cache to bfloat16 and performs attention computation in bfloat16. The output is also in bfloat16.

In the "FP8 with scale" format, each token's KV cache is 656 Bytes, structured as:

  • First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values.
  • Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on.
  • Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy.

See tests/quant.py for quantization and dequantization details.

Sparse Attention (indices tensor): The indices tensor (if provided) enables token-level sparse attention by instructing the kernel to compute attention only for specified tokens.

  • Shape: indices should be a 3D tensor of shape (batch_size, seq_len_q, topk).
  • Format: indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * page_block_size + (the offset of token t within the page block), where t is the k-th token for the j-th query sequence in the i-th batch. Since the index of the page block has already been encoded into indices_in_kvcache, the kernel does not require the block_table parameter.
  • Invalid entries: Set invalid indices to -1.

Return Values: The kernel returns (out, lse), where:

  • out is the attention result.
  • lse is the log-sum-exp value of the attention scores for each query head.

See tests/test_flash_mla_decoding.py for a complete example.

Sparse MLA Prefill

For the sparse MLA prefill kernel, call flash_mla_sparse_fwd directly with the following parameters:

  • q: Query tensor of shape [s_q, h_q, d_qk]
  • kv: Key-Value tensor of shape [s_kv, h_kv, d_qk]
  • indices: Indices tensor of shape [s_q, h_kv, topk]
  • sm_scale: A scalar value

Note on batching: This kernel does not support a batch dimension. For multi-batch inference, reshape the input tensors and adjust the indices parameter to simulate batch processing.

Invalid indices: Set invalid entries in indices to -1 or any number >= s_kv.

Return Values and Equivalent PyTorch Code: The kernel returns (out, max_logits, lse). This is equivalent to the following PyTorch operations:

Q: [s_q, h_q, d_qk], bfloat16
kv: [s_kv, h_kv, d_qk], bfloat16
indices: [s_q, h_kv, topk], int32

kv = kv.squeeze(1)  # [s_kv, d_qk], h_kv must be 1
indices = indices.squeeze(1)    # [s_q, topk]
focused_kv = kv[indices]    # For the i-th sequence (s_q), the corresponding KV tokens are selected from the KV cache based on indices[i, :]. This operation results in a tensor of shape [s_q, topk, d_qk].

P = (Q @ focused_kv.transpose(-1, -2)) * sm_scale * math.log2(math.e)    # [s_q, h_q, topk]
max_logits = P.max(dim=-1) # [s_q, h_q]
lse = log2sumexp2(P, dim=-1, base=2)   # [s_q, h_q],"log2sumexp2" means that the exponentiation and logarithm are base-2
S = exp2(P - lse)      # [s_q, h_q, topk]
out = S @ focused_kv  # [s_q, h_q, d_qk]

return (out, max_logits, lse)

See tests/test_flash_mla_prefill.py for a complete example.

Dense MHA Prefill

This kernel implements the standard dense Multi-Head Attention (MHA) forward and backward operations. It can be called using:

  • flash_attn_varlen_func
  • flash_attn_varlen_qkvpacked_func
  • flash_attn_varlen_kvpacked_func

The usage is similar to the flash_attn package. See tests/test_fmha_sm100.py for a complete example.

Windows Compilation

This repository includes Windows support with MSVC for NVIDIA Blackwell GPUs (SM100a/SM120).

Prerequisites

  • Windows 10/11
  • CUDA 12.9+
  • PyTorch 2.0+
  • Microsoft Visual Studio 2022 Build Tools
  • Python 3.8+

Compilation Options

Option 1: SM100a Only (B100/B200 Server GPUs - 227KB shared memory)

Build for SM100a server GPUs with full shared memory (227KB):

# Clean previous builds
rm -rf build flash_mla.egg-info

# Set environment variables
export NVCC_THREADS=32
export FLASH_MLA_DISABLE_SM90=1

# Build
python setup.py build_ext --inplace

This build includes:

  • Dense MHA prefill kernels (forward + backward) for training
  • Decode kernels for inference
  • Full SM100a optimization (227KB shared memory)

Option 2: SM120 (RTX 6000 Pro / RTX 50 Series - 99KB shared memory)

Build for SM120 workstation GPUs with optimized memory usage (99KB shared memory limit):

# Clean previous builds
rm -rf build flash_mla.egg-info

# Set environment variables
export NVCC_THREADS=32
export FLASH_MLA_DISABLE_SM90=1
export FLASH_MLA_DISABLE_SM100=1

# Build
python setup.py build_ext --inplace

This build includes full training support with the following memory optimizations:

  • kStages=2 (CUTLASS minimum requirement)
  • Minimum CUTLASS-compliant tiles: Q=128, K=128, DQK=16, DVO=16
  • Buffer sharing: smem_v union smem_dq (saves ~20KB)
  • Non-persistent scheduler: UnionType storage (saves ~16KB)
  • Reduced pipeline stages: kStagesReduceTmaStore=1

Total memory savings: ~36-40KB from SM100a baseline, fitting within 99KB limit

Supported GPUs:

  • NVIDIA RTX 6000 Pro Blackwell Workstation Edition
  • NVIDIA GeForce RTX 5090 / 5080 / 5070 (upcoming)

What's Included:

  • Dense MHA prefill kernels (forward + backward) - FULL TRAINING SUPPORT
  • All memory-optimized for 99KB shared memory limit

Verification

Test your build:

import sys
sys.path.insert(0, "path/to/FlashMLA")
import flash_mla
import torch

# Check CUDA availability
if torch.cuda.is_available():
    device = torch.cuda.current_device()
    props = torch.cuda.get_device_properties(device)
    print(f"GPU: {props.name}")
    print(f"Compute Capability: {props.major}.{props.minor}")
    print(f"Shared Memory: {props.shared_memory_per_block / 1024:.0f} KB")

Key Files Modified for SM120 Support

  • csrc/sm100/prefill/dense/sm100_kernel_traits.hpp - Added Sm120WorkstationConfig with memory optimizations
  • csrc/sm100/prefill/dense/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp - Implemented buffer sharing (smem_v union smem_dq)
  • setup.py - Added conditional compilation for SM100a/SM120 variants

Acknowledgement

FlashMLA is inspired by FlashAttention 2&3 and cutlass projects.

Community Support

MetaX

For MetaX GPUs, visit the official website: MetaX.

The corresponding FlashMLA version can be found at: MetaX-MACA/FlashMLA

Moore Threads

For the Moore Threads GPU, visit the official website: Moore Threads.

The corresponding FlashMLA version is available on GitHub: MooreThreads/MT-flashMLA.

Hygon DCU

For the Hygon DCU, visit the official website: Hygon Developer.

The corresponding FlashMLA version is available here: OpenDAS/MLAttention.

Intellifusion

For the Intellifusion NNP, visit the official website: Intellifusion.

The corresponding FlashMLA version is available on Gitee: Intellifusion/tyllm.

Iluvatar Corex

For Iluvatar Corex GPUs, visit the official website: Iluvatar Corex.

The corresponding FlashMLA version is available on GitHub: Deep-Spark/FlashMLA

AMD Instinct

For AMD Instinct GPUs, visit the official website: AMD Instinct.

The corresponding FlashMLA version can be found at: AITER/MLA

Citation

@misc{flashmla2025,
      title={FlashMLA: Efficient Multi-head Latent Attention Kernels},
      author={Jiashi Li, Shengyu Liu},
      year={2025},
      publisher = {GitHub},
      howpublished = {\url{https://github.com/deepseek-ai/FlashMLA}},
}

About

Still... Inference only. Will take time for full end-end...! Working source for FlashMLA that works on Windows, more specifically blackwell workstation cards..... I want to bmbo, I hate Nvidia I really understand Linus now

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 16