From e5c3690498310c05f13ae484c88bf2abec7c5d5e Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Mon, 2 Mar 2026 14:28:03 -0800 Subject: [PATCH] Voxtral Realtime: enable CUDA backend with int4 quantization Add CUDA/AOTI backend support for the Voxtral Realtime model alongside the existing XNNPACK and Metal backends. Model (model.py): - CudaSDPA: F.scaled_dot_product_attention with repeat_interleave for GQA expansion and boolean attention masks (Triton SDPA requirement) - StaticKVCache (shared with Metal) for [B,H,S,D] layout with index_copy_ - StandardEncoderRingKVCache/StandardEncoderSDPA for streaming encoder - _build_causal_mask_bool: 4D boolean mask for Triton compatibility - Simplified LMAttention.forward to always pass attn_mask (None for XNNPACK) Export (export_voxtral_rt.py): - --backend cuda with CudaPartitioner and conv1d_to_conv2d decomposition - --dtype flag (default fp32, bf16 for CUDA Triton SDPA) - --qlinear-packing-format / --qlinear-encoder-packing-format for tile_packed_to_4d int4 quantization - CUDA device placement, Dim.AUTO for audio encoder, .ptd output Runner (main.cpp, voxtral_realtime_runner.cpp/.h): - --data_path flag for .ptd delegate data (CUDA compiled kernels) - Module two-arg constructor for pte+ptd loading Build (CMakePresets.json, Makefile): - voxtral-realtime-cuda preset - make voxtral_realtime-cuda target CI (.github/workflows/cuda.yml, .ci/scripts/): - Voxtral Realtime in CUDA CI matrix (int4-tile-packed, offline mode) - Export/test scripts updated for CUDA quantization args and data path --- .ci/scripts/export_model_artifact.sh | 7 +- .ci/scripts/test_model_e2e.sh | 4 + .github/workflows/cuda.yml | 36 ++- Makefile | 14 +- .../models/voxtral_realtime/CMakePresets.json | 33 +++ examples/models/voxtral_realtime/README.md | 67 +++++- .../voxtral_realtime/export_voxtral_rt.py | 126 ++++++++-- examples/models/voxtral_realtime/main.cpp | 9 +- examples/models/voxtral_realtime/model.md | 47 ++-- examples/models/voxtral_realtime/model.py | 112 +++++++-- .../voxtral_realtime_runner.cpp | 223 +++++++++++++----- .../voxtral_realtime_runner.h | 18 +- 12 files changed, 580 insertions(+), 116 deletions(-) diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index 427bb743180..3c0848475a8 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -257,10 +257,14 @@ if [ "$MODEL_NAME" = "voxtral_realtime" ]; then # Per-component quantization flags VR_QUANT_ARGS="" + VR_DTYPE_ARGS="" if [ "$QUANT_NAME" = "quantized-8da4w" ]; then VR_QUANT_ARGS="--qlinear-encoder 8da4w --qlinear 8da4w --qlinear-group-size 32 --qembedding 8w" elif [ "$QUANT_NAME" = "quantized-int4-metal" ]; then VR_QUANT_ARGS="--qlinear-encoder fpa4w --qlinear fpa4w" + elif [ "$QUANT_NAME" = "quantized-int4-tile-packed" ]; then + VR_QUANT_ARGS="--qlinear-encoder 4w --qlinear-encoder-packing-format tile_packed_to_4d --qlinear 4w --qlinear-packing-format tile_packed_to_4d --qembedding 8w" + VR_DTYPE_ARGS="--dtype bf16" fi # Determine streaming mode based on MODE parameter @@ -284,7 +288,8 @@ if [ "$MODEL_NAME" = "voxtral_realtime" ]; then --backend "$DEVICE" \ ${STREAMING_ARG} \ --output-dir "${OUTPUT_DIR}" \ - ${VR_QUANT_ARGS} + ${VR_QUANT_ARGS} \ + ${VR_DTYPE_ARGS} # Export preprocessor python -m executorch.extension.audio.mel_spectrogram ${PREPROCESSOR_ARGS} diff --git a/.ci/scripts/test_model_e2e.sh b/.ci/scripts/test_model_e2e.sh index 12a1b78681d..b0d9a68c5b0 100755 --- a/.ci/scripts/test_model_e2e.sh +++ b/.ci/scripts/test_model_e2e.sh @@ -298,6 +298,10 @@ EOF ;; voxtral_realtime) RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --preprocessor_path ${MODEL_DIR}/$PREPROCESSOR --audio_path ${MODEL_DIR}/$AUDIO_FILE --temperature 0" + # Add CUDA data path if present + if [ "$DEVICE" = "cuda" ] && [ -f "${MODEL_DIR}/aoti_cuda_blob.ptd" ]; then + RUNNER_ARGS="$RUNNER_ARGS --data_path ${MODEL_DIR}/aoti_cuda_blob.ptd" + fi # Determine streaming mode based on MODE parameter USE_STREAMING="true" if [ "$MODE" = "vr-offline" ]; then diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index 23d840b6946..703189a7ee2 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -132,6 +132,8 @@ jobs: model: - repo: "mistralai" name: "Voxtral-Mini-3B-2507" + - repo: "mistralai" + name: "Voxtral-Mini-4B-Realtime-2602" - repo: "openai" name: "whisper-small" - repo: "openai" @@ -152,6 +154,15 @@ jobs: repo: "google" name: "gemma-3-4b-it" quant: "quantized-int4-weight-only" + # Voxtral Realtime only supports int4-tile-packed on CUDA (offline mode) + - model: + repo: "mistralai" + name: "Voxtral-Mini-4B-Realtime-2602" + quant: "non-quantized" + - model: + repo: "mistralai" + name: "Voxtral-Mini-4B-Realtime-2602" + quant: "quantized-int4-weight-only" with: timeout: 90 secrets-env: EXECUTORCH_HF_TOKEN @@ -181,7 +192,12 @@ jobs: echo "::endgroup::" fi - source .ci/scripts/export_model_artifact.sh cuda "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}" + # Voxtral Realtime uses offline mode for CUDA CI (not streaming) + VR_MODE="" + if [ "${{ matrix.model.name }}" = "Voxtral-Mini-4B-Realtime-2602" ]; then + VR_MODE="vr-offline" + fi + source .ci/scripts/export_model_artifact.sh cuda "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}" "$VR_MODE" test-model-cuda-e2e: name: test-model-cuda-e2e @@ -196,6 +212,8 @@ jobs: model: - repo: "mistralai" name: "Voxtral-Mini-3B-2507" + - repo: "mistralai" + name: "Voxtral-Mini-4B-Realtime-2602" - repo: "openai" name: "whisper-small" - repo: "openai" @@ -214,6 +232,15 @@ jobs: repo: "google" name: "gemma-3-4b-it" quant: "quantized-int4-weight-only" + # Voxtral Realtime only supports int4-tile-packed on CUDA (offline mode) + - model: + repo: "mistralai" + name: "Voxtral-Mini-4B-Realtime-2602" + quant: "non-quantized" + - model: + repo: "mistralai" + name: "Voxtral-Mini-4B-Realtime-2602" + quant: "quantized-int4-weight-only" with: timeout: 90 runner: linux.g5.4xlarge.nvidia.gpu @@ -224,7 +251,12 @@ jobs: download-artifact: ${{ matrix.model.repo }}-${{ matrix.model.name }}-cuda-${{ matrix.quant }} ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | - source .ci/scripts/test_model_e2e.sh cuda "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}" + # Voxtral Realtime uses offline mode for CUDA CI (not streaming) + VR_MODE="" + if [ "${{ matrix.model.name }}" = "Voxtral-Mini-4B-Realtime-2602" ]; then + VR_MODE="vr-offline" + fi + source .ci/scripts/test_model_e2e.sh cuda "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}" "$VR_MODE" test-cuda-pybind: name: test-cuda-pybind diff --git a/Makefile b/Makefile index 2bd59778ec6..ad8544210f7 100644 --- a/Makefile +++ b/Makefile @@ -15,7 +15,7 @@ # SUPPORTED MODELS: # ----------------- # - voxtral: Multimodal voice + text model (CPU, CUDA, Metal) -# - voxtral_realtime: Realtime speech-to-text model (CPU) +# - voxtral_realtime: Realtime speech-to-text model (CPU, CUDA, Metal) # - whisper: Speech recognition model (CPU, CUDA, Metal) # - parakeet: Speech recognition model (CPU, CUDA, Metal) # - sortformer: Speaker diarization model (CPU) @@ -91,13 +91,14 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral_realtime-cpu voxtral_realtime-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help help: @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" @echo " voxtral-cuda - Build Voxtral runner with CUDA backend" @echo " voxtral-cpu - Build Voxtral runner with CPU backend" @echo " voxtral-metal - Build Voxtral runner with Metal backend (macOS only)" + @echo " voxtral_realtime-cuda - Build Voxtral Realtime runner with CUDA backend" @echo " voxtral_realtime-cpu - Build Voxtral Realtime runner with CPU backend" @echo " voxtral_realtime-metal - Build Voxtral Realtime runner with Metal backend (macOS only)" @echo " whisper-cuda - Build Whisper runner with CUDA backend" @@ -244,6 +245,15 @@ voxtral_realtime-metal: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner" +voxtral_realtime-cuda: + @echo "==> Building and installing ExecuTorch with CUDA..." + cmake --workflow --preset llm-release-cuda + @echo "==> Building Voxtral Realtime runner with CUDA..." + cd examples/models/voxtral_realtime && cmake --workflow --preset voxtral-realtime-cuda + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner" + silero-vad-cpu: @echo "==> Building and installing ExecuTorch..." cmake --workflow --preset llm-release diff --git a/examples/models/voxtral_realtime/CMakePresets.json b/examples/models/voxtral_realtime/CMakePresets.json index 7c6978ecc81..707e94b0169 100644 --- a/examples/models/voxtral_realtime/CMakePresets.json +++ b/examples/models/voxtral_realtime/CMakePresets.json @@ -28,6 +28,19 @@ "type": "equals", "rhs": "Darwin" } + }, + { + "name": "voxtral-realtime-cuda", + "displayName": "Voxtral Realtime runner (CUDA)", + "inherits": ["voxtral-realtime-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_CUDA": "ON" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": ["Linux", "Windows"] + } } ], "buildPresets": [ @@ -43,6 +56,12 @@ "configurePreset": "voxtral-realtime-metal", "configuration": "Release", "targets": ["voxtral_realtime_runner"] + }, + { + "name": "voxtral-realtime-cuda", + "displayName": "Build Voxtral Realtime runner (CUDA)", + "configurePreset": "voxtral-realtime-cuda", + "targets": ["voxtral_realtime_runner"] } ], "workflowPresets": [ @@ -73,6 +92,20 @@ "name": "voxtral-realtime-metal" } ] + }, + { + "name": "voxtral-realtime-cuda", + "displayName": "Configure and build Voxtral Realtime runner (CUDA)", + "steps": [ + { + "type": "configure", + "name": "voxtral-realtime-cuda" + }, + { + "type": "build", + "name": "voxtral-realtime-cuda" + } + ] } ] } diff --git a/examples/models/voxtral_realtime/README.md b/examples/models/voxtral_realtime/README.md index c39be3fd9bc..7d29ba8c11b 100644 --- a/examples/models/voxtral_realtime/README.md +++ b/examples/models/voxtral_realtime/README.md @@ -88,8 +88,43 @@ python export_voxtral_rt.py \ |---------|---------|-----------|--------------| | `xnnpack` | ✓ | ✓ | `4w`, `8w`, `8da4w`, `8da8w` | | `metal` | ✓ | ✓ | none (fp32) or `fpa4w` (Metal-specific 4-bit) | +| `cuda` | ✓ | ✓ | `4w`, `8w` | -Metal backend provides Apple GPU acceleration. +Metal backend provides Apple GPU acceleration. CUDA backend provides NVIDIA GPU +acceleration via AOTInductor. + +#### CUDA export examples + +Offline with int4 quantization: + +```bash +python export_voxtral_rt.py \ + --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ + --backend cuda \ + --dtype bf16 \ + --output-dir ./voxtral_rt_exports \ + --qlinear-encoder 4w \ + --qlinear-encoder-packing-format tile_packed_to_4d \ + --qlinear 4w \ + --qlinear-packing-format tile_packed_to_4d \ + --qembedding 8w +``` + +Streaming with int4 quantization: + +```bash +python export_voxtral_rt.py \ + --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ + --backend cuda \ + --dtype bf16 \ + --streaming \ + --output-dir ./voxtral_rt_exports \ + --qlinear-encoder 4w \ + --qlinear-encoder-packing-format tile_packed_to_4d \ + --qlinear 4w \ + --qlinear-packing-format tile_packed_to_4d \ + --qembedding 8w +``` #### Metal export examples @@ -133,14 +168,17 @@ EXECUTORCH_BUILD_KERNELS_TORCHAO=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 ./install_ex | Flag | Default | Description | |------|---------|-------------| | `--model-path` | (required) | Directory with `params.json` + `consolidated.safetensors` | -| `--backend` | `xnnpack` | `xnnpack`, `metal`, or `portable` | +| `--backend` | `xnnpack` | `xnnpack`, `metal`, `cuda`, or `portable` | +| `--dtype` | `fp32` | Model dtype: `fp32` or `bf16` | | `--output-dir` | `./voxtral_rt_exports` | Output directory | | `--max-seq-len` | `4096` | KV cache length | | `--delay-tokens` | `6` | Transcription delay in tokens (6 = 480ms) | | `--qlinear` | (none) | Decoder linear layer quantization (`4w`, `8w`, `8da4w`, `8da8w`, `fpa4w`) | | `--qlinear-group-size` | `32` | Group size for decoder linear quantization | +| `--qlinear-packing-format` | (none) | Packing format for decoder 4w quantization (`tile_packed_to_4d` for CUDA) | | `--qlinear-encoder` | (none) | Encoder linear layer quantization (`4w`, `8w`, `8da4w`, `8da8w`, `fpa4w`) | | `--qlinear-encoder-group-size` | `32` | Group size for encoder linear quantization | +| `--qlinear-encoder-packing-format` | (none) | Packing format for encoder 4w quantization (`tile_packed_to_4d` for CUDA) | | `--qembedding` | (none) | Embedding layer quantization (`8w`) | | `--streaming` | off | Export streaming encoder with KV cache | | `--max-enc-len` | `750` | Encoder sliding window size (streaming only) | @@ -164,6 +202,15 @@ make voxtral_realtime-cpu This builds ExecuTorch core libraries with XNNPACK, then the runner binary at `cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner`. +### CUDA (NVIDIA GPU) + +```bash +make voxtral_realtime-cuda +``` + +This builds ExecuTorch with CUDA backend support. The runner binary is at +the same path as above. Requires NVIDIA GPU with CUDA toolkit installed. + ### Metal (Apple GPU) ```bash @@ -180,10 +227,22 @@ The runner requires: - `tekken.json` — tokenizer from the model weights directory - `preprocessor.pte` — mel spectrogram preprocessor (see [Preprocessor](#preprocessor)) - A 16kHz mono WAV audio file (or live audio via `--mic`) +- For CUDA: `aoti_cuda_blob.ptd` — delegate data file (pass via `--data_path`) + +```bash +cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner \ + --model_path voxtral_rt_exports/model.pte \ + --tokenizer_path ~/models/Voxtral-Mini-4B-Realtime-2602/tekken.json \ + --preprocessor_path voxtral_rt_exports/preprocessor.pte \ + --audio_path input.wav +``` + +For CUDA, include the `.ptd` data file: ```bash cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner \ --model_path voxtral_rt_exports/model.pte \ + --data_path voxtral_rt_exports/aoti_cuda_blob.ptd \ --tokenizer_path ~/models/Voxtral-Mini-4B-Realtime-2602/tekken.json \ --preprocessor_path voxtral_rt_exports/preprocessor.pte \ --audio_path input.wav @@ -218,9 +277,13 @@ ffmpeg -f avfoundation -i ":0" -ar 16000 -ac 1 -f f32le -nostats -loglevel error Ctrl+C stops recording and flushes remaining text. +**CUDA:** Add `--data_path voxtral_rt_exports/aoti_cuda_blob.ptd` to all +run commands above when using the CUDA backend. + | Flag | Default | Description | |------|---------|-------------| | `--model_path` | `model.pte` | Path to exported model | +| `--data_path` | (none) | Path to delegate data file (`.ptd`, required for CUDA) | | `--tokenizer_path` | `tekken.json` | Path to Tekken tokenizer | | `--preprocessor_path` | (none) | Path to mel preprocessor `.pte` | | `--audio_path` | (none) | Path to 16kHz mono WAV file | diff --git a/examples/models/voxtral_realtime/export_voxtral_rt.py b/examples/models/voxtral_realtime/export_voxtral_rt.py index 21dc5bf0ea7..c813f0ecc34 100644 --- a/examples/models/voxtral_realtime/export_voxtral_rt.py +++ b/examples/models/voxtral_realtime/export_voxtral_rt.py @@ -22,12 +22,16 @@ and StandardEncoderSDPA (F.scaled_dot_product_attention) for streaming encoder, avoiding custom_sdpa which is incompatible with AOTI. Uses Dim.AUTO for audio encoder dynamic shapes (explicit bounds cause issues with AOTI). + - CUDA/AOTI: Uses CudaSDPA (F.scaled_dot_product_attention with GQA expansion) for text_decoder + and StandardEncoderSDPA for streaming encoder. Compiles to CUDA kernels via + AOTInductor. Supports int4 quantization via _weight_int4pack_mm fallback kernel. - Portable: Uses custom SDPA like XNNPACK Usage: python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --streaming python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend metal + python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend cuda --qlinear 4w """ import argparse @@ -101,12 +105,19 @@ def forward(self, token_ids: torch.Tensor) -> torch.Tensor: def _export_decoder_and_embedding( - programs, model, max_seq_len, qlinear, qlinear_group_size, qembedding + programs, + model, + max_seq_len, + qlinear, + qlinear_group_size, + qlinear_packing_format, + qembedding, + device="cpu", ): """Export text_decoder and token_embedding into programs dict.""" from executorch.extension.llm.export.quantize import quantize_model_ - param_dtype = torch.float32 + param_dtype = next(model.parameters()).dtype print("\nExporting text_decoder...") text_decoder = TextDecoderExport(model) @@ -118,11 +129,14 @@ def _export_decoder_and_embedding( text_decoder, qlinear_config=qlinear, qlinear_group_size=qlinear_group_size, + qlinear_packing_format=qlinear_packing_format, ) seq_dim = Dim("seq_len", min=1, max=max_seq_len) - sample_embeds = torch.randn(1, 4, model.config.dim, dtype=param_dtype) - sample_pos = torch.arange(4, dtype=torch.long) + sample_embeds = torch.randn( + 1, 4, model.config.dim, dtype=param_dtype, device=device + ) + sample_pos = torch.arange(4, dtype=torch.long, device=device) programs["text_decoder"] = export( text_decoder, (sample_embeds, sample_pos), @@ -146,7 +160,7 @@ def _export_decoder_and_embedding( ) tok_seq_dim = Dim("tok_seq_len", min=1, max=max_seq_len) - sample_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long) + sample_ids = torch.tensor([[0, 1, 2, 3]], dtype=torch.long, device=device) programs["token_embedding"] = export( tok_emb, (sample_ids,), @@ -161,8 +175,10 @@ def export_all( max_seq_len, qlinear_encoder=None, qlinear_encoder_group_size=32, + qlinear_encoder_packing_format=None, qlinear=None, qlinear_group_size=32, + qlinear_packing_format=None, qembedding=None, backend="xnnpack", ): @@ -170,7 +186,8 @@ def export_all( from executorch.extension.llm.export.quantize import quantize_model_ programs = {} - param_dtype = torch.float32 + param_dtype = next(model.parameters()).dtype + device = "cuda" if backend == "cuda" else "cpu" # 1. Audio encoder print("\nExporting audio_encoder...") @@ -183,20 +200,23 @@ def export_all( audio_encoder, qlinear_config=qlinear_encoder, qlinear_group_size=qlinear_encoder_group_size, + qlinear_packing_format=qlinear_encoder_packing_format, ) - # For Metal/AOTI: use max size as sample and Dim.AUTO (explicit bounds cause issues) + # For Metal/CUDA/AOTI: use max size as sample and Dim.AUTO (explicit bounds cause issues) # For XNNPACK: use small sample with explicit bounds - if backend == "metal": + if backend in ("metal", "cuda"): max_t_mel = 24000 # 3000 * 8 sample_mel = torch.randn( - 1, model.config.num_mel_bins, max_t_mel, dtype=param_dtype + 1, model.config.num_mel_bins, max_t_mel, dtype=param_dtype, device=device ) dynamic_shapes = {"mel": {2: Dim.AUTO}} else: _t_mel_base = Dim("_t_mel_base", min=1, max=3000) t_mel_dim = 8 * _t_mel_base - sample_mel = torch.randn(1, model.config.num_mel_bins, 160, dtype=param_dtype) + sample_mel = torch.randn( + 1, model.config.num_mel_bins, 160, dtype=param_dtype, device=device + ) dynamic_shapes = {"mel": {2: t_mel_dim}} programs["audio_encoder"] = export( @@ -209,7 +229,14 @@ def export_all( # 2-3. Text decoder + token embedding _export_decoder_and_embedding( - programs, model, max_seq_len, qlinear, qlinear_group_size, qembedding + programs, + model, + max_seq_len, + qlinear, + qlinear_group_size, + qlinear_packing_format, + qembedding, + device, ) metadata = { @@ -232,8 +259,10 @@ def export_streaming( max_enc_len=750, qlinear_encoder=None, qlinear_encoder_group_size=32, + qlinear_encoder_packing_format=None, qlinear=None, qlinear_group_size=32, + qlinear_packing_format=None, qembedding=None, backend="xnnpack", ): @@ -241,7 +270,8 @@ def export_streaming( from executorch.extension.llm.export.quantize import quantize_model_ programs = {} - param_dtype = torch.float32 + param_dtype = next(model.parameters()).dtype + device = "cuda" if backend == "cuda" else "cpu" # 1. Streaming audio encoder print("\nExporting encode_audio_chunk...") @@ -250,6 +280,7 @@ def export_streaming( ) streaming_enc = StreamingAudioEncoderExport(model, max_enc_len=max_enc_len) + streaming_enc.to(device=device, dtype=param_dtype) streaming_enc.eval() if qlinear_encoder: @@ -258,10 +289,13 @@ def export_streaming( streaming_enc, qlinear_config=qlinear_encoder, qlinear_group_size=qlinear_encoder_group_size, + qlinear_packing_format=qlinear_encoder_packing_format, ) - sample_mel_chunk = torch.randn(1, model.config.num_mel_bins, 8, dtype=param_dtype) - sample_enc_pos = torch.arange(4, dtype=torch.long) + sample_mel_chunk = torch.randn( + 1, model.config.num_mel_bins, 8, dtype=param_dtype, device=device + ) + sample_enc_pos = torch.arange(4, dtype=torch.long, device=device) programs["encode_audio_chunk"] = export( streaming_enc, @@ -275,7 +309,14 @@ def export_streaming( # 2-3. Text decoder + token embedding _export_decoder_and_embedding( - programs, model, max_seq_len, qlinear, qlinear_group_size, qembedding + programs, + model, + max_seq_len, + qlinear, + qlinear_group_size, + qlinear_packing_format, + qembedding, + device, ) # Derive STFT overlap from audio parameters. @@ -363,6 +404,25 @@ def lower_to_executorch(programs, metadata, backend="xnnpack"): for key in programs: compile_specs = [MetalBackend.generate_method_name_compile_spec(key)] partitioner[key] = [MetalPartitioner(compile_specs)] + elif backend == "cuda": + from executorch.backends.cuda.cuda_backend import CudaBackend + from executorch.backends.cuda.cuda_partitioner import CudaPartitioner + from torch._inductor.decomposition import conv1d_to_conv2d + + print("\nLowering to ExecuTorch with CUDA...") + + # Run conv1d decomposition for CUDA backend + updated_programs = {} + for key, ep in programs.items(): + updated_programs[key] = ep.run_decompositions( + {torch.ops.aten.conv1d.default: conv1d_to_conv2d} + ) + programs = updated_programs + + partitioner = {} + for key in programs: + compile_specs = [CudaBackend.generate_method_name_compile_spec(key)] + partitioner[key] = [CudaPartitioner(compile_specs)] else: print("\nLowering to ExecuTorch (portable)...") partitioner = [] @@ -403,7 +463,7 @@ def main(): parser.add_argument( "--backend", default="xnnpack", - choices=["portable", "xnnpack", "metal"], + choices=["portable", "xnnpack", "metal", "cuda"], help="Backend for acceleration (default: xnnpack)", ) parser.add_argument( @@ -435,6 +495,12 @@ def main(): default=32, help="Group size for decoder linear quantization (default: 32).", ) + parser.add_argument( + "--qlinear-packing-format", + default=None, + choices=["tile_packed_to_4d"], + help="Packing format for decoder 4w quantization (CUDA: tile_packed_to_4d).", + ) parser.add_argument( "--qlinear-encoder", default=None, @@ -447,6 +513,12 @@ def main(): default=32, help="Group size for encoder linear quantization (default: 32).", ) + parser.add_argument( + "--qlinear-encoder-packing-format", + default=None, + choices=["tile_packed_to_4d"], + help="Packing format for encoder 4w quantization (CUDA: tile_packed_to_4d).", + ) parser.add_argument( "--qembedding", default=None, @@ -464,6 +536,12 @@ def main(): default=750, help="Encoder sliding window size for streaming (default: 750).", ) + parser.add_argument( + "--dtype", + default="fp32", + choices=["fp32", "bf16"], + help="Model dtype (default: fp32).", + ) args = parser.parse_args() # Validate fpa4w quantization requires Metal backend @@ -474,15 +552,22 @@ def main(): os.makedirs(args.output_dir, exist_ok=True) - # Load model + model_dtype = {"fp32": torch.float32, "bf16": torch.bfloat16}[args.dtype] + print("Loading model...") model = load_model( args.model_path, max_seq_len=args.max_seq_len, n_delay_tokens=args.delay_tokens, + dtype=model_dtype, backend=args.backend, ) + # Move to CUDA for CUDA backend export (AOTInductor needs CUDA tensors) + if args.backend == "cuda": + print("Moving model to CUDA...") + model.cuda() + # Untie output/embedding weights before quantization so each layer gets # its own quantization config (embedding: 8w, output linear: 8da4w). if args.qlinear or args.qembedding: @@ -495,8 +580,10 @@ def main(): quant_args = { "qlinear_encoder": args.qlinear_encoder, "qlinear_encoder_group_size": args.qlinear_encoder_group_size, + "qlinear_encoder_packing_format": args.qlinear_encoder_packing_format, "qlinear": args.qlinear, "qlinear_group_size": args.qlinear_group_size, + "qlinear_packing_format": args.qlinear_packing_format, "qembedding": args.qembedding, "backend": args.backend, } @@ -518,6 +605,11 @@ def main(): size_mb = os.path.getsize(pte_path) / (1024 * 1024) print(f"Saved {size_mb:.1f} MB") + # Write tensor data for CUDA backend (.ptd file with compiled .so and weights) + if et._tensor_data: + et.write_tensor_data_to_file(args.output_dir) + print(f"Saved tensor data to {args.output_dir}/") + print("\nDone!") diff --git a/examples/models/voxtral_realtime/main.cpp b/examples/models/voxtral_realtime/main.cpp index 3824c50e921..0f474c22a93 100644 --- a/examples/models/voxtral_realtime/main.cpp +++ b/examples/models/voxtral_realtime/main.cpp @@ -60,6 +60,10 @@ DEFINE_int32( 80, "Mic read chunk size in ms. Multiples of 80 align with the model's " "streaming step (80, 160, 320, 640, 960)."); +DEFINE_string( + data_path, + "", + "Path to data file (.ptd) for delegate data (required for CUDA)."); DEFINE_string( color, "", @@ -97,7 +101,10 @@ int main(int argc, char** argv) { stats.model_load_start_ms = ::executorch::extension::llm::time_in_ms(); voxtral_realtime::VoxtralRealtimeRunner runner( - FLAGS_model_path, FLAGS_tokenizer_path, FLAGS_preprocessor_path); + FLAGS_model_path, + FLAGS_tokenizer_path, + FLAGS_preprocessor_path, + FLAGS_data_path); stats.model_load_end_ms = ::executorch::extension::llm::time_in_ms(); stats.inference_start_ms = ::executorch::extension::llm::time_in_ms(); diff --git a/examples/models/voxtral_realtime/model.md b/examples/models/voxtral_realtime/model.md index 4268eddfe2e..fe240b03d8c 100644 --- a/examples/models/voxtral_realtime/model.md +++ b/examples/models/voxtral_realtime/model.md @@ -101,8 +101,8 @@ VoxtralRealtimeModel attention_norm: RMSNorm attention: LMAttention wq/wk/wv/wo: Linear (no bias) - kv_cache: KVCache (XNNPACK) or StaticKVCache (Metal) - sdpa: SDPA (XNNPACK) or MetalSDPA (Metal) + kv_cache: KVCache (XNNPACK) or StaticKVCache (Metal/CUDA) + sdpa: SDPA (XNNPACK) or MetalSDPA (Metal) or CudaSDPA (CUDA) ffn_norm: RMSNorm ada_rms_norm_t_cond: Sequential(Linear, GELU, Linear) feed_forward: LMMLP (w1/w2/w3) @@ -115,8 +115,8 @@ StreamingAudioEncoderExport layers: 32x CausalEncoderLayer (shared from encoder.layers) enc_norm: RMSNorm (shared from encoder.norm) adapter: AudioLanguageAdapter (shared from model.adapter) - kv_caches: 32x EncoderRingKVCache (XNNPACK) or StandardEncoderRingKVCache (Metal) - sdpa: SDPA (XNNPACK) or StandardEncoderSDPA (Metal) + kv_caches: 32x EncoderRingKVCache (XNNPACK) or StandardEncoderRingKVCache (Metal/CUDA) + sdpa: SDPA (XNNPACK) or StandardEncoderSDPA (Metal/CUDA) inv_freq: RoPE inverse frequencies (owned, on-the-fly computation) ``` @@ -127,7 +127,7 @@ spectrogram at once. No KV cache, no GQA (n_heads == n_kv_heads). `EncoderAttention` uses `F.scaled_dot_product_attention` with `is_causal=True`, transposing to `[B, H, T, D]` internally. No custom -ops needed — works on all backends (XNNPACK, Metal, Portable). +ops needed — works on all backends (XNNPACK, Metal, CUDA, Portable). The offline encoder uses full causal attention (no sliding window). The model's `params.json` specifies `sliding_window: 750` but this is @@ -139,7 +139,7 @@ than 750 encoder frames (~15s), full causal is equivalent. The text decoder (`MistralDecoder`) is a 26-layer Mistral decoder with GQA (32 query heads, 8 KV heads). Backend selection is controlled by the `backend` config field, passed through from the export script's `--backend` -flag (e.g., `"xnnpack"`, `"metal"`, `"portable"`). +flag (e.g., `"xnnpack"`, `"metal"`, `"cuda"`, `"portable"`). ### KV cache @@ -150,7 +150,7 @@ triggers a `requires_grad` bug in `SpecPropPass` during `to_executorch()`. The `[B, S, H, D]` layout matches what `update_cache` and `custom_sdpa` expect, so there are no transposes between cache update and attention. -**Metal:** `StaticKVCache` with `[B, H, S, D]` layout. Uses `index_copy_` +**Metal/CUDA:** `StaticKVCache` with `[B, H, S, D]` layout. Uses `index_copy_` for cache updates, which is compatible with `torch.export` and AOTI. ### SDPA @@ -167,6 +167,12 @@ which handles GQA natively via `gqa_factor`, avoiding the memory bandwidth overhead of `repeat_interleave`. Uses explicit additive attention masks. AOTInductor has compatibility issues with the `custom_sdpa` custom op. +**CUDA:** `CudaSDPA` uses `F.scaled_dot_product_attention` with +`repeat_interleave` for GQA expansion (32 query heads / 8 KV heads = 4x). +Uses boolean attention masks (`True`=attend, `False`=masked) as required +by the Triton SDPA kernel. The CUDA backend's Triton SDPA replacement +pass optimizes the attention kernel at compile time. + ### Attention layout **XNNPACK/Portable:** Q/K/V projections produce `[B, T, H, D]` via @@ -176,10 +182,9 @@ AOTInductor has compatibility issues with the `custom_sdpa` custom op. `RemoveRedundantTransposes` post-export pass that Llama/optimum-executorch require when using `[B, H, S, D]` attention with `[B, S, H, D]` cache. -**Metal:** Q/K/V projections still produce `[B, T, H, D]`, but -`StaticKVCache` stores `[B, H, S, D]` and `MetalSDPA` transposes q to -`[B, H, T, D]` for `_scaled_dot_product_attention_math_for_mps`, then -transposes back. +**Metal/CUDA:** Q/K/V projections still produce `[B, T, H, D]`, but +`StaticKVCache` stores `[B, H, S, D]` and `MetalSDPA`/`CudaSDPA` transpose q to +`[B, H, T, D]` for the SDPA kernel, then transpose back. ### Adaptive RMSNorm @@ -220,10 +225,9 @@ mel_chunk (1, 128, 8) + enc_input_pos (4,) **XNNPACK/Portable:** Uses `EncoderRingKVCache` (`update_cache_with_indices` custom op) and `SDPA` (`custom_sdpa`). -**Metal:** Uses `StandardEncoderRingKVCache` (`index_copy_`-based ring +**Metal/CUDA:** Uses `StandardEncoderRingKVCache` (`index_copy_`-based ring buffer) and `StandardEncoderSDPA` (`F.scaled_dot_product_attention` with -explicit sliding window masks) — the same patterns used in the Metal -text decoder. +explicit sliding window masks) — AOTI-compatible patterns avoiding custom ops. ### Streaming decode loop @@ -257,7 +261,7 @@ encoder — verified to within fp32 precision (max diff < 2e-5). Each of the 32 encoder transformer layers gets its own ring buffer KV cache (`EncoderRingKVCache` for XNNPACK/Portable, `StandardEncoderRingKVCache` -for Metal) that overwrites old entries when the window is exceeded, +for Metal/CUDA) that overwrites old entries when the window is exceeded, enabling streaming of arbitrary length audio. - Cache shape: `(1, 2*max_enc_len, 32, 64)` per layer. The buffer is 2x the @@ -283,7 +287,7 @@ for the offline encoder. (a custom op that scatter-writes via an indices tensor). Write indices are computed analytically: `(arange(seq_len) + start_pos) % buf_size`. -**Metal:** Cache writes use `index_copy_` with wrapped indices +**Metal/CUDA:** Cache writes use `index_copy_` with wrapped indices (`input_pos % buf_size`). No mutable position state is needed in either variant. @@ -304,7 +308,10 @@ is computed from these positions each step: ```python valid = (cache_pos >= 0) & (delta >= 0) & (delta < window_size) +# Metal: float additive mask mask = torch.where(valid, 0.0, float("-inf")) +# CUDA: boolean mask (bool_mask=True returns valid directly) +mask = valid ``` The mask is identical for all 32 layers (same `input_pos`), so it @@ -360,6 +367,10 @@ Parakeet pattern), allowing different configs for encoder vs decoder: # Metal --qlinear-encoder fpa4w # encoder linear layers --qlinear fpa4w # decoder linear layers + +# CUDA +--qlinear-encoder 4w --qlinear-encoder-packing-format tile_packed_to_4d +--qlinear 4w --qlinear-packing-format tile_packed_to_4d ``` The streaming encoder references the same module objects that @@ -411,7 +422,7 @@ of ~34 GB for the full-size model): 1. **Meta device construction** — `with torch.device("meta"):` builds the model with zero-storage parameter tensors (shape/dtype metadata only). 2. **safetensors lazy access** — `safe_open` loads tensors on demand, cast - to float32 (the default; bf16 is rejected by the XNNPACK partitioner). + to the configured dtype (`--dtype`, default fp32; CUDA uses bf16). 3. **`assign=True` state dict loading** — replaces meta tensors by reference instead of copying into pre-allocated storage. No duplication. 4. **Post-load fixups** — re-tie `output.weight = tok_embeddings.weight` @@ -430,7 +441,7 @@ of ~34 GB for the full-size model): | `layers.*` | `decoder.layers.*` | | `norm.weight` | `decoder.norm.weight` | -Weights are cast to float32 during loading. `decoder.output.weight` is +Weights are cast to the configured dtype during loading. `decoder.output.weight` is not in the checkpoint — it is created by tying to `decoder.tok_embeddings.weight` in `VoxtralRealtimeModel.__init__`. During export with quantization, the tie is broken (the `if args.qlinear diff --git a/examples/models/voxtral_realtime/model.py b/examples/models/voxtral_realtime/model.py index c055b06a0f6..26778413834 100644 --- a/examples/models/voxtral_realtime/model.py +++ b/examples/models/voxtral_realtime/model.py @@ -50,7 +50,7 @@ class VoxtralRealtimeConfig: downsample_factor: int = 4 # Runtime max_seq_len: int = 4096 - backend: str = "xnnpack" # "xnnpack", "metal", or "portable" + backend: str = "xnnpack" # "xnnpack", "metal", "cuda", or "portable" @staticmethod def from_params_json(path: str) -> "VoxtralRealtimeConfig": @@ -447,6 +447,19 @@ def _build_attn_mask( return (valid.float() - 1.0) * 1e9 +def _build_causal_mask_bool( + input_pos: torch.Tensor, max_seq_len: int, device: torch.device +) -> torch.Tensor: + """Build boolean causal attention mask. True = attend, False = masked. + + Returns [1, 1, seqlen, max_seq_len] for Triton SDPA compatibility + (requires 4D mask with batch and head dims). + """ + k_pos = torch.arange(max_seq_len, device=device) + mask = input_pos.unsqueeze(1) >= k_pos.unsqueeze(0) # [seqlen, max_seq_len] + return mask.unsqueeze(0).unsqueeze(0) # [1, 1, seqlen, max_seq_len] + + class MetalSDPA(nn.Module): """Standard SDPA calling the MPS op directly for native GQA support. @@ -497,10 +510,64 @@ def forward( return y.view(bsz, seqlen, self.dim) +class CudaSDPA(nn.Module): + """Standard SDPA with GQA support for CUDA/AOTI backend. + + Uses F.scaled_dot_product_attention with repeat_interleave for GQA expansion. + KV cache uses [B, H, S, D] layout from StaticKVCache. Requires boolean + attention masks (Triton SDPA kernel only accepts torch.bool). + """ + + def __init__(self, n_heads: int, n_kv_heads: int, head_dim: int): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.n_rep = n_heads // n_kv_heads + self.head_dim = head_dim + self.dim = n_heads * head_dim + + def forward( + self, + input_pos: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz: int, + seqlen: int, + attn_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + input_pos: (seq_len,) position indices. + q: (B, seq_len, n_heads, head_dim) in [B, S, H, D] layout. + k, v: (B, n_kv_heads, max_seq_len, head_dim) in [B, H, S, D] layout from StaticKVCache. + bsz, seqlen: batch size and query sequence length. + attn_mask: precomputed boolean mask (True=attend), or None to compute here. + Returns: + output: (B, seq_len, n_heads * head_dim). + """ + q = q.transpose(1, 2) # [B, n_heads, seq_len, head_dim] + + # Expand KV for GQA + if self.n_rep > 1: + k = k.repeat_interleave(self.n_rep, dim=1) + v = v.repeat_interleave(self.n_rep, dim=1) + + if attn_mask is None: + attn_mask = _build_causal_mask_bool(input_pos, k.shape[2], q.device) + + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, is_causal=False + ) # [B, n_heads, seq_len, head_dim] + + y = y.transpose(1, 2).contiguous() # [B, seq_len, n_heads, head_dim] + return y.view(bsz, seqlen, self.dim) + + class StandardEncoderSDPA(nn.Module): """Standard SDPA for encoder using F.scaled_dot_product_attention. - Compatible with AOTI/Metal backend. Works with EncoderRingKVCache that uses + Compatible with AOTI/Metal/CUDA backend. Works with EncoderRingKVCache that uses [B, S, H, D] layout and sliding window masks. """ @@ -526,7 +593,8 @@ def forward( q: (B, seq_len, n_heads, head_dim) in [B, S, H, D] layout. k, v: (B, buf_size, n_heads, head_dim) in [B, S, H, D] layout from EncoderRingKVCache. bsz, seqlen: batch size and query sequence length. - mask: (seq_len, buf_size) additive attention mask (0.0 = attend, -inf = don't attend). + mask: (seq_len, buf_size) attention mask. Float additive (0.0=attend, -inf=masked) + for Metal, or boolean (True=attend) for CUDA. Returns: output: (B, seq_len, n_heads * head_dim). """ @@ -572,6 +640,9 @@ def __init__(self, config: VoxtralRealtimeConfig, max_seq_len: int): if self.backend == "metal": self.kv_cache = StaticKVCache(max_seq_len, self.n_kv_heads, self.head_dim) self.sdpa = MetalSDPA(self.n_heads, self.n_kv_heads, self.head_dim) + elif self.backend == "cuda": + self.kv_cache = StaticKVCache(max_seq_len, self.n_kv_heads, self.head_dim) + self.sdpa = CudaSDPA(self.n_heads, self.n_kv_heads, self.head_dim) else: self.kv_cache = KVCache(max_seq_len, self.n_kv_heads, self.head_dim) self.sdpa = SDPA(self.n_heads, self.head_dim) @@ -593,10 +664,7 @@ def forward( k, v = self.kv_cache.update(input_pos, k, v) - if self.backend == "metal": - y = self.sdpa(input_pos, q, k, v, B, T, attn_mask) - else: - y = self.sdpa(input_pos, q, k, v, B, T) + y = self.sdpa(input_pos, q, k, v, B, T, attn_mask) return self.wo(y) @@ -686,6 +754,11 @@ def forward( if self.config.backend == "metal": max_seq_len = self.freqs_cos.shape[0] attn_mask = _build_attn_mask(input_pos, max_seq_len, input_embeds.device) + elif self.config.backend == "cuda": + max_seq_len = self.freqs_cos.shape[0] + attn_mask = _build_causal_mask_bool( + input_pos, max_seq_len, input_embeds.device + ) x = input_embeds for layer in self.layers: @@ -798,7 +871,7 @@ def update( return self.k_cache, self.v_cache def create_causal_mask( - self, start_pos: torch.Tensor | int, seq_len: int + self, start_pos: torch.Tensor | int, seq_len: int, bool_mask: bool = False ) -> torch.Tensor: device = ( start_pos.device @@ -851,12 +924,16 @@ def update( return self.k_cache, self.v_cache - def create_causal_mask(self, start_pos: torch.Tensor, seq_len: int) -> torch.Tensor: + def create_causal_mask( + self, start_pos: torch.Tensor, seq_len: int, bool_mask: bool = False + ) -> torch.Tensor: """Create sliding window attention mask for ring buffer. Args: start_pos: Tensor containing the starting position (scalar tensor) seq_len: Number of query positions + bool_mask: If True, return boolean mask (True=attend). If False, + return float additive mask (0.0=attend, -inf=masked). """ total_written = start_pos + seq_len j = torch.arange(self.buf_size, dtype=torch.long, device=start_pos.device) @@ -868,6 +945,10 @@ def create_causal_mask(self, start_pos: torch.Tensor, seq_len: int) -> torch.Ten delta = pos_q - cache_pos.unsqueeze(0) valid = (cache_pos >= 0) & (delta >= 0) & (delta < self.window_size) + if bool_mask: + return valid.unsqueeze(0).unsqueeze( + 0 + ) # [1, 1, seq_len, buf_size] for Triton return torch.where(valid, 0.0, float("-inf")) @@ -896,6 +977,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750): self.downsample_factor = config.downsample_factor self.n_heads = config.enc_n_heads self.head_dim = config.enc_head_dim + self.bool_mask = config.backend == "cuda" # Register conv states as buffers (mutable state for streaming) self.register_buffer("conv1_state", torch.zeros(1, config.num_mel_bins, 2)) @@ -907,7 +989,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750): # Choose cache implementation based on backend cache_class = ( StandardEncoderRingKVCache - if config.backend == "metal" + if config.backend in ("metal", "cuda") else EncoderRingKVCache ) self.kv_caches = nn.ModuleList( @@ -918,7 +1000,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750): ) # Choose SDPA based on backend - if config.backend == "metal": + if config.backend in ("metal", "cuda"): self.sdpa = StandardEncoderSDPA(config.enc_n_heads, config.enc_head_dim) else: self.sdpa = SDPA(config.enc_n_heads, config.enc_head_dim) @@ -995,7 +1077,9 @@ def forward( # Sliding window mask — identical for all layers, compute once. T = x.size(1) # Pass start position as tensor (not .item()) to avoid unbacked symbols in AOTI - mask = self.kv_caches[0].create_causal_mask(enc_input_pos[0], T) + mask = self.kv_caches[0].create_causal_mask( + enc_input_pos[0], T, bool_mask=self.bool_mask + ) for i, layer in enumerate(self.layers): x = self._streaming_encoder_layer( @@ -1078,9 +1162,9 @@ def load_model( max_seq_len: Maximum sequence length for KV cache. n_delay_tokens: Transcription delay in tokens (default 6 = 480ms). dtype: Weight dtype (default: float32). - backend: Backend for acceleration ("xnnpack", "metal", or "portable"). + backend: Backend for acceleration ("xnnpack", "metal", "cuda", or "portable"). """ - _VALID_BACKENDS = ("xnnpack", "metal", "portable") + _VALID_BACKENDS = ("xnnpack", "metal", "cuda", "portable") if backend not in _VALID_BACKENDS: raise ValueError( f"Unknown backend '{backend}'. Must be one of {_VALID_BACKENDS}." diff --git a/examples/models/voxtral_realtime/voxtral_realtime_runner.cpp b/examples/models/voxtral_realtime/voxtral_realtime_runner.cpp index c6a329de95b..f181488197c 100644 --- a/examples/models/voxtral_realtime/voxtral_realtime_runner.cpp +++ b/examples/models/voxtral_realtime/voxtral_realtime_runner.cpp @@ -14,7 +14,6 @@ #include #include -#include #include #include #include @@ -31,11 +30,19 @@ VoxtralRealtimeRunner::VoxtralRealtimeRunner( const std::string& model_path, const std::string& tokenizer_path, const std::string& preprocessor_path, + const std::string& data_path, bool warmup) { // Load the main model (.pte with audio_encoder, text_decoder, // token_embedding methods). Mmap avoids copying the file into memory. + // For CUDA backend, data_path points to the .ptd file with compiled kernels. ET_LOG(Info, "Loading model from: %s", model_path.c_str()); - model_ = std::make_unique(model_path, Module::LoadMode::Mmap); + if (!data_path.empty()) { + ET_LOG(Info, "Loading data from: %s", data_path.c_str()); + model_ = + std::make_unique(model_path, data_path, Module::LoadMode::Mmap); + } else { + model_ = std::make_unique(model_path, Module::LoadMode::Mmap); + } auto load_error = model_->load(); ET_CHECK_MSG(load_error == Error::Ok, "Failed to load model."); @@ -54,12 +61,30 @@ VoxtralRealtimeRunner::VoxtralRealtimeRunner( if (dm.ok()) dim_ = dm.get()[0].toInt(); + // Detect model dtype from method metadata (same pattern as ASR runner). + // Checks the first input tensor's scalar_type of the audio_encoder or + // encode_audio_chunk method. Falls back to Float for old .pte files. + for (const char* method : {"audio_encoder", "encode_audio_chunk"}) { + auto meta_result = model_->method_meta(method); + if (meta_result.ok()) { + auto meta = meta_result.get(); + if (meta.num_inputs() > 0) { + auto input_meta = meta.input_tensor_meta(0); + if (input_meta.ok()) { + model_dtype_ = input_meta.get().scalar_type(); + } + } + break; + } + } + ET_LOG( Info, - "Model: max_seq_len=%ld, vocab_size=%ld, dim=%ld", + "Model: max_seq_len=%ld, vocab_size=%ld, dim=%ld, dtype=%s", static_cast(max_seq_len_), static_cast(vocab_size_), - static_cast(dim_)); + static_cast(dim_), + ::executorch::runtime::toString(model_dtype_)); // Detect streaming model (exported with --streaming flag). auto streaming_val = model_->execute("streaming", empty); @@ -140,7 +165,7 @@ VoxtralRealtimeRunner::VoxtralRealtimeRunner( session->feed_audio(dummy_audio.data(), step_samples_); session->flush(); } else { - // Preprocessor + // Preprocessor (always float32 — runs on CPU) auto pp_wav = from_blob( dummy_audio.data(), {static_cast(step_samples_)}, @@ -150,12 +175,14 @@ VoxtralRealtimeRunner::VoxtralRealtimeRunner( ET_CHECK_MSG(pp_r.ok(), "Warmup: preprocessor failed."); // Audio encoder (8 mel frames = minimum valid input) - std::vector dummy_mel( + // Create fp32 mel then convert to model dtype if needed. + std::vector dummy_mel_fp32( static_cast(num_mel_bins_ * 8), 0.0f); - auto mel_t = from_blob( - dummy_mel.data(), + auto mel_fp32 = from_blob( + dummy_mel_fp32.data(), {1, static_cast(num_mel_bins_), 8}, ::executorch::aten::ScalarType::Float); + auto mel_t = convert_to_model_dtype(std::move(mel_fp32)); auto enc_r = model_->execute("audio_encoder", std::vector{*mel_t}); ET_CHECK_MSG(enc_r.ok(), "Warmup: audio_encoder failed."); @@ -168,13 +195,13 @@ VoxtralRealtimeRunner::VoxtralRealtimeRunner( model_->execute("token_embedding", std::vector{*tok_t}); ET_CHECK_MSG(tok_r.ok(), "Warmup: token_embedding failed."); - // Text decoder - std::vector dummy_emb(static_cast(dim_), 0.0f); - int64_t dummy_pos = 0; + // Text decoder — create embeds in model dtype + auto tok_embed = tok_r.get()[0].toTensor(); auto emb_t = from_blob( - dummy_emb.data(), + tok_embed.mutable_data_ptr(), {1, 1, static_cast(dim_)}, - ::executorch::aten::ScalarType::Float); + model_dtype_); + int64_t dummy_pos = 0; auto pos_t = from_blob(&dummy_pos, {1}, ::executorch::aten::ScalarType::Long); auto dec_r = @@ -216,6 +243,51 @@ TensorPtr VoxtralRealtimeRunner::run_preprocessor( ::executorch::aten::ScalarType::Float); } +// Extract the last vocab_size logits as fp32 for sampling. +// If the tensor is already fp32, returns a pointer into it directly. +// Otherwise converts to fp32 into the provided buffer. +static float* get_logits_fp32( + ::executorch::aten::Tensor& logits, + int64_t vocab_size, + std::vector& buf) { + const int64_t offset = logits.numel() - vocab_size; + if (logits.scalar_type() == ::executorch::aten::ScalarType::Float) { + return logits.mutable_data_ptr() + offset; + } + // Convert bf16/half logits to fp32 + buf.resize(static_cast(vocab_size)); + if (logits.scalar_type() == ::executorch::aten::ScalarType::BFloat16) { + const auto* src = + logits.const_data_ptr<::executorch::aten::BFloat16>() + offset; + for (int64_t i = 0; i < vocab_size; i++) { + buf[static_cast(i)] = static_cast(src[i]); + } + } else if (logits.scalar_type() == ::executorch::aten::ScalarType::Half) { + const auto* src = + logits.const_data_ptr<::executorch::aten::Half>() + offset; + for (int64_t i = 0; i < vocab_size; i++) { + buf[static_cast(i)] = static_cast(src[i]); + } + } else { + ET_CHECK_MSG(false, "Unsupported logits dtype for sampling."); + } + return buf.data(); +} + +TensorPtr VoxtralRealtimeRunner::convert_to_model_dtype(TensorPtr tensor) { + if (model_dtype_ == ::executorch::aten::ScalarType::Float || + tensor->scalar_type() == model_dtype_) { + return tensor; + } + if (model_dtype_ == ::executorch::aten::ScalarType::BFloat16) { + auto result = ::executorch::extension::llm::convert_to_bfloat16(tensor); + ET_CHECK_MSG(result.ok(), "Failed to convert tensor to BFloat16."); + return std::move(result.get()); + } + ET_CHECK_MSG(false, "Unsupported model dtype conversion."); + return tensor; // unreachable +} + int VoxtralRealtimeRunner::transcribe( const float* audio_data, int64_t num_samples, @@ -223,7 +295,10 @@ int VoxtralRealtimeRunner::transcribe( TokenCallback token_cb) { // --- Step 1: Preprocess raw audio to mel spectrogram --- ET_CHECK_MSG(preprocessor_ != nullptr, "No preprocessor provided."); - TensorPtr mel = run_preprocessor(audio_data, num_samples); + TensorPtr mel_fp32 = run_preprocessor(audio_data, num_samples); + + // Convert mel from fp32 (preprocessor) to model dtype (may be bf16) + TensorPtr mel = convert_to_model_dtype(std::move(mel_fp32)); // --- Step 2: Encode mel to audio embeddings --- // audio_encoder: (1, 128, T_mel) -> (1, T_audio, 3072) @@ -257,14 +332,14 @@ int VoxtralRealtimeRunner::transcribe( const int64_t max_pos = std::min( static_cast(config.max_new_tokens) + t_audio, max_seq_len_); - std::vector input_embeds_buf(static_cast(dim_)); - - // Token sampler with xorshift RNG, seeded from wall clock. ::executorch::extension::llm::Sampler sampler( static_cast(vocab_size_), config.temperature, ::executorch::extension::llm::kTopp, static_cast(std::time(nullptr))); + std::vector logits_fp32_buf; + auto input_embeds = ::executorch::extension::empty( + {1, 1, static_cast(dim_)}, model_dtype_); for (int64_t pos = 0; pos < max_pos; pos++) { // a. Look up embedding for the previous token. @@ -276,27 +351,40 @@ int VoxtralRealtimeRunner::transcribe( model_->execute("token_embedding", std::vector{*token_tensor}); ET_CHECK_MSG(tok_result.ok(), "token_embedding failed."); auto tok_embed = tok_result.get()[0].toTensor(); - const float* tok_data = tok_embed.const_data_ptr(); // b. Sum audio + token embeddings (or token-only after audio ends). + // Both audio_embeds and tok_embed are in model_dtype_ (fp32 or bf16). + // Reuses pre-allocated input_embeds buffer (no per-token allocation). if (pos < t_audio) { - const float* audio_frame = - audio_embeds.const_data_ptr() + pos * dim_; - for (int64_t i = 0; i < dim_; i++) { - input_embeds_buf[static_cast(i)] = audio_frame[i] + tok_data[i]; + // Element-wise sum: audio_frame[i] + tok_data[i] + // Works for any dtype since we operate on raw bytes via BFloat16 type. + if (model_dtype_ == ::executorch::aten::ScalarType::BFloat16) { + auto* out = + input_embeds->mutable_data_ptr<::executorch::aten::BFloat16>(); + const auto* af = + audio_embeds.const_data_ptr<::executorch::aten::BFloat16>() + + pos * dim_; + const auto* tf = + tok_embed.const_data_ptr<::executorch::aten::BFloat16>(); + for (int64_t i = 0; i < dim_; i++) { + out[i] = ::executorch::aten::BFloat16( + static_cast(af[i]) + static_cast(tf[i])); + } + } else { + auto* out = input_embeds->mutable_data_ptr(); + const auto* af = audio_embeds.const_data_ptr() + pos * dim_; + const auto* tf = tok_embed.const_data_ptr(); + for (int64_t i = 0; i < dim_; i++) { + out[i] = af[i] + tf[i]; + } } } else { std::memcpy( - input_embeds_buf.data(), - tok_data, - static_cast(dim_) * sizeof(float)); + input_embeds->mutable_data_ptr(), + tok_embed.const_data_ptr(), + static_cast(dim_) * input_embeds->element_size()); } - auto input_embeds = from_blob( - input_embeds_buf.data(), - {1, 1, static_cast(dim_)}, - ::executorch::aten::ScalarType::Float); - // c. Run one decoder step. KV cache is updated internally by the model. auto cache_pos = from_blob(&pos, {1}, ::executorch::aten::ScalarType::Long); @@ -306,10 +394,8 @@ int VoxtralRealtimeRunner::transcribe( auto logits = dec_result.get()[0].toTensor(); - // d. Sample next token from logits. Safe to mutate the output buffer - // since text_decoder overwrites it on the next execute() call. - float* logits_data = - logits.mutable_data_ptr() + (logits.numel() - vocab_size_); + // d. Sample next token (persistent sampler preserves RNG state). + float* logits_data = get_logits_fp32(logits, vocab_size_, logits_fp32_buf); int64_t next_token = static_cast(sampler.sample(logits_data)); num_generated++; @@ -359,7 +445,9 @@ StreamingSession::StreamingSession( config.temperature, ::executorch::extension::llm::kTopp, static_cast(std::time(nullptr))), - input_embeds_buf_(static_cast(runner.dim_)) {} + input_embeds_(::executorch::extension::empty( + {1, 1, static_cast(runner.dim_)}, + runner.model_dtype_)) {} int StreamingSession::feed_audio(const float* data, int64_t num_samples) { audio_buf_.insert(audio_buf_.end(), data, data + num_samples); @@ -457,21 +545,24 @@ bool StreamingSession::try_process_step() { // These align exactly with the offline mel frames for this step. // Output layout is channels-first: (1, 128, T). For each channel, // copy 8 contiguous frames starting at offset mel_skip. - std::vector mel_chunk_buf( + std::vector mel_chunk_fp32( static_cast(num_mel_bins * chunk_mel_len)); const float* mel_data = mel.const_data_ptr(); for (int64_t c = 0; c < num_mel_bins; c++) { std::memcpy( - mel_chunk_buf.data() + c * chunk_mel_len, + mel_chunk_fp32.data() + c * chunk_mel_len, mel_data + c * total_mel_frames + mel_skip, static_cast(chunk_mel_len) * sizeof(float)); } - auto mel_chunk = from_blob( - mel_chunk_buf.data(), + auto mel_chunk_tensor = from_blob( + mel_chunk_fp32.data(), {1, static_cast(num_mel_bins), static_cast(chunk_mel_len)}, ::executorch::aten::ScalarType::Float); + // Convert to model dtype if needed (e.g., fp32 -> bf16 for CUDA) + auto mel_chunk = runner_.convert_to_model_dtype(std::move(mel_chunk_tensor)); + std::vector enc_pos_data(static_cast(enc_frames_per_chunk)); for (int64_t i = 0; i < enc_frames_per_chunk; i++) { enc_pos_data[static_cast(i)] = enc_frame_pos_ + i; @@ -487,16 +578,21 @@ bool StreamingSession::try_process_step() { ET_CHECK_MSG(enc_result.ok(), "encode_audio_chunk failed."); auto& enc_outputs = enc_result.get(); - auto audio_embeds = enc_outputs[0].toTensor(); + auto audio_embeds_tensor = + std::make_shared<::executorch::aten::Tensor>(enc_outputs[0].toTensor()); + auto audio_embeds_ptr = TensorPtr(audio_embeds_tensor); enc_frame_pos_ += enc_frames_per_chunk; samples_consumed_ += step; // --- Decode one step --- - return decode_step(audio_embeds.const_data_ptr()); + return decode_step(&audio_embeds_ptr); } -bool StreamingSession::decode_step(const float* audio_embeds) { +bool StreamingSession::decode_step(const TensorPtr* audio_embeds_tensor) { + const int64_t dim = runner_.dim_; + const auto model_dtype = runner_.model_dtype_; + // Token embedding for previous token. int64_t token_id = static_cast(prev_token_); auto token_tensor = @@ -506,35 +602,48 @@ bool StreamingSession::decode_step(const float* audio_embeds) { "token_embedding", std::vector{*token_tensor}); ET_CHECK_MSG(tok_result.ok(), "token_embedding failed."); auto tok_embed = tok_result.get()[0].toTensor(); - const float* tok_data = tok_embed.const_data_ptr(); - // Sum audio + token embeddings (or token-only if audio_embeds is null). - if (audio_embeds != nullptr) { - for (int64_t i = 0; i < runner_.dim_; i++) { - input_embeds_buf_[static_cast(i)] = audio_embeds[i] + tok_data[i]; + // Sum audio + token embeddings (or token-only if no audio). + // Reuses pre-allocated input_embeds_ buffer (no per-token allocation). + if (audio_embeds_tensor != nullptr) { + auto& audio_embeds = **audio_embeds_tensor; + if (model_dtype == ::executorch::aten::ScalarType::BFloat16) { + auto* out = + input_embeds_->mutable_data_ptr<::executorch::aten::BFloat16>(); + const auto* af = + audio_embeds.const_data_ptr<::executorch::aten::BFloat16>(); + const auto* tf = tok_embed.const_data_ptr<::executorch::aten::BFloat16>(); + for (int64_t i = 0; i < dim; i++) { + out[i] = ::executorch::aten::BFloat16( + static_cast(af[i]) + static_cast(tf[i])); + } + } else { + auto* out = input_embeds_->mutable_data_ptr(); + const auto* af = audio_embeds.const_data_ptr(); + const auto* tf = tok_embed.const_data_ptr(); + for (int64_t i = 0; i < dim; i++) { + out[i] = af[i] + tf[i]; + } } } else { std::memcpy( - input_embeds_buf_.data(), - tok_data, - static_cast(runner_.dim_) * sizeof(float)); + input_embeds_->mutable_data_ptr(), + tok_embed.const_data_ptr(), + static_cast(dim) * input_embeds_->element_size()); } - auto input_embeds = from_blob( - input_embeds_buf_.data(), - {1, 1, static_cast(runner_.dim_)}, - ::executorch::aten::ScalarType::Float); - auto cache_pos = from_blob(&dec_pos_, {1}, ::executorch::aten::ScalarType::Long); auto dec_result = runner_.model_->execute( - "text_decoder", std::vector{*input_embeds, *cache_pos}); + "text_decoder", std::vector{*input_embeds_, *cache_pos}); ET_CHECK_MSG(dec_result.ok(), "text_decoder failed."); auto logits = dec_result.get()[0].toTensor(); + + // Sample next token (persistent sampler preserves RNG state). float* logits_data = - logits.mutable_data_ptr() + (logits.numel() - runner_.vocab_size_); + get_logits_fp32(logits, runner_.vocab_size_, logits_fp32_buf_); int64_t next_token = static_cast(sampler_.sample(logits_data)); num_generated_++; diff --git a/examples/models/voxtral_realtime/voxtral_realtime_runner.h b/examples/models/voxtral_realtime/voxtral_realtime_runner.h index a209c73bd6a..fbd5a9aea4a 100644 --- a/examples/models/voxtral_realtime/voxtral_realtime_runner.h +++ b/examples/models/voxtral_realtime/voxtral_realtime_runner.h @@ -40,6 +40,7 @@ class VoxtralRealtimeRunner { const std::string& model_path, const std::string& tokenizer_path, const std::string& preprocessor_path = "", + const std::string& data_path = "", bool warmup = true); // Offline transcription: full encoder first, then step-by-step decode. @@ -78,6 +79,11 @@ class VoxtralRealtimeRunner { int64_t vocab_size_ = 131072; int64_t dim_ = 3072; + // Model dtype detected from method metadata (input_tensor_meta). + // Defaults to Float; set to BFloat16 if the model expects bf16 inputs. + ::executorch::aten::ScalarType model_dtype_ = + ::executorch::aten::ScalarType::Float; + // Streaming metadata (from constant_methods, if present) bool is_streaming_ = false; int64_t num_mel_bins_ = 128; @@ -103,6 +109,10 @@ class VoxtralRealtimeRunner { ::executorch::extension::TensorPtr run_preprocessor( const float* audio, int64_t num_samples); + + // Convert a tensor to model_dtype_ if needed (e.g., fp32 mel -> bf16). + ::executorch::extension::TensorPtr convert_to_model_dtype( + ::executorch::extension::TensorPtr tensor); }; // Streaming session: accepts raw audio incrementally via feed_audio(), @@ -147,13 +157,17 @@ class StreamingSession { bool flushed_ = false; ::executorch::extension::llm::Sampler sampler_; - std::vector input_embeds_buf_; + ::executorch::extension::TensorPtr input_embeds_; + std::vector logits_fp32_buf_; // Process one 80ms step from the audio buffer. bool try_process_step(); // Run one decoder step (token_embed + optional audio_embed -> logits). - bool decode_step(const float* audio_embeds); + // audio_embeds_tensor is the output from encode_audio_chunk, or nullptr + // for text-only decoding after audio ends. + bool decode_step( + const ::executorch::extension::TensorPtr* audio_embeds_tensor); }; } // namespace voxtral_realtime