Skip to content

My reimplementation of the GPT-2 architecture on the level of nn.Parameter. Contains trained weights plus code for model architecture, inference and multi-GPU training.

License

JaHeRoth/gpt2-from-scratch

Repository files navigation

gpt2-from-scratch

The main purpose of this project was for me to better grasp all the details in the end-to-end process of training a system such as GPT-2. However, I hope this project can also be of use to others, if only by providing an example implementation of this end-to-end process.

Contents

  • train.py contains the entry point for training this GPT-2 like model on a subset of FineWeb-Edu, across all GPUs of the machine.
  • inference.ipynb contains code for loading and streaming from the model trained in train.py (or the model I have pre-trained for you).
  • outputs/tokenizer and outputs/model.pt contain the result of my training run on a 1024-length sequence packed encoding of the sample-10BT subset of FineWeb-Edu, for just below 30k updates (1.5 epochs) of batch size 512 each.
  • utilities contains our model and optimizer architectures, as well as all the helper code we use for loading and preprocessing data and performing our training runs.

Prerequisites

  • pixi. Installation guide: curl -fsSL https://pixi.sh/install.sh | sh
  • git lfs. Installation guide:
    • Linux: sudo apt-get install git-lfs ; git lfs install
    • macOS: brew install git-lfs ; git lfs install

Setup

git clone https://github.com/JaHeRoth/gpt2-from-scratch.git
cd gpt2-from-scratch
pixi install
pixi shell

Learnings

  • Exact network architectures of GPT and GPT-2 (down to the level of every individual nn.Parameter)
  • Inner workings of AdamW optimizer
  • LLM sampling tricks (and implementing temperature and nucleus sampling)
  • Sequence packing
  • HuggingFace tokenizers, datasets and Hub
  • The PyTorch stack
  • GPU tricks (kernel fusion through torch.compile, optimizing tensor sizes)
  • Mixed-precision training
  • Distributed training with DistributedDataParallel

Future directions

  • Optimizations of newer models: MoE, MLA, ROPE, YaRN, SWiGLU, QKNorm
  • SFT, to turn into chatbot
  • Reinforcement learning and allowing "thinking" stage (although I doubt model is smart enough to benefit from chain of thought)

Sources

About

My reimplementation of the GPT-2 architecture on the level of nn.Parameter. Contains trained weights plus code for model architecture, inference and multi-GPU training.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published