diff --git a/records/track_non_record_16mb/2026-03-24_11L_XSA_SwiGLU_LoRATTT_1xH100/README.md b/records/track_non_record_16mb/2026-03-24_11L_XSA_SwiGLU_LoRATTT_1xH100/README.md new file mode 100644 index 000000000..b1e9dbdd7 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-24_11L_XSA_SwiGLU_LoRATTT_1xH100/README.md @@ -0,0 +1,101 @@ +# Non-record: 11L XSA + SwiGLU + LoRA TTT (1xH100 PCIe) + +**val_bpb: 1.1573** (LoRA TTT) | **15.02 MB** artifact | 1xH100 PCIe, ~80 min + +## Key Techniques + +1. **XSA (Cross-token Self-Attention)** on last 4 layers — removes self-value projection, forcing attention to contribute cross-position context. -0.005 BPB. +2. **SwiGLU 3x MLP** — gated activation `swish(gate(x)) * up(x)`. More parameter-efficient than ReLU². +0.004. +3. **SmearGate** — blends each token embedding with the previous token's embedding, giving bigram context at the embedding layer. Critical: +1.80 loss without it. +4. **U-Net skip connections** — encoder (L0-L4) saves skip outputs, decoder (L6-L10) adds them back. Ensures gradient flow through all 11 layers. +5. **Orthogonal initialization** — all weight matrices initialized orthogonally. Required for SmearGate to work. +6. **Muon optimizer with WD=0.04** — decoupled weight decay shrinks weights for better quantization + generalization. +7. **Stochastic Weight Averaging** — averages 15 checkpoints during warmdown for smoother quantized weights. +8. **Mixed quantization** — int5 (MLP) + int6 (attention) + int8 (embeddings) + zstd-22 compression. Fits in 15.02 MB. +9. **LoRA TTT** — per-document test-time training with rank-8 LoRA on Q and V projections. Score-then-train per 256-token chunk (legal: every token scored before being trained on). -0.034 BPB. + +## Results + +| Eval Method | val_loss | val_bpb | Delta | +|-------------|----------|---------|-------| +| Pre-quant (SWA) | 1.9800 | 1.1727 | — | +| Int8+zlib roundtrip | 1.9969 | 1.1826 | +0.010 | +| Mixed quant (int5/int6/int8+zstd) | 1.9913 | 1.1930 | +0.020 | +| **LoRA TTT (mixed quant)** | **1.9724** | **1.1573** | **-0.015** | + +## Architecture + +``` +11L, 512d, 8H/4KV (GQA), SwiGLU 3x MLP +XSA on L7-L10, SmearGate, U-Net skips +OrthoInit, Muon WD=0.04, SWA (15 checkpoints) +Mixed quant: int5-MLP + int6-attn + int8-embed + zstd-22 +LoRA TTT: rank-8, Q+V, LR=0.05, score-then-train, 256-token chunks +``` + +## Training Configuration + +- **GPU**: 1xH100 PCIe (RunPod) — grad accumulation 8 steps to match 524K batch +- **Wallclock**: ~4850s (~80 min) — NOT a 10-min record submission +- **Batch**: 524,288 tokens/step (grad_accum=8 × seq_len=2048 × micro_batch=32) +- **Sequence length**: 2048 +- **Warmdown**: 3000 iterations +- **Steps completed**: 7,926 / 20,000 (wallclock cap) + +## Why Non-Record + +This ran on 1xH100 PCIe for ~80 minutes (not 8xH100 in 10 min). The architecture and training are identical to what would run on 8xH100 — only the batch parallelism differs. + +## Development Journey + +18 experiments over 5 days, from val_bpb=3.10 (wrong batch size) to 1.1573: + +| Experiment | val_bpb | What changed | +|-----------|---------|-------------| +| 1 (baseline) | 3.10 | Wrong batch size | +| 2 | 1.46 | Fixed batch to 65K | +| 6 | 1.312 | 1200s training, warmdown=600 | +| 10 | 1.283 | + SmearGate, OrthoInit, MLP 3x, WD | +| 13 | — | SwiGLU > ReLU² (+0.004) | +| 14-8x | 1.202 | 11 layers + SWA on 8xH100 | +| 15 | 1.187 | + seq_len=2048 | +| 17 | 1.183 | + XSA (last 4 layers) | +| + Quant | 1.191 | int5+int6+int8+zstd (15 MB) | +| **+ LoRA TTT** | **1.157** | Per-document adaptation at eval | + +Total compute cost: ~$50 across all experiments. + +## What Didn't Work + +| Technique | Result | Why | +|-----------|--------|-----| +| Register token | +0.002 worse | Step overhead > marginal benefit | +| Layer looping + wider | +0.034 worse | Step time from wider dim | +| Data sampling (juncture) | +0.002 worse | Shard-level too coarse | +| Hard example mining | +0.040 worse | Destroys Muon weight geometry | +| Partial RoPE (16/64) | +0.015 worse | Head dim too small | +| EMA (replacing SWA) | +0.015 worse | Over-smoothed warmdown weights | +| BigramHash | 0.000 | SmearGate makes it redundant | +| SGD TTT | +0.018 worse | Modifying dequantized weights directly breaks them | + +## Command + +```bash +RUN_ID=exp17_xsa \ +MAX_WALLCLOCK_SECONDS=4850 \ +TRAIN_BATCH_TOKENS=524288 \ +WARMDOWN_ITERS=3000 \ +MUON_WD=0.04 \ +NUM_LAYERS=11 \ +TRAIN_SEQ_LEN=2048 \ +MLP_MULT=3 \ +MATRIX_LR=0.04 \ +SCALAR_LR=0.04 \ +python train_gpt.py +``` + +## Included Files + +- `train_gpt.py` — full training + quantization + LoRA TTT eval script +- `train.log` — training log from 1xH100 run +- `submission.json` — metadata diff --git a/records/track_non_record_16mb/2026-03-24_11L_XSA_SwiGLU_LoRATTT_1xH100/submission.json b/records/track_non_record_16mb/2026-03-24_11L_XSA_SwiGLU_LoRATTT_1xH100/submission.json new file mode 100644 index 000000000..e8c9913af --- /dev/null +++ b/records/track_non_record_16mb/2026-03-24_11L_XSA_SwiGLU_LoRATTT_1xH100/submission.json @@ -0,0 +1,18 @@ +{ + "name": "11L XSA + SwiGLU + SWA + Mixed Quant + LoRA TTT (1xH100 PCIe)", + "author": "swapp1990", + "github_id": "swapp1990", + "date": "2026-03-24", + "val_bpb": 1.1573, + "val_loss": 1.9724, + "pre_quant_val_bpb": 1.1727, + "pre_quant_val_loss": 1.9800, + "step_stop": 7926, + "wallclock_seconds": 4850, + "gpu": "1xH100 PCIe (RunPod)", + "track": "non-record-16mb", + "bytes_total": 15793319, + "bytes_model": 15727804, + "bytes_code": 65515, + "blurb": "11-layer transformer with XSA (cross-token self-attention on last 4 layers), SwiGLU 3x MLP, SmearGate, U-Net skip connections, orthogonal init, Muon optimizer with WD=0.04, and stochastic weight averaging. Trained on 1xH100 PCIe for ~80 min with batch=524K seq=2048 (grad accumulation). Mixed quantization (int5 MLP + int6 attn + int8 embed + zstd) fits in 15.02 MB. LoRA TTT (rank-8, score-then-train per chunk) brings val_bpb from 1.191 to 1.157. 18 experiments over 5 days." +} diff --git a/records/track_non_record_16mb/2026-03-24_11L_XSA_SwiGLU_LoRATTT_1xH100/train.log b/records/track_non_record_16mb/2026-03-24_11L_XSA_SwiGLU_LoRATTT_1xH100/train.log new file mode 100644 index 000000000..d0c055826 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-24_11L_XSA_SwiGLU_LoRATTT_1xH100/train.log @@ -0,0 +1,103 @@ +logs/06226eeb-9a28-46a0-ba37-47715b3f2521.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27092057 +world_size:1 grad_accum_steps:8 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04 +train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:4850.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9277 val_bpb:4.1030 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9295 train_time:578ms step_avg:577.95ms +step:2/20000 train_loss:9.7430 train_time:1181ms step_avg:590.66ms +step:3/20000 train_loss:8.6021 train_time:1787ms step_avg:595.71ms +step:4/20000 train_loss:8.7606 train_time:2796ms step_avg:699.05ms +step:5/20000 train_loss:8.3936 train_time:3421ms step_avg:684.26ms +step:6/20000 train_loss:8.2054 train_time:4025ms step_avg:670.80ms +step:7/20000 train_loss:7.7758 train_time:4625ms step_avg:660.69ms +step:8/20000 train_loss:7.2954 train_time:5294ms step_avg:661.75ms +step:9/20000 train_loss:6.8198 train_time:5924ms step_avg:658.17ms +step:10/20000 train_loss:6.4736 train_time:6506ms step_avg:650.65ms +step:200/20000 train_loss:2.7500 train_time:125492ms step_avg:627.46ms +step:400/20000 train_loss:2.3249 train_time:246819ms step_avg:617.05ms +step:600/20000 train_loss:2.4240 train_time:367212ms step_avg:612.02ms +step:800/20000 train_loss:2.2708 train_time:488724ms step_avg:610.91ms +step:1000/20000 train_loss:2.2825 train_time:613127ms step_avg:613.13ms +step:1000/20000 val_loss:2.2538 val_bpb:1.3348 train_time:613139ms step_avg:613.14ms +step:1200/20000 train_loss:2.2151 train_time:734714ms step_avg:612.26ms +step:1400/20000 train_loss:2.2454 train_time:856615ms step_avg:611.87ms +step:1600/20000 train_loss:2.1343 train_time:978149ms step_avg:611.34ms +step:1800/20000 train_loss:2.1791 train_time:1099034ms step_avg:610.57ms +step:2000/20000 train_loss:2.1387 train_time:1219142ms step_avg:609.57ms +step:2000/20000 val_loss:2.1475 val_bpb:1.2719 train_time:1219151ms step_avg:609.58ms +step:2200/20000 train_loss:2.0670 train_time:1342246ms step_avg:610.11ms +step:2400/20000 train_loss:2.1141 train_time:1462622ms step_avg:609.43ms +step:2600/20000 train_loss:2.1813 train_time:1584180ms step_avg:609.30ms +step:2800/20000 train_loss:2.1298 train_time:1704582ms step_avg:608.78ms +step:3000/20000 train_loss:2.0552 train_time:1825266ms step_avg:608.42ms +step:3000/20000 val_loss:2.1026 val_bpb:1.2453 train_time:1825424ms step_avg:608.47ms +step:3200/20000 train_loss:2.1302 train_time:1950560ms step_avg:609.55ms +step:3400/20000 train_loss:2.1032 train_time:2072471ms step_avg:609.55ms +step:3600/20000 train_loss:2.0625 train_time:2196370ms step_avg:610.10ms +step:3800/20000 train_loss:2.1419 train_time:2319658ms step_avg:610.44ms +step:4000/20000 train_loss:2.0426 train_time:2442476ms step_avg:610.62ms +step:4000/20000 val_loss:2.0758 val_bpb:1.2294 train_time:2442641ms step_avg:610.66ms +step:4200/20000 train_loss:2.1144 train_time:2570120ms step_avg:611.93ms +step:4400/20000 train_loss:2.1164 train_time:2695921ms step_avg:612.71ms +step:4600/20000 train_loss:1.9531 train_time:2819014ms step_avg:612.83ms +step:4800/20000 train_loss:2.0661 train_time:2939820ms step_avg:612.46ms +step:5000/20000 train_loss:2.0568 train_time:3060813ms step_avg:612.16ms +SWA: started at step 5000 (warmdown_start~4922) +step:5000/20000 val_loss:2.0563 val_bpb:1.2179 train_time:3061106ms step_avg:612.22ms +step:5200/20000 train_loss:2.0657 train_time:3181771ms step_avg:611.88ms +step:5400/20000 train_loss:2.0772 train_time:3303276ms step_avg:611.72ms +step:5600/20000 train_loss:2.0720 train_time:3427130ms step_avg:611.99ms +step:5800/20000 train_loss:2.0417 train_time:3548669ms step_avg:611.84ms +step:6000/20000 train_loss:2.1672 train_time:3670464ms step_avg:611.74ms +step:6000/20000 val_loss:2.0271 val_bpb:1.2006 train_time:3670951ms step_avg:611.83ms +step:6200/20000 train_loss:2.0269 train_time:3794142ms step_avg:611.96ms +step:6400/20000 train_loss:2.0329 train_time:3914942ms step_avg:611.71ms +step:6600/20000 train_loss:1.9839 train_time:4035937ms step_avg:611.51ms +step:6800/20000 train_loss:2.1208 train_time:4158062ms step_avg:611.48ms +step:7000/20000 train_loss:2.0063 train_time:4279307ms step_avg:611.33ms +step:7000/20000 val_loss:1.9992 val_bpb:1.1840 train_time:4279531ms step_avg:611.36ms +step:7200/20000 train_loss:2.0013 train_time:4402050ms step_avg:611.40ms +step:7400/20000 train_loss:1.9981 train_time:4523744ms step_avg:611.32ms +step:7600/20000 train_loss:1.9361 train_time:4645429ms step_avg:611.24ms +step:7800/20000 train_loss:1.9874 train_time:4773436ms step_avg:611.98ms +step:7926/20000 val_loss:1.9800 val_bpb:1.1727 train_time:4850090ms step_avg:611.92ms +stopping_early: wallclock_cap train_time:4850090ms step:7926/20000 +peak memory allocated: 18476 MiB reserved: 18726 MiB +SWA: averaging 15 checkpoints +SWA: applied +Serialized model: 106317248 bytes +Code size: 57899 bytes +Total submission size: 106375147 bytes +Serialized model int8+zlib: 24655764 bytes (payload:27333986 raw_torch:27397078 payload_ratio:3.89x) +Total submission size int8+zlib: 24713663 bytes +/runpod-volume/parameter-golf/train_gpt.py:1319: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") +final_int8_zlib_roundtrip val_loss:1.9969 val_bpb:1.1826 eval_time:18994ms +final_int8_zlib_roundtrip_exact val_loss:1.99685268 val_bpb:1.18264864 diff --git a/records/track_non_record_16mb/2026-03-24_11L_XSA_SwiGLU_LoRATTT_1xH100/train_gpt.py b/records/track_non_record_16mb/2026-03-24_11L_XSA_SwiGLU_LoRATTT_1xH100/train_gpt.py new file mode 100644 index 000000000..cc3497b1d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-24_11L_XSA_SwiGLU_LoRATTT_1xH100/train_gpt.py @@ -0,0 +1,1473 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 3)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + wd = group.get("weight_decay", 0.0) + if wd > 0 and p.ndim >= 2: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + # Matrices get one scale per row, which usually tracks output-channel + # ranges much better than a single tensor-wide scale. + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + # Vectors / scalars use a simpler per-tensor scale. + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + + +# --- Mixed quantization (int5-MLP, int6-attn, int8-embed) --- + +def quantize_int5_per_row(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / 15.0).clamp_min(1e-12).to(torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -16, 15).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / 15.0, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -16, 15).to(torch.int8) + return q, scale + +def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1e-12).to(torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / 31.0, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) + return q, scale + +MLP_QUANT_PATTERNS = ("mlp.", "gate.", "up.") +EMBED_QUANT_PATTERNS = ("tok_emb",) + +def quantize_state_dict_mixed(state_dict: dict) -> dict: + result = {} + for name, t in state_dict.items(): + t_cpu = t.detach().cpu() + if t_cpu.numel() <= 896 or not t_cpu.is_floating_point(): + result[name] = t_cpu.to(torch.float16) if t_cpu.is_floating_point() else t_cpu + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS) or any(p in name for p in ("smear",)): + result[name] = t_cpu.to(torch.float16).contiguous() + continue + t32 = t_cpu.float() + if any(p in name for p in EMBED_QUANT_PATTERNS): + # int8 for embeddings + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / 127.0).clamp_min(1e-12).to(torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -128, 127).to(torch.int8) + else: + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / 127.0, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -128, 127).to(torch.int8) + result[name + ".__q"] = q + result[name + ".__scale"] = scale + result[name + ".__dtype"] = str(t_cpu.dtype) + elif any(p in name for p in MLP_QUANT_PATTERNS): + q, scale = quantize_int5_per_row(t_cpu) + result[name + ".__q"] = q + result[name + ".__scale"] = scale + result[name + ".__dtype"] = str(t_cpu.dtype) + else: + q, scale = quantize_int6_per_row(t_cpu) + result[name + ".__q"] = q + result[name + ".__scale"] = scale + result[name + ".__dtype"] = str(t_cpu.dtype) + result["__quant_format__"] = "mixed_v1" + return result + +def dequantize_state_dict_mixed(quant_dict: dict) -> dict: + quant_dict.pop("__quant_format__", None) + result = {} + seen = set() + for key in list(quant_dict.keys()): + if key.endswith(".__q"): + name = key[:-4] + if name in seen: continue + seen.add(name) + q = quant_dict[name + ".__q"] + scale = quant_dict[name + ".__scale"] + dtype = getattr(torch, quant_dict[name + ".__dtype"].split(".")[-1]) + if q.ndim == 2 and scale.ndim == 1: + result[name] = (q.float() * scale.float()[:, None]).to(dtype) + else: + result[name] = (q.float() * scale.float()).to(dtype) + elif not any(key.endswith(s) for s in (".__scale", ".__dtype")): + result[key] = quant_dict[key] + return result + + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + use_xsa: bool = False, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + self.use_xsa = use_xsa + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # Expand KV heads for GQA (compatible with PyTorch < 2.5 without enable_gqa) + if self.num_kv_heads != self.num_heads: + reps = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(reps, dim=1) + v = v.repeat_interleave(reps, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) + # XSA: remove self-value projection from attention output + if self.use_xsa: + # v is already expanded for GQA above + dot_yv = (y * v).sum(dim=-1, keepdim=True) + dot_vv = (v * v).sum(dim=-1, keepdim=True).clamp_min(1e-8) + y = y - (dot_yv / dot_vv) * v + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # SwiGLU MLP: swish(gate(x)) * up(x), then project down + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(2 * mlp_mult * dim / 3) + hidden = ((hidden + 63) // 64) * 64 + self.gate = CastedLinear(dim, hidden, bias=False) + self.up = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + return self.proj(F.silu(self.gate(x)) * self.up(x)) + + + +class SmearGate(nn.Module): + """Blend each token embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + use_xsa: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, use_xsa=use_xsa) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + use_xsa=(i >= num_layers - 4), + ) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + import math + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + + # First half stores skips; second half reuses them in reverse order. + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.Adam( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + + # SWA: collect weights every 200 steps during warmdown + _swa_every = 200 + if not hasattr(base_model, '_swa_state'): + base_model._swa_state = None + base_model._swa_count = 0 + _est_total = int(max_wallclock_ms / (approx_training_time_ms / max(step, 1))) + _warmdown_start = max(0, _est_total - args.warmdown_iters) + if step >= _warmdown_start and step % _swa_every == 0 and step > 100: + if base_model._swa_state is None: + base_model._swa_state = {n: p.detach().cpu().clone().float() for n, p in base_model.state_dict().items()} + base_model._swa_count = 1 + log0(f"SWA: started at step {step} (warmdown_start~{_warmdown_start})") + else: + for n, p in base_model.state_dict().items(): + base_model._swa_state[n] += p.detach().cpu().float() + base_model._swa_count += 1 + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + + # Apply SWA averaged weights + if hasattr(base_model, '_swa_state') and base_model._swa_state is not None and base_model._swa_count > 1: + log0(f"SWA: averaging {base_model._swa_count} checkpoints") + for n, t in base_model._swa_state.items(): + avg = (t / base_model._swa_count).to(dtype=base_model.state_dict()[n].dtype) + base_model.state_dict()[n].copy_(avg) + log0("SWA: applied") + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu", weights_only=False) + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Mixed quantization (int5-MLP + int6-attn + int8-embed + zstd-22) + code_bytes = len(code.encode("utf-8")) + try: + import zstandard + quant_mixed = quantize_state_dict_mixed(base_model.state_dict()) + mixed_buf = io.BytesIO() + torch.save(quant_mixed, mixed_buf) + mixed_raw = mixed_buf.getvalue() + mixed_blob = zstandard.ZstdCompressor(level=22).compress(mixed_raw) + with open("final_model.mixed.ptz", "wb") as f: + f.write(mixed_blob) + log0(f"Serialized model mixed int5/int6/int8+zstd: {len(mixed_blob)} bytes") + _fits = "FITS" if (len(mixed_blob) + code_bytes) < 16_000_000 else "OVER" + log0(f"Total mixed submission: {len(mixed_blob) + code_bytes} bytes ({_fits} 16MB)") + # Roundtrip validation + mixed_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(mixed_blob)), + map_location="cpu", weights_only=False + ) + base_model.load_state_dict(dequantize_state_dict_mixed(mixed_state), strict=True) + torch.cuda.synchronize() + t_mqeval = time.perf_counter() + mq_val_loss, mq_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0(f"final_mixed_roundtrip val_loss:{mq_val_loss:.4f} val_bpb:{mq_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_mqeval):.0f}ms") + # --- LoRA TTT eval on mixed-quantized model (score-then-train, parallelized) --- + # Each document is processed in chunks. For each chunk: + # 1. Score the chunk (record NLL + bytes) + # 2. Train LoRA on that chunk (so later chunks benefit from adaptation) + # This is legal: every token is scored BEFORE being trained on. + ttt_enabled = int(os.environ.get("TTT_ENABLED", "1")) + ttt_max_docs = int(os.environ.get("TTT_MAX_DOCS", "0")) # 0 = all docs + ttt_min_doc_len = int(os.environ.get("TTT_MIN_DOC_LEN", "32")) # skip tiny docs for TTT + if ttt_enabled: + log0("Running LoRA TTT eval (score-then-train per chunk, parallelized)...") + ttt_rank_r = int(os.environ.get("TTT_RANK", "8")) + ttt_lr = float(os.environ.get("TTT_LR", "0.05")) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", "256")) + + # LoRA adapter + class LoRALinear(nn.Module): + def __init__(self, original, lora_rank=8): + super().__init__() + self.original = original + in_d = original.weight.shape[1] + out_d = original.weight.shape[0] + self.lora_A = nn.Parameter(torch.randn(lora_rank, in_d, device=device) * 0.01) + self.lora_B = nn.Parameter(torch.randn(out_d, lora_rank, device=device) * 0.001) + self.scale = 1.0 / lora_rank + for p in self.original.parameters(): + p.requires_grad = False + def forward(self, x): + base = F.linear(x, self.original.weight.to(x.dtype), + self.original.bias.to(x.dtype) if self.original.bias is not None else None) + return base + (x @ self.lora_A.to(x.dtype).T @ self.lora_B.to(x.dtype).T) * self.scale + def reset(self): + nn.init.normal_(self.lora_A, std=0.01) + nn.init.normal_(self.lora_B, std=0.001) + + # Inject LoRA into Q and V of all layers + for p in base_model.parameters(): + p.requires_grad = False + lora_modules = [] + for blk in base_model.blocks: + lq = LoRALinear(blk.attn.c_q, lora_rank=ttt_rank_r) + blk.attn.c_q = lq + lora_modules.append(lq) + lv = LoRALinear(blk.attn.c_v, lora_rank=ttt_rank_r) + blk.attn.c_v = lv + lora_modules.append(lv) + lora_params = [] + for m in lora_modules: + lora_params.extend([m.lora_A, m.lora_B]) + log0(f"TTT: {len(lora_modules)} LoRA modules, {sum(p.numel() for p in lora_params)} params, chunk_size={ttt_chunk_size}") + + # Find document boundaries (BOS=1) + bos_positions = (val_tokens == 1).nonzero(as_tuple=True)[0].cpu().numpy() + n_all_docs = len(bos_positions) + if ttt_max_docs > 0: + n_all_docs = min(n_all_docs, ttt_max_docs) + + # Build doc list with (start, length) + doc_list = [] + for d in range(n_all_docs): + ds = int(bos_positions[d]) + de = int(bos_positions[d + 1]) if d + 1 < len(bos_positions) else val_tokens.numel() + dl = de - ds + if dl >= 5: + doc_list.append((ds, dl)) + n_ttt_docs = len(doc_list) + + # Separate short and long docs + short_docs = [(ds, dl) for ds, dl in doc_list if dl < ttt_min_doc_len] + long_docs = [(ds, dl) for ds, dl in doc_list if dl >= ttt_min_doc_len] + + # Shard docs across GPUs + my_short = short_docs[(len(short_docs) * rank) // world_size : (len(short_docs) * (rank + 1)) // world_size] + my_long = long_docs[(len(long_docs) * rank) // world_size : (len(long_docs) * (rank + 1)) // world_size] + log0(f"TTT: {n_ttt_docs} docs ({len(short_docs)} short, {len(long_docs)} long), " + f"rank {rank}: {len(my_short)} short + {len(my_long)} long") + + ttt_nll = torch.zeros((), device=device, dtype=torch.float64) + ttt_bytes = torch.zeros((), device=device, dtype=torch.float64) + ttt_tokens = torch.zeros((), device=device, dtype=torch.float64) + t_ttt = time.perf_counter() + + # Short docs: score without TTT adaptation (not enough context to benefit) + base_model.eval() + with torch.no_grad(): + for ds, dl in my_short: + x = val_tokens[ds:ds + dl - 1].to(device=device, dtype=torch.int64).unsqueeze(0) + y = val_tokens[ds + 1:ds + dl].to(device=device, dtype=torch.int64).unsqueeze(0) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + n = dl - 1 + ttt_nll += loss.to(torch.float64) * n + ttt_tokens += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_bytes_lut[tgt_ids].to(torch.float64) + tb += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.float64) + ttt_bytes += tb.sum() + + # Long docs: chunk-based score-then-train + # For each chunk: score it, then train LoRA on it. + # Every token is scored BEFORE being trained on (legal). + # LoRA adaptation from prior chunks provides document context. + for di, (ds, dl) in enumerate(my_long): + # Reset LoRA for each new document + for m in lora_modules: + m.reset() + optimizer = torch.optim.Adam(lora_params, lr=ttt_lr) + + pred_len = dl - 1 # number of next-token predictions + for chunk_start in range(0, pred_len, ttt_chunk_size): + chunk_end = min(chunk_start + ttt_chunk_size, pred_len) + chunk_len = chunk_end - chunk_start + if chunk_len < 2: + continue + + x = val_tokens[ds + chunk_start:ds + chunk_end].to(device=device, dtype=torch.int64).unsqueeze(0) + y = val_tokens[ds + chunk_start + 1:ds + chunk_end + 1].to(device=device, dtype=torch.int64).unsqueeze(0) + + is_last_chunk = (chunk_end >= pred_len) + + # Forward pass — need grads only if we'll train after + if is_last_chunk: + base_model.eval() + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + else: + base_model.train() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + + # 1. Score: accumulate NLL + bytes for this chunk + with torch.no_grad(): + ttt_nll += loss.to(torch.float64) * chunk_len + ttt_tokens += chunk_len + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_bytes_lut[tgt_ids].to(torch.float64) + tb += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.float64) + ttt_bytes += tb.sum() + + # 2. Train LoRA on this chunk (so subsequent chunks benefit) + if not is_last_chunk: + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if (di + 1) % 1000 == 0: + elapsed = time.perf_counter() - t_ttt + running_bpb = (ttt_nll.item() / math.log(2.0)) / max(ttt_bytes.item(), 1.0) + log0(f" TTT rank {rank}: {di+1}/{len(my_long)} long docs, bpb={running_bpb:.4f} elapsed={elapsed:.0f}s") + + # All-reduce across GPUs + if distributed: + dist.all_reduce(ttt_nll, op=dist.ReduceOp.SUM) + dist.all_reduce(ttt_bytes, op=dist.ReduceOp.SUM) + dist.all_reduce(ttt_tokens, op=dist.ReduceOp.SUM) + ttt_bpb = (ttt_nll.item() / math.log(2.0)) / max(ttt_bytes.item(), 1.0) + ttt_loss = ttt_nll.item() / max(ttt_tokens.item(), 1.0) + log0(f"final_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"docs:{n_ttt_docs} eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + + except ImportError: + log0("WARNING: zstandard not installed, skipping mixed quantization") + except Exception as e: + log0(f"WARNING: mixed quantization failed: {e}") + import traceback; traceback.print_exc() + + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main()