Voxtral Realtime: enable CUDA backend with int4 quantization#17798
Voxtral Realtime: enable CUDA backend with int4 quantization#17798mergennachin wants to merge 1 commit intomainfrom
Conversation
Add CUDA/AOTI backend support for the Voxtral Realtime model alongside the existing XNNPACK and Metal backends. Model (model.py): - CudaSDPA: F.scaled_dot_product_attention with repeat_interleave for GQA expansion and boolean attention masks (Triton SDPA requirement) - StaticKVCache (shared with Metal) for [B,H,S,D] layout with index_copy_ - StandardEncoderRingKVCache/StandardEncoderSDPA for streaming encoder - _build_causal_mask_bool: 4D boolean mask for Triton compatibility - Simplified LMAttention.forward to always pass attn_mask (None for XNNPACK) Export (export_voxtral_rt.py): - --backend cuda with CudaPartitioner and conv1d_to_conv2d decomposition - --dtype flag (default fp32, bf16 for CUDA Triton SDPA) - --qlinear-packing-format / --qlinear-encoder-packing-format for tile_packed_to_4d int4 quantization - CUDA device placement, Dim.AUTO for audio encoder, .ptd output Runner (main.cpp, voxtral_realtime_runner.cpp/.h): - --data_path flag for .ptd delegate data (CUDA compiled kernels) - Module two-arg constructor for pte+ptd loading Build (CMakePresets.json, Makefile): - voxtral-realtime-cuda preset - make voxtral_realtime-cuda target CI (.github/workflows/cuda.yml, .ci/scripts/): - Voxtral Realtime in CUDA CI matrix (int4-tile-packed, offline mode) - Export/test scripts updated for CUDA quantization args and data path
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17798
Note: Links to docs will display an error until the docs builds have been completed. ❌ 7 New Failures, 1 Unrelated FailureAs of commit 1e5399a with merge base 25f2a3f ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
Add CUDA/AOTI backend support for the Voxtral Realtime model alongside
the existing XNNPACK and Metal backends.
Model (model.py):
GQA expansion and boolean attention masks (Triton SDPA requirement)
Export (export_voxtral_rt.py):
tile_packed_to_4d int4 quantization
Runner (main.cpp, voxtral_realtime_runner.cpp/.h):
Build (CMakePresets.json, Makefile):
CI (.github/workflows/cuda.yml, .ci/scripts/):