From 8361c9aec723953bc7763881f56ef4c66029ae99 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Fri, 19 Dec 2025 14:26:36 -0800 Subject: [PATCH 1/9] parakeet --- examples/models/parakeet/README.md | 36 ++ .../models/parakeet/export_parakeet_tdt.py | 383 ++++++++++++++++++ 2 files changed, 419 insertions(+) create mode 100644 examples/models/parakeet/README.md create mode 100644 examples/models/parakeet/export_parakeet_tdt.py diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md new file mode 100644 index 00000000000..97611b10d95 --- /dev/null +++ b/examples/models/parakeet/README.md @@ -0,0 +1,36 @@ +# Parakeet TDT Export for ExecuTorch + +Export [nvidia/parakeet-tdt-0.6b-v3](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3) speech recognition model to ExecuTorch. + +## Installation + +```bash +pip install nemo_toolkit[asr] torchaudio +``` + +## Usage + +Export the model (portable backend): +```bash +python export_parakeet_tdt.py +``` + +Export with a specific backend: +```bash +python export_parakeet_tdt.py --backend xnnpack # CPU acceleration +python export_parakeet_tdt.py --backend cuda # CUDA acceleration +python export_parakeet_tdt.py --backend cuda-windows # CUDA on Windows +``` + +Test transcription on an audio file: +```bash +python export_parakeet_tdt.py --audio /path/to/audio.wav +``` + +### Arguments + +| Argument | Description | +|----------|-------------| +| `--output-dir` | Output directory for exports (default: `./parakeet_tdt_exports`) | +| `--backend` | Backend for acceleration: `portable`, `xnnpack`, `cuda`, `cuda-windows` (default: `portable`) | +| `--audio` | Path to audio file for transcription test | diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py new file mode 100644 index 00000000000..8d51b6826fe --- /dev/null +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python3 +"""Export nvidia/parakeet-tdt-0.6b-v3 components to ExecuTorch.""" + +import os + +import torch +from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge_transform_and_lower +from executorch.exir.passes import MemoryPlanningPass +from torch.export import Dim, export + + +def load_audio(audio_path: str, sample_rate: int = 16000) -> torch.Tensor: + """Load audio file and resample to target sample rate.""" + try: + import torchaudio + + waveform, sr = torchaudio.load(audio_path) + except (ImportError, Exception): + from scipy.io import wavfile + + sr, data = wavfile.read(audio_path) + if data.dtype == "int16": + data = data.astype("float32") / 32768.0 + elif data.dtype == "int32": + data = data.astype("float32") / 2147483648.0 + waveform = torch.from_numpy(data).unsqueeze(0) + + if waveform.shape[0] > 1: + waveform = waveform.mean(dim=0, keepdim=True) + + if sr != sample_rate: + try: + import torchaudio + + resampler = torchaudio.transforms.Resample(sr, sample_rate) + waveform = resampler(waveform) + except ImportError: + from scipy import signal + + num_samples = int(len(waveform[0]) * sample_rate / sr) + resampled = signal.resample(waveform[0].numpy(), num_samples) + waveform = torch.from_numpy(resampled).unsqueeze(0).float() + + return waveform + + +def greedy_decode_eager(encoder_output: torch.Tensor, encoder_len: torch.Tensor, model) -> list[int]: + """Greedy decode using NeMo's built-in decoding.""" + hypotheses = model.decoding.rnnt_decoder_predictions_tensor( + encoder_output=encoder_output, + encoded_lengths=encoder_len, + return_hypotheses=True, + ) + return hypotheses[0].y_sequence + + +class DecoderPredict(torch.nn.Module): + """Wrapper for decoder.predict() with LSTM state.""" + + def __init__(self, decoder): + super().__init__() + self.decoder = decoder + self.pred_hidden = decoder.pred_hidden + self.pred_rnn_layers = getattr(decoder, "pred_rnn_layers", 2) + + def forward( + self, token: torch.Tensor, h: torch.Tensor, c: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + g, new_state = self.decoder.predict(y=token, state=[h, c], add_sos=False) + return g, new_state[0], new_state[1] + + +def greedy_decode_executorch( + encoder_output: torch.Tensor, + encoder_len: int, + program, + blank_id: int, + vocab_size: int, + num_rnn_layers: int = 2, + pred_hidden: int = 640, + max_symbols_per_step: int = 10, + durations: list[int] | None = None, +) -> list[int]: + """TDT duration-aware greedy decode using ExecuTorch runtime.""" + if durations is None: + durations = [0, 1, 2, 3, 4] + + hypothesis = [] + num_token_classes = vocab_size + 1 + + encoder_output = encoder_output.transpose(1, 2) + + proj_enc_method = program.load_method("joint_project_encoder") + f_proj = proj_enc_method.execute([encoder_output.contiguous()])[0] + + decoder_predict_method = program.load_method("decoder_predict") + proj_dec_method = program.load_method("joint_project_decoder") + joint_method = program.load_method("joint") + + h = torch.zeros(num_rnn_layers, 1, pred_hidden) + c = torch.zeros(num_rnn_layers, 1, pred_hidden) + + sos_g = torch.zeros(1, 1, pred_hidden) + g_proj = proj_dec_method.execute([sos_g])[0] + + t = 0 + symbols_on_frame = 0 + + # Scan over the encoder output + while t < encoder_len: + f_t = f_proj[:, t : t + 1, :].contiguous() + + joint_out = joint_method.execute([f_t, g_proj]) + + full_logits = joint_out[0].squeeze() + token_logits = full_logits[:num_token_classes] + duration_logits = full_logits[num_token_classes:] + + k = token_logits.argmax().item() + dur_idx = duration_logits.argmax().item() + dur = durations[dur_idx] + + # TDT decoding: joint network outputs both token logits and duration logits. + # - If blank: skip forward by predicted duration (min 1 frame) + # - If token: emit it, update decoder state, advance by duration. + # Duration=0 means "emit another token on this frame" (up to max_symbols_per_step). + if k == blank_id: + t += max(dur, 1) + symbols_on_frame = 0 + else: + hypothesis.append(k) + + token = torch.tensor([[k]], dtype=torch.long) + result = decoder_predict_method.execute([token, h, c]) + g = result[0] + h = result[1] + c = result[2] + + g_proj = proj_dec_method.execute([g])[0] + t += dur + + if dur == 0: + symbols_on_frame += 1 + if symbols_on_frame >= max_symbols_per_step: + t += 1 + symbols_on_frame = 0 + else: + symbols_on_frame = 0 + + return hypothesis + + +def transcribe_executorch(audio_path: str, model, et_buffer) -> str: + """Transcribe audio file using ExecuTorch runtime.""" + from executorch.runtime import Runtime + + runtime = Runtime.get() + program = runtime.load_program(et_buffer) + + with torch.no_grad(): + audio = load_audio(audio_path) + + mel, mel_len = model.preprocessor(input_signal=audio, length=torch.tensor([audio.shape[1]])) + + encoder_method = program.load_method("encoder") + enc_result = encoder_method.execute([mel, mel_len]) + encoded = enc_result[0] + encoded_len = enc_result[1].item() + + vocab_size = model.tokenizer.vocab_size + tokens = greedy_decode_executorch( + encoded, + encoded_len, + program, + blank_id=vocab_size, + vocab_size=vocab_size, + num_rnn_layers=model.decoder.pred_rnn_layers, + pred_hidden=model.decoder.pred_hidden, + ) + + return model.tokenizer.ids_to_text(tokens) + + +def transcribe_eager(audio_path: str, model) -> str: + """Transcribe audio file using eager PyTorch model.""" + with torch.no_grad(): + audio = load_audio(audio_path) + mel, mel_len = model.preprocessor(input_signal=audio, length=torch.tensor([audio.shape[1]])) + encoded, encoded_len = model.encoder(audio_signal=mel, length=mel_len) + tokens = greedy_decode_eager(encoded, encoded_len, model) + return model.tokenizer.ids_to_text(tokens) + + +def load_model(): + import nemo.collections.asr as nemo_asr + + model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v3") + model.eval() + return model + + +class JointAfterProjection(torch.nn.Module): + def __init__(self, joint): + super().__init__() + self.joint = joint + + def forward(self, f, g): + return self.joint.joint_after_projection(f, g) + + +class JointProjectEncoder(torch.nn.Module): + def __init__(self, joint): + super().__init__() + self.joint = joint + + def forward(self, f): + return self.joint.project_encoder(f) + + +class JointProjectDecoder(torch.nn.Module): + def __init__(self, joint): + super().__init__() + self.joint = joint + + def forward(self, g): + return self.joint.project_prednet(g) + + +def export_all(model): + """Export all components, return dict of ExportedPrograms.""" + programs = {} + + feat_in = getattr(model.encoder, "_feat_in", 80) + audio_signal = torch.randn(1, feat_in, 100) + length = torch.tensor([100], dtype=torch.int64) + programs["encoder"] = export( + model.encoder, + (), + kwargs={"audio_signal": audio_signal, "length": length}, + dynamic_shapes={"audio_signal": {2: Dim.AUTO}, "length": {}}, + strict=False, + ) + + decoder_predict = DecoderPredict(model.decoder) + decoder_predict.eval() + token = torch.tensor([[0]], dtype=torch.long) + num_layers = model.decoder.pred_rnn_layers + pred_hidden = model.decoder.pred_hidden + h = torch.zeros(num_layers, 1, pred_hidden) + c = torch.zeros(num_layers, 1, pred_hidden) + programs["decoder_predict"] = export( + decoder_predict, + (token, h, c), + dynamic_shapes={"token": {}, "h": {}, "c": {}}, + strict=False, + ) + + f_proj = torch.randn(1, 1, 640) + g_proj = torch.randn(1, 1, 640) + programs["joint"] = export( + JointAfterProjection(model.joint), + (f_proj, g_proj), + dynamic_shapes={"f": {}, "g": {}}, + strict=False, + ) + + programs["joint_project_encoder"] = export( + JointProjectEncoder(model.joint), + (torch.randn(1, 25, 1024),), + dynamic_shapes={"f": {1: Dim("enc_time", min=1, max=60000)}}, + strict=False, + ) + + pred_hidden = getattr(model.decoder, "pred_hidden", 640) + programs["joint_project_decoder"] = export( + JointProjectDecoder(model.joint), + (torch.randn(1, 1, pred_hidden),), + dynamic_shapes={"g": {}}, + strict=False, + ) + + return programs + + +def lower_to_executorch(programs, backend="portable"): + """Lower all ExportedPrograms to ExecuTorch.""" + partitioner = None + + if backend == "xnnpack": + from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner + + print("\nLowering to ExecuTorch with XNNPACK...") + partitioner = [XnnpackPartitioner()] + + elif backend in ("cuda", "cuda-windows"): + from torch._inductor.decomposition import conv1d_to_conv2d + + from executorch.backends.cuda.cuda_backend import CudaBackend + from executorch.backends.cuda.cuda_partitioner import CudaPartitioner + from executorch.exir.backend.compile_spec_schema import CompileSpec + + print(f"\nLowering to ExecuTorch with CUDA{' (Windows)' if backend == 'cuda-windows' else ''}...") + + # Decompose conv1d to conv2d for Triton kernel generation + for key, ep in programs.items(): + programs[key] = ep.run_decompositions({torch.ops.aten.conv1d.default: conv1d_to_conv2d}) + + partitioner = {} + for key in programs.keys(): + compile_specs = [CudaBackend.generate_method_name_compile_spec(key)] + if backend == "cuda-windows": + compile_specs.append(CompileSpec("platform", "windows".encode("utf-8"))) + partitioner[key] = [CudaPartitioner(compile_specs)] + + else: + print("\nLowering to ExecuTorch") + partitioner = [] + + et_prog = to_edge_transform_and_lower( + programs, + partitioner=partitioner, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + ) + return et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + +def main(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", default="./parakeet_tdt_exports") + parser.add_argument("--audio", type=str, help="Path to audio file for transcription test") + parser.add_argument( + "--backend", + choices=["portable", "xnnpack", "cuda", "cuda-windows"], + default="portable", + help="Backend for acceleration", + ) + args = parser.parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + print("Loading model...") + model = load_model() + + print("\nExporting components...") + programs = export_all(model) + + et = lower_to_executorch(programs, backend=args.backend) + + if args.audio: + print("\n" + "=" * 60) + print("Testing transcription...") + print("=" * 60) + + print("\n[Eager PyTorch]") + eager_text = transcribe_eager(args.audio, model) + print(f" Result: {eager_text}") + + print("\n[ExecuTorch Runtime]") + et_text = transcribe_executorch(args.audio, model, et.buffer) + print(f" Result: {et_text}") + + if eager_text == et_text: + print("\n✓ Transcriptions match!") + else: + print("\n✗ Transcriptions differ!") + print(f" Eager: {eager_text}") + print(f" ET: {et_text}") + + print("\nDone!") + + +if __name__ == "__main__": + main() From ec962255352886e678b3413c854da97c4d46d215 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Fri, 19 Dec 2025 16:43:29 -0800 Subject: [PATCH 2/9] close to getting cpp working only the first few tokens are wrong --- examples/models/parakeet/CMakeLists.txt | 110 +++++ examples/models/parakeet/CMakePresets.json | 110 +++++ examples/models/parakeet/README.md | 52 ++- .../models/parakeet/export_parakeet_tdt.py | 135 +++++- examples/models/parakeet/main.cpp | 411 ++++++++++++++++++ 5 files changed, 815 insertions(+), 3 deletions(-) create mode 100644 examples/models/parakeet/CMakeLists.txt create mode 100644 examples/models/parakeet/CMakePresets.json create mode 100644 examples/models/parakeet/main.cpp diff --git a/examples/models/parakeet/CMakeLists.txt b/examples/models/parakeet/CMakeLists.txt new file mode 100644 index 00000000000..9ca451c095b --- /dev/null +++ b/examples/models/parakeet/CMakeLists.txt @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.24) +project(parakeet_runner) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +if(CMAKE_TOOLCHAIN_FILE MATCHES ".*(iOS|ios\.toolchain)\.cmake$") + set(CMAKE_TOOLCHAIN_IOS ON) +else() + set(CMAKE_TOOLCHAIN_IOS OFF) +endif() + +# Let files say "include " +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +# Need this for gflags +set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags) +find_package(gflags REQUIRED) + +# Find executorch libraries +list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..) +find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH) +executorch_target_link_options_shared_lib(executorch) + +set(link_libraries executorch gflags) + +# Common ops for all builds +list(APPEND link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas) +executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) + +# CPU-only builds need quantized and custom ops +if(NOT EXECUTORCH_BUILD_CUDA AND MSVC) + list(APPEND link_libraries quantized_ops_lib custom_ops) + executorch_target_link_options_shared_lib(quantized_ops_lib) + executorch_target_link_options_shared_lib(custom_ops) +endif() + +# XNNPACK +if(TARGET xnnpack_backend) + set(xnnpack_backend_libs xnnpack_backend XNNPACK xnnpack-microkernels-prod) + if(TARGET kleidiai) + list(APPEND xnnpack_backend_libs kleidiai) + endif() + list(APPEND link_libraries ${xnnpack_backend_libs}) + executorch_target_link_options_shared_lib(xnnpack_backend) +endif() + +# Needed for cpuinfo where it uses android specific log lib +if(ANDROID) + list(APPEND link_libraries log) +endif() + +# Add the required ExecuTorch extensions +list( + APPEND + link_libraries + extension_llm_runner + extension_module + extension_data_loader + extension_tensor + extension_flat_tensor + tokenizers::tokenizers +) + +# Link CUDA backend +if(EXECUTORCH_BUILD_CUDA) + find_package(CUDAToolkit REQUIRED) + list(APPEND link_libraries aoti_cuda_backend) + if(NOT MSVC) + executorch_target_link_options_shared_lib(aoti_cuda_backend) + endif() +endif() + +if(EXECUTORCH_BUILD_METAL) + list(APPEND link_libraries metal_backend) + executorch_target_link_options_shared_lib(metal_backend) +endif() + +add_executable(parakeet_runner main.cpp) +if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + target_link_options_gc_sections(parakeet_runner) + if(NOT APPLE AND NOT MSVC) + target_link_options(parakeet_runner PRIVATE "LINKER:-s") + endif() +endif() + +target_include_directories(parakeet_runner PUBLIC ${_common_include_directories}) +target_link_libraries(parakeet_runner PUBLIC ${link_libraries}) +target_compile_options(parakeet_runner PUBLIC ${_common_compile_options}) + +# On Windows, copy required DLLs to the executable directory +if(MSVC AND EXECUTORCH_BUILD_CUDA) + add_custom_command( + TARGET parakeet_runner + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ + $ + COMMENT "Copying aoti_cuda_shims.dll to parakeet_runner directory" + ) +endif() diff --git a/examples/models/parakeet/CMakePresets.json b/examples/models/parakeet/CMakePresets.json new file mode 100644 index 00000000000..ea93d257ba7 --- /dev/null +++ b/examples/models/parakeet/CMakePresets.json @@ -0,0 +1,110 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "parakeet-base", + "hidden": true, + "binaryDir": "${sourceDir}/../../../cmake-out/examples/models/parakeet", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out", + "CMAKE_PREFIX_PATH": "${sourceDir}/../../../cmake-out" + } + }, + { + "name": "parakeet-cpu", + "displayName": "Parakeet runner (CPU)", + "inherits": ["parakeet-base"] + }, + { + "name": "parakeet-cuda", + "displayName": "Parakeet runner (CUDA)", + "inherits": ["parakeet-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_CUDA": "ON" + }, + "condition": { + "type": "inList", + "string": "${hostSystemName}", + "list": ["Linux", "Windows"] + } + }, + { + "name": "parakeet-metal", + "displayName": "Parakeet runner (Metal)", + "inherits": ["parakeet-base"], + "cacheVariables": { + "EXECUTORCH_BUILD_METAL": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "type": "equals", + "rhs": "Darwin" + } + } + ], + "buildPresets": [ + { + "name": "parakeet-cpu", + "displayName": "Build Parakeet runner (CPU)", + "configurePreset": "parakeet-cpu", + "targets": ["parakeet_runner"] + }, + { + "name": "parakeet-cuda", + "displayName": "Build Parakeet runner (CUDA)", + "configurePreset": "parakeet-cuda", + "targets": ["parakeet_runner"] + }, + { + "name": "parakeet-metal", + "displayName": "Build Parakeet runner (Metal)", + "configurePreset": "parakeet-metal", + "targets": ["parakeet_runner"] + } + ], + "workflowPresets": [ + { + "name": "parakeet-cpu", + "displayName": "Configure and build Parakeet runner (CPU)", + "steps": [ + { + "type": "configure", + "name": "parakeet-cpu" + }, + { + "type": "build", + "name": "parakeet-cpu" + } + ] + }, + { + "name": "parakeet-cuda", + "displayName": "Configure and build Parakeet runner (CUDA)", + "steps": [ + { + "type": "configure", + "name": "parakeet-cuda" + }, + { + "type": "build", + "name": "parakeet-cuda" + } + ] + }, + { + "name": "parakeet-metal", + "displayName": "Configure and build Parakeet runner (Metal)", + "steps": [ + { + "type": "configure", + "name": "parakeet-metal" + }, + { + "type": "build", + "name": "parakeet-metal" + } + ] + } + ] +} diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index 97611b10d95..a5b340c1dcf 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -8,7 +8,7 @@ Export [nvidia/parakeet-tdt-0.6b-v3](https://huggingface.co/nvidia/parakeet-tdt- pip install nemo_toolkit[asr] torchaudio ``` -## Usage +## Export Export the model (portable backend): ```bash @@ -27,10 +27,58 @@ Test transcription on an audio file: python export_parakeet_tdt.py --audio /path/to/audio.wav ``` -### Arguments +### Export Arguments | Argument | Description | |----------|-------------| | `--output-dir` | Output directory for exports (default: `./parakeet_tdt_exports`) | | `--backend` | Backend for acceleration: `portable`, `xnnpack`, `cuda`, `cuda-windows` (default: `portable`) | | `--audio` | Path to audio file for transcription test | + +## C++ Runner + +### Building + +First, build ExecuTorch with the LLM preset: + +```bash +cd executorch +cmake --workflow --preset llm-release +``` + +Then build the parakeet runner: + +```bash +cd examples/models/parakeet +cmake --workflow --preset parakeet-cpu +``` + +For Metal (macOS): +```bash +cd examples/models/parakeet +cmake --workflow --preset parakeet-metal +``` + +For CUDA (Linux/Windows): +```bash +cd examples/models/parakeet +cmake --workflow --preset parakeet-cuda +``` + +### Running + +```bash +./cmake-out/examples/models/parakeet/parakeet_runner \ + --model_path parakeet.pte \ + --processor_path preprocessor.pte \ + --audio_path audio.wav +``` + +### Runner Arguments + +| Argument | Description | +|----------|-------------| +| `--model_path` | Path to Parakeet model (.pte) | +| `--processor_path` | Path to preprocessor .pte for mel spectrogram extraction | +| `--audio_path` | Path to input audio file (.wav) | +| `--tokenizer_path` | Path to tokenizer file (for token-to-text conversion) | diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 8d51b6826fe..00a727cad06 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -106,6 +106,9 @@ def greedy_decode_executorch( t = 0 symbols_on_frame = 0 + # Debug: print first few tokens + debug_count = 0 + # Scan over the encoder output while t < encoder_len: f_t = f_proj[:, t : t + 1, :].contiguous() @@ -128,6 +131,9 @@ def greedy_decode_executorch( t += max(dur, 1) symbols_on_frame = 0 else: + if debug_count < 20: + print(f"Token[{debug_count}]: t={t} k={k} dur={dur}") + debug_count += 1 hypothesis.append(k) token = torch.tensor([[k]], dtype=torch.long) @@ -161,6 +167,7 @@ def transcribe_executorch(audio_path: str, model, et_buffer) -> str: audio = load_audio(audio_path) mel, mel_len = model.preprocessor(input_signal=audio, length=torch.tensor([audio.shape[1]])) + print(mel.shape) encoder_method = program.load_method("encoder") enc_result = encoder_method.execute([mel, mel_len]) @@ -230,7 +237,8 @@ def export_all(model): """Export all components, return dict of ExportedPrograms.""" programs = {} - feat_in = getattr(model.encoder, "_feat_in", 80) + feat_in = getattr(model.encoder, "_feat_in", 128) + print(f"Encoder feat_in: {feat_in}") audio_signal = torch.randn(1, feat_in, 100) length = torch.tensor([100], dtype=torch.int64) programs["encoder"] = export( @@ -332,6 +340,65 @@ def lower_to_executorch(programs, backend="portable"): ) +def export_preprocessor(model, output_dir: str, backend: str = "portable"): + """Export NeMo's preprocessor to ExecuTorch.""" + + class PreprocessorWrapper(torch.nn.Module): + def __init__(self, preprocessor): + super().__init__() + self.preprocessor = preprocessor + + def forward(self, audio: torch.Tensor) -> torch.Tensor: + # audio is 1D: [num_samples] + # Add batch dimension and compute length + audio_signal = audio.unsqueeze(0) # [1, num_samples] + length = torch.tensor([audio.shape[0]], dtype=torch.int64) + + mel, mel_len = self.preprocessor(input_signal=audio_signal, length=length) + return mel + + wrapper = PreprocessorWrapper(model.preprocessor) + wrapper.eval() + + # Export with dynamic audio length + sample_audio = torch.randn(16000 * 10) # 10 seconds + + preprocessor_ep = export( + wrapper, + (sample_audio,), + dynamic_shapes={"audio": {0: Dim("audio_len", min=1600, max=16000 * 600)}}, + strict=False, + ) + + partitioner = [] + if backend == "xnnpack": + from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner + partitioner = [XnnpackPartitioner()] + + et_prog = to_edge_transform_and_lower( + {"forward": preprocessor_ep}, + partitioner=partitioner, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + ) + + et_preprocessor = et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), + ) + + pte_path = os.path.join(output_dir, "parakeet_preprocessor.pte") + print(f"Saving preprocessor to: {pte_path}") + with open(pte_path, "wb") as f: + et_preprocessor.write_to_file(f) + + return pte_path + + def main(): import argparse @@ -344,6 +411,16 @@ def main(): default="portable", help="Backend for acceleration", ) + parser.add_argument( + "--export-preprocessor", + action="store_true", + help="Export NeMo's preprocessor to ExecuTorch", + ) + parser.add_argument( + "--test-preprocessor", + type=str, + help="Test exported preprocessor against NeMo's native preprocessor", + ) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) @@ -351,11 +428,67 @@ def main(): print("Loading model...") model = load_model() + if args.test_preprocessor: + print("\nTesting preprocessor...") + from executorch.runtime import Runtime + + audio = load_audio(args.test_preprocessor) + audio_1d = audio.squeeze(0) # [num_samples] + + print(f"Python audio shape: {audio.shape}, first 5 samples: {audio[0, :5].tolist()}") + + # NeMo's native preprocessor + mel_native, mel_len_native = model.preprocessor( + input_signal=audio, + length=torch.tensor([audio.shape[1]]) + ) + print(f"NeMo mel shape: {mel_native.shape}, mel_len: {mel_len_native.item()}") + + # Exported preprocessor + pte_path = os.path.join(args.output_dir, "parakeet_preprocessor.pte") + with open(pte_path, "rb") as f: + runtime = Runtime.get() + program = runtime.load_program(f.read()) + method = program.load_method("forward") + mel_exported = method.execute([audio_1d])[0] + print(f"Exported mel shape: {mel_exported.shape}") + + # Compare + mel_native_np = mel_native.numpy() + mel_exported_np = mel_exported.numpy() + + max_diff = abs(mel_native_np - mel_exported_np).max() + mean_diff = abs(mel_native_np - mel_exported_np).mean() + print(f"Max diff: {max_diff:.6f}, Mean diff: {mean_diff:.6f}") + + if max_diff < 1e-4: + print("✓ Preprocessors match!") + else: + print("✗ Preprocessors differ!") + # Print first few values + print(f"Native [0,0,:5]: {mel_native_np[0,0,:5]}") + print(f"Exported [0,0,:5]: {mel_exported_np[0,0,:5]}") + return + + if args.export_preprocessor: + print("\nExporting preprocessor...") + export_preprocessor(model, args.output_dir, args.backend) + print("Preprocessor exported!") + if not args.audio: + return + print("\nExporting components...") programs = export_all(model) et = lower_to_executorch(programs, backend=args.backend) + # Save the .pte file + pte_path = os.path.join(args.output_dir, "parakeet_tdt.pte") + print(f"\nSaving ExecuTorch program to: {pte_path}") + with open(pte_path, "wb") as f: + et.write_to_file(f) + print(f"Saved {os.path.getsize(pte_path) / (1024 * 1024):.1f} MB") + if args.audio: print("\n" + "=" * 60) print("Testing transcription...") diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp new file mode 100644 index 00000000000..cdef1521cf8 --- /dev/null +++ b/examples/models/parakeet/main.cpp @@ -0,0 +1,411 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +DEFINE_string(model_path, "parakeet.pte", "Path to Parakeet model (.pte)."); +DEFINE_string( + processor_path, + "", + "Path to preprocessor .pte for converting raw audio to mel spectrogram."); +DEFINE_string(audio_path, "", "Path to input audio file (.wav)."); +DEFINE_string( + tokenizer_path, + "tokenizer.json", + "Path to SentencePiece tokenizer model file."); + +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; +using ::executorch::runtime::EValue; +using ::executorch::runtime::Error; + +namespace { + +// TDT duration values +const std::vector DURATIONS = {0, 1, 2, 3, 4}; + +std::vector greedy_decode_executorch( + Module& model, + const ::executorch::aten::Tensor& encoder_output, + int64_t encoder_len, + int64_t blank_id, + int64_t vocab_size, + int64_t num_rnn_layers = 2, + int64_t pred_hidden = 640, + int64_t max_symbols_per_step = 10) { + std::vector hypothesis; + int64_t num_token_classes = vocab_size + 1; + + // Transpose encoder output from [1, enc_dim, time] to [1, time, enc_dim] + // The encoder output shape is [1, 1024, T], we need [1, T, 1024] + auto enc_sizes = encoder_output.sizes(); + int64_t batch = enc_sizes[0]; + int64_t enc_dim = enc_sizes[1]; + int64_t time_steps = enc_sizes[2]; + + // Create transposed tensor + std::vector transposed_data(batch * time_steps * enc_dim); + const float* src = encoder_output.const_data_ptr(); + for (int64_t t = 0; t < time_steps; t++) { + for (int64_t d = 0; d < enc_dim; d++) { + transposed_data[t * enc_dim + d] = src[d * time_steps + t]; + } + } + + auto transposed_tensor = from_blob( + transposed_data.data(), + {static_cast<::executorch::aten::SizesType>(batch), + static_cast<::executorch::aten::SizesType>(time_steps), + static_cast<::executorch::aten::SizesType>(enc_dim)}, + ::executorch::aten::ScalarType::Float); + + // Project encoder output + auto proj_enc_result = model.execute( + "joint_project_encoder", + std::vector<::executorch::runtime::EValue>{transposed_tensor}); + if (!proj_enc_result.ok()) { + ET_LOG(Error, "joint_project_encoder failed"); + return hypothesis; + } + auto f_proj = proj_enc_result.get()[0].toTensor(); + + // Initialize LSTM state + std::vector h_data(num_rnn_layers * 1 * pred_hidden, 0.0f); + std::vector c_data(num_rnn_layers * 1 * pred_hidden, 0.0f); + + auto h = from_blob( + h_data.data(), + {static_cast<::executorch::aten::SizesType>(num_rnn_layers), + 1, + static_cast<::executorch::aten::SizesType>(pred_hidden)}, + ::executorch::aten::ScalarType::Float); + auto c = from_blob( + c_data.data(), + {static_cast<::executorch::aten::SizesType>(num_rnn_layers), + 1, + static_cast<::executorch::aten::SizesType>(pred_hidden)}, + ::executorch::aten::ScalarType::Float); + + // Initialize decoder with SOS (zeros) + std::vector sos_g_data(1 * 1 * pred_hidden, 0.0f); + auto sos_g = from_blob( + sos_g_data.data(), + {1, 1, static_cast<::executorch::aten::SizesType>(pred_hidden)}, + ::executorch::aten::ScalarType::Float); + + auto g_proj_result = model.execute( + "joint_project_decoder", + std::vector<::executorch::runtime::EValue>{sos_g}); + if (!g_proj_result.ok()) { + ET_LOG(Error, "joint_project_decoder failed"); + return hypothesis; + } + auto g_proj_tensor = g_proj_result.get()[0].toTensor(); + + // Copy g_proj data for reuse + std::vector g_proj_data( + g_proj_tensor.const_data_ptr(), + g_proj_tensor.const_data_ptr() + g_proj_tensor.numel()); + + int64_t t = 0; + int64_t symbols_on_frame = 0; + + // Debug: print first few tokens + bool debug = true; + int debug_count = 0; + + // Scan over encoder output + while (t < encoder_len) { + // Get encoder frame at time t: f_proj[:, t:t+1, :] + const float* f_proj_data = f_proj.const_data_ptr(); + int64_t proj_dim = f_proj.sizes()[2]; + + std::vector f_t_data(1 * 1 * proj_dim); + for (int64_t d = 0; d < proj_dim; d++) { + f_t_data[d] = f_proj_data[t * proj_dim + d]; + } + auto f_t = from_blob( + f_t_data.data(), + {1, 1, static_cast<::executorch::aten::SizesType>(proj_dim)}, + ::executorch::aten::ScalarType::Float); + + auto g_proj = from_blob( + g_proj_data.data(), + {1, 1, static_cast<::executorch::aten::SizesType>(proj_dim)}, + ::executorch::aten::ScalarType::Float); + + // Joint network + auto joint_result = model.execute( + "joint", std::vector<::executorch::runtime::EValue>{f_t, g_proj}); + if (!joint_result.ok()) { + ET_LOG(Error, "joint failed at t=%ld", t); + return hypothesis; + } + auto full_logits = joint_result.get()[0].toTensor(); + + // Split logits into token and duration + const float* logits_data = full_logits.const_data_ptr(); + + // Find argmax for token logits + int64_t k = 0; + float max_token_logit = logits_data[0]; + for (int64_t i = 1; i < num_token_classes; i++) { + if (logits_data[i] > max_token_logit) { + max_token_logit = logits_data[i]; + k = i; + } + } + + // Find argmax for duration logits + int64_t dur_idx = 0; + float max_dur_logit = logits_data[num_token_classes]; + for (size_t i = 1; i < DURATIONS.size(); i++) { + if (logits_data[num_token_classes + i] > max_dur_logit) { + max_dur_logit = logits_data[num_token_classes + i]; + dur_idx = i; + } + } + int64_t dur = DURATIONS[dur_idx]; + + // TDT decoding: joint network outputs both token logits and duration + // logits. + // - If blank: skip forward by predicted duration (min 1 frame) + // - If token: emit it, update decoder state, advance by duration. + // Duration=0 means "emit another token on this frame" (up to + // max_symbols_per_step). + if (k == blank_id) { + t += std::max(dur, (int64_t)1); + symbols_on_frame = 0; + } else { + if (debug && debug_count < 20) { + ET_LOG(Info, "Token[%d]: t=%ld k=%ld dur=%ld", debug_count, t, k, dur); + debug_count++; + } + hypothesis.push_back(k); + + // Update decoder state + std::vector token_data = {k}; + auto token = from_blob( + token_data.data(), {1, 1}, ::executorch::aten::ScalarType::Long); + + auto decoder_result = model.execute( + "decoder_predict", + std::vector<::executorch::runtime::EValue>{token, h, c}); + if (!decoder_result.ok()) { + ET_LOG(Error, "decoder_predict failed"); + return hypothesis; + } + auto& outputs = decoder_result.get(); + auto g = outputs[0].toTensor(); + auto new_h = outputs[1].toTensor(); + auto new_c = outputs[2].toTensor(); + + // Update h and c + std::memcpy( + h_data.data(), + new_h.const_data_ptr(), + h_data.size() * sizeof(float)); + std::memcpy( + c_data.data(), + new_c.const_data_ptr(), + c_data.size() * sizeof(float)); + + // Project decoder output + auto proj_dec_result = model.execute( + "joint_project_decoder", + std::vector<::executorch::runtime::EValue>{g}); + if (!proj_dec_result.ok()) { + ET_LOG(Error, "joint_project_decoder failed"); + return hypothesis; + } + auto new_g_proj = proj_dec_result.get()[0].toTensor(); + std::memcpy( + g_proj_data.data(), + new_g_proj.const_data_ptr(), + g_proj_data.size() * sizeof(float)); + + t += dur; + + if (dur == 0) { + symbols_on_frame++; + if (symbols_on_frame >= max_symbols_per_step) { + t++; + symbols_on_frame = 0; + } + } else { + symbols_on_frame = 0; + } + } + } + + return hypothesis; +} + +std::string tokens_to_text( + const std::vector& tokens, + tokenizers::Tokenizer* tokenizer) { + // Decode tokens to text one by one + std::string result; + uint64_t prev_token = 0; + for (size_t i = 0; i < tokens.size(); i++) { + uint64_t token = static_cast(tokens[i]); + auto decode_result = tokenizer->decode(prev_token, token); + if (decode_result.ok()) { + result += decode_result.get(); + } + prev_token = token; + } + + return result; +} + +} // namespace + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (FLAGS_audio_path.empty()) { + ET_LOG(Error, "audio_path flag must be provided."); + return 1; + } + + if (FLAGS_processor_path.empty()) { + ET_LOG(Error, "processor_path flag must be provided."); + return 1; + } + + // Load audio + ET_LOG(Info, "Loading audio from: %s", FLAGS_audio_path.c_str()); + std::vector audio_data = + ::executorch::extension::llm::load_wav_audio_data(FLAGS_audio_path); + ET_LOG(Info, "Loaded %zu audio samples", audio_data.size()); + ET_LOG( + Info, + "First 5 audio samples: %f, %f, %f, %f, %f", + audio_data[0], + audio_data[1], + audio_data[2], + audio_data[3], + audio_data[4]); + + // Load preprocessor + ET_LOG(Info, "Loading preprocessor from: %s", FLAGS_processor_path.c_str()); + Module processor(FLAGS_processor_path, Module::LoadMode::Mmap); + auto proc_load_error = processor.load(); + if (proc_load_error != Error::Ok) { + ET_LOG(Error, "Failed to load preprocessor module."); + return 1; + } + + // Process audio to mel spectrogram + auto audio_tensor = from_blob( + audio_data.data(), + {static_cast<::executorch::aten::SizesType>(audio_data.size())}, + ::executorch::aten::ScalarType::Float); + + auto proc_result = processor.execute( + "forward", std::vector<::executorch::runtime::EValue>{audio_tensor}); + if (!proc_result.ok()) { + ET_LOG(Error, "Preprocessor forward failed."); + return 1; + } + auto& proc_outputs = proc_result.get(); + auto mel = proc_outputs[0].toTensor(); + + // Compute mel_len from tensor shape + std::vector mel_len_data = { + static_cast(mel.sizes()[2])}; + auto mel_len = from_blob( + mel_len_data.data(), {1}, ::executorch::aten::ScalarType::Long); + + ET_LOG( + Info, + "Mel spectrogram shape: [%ld, %ld, %ld]", + static_cast(mel.sizes()[0]), + static_cast(mel.sizes()[1]), + static_cast(mel.sizes()[2])); + + // Load model + ET_LOG(Info, "Loading model from: %s", FLAGS_model_path.c_str()); + Module model(FLAGS_model_path, Module::LoadMode::Mmap); + auto model_load_error = model.load(); + if (model_load_error != Error::Ok) { + ET_LOG(Error, "Failed to load model."); + return 1; + } + + // Run encoder + ET_LOG(Info, "Running encoder..."); + auto enc_result = model.execute( + "encoder", std::vector<::executorch::runtime::EValue>{mel, mel_len}); + if (!enc_result.ok()) { + ET_LOG(Error, "Encoder forward failed."); + return 1; + } + auto& enc_outputs = enc_result.get(); + auto encoded = enc_outputs[0].toTensor(); + int64_t encoded_len = enc_outputs[1].toTensor().const_data_ptr()[0]; + + ET_LOG( + Info, + "Encoder output shape: [%ld, %ld, %ld], len=%ld", + static_cast(encoded.sizes()[0]), + static_cast(encoded.sizes()[1]), + static_cast(encoded.sizes()[2]), + static_cast(encoded_len)); + + // Greedy decode + // Parakeet TDT uses vocab_size=8192, blank_id=8192 + int64_t vocab_size = 8192; + int64_t blank_id = vocab_size; + int64_t num_rnn_layers = 2; + int64_t pred_hidden = 640; + + ET_LOG(Info, "Running TDT greedy decode..."); + auto tokens = greedy_decode_executorch( + model, + encoded, + encoded_len, + blank_id, + vocab_size, + num_rnn_layers, + pred_hidden); + + ET_LOG(Info, "Decoded %zu tokens", tokens.size()); + + // Load tokenizer using the LLM runner helper + ET_LOG(Info, "Loading tokenizer from: %s", FLAGS_tokenizer_path.c_str()); + auto tokenizer = + ::executorch::extension::llm::load_tokenizer(FLAGS_tokenizer_path); + if (!tokenizer || !tokenizer->is_loaded()) { + ET_LOG(Error, "Failed to load tokenizer from: %s", FLAGS_tokenizer_path.c_str()); + return 1; + } + + // Convert tokens to text + std::string text = tokens_to_text(tokens, tokenizer.get()); + std::cout << "Transcription tokens: " << text << std::endl; + + ET_LOG(Info, "Done!"); + return 0; +} From 47ae3e0957fa14beea2bae6b45ef5770690a27ac Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Mon, 22 Dec 2025 12:17:13 -0800 Subject: [PATCH 3/9] clean up and fixes --- examples/models/parakeet/README.md | 23 +- .../models/parakeet/export_parakeet_tdt.py | 214 +++++------------- examples/models/parakeet/main.cpp | 131 +++++------ 3 files changed, 122 insertions(+), 246 deletions(-) diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index a5b340c1dcf..e0debb9d5f2 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -15,14 +15,7 @@ Export the model (portable backend): python export_parakeet_tdt.py ``` -Export with a specific backend: -```bash -python export_parakeet_tdt.py --backend xnnpack # CPU acceleration -python export_parakeet_tdt.py --backend cuda # CUDA acceleration -python export_parakeet_tdt.py --backend cuda-windows # CUDA on Windows -``` - -Test transcription on an audio file: +Test transcription on an audio file and compare eager vs lowered results: ```bash python export_parakeet_tdt.py --audio /path/to/audio.wav ``` @@ -53,24 +46,11 @@ cd examples/models/parakeet cmake --workflow --preset parakeet-cpu ``` -For Metal (macOS): -```bash -cd examples/models/parakeet -cmake --workflow --preset parakeet-metal -``` - -For CUDA (Linux/Windows): -```bash -cd examples/models/parakeet -cmake --workflow --preset parakeet-cuda -``` - ### Running ```bash ./cmake-out/examples/models/parakeet/parakeet_runner \ --model_path parakeet.pte \ - --processor_path preprocessor.pte \ --audio_path audio.wav ``` @@ -79,6 +59,5 @@ cmake --workflow --preset parakeet-cuda | Argument | Description | |----------|-------------| | `--model_path` | Path to Parakeet model (.pte) | -| `--processor_path` | Path to preprocessor .pte for mel spectrogram extraction | | `--audio_path` | Path to input audio file (.wav) | | `--tokenizer_path` | Path to tokenizer file (for token-to-text conversion) | diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 00a727cad06..56a9822cda2 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -45,7 +45,6 @@ def load_audio(audio_path: str, sample_rate: int = 16000) -> torch.Tensor: def greedy_decode_eager(encoder_output: torch.Tensor, encoder_len: torch.Tensor, model) -> list[int]: - """Greedy decode using NeMo's built-in decoding.""" hypotheses = model.decoding.rnnt_decoder_predictions_tensor( encoder_output=encoder_output, encoded_lengths=encoder_len, @@ -55,8 +54,6 @@ def greedy_decode_eager(encoder_output: torch.Tensor, encoder_len: torch.Tensor, class DecoderPredict(torch.nn.Module): - """Wrapper for decoder.predict() with LSTM state.""" - def __init__(self, decoder): super().__init__() self.decoder = decoder @@ -81,7 +78,6 @@ def greedy_decode_executorch( max_symbols_per_step: int = 10, durations: list[int] | None = None, ) -> list[int]: - """TDT duration-aware greedy decode using ExecuTorch runtime.""" if durations is None: durations = [0, 1, 2, 3, 4] @@ -106,10 +102,6 @@ def greedy_decode_executorch( t = 0 symbols_on_frame = 0 - # Debug: print first few tokens - debug_count = 0 - - # Scan over the encoder output while t < encoder_len: f_t = f_proj[:, t : t + 1, :].contiguous() @@ -123,17 +115,10 @@ def greedy_decode_executorch( dur_idx = duration_logits.argmax().item() dur = durations[dur_idx] - # TDT decoding: joint network outputs both token logits and duration logits. - # - If blank: skip forward by predicted duration (min 1 frame) - # - If token: emit it, update decoder state, advance by duration. - # Duration=0 means "emit another token on this frame" (up to max_symbols_per_step). if k == blank_id: t += max(dur, 1) symbols_on_frame = 0 else: - if debug_count < 20: - print(f"Token[{debug_count}]: t={t} k={k} dur={dur}") - debug_count += 1 hypothesis.append(k) token = torch.tensor([[k]], dtype=torch.long) @@ -157,20 +142,26 @@ def greedy_decode_executorch( def transcribe_executorch(audio_path: str, model, et_buffer) -> str: - """Transcribe audio file using ExecuTorch runtime.""" from executorch.runtime import Runtime runtime = Runtime.get() program = runtime.load_program(et_buffer) - with torch.no_grad(): - audio = load_audio(audio_path) + # Get sample rate from model + sample_rate = model.preprocessor._cfg.sample_rate - mel, mel_len = model.preprocessor(input_signal=audio, length=torch.tensor([audio.shape[1]])) - print(mel.shape) + with torch.no_grad(): + audio = load_audio(audio_path, sample_rate=sample_rate) + preprocessor_method = program.load_method("preprocessor") + audio_1d = audio.squeeze(0) + audio_len = torch.tensor([audio_1d.shape[0]], dtype=torch.int64) + proc_result = preprocessor_method.execute([audio_1d, audio_len]) + mel = proc_result[0] + mel_len = proc_result[1].item() encoder_method = program.load_method("encoder") - enc_result = encoder_method.execute([mel, mel_len]) + mel_len_tensor = torch.tensor([mel_len], dtype=torch.int64) + enc_result = encoder_method.execute([mel, mel_len_tensor]) encoded = enc_result[0] encoded_len = enc_result[1].item() @@ -189,7 +180,6 @@ def transcribe_executorch(audio_path: str, model, et_buffer) -> str: def transcribe_eager(audio_path: str, model) -> str: - """Transcribe audio file using eager PyTorch model.""" with torch.no_grad(): audio = load_audio(audio_path) mel, mel_len = model.preprocessor(input_signal=audio, length=torch.tensor([audio.shape[1]])) @@ -233,12 +223,32 @@ def forward(self, g): return self.joint.project_prednet(g) +class PreprocessorWrapper(torch.nn.Module): + def __init__(self, preprocessor): + super().__init__() + self.preprocessor = preprocessor + + def forward(self, audio: torch.Tensor, length: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + audio_signal = audio.unsqueeze(0) + mel, mel_len = self.preprocessor(input_signal=audio_signal, length=length) + return mel, mel_len + + def export_all(model): - """Export all components, return dict of ExportedPrograms.""" programs = {} + preprocessor_wrapper = PreprocessorWrapper(model.preprocessor) + preprocessor_wrapper.eval() + sample_audio = torch.randn(16000 * 10) + sample_length = torch.tensor([sample_audio.shape[0]], dtype=torch.int64) + programs["preprocessor"] = export( + preprocessor_wrapper, + (sample_audio, sample_length), + dynamic_shapes={"audio": {0: Dim("audio_len", min=1600, max=16000 * 600)}, "length": {}}, + strict=False, + ) + feat_in = getattr(model.encoder, "_feat_in", 128) - print(f"Encoder feat_in: {feat_in}") audio_signal = torch.randn(1, feat_in, 100) length = torch.tensor([100], dtype=torch.int64) programs["encoder"] = export( @@ -263,8 +273,10 @@ def export_all(model): strict=False, ) - f_proj = torch.randn(1, 1, 640) - g_proj = torch.randn(1, 1, 640) + joint_hidden = model.joint.joint_hidden + + f_proj = torch.randn(1, 1, joint_hidden) + g_proj = torch.randn(1, 1, joint_hidden) programs["joint"] = export( JointAfterProjection(model.joint), (f_proj, g_proj), @@ -272,14 +284,15 @@ def export_all(model): strict=False, ) + enc_output_dim = getattr(model.encoder, "_feat_out", 1024) + programs["joint_project_encoder"] = export( JointProjectEncoder(model.joint), - (torch.randn(1, 25, 1024),), + (torch.randn(1, 25, enc_output_dim),), dynamic_shapes={"f": {1: Dim("enc_time", min=1, max=60000)}}, strict=False, ) - pred_hidden = getattr(model.decoder, "pred_hidden", 640) programs["joint_project_decoder"] = export( JointProjectDecoder(model.joint), (torch.randn(1, 1, pred_hidden),), @@ -287,11 +300,20 @@ def export_all(model): strict=False, ) - return programs + sample_rate = model.preprocessor._cfg.sample_rate + metadata = { + "num_rnn_layers": num_layers, + "pred_hidden": pred_hidden, + "joint_hidden": joint_hidden, + "vocab_size": model.tokenizer.vocab_size, + "blank_id": model.tokenizer.vocab_size, + "sample_rate": sample_rate, + } + + return programs, metadata -def lower_to_executorch(programs, backend="portable"): - """Lower all ExportedPrograms to ExecuTorch.""" +def lower_to_executorch(programs, metadata=None, backend="portable"): partitioner = None if backend == "xnnpack": @@ -309,7 +331,6 @@ def lower_to_executorch(programs, backend="portable"): print(f"\nLowering to ExecuTorch with CUDA{' (Windows)' if backend == 'cuda-windows' else ''}...") - # Decompose conv1d to conv2d for Triton kernel generation for key, ep in programs.items(): programs[key] = ep.run_decompositions({torch.ops.aten.conv1d.default: conv1d_to_conv2d}) @@ -324,6 +345,11 @@ def lower_to_executorch(programs, backend="portable"): print("\nLowering to ExecuTorch") partitioner = [] + constant_methods = {} + if metadata: + for key, value in metadata.items(): + constant_methods[key] = value + et_prog = to_edge_transform_and_lower( programs, partitioner=partitioner, @@ -331,6 +357,7 @@ def lower_to_executorch(programs, backend="portable"): _check_ir_validity=False, _skip_dim_order=True, ), + constant_methods=constant_methods if constant_methods else None, ) return et_prog.to_executorch( config=ExecutorchBackendConfig( @@ -340,65 +367,6 @@ def lower_to_executorch(programs, backend="portable"): ) -def export_preprocessor(model, output_dir: str, backend: str = "portable"): - """Export NeMo's preprocessor to ExecuTorch.""" - - class PreprocessorWrapper(torch.nn.Module): - def __init__(self, preprocessor): - super().__init__() - self.preprocessor = preprocessor - - def forward(self, audio: torch.Tensor) -> torch.Tensor: - # audio is 1D: [num_samples] - # Add batch dimension and compute length - audio_signal = audio.unsqueeze(0) # [1, num_samples] - length = torch.tensor([audio.shape[0]], dtype=torch.int64) - - mel, mel_len = self.preprocessor(input_signal=audio_signal, length=length) - return mel - - wrapper = PreprocessorWrapper(model.preprocessor) - wrapper.eval() - - # Export with dynamic audio length - sample_audio = torch.randn(16000 * 10) # 10 seconds - - preprocessor_ep = export( - wrapper, - (sample_audio,), - dynamic_shapes={"audio": {0: Dim("audio_len", min=1600, max=16000 * 600)}}, - strict=False, - ) - - partitioner = [] - if backend == "xnnpack": - from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner - partitioner = [XnnpackPartitioner()] - - et_prog = to_edge_transform_and_lower( - {"forward": preprocessor_ep}, - partitioner=partitioner, - compile_config=EdgeCompileConfig( - _check_ir_validity=False, - _skip_dim_order=True, - ), - ) - - et_preprocessor = et_prog.to_executorch( - config=ExecutorchBackendConfig( - extract_delegate_segments=True, - memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), - ), - ) - - pte_path = os.path.join(output_dir, "parakeet_preprocessor.pte") - print(f"Saving preprocessor to: {pte_path}") - with open(pte_path, "wb") as f: - et_preprocessor.write_to_file(f) - - return pte_path - - def main(): import argparse @@ -411,16 +379,6 @@ def main(): default="portable", help="Backend for acceleration", ) - parser.add_argument( - "--export-preprocessor", - action="store_true", - help="Export NeMo's preprocessor to ExecuTorch", - ) - parser.add_argument( - "--test-preprocessor", - type=str, - help="Test exported preprocessor against NeMo's native preprocessor", - ) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) @@ -428,61 +386,11 @@ def main(): print("Loading model...") model = load_model() - if args.test_preprocessor: - print("\nTesting preprocessor...") - from executorch.runtime import Runtime - - audio = load_audio(args.test_preprocessor) - audio_1d = audio.squeeze(0) # [num_samples] - - print(f"Python audio shape: {audio.shape}, first 5 samples: {audio[0, :5].tolist()}") - - # NeMo's native preprocessor - mel_native, mel_len_native = model.preprocessor( - input_signal=audio, - length=torch.tensor([audio.shape[1]]) - ) - print(f"NeMo mel shape: {mel_native.shape}, mel_len: {mel_len_native.item()}") - - # Exported preprocessor - pte_path = os.path.join(args.output_dir, "parakeet_preprocessor.pte") - with open(pte_path, "rb") as f: - runtime = Runtime.get() - program = runtime.load_program(f.read()) - method = program.load_method("forward") - mel_exported = method.execute([audio_1d])[0] - print(f"Exported mel shape: {mel_exported.shape}") - - # Compare - mel_native_np = mel_native.numpy() - mel_exported_np = mel_exported.numpy() - - max_diff = abs(mel_native_np - mel_exported_np).max() - mean_diff = abs(mel_native_np - mel_exported_np).mean() - print(f"Max diff: {max_diff:.6f}, Mean diff: {mean_diff:.6f}") - - if max_diff < 1e-4: - print("✓ Preprocessors match!") - else: - print("✗ Preprocessors differ!") - # Print first few values - print(f"Native [0,0,:5]: {mel_native_np[0,0,:5]}") - print(f"Exported [0,0,:5]: {mel_exported_np[0,0,:5]}") - return - - if args.export_preprocessor: - print("\nExporting preprocessor...") - export_preprocessor(model, args.output_dir, args.backend) - print("Preprocessor exported!") - if not args.audio: - return - print("\nExporting components...") - programs = export_all(model) + programs, metadata = export_all(model) - et = lower_to_executorch(programs, backend=args.backend) + et = lower_to_executorch(programs, metadata=metadata, backend=args.backend) - # Save the .pte file pte_path = os.path.join(args.output_dir, "parakeet_tdt.pte") print(f"\nSaving ExecuTorch program to: {pte_path}") with open(pte_path, "wb") as f: diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index cdef1521cf8..f99a55422e1 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -24,10 +24,6 @@ #include DEFINE_string(model_path, "parakeet.pte", "Path to Parakeet model (.pte)."); -DEFINE_string( - processor_path, - "", - "Path to preprocessor .pte for converting raw audio to mel spectrogram."); DEFINE_string(audio_path, "", "Path to input audio file (.wav)."); DEFINE_string( tokenizer_path, @@ -57,7 +53,6 @@ std::vector greedy_decode_executorch( int64_t num_token_classes = vocab_size + 1; // Transpose encoder output from [1, enc_dim, time] to [1, time, enc_dim] - // The encoder output shape is [1, 1024, T], we need [1, T, 1024] auto enc_sizes = encoder_output.sizes(); int64_t batch = enc_sizes[0]; int64_t enc_dim = enc_sizes[1]; @@ -72,12 +67,12 @@ std::vector greedy_decode_executorch( } } - auto transposed_tensor = from_blob( - transposed_data.data(), - {static_cast<::executorch::aten::SizesType>(batch), - static_cast<::executorch::aten::SizesType>(time_steps), - static_cast<::executorch::aten::SizesType>(enc_dim)}, - ::executorch::aten::ScalarType::Float); + auto transposed_tensor = from_blob( + transposed_data.data(), + {static_cast<::executorch::aten::SizesType>(batch), + static_cast<::executorch::aten::SizesType>(time_steps), + static_cast<::executorch::aten::SizesType>(enc_dim)}, + ::executorch::aten::ScalarType::Float); // Project encoder output auto proj_enc_result = model.execute( @@ -106,7 +101,7 @@ std::vector greedy_decode_executorch( static_cast<::executorch::aten::SizesType>(pred_hidden)}, ::executorch::aten::ScalarType::Float); - // Initialize decoder with SOS (zeros) + // Initialize decoder state with zeros std::vector sos_g_data(1 * 1 * pred_hidden, 0.0f); auto sos_g = from_blob( sos_g_data.data(), @@ -130,10 +125,6 @@ std::vector greedy_decode_executorch( int64_t t = 0; int64_t symbols_on_frame = 0; - // Debug: print first few tokens - bool debug = true; - int debug_count = 0; - // Scan over encoder output while (t < encoder_len) { // Get encoder frame at time t: f_proj[:, t:t+1, :] @@ -158,7 +149,7 @@ std::vector greedy_decode_executorch( auto joint_result = model.execute( "joint", std::vector<::executorch::runtime::EValue>{f_t, g_proj}); if (!joint_result.ok()) { - ET_LOG(Error, "joint failed at t=%ld", t); + ET_LOG(Error, "joint failed at t=%lld", static_cast(t)); return hypothesis; } auto full_logits = joint_result.get()[0].toTensor(); @@ -187,20 +178,10 @@ std::vector greedy_decode_executorch( } int64_t dur = DURATIONS[dur_idx]; - // TDT decoding: joint network outputs both token logits and duration - // logits. - // - If blank: skip forward by predicted duration (min 1 frame) - // - If token: emit it, update decoder state, advance by duration. - // Duration=0 means "emit another token on this frame" (up to - // max_symbols_per_step). if (k == blank_id) { t += std::max(dur, (int64_t)1); symbols_on_frame = 0; } else { - if (debug && debug_count < 20) { - ET_LOG(Info, "Token[%d]: t=%ld k=%ld dur=%ld", debug_count, t, k, dur); - debug_count++; - } hypothesis.push_back(k); // Update decoder state @@ -289,8 +270,12 @@ int main(int argc, char** argv) { return 1; } - if (FLAGS_processor_path.empty()) { - ET_LOG(Error, "processor_path flag must be provided."); + // Load model (which includes the bundled preprocessor) + ET_LOG(Info, "Loading model from: %s", FLAGS_model_path.c_str()); + Module model(FLAGS_model_path, Module::LoadMode::Mmap); + auto model_load_error = model.load(); + if (model_load_error != Error::Ok) { + ET_LOG(Error, "Failed to load model."); return 1; } @@ -299,60 +284,41 @@ int main(int argc, char** argv) { std::vector audio_data = ::executorch::extension::llm::load_wav_audio_data(FLAGS_audio_path); ET_LOG(Info, "Loaded %zu audio samples", audio_data.size()); - ET_LOG( - Info, - "First 5 audio samples: %f, %f, %f, %f, %f", - audio_data[0], - audio_data[1], - audio_data[2], - audio_data[3], - audio_data[4]); - - // Load preprocessor - ET_LOG(Info, "Loading preprocessor from: %s", FLAGS_processor_path.c_str()); - Module processor(FLAGS_processor_path, Module::LoadMode::Mmap); - auto proc_load_error = processor.load(); - if (proc_load_error != Error::Ok) { - ET_LOG(Error, "Failed to load preprocessor module."); - return 1; - } - // Process audio to mel spectrogram auto audio_tensor = from_blob( audio_data.data(), {static_cast<::executorch::aten::SizesType>(audio_data.size())}, ::executorch::aten::ScalarType::Float); - - auto proc_result = processor.execute( - "forward", std::vector<::executorch::runtime::EValue>{audio_tensor}); + std::vector audio_len_data = {static_cast(audio_data.size())}; + auto audio_len_tensor = from_blob( + audio_len_data.data(), + {1}, + ::executorch::aten::ScalarType::Long); + + ET_LOG(Info, "Running preprocessor..."); + auto proc_result = model.execute( + "preprocessor", std::vector<::executorch::runtime::EValue>{audio_tensor, audio_len_tensor}); if (!proc_result.ok()) { ET_LOG(Error, "Preprocessor forward failed."); return 1; } auto& proc_outputs = proc_result.get(); auto mel = proc_outputs[0].toTensor(); + auto mel_len_tensor_out = proc_outputs[1].toTensor(); + int64_t mel_len_value = mel_len_tensor_out.const_data_ptr()[0]; - // Compute mel_len from tensor shape - std::vector mel_len_data = { - static_cast(mel.sizes()[2])}; + // Create mel_len tensor for encoder + std::vector mel_len_data = {mel_len_value}; auto mel_len = from_blob( mel_len_data.data(), {1}, ::executorch::aten::ScalarType::Long); ET_LOG( Info, - "Mel spectrogram shape: [%ld, %ld, %ld]", + "Mel spectrogram shape: [%ld, %ld, %ld], mel_len: %lld", static_cast(mel.sizes()[0]), static_cast(mel.sizes()[1]), - static_cast(mel.sizes()[2])); - - // Load model - ET_LOG(Info, "Loading model from: %s", FLAGS_model_path.c_str()); - Module model(FLAGS_model_path, Module::LoadMode::Mmap); - auto model_load_error = model.load(); - if (model_load_error != Error::Ok) { - ET_LOG(Error, "Failed to load model."); - return 1; - } + static_cast(mel.sizes()[2]), + static_cast(mel_len_value)); // Run encoder ET_LOG(Info, "Running encoder..."); @@ -374,12 +340,35 @@ int main(int argc, char** argv) { static_cast(encoded.sizes()[2]), static_cast(encoded_len)); - // Greedy decode - // Parakeet TDT uses vocab_size=8192, blank_id=8192 - int64_t vocab_size = 8192; - int64_t blank_id = vocab_size; - int64_t num_rnn_layers = 2; - int64_t pred_hidden = 640; + // Query model metadata from constant_methods + std::vector<::executorch::runtime::EValue> empty_inputs; + auto num_rnn_layers_result = model.execute("num_rnn_layers", empty_inputs); + auto pred_hidden_result = model.execute("pred_hidden", empty_inputs); + auto vocab_size_result = model.execute("vocab_size", empty_inputs); + auto blank_id_result = model.execute("blank_id", empty_inputs); + auto sample_rate_result = model.execute("sample_rate", empty_inputs); + + if (!num_rnn_layers_result.ok() || !pred_hidden_result.ok() || + !vocab_size_result.ok() || !blank_id_result.ok() || + !sample_rate_result.ok()) { + ET_LOG(Error, "Failed to query model metadata. Make sure the model was exported with constant_methods."); + return 1; + } + + int64_t vocab_size = vocab_size_result.get()[0].toInt(); + int64_t blank_id = blank_id_result.get()[0].toInt(); + int64_t num_rnn_layers = num_rnn_layers_result.get()[0].toInt(); + int64_t pred_hidden = pred_hidden_result.get()[0].toInt(); + int64_t sample_rate = sample_rate_result.get()[0].toInt(); + + ET_LOG( + Info, + "Model metadata: vocab_size=%lld, blank_id=%lld, num_rnn_layers=%lld, pred_hidden=%lld, sample_rate=%lld", + static_cast(vocab_size), + static_cast(blank_id), + static_cast(num_rnn_layers), + static_cast(pred_hidden), + static_cast(sample_rate)); ET_LOG(Info, "Running TDT greedy decode..."); auto tokens = greedy_decode_executorch( @@ -393,7 +382,7 @@ int main(int argc, char** argv) { ET_LOG(Info, "Decoded %zu tokens", tokens.size()); - // Load tokenizer using the LLM runner helper + // Load tokenizer ET_LOG(Info, "Loading tokenizer from: %s", FLAGS_tokenizer_path.c_str()); auto tokenizer = ::executorch::extension::llm::load_tokenizer(FLAGS_tokenizer_path); From a2a46879028ab449d62856310c804da2ea4d9415 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Mon, 22 Dec 2025 12:34:02 -0800 Subject: [PATCH 4/9] block backend arg --- examples/models/parakeet/export_parakeet_tdt.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 56a9822cda2..fa1cc58d7ea 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -342,7 +342,7 @@ def lower_to_executorch(programs, metadata=None, backend="portable"): partitioner[key] = [CudaPartitioner(compile_specs)] else: - print("\nLowering to ExecuTorch") + print("\nLowering to ExecuTorch...") partitioner = [] constant_methods = {} @@ -369,18 +369,18 @@ def lower_to_executorch(programs, metadata=None, backend="portable"): def main(): import argparse + import sys parser = argparse.ArgumentParser() parser.add_argument("--output-dir", default="./parakeet_tdt_exports") parser.add_argument("--audio", type=str, help="Path to audio file for transcription test") - parser.add_argument( - "--backend", - choices=["portable", "xnnpack", "cuda", "cuda-windows"], - default="portable", - help="Backend for acceleration", - ) + parser.add_argument("--backend", type=str, default=None, help=argparse.SUPPRESS) args = parser.parse_args() + if args.backend is not None: + print("Error: --backend is not currently supported. Backend acceleration is still being verified.") + sys.exit(1) + os.makedirs(args.output_dir, exist_ok=True) print("Loading model...") @@ -389,7 +389,7 @@ def main(): print("\nExporting components...") programs, metadata = export_all(model) - et = lower_to_executorch(programs, metadata=metadata, backend=args.backend) + et = lower_to_executorch(programs, metadata=metadata) pte_path = os.path.join(args.output_dir, "parakeet_tdt.pte") print(f"\nSaving ExecuTorch program to: {pte_path}") From af801b10fc01edbd79b883bd5ac4d1be46143a39 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Mon, 22 Dec 2025 12:57:26 -0800 Subject: [PATCH 5/9] lint and readme --- examples/models/parakeet/CMakeLists.txt | 4 +- examples/models/parakeet/README.md | 20 +++++--- .../models/parakeet/export_parakeet_tdt.py | 49 +++++++++++++------ examples/models/parakeet/main.cpp | 38 ++++++++------ 4 files changed, 73 insertions(+), 38 deletions(-) diff --git a/examples/models/parakeet/CMakeLists.txt b/examples/models/parakeet/CMakeLists.txt index 9ca451c095b..950686dd00a 100644 --- a/examples/models/parakeet/CMakeLists.txt +++ b/examples/models/parakeet/CMakeLists.txt @@ -94,7 +94,9 @@ if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") endif() endif() -target_include_directories(parakeet_runner PUBLIC ${_common_include_directories}) +target_include_directories( + parakeet_runner PUBLIC ${_common_include_directories} +) target_link_libraries(parakeet_runner PUBLIC ${link_libraries}) target_compile_options(parakeet_runner PUBLIC ${_common_compile_options}) diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index e0debb9d5f2..4dc5d55318f 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -10,7 +10,7 @@ pip install nemo_toolkit[asr] torchaudio ## Export -Export the model (portable backend): +Export the model: ```bash python export_parakeet_tdt.py ``` @@ -25,17 +25,15 @@ python export_parakeet_tdt.py --audio /path/to/audio.wav | Argument | Description | |----------|-------------| | `--output-dir` | Output directory for exports (default: `./parakeet_tdt_exports`) | -| `--backend` | Backend for acceleration: `portable`, `xnnpack`, `cuda`, `cuda-windows` (default: `portable`) | | `--audio` | Path to audio file for transcription test | ## C++ Runner ### Building -First, build ExecuTorch with the LLM preset: +First, build ExecuTorch with the LLM preset from the executorch root directory: ```bash -cd executorch cmake --workflow --preset llm-release ``` @@ -46,12 +44,20 @@ cd examples/models/parakeet cmake --workflow --preset parakeet-cpu ``` +Available presets: +- `parakeet-cpu` - CPU-only build +- `parakeet-cuda` - CUDA acceleration (Linux/Windows) +- `parakeet-metal` - Metal acceleration (macOS) + ### Running +From the executorch root directory: + ```bash ./cmake-out/examples/models/parakeet/parakeet_runner \ - --model_path parakeet.pte \ - --audio_path audio.wav + --model_path examples/models/parakeet/parakeet_tdt_exports/parakeet_tdt.pte \ + --audio_path /path/to/audio.wav \ + --tokenizer_path examples/models/parakeet/tokenizer.model ``` ### Runner Arguments @@ -60,4 +66,4 @@ cmake --workflow --preset parakeet-cpu |----------|-------------| | `--model_path` | Path to Parakeet model (.pte) | | `--audio_path` | Path to input audio file (.wav) | -| `--tokenizer_path` | Path to tokenizer file (for token-to-text conversion) | +| `--tokenizer_path` | Path to tokenizer file (default: `tokenizer.json`) | diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index fa1cc58d7ea..9001bb0f2b6 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -1,10 +1,13 @@ -#!/usr/bin/env python3 """Export nvidia/parakeet-tdt-0.6b-v3 components to ExecuTorch.""" import os import torch -from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge_transform_and_lower +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) from executorch.exir.passes import MemoryPlanningPass from torch.export import Dim, export @@ -15,7 +18,7 @@ def load_audio(audio_path: str, sample_rate: int = 16000) -> torch.Tensor: import torchaudio waveform, sr = torchaudio.load(audio_path) - except (ImportError, Exception): + except Exception: from scipy.io import wavfile sr, data = wavfile.read(audio_path) @@ -44,7 +47,9 @@ def load_audio(audio_path: str, sample_rate: int = 16000) -> torch.Tensor: return waveform -def greedy_decode_eager(encoder_output: torch.Tensor, encoder_len: torch.Tensor, model) -> list[int]: +def greedy_decode_eager( + encoder_output: torch.Tensor, encoder_len: torch.Tensor, model +) -> list[int]: hypotheses = model.decoding.rnnt_decoder_predictions_tensor( encoder_output=encoder_output, encoded_lengths=encoder_len, @@ -182,7 +187,9 @@ def transcribe_executorch(audio_path: str, model, et_buffer) -> str: def transcribe_eager(audio_path: str, model) -> str: with torch.no_grad(): audio = load_audio(audio_path) - mel, mel_len = model.preprocessor(input_signal=audio, length=torch.tensor([audio.shape[1]])) + mel, mel_len = model.preprocessor( + input_signal=audio, length=torch.tensor([audio.shape[1]]) + ) encoded, encoded_len = model.encoder(audio_signal=mel, length=mel_len) tokens = greedy_decode_eager(encoded, encoded_len, model) return model.tokenizer.ids_to_text(tokens) @@ -228,7 +235,9 @@ def __init__(self, preprocessor): super().__init__() self.preprocessor = preprocessor - def forward(self, audio: torch.Tensor, length: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward( + self, audio: torch.Tensor, length: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: audio_signal = audio.unsqueeze(0) mel, mel_len = self.preprocessor(input_signal=audio_signal, length=length) return mel, mel_len @@ -244,7 +253,10 @@ def export_all(model): programs["preprocessor"] = export( preprocessor_wrapper, (sample_audio, sample_length), - dynamic_shapes={"audio": {0: Dim("audio_len", min=1600, max=16000 * 600)}, "length": {}}, + dynamic_shapes={ + "audio": {0: Dim("audio_len", min=1600, max=16000 * 600)}, + "length": {}, + }, strict=False, ) @@ -317,22 +329,27 @@ def lower_to_executorch(programs, metadata=None, backend="portable"): partitioner = None if backend == "xnnpack": - from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackPartitioner, + ) print("\nLowering to ExecuTorch with XNNPACK...") partitioner = [XnnpackPartitioner()] elif backend in ("cuda", "cuda-windows"): - from torch._inductor.decomposition import conv1d_to_conv2d - from executorch.backends.cuda.cuda_backend import CudaBackend from executorch.backends.cuda.cuda_partitioner import CudaPartitioner from executorch.exir.backend.compile_spec_schema import CompileSpec + from torch._inductor.decomposition import conv1d_to_conv2d - print(f"\nLowering to ExecuTorch with CUDA{' (Windows)' if backend == 'cuda-windows' else ''}...") + print( + f"\nLowering to ExecuTorch with CUDA{' (Windows)' if backend == 'cuda-windows' else ''}..." + ) for key, ep in programs.items(): - programs[key] = ep.run_decompositions({torch.ops.aten.conv1d.default: conv1d_to_conv2d}) + programs[key] = ep.run_decompositions( + {torch.ops.aten.conv1d.default: conv1d_to_conv2d} + ) partitioner = {} for key in programs.keys(): @@ -373,12 +390,16 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--output-dir", default="./parakeet_tdt_exports") - parser.add_argument("--audio", type=str, help="Path to audio file for transcription test") + parser.add_argument( + "--audio", type=str, help="Path to audio file for transcription test" + ) parser.add_argument("--backend", type=str, default=None, help=argparse.SUPPRESS) args = parser.parse_args() if args.backend is not None: - print("Error: --backend is not currently supported. Backend acceleration is still being verified.") + print( + "Error: --backend is not currently supported. Backend acceleration is still being verified." + ) sys.exit(1) os.makedirs(args.output_dir, exist_ok=True) diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index f99a55422e1..b247caa1f15 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -32,8 +32,8 @@ DEFINE_string( using ::executorch::extension::from_blob; using ::executorch::extension::Module; -using ::executorch::runtime::EValue; using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; namespace { @@ -67,12 +67,12 @@ std::vector greedy_decode_executorch( } } - auto transposed_tensor = from_blob( - transposed_data.data(), - {static_cast<::executorch::aten::SizesType>(batch), - static_cast<::executorch::aten::SizesType>(time_steps), - static_cast<::executorch::aten::SizesType>(enc_dim)}, - ::executorch::aten::ScalarType::Float); + auto transposed_tensor = from_blob( + transposed_data.data(), + {static_cast<::executorch::aten::SizesType>(batch), + static_cast<::executorch::aten::SizesType>(time_steps), + static_cast<::executorch::aten::SizesType>(enc_dim)}, + ::executorch::aten::ScalarType::Float); // Project encoder output auto proj_enc_result = model.execute( @@ -289,15 +289,16 @@ int main(int argc, char** argv) { audio_data.data(), {static_cast<::executorch::aten::SizesType>(audio_data.size())}, ::executorch::aten::ScalarType::Float); - std::vector audio_len_data = {static_cast(audio_data.size())}; + std::vector audio_len_data = { + static_cast(audio_data.size())}; auto audio_len_tensor = from_blob( - audio_len_data.data(), - {1}, - ::executorch::aten::ScalarType::Long); + audio_len_data.data(), {1}, ::executorch::aten::ScalarType::Long); ET_LOG(Info, "Running preprocessor..."); auto proc_result = model.execute( - "preprocessor", std::vector<::executorch::runtime::EValue>{audio_tensor, audio_len_tensor}); + "preprocessor", + std::vector<::executorch::runtime::EValue>{ + audio_tensor, audio_len_tensor}); if (!proc_result.ok()) { ET_LOG(Error, "Preprocessor forward failed."); return 1; @@ -309,8 +310,8 @@ int main(int argc, char** argv) { // Create mel_len tensor for encoder std::vector mel_len_data = {mel_len_value}; - auto mel_len = from_blob( - mel_len_data.data(), {1}, ::executorch::aten::ScalarType::Long); + auto mel_len = + from_blob(mel_len_data.data(), {1}, ::executorch::aten::ScalarType::Long); ET_LOG( Info, @@ -351,7 +352,9 @@ int main(int argc, char** argv) { if (!num_rnn_layers_result.ok() || !pred_hidden_result.ok() || !vocab_size_result.ok() || !blank_id_result.ok() || !sample_rate_result.ok()) { - ET_LOG(Error, "Failed to query model metadata. Make sure the model was exported with constant_methods."); + ET_LOG( + Error, + "Failed to query model metadata. Make sure the model was exported with constant_methods."); return 1; } @@ -387,7 +390,10 @@ int main(int argc, char** argv) { auto tokenizer = ::executorch::extension::llm::load_tokenizer(FLAGS_tokenizer_path); if (!tokenizer || !tokenizer->is_loaded()) { - ET_LOG(Error, "Failed to load tokenizer from: %s", FLAGS_tokenizer_path.c_str()); + ET_LOG( + Error, + "Failed to load tokenizer from: %s", + FLAGS_tokenizer_path.c_str()); return 1; } From fa22b18b5ed74fa4dd7030e50d8fc250eaf2420d Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Mon, 22 Dec 2025 13:46:01 -0800 Subject: [PATCH 6/9] cuda graph preprocessor issue --- examples/models/parakeet/export_parakeet_tdt.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 9001bb0f2b6..ffdb268b780 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -250,6 +250,9 @@ def export_all(model): preprocessor_wrapper.eval() sample_audio = torch.randn(16000 * 10) sample_length = torch.tensor([sample_audio.shape[0]], dtype=torch.int64) + old_cuda_is_available = torch.cuda.is_available() + # The preprocessor definition changes if cuda is available (likely due to making it cuda graphable). Unfortunately that new definition is not supported by export, so we need to stop that from happening. + torch.cuda.is_available = lambda: False programs["preprocessor"] = export( preprocessor_wrapper, (sample_audio, sample_length), @@ -259,6 +262,7 @@ def export_all(model): }, strict=False, ) + torch.cuda.is_available = lambda: old_cuda_is_available feat_in = getattr(model.encoder, "_feat_in", 128) audio_signal = torch.randn(1, feat_in, 100) From 1829505b9fd4d0132741f98a58a5296fe1aa85d2 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Mon, 22 Dec 2025 15:05:42 -0800 Subject: [PATCH 7/9] fix torch cuda bug --- examples/models/parakeet/README.md | 3 ++ .../models/parakeet/export_parakeet_tdt.py | 52 +++++++++++-------- 2 files changed, 34 insertions(+), 21 deletions(-) diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index 4dc5d55318f..a899fc15b9c 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -25,8 +25,11 @@ python export_parakeet_tdt.py --audio /path/to/audio.wav | Argument | Description | |----------|-------------| | `--output-dir` | Output directory for exports (default: `./parakeet_tdt_exports`) | +| `--backend` | Backend for acceleration: `portable`, `xnnpack`, `cuda`, `cuda-windows` (default: `portable`) | | `--audio` | Path to audio file for transcription test | +**Note:** The preprocessor is always lowered with the portable backend regardless of the `--backend` setting. + ## C++ Runner ### Building diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index ffdb268b780..325cba41f5b 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -250,8 +250,9 @@ def export_all(model): preprocessor_wrapper.eval() sample_audio = torch.randn(16000 * 10) sample_length = torch.tensor([sample_audio.shape[0]], dtype=torch.int64) - old_cuda_is_available = torch.cuda.is_available() - # The preprocessor definition changes if cuda is available (likely due to making it cuda graphable). Unfortunately that new definition is not supported by export, so we need to stop that from happening. + # The preprocessor definition changes if cuda is available (likely due to making it cuda graphable). + # Unfortunately that new definition is not supported by export, so we need to stop that from happening. + old_cuda_is_available = torch.cuda.is_available torch.cuda.is_available = lambda: False programs["preprocessor"] = export( preprocessor_wrapper, @@ -262,7 +263,7 @@ def export_all(model): }, strict=False, ) - torch.cuda.is_available = lambda: old_cuda_is_available + torch.cuda.is_available = old_cuda_is_available feat_in = getattr(model.encoder, "_feat_in", 128) audio_signal = torch.randn(1, feat_in, 100) @@ -330,7 +331,7 @@ def export_all(model): def lower_to_executorch(programs, metadata=None, backend="portable"): - partitioner = None + partitioner = {} if backend == "xnnpack": from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( @@ -338,7 +339,11 @@ def lower_to_executorch(programs, metadata=None, backend="portable"): ) print("\nLowering to ExecuTorch with XNNPACK...") - partitioner = [XnnpackPartitioner()] + for key in programs.keys(): + if key == "preprocessor": + partitioner[key] = [] + else: + partitioner[key] = [XnnpackPartitioner()] elif backend in ("cuda", "cuda-windows"): from executorch.backends.cuda.cuda_backend import CudaBackend @@ -351,16 +356,21 @@ def lower_to_executorch(programs, metadata=None, backend="portable"): ) for key, ep in programs.items(): - programs[key] = ep.run_decompositions( - {torch.ops.aten.conv1d.default: conv1d_to_conv2d} - ) + if key != "preprocessor": + programs[key] = ep.run_decompositions( + {torch.ops.aten.conv1d.default: conv1d_to_conv2d} + ) - partitioner = {} for key in programs.keys(): - compile_specs = [CudaBackend.generate_method_name_compile_spec(key)] - if backend == "cuda-windows": - compile_specs.append(CompileSpec("platform", "windows".encode("utf-8"))) - partitioner[key] = [CudaPartitioner(compile_specs)] + if key == "preprocessor": + partitioner[key] = [] + else: + compile_specs = [CudaBackend.generate_method_name_compile_spec(key)] + if backend == "cuda-windows": + compile_specs.append( + CompileSpec("platform", "windows".encode("utf-8")) + ) + partitioner[key] = [CudaPartitioner(compile_specs)] else: print("\nLowering to ExecuTorch...") @@ -397,15 +407,15 @@ def main(): parser.add_argument( "--audio", type=str, help="Path to audio file for transcription test" ) - parser.add_argument("--backend", type=str, default=None, help=argparse.SUPPRESS) + parser.add_argument( + "--backend", + type=str, + default="portable", + choices=["portable", "xnnpack", "cuda", "cuda-windows"], + help="Backend for acceleration (default: portable)", + ) args = parser.parse_args() - if args.backend is not None: - print( - "Error: --backend is not currently supported. Backend acceleration is still being verified." - ) - sys.exit(1) - os.makedirs(args.output_dir, exist_ok=True) print("Loading model...") @@ -414,7 +424,7 @@ def main(): print("\nExporting components...") programs, metadata = export_all(model) - et = lower_to_executorch(programs, metadata=metadata) + et = lower_to_executorch(programs, metadata=metadata, backend=args.backend) pte_path = os.path.join(args.output_dir, "parakeet_tdt.pte") print(f"\nSaving ExecuTorch program to: {pte_path}") From 54a6319692aa7f5881152739374f8d01ff3b5b4d Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Mon, 22 Dec 2025 15:24:42 -0800 Subject: [PATCH 8/9] ptd support --- examples/models/parakeet/README.md | 1 + .../models/parakeet/export_parakeet_tdt.py | 15 ++++++++- examples/models/parakeet/main.cpp | 31 +++++++++++++------ 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/examples/models/parakeet/README.md b/examples/models/parakeet/README.md index a899fc15b9c..045f22571fd 100644 --- a/examples/models/parakeet/README.md +++ b/examples/models/parakeet/README.md @@ -70,3 +70,4 @@ From the executorch root directory: | `--model_path` | Path to Parakeet model (.pte) | | `--audio_path` | Path to input audio file (.wav) | | `--tokenizer_path` | Path to tokenizer file (default: `tokenizer.json`) | +| `--data_path` | Path to data file (.ptd) for delegate data (optional, required for CUDA) | diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 325cba41f5b..0547203bc10 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -198,8 +198,11 @@ def transcribe_eager(audio_path: str, model) -> str: def load_model(): import nemo.collections.asr as nemo_asr - model = nemo_asr.models.ASRModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v3") + model = nemo_asr.models.ASRModel.from_pretrained( + "nvidia/parakeet-tdt-0.6b-v3", map_location="cpu" + ) model.eval() + model.cpu() return model @@ -432,6 +435,16 @@ def main(): et.write_to_file(f) print(f"Saved {os.path.getsize(pte_path) / (1024 * 1024):.1f} MB") + # Save .ptd data files (e.g., CUDA delegate data) + data_files = et.data_files + if data_files: + print(f"\nSaving {len(data_files)} data file(s)...") + for filename, data in data_files.items(): + ptd_path = os.path.join(args.output_dir, filename) + with open(ptd_path, "wb") as f: + f.write(data) + print(f" Saved {filename} ({len(data) / (1024 * 1024):.1f} MB)") + if args.audio: print("\n" + "=" * 60) print("Testing transcription...") diff --git a/examples/models/parakeet/main.cpp b/examples/models/parakeet/main.cpp index b247caa1f15..173cf722d77 100644 --- a/examples/models/parakeet/main.cpp +++ b/examples/models/parakeet/main.cpp @@ -29,6 +29,10 @@ DEFINE_string( tokenizer_path, "tokenizer.json", "Path to SentencePiece tokenizer model file."); +DEFINE_string( + data_path, + "", + "Path to data file (.ptd) for delegate data (optional, required for CUDA)."); using ::executorch::extension::from_blob; using ::executorch::extension::Module; @@ -272,8 +276,15 @@ int main(int argc, char** argv) { // Load model (which includes the bundled preprocessor) ET_LOG(Info, "Loading model from: %s", FLAGS_model_path.c_str()); - Module model(FLAGS_model_path, Module::LoadMode::Mmap); - auto model_load_error = model.load(); + std::unique_ptr model; + if (!FLAGS_data_path.empty()) { + ET_LOG(Info, "Loading data from: %s", FLAGS_data_path.c_str()); + model = std::make_unique( + FLAGS_model_path, FLAGS_data_path, Module::LoadMode::Mmap); + } else { + model = std::make_unique(FLAGS_model_path, Module::LoadMode::Mmap); + } + auto model_load_error = model->load(); if (model_load_error != Error::Ok) { ET_LOG(Error, "Failed to load model."); return 1; @@ -295,7 +306,7 @@ int main(int argc, char** argv) { audio_len_data.data(), {1}, ::executorch::aten::ScalarType::Long); ET_LOG(Info, "Running preprocessor..."); - auto proc_result = model.execute( + auto proc_result = model->execute( "preprocessor", std::vector<::executorch::runtime::EValue>{ audio_tensor, audio_len_tensor}); @@ -323,7 +334,7 @@ int main(int argc, char** argv) { // Run encoder ET_LOG(Info, "Running encoder..."); - auto enc_result = model.execute( + auto enc_result = model->execute( "encoder", std::vector<::executorch::runtime::EValue>{mel, mel_len}); if (!enc_result.ok()) { ET_LOG(Error, "Encoder forward failed."); @@ -343,11 +354,11 @@ int main(int argc, char** argv) { // Query model metadata from constant_methods std::vector<::executorch::runtime::EValue> empty_inputs; - auto num_rnn_layers_result = model.execute("num_rnn_layers", empty_inputs); - auto pred_hidden_result = model.execute("pred_hidden", empty_inputs); - auto vocab_size_result = model.execute("vocab_size", empty_inputs); - auto blank_id_result = model.execute("blank_id", empty_inputs); - auto sample_rate_result = model.execute("sample_rate", empty_inputs); + auto num_rnn_layers_result = model->execute("num_rnn_layers", empty_inputs); + auto pred_hidden_result = model->execute("pred_hidden", empty_inputs); + auto vocab_size_result = model->execute("vocab_size", empty_inputs); + auto blank_id_result = model->execute("blank_id", empty_inputs); + auto sample_rate_result = model->execute("sample_rate", empty_inputs); if (!num_rnn_layers_result.ok() || !pred_hidden_result.ok() || !vocab_size_result.ok() || !blank_id_result.ok() || @@ -375,7 +386,7 @@ int main(int argc, char** argv) { ET_LOG(Info, "Running TDT greedy decode..."); auto tokens = greedy_decode_executorch( - model, + *model, encoded, encoded_len, blank_id, From 09cb1be935bb32f2d58831df790b1a86316210b0 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Mon, 22 Dec 2025 15:37:14 -0800 Subject: [PATCH 9/9] ptd serialization --- examples/models/parakeet/export_parakeet_tdt.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index 0547203bc10..509da67051c 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -436,14 +436,9 @@ def main(): print(f"Saved {os.path.getsize(pte_path) / (1024 * 1024):.1f} MB") # Save .ptd data files (e.g., CUDA delegate data) - data_files = et.data_files - if data_files: - print(f"\nSaving {len(data_files)} data file(s)...") - for filename, data in data_files.items(): - ptd_path = os.path.join(args.output_dir, filename) - with open(ptd_path, "wb") as f: - f.write(data) - print(f" Saved {filename} ({len(data) / (1024 * 1024):.1f} MB)") + if et._tensor_data: + print(f"\nSaving {len(et._tensor_data)} data file(s)...") + et.write_tensor_data_to_file(args.output_dir) if args.audio: print("\n" + "=" * 60)