1+ #
2+ # SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3+ #
4+ # SPDX-License-Identifier: Apache-2.0
5+ #
6+
7+ #!/usr/bin/env python3
8+
9+ import argparse
10+ import os
11+ import sys
12+ import tempfile
13+ from pathlib import Path
14+
15+ import torchaudio
16+ from torchaudio .datasets import LIBRISPEECH
17+
18+ import sentencepiece as spm
19+
20+
21+ def normalize_text (text : str , lowercase : bool = False ) -> str :
22+ t = text .strip ()
23+ if lowercase :
24+ t = t .lower ()
25+ return t
26+
27+
28+ def build_corpus (root : str , subset : str , lowercase : bool , limit : int | None ) -> str :
29+ dataset = LIBRISPEECH (root = root , url = subset , download = True )
30+ n = len (dataset )
31+ if limit is not None :
32+ n = min (n , limit )
33+
34+ tmp_fd , tmp_path = tempfile .mkstemp (prefix = "librispeech_corpus_" , suffix = ".txt" )
35+ os .close (tmp_fd )
36+
37+ with open (tmp_path , "w" , encoding = "utf-8" ) as f :
38+ for idx in range (n ):
39+ try :
40+ _ , _ , transcript , * _ = dataset [idx ]
41+ except Exception as ex :
42+ print (f"Warning: failed to read sample { idx } : { ex } " , file = sys .stderr )
43+ continue
44+ line = normalize_text (transcript , lowercase )
45+ if line :
46+ f .write (line + "\n " )
47+ return tmp_path
48+
49+
50+ def train_sentencepiece (
51+ corpus_path : str ,
52+ output_dir : str ,
53+ vocab_size : int ,
54+ model_type : str ,
55+ character_coverage : float ,
56+ model_prefix : str ,
57+ pad_id : int ,
58+ disable_bos_eos : bool ,
59+ seed_sentencepiece : int | None ,
60+ input_sentence_size : int ,
61+ ):
62+ Path (output_dir ).mkdir (parents = True , exist_ok = True )
63+ model_prefix_path = str (Path (output_dir ) / model_prefix )
64+
65+ sp_args = [
66+ f"--input={ corpus_path } " ,
67+ f"--model_prefix={ model_prefix_path } " ,
68+ f"--vocab_size={ vocab_size } " ,
69+ f"--model_type={ model_type } " ,
70+ f"--character_coverage={ character_coverage } " ,
71+ "--unk_id=0" ,
72+ "--input_sentence_size=" + str (input_sentence_size ),
73+ "--shuffle_input_sentence=true" ,
74+ "--hard_vocab_limit=true" ,
75+ "--num_threads=32" ,
76+ ]
77+
78+ if disable_bos_eos :
79+ sp_args += ["--bos_id=-1" , "--eos_id=-1" ]
80+ else :
81+ sp_args += ["--bos_id=1" , "--eos_id=2" ]
82+
83+ if pad_id is None or pad_id < 0 :
84+ sp_args += ["--pad_id=-1" ]
85+ else :
86+ sp_args += [f"--pad_id={ pad_id } " ]
87+
88+ if seed_sentencepiece is not None :
89+ sp_args += [f"--seed_sentencepiece_size={ seed_sentencepiece } " ]
90+
91+ spm .SentencePieceTrainer .Train (" " .join (sp_args ))
92+
93+ model_path = model_prefix_path + ".model"
94+ vocab_path = model_prefix_path + ".vocab"
95+ return model_path , vocab_path
96+
97+
98+ def main ():
99+ parser = argparse .ArgumentParser (description = "Train a 128-token SentencePiece tokenizer on LibriSpeech using torchaudio." )
100+ parser .add_argument ("--root" , type = str , default = "./data" , help = "Directory to store/lookup LibriSpeech." )
101+ parser .add_argument ("--subset" , type = str , default = "train-clean-100" ,
102+ choices = [
103+ "train-clean-100" , "train-clean-360" , "train-other-500" ,
104+ "dev-clean" , "dev-other" , "test-clean" , "test-other"
105+ ],
106+ help = "LibriSpeech subset to use." )
107+ parser .add_argument ("--output_dir" , type = str , default = "./tokenizer_out" , help = "Where to write the tokenizer files." )
108+ parser .add_argument ("--vocab_size" , type = int , default = 128 , help = "Total vocabulary size." )
109+ parser .add_argument ("--model_type" , type = str , default = "unigram" , choices = ["unigram" , "bpe" ],
110+ help = "SentencePiece model type. 'unigram' tends to match English lists like yours." )
111+ parser .add_argument ("--character_coverage" , type = float , default = 1.0 ,
112+ help = "Fraction of characters covered by the model (1.0 is fine for English)." )
113+ parser .add_argument ("--model_prefix" , type = str , default = "librispeech_sp" ,
114+ help = "Prefix (filename stem) for the trained model." )
115+ parser .add_argument ("--lowercase" , action = "store_true" , help = "Lowercase transcripts before training." )
116+ parser .add_argument ("--pad_id" , type = int , default = - 1 , help = "Pad ID; set to -1 to disable (default)." )
117+ parser .add_argument ("--disable_bos_eos" , action = "store_true" , help = "Disable BOS/EOS (recommended)." )
118+ parser .add_argument ("--enable_bos_eos" , action = "store_true" , help = "Enable BOS/EOS tokens." )
119+ parser .add_argument ("--limit" , type = int , default = None , help = "Limit the number of samples for a quick run." )
120+ parser .add_argument ("--seed_sentencepiece_size" , type = int , default = None ,
121+ help = "Advanced: initial seed size for SentencePiece's sentence sampling (optional)." )
122+ parser .add_argument ("--input_sentence_size" , type = int , default = 1000000 ,
123+ help = "Number of sentences to sample during training." )
124+
125+ args = parser .parse_args ()
126+
127+ disable_bos_eos = True
128+ if args .enable_bos_eos :
129+ disable_bos_eos = False
130+ if args .disable_bos_eos :
131+ disable_bos_eos = True
132+
133+ print (f"Preparing corpus from LibriSpeech subset='{ args .subset } '..." )
134+ corpus_path = build_corpus (root = args .root , subset = args .subset , lowercase = args .lowercase , limit = args .limit )
135+
136+ print (f"Training SentencePiece... corpus_path { corpus_path } " )
137+ model_path , vocab_path = train_sentencepiece (
138+ corpus_path = corpus_path ,
139+ output_dir = args .output_dir ,
140+ vocab_size = args .vocab_size ,
141+ model_type = args .model_type ,
142+ character_coverage = args .character_coverage ,
143+ model_prefix = args .model_prefix ,
144+ pad_id = args .pad_id ,
145+ disable_bos_eos = disable_bos_eos ,
146+ seed_sentencepiece = args .seed_sentencepiece_size ,
147+ input_sentence_size = args .input_sentence_size ,
148+ )
149+
150+ print ("Done!" )
151+ print (f"Model : { model_path } " )
152+ print (f"Vocab : { vocab_path } " )
153+
154+
155+ if __name__ == "__main__" :
156+ main ()
0 commit comments