nanoXLSTM is a minimal codebase for playing around with language models based on the xLSTM (extended Long Short-Term Memory) architecture from the awesome research paper: xLSTM: Extended Long Short-Term Memory and heavily inspired by Andrej Karpathy's nanoGPT.
**Note: Work in progress!!! I am working on improving the generated text.
No lofty goals here - just a simple codebase for tinkering with this innovative xLSTM technology!
Contributions are more than welcome as I continue exploring this exciting research direction.
pip install torch numpy transformers datasets tiktoken wandb tqdm
!python data/shakespeare_char/prepare.py
python train.py config/train_shakespeare_char.py
python sample.py --out_dir=out-shakespeare-char
- Run hyperparameter sweep
- Import OneCycleLR: The OneCycleLR scheduler is imported from
torch.optim.lr_scheduler. -
sLSTMclass: Thef_biasanddropoutare added to thesLSTMclass. -
mLSTMclass: Thef_biasanddropoutare added to themLSTMclass. -
xLSTMBlockclass: ThexLSTMBlockclass is implemented with a configurable ratio ofsLSTMandmLSTMblocks, and layer normalization is applied. -
GPTclass: ThexLSTM_blocksare used in theGPTclass instead of separatesLSTMandmLSTMblocks. -
configure_optimizersmethod: Theconfigure_optimizersmethod in theGPTclass is updated to use AdamW optimizer and OneCycleLR scheduler.
20/05/24
- Initialize the forget gate bias (self.f_bias) with values between 3 and 6 instead of ones. This helps the forget gate to be effective from the beginning of training.
- Introduce a stabilization technique to avoid overflow due to the exponential function. You can use the max function to compute a stabilization factor and subtract it from the input gate and forget gate activations before applying the exponential function.
- Import statement: The OneCycleLR scheduler is imported.
- Optimizer and scheduler initialization: The optimizer and scheduler are obtained from the
configure_optimizersmethod of theGPTclass. - Loading optimizer and scheduler state: The optimizer and scheduler states are loaded from the checkpoint when resuming training.
- Saving scheduler state: The scheduler state is included in the checkpoint dictionary.
- Stepping the scheduler: The
scheduler.step()is called after each optimizer step. - Logging learning rate and MFU: The learning rate and MFU are logged using
wandb(ifwandb_logis enabled). -
estimate_lossfunction: Theestimate_lossfunction is updated to use thectxcontext manager. - Training loop: The training loop is updated to use
scaler.scale(loss).backward()andscaler.step(optimizer)for gradient scaling when training in fp16.

