This project designs a new audio segmentation model based on a two-layer bidirectional LSTM architecture. The input to the model is the output of the Whisper Encoder, while the output is the signal value (alphas) for each frame.
By predicting these alphas, the model is able to estimate the number of tokens within an audio segment, generate timestamps, and guide the segmentation of audio chunks. These predictions not only support the Whisper Decoder during decoding but also enable more efficient streaming speech recognition.
The training of this model follows the CIF method (Continuous Integrate-and-Fire), which has been successfully applied in ASR models such as Paraformer.
The general process is as follows:
- Pass the audio through the encoder to obtain the encoder output.
- Transform each frame’s feature into a scalar and apply a sigmoid function, yielding an alpha value between 0 and 1.
- Continuously accumulate alpha values until they reach a predefined threshold, then trigger a “fire.”
- When firing, reset the accumulator to zero. The encoder features of the frames whose alphas sum to 1 are combined through weighted averaging to produce a new frame as the output.
- The new frame is then sent into the decoder to produce a token.
In previous implementations, CIF typically functioned as part of a larger model, meaning training was applied to the full model rather than CIF itself. In contrast, this project uses Whisper’s encoder and decoder with frozen parameters, and specifically trains a small model for signal prediction.
The output of this model is the per-frame signal value (alpha), which serves several purposes:
-
Token Count Prediction By summing the alphas, we can estimate the token count of the current audio segment. This allows us to design a decoder-side penalty function that guides the decoding process and mitigates hallucinations, such as repeated or erroneous tokens.
-
Truncation Detection The sum of alphas indicates whether truncation occurs within the current segment. For instance, if the sum equals 2.5, the segment is truncated before the third token. If the decoder produces more than two tokens, it may have generated a truncated and thus erroneous token, which should be discarded.
-
Timestamp Generation By integrating the alpha values over time, we can compute the timestamp of each token.
-
Guiding Next-Segment Splitting The timestamp of the last complete token can serve as the starting point for the next audio segment, preventing truncation at the beginning of new segments and reducing decoding errors.
In Whisper’s Audio Encoder module, the input audio is required to be exactly 30 seconds long. Any shorter input raises an error. Since this project focuses on streaming recognition, I set the segment length to 1 second.
Padding a 1-second segment to 30 seconds wastes computation, so I slightly modified Whisper’s Audio Encoder by removing the assert statement enforcing the 30-second input. This allows direct processing of 1-second audio chunks.
Both LibriSpeech (English) and AIShell-1 (Chinese) datasets lack timestamp annotations, so I generated them myself.
- For LibriSpeech, I used WhisperX.
- For AIShell-1, I used the
fa_zhmodel from FunASR.
Timestamps were converted into label sequences. Given Whisper’s frame length of 0.02s, each timestamp was mapped to a frame index.
For each token, frames between its start and end indices were labeled with signal values, while others were set to 0 (blank). The signal value for each token was defined as:
where
These alpha labels are hypothetical, serving only as training signals to help the model estimate token counts and timestamps. They are not direct ground truth values.
The model input consists of Whisper Encoder outputs from 1-second audio segments. Steps:
- Traverse the timestamp list and extract 1-second segments every 5 tokens, aligned with token boundaries. This reduces overlap and ensures padding at the beginning.
- Pad shorter segments to 1 second.
- Encode the segments using Whisper Encoder to obtain feature representations.
- Extract corresponding labels and pad them if needed.
This process yielded 100,000 audio segments for training.
The loss function plays a crucial role in this model, as it significantly impacts performance. I designed a composite loss consisting of three components: Token Count Loss, Timestamp Regression Loss, and Blank Loss. The overall formula is:
where:
-
$\lambda_{\mathrm{count}}$ ,$\lambda_{\mathrm{time}}$ , and$\lambda_{\mathrm{blank}}$ are the weighting coefficients for Token Count Loss, Timestamp Regression Loss, and Blank Loss, respectively.
To enable the model to predict the number of tokens in an audio segment, I designed a Token Count Loss. Its formulation is:
where:
-
$B$ is the batch size; -
$C_i = \sum_{t=1}^T \alpha_{i,t}$ is the predicted total token count for the$i$ -th sample; -
$c_i$ is the ground-truth token count (true_counts) for the$i$ -th sample.
To train the model to predict the timestamps of each token within a segment, I designed a Timestamp Regression Loss:
where:
-
${i:c_i>0}$ denotes the subset of samples with token counts greater than 0; -
$\hat\tau_{i}$ is the predicted frame index list for the$i$ -th sample, with each element$\hat\tau_{i,u}$ computed as:
-
$w_{i,u,t}$ is the attention weight, defined as:
-
$f_{i,u,t}$ is the offset penalty, computed as:
-
$A_{i,t} = \sum_{k=1}^t \alpha_{i,k}$ is the temporal integration of$\alpha$ ; -
$u$ ranges from$1$ to$c_i$ , where$c_i$ is the ground-truth token count for the$i$ -th sample; -
$\tau_{i} = (\tau_{i,1}, \dots, \tau_{i,C_i})$ is the ground-truth frame index list for the$i$ -th sample.
To prevent the model from producing excessive signal values in blank frames (e.g., silence or background noise), I designed a Blank Loss:
where:
-
$m_{i,t}$ is the blank frame mask, defined as:
-
$l_{i,t}$ is the label in the label sequence, with$0$ indicating a blank frame.
I trained on both LibriSpeech (English) and AIShell-1 (Chinese) datasets, using Whisper’s base, small, medium, and large-v3 models. Optimizer: AdamW, learning rate: 1e-4, batch size: 8, epochs: 20.
Key results (base model shown as example):
- Chinese dataset: Validation loss = 0.7213 (Count Loss: 0.0968, Time Loss: 0.4646, Blank Loss: 0.0632).
- English dataset: Validation loss = 2.8246 (sub-losses not recorded).
Performance was notably better on Chinese, likely because it is a monosyllabic language. In English, multi-syllabic words such as “everyone” may be misinterpreted as multiple tokens (“every” + “one”), an issue absent in Chinese.
To further evaluate, I implemented a local demo using my own recordings. Example timestamps:
| Start | End | Token |
|---|---|---|
| 59.5 | 70.5 | 大 |
| 70.5 | 82.5 | 家 |
| 83.5 | 95.5 | 好 |
| 127.5 | 139.5 | 我 |
| 139.5 | 151.5 | 叫 |
| 151.5 | 162.5 | 小 |
| 162.5 | 174.5 | 明 |
| 201.5 | 213.5 | 很 |
| 213.5 | 224.5 | 高 |
| 224.5 | 232.5 | 兴 |
| 232.5 | 242.5 | 认 |
| 242.5 | 254.5 | 识 |
| 254.5 | 275.7 | 你 |
- Extract the first 1s of audio and pass it through the Whisper Encoder and the model.
- Use the predicted token boundary to cut the next segment.
- Repeat until the full recording is processed.
[69.49354553222656, 81.75992584228516, 92.08222961425781, 138.9051971435547,
148.89248657226562, 160.82579040527344, 176.3285369873047, 214.15577697753906,
225.21575927734375, 233.57644653320312, 243.09828186035156, 254.3844757080078,
273.79620361328125]
Predicted timestamps aligned closely with reference values, with only 2–3 frames of deviation, which is acceptable given timestamp variability.
- On blank segments (background noise >1s), the model output nearly zero alphas, showing robustness to silence.
tensor([3.5944e-06, 2.8404e-07, 1.7056e-07, 1.7256e-07, 2.0879e-07, 2.6019e-07, 3.2543e-07, 3.5060e-07, 3.5138e-07, 3.0055e-07, 2.8155e-07, 2.5605e-07, 2.1286e-07, 2.0457e-07, 1.9383e-07, 1.8396e-07, 1.8302e-07, 1.7400e-07, 1.6818e-07, 1.7000e-07, 1.6519e-07, 1.4737e-07, 1.6438e-07, 1.5200e-07, 1.4527e-07, 1.3863e-07, 1.4256e-07, 1.4058e-07, 1.3799e-07, 1.2596e-07, 1.3253e-07, 1.3737e-07, 1.5898e-07, 1.7688e-07, 1.9847e-07, 2.2142e-07, 2.4273e-07, 3.0331e-07, 4.1480e-07, 5.2232e-07, 4.3415e-07, 5.9241e-07, 5.9167e-07, 6.0761e-07, 4.8981e-07, 4.7571e-07, 4.4415e-07, 3.5544e-07, 5.9094e-07, 8.9407e-06], device='cuda:0', grad_fn=<SelectBackward0>) - On token-to-token gaps, alphas also remained near zero until a token appeared, further confirming robustness.
tensor([1.4861e-06, 1.4406e-07, 8.3902e-08, 7.3880e-08, 7.4446e-08, 7.3631e-08, 8.0129e-08, 8.8093e-08, 9.5247e-08, 1.0435e-07, 1.2186e-07, 1.2187e-07, 1.3772e-07, 1.6681e-07, 1.7682e-07, 1.3715e-07, 1.3125e-07, 9.6394e-08, 1.0572e-07, 8.6677e-08, 7.0555e-08, 5.6862e-08, 4.7397e-08, 4.9379e-08, 5.1056e-07, 1.3614e-05, 2.3531e-05, 1.4020e-04, 2.1676e-02, 4.8366e-01, 4.7016e-01, 4.5352e-03, 4.5237e-04, 1.7187e-04, 1.0282e-04, 5.6864e-05, 2.5481e-05, 1.6732e-05, 2.0137e-05, 2.2847e-05, 3.2090e-05, 8.0079e-05, 2.1820e-04, 9.8521e-04, 1.5554e-02, 1.3876e-01, 2.6434e-01, 3.0096e-01, 2.4608e-01, 1.1959e-01], device='cuda:0', grad_fn=<SelectBackward0>)