diff --git a/examples/models/parakeet/export_parakeet_tdt.py b/examples/models/parakeet/export_parakeet_tdt.py index abf4292f852..92e32ca30bf 100644 --- a/examples/models/parakeet/export_parakeet_tdt.py +++ b/examples/models/parakeet/export_parakeet_tdt.py @@ -2,6 +2,9 @@ import argparse import os +import shutil +import tarfile +import tempfile import torch import torchaudio @@ -188,6 +191,51 @@ def load_model(): return model +def extract_tokenizer(output_dir: str) -> str | None: + """Extract tokenizer.model from the cached .nemo file. + + Args: + output_dir: Directory to save the tokenizer.model file. + + Returns: + Path to the extracted tokenizer.model, or None if extraction failed. + """ + from huggingface_hub import hf_hub_download + + # Download/get cached .nemo file path + nemo_path = hf_hub_download( + repo_id="nvidia/parakeet-tdt-0.6b-v3", + filename="parakeet-tdt-0.6b-v3.nemo", + ) + + # .nemo files are tar archives - extract tokenizer.model + tokenizer_filename = "tokenizer.model" + output_path = os.path.join(output_dir, tokenizer_filename) + + with tempfile.TemporaryDirectory() as tmpdir: + with tarfile.open(nemo_path, "r") as tar: + # Find tokenizer.model in the archive (may be in root or subdirectory) + tokenizer_member = None + for member in tar.getmembers(): + if member.name.endswith(tokenizer_filename): + tokenizer_member = member + break + + if tokenizer_member is None: + print(f"Warning: {tokenizer_filename} not found in .nemo archive") + return None + + # Extract to temp directory + tar.extract(tokenizer_member, tmpdir) + extracted_path = os.path.join(tmpdir, tokenizer_member.name) + + # Copy to output directory + shutil.copy2(extracted_path, output_path) + + print(f"Extracted tokenizer to: {output_path}") + return output_path + + class JointAfterProjection(torch.nn.Module): def __init__(self, joint): super().__init__() @@ -401,6 +449,9 @@ def main(): os.makedirs(args.output_dir, exist_ok=True) + print("Extracting tokenizer...") + extract_tokenizer(args.output_dir) + print("Loading model...") model = load_model()