Skip to content

Commit 50a3574

Browse files
committed
Conformer training scripts
Add Conformer model PyTorch training scripts. Produced checkpoint can be used to get an accurate int8 ExecuTorch PTE file. Change-Id: I8d19db5612172de5eac535572fdd2f48b770ffda
1 parent e5dcc9d commit 50a3574

File tree

5 files changed

+973
-0
lines changed

5 files changed

+973
-0
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Conformer-S Model Training
2+
3+
This repository provides an example of training the **Conformer-S** model on the **LibriSpeech** dataset.
4+
5+
## External dependencies
6+
- **Model**: https://github.com/sooftware/conformer implementation of Conformer-S
7+
- **Dataset**: LibriSpeech (downloaded via `torchaudio`) - used both to generate Tokenizer and Conformer model
8+
- **Tokenizer**: Generated using https://github.com/google/sentencepiece/
9+
- **Python Dependencies**: Python packages listed in **requirements.txt**.
10+
11+
## Environment description
12+
- AWS g5.24xlarge instance
13+
- Python version 3.12.7
14+
- AWS AMI - Deep Learning OSS Nvidia Driver AMI GPU PyTorch (Ubuntu 22.04)
15+
16+
## Setup
17+
1) Make sure the Conformer repository is cloned in the same directory as the training script:
18+
```angular2html
19+
git clone https://github.com/sooftware/conformer.git
20+
```
21+
2) Generate SentencePiece Tokenizer
22+
- More information on what is SentencePiece tokenizer and how to use it can be found at https://github.com/google/sentencepiece?tab=readme-ov-file#overview
23+
- Generate the tokenizer using the following command
24+
```angular2html
25+
!python build_sp_128_librispeech.py \
26+
--root ./data \
27+
--subset train-clean-100 \
28+
--output_dir ./tokenizer_out \
29+
--vocab_size 128 \
30+
--model_type unigram \
31+
--lowercase \
32+
--disable_bos_eos \
33+
--pad_id -1
34+
```
35+
- Pass the tokenizer path to the training script via the --sp-model argument
36+
3) create an empty data folder in the same directory as the training script
37+
## Training
38+
Run the following command to start training:
39+
```angular2html
40+
!CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \
41+
--train-sets "train-clean-100,train-clean-360,train-other-500" \
42+
--valid-set "dev-clean" \
43+
--epochs 160 \
44+
--batch-size 96 \
45+
--lr=0.0005 \
46+
--betas 0.9,0.98 \
47+
--weight-decay 1e-6 \
48+
--warmup-epochs 2.0 \
49+
--grad-clip 5 \
50+
--root "data" \
51+
--save-dir "checkpoints" \
52+
--num-workers=32 \
53+
--accum-steps 16 \
54+
2>&1 | tee train_log.txt
55+
```
56+
## Notes and recommendations
57+
- Hyperparameter tuning and active monitoring (“model babysitting”) are strongly recommended to achieve optimal performance
58+
- We should be able to reach WER in the range of 6%-7% on the test clean dataset
59+
- Ckeckpoints will be saved under the checkpoints/ directory
60+
- Logs are written to train_log.txt for convenience
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3+
#
4+
# SPDX-License-Identifier: Apache-2.0
5+
#
6+
7+
#!/usr/bin/env python3
8+
9+
import argparse
10+
import os
11+
import sys
12+
import tempfile
13+
from pathlib import Path
14+
15+
import torchaudio
16+
from torchaudio.datasets import LIBRISPEECH
17+
18+
import sentencepiece as spm
19+
20+
21+
def normalize_text(text: str, lowercase: bool = False) -> str:
22+
t = text.strip()
23+
if lowercase:
24+
t = t.lower()
25+
return t
26+
27+
28+
def build_corpus(root: str, subset: str, lowercase: bool, limit: int | None) -> str:
29+
dataset = LIBRISPEECH(root=root, url=subset, download=True)
30+
n = len(dataset)
31+
if limit is not None:
32+
n = min(n, limit)
33+
34+
tmp_fd, tmp_path = tempfile.mkstemp(prefix="librispeech_corpus_", suffix=".txt")
35+
os.close(tmp_fd)
36+
37+
with open(tmp_path, "w", encoding="utf-8") as f:
38+
for idx in range(n):
39+
try:
40+
_, _, transcript, *_ = dataset[idx]
41+
except Exception as ex:
42+
print(f"Warning: failed to read sample {idx}: {ex}", file=sys.stderr)
43+
continue
44+
line = normalize_text(transcript, lowercase)
45+
if line:
46+
f.write(line + "\n")
47+
return tmp_path
48+
49+
50+
def train_sentencepiece(
51+
corpus_path: str,
52+
output_dir: str,
53+
vocab_size: int,
54+
model_type: str,
55+
character_coverage: float,
56+
model_prefix: str,
57+
pad_id: int,
58+
disable_bos_eos: bool,
59+
seed_sentencepiece: int | None,
60+
input_sentence_size: int,
61+
):
62+
Path(output_dir).mkdir(parents=True, exist_ok=True)
63+
model_prefix_path = str(Path(output_dir) / model_prefix)
64+
65+
sp_args = [
66+
f"--input={corpus_path}",
67+
f"--model_prefix={model_prefix_path}",
68+
f"--vocab_size={vocab_size}",
69+
f"--model_type={model_type}",
70+
f"--character_coverage={character_coverage}",
71+
"--unk_id=0",
72+
"--input_sentence_size="+str(input_sentence_size),
73+
"--shuffle_input_sentence=true",
74+
"--hard_vocab_limit=true",
75+
"--num_threads=32",
76+
]
77+
78+
if disable_bos_eos:
79+
sp_args += ["--bos_id=-1", "--eos_id=-1"]
80+
else:
81+
sp_args += ["--bos_id=1", "--eos_id=2"]
82+
83+
if pad_id is None or pad_id < 0:
84+
sp_args += ["--pad_id=-1"]
85+
else:
86+
sp_args += [f"--pad_id={pad_id}"]
87+
88+
if seed_sentencepiece is not None:
89+
sp_args += [f"--seed_sentencepiece_size={seed_sentencepiece}"]
90+
91+
spm.SentencePieceTrainer.Train(" ".join(sp_args))
92+
93+
model_path = model_prefix_path + ".model"
94+
vocab_path = model_prefix_path + ".vocab"
95+
return model_path, vocab_path
96+
97+
98+
def main():
99+
parser = argparse.ArgumentParser(description="Train a 128-token SentencePiece tokenizer on LibriSpeech using torchaudio.")
100+
parser.add_argument("--root", type=str, default="./data", help="Directory to store/lookup LibriSpeech.")
101+
parser.add_argument("--subset", type=str, default="train-clean-100",
102+
choices=[
103+
"train-clean-100", "train-clean-360", "train-other-500",
104+
"dev-clean", "dev-other", "test-clean", "test-other"
105+
],
106+
help="LibriSpeech subset to use.")
107+
parser.add_argument("--output_dir", type=str, default="./tokenizer_out", help="Where to write the tokenizer files.")
108+
parser.add_argument("--vocab_size", type=int, default=128, help="Total vocabulary size.")
109+
parser.add_argument("--model_type", type=str, default="unigram", choices=["unigram", "bpe"],
110+
help="SentencePiece model type. 'unigram' tends to match English lists like yours.")
111+
parser.add_argument("--character_coverage", type=float, default=1.0,
112+
help="Fraction of characters covered by the model (1.0 is fine for English).")
113+
parser.add_argument("--model_prefix", type=str, default="librispeech_sp",
114+
help="Prefix (filename stem) for the trained model.")
115+
parser.add_argument("--lowercase", action="store_true", help="Lowercase transcripts before training.")
116+
parser.add_argument("--pad_id", type=int, default=-1, help="Pad ID; set to -1 to disable (default).")
117+
parser.add_argument("--disable_bos_eos", action="store_true", help="Disable BOS/EOS (recommended).")
118+
parser.add_argument("--enable_bos_eos", action="store_true", help="Enable BOS/EOS tokens.")
119+
parser.add_argument("--limit", type=int, default=None, help="Limit the number of samples for a quick run.")
120+
parser.add_argument("--seed_sentencepiece_size", type=int, default=None,
121+
help="Advanced: initial seed size for SentencePiece's sentence sampling (optional).")
122+
parser.add_argument("--input_sentence_size", type=int, default=1000000,
123+
help="Number of sentences to sample during training.")
124+
125+
args = parser.parse_args()
126+
127+
disable_bos_eos = True
128+
if args.enable_bos_eos:
129+
disable_bos_eos = False
130+
if args.disable_bos_eos:
131+
disable_bos_eos = True
132+
133+
print(f"Preparing corpus from LibriSpeech subset='{args.subset}'...")
134+
corpus_path = build_corpus(root=args.root, subset=args.subset, lowercase=args.lowercase, limit=args.limit)
135+
136+
print(f"Training SentencePiece... corpus_path {corpus_path}")
137+
model_path, vocab_path = train_sentencepiece(
138+
corpus_path=corpus_path,
139+
output_dir=args.output_dir,
140+
vocab_size=args.vocab_size,
141+
model_type=args.model_type,
142+
character_coverage=args.character_coverage,
143+
model_prefix=args.model_prefix,
144+
pad_id=args.pad_id,
145+
disable_bos_eos=disable_bos_eos,
146+
seed_sentencepiece=args.seed_sentencepiece_size,
147+
input_sentence_size=args.input_sentence_size,
148+
)
149+
150+
print("Done!")
151+
print(f"Model : {model_path}")
152+
print(f"Vocab : {vocab_path}")
153+
154+
155+
if __name__ == "__main__":
156+
main()
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3+
#
4+
# SPDX-License-Identifier: Apache-2.0
5+
#
6+
7+
#!/usr/bin/env python3
8+
from __future__ import annotations
9+
import argparse
10+
from pathlib import Path
11+
12+
import torch
13+
import torch.nn as nn
14+
from torch.utils.data import DataLoader
15+
from torchaudio.datasets import LIBRISPEECH
16+
import sentencepiece as spm
17+
18+
# ---------------------------------------------------------------------
19+
# Import the components we need from the original training script
20+
# ---------------------------------------------------------------------
21+
import train # assumes train.py is in the same directory
22+
from train import (
23+
create_model,
24+
AudioPreprocessor,
25+
collate_eval_factory,
26+
evaluate,
27+
)
28+
29+
# ---------------------------------------------------------------------
30+
# Argument parsing
31+
# ---------------------------------------------------------------------
32+
def parse_args() -> argparse.Namespace:
33+
p = argparse.ArgumentParser(description="Conformer evaluation only")
34+
p.add_argument("--root", type=str, default="/shared/LIBRISPEECH",
35+
help="LibriSpeech root directory")
36+
p.add_argument("--set", type=str, default="test-clean",
37+
help="Comma-separated LibriSpeech subset names (e.g. test-clean,test-other)")
38+
p.add_argument("--batch-size", type=int, default=64,
39+
help="Batch size for evaluation")
40+
p.add_argument("--num-workers", type=int, default=4,
41+
help="DataLoader worker processes")
42+
p.add_argument("--sp-model", type=str,
43+
default="tokenizer_out/librispeech_sp.model",
44+
help="Path to the SentencePiece *.model used at training time")
45+
p.add_argument("--checkpoint", type=str, required=True,
46+
help="Path to a trained checkpoint (*.pt)")
47+
return p.parse_args()
48+
49+
# ---------------------------------------------------------------------
50+
# Main evaluation routine
51+
# ---------------------------------------------------------------------
52+
def main() -> None:
53+
args = parse_args()
54+
55+
# -------- SentencePiece ----------
56+
sp = spm.SentencePieceProcessor()
57+
sp.load(args.sp_model)
58+
59+
# expose the tokenizer inside the imported `train` module so that
60+
# train.int_to_text() works correctly during decoding
61+
train.sp = sp
62+
63+
vocab_size = sp.get_piece_size() + 1 # +1 for CTC blank
64+
65+
# -------- Model ----------
66+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67+
model = create_model(vocab_size).to(device)
68+
69+
ckpt_path = Path(args.checkpoint)
70+
ckpt = torch.load(ckpt_path, map_location=device)
71+
state = ckpt.get("model", ckpt) # handles raw state_dict vs wrapper
72+
73+
# tolerate DataParallel prefix differences
74+
missing, unexpected = model.load_state_dict(state, strict=False)
75+
if missing:
76+
print(f"[warn] missing keys in checkpoint: {missing}")
77+
if unexpected:
78+
print(f"[warn] unexpected keys in checkpoint: {unexpected}")
79+
80+
print(f"=> loaded weights from {ckpt_path}")
81+
82+
if torch.cuda.device_count() > 1:
83+
print(f"Using {torch.cuda.device_count()} GPUs via DataParallel …")
84+
model = torch.nn.DataParallel(model)
85+
86+
loss_fn = nn.CTCLoss(blank=0, zero_infinity=True)
87+
88+
# -------- Datasets & evaluation ----------
89+
subsets = [s.strip() for s in args.set.split(",") if s.strip()]
90+
preproc = AudioPreprocessor(training=False)
91+
for subset in subsets:
92+
ds = LIBRISPEECH(args.root, url=subset, download=True)
93+
loader = DataLoader(
94+
ds,
95+
batch_size=args.batch_size,
96+
shuffle=False,
97+
collate_fn=collate_eval_factory(preproc),
98+
num_workers=args.num_workers,
99+
)
100+
101+
wer, val_loss = evaluate(model, loader, device, loss_fn)
102+
print(f"\n── Results on {subset} ──")
103+
print(f" • WER: {wer:6.2f} %")
104+
print(f" • CTC loss: {val_loss:8.4f}")
105+
106+
if __name__ == "__main__":
107+
main()

0 commit comments

Comments
 (0)