Port of NanoChat (Andrej Karpathy's GPT-2 speedrun) to train on a single Colab Pro+ TPU v6e (32 GB HBM) using raw JAX (no Flax/Orbax).
Model: 168M params (n_head=8, n_embd=1024, depth=16, head_dim=128, seq_len=2048)
- Architecture: RoPE, QK-norm, ReLU² MLP, sliding window attention (SSSL pattern), logit softcap (15.0), x0 residual connections
n_headis the primary scaling knob —n_embd = n_head × head_dim,depth = n_embd / aspect_ratio- No value embeddings (removed — param cost outweighed benefit at this scale)
- Separate wte (embed) + lm_head (unembed) — weight tying was tried but regressed loss due to init scale × softcap interaction
Training: 1K steps, device_batch_size=4, ~65K tok/s on v6e, val loss ~5.9 at step 1000
- AdamW optimizer (lr=3e-4, warmup 2%, warmdown 50%)
- Data: 50 shards from FineWeb-Edu-100B, tokenized on-the-fly with custom BPE (vocab 32768)
- Profiling: XProf annotations (
jax.named_scope) on all model components, TensorBoard integration
Attention: 4 switchable implementations via Config(attn_impl=...):
'einsum'— manual QK^T/softmax/AV, supports sliding window (default)'jax'—jax.nn.dot_product_attention, supports sliding window via mask'splash'— Pallas splash kernel (used in MaxText/Gemma), supports sliding window'pallas'— Pallas flash attention, causal only
MXU utilization: ~14.5% at device_batch_size=4 — model is memory-bandwidth-bound, not compute-bound at this scale. Larger batch sizes help but hit HBM limits (batch_size=6 uses 28.8 GiB / 32 GiB).
02_train.py Main training notebook (jupytext percent format)
03_worker.py Hyperparameter sweep worker (wandb sweeps)
01_tokenizer.py Tokenizer training notebook
LOG.md Chronological development log (Roman: / Agent: entries)
update_notebooks.sh Converts .py → .ipynb via jupytext
nanochat/ Reference nanochat repo (not committed, .gitignored)
- Edit the
.pyfile (e.g.02_train.py) — this is the source of truth - Run
bash update_notebooks.shto regenerate.ipynbfiles git add -A && git commit -m "..." && git push- User re-opens the notebook in Colab from GitHub to pick up changes
- LOG.md — append entries after significant work. Prefix with
Agent:. Roman prefixes his withRoman:. - Config is frozen — all model/training params in the
Configdataclass.Configis registered as a JAX static type, so changes trigger recompilation. - Weights use explicit head dims — QKV weights are
(n_embd, n_head, head_dim), c_proj is(n_head, head_dim, n_embd). Einsums produce multi-head shapes directly, no reshapes. - split_trainable/merge_params — RoPE parameters (rope_cos, rope_sin) are separated as non-trainable. This pattern is incompatible with
donate_argnums. - PrefetchDataLoader — background thread does
jax.device_putto overlap host→device transfer with compute. - Param counting — reported as embed / attn / mlp / lm_head separately for analysis.
Notebooks are stored as .py files in jupytext percent format for readable diffs. The corresponding .ipynb files are also committed so you can open them directly in Colab from GitHub.
| Notebook | Open in Colab | Description |
|---|---|---|
01_tokenizer.py |
Open | Train BPE tokenizer (vocab 32768), upload to HuggingFace Hub |
02_train.py |
Open | Pretrain GPT model in raw JAX on single TPU |
03_worker.py |
Open | wandb hyperparameter sweep worker |
04_maxtext.py |
Open | MaxText-inspired ~370M model (SwiGLU, 256-aligned dims) |