Skip to content

Commit 6474aac

Browse files
authored
Merge pull request #169 from Alexey234432/conformer_training
2 parents e5dcc9d + 50a3574 commit 6474aac

File tree

5 files changed

+973
-0
lines changed

5 files changed

+973
-0
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Conformer-S Model Training
2+
3+
This repository provides an example of training the **Conformer-S** model on the **LibriSpeech** dataset.
4+
5+
## External dependencies
6+
- **Model**: https://github.com/sooftware/conformer implementation of Conformer-S
7+
- **Dataset**: LibriSpeech (downloaded via `torchaudio`) - used both to generate Tokenizer and Conformer model
8+
- **Tokenizer**: Generated using https://github.com/google/sentencepiece/
9+
- **Python Dependencies**: Python packages listed in **requirements.txt**.
10+
11+
## Environment description
12+
- AWS g5.24xlarge instance
13+
- Python version 3.12.7
14+
- AWS AMI - Deep Learning OSS Nvidia Driver AMI GPU PyTorch (Ubuntu 22.04)
15+
16+
## Setup
17+
1) Make sure the Conformer repository is cloned in the same directory as the training script:
18+
```angular2html
19+
git clone https://github.com/sooftware/conformer.git
20+
```
21+
2) Generate SentencePiece Tokenizer
22+
- More information on what is SentencePiece tokenizer and how to use it can be found at https://github.com/google/sentencepiece?tab=readme-ov-file#overview
23+
- Generate the tokenizer using the following command
24+
```angular2html
25+
!python build_sp_128_librispeech.py \
26+
--root ./data \
27+
--subset train-clean-100 \
28+
--output_dir ./tokenizer_out \
29+
--vocab_size 128 \
30+
--model_type unigram \
31+
--lowercase \
32+
--disable_bos_eos \
33+
--pad_id -1
34+
```
35+
- Pass the tokenizer path to the training script via the --sp-model argument
36+
3) create an empty data folder in the same directory as the training script
37+
## Training
38+
Run the following command to start training:
39+
```angular2html
40+
!CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \
41+
--train-sets "train-clean-100,train-clean-360,train-other-500" \
42+
--valid-set "dev-clean" \
43+
--epochs 160 \
44+
--batch-size 96 \
45+
--lr=0.0005 \
46+
--betas 0.9,0.98 \
47+
--weight-decay 1e-6 \
48+
--warmup-epochs 2.0 \
49+
--grad-clip 5 \
50+
--root "data" \
51+
--save-dir "checkpoints" \
52+
--num-workers=32 \
53+
--accum-steps 16 \
54+
2>&1 | tee train_log.txt
55+
```
56+
## Notes and recommendations
57+
- Hyperparameter tuning and active monitoring (“model babysitting”) are strongly recommended to achieve optimal performance
58+
- We should be able to reach WER in the range of 6%-7% on the test clean dataset
59+
- Ckeckpoints will be saved under the checkpoints/ directory
60+
- Logs are written to train_log.txt for convenience
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3+
#
4+
# SPDX-License-Identifier: Apache-2.0
5+
#
6+
7+
#!/usr/bin/env python3
8+
9+
import argparse
10+
import os
11+
import sys
12+
import tempfile
13+
from pathlib import Path
14+
15+
import torchaudio
16+
from torchaudio.datasets import LIBRISPEECH
17+
18+
import sentencepiece as spm
19+
20+
21+
def normalize_text(text: str, lowercase: bool = False) -> str:
22+
t = text.strip()
23+
if lowercase:
24+
t = t.lower()
25+
return t
26+
27+
28+
def build_corpus(root: str, subset: str, lowercase: bool, limit: int | None) -> str:
29+
dataset = LIBRISPEECH(root=root, url=subset, download=True)
30+
n = len(dataset)
31+
if limit is not None:
32+
n = min(n, limit)
33+
34+
tmp_fd, tmp_path = tempfile.mkstemp(prefix="librispeech_corpus_", suffix=".txt")
35+
os.close(tmp_fd)
36+
37+
with open(tmp_path, "w", encoding="utf-8") as f:
38+
for idx in range(n):
39+
try:
40+
_, _, transcript, *_ = dataset[idx]
41+
except Exception as ex:
42+
print(f"Warning: failed to read sample {idx}: {ex}", file=sys.stderr)
43+
continue
44+
line = normalize_text(transcript, lowercase)
45+
if line:
46+
f.write(line + "\n")
47+
return tmp_path
48+
49+
50+
def train_sentencepiece(
51+
corpus_path: str,
52+
output_dir: str,
53+
vocab_size: int,
54+
model_type: str,
55+
character_coverage: float,
56+
model_prefix: str,
57+
pad_id: int,
58+
disable_bos_eos: bool,
59+
seed_sentencepiece: int | None,
60+
input_sentence_size: int,
61+
):
62+
Path(output_dir).mkdir(parents=True, exist_ok=True)
63+
model_prefix_path = str(Path(output_dir) / model_prefix)
64+
65+
sp_args = [
66+
f"--input={corpus_path}",
67+
f"--model_prefix={model_prefix_path}",
68+
f"--vocab_size={vocab_size}",
69+
f"--model_type={model_type}",
70+
f"--character_coverage={character_coverage}",
71+
"--unk_id=0",
72+
"--input_sentence_size="+str(input_sentence_size),
73+
"--shuffle_input_sentence=true",
74+
"--hard_vocab_limit=true",
75+
"--num_threads=32",
76+
]
77+
78+
if disable_bos_eos:
79+
sp_args += ["--bos_id=-1", "--eos_id=-1"]
80+
else:
81+
sp_args += ["--bos_id=1", "--eos_id=2"]
82+
83+
if pad_id is None or pad_id < 0:
84+
sp_args += ["--pad_id=-1"]
85+
else:
86+
sp_args += [f"--pad_id={pad_id}"]
87+
88+
if seed_sentencepiece is not None:
89+
sp_args += [f"--seed_sentencepiece_size={seed_sentencepiece}"]
90+
91+
spm.SentencePieceTrainer.Train(" ".join(sp_args))
92+
93+
model_path = model_prefix_path + ".model"
94+
vocab_path = model_prefix_path + ".vocab"
95+
return model_path, vocab_path
96+
97+
98+
def main():
99+
parser = argparse.ArgumentParser(description="Train a 128-token SentencePiece tokenizer on LibriSpeech using torchaudio.")
100+
parser.add_argument("--root", type=str, default="./data", help="Directory to store/lookup LibriSpeech.")
101+
parser.add_argument("--subset", type=str, default="train-clean-100",
102+
choices=[
103+
"train-clean-100", "train-clean-360", "train-other-500",
104+
"dev-clean", "dev-other", "test-clean", "test-other"
105+
],
106+
help="LibriSpeech subset to use.")
107+
parser.add_argument("--output_dir", type=str, default="./tokenizer_out", help="Where to write the tokenizer files.")
108+
parser.add_argument("--vocab_size", type=int, default=128, help="Total vocabulary size.")
109+
parser.add_argument("--model_type", type=str, default="unigram", choices=["unigram", "bpe"],
110+
help="SentencePiece model type. 'unigram' tends to match English lists like yours.")
111+
parser.add_argument("--character_coverage", type=float, default=1.0,
112+
help="Fraction of characters covered by the model (1.0 is fine for English).")
113+
parser.add_argument("--model_prefix", type=str, default="librispeech_sp",
114+
help="Prefix (filename stem) for the trained model.")
115+
parser.add_argument("--lowercase", action="store_true", help="Lowercase transcripts before training.")
116+
parser.add_argument("--pad_id", type=int, default=-1, help="Pad ID; set to -1 to disable (default).")
117+
parser.add_argument("--disable_bos_eos", action="store_true", help="Disable BOS/EOS (recommended).")
118+
parser.add_argument("--enable_bos_eos", action="store_true", help="Enable BOS/EOS tokens.")
119+
parser.add_argument("--limit", type=int, default=None, help="Limit the number of samples for a quick run.")
120+
parser.add_argument("--seed_sentencepiece_size", type=int, default=None,
121+
help="Advanced: initial seed size for SentencePiece's sentence sampling (optional).")
122+
parser.add_argument("--input_sentence_size", type=int, default=1000000,
123+
help="Number of sentences to sample during training.")
124+
125+
args = parser.parse_args()
126+
127+
disable_bos_eos = True
128+
if args.enable_bos_eos:
129+
disable_bos_eos = False
130+
if args.disable_bos_eos:
131+
disable_bos_eos = True
132+
133+
print(f"Preparing corpus from LibriSpeech subset='{args.subset}'...")
134+
corpus_path = build_corpus(root=args.root, subset=args.subset, lowercase=args.lowercase, limit=args.limit)
135+
136+
print(f"Training SentencePiece... corpus_path {corpus_path}")
137+
model_path, vocab_path = train_sentencepiece(
138+
corpus_path=corpus_path,
139+
output_dir=args.output_dir,
140+
vocab_size=args.vocab_size,
141+
model_type=args.model_type,
142+
character_coverage=args.character_coverage,
143+
model_prefix=args.model_prefix,
144+
pad_id=args.pad_id,
145+
disable_bos_eos=disable_bos_eos,
146+
seed_sentencepiece=args.seed_sentencepiece_size,
147+
input_sentence_size=args.input_sentence_size,
148+
)
149+
150+
print("Done!")
151+
print(f"Model : {model_path}")
152+
print(f"Vocab : {vocab_path}")
153+
154+
155+
if __name__ == "__main__":
156+
main()
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#
2+
# SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
3+
#
4+
# SPDX-License-Identifier: Apache-2.0
5+
#
6+
7+
#!/usr/bin/env python3
8+
from __future__ import annotations
9+
import argparse
10+
from pathlib import Path
11+
12+
import torch
13+
import torch.nn as nn
14+
from torch.utils.data import DataLoader
15+
from torchaudio.datasets import LIBRISPEECH
16+
import sentencepiece as spm
17+
18+
# ---------------------------------------------------------------------
19+
# Import the components we need from the original training script
20+
# ---------------------------------------------------------------------
21+
import train # assumes train.py is in the same directory
22+
from train import (
23+
create_model,
24+
AudioPreprocessor,
25+
collate_eval_factory,
26+
evaluate,
27+
)
28+
29+
# ---------------------------------------------------------------------
30+
# Argument parsing
31+
# ---------------------------------------------------------------------
32+
def parse_args() -> argparse.Namespace:
33+
p = argparse.ArgumentParser(description="Conformer evaluation only")
34+
p.add_argument("--root", type=str, default="/shared/LIBRISPEECH",
35+
help="LibriSpeech root directory")
36+
p.add_argument("--set", type=str, default="test-clean",
37+
help="Comma-separated LibriSpeech subset names (e.g. test-clean,test-other)")
38+
p.add_argument("--batch-size", type=int, default=64,
39+
help="Batch size for evaluation")
40+
p.add_argument("--num-workers", type=int, default=4,
41+
help="DataLoader worker processes")
42+
p.add_argument("--sp-model", type=str,
43+
default="tokenizer_out/librispeech_sp.model",
44+
help="Path to the SentencePiece *.model used at training time")
45+
p.add_argument("--checkpoint", type=str, required=True,
46+
help="Path to a trained checkpoint (*.pt)")
47+
return p.parse_args()
48+
49+
# ---------------------------------------------------------------------
50+
# Main evaluation routine
51+
# ---------------------------------------------------------------------
52+
def main() -> None:
53+
args = parse_args()
54+
55+
# -------- SentencePiece ----------
56+
sp = spm.SentencePieceProcessor()
57+
sp.load(args.sp_model)
58+
59+
# expose the tokenizer inside the imported `train` module so that
60+
# train.int_to_text() works correctly during decoding
61+
train.sp = sp
62+
63+
vocab_size = sp.get_piece_size() + 1 # +1 for CTC blank
64+
65+
# -------- Model ----------
66+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67+
model = create_model(vocab_size).to(device)
68+
69+
ckpt_path = Path(args.checkpoint)
70+
ckpt = torch.load(ckpt_path, map_location=device)
71+
state = ckpt.get("model", ckpt) # handles raw state_dict vs wrapper
72+
73+
# tolerate DataParallel prefix differences
74+
missing, unexpected = model.load_state_dict(state, strict=False)
75+
if missing:
76+
print(f"[warn] missing keys in checkpoint: {missing}")
77+
if unexpected:
78+
print(f"[warn] unexpected keys in checkpoint: {unexpected}")
79+
80+
print(f"=> loaded weights from {ckpt_path}")
81+
82+
if torch.cuda.device_count() > 1:
83+
print(f"Using {torch.cuda.device_count()} GPUs via DataParallel …")
84+
model = torch.nn.DataParallel(model)
85+
86+
loss_fn = nn.CTCLoss(blank=0, zero_infinity=True)
87+
88+
# -------- Datasets & evaluation ----------
89+
subsets = [s.strip() for s in args.set.split(",") if s.strip()]
90+
preproc = AudioPreprocessor(training=False)
91+
for subset in subsets:
92+
ds = LIBRISPEECH(args.root, url=subset, download=True)
93+
loader = DataLoader(
94+
ds,
95+
batch_size=args.batch_size,
96+
shuffle=False,
97+
collate_fn=collate_eval_factory(preproc),
98+
num_workers=args.num_workers,
99+
)
100+
101+
wer, val_loss = evaluate(model, loader, device, loss_fn)
102+
print(f"\n── Results on {subset} ──")
103+
print(f" • WER: {wer:6.2f} %")
104+
print(f" • CTC loss: {val_loss:8.4f}")
105+
106+
if __name__ == "__main__":
107+
main()

0 commit comments

Comments
 (0)