diff --git a/.gitignore b/.gitignore index 3423c416a..9260888ec 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,6 @@ data/manifest.json data/docs_selected.jsonl .mypy_cache/ .venv -logs/ \ No newline at end of file +logs/ +final_model.* +sweep.sh \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/README.md b/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/README.md new file mode 100644 index 000000000..7fdbb7475 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/README.md @@ -0,0 +1,58 @@ +## Depth Recurrence + Cross-Repeat Skip + Value Embeddings + +Beats naive baseline (1.2244) by 0.005 bpb using 3.1x fewer training steps through stateful depth recurrence. + +val_bpb = 1.2196 (sliding window eval on int8+zlib roundtrip model, stride=256) +val_bpb = 1.2533 (standard int8+zlib roundtrip) + +### Architecture + +Replaced the baseline's 9 unique transformer blocks with 3 shared blocks repeated 4 times (12 effective layers). Trades unique parameters for effective depth. + +Changes from baseline: +- Depth recurrence: 3 blocks x 4 repeats = 12 effective layers (vs 9 in baseline) +- Cross-Repeat Skip (original): each block gets a weighted residual of its own output from the previous repeat, turning stateless recurrence into stateful. Per-repeat learned scales, ~7.5K params total. +- Value Embeddings: 2 extra embedding tables mixed into the residual stream at each effective layer with learned scales. From snimu's modded-nanogpt record. +- Loop Embedding: learned per-layer vector added before each block as depth-wise positional encoding. +- Model dim 832 (vs 512), 8 heads, 4 KV heads, MLP 2x +- Removed U-Net skip connections (Cross-Repeat Skip covers this role) +- 17.14M params, 12.83MB artifact + +### Training + +LR x0.3 from baseline — recurrence amplifies gradients through 4 passes, so optimal LR is much lower. Found via sweep of 10 configs on RTX 3060. + +MATRIX_LR=0.012, SCALAR_LR=0.012, TIED_EMBED_LR=0.015, GRAD_CLIP_NORM=0.3, WARMDOWN_ITERS=3000, TRAIN_SEQ_LEN=1024. + +Tested train@2048 but 1024 gives more steps (133ms vs 253ms/step) which matters more for this architecture. Standard Muon + Adam. + +### Evaluation + +Sliding window eval: window=1024, stride=256 on the int8+zlib roundtrip model. Eval time 209s on 8xH100. + +### Results (8xH100, 600s wallclock) + +4494 steps, 133ms/step avg. Pre-quant 1.2487, roundtrip 1.2533, sliding window 1.2196. Artifact 12.83MB, quant degradation 0.005 bpb, peak memory ~29GB/GPU. + +### Ablations (RTX 3060, 2000 steps each) + +- Cross-Repeat Skip: -0.041 bpb +- Value Embeddings (2 tables): -0.079 bpb +- LR x0.3: -0.052 bpb +- Sliding window eval: -0.034 bpb +- WARMDOWN_ITERS=3000: -0.027 bpb + +### Development + +All experiments, ablations, and hyperparameter sweeps done on a single RTX 3060 12GB. Cloud GPUs (1xH200, 6xH100) used only for validation. Final run on 8xH100. + +### Command + +``` +RUN_ID=submission_8xh100 \ +QUANT_LEVELS=127 \ +TTT_STEPS=0 \ +EVAL_STRIDE=256 \ +EVAL_SEQ_LEN=1024 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` diff --git a/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/submission.json b/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/submission.json new file mode 100644 index 000000000..f04f129d1 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/submission.json @@ -0,0 +1,16 @@ +{ + "author": "Ivan Verbovoy", + "github_id": "iverbovoy", + "name": "Depth Recurrence + Cross-Repeat Skip + Value Embeddings + Sliding Window", + "blurb": "3 unique blocks x 4 repeats (12 effective layers), dim=832, with Cross-Repeat Skip (stateful recurrence), 2 Value Embedding tables, LR x0.3, sliding window eval (stride=256). 4494 steps in 600s on 8xH100.", + "date": "2026-03-20T02:00:00Z", + "val_loss": 2.05921204, + "val_bpb": 1.21958209, + "roundtrip_val_loss": 2.11612232, + "roundtrip_val_bpb": 1.25328684, + "step_stop": 4494, + "wallclock_seconds": 600.133, + "bytes_total": 12829176, + "bytes_model_int8_zlib": 12771121, + "bytes_code": 58055 +} diff --git a/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/train.log b/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/train.log new file mode 100644 index 000000000..d9a0c1529 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/train.log @@ -0,0 +1,84 @@ +W0320 00:54:42.000000 1050 torch/distributed/run.py:852] +W0320 00:54:42.000000 1050 torch/distributed/run.py:852] ***************************************** +W0320 00:54:42.000000 1050 torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0320 00:54:42.000000 1050 torch/distributed/run.py:852] ***************************************** +logs/submission_8xh100.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17140056 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.015 head_lr:0.0 matrix_lr:0.012 scalar_lr:0.012 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9766 val_bpb:4.1319 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9765 train_time:162ms step_avg:161.95ms +step:2/20000 train_loss:9.0581 train_time:218ms step_avg:109.04ms +step:3/20000 train_loss:7.8439 train_time:342ms step_avg:114.12ms +step:4/20000 train_loss:6.5913 train_time:466ms step_avg:116.40ms +step:5/20000 train_loss:6.1067 train_time:589ms step_avg:117.72ms +step:6/20000 train_loss:6.3514 train_time:712ms step_avg:118.70ms +step:7/20000 train_loss:5.9725 train_time:836ms step_avg:119.39ms +step:8/20000 train_loss:5.8139 train_time:958ms step_avg:119.78ms +step:9/20000 train_loss:5.5629 train_time:1081ms step_avg:120.13ms +step:10/20000 train_loss:5.3728 train_time:1206ms step_avg:120.64ms +step:200/20000 train_loss:2.7739 train_time:26609ms step_avg:133.05ms +step:400/20000 train_loss:2.3107 train_time:53543ms step_avg:133.86ms +step:600/20000 train_loss:2.5249 train_time:80122ms step_avg:133.54ms +step:800/20000 train_loss:2.2710 train_time:106824ms step_avg:133.53ms +step:1000/20000 train_loss:2.3610 train_time:133649ms step_avg:133.65ms +step:1000/20000 val_loss:2.3206 val_bpb:1.3744 train_time:133722ms step_avg:133.72ms +step:1200/20000 train_loss:2.3700 train_time:160457ms step_avg:133.71ms +step:1400/20000 train_loss:2.4196 train_time:187085ms step_avg:133.63ms +step:1600/20000 train_loss:2.0826 train_time:213643ms step_avg:133.53ms +step:1800/20000 train_loss:2.1817 train_time:240257ms step_avg:133.48ms +step:2000/20000 train_loss:2.2342 train_time:266823ms step_avg:133.41ms +step:2000/20000 val_loss:2.2137 val_bpb:1.3111 train_time:266903ms step_avg:133.45ms +step:2200/20000 train_loss:2.0469 train_time:293423ms step_avg:133.37ms +step:2400/20000 train_loss:2.1757 train_time:320078ms step_avg:133.37ms +step:2600/20000 train_loss:2.3756 train_time:346626ms step_avg:133.32ms +step:2800/20000 train_loss:2.2012 train_time:373394ms step_avg:133.35ms +step:3000/20000 train_loss:2.1910 train_time:400062ms step_avg:133.35ms +step:3000/20000 val_loss:2.1585 val_bpb:1.2784 train_time:400147ms step_avg:133.38ms +step:3200/20000 train_loss:2.1485 train_time:426762ms step_avg:133.36ms +step:3400/20000 train_loss:2.1171 train_time:453425ms step_avg:133.36ms +step:3600/20000 train_loss:2.0703 train_time:480073ms step_avg:133.35ms +step:3800/20000 train_loss:2.1774 train_time:506627ms step_avg:133.32ms +step:4000/20000 train_loss:2.1156 train_time:532930ms step_avg:133.23ms +step:4000/20000 val_loss:2.1201 val_bpb:1.2556 train_time:533004ms step_avg:133.25ms +step:4200/20000 train_loss:2.1277 train_time:561906ms step_avg:133.79ms +step:4400/20000 train_loss:2.0541 train_time:588700ms step_avg:133.80ms +step:4494/20000 val_loss:2.1084 val_bpb:1.2487 train_time:600133ms step_avg:133.54ms +stopping_early: wallclock_cap train_time:600133ms step:4494/20000 +peak memory allocated: 21771 MiB reserved: 21818 MiB +Serialized model: 63387167 bytes +Code size: 58055 bytes +Total submission size: 63445222 bytes +Serialized model int8+zlib: 12771121 bytes (payload:17243616 raw_torch:17261176 payload_ratio:3.68x) +Total submission size int8+zlib: 12829176 bytes +final_int8_zlib_roundtrip val_loss:2.1161 val_bpb:1.2533 eval_time:3709ms +final_int8_zlib_roundtrip_exact val_loss:2.11612232 val_bpb:1.25328684 +final_sliding_window val_loss:2.0592 val_bpb:1.2196 window:1024 stride:256 eval_time:209349ms +final_sliding_window_exact val_loss:2.05921204 val_bpb:1.21958209 diff --git a/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/train_gpt.py b/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/train_gpt.py new file mode 100644 index 000000000..aa83a930b --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/train_gpt.py @@ -0,0 +1,1365 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + ttt_steps = int(os.environ.get("TTT_STEPS", 0)) + ttt_lr = float(os.environ.get("TTT_LR", 1e-4)) + + # Sliding window eval. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 3)) + num_repeats = int(os.environ.get("NUM_REPEATS", 4)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 832)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + num_value_embeds = int(os.environ.get("NUM_VALUE_EMBEDS", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.015)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.012)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.012)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Sliding window eval: each window is eval_seq_len tokens, advancing by eval_stride. + Loss is scored only on the last eval_stride tokens per window.""" + seq_len = args.eval_seq_len + stride = args.eval_stride + total_tokens = val_tokens.numel() + + starts: list[int] = [] + pos = 0 + while pos + seq_len < total_tokens: + starts.append(pos) + pos += stride + total_windows = len(starts) + win_start = (total_windows * rank) // world_size + win_end = (total_windows * (rank + 1)) // world_size + score_offset = seq_len - stride + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.no_grad(): + for wi in range(win_start, win_end): + s = starts[wi] + window = val_tokens[s : s + seq_len + 1].to(device=device, dtype=torch.int64) + x = window[:-1].unsqueeze(0) + y = window[1:].unsqueeze(0) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(x) + + tail_logits = logits[0, score_offset:, :].float() + tail_targets = y[0, score_offset:] + per_token_loss = F.cross_entropy(tail_logits, tail_targets, reduction="none") + val_loss_sum += per_token_loss.to(torch.float64).sum() + val_token_count += float(stride) + + tail_prev = x[0, score_offset:] + tail_tgt = y[0, score_offset:] + token_bytes = base_bytes_lut[tail_tgt].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tail_tgt] & ~is_boundary_token_lut[tail_prev]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_ttt( + args: Hyperparameters, + base_model: nn.Module, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Test-Time Training: adapt the model on each validation batch before evaluating. + # For each batch: save weights → K gradient steps → evaluate → restore weights. + if args.ttt_steps <= 0: + return eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Save original weights once + saved_state = {k: v.detach().clone() for k, v in base_model.state_dict().items()} + + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + + # TTT: adapt on this batch + model.train() + for _ttt_step in range(args.ttt_steps): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ttt_loss = model(x, y) + ttt_loss.backward() + with torch.no_grad(): + for p in base_model.parameters(): + if p.grad is not None: + p -= args.ttt_lr * p.grad + p.grad = None + + # Evaluate with adapted model + model.eval() + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + # Restore original weights + base_model.load_state_dict(saved_state, strict=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +# Int6 quantization: ±31 instead of ±127. Stored as int8 but zlib compresses better. +QUANT_LEVELS = int(os.environ.get("QUANT_LEVELS", 127)) # 127 = int8, 31 = int6 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + ql = QUANT_LEVELS # 31 for int6, 127 for int8 + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / ql).clamp_min(1.0 / ql) + q = torch.clamp(torch.round(clipped / scale[:, None]), -ql, ql).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / ql if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -ql, ql).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_repeats: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + num_value_embeds: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_repeats = num_repeats + effective_depth = num_layers * num_repeats + self.tok_emb = nn.Embedding(vocab_size, model_dim) + # Value embeddings: extra embedding tables mixed into each effective layer + self.num_value_embeds = num_value_embeds + if num_value_embeds > 0: + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(num_value_embeds)]) + self.value_scales = nn.Parameter(torch.zeros(effective_depth, num_value_embeds, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + # Loop embedding: tells the model which effective layer it's at + self.loop_embed = nn.Parameter(torch.zeros(effective_depth, model_dim, dtype=torch.float32)) + # Cross-repeat skip: each block remembers its output from previous repeat + # Per-repeat scales (repeat 0 has no prev, so num_repeats-1 scales per block) + self.cross_repeat_scales = nn.Parameter(torch.zeros(num_layers, num_repeats - 1, model_dim, dtype=torch.float32)) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # Pre-compute value embeddings once + ve_list: list[Tensor] = [] + if self.num_value_embeds > 0: + for ve in self.value_embeds: + ve_list.append(ve(input_ids)) # (bsz, seq, dim) + + num_blocks = len(self.blocks) + prev_block_outputs: list[Tensor | None] = [None] * num_blocks + layer_idx = 0 + for repeat in range(self.num_repeats): + for block_idx, block in enumerate(self.blocks): + x = x + self.loop_embed[layer_idx].to(dtype=x.dtype) + # Value embeddings: add weighted extra embeddings at each layer + for ve_idx, ve_out in enumerate(ve_list): + vs = self.value_scales[layer_idx, ve_idx].to(dtype=x.dtype) + x = x + vs[None, None, :] * ve_out + # Cross-repeat skip: mix in this block's output from previous repeat + if repeat > 0 and prev_block_outputs[block_idx] is not None: + scale = self.cross_repeat_scales[block_idx, repeat - 1].to(dtype=x.dtype) + x = x + scale[None, None, :] * prev_block_outputs[block_idx] + x = block(x, x0) + prev_block_outputs[block_idx] = x.detach() if not self.training else x + layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + num_repeats=args.num_repeats, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + num_value_embeds=args.num_value_embeds, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params.append(base_model.loop_embed) + scalar_params.append(base_model.cross_repeat_scales) + if base_model.num_value_embeds > 0: + scalar_params.append(base_model.value_scales) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params = [base_model.tok_emb.weight] + if base_model.num_value_embeds > 0: + embed_params.extend(ve.weight for ve in base_model.value_embeds) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval + if args.eval_stride > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"window:{args.eval_seq_len} stride:{args.eval_stride} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # TTT eval: adapt model on each batch before evaluating + if args.ttt_steps > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt( + args, + base_model, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_ttt val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"ttt_steps:{args.ttt_steps} ttt_lr:{args.ttt_lr} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"final_ttt_exact val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-26_ProgressiveDepth/README.md b/records/track_10min_16mb/2026-03-26_ProgressiveDepth/README.md new file mode 100644 index 000000000..b0201bf74 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ProgressiveDepth/README.md @@ -0,0 +1,67 @@ +## Progressive Depth Training via Shared-Weight Recurrence + +val_bpb = **1.1980** (sliding window, stride=256, int8+zstd22 roundtrip) +val_bpb = 1.2315 (standard int8+zstd22 roundtrip) + +Progressive Depth is a training-time advantage unique to shared-weight recurrence — flat architectures cannot dynamically adjust their depth during training. + +Because the same 3 blocks are reused at every depth, we can start training with 2 repeats (fast, cheap steps), then progressively increase to 3 and 4 repeats as training progresses. The model learns coarse representations quickly at shallow depth, then refines them at full depth. This is structurally impossible with flat architectures where each layer has unique parameters — you cannot add or remove layers mid-training without changing the parameter space. + +### Progressive Depth Schedule + +| Phase | Time | Repeats | Eff. depth | ms/step | Steps | val_bpb at end | +|-------|------|---------|------------|---------|-------|----------------| +| 1 | 0–40% | 2 | 6 | ~75 | ~3200 | 1.319 | +| 2 | 40–65% | 3 | 9 | ~86 | ~1200 | 1.298 | +| 3 | 65–100% | 4 | 12 | ~96 | ~1800 | 1.229 | + +**Total: 5861 steps** in 600s vs ~4300 steps at constant depth 4 (+36% more gradient updates). + +SWA (Stochastic Weight Averaging) collects checkpoints only during Phase 3 at full depth to avoid mixing representations from different depths. 18 checkpoints averaged. + +### Ablation Trajectory + +Each change isolated and measured on 8xH100 (sliding window eval): + +| Change | val_bpb | Delta | +|--------|---------|-------| +| OpenAI Naive Baseline (9×512, unique layers) | 1.2244 | — | +| Depth Recurrence 3×4 + Cross-Repeat Skip (PR [#148](https://github.com/openai/parameter-golf/pull/148)) | 1.2213 | -0.003 | +| + XSA (Exclusive Self-Attention, last 4 layers) | 1.2110 | -0.010 | +| + LeakyReLU(0.5)² MLP | 1.2069 | -0.004 | +| + Progressive Depth (2→3→4 schedule) | 1.1980 | -0.009 | +| **Total** | **1.1980** | **-0.026** | + +### Cross-Repeat Skip (Novel, PR [#148](https://github.com/openai/parameter-golf/pull/148)) + +Standard depth recurrence is stateless — each repeat starts fresh with no memory of previous passes. Cross-Repeat Skip turns this into stateful recurrence: each block receives a weighted residual of its own output from the previous repeat. Per-block, per-repeat learned scales (~7.5K params). This gives the model a direct gradient path across repeats without the overhead of unique parameters. + +### Architecture + +- 3 shared blocks × 4 repeats = 12 effective layers +- dim=832, 8 heads, 4 KV heads (GQA), MLP 2×, tied embeddings +- **XSA**: Subtracts self-value projection from attention output on last 4 effective layers (reduces attention collapse in deep recurrence) +- **LeakyReLU(0.5)²**: Replaces ReLU² — preserves gradient flow on negative activations through 4 recurrence passes +- 2 Value Embedding tables with per-layer learned scales +- Loop Embedding (depth-wise positional encoding) +- Logit softcap=30, RoPE, RMSNorm +- GPTQ-lite int8 quantization (per-row clip percentile search) + zstd-22 compression +- 17.14M params, 15.88MB artifact + +### Training + +Muon optimizer (momentum=0.95, 5 Newton-Schulz steps, WD=0.04) for matrix params, Adam for scalars/embeddings. + +MATRIX_LR=0.012, SCALAR_LR=0.012, TIED_EMBED_LR=0.015, GRAD_CLIP_NORM=0.3, WARMDOWN_ITERS=3000. + +Phase switching synchronized across DDP ranks via `all_reduce` (max elapsed time) to prevent race conditions during `torch.compile` recompilation. + +### Command + +``` +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +### Results + +5861 steps, 600s on 8xH100. Roundtrip val_bpb 1.2315, sliding window 1.1980. Peak memory 25.5 GB/GPU. diff --git a/records/track_10min_16mb/2026-03-26_ProgressiveDepth/submission.json b/records/track_10min_16mb/2026-03-26_ProgressiveDepth/submission.json new file mode 100644 index 000000000..104d080fa --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ProgressiveDepth/submission.json @@ -0,0 +1,16 @@ +{ + "author": "Ivan Verbovoy", + "github_id": "iverbovoy", + "name": "Progressive Depth + Depth Recurrence + XSA + LeakyReLU\u00b2", + "blurb": "3 unique blocks with progressive depth scheduling (2\u21923\u21924 repeats), XSA on last 4 layers, LeakyReLU(0.5)\u00b2 MLP, SWA over 18 checkpoints, GPTQ-lite int8+zstd22 compression. 5861 steps in 600s on 8xH100.", + "date": "2026-03-26T07:40:00Z", + "val_loss": 2.02277954, + "val_bpb": 1.19800347, + "roundtrip_val_loss": 2.07939783, + "roundtrip_val_bpb": 1.23153652, + "step_stop": 5861, + "wallclock_seconds": 600.140, + "bytes_total": 15875591, + "bytes_model_int8_zstd22": 15815371, + "bytes_code": 60220 +} diff --git a/records/track_10min_16mb/2026-03-26_ProgressiveDepth/train.log b/records/track_10min_16mb/2026-03-26_ProgressiveDepth/train.log new file mode 100644 index 000000000..ea7b7f77f --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ProgressiveDepth/train.log @@ -0,0 +1,113 @@ +W0326 08:03:49.332000 20006 torch/distributed/run.py:793] +W0326 08:03:49.332000 20006 torch/distributed/run.py:793] ***************************************** +W0326 08:03:49.332000 20006 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 08:03:49.332000 20006 torch/distributed/run.py:793] ***************************************** +logs/b9c03e97-bd7a-4a2b-bcb0-89a9dbb80dd2.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17140056 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.015 head_lr:0.0 matrix_lr:0.012 scalar_lr:0.012 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +prog_depth: schedule=[(0.4, 2), (0.65, 3), (1.0, 4)] starting_repeats=2 +step:0/20000 val_loss:6.9300 val_bpb:4.1043 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9769 train_time:6848ms step_avg:6847.62ms +step:2/20000 train_loss:7.6698 train_time:6867ms step_avg:3433.58ms +step:3/20000 train_loss:7.5199 train_time:6936ms step_avg:2312.01ms +step:4/20000 train_loss:7.1738 train_time:7006ms step_avg:1751.49ms +step:5/20000 train_loss:6.6264 train_time:7077ms step_avg:1415.41ms +step:6/20000 train_loss:6.3184 train_time:7150ms step_avg:1191.64ms +step:7/20000 train_loss:5.8051 train_time:7223ms step_avg:1031.86ms +step:8/20000 train_loss:5.6215 train_time:7296ms step_avg:911.96ms +step:9/20000 train_loss:5.4769 train_time:7365ms step_avg:818.36ms +step:10/20000 train_loss:5.3601 train_time:7438ms step_avg:743.83ms +step:200/20000 train_loss:2.7618 train_time:21072ms step_avg:105.36ms +step:400/20000 train_loss:2.3134 train_time:35458ms step_avg:88.65ms +step:600/20000 train_loss:2.5245 train_time:49866ms step_avg:83.11ms +step:800/20000 train_loss:2.2887 train_time:64314ms step_avg:80.39ms +step:1000/20000 train_loss:2.3830 train_time:78796ms step_avg:78.80ms +step:1000/20000 val_loss:2.3413 val_bpb:1.3867 train_time:78837ms step_avg:78.84ms +step:1200/20000 train_loss:2.3953 train_time:93285ms step_avg:77.74ms +step:1400/20000 train_loss:2.4453 train_time:107763ms step_avg:76.97ms +step:1600/20000 train_loss:2.1185 train_time:122260ms step_avg:76.41ms +step:1800/20000 train_loss:2.2231 train_time:136732ms step_avg:75.96ms +step:2000/20000 train_loss:2.2778 train_time:151200ms step_avg:75.60ms +step:2000/20000 val_loss:2.2603 val_bpb:1.3387 train_time:151241ms step_avg:75.62ms +step:2200/20000 train_loss:2.1014 train_time:165670ms step_avg:75.30ms +step:2400/20000 train_loss:2.2244 train_time:180134ms step_avg:75.06ms +step:2600/20000 train_loss:2.4395 train_time:194595ms step_avg:74.84ms +step:2800/20000 train_loss:2.2726 train_time:209041ms step_avg:74.66ms +step:3000/20000 train_loss:2.2600 train_time:223486ms step_avg:74.50ms +step:3000/20000 val_loss:2.2269 val_bpb:1.3189 train_time:223527ms step_avg:74.51ms +step:3200/20000 train_loss:2.2188 train_time:237917ms step_avg:74.35ms +prog_depth: switched to 3 repeats at step:3229 frac:0.40 +step:3400/20000 train_loss:2.1932 train_time:279477ms step_avg:82.20ms +step:3600/20000 train_loss:2.1460 train_time:300629ms step_avg:83.51ms +step:3800/20000 train_loss:2.2472 train_time:321879ms step_avg:84.70ms +step:4000/20000 train_loss:2.1847 train_time:343064ms step_avg:85.77ms +step:4000/20000 val_loss:2.1917 val_bpb:1.2981 train_time:343131ms step_avg:85.78ms +step:4200/20000 train_loss:2.1871 train_time:364232ms step_avg:86.72ms +step:4400/20000 train_loss:2.1208 train_time:385376ms step_avg:87.59ms +prog_depth: switched to 4 repeats at step:4444 frac:0.65 +step:4600/20000 train_loss:1.9634 train_time:423021ms step_avg:91.96ms +step:4800/20000 train_loss:2.2479 train_time:450939ms step_avg:93.95ms +step:5000/20000 train_loss:1.9975 train_time:478895ms step_avg:95.78ms +step:5000/20000 val_loss:2.1211 val_bpb:1.2562 train_time:478979ms step_avg:95.80ms +swa:start step:5050 +step:5200/20000 train_loss:2.1314 train_time:507384ms step_avg:97.57ms +step:5400/20000 train_loss:2.1322 train_time:535401ms step_avg:99.15ms +step:5600/20000 train_loss:2.1209 train_time:563472ms step_avg:100.62ms +step:5800/20000 train_loss:2.0748 train_time:591529ms step_avg:101.99ms +step:5861/20000 val_loss:2.0758 val_bpb:1.2294 train_time:600140ms step_avg:102.40ms +stopping_early: wallclock_cap train_time:600140ms step:5861/20000 +peak memory allocated: 25539 MiB reserved: 26118 MiB +swa: averaging 18 checkpoints +Serialized model: 63386762 bytes +Code size: 60220 bytes +Total submission size: 63446982 bytes +Serialized model int8+zstd22: 15815371 bytes (payload:17243616 raw_torch:17260843 payload_ratio:3.68x) +Total submission size int8+zstd22: 15875591 bytes +/workspace/parameter-golf/train_gpt.py:1334: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1334: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1334: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1334: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1334: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1334: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1334: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1334: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +final_roundtrip val_loss:2.0794 val_bpb:1.2315 eval_time:14103ms +final_roundtrip_exact val_loss:2.07939783 val_bpb:1.23153652 +final_sliding_window val_loss:2.0228 val_bpb:1.1980 window:1024 stride:256 eval_time:66815ms +final_sliding_window_exact val_loss:2.02277954 val_bpb:1.19800347 diff --git a/records/track_10min_16mb/2026-03-26_ProgressiveDepth/train_gpt.py b/records/track_10min_16mb/2026-03-26_ProgressiveDepth/train_gpt.py new file mode 100644 index 000000000..e45fdfc2e --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ProgressiveDepth/train_gpt.py @@ -0,0 +1,1386 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zstandard as zstd +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# HYPERPARAMETERS + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + + # Progressive Depth: train with fewer repeats early (faster), more repeats later (deeper). + # Schedule format: "frac1:rep1,frac2:rep2,..." e.g. "0.4:2,0.65:3,1.0:4" + prog_depth_schedule = os.environ.get("PROG_DEPTH", "0.4:2,0.65:3,1.0:4") + + # XSA (Exclusive Self-Attention) on last N effective layers. + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + + # SWA (Stochastic Weight Averaging) during warmdown. + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Sliding window eval. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 1024)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 3)) + num_repeats = int(os.environ.get("NUM_REPEATS", 4)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 832)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + num_value_embeds = int(os.environ.get("NUM_VALUE_EMBEDS", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.015)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.012)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.012)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + +# MUON OPTIMIZER +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Sliding window eval with batching. Windows of train_seq_len advance by eval_stride. + Only the last stride tokens per window are scored (first window scores all).""" + seq_len = args.eval_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi : bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws : end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + val_loss_sum += scored_nll.sum() + val_token_count += float(wlen - s) + prev_ids = x_batch[i, s:wlen] + tgt_ids = y_batch[i, s:wlen] + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + + +# POST-TRAINING QUANTIZATION +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing and zstd compressing. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +# Quantization levels: 127 = int8, 31 = int6, 16 = int5. Per-tensor override via MLP_QUANT_LEVELS. +QUANT_LEVELS = int(os.environ.get("QUANT_LEVELS", 127)) +MLP_QUANT_LEVELS = int(os.environ.get("MLP_QUANT_LEVELS", 0)) # 0 = same as QUANT_LEVELS +MLP_TENSOR_PATTERNS = ("mlp.fc.", "mlp.proj.", "fc.weight", "mlp.proj.weight") + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 0.999999] + +def quantize_float_tensor(t: Tensor, ql: int = 0) -> tuple[Tensor, Tensor]: + if ql <= 0: + ql = QUANT_LEVELS + t32 = t.float() + if t32.ndim == 2: + # GPTQ-lite: try multiple clip percentiles per row, pick best MSE + abs_t = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in GPTQ_LITE_PERCENTILES: + clip_abs = ( + torch.quantile(abs_t, pct, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + s = (clip_abs / ql).clamp_min(1e-12) + q = torch.clamp(torch.round(clipped / s[:, None]), -ql, ql) + # Reconstruction error per row + recon = q * s[:, None] + mse = (t32 - recon).square().sum(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = s + else: + better = mse < best_mse + best_mse = torch.where(better, mse, best_mse) + best_q = torch.where(better[:, None], q, best_q) + best_scale = torch.where(better, s, best_scale) + return best_q.to(torch.int8).contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / ql if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -ql, ql).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + mlp_ql = MLP_QUANT_LEVELS if MLP_QUANT_LEVELS > 0 else QUANT_LEVELS + ql = mlp_ql if any(p in name for p in MLP_TENSOR_PATTERNS) else QUANT_LEVELS + q, s = quantize_float_tensor(t, ql=ql) + meta: dict[str, object] = {} + if s.ndim > 0: + meta["scheme"] = "per_row" + meta["axis"] = 0 + if ql != QUANT_LEVELS: + meta["ql"] = ql + if meta: + qmeta[name] = meta + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# DATA LOADING + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# TRANSFORMER MODULES + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def _xsa(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection from attention output (GQA-aware).""" + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(3) # [B, T, Hkv, 1, D] + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, use_xsa: bool = False) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + # XSA: remove self-value bias from attention output + if use_xsa: + y = y.transpose(1, 2).contiguous() # [B, T, H, D] + v_for_xsa = v.transpose(1, 2) # [B, T, Hkv, D] + y = self._xsa(y, v_for_xsa) + y = y.reshape(bsz, seqlen, dim) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu(0.5)^2 MLP — better gradient flow than relu^2 for deep/recurrent models + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, use_xsa: bool = False) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x), use_xsa=use_xsa) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_repeats: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + num_value_embeds: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_repeats = num_repeats + self.xsa_last_n = xsa_last_n + effective_depth = num_layers * num_repeats + self.tok_emb = nn.Embedding(vocab_size, model_dim) + # Value embeddings: extra embedding tables mixed into each effective layer + self.num_value_embeds = num_value_embeds + if num_value_embeds > 0: + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(num_value_embeds)]) + self.value_scales = nn.Parameter(torch.zeros(effective_depth, num_value_embeds, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + # Loop embedding: tells the model which effective layer it's at + self.loop_embed = nn.Parameter(torch.zeros(effective_depth, model_dim, dtype=torch.float32)) + # Cross-repeat skip: each block receives its own output from previous repeat + self.cross_repeat_scales = nn.Parameter(torch.zeros(num_layers, num_repeats - 1, model_dim, dtype=torch.float32)) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # Pre-compute value embeddings once + ve_list: list[Tensor] = [] + if self.num_value_embeds > 0: + for ve in self.value_embeds: + ve_list.append(ve(input_ids)) # (bsz, seq, dim) + + cur_repeats = self.cur_repeats if hasattr(self, "cur_repeats") else self.num_repeats + cur_depth = len(self.blocks) * cur_repeats + xsa_start = max(0, cur_depth - self.xsa_last_n) if self.xsa_last_n > 0 else cur_depth + + num_blocks = len(self.blocks) + prev_block_outputs: list[Tensor | None] = [None] * num_blocks + layer_idx = 0 + for repeat in range(cur_repeats): + for block_idx, block in enumerate(self.blocks): + x = x + self.loop_embed[layer_idx].to(dtype=x.dtype) + # Value embeddings: add weighted extra embeddings at each layer + if layer_idx < self.value_scales.size(0): + for ve_idx, ve_out in enumerate(ve_list): + vs = self.value_scales[layer_idx, ve_idx].to(dtype=x.dtype) + x = x + vs[None, None, :] * ve_out + # Cross-repeat skip: mix in this block's output from previous repeat + if repeat > 0 and prev_block_outputs[block_idx] is not None: + rep_idx = min(repeat - 1, self.cross_repeat_scales.size(1) - 1) + scale = self.cross_repeat_scales[block_idx, rep_idx].to(dtype=x.dtype) + x = x + scale[None, None, :] * prev_block_outputs[block_idx] + x = block(x, x0, use_xsa=(layer_idx >= xsa_start)) + prev_block_outputs[block_idx] = x.detach() if not self.training else x + layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# TRAINING + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # DISTRIBUTED + CUDA SETUP + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # TOKENIZER + VALIDATION METRIC SETUP + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + num_repeats=args.num_repeats, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + num_value_embeds=args.num_value_embeds, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params.append(base_model.loop_embed) + scalar_params.append(base_model.cross_repeat_scales) + if base_model.num_value_embeds > 0: + scalar_params.append(base_model.value_scales) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params = [base_model.tok_emb.weight] + if base_model.num_value_embeds > 0: + embed_params.extend(ve.weight for ve in base_model.value_embeds) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # Progressive depth schedule: parse "frac:repeats,..." and sort + prog_phases: list[tuple[float, int]] = [] + for entry in args.prog_depth_schedule.split(","): + frac_s, rep_s = entry.strip().split(":") + prog_phases.append((float(frac_s), int(rep_s))) + prog_phases.sort() + current_phase_repeats = prog_phases[0][1] if prog_phases else args.num_repeats + base_model.cur_repeats = current_phase_repeats + # Recompile with initial phase depth + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"prog_depth: schedule={prog_phases} starting_repeats={current_phase_repeats}") + + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + # Progressive depth: check if we need to switch phase + # Use synchronized elapsed time (max across ranks) to avoid race conditions + if max_wallclock_ms is not None and prog_phases: + if distributed: + elapsed_tensor = torch.tensor(elapsed_ms, device=device) + dist.all_reduce(elapsed_tensor, op=dist.ReduceOp.MAX) + frac = elapsed_tensor.item() / max_wallclock_ms + else: + frac = elapsed_ms / max_wallclock_ms + new_repeats = prog_phases[-1][1] # default to last + for phase_frac, phase_rep in prog_phases: + if frac < phase_frac: + new_repeats = phase_rep + break + if new_repeats != current_phase_repeats: + current_phase_repeats = new_repeats + base_model.cur_repeats = new_repeats + torch._dynamo.reset() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"prog_depth: switched to {new_repeats} repeats at step:{step} frac:{frac:.2f}") + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown (only at full depth to avoid mixing phases) + at_full_depth = current_phase_repeats == args.num_repeats + if args.swa_enabled and at_full_depth and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().float() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Restore full depth for eval/export + base_model.cur_repeats = args.num_repeats + torch._dynamo.reset() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None: + # Include final weights (may not have landed on swa_every boundary) + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + log0(f"swa: averaging {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed quantized+zstd artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + zstd_level = int(os.environ.get("ZSTD_LEVEL", 22)) + cctx = zstd.ZstdCompressor(level=zstd_level) + quant_blob = cctx.compress(quant_raw) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zstd{zstd_level}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zstd{zstd_level}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + dctx = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval + if args.eval_stride > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"window:{args.eval_seq_len} stride:{args.eval_stride} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/README.md b/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/README.md new file mode 100644 index 000000000..3995ce744 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/README.md @@ -0,0 +1,62 @@ +## Progressive Depth + Hedge Mixer + +val_bpb = **1.1454** (Hedge Mixer eval, int8+zstd22 roundtrip model) +val_bpb = 1.1966 (sliding window only) +val_bpb = 1.2304 (standard roundtrip) + +### Hedge Mixer: 5-Expert Online Ensemble + +Eval-time improvement via online mixture of 5 experts using the Hedge algorithm (multiplicative weights). No training data access — n-gram tables built from already-scored tokens only. + +| Expert | Source | Role | +|--------|--------|------| +| Neural | Model softmax output | Primary prediction | +| Unigram | Token frequency from scored data | Frequency prior | +| Bigram | P(next\|prev) from scored data | Local context | +| Trigram | Hash table (64K buckets) from scored data | Extended context | +| Entropy | Model confidence weighting | Calibration | + +Weights initialized with neural bias (log_weight=2.0), updated via `log_w -= eta * expert_mean_loss` after each batch. The mixer is cold-started (uses pure neural output until 10K tokens scored), then progressively improves as n-gram statistics accumulate. + +**Impact: -0.051 bpb** over sliding window eval (1.1966 → 1.1454). This is larger than all architectural improvements combined. + +Eval time: 579s on 8xH100 (sequential processing required for n-gram table consistency). + +### Architecture (unchanged from PR #835) + +3 shared transformer blocks with depth recurrence, progressive depth scheduling unique to shared-weight recurrence. + +- **Progressive Depth Training**: Phase 1 (0-40%): 2 repeats ~75ms/step. Phase 2 (40-65%): 3 repeats ~86ms/step. Phase 3 (65-100%): 4 repeats ~96ms/step. 5673 steps in 600s. +- **Cross-Repeat Skip** (#148, Novel): Stateful recurrence — each block receives weighted residual from previous repeat. +- **XSA**: Exclusive Self-Attention on last 4 effective layers. +- **LeakyReLU(0.5)²**: Better gradient flow through 4-repeat recurrence. +- dim=832, 8 heads, 4 KV heads (GQA), MLP 2×, tied embeddings, SWA (18 checkpoints). +- 17.14M params, 15.88MB artifact (int8+zstd22). + +### Tuned Hyperparameters + +MATRIX_LR=0.018, SCALAR_LR=0.018, TIED_EMBED_LR=0.021, WARMDOWN_ITERS=2000. + +Higher LR compensates for progressive depth's shallow early phases. Shorter warmdown gives full LR at full-depth entry. + +### Ablation Trajectory + +| Change | val_bpb | Delta | +|--------|---------|-------| +| OpenAI Naive Baseline | 1.2244 | — | +| Depth Recurrence 3×4 + Cross-Repeat Skip (#148) | 1.2213 | -0.003 | +| + XSA + LeakyReLU² (#784) | 1.2069 | -0.014 | +| + Progressive Depth (#835) | 1.1980 | -0.009 | +| + LR/Warmdown tuning | 1.1960 | -0.002 | +| + Hedge Mixer (eval) | 1.1454 | -0.051 | +| **Total** | **1.1454** | **-0.079** | + +### Command + +``` +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +### Credits + +Hedge Mixer algorithm adapted from PR #688 (@RoyiRa) and PR #745 (@stukenov). diff --git a/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/submission.json b/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/submission.json new file mode 100644 index 000000000..9661bb9c4 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/submission.json @@ -0,0 +1,19 @@ +{ + "author": "Ivan Verbovoy", + "github_id": "iverbovoy", + "name": "Progressive Depth + Hedge Mixer (5-expert online ensemble)", + "blurb": "3 unique blocks with progressive depth scheduling (2\u21923\u21924 repeats), XSA, LeakyReLU\u00b2, Cross-Repeat Skip, SWA, int8+zstd22. Eval: 5-expert Hedge Mixer (neural + unigram + bigram + trigram + entropy) with online multiplicative weight updates. 5673 steps in 600s train, 579s eval on 8xH100.", + "date": "2026-03-26T15:00:00Z", + "val_loss": 1.93403169, + "val_bpb": 1.14544202, + "roundtrip_val_loss": 2.07744208, + "roundtrip_val_bpb": 1.23037822, + "sliding_val_loss": 2.02046173, + "sliding_val_bpb": 1.19663074, + "step_stop": 5673, + "wallclock_seconds": 600.218, + "eval_seconds": 579.109, + "bytes_total": 15884272, + "bytes_model_int8_zstd22": 15818418, + "bytes_code": 65854 +} diff --git a/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/train.log b/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/train.log new file mode 100644 index 000000000..221bf19a8 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/train.log @@ -0,0 +1,114 @@ +W0326 14:09:06.471000 2358 torch/distributed/run.py:793] +W0326 14:09:06.471000 2358 torch/distributed/run.py:793] ***************************************** +W0326 14:09:06.471000 2358 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 14:09:06.471000 2358 torch/distributed/run.py:793] ***************************************** +logs/07b6a996-fe2d-47a4-a5a5-24bf61fec8f0.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17140056 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.021 head_lr:0.0 matrix_lr:0.018 scalar_lr:0.018 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +prog_depth: schedule=[(0.4, 2), (0.65, 3), (1.0, 4)] starting_repeats=2 +step:0/20000 val_loss:6.9300 val_bpb:4.1043 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9769 train_time:20810ms step_avg:20809.51ms +step:2/20000 train_loss:9.6250 train_time:20829ms step_avg:10414.40ms +step:3/20000 train_loss:9.4925 train_time:20897ms step_avg:6965.53ms +step:4/20000 train_loss:9.1975 train_time:20968ms step_avg:5241.90ms +step:5/20000 train_loss:8.6451 train_time:21039ms step_avg:4207.76ms +step:6/20000 train_loss:8.1740 train_time:21110ms step_avg:3518.30ms +step:7/20000 train_loss:7.2979 train_time:21182ms step_avg:3026.05ms +step:8/20000 train_loss:6.6939 train_time:21255ms step_avg:2656.83ms +step:9/20000 train_loss:6.1779 train_time:21331ms step_avg:2370.13ms +step:10/20000 train_loss:5.8322 train_time:21400ms step_avg:2140.02ms +step:200/20000 train_loss:2.7613 train_time:35036ms step_avg:175.18ms +step:400/20000 train_loss:2.3100 train_time:49402ms step_avg:123.50ms +step:600/20000 train_loss:2.5366 train_time:63819ms step_avg:106.36ms +step:800/20000 train_loss:2.2979 train_time:78288ms step_avg:97.86ms +step:1000/20000 train_loss:2.3821 train_time:92765ms step_avg:92.76ms +step:1000/20000 val_loss:2.3450 val_bpb:1.3888 train_time:92807ms step_avg:92.81ms +step:1200/20000 train_loss:2.4011 train_time:107254ms step_avg:89.38ms +step:1400/20000 train_loss:2.4523 train_time:121727ms step_avg:86.95ms +step:1600/20000 train_loss:2.1223 train_time:136207ms step_avg:85.13ms +step:1800/20000 train_loss:2.2266 train_time:150673ms step_avg:83.71ms +step:2000/20000 train_loss:2.2854 train_time:165142ms step_avg:82.57ms +step:2000/20000 val_loss:2.2671 val_bpb:1.3427 train_time:165184ms step_avg:82.59ms +step:2200/20000 train_loss:2.1072 train_time:179607ms step_avg:81.64ms +step:2400/20000 train_loss:2.2328 train_time:194078ms step_avg:80.87ms +step:2600/20000 train_loss:2.4458 train_time:208531ms step_avg:80.20ms +step:2800/20000 train_loss:2.2812 train_time:222979ms step_avg:79.64ms +step:3000/20000 train_loss:2.2707 train_time:237400ms step_avg:79.13ms +step:3000/20000 val_loss:2.2365 val_bpb:1.3246 train_time:237442ms step_avg:79.15ms +prog_depth: switched to 3 repeats at step:3036 frac:0.40 +step:3200/20000 train_loss:2.2283 train_time:278657ms step_avg:87.08ms +step:3400/20000 train_loss:2.1915 train_time:299788ms step_avg:88.17ms +step:3600/20000 train_loss:2.1526 train_time:320953ms step_avg:89.15ms +step:3800/20000 train_loss:2.2486 train_time:342113ms step_avg:90.03ms +step:4000/20000 train_loss:2.1951 train_time:363281ms step_avg:90.82ms +step:4000/20000 val_loss:2.2001 val_bpb:1.3030 train_time:363349ms step_avg:90.84ms +step:4200/20000 train_loss:2.2068 train_time:384450ms step_avg:91.54ms +prog_depth: switched to 4 repeats at step:4252 frac:0.65 +step:4400/20000 train_loss:2.1355 train_time:421875ms step_avg:95.88ms +step:4600/20000 train_loss:1.9747 train_time:449821ms step_avg:97.79ms +step:4800/20000 train_loss:2.2530 train_time:477877ms step_avg:99.56ms +step:5000/20000 train_loss:2.0057 train_time:505841ms step_avg:101.17ms +step:5000/20000 val_loss:2.1261 val_bpb:1.2592 train_time:505925ms step_avg:101.19ms +swa:start step:5100 +step:5200/20000 train_loss:2.1309 train_time:533773ms step_avg:102.65ms +step:5400/20000 train_loss:2.1258 train_time:561825ms step_avg:104.04ms +step:5600/20000 train_loss:2.1075 train_time:589899ms step_avg:105.34ms +step:5673/20000 val_loss:2.0739 val_bpb:1.2283 train_time:600218ms step_avg:105.80ms +stopping_early: wallclock_cap train_time:600218ms step:5673/20000 +peak memory allocated: 25696 MiB reserved: 27322 MiB +swa: averaging 13 checkpoints +Serialized model: 63386762 bytes +Code size: 65854 bytes +Total submission size: 63452616 bytes +Serialized model int8+zstd22: 15818418 bytes (payload:17243616 raw_torch:17260843 payload_ratio:3.68x) +Total submission size int8+zstd22: 15884272 bytes +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +final_roundtrip val_loss:2.0774 val_bpb:1.2304 eval_time:13407ms +final_roundtrip_exact val_loss:2.07744208 val_bpb:1.23037822 +final_sliding_window val_loss:2.0205 val_bpb:1.1966 window:1024 stride:256 eval_time:66781ms +final_sliding_window_exact val_loss:2.02046173 val_bpb:1.19663074 +final_hedge_mixer val_loss:1.9340 val_bpb:1.1454 eval_time:579109ms +final_hedge_mixer_exact val_loss:1.93403169 val_bpb:1.14544202 diff --git a/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/train_gpt.py b/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/train_gpt.py new file mode 100644 index 000000000..1738288f3 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/train_gpt.py @@ -0,0 +1,1498 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zstandard as zstd +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + + +class HedgeMixer: + """Online mixture of 5 experts via Hedge algorithm for eval-time improvement. + Experts: Neural, Unigram, Bigram, Trigram (hashed), Entropy.""" + def __init__(self, vocab_size: int = 1024, device: str = "cuda", eta: float = 0.1): + self.V = vocab_size + self.device = device + self.eta = eta + self.log_weights = torch.zeros(5, device=device) + self.log_weights[0] = 2.0 # bias toward neural + self.uni_counts = torch.zeros(vocab_size, device=device) + self.bi_counts = torch.zeros(vocab_size, vocab_size, device=device) + self.total_tokens = 0 + self.TRI_HASH = 65536 + self.tri_counts = torch.zeros(self.TRI_HASH, vocab_size, device=device) + self.tri_row_totals = torch.zeros(self.TRI_HASH, device=device) + + def update(self, tokens: Tensor) -> None: + t = tokens.to(self.device).long() + n = t.numel() + if n == 0: + return + self.total_tokens += n + ones = torch.ones(n, device=self.device) + self.uni_counts.scatter_add_(0, t, ones) + if n >= 2: + bi_idx = t[:-1] * self.V + t[1:] + self.bi_counts.reshape(-1).scatter_add_(0, bi_idx, torch.ones(n - 1, device=self.device)) + if n >= 3: + tri_ctx = ((t[:-2] * 36313) ^ (t[1:-1] * 27191)) % self.TRI_HASH + tri_idx = tri_ctx * self.V + t[2:] + ones_tri = torch.ones(n - 2, device=self.device) + self.tri_counts.reshape(-1).scatter_add_(0, tri_idx, ones_tri) + self.tri_row_totals.scatter_add_(0, tri_ctx, ones_tri) + + def mix_and_score(self, neural_logits: Tensor, x_batch: Tensor, y_batch: Tensor, wlens: list[int]) -> Tensor: + bsz, slen, V = neural_logits.shape + uniform_nll = math.log(self.V) + has_data = self.total_tokens > 0 + neural_lp = F.log_softmax(neural_logits, dim=-1) + neural_nll = -neural_lp.gather(2, y_batch.unsqueeze(2)).squeeze(2) + if not has_data or self.total_tokens < 10000: + return neural_nll + uni_probs = (self.uni_counts + 0.1) / (self.total_tokens + 0.1 * self.V) + uni_nll = -uni_probs.log()[y_batch] + bi_total = self.bi_counts.sum(dim=1, keepdim=True) + bi_probs = (self.bi_counts + 0.1) / (bi_total + 0.1 * self.V) + bi_nll = -bi_probs.log()[x_batch.reshape(-1), y_batch.reshape(-1)].reshape(bsz, slen) + if slen >= 2: + prev2 = torch.zeros_like(x_batch) + prev2[:, 1:] = x_batch[:, :-1] + ctx_hash = ((prev2 * 36313) ^ (x_batch * 27191)) % self.TRI_HASH + tri_count = self.tri_counts[ctx_hash.reshape(-1).long(), y_batch.reshape(-1).long()] + tri_total = self.tri_row_totals[ctx_hash.reshape(-1).long()].clamp(min=1) + tri_nll = -(((tri_count + 0.01) / (tri_total + 0.01 * self.V)).log()).reshape(bsz, slen) + else: + tri_nll = torch.full((bsz, slen), uniform_nll, device=self.device) + entropy_nll = -(neural_lp.exp() * neural_lp).sum(-1) + expert_nll = torch.stack([neural_nll, uni_nll, bi_nll, tri_nll, entropy_nll], dim=-1) + log_w = self.log_weights - self.log_weights.logsumexp(0) + mixed_nll = -(-expert_nll + log_w.unsqueeze(0).unsqueeze(0)).logsumexp(dim=-1) + # Update weights + wlens_t = torch.tensor(wlens, device=self.device, dtype=torch.long) + mask = torch.arange(slen, device=self.device).unsqueeze(0) < wlens_t.unsqueeze(1) + masked_nll = expert_nll * mask.unsqueeze(-1).float() + expert_mean_loss = masked_nll.sum(dim=(0, 1)) / mask.sum().clamp(min=1) + self.log_weights -= self.eta * expert_mean_loss + return mixed_nll + + +# HYPERPARAMETERS + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + + # Progressive Depth: train with fewer repeats early (faster), more repeats later (deeper). + # Schedule format: "frac1:rep1,frac2:rep2,..." e.g. "0.4:2,0.65:3,1.0:4" + prog_depth_schedule = os.environ.get("PROG_DEPTH", "0.4:2,0.65:3,1.0:4") + + # XSA (Exclusive Self-Attention) on last N effective layers. + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + + # SWA (Stochastic Weight Averaging) during warmdown. + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Sliding window eval. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 1024)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + # Hedge Mixer (eval-time n-gram ensemble). + use_hedge = bool(int(os.environ.get("USE_HEDGE", "1"))) + hedge_eta = float(os.environ.get("HEDGE_ETA", 0.1)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 3)) + num_repeats = int(os.environ.get("NUM_REPEATS", 4)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 832)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + num_value_embeds = int(os.environ.get("NUM_VALUE_EMBEDS", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.021)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.018)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.018)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + +# MUON OPTIMIZER +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + use_hedge: bool = False, + hedge_eta: float = 0.1, +) -> tuple[float, float]: + """Sliding window eval with batching. Windows of train_seq_len advance by eval_stride. + Only the last stride tokens per window are scored (first window scores all). + Optional Hedge Mixer: online n-gram ensemble over scored tokens.""" + seq_len = args.eval_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + + # With Hedge Mixer: process ALL windows on each rank (sequential, n-gram tables need full context) + # Without: distribute windows across ranks + if use_hedge: + my_windows = window_starts + else: + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + mixer = HedgeMixer(vocab_size=args.vocab_size, device=device, eta=hedge_eta) if use_hedge else None + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi : bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws : end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(x_batch) + + if mixer is not None: + nll = mixer.mix_and_score(logits.float(), x_batch, y_batch, wlens) + else: + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + val_loss_sum += scored_nll.sum() + val_token_count += float(wlen - s) + prev_ids = x_batch[i, s:wlen] + tgt_ids = y_batch[i, s:wlen] + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + # Update n-gram tables with scored tokens + if mixer is not None: + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + mixer.update(y_batch[i, s:wlen]) + + if not use_hedge and dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + + +# POST-TRAINING QUANTIZATION +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing and zstd compressing. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +# Quantization levels: 127 = int8, 31 = int6, 16 = int5. Per-tensor override via MLP_QUANT_LEVELS. +QUANT_LEVELS = int(os.environ.get("QUANT_LEVELS", 127)) +MLP_QUANT_LEVELS = int(os.environ.get("MLP_QUANT_LEVELS", 0)) # 0 = same as QUANT_LEVELS +MLP_TENSOR_PATTERNS = ("mlp.fc.", "mlp.proj.", "fc.weight", "mlp.proj.weight") + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 0.999999] + +def quantize_float_tensor(t: Tensor, ql: int = 0) -> tuple[Tensor, Tensor]: + if ql <= 0: + ql = QUANT_LEVELS + t32 = t.float() + if t32.ndim == 2: + # GPTQ-lite: try multiple clip percentiles per row, pick best MSE + abs_t = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in GPTQ_LITE_PERCENTILES: + clip_abs = ( + torch.quantile(abs_t, pct, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + s = (clip_abs / ql).clamp_min(1e-12) + q = torch.clamp(torch.round(clipped / s[:, None]), -ql, ql) + # Reconstruction error per row + recon = q * s[:, None] + mse = (t32 - recon).square().sum(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = s + else: + better = mse < best_mse + best_mse = torch.where(better, mse, best_mse) + best_q = torch.where(better[:, None], q, best_q) + best_scale = torch.where(better, s, best_scale) + return best_q.to(torch.int8).contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / ql if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -ql, ql).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + mlp_ql = MLP_QUANT_LEVELS if MLP_QUANT_LEVELS > 0 else QUANT_LEVELS + ql = mlp_ql if any(p in name for p in MLP_TENSOR_PATTERNS) else QUANT_LEVELS + q, s = quantize_float_tensor(t, ql=ql) + meta: dict[str, object] = {} + if s.ndim > 0: + meta["scheme"] = "per_row" + meta["axis"] = 0 + if ql != QUANT_LEVELS: + meta["ql"] = ql + if meta: + qmeta[name] = meta + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# DATA LOADING + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# TRANSFORMER MODULES + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def _xsa(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection from attention output (GQA-aware).""" + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(3) # [B, T, Hkv, 1, D] + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, use_xsa: bool = False) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + # XSA: remove self-value bias from attention output + if use_xsa: + y = y.transpose(1, 2).contiguous() # [B, T, H, D] + v_for_xsa = v.transpose(1, 2) # [B, T, Hkv, D] + y = self._xsa(y, v_for_xsa) + y = y.reshape(bsz, seqlen, dim) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu(0.5)^2 MLP — better gradient flow than relu^2 for deep/recurrent models + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, use_xsa: bool = False) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x), use_xsa=use_xsa) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_repeats: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + num_value_embeds: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_repeats = num_repeats + self.xsa_last_n = xsa_last_n + effective_depth = num_layers * num_repeats + self.tok_emb = nn.Embedding(vocab_size, model_dim) + # Value embeddings: extra embedding tables mixed into each effective layer + self.num_value_embeds = num_value_embeds + if num_value_embeds > 0: + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(num_value_embeds)]) + self.value_scales = nn.Parameter(torch.zeros(effective_depth, num_value_embeds, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + # Loop embedding: tells the model which effective layer it's at + self.loop_embed = nn.Parameter(torch.zeros(effective_depth, model_dim, dtype=torch.float32)) + # Cross-repeat skip: each block receives its own output from previous repeat + self.cross_repeat_scales = nn.Parameter(torch.zeros(num_layers, num_repeats - 1, model_dim, dtype=torch.float32)) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # Pre-compute value embeddings once + ve_list: list[Tensor] = [] + if self.num_value_embeds > 0: + for ve in self.value_embeds: + ve_list.append(ve(input_ids)) # (bsz, seq, dim) + + cur_repeats = self.cur_repeats if hasattr(self, "cur_repeats") else self.num_repeats + cur_depth = len(self.blocks) * cur_repeats + xsa_start = max(0, cur_depth - self.xsa_last_n) if self.xsa_last_n > 0 else cur_depth + + num_blocks = len(self.blocks) + prev_block_outputs: list[Tensor | None] = [None] * num_blocks + layer_idx = 0 + for repeat in range(cur_repeats): + for block_idx, block in enumerate(self.blocks): + x = x + self.loop_embed[layer_idx].to(dtype=x.dtype) + # Value embeddings: add weighted extra embeddings at each layer + if layer_idx < self.value_scales.size(0): + for ve_idx, ve_out in enumerate(ve_list): + vs = self.value_scales[layer_idx, ve_idx].to(dtype=x.dtype) + x = x + vs[None, None, :] * ve_out + # Cross-repeat skip: mix in this block's output from previous repeat + if repeat > 0 and prev_block_outputs[block_idx] is not None: + rep_idx = min(repeat - 1, self.cross_repeat_scales.size(1) - 1) + scale = self.cross_repeat_scales[block_idx, rep_idx].to(dtype=x.dtype) + x = x + scale[None, None, :] * prev_block_outputs[block_idx] + x = block(x, x0, use_xsa=(layer_idx >= xsa_start)) + prev_block_outputs[block_idx] = x.detach() if not self.training else x + layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# TRAINING + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # DISTRIBUTED + CUDA SETUP + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # TOKENIZER + VALIDATION METRIC SETUP + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + num_repeats=args.num_repeats, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + num_value_embeds=args.num_value_embeds, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params.append(base_model.loop_embed) + scalar_params.append(base_model.cross_repeat_scales) + if base_model.num_value_embeds > 0: + scalar_params.append(base_model.value_scales) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params = [base_model.tok_emb.weight] + if base_model.num_value_embeds > 0: + embed_params.extend(ve.weight for ve in base_model.value_embeds) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # Progressive depth schedule: parse "frac:repeats,..." and sort + prog_phases: list[tuple[float, int]] = [] + for entry in args.prog_depth_schedule.split(","): + frac_s, rep_s = entry.strip().split(":") + prog_phases.append((float(frac_s), int(rep_s))) + prog_phases.sort() + current_phase_repeats = prog_phases[0][1] if prog_phases else args.num_repeats + base_model.cur_repeats = current_phase_repeats + # Recompile with initial phase depth + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"prog_depth: schedule={prog_phases} starting_repeats={current_phase_repeats}") + + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + # Progressive depth: check if we need to switch phase + # Use synchronized elapsed time (max across ranks) to avoid race conditions + if max_wallclock_ms is not None and prog_phases: + if distributed: + elapsed_tensor = torch.tensor(elapsed_ms, device=device) + dist.all_reduce(elapsed_tensor, op=dist.ReduceOp.MAX) + frac = elapsed_tensor.item() / max_wallclock_ms + else: + frac = elapsed_ms / max_wallclock_ms + new_repeats = prog_phases[-1][1] # default to last + for phase_frac, phase_rep in prog_phases: + if frac < phase_frac: + new_repeats = phase_rep + break + if new_repeats != current_phase_repeats: + current_phase_repeats = new_repeats + base_model.cur_repeats = new_repeats + torch._dynamo.reset() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"prog_depth: switched to {new_repeats} repeats at step:{step} frac:{frac:.2f}") + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown (only at full depth to avoid mixing phases) + at_full_depth = current_phase_repeats == args.num_repeats + if args.swa_enabled and at_full_depth and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().float() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Restore full depth for eval/export + base_model.cur_repeats = args.num_repeats + torch._dynamo.reset() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None: + # Include final weights (may not have landed on swa_every boundary) + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + log0(f"swa: averaging {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed quantized+zstd artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + zstd_level = int(os.environ.get("ZSTD_LEVEL", 22)) + cctx = zstd.ZstdCompressor(level=zstd_level) + quant_blob = cctx.compress(quant_raw) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zstd{zstd_level}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zstd{zstd_level}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + dctx = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval + if args.eval_stride > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"window:{args.eval_seq_len} stride:{args.eval_stride} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Hedge Mixer eval (n-gram ensemble) + if args.use_hedge: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_hm = time.perf_counter() + hm_val_loss, hm_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + use_hedge=True, hedge_eta=args.hedge_eta, + ) + torch.cuda.synchronize() + log0( + f"final_hedge_mixer val_loss:{hm_val_loss:.4f} val_bpb:{hm_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_hm):.0f}ms" + ) + log0(f"final_hedge_mixer_exact val_loss:{hm_val_loss:.8f} val_bpb:{hm_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/README.md b/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/README.md new file mode 100644 index 000000000..8a06b7a2e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/README.md @@ -0,0 +1,52 @@ +## Depth Recurrence + XSA + LeakyReLU² + +Improves previous submission (1.2196 → 1.2065, -0.013 bpb) through three zero-parameter additions on top of depth recurrence. + +val_bpb = 1.2065 (sliding window eval on int8+zstd22 roundtrip model, stride=256) +val_bpb = 1.2398 (standard int8+zstd22 roundtrip) + +### Architecture + +Same depth recurrence base as previous submission: 3 shared blocks repeated 4 times (12 effective layers), dim=832, 8 heads, 4 KV heads, MLP 2x, tied embeddings. + +New additions (all zero extra parameters): +- **XSA (Exclusive Self-Attention)** on last 4 effective layers: removes self-value bias from attention output via GQA-aware projection subtraction. -0.010 bpb. +- **LeakyReLU(0.5)²** instead of relu²: preserves negative gradient flow while maintaining sparsity. Better gradient propagation through 4 recurrence passes. -0.004 bpb. +- **GPTQ-lite**: per-row best-of-5 clip percentiles during quantization (post-training, zero cost). +- **zstd-22** compression instead of zlib (saves ~1.85MB artifact space). +- **SWA** tuned to frac=0.4, every=50 steps. +- **Muon weight decay** 0.04. + +Retained from previous submission: +- Cross-Repeat Skip (stateful recurrence with per-repeat learned scales) +- 2 Value Embedding tables +- Loop Embedding (per-effective-layer depth encoding) + +17.14M params, 15.87MB artifact. + +### Training + +Same LR schedule as previous: MATRIX_LR=0.012, SCALAR_LR=0.012, TIED_EMBED_LR=0.015, GRAD_CLIP_NORM=0.3, WARMDOWN_ITERS=3000, TRAIN_SEQ_LEN=1024. + +### Results (8xH100, 600s wallclock) + +4300 steps, 140ms/step avg. Pre-quant 1.2373, roundtrip 1.2398, sliding window 1.2065. Artifact 15.87MB, quant degradation +0.003 bpb. + +### Ablations (8xH100, 80 shards, all cumulative) + +| Change | Sliding bpb | Delta | +|--------|-------------|-------| +| Baseline (previous submission repro) | 1.2213 | — | +| + XSA last 4 layers | 1.2110 | -0.0103 | +| + LeakyReLU(0.5)² | 1.2070 | -0.0040 | +| + GPTQ-lite + zstd-22 | 1.2065 | -0.0005 | + +### Command + +``` +XSA_LAST_N=4 \ +QUANT_LEVELS=127 \ +EVAL_SEQ_LEN=1024 \ +EVAL_STRIDE=256 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` diff --git a/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/submission.json b/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/submission.json new file mode 100644 index 000000000..137cdb5ac --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/submission.json @@ -0,0 +1,16 @@ +{ + "author": "Ivan Verbovoy", + "github_id": "iverbovoy", + "name": "Depth Recurrence + XSA + LeakyReLU² + GPTQ-lite + zstd-22", + "blurb": "3 unique blocks x 4 repeats (12 effective layers), dim=832, with Cross-Repeat Skip, XSA on last 4 layers, LeakyReLU(0.5)², GPTQ-lite quantization, SWA, Muon WD=0.04, zstd-22 compression. 4300 steps in 600s on 8xH100.", + "date": "2026-03-26T00:00:00Z", + "val_loss": 2.03711228, + "val_bpb": 1.20649213, + "roundtrip_val_loss": 2.09336895, + "roundtrip_val_bpb": 1.23981101, + "step_stop": 4300, + "wallclock_seconds": 600.151, + "bytes_total": 15873439, + "bytes_model_int8_zlib": 15810364, + "bytes_code": 63075 +} diff --git a/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/train.log b/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/train.log new file mode 100644 index 000000000..b77e2a36a --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/train.log @@ -0,0 +1,101 @@ +W0325 23:14:13.792000 1272 torch/distributed/run.py:793] +W0325 23:14:13.792000 1272 torch/distributed/run.py:793] ***************************************** +W0325 23:14:13.792000 1272 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 23:14:13.792000 1272 torch/distributed/run.py:793] ***************************************** +logs/80906ba7-598b-4113-8215-45b1a3a1b567.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17140056 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.015 head_lr:0.0 matrix_lr:0.012 scalar_lr:0.012 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9300 val_bpb:4.1043 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9769 train_time:192ms step_avg:191.99ms +step:2/20000 train_loss:6.4406 train_time:248ms step_avg:123.86ms +step:3/20000 train_loss:7.4687 train_time:384ms step_avg:128.14ms +step:4/20000 train_loss:7.5661 train_time:522ms step_avg:130.52ms +step:5/20000 train_loss:6.8849 train_time:659ms step_avg:131.83ms +step:6/20000 train_loss:6.2342 train_time:796ms step_avg:132.73ms +step:7/20000 train_loss:5.3396 train_time:934ms step_avg:133.40ms +step:8/20000 train_loss:5.0498 train_time:1072ms step_avg:133.94ms +step:9/20000 train_loss:4.8488 train_time:1210ms step_avg:134.49ms +step:10/20000 train_loss:4.7649 train_time:1350ms step_avg:134.97ms +step:200/20000 train_loss:2.7331 train_time:27722ms step_avg:138.61ms +step:400/20000 train_loss:2.2867 train_time:55632ms step_avg:139.08ms +step:600/20000 train_loss:2.5063 train_time:83627ms step_avg:139.38ms +step:800/20000 train_loss:2.2652 train_time:111578ms step_avg:139.47ms +step:1000/20000 train_loss:2.3527 train_time:139525ms step_avg:139.53ms +step:1000/20000 val_loss:2.3114 val_bpb:1.3689 train_time:139609ms step_avg:139.61ms +step:1200/20000 train_loss:2.3656 train_time:167550ms step_avg:139.63ms +step:1400/20000 train_loss:2.4157 train_time:195456ms step_avg:139.61ms +step:1600/20000 train_loss:2.0725 train_time:223357ms step_avg:139.60ms +step:1800/20000 train_loss:2.1766 train_time:251241ms step_avg:139.58ms +step:2000/20000 train_loss:2.2289 train_time:279108ms step_avg:139.55ms +step:2000/20000 val_loss:2.2075 val_bpb:1.3074 train_time:279190ms step_avg:139.60ms +step:2200/20000 train_loss:2.0380 train_time:306975ms step_avg:139.53ms +step:2400/20000 train_loss:2.1660 train_time:334846ms step_avg:139.52ms +step:2600/20000 train_loss:2.3737 train_time:362708ms step_avg:139.50ms +step:2800/20000 train_loss:2.1927 train_time:390569ms step_avg:139.49ms +step:3000/20000 train_loss:2.1817 train_time:418424ms step_avg:139.47ms +step:3000/20000 val_loss:2.1487 val_bpb:1.2726 train_time:418509ms step_avg:139.50ms +swa:start step:3150 +step:3200/20000 train_loss:2.1378 train_time:446330ms step_avg:139.48ms +step:3400/20000 train_loss:2.1057 train_time:474263ms step_avg:139.49ms +step:3600/20000 train_loss:2.0547 train_time:502221ms step_avg:139.51ms +step:3800/20000 train_loss:2.1572 train_time:530168ms step_avg:139.52ms +step:4000/20000 train_loss:2.0950 train_time:558094ms step_avg:139.52ms +step:4000/20000 val_loss:2.0989 val_bpb:1.2431 train_time:558197ms step_avg:139.55ms +step:4200/20000 train_loss:2.0995 train_time:586106ms step_avg:139.55ms +step:4300/20000 val_loss:2.0892 val_bpb:1.2373 train_time:600151ms step_avg:139.57ms +stopping_early: wallclock_cap train_time:600151ms step:4300/20000 +peak memory allocated: 25696 MiB reserved: 27322 MiB +swa: averaging 25 checkpoints +Serialized model: 63386762 bytes +Code size: 63075 bytes +Total submission size: 63449837 bytes +Serialized model int8+zstd22: 15810364 bytes (payload:17243616 raw_torch:17260843 payload_ratio:3.68x) +Total submission size int8+zstd22: 15873439 bytes +/workspace/parameter-golf/train_gpt.py:1395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +final_roundtrip val_loss:2.0934 val_bpb:1.2398 eval_time:4076ms +final_roundtrip_exact val_loss:2.09336895 val_bpb:1.23981101 +final_sliding_window val_loss:2.0371 val_bpb:1.2065 window:1024 stride:256 eval_time:66852ms +final_sliding_window_exact val_loss:2.03711228 val_bpb:1.20649213 diff --git a/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/train_gpt.py b/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/train_gpt.py new file mode 100644 index 000000000..41f4ac4b9 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/train_gpt.py @@ -0,0 +1,1473 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zstandard as zstd +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + ttt_steps = int(os.environ.get("TTT_STEPS", 0)) + ttt_lr = float(os.environ.get("TTT_LR", 1e-4)) + + # XSA (Exclusive Self-Attention) on last N effective layers. + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + + # SWA (Stochastic Weight Averaging) during warmdown. + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Sliding window eval. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 1024)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 3)) + num_repeats = int(os.environ.get("NUM_REPEATS", 4)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 832)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + num_value_embeds = int(os.environ.get("NUM_VALUE_EMBEDS", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.015)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.012)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.012)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Sliding window eval with batching. Windows of train_seq_len advance by eval_stride. + Only the last stride tokens per window are scored (first window scores all).""" + seq_len = args.eval_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi : bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws : end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + val_loss_sum += scored_nll.sum() + val_token_count += float(wlen - s) + prev_ids = x_batch[i, s:wlen] + tgt_ids = y_batch[i, s:wlen] + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_ttt( + args: Hyperparameters, + base_model: nn.Module, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Test-Time Training: adapt the model on each validation batch before evaluating. + # For each batch: save weights → K gradient steps → evaluate → restore weights. + if args.ttt_steps <= 0: + return eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Save original weights once + saved_state = {k: v.detach().clone() for k, v in base_model.state_dict().items()} + + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + + # TTT: adapt on this batch + model.train() + for _ttt_step in range(args.ttt_steps): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ttt_loss = model(x, y) + ttt_loss.backward() + with torch.no_grad(): + for p in base_model.parameters(): + if p.grad is not None: + p -= args.ttt_lr * p.grad + p.grad = None + + # Evaluate with adapted model + model.eval() + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + # Restore original weights + base_model.load_state_dict(saved_state, strict=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing and zstd compressing. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +# Quantization levels: 127 = int8, 31 = int6, 16 = int5. Per-tensor override via MLP_QUANT_LEVELS. +QUANT_LEVELS = int(os.environ.get("QUANT_LEVELS", 127)) +MLP_QUANT_LEVELS = int(os.environ.get("MLP_QUANT_LEVELS", 0)) # 0 = same as QUANT_LEVELS +MLP_TENSOR_PATTERNS = ("mlp.fc.", "mlp.proj.", "fc.weight", "mlp.proj.weight") + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 0.999999] + +def quantize_float_tensor(t: Tensor, ql: int = 0) -> tuple[Tensor, Tensor]: + if ql <= 0: + ql = QUANT_LEVELS + t32 = t.float() + if t32.ndim == 2: + # GPTQ-lite: try multiple clip percentiles per row, pick best MSE + abs_t = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in GPTQ_LITE_PERCENTILES: + clip_abs = ( + torch.quantile(abs_t, pct, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + s = (clip_abs / ql).clamp_min(1e-12) + q = torch.clamp(torch.round(clipped / s[:, None]), -ql, ql) + # Reconstruction error per row + recon = q * s[:, None] + mse = (t32 - recon).square().sum(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = s + else: + better = mse < best_mse + best_mse = torch.where(better, mse, best_mse) + best_q = torch.where(better[:, None], q, best_q) + best_scale = torch.where(better, s, best_scale) + return best_q.to(torch.int8).contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / ql if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -ql, ql).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + mlp_ql = MLP_QUANT_LEVELS if MLP_QUANT_LEVELS > 0 else QUANT_LEVELS + ql = mlp_ql if any(p in name for p in MLP_TENSOR_PATTERNS) else QUANT_LEVELS + q, s = quantize_float_tensor(t, ql=ql) + meta: dict[str, object] = {} + if s.ndim > 0: + meta["scheme"] = "per_row" + meta["axis"] = 0 + if ql != QUANT_LEVELS: + meta["ql"] = ql + if meta: + qmeta[name] = meta + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def _xsa(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection from attention output (GQA-aware).""" + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(3) # [B, T, Hkv, 1, D] + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, use_xsa: bool = False) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + # XSA: remove self-value bias from attention output + if use_xsa: + y = y.transpose(1, 2).contiguous() # [B, T, H, D] + v_for_xsa = v.transpose(1, 2) # [B, T, Hkv, D] + y = self._xsa(y, v_for_xsa) + y = y.reshape(bsz, seqlen, dim) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu(0.5)^2 MLP — better gradient flow than relu^2 for deep/recurrent models + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, use_xsa: bool = False) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x), use_xsa=use_xsa) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_repeats: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + num_value_embeds: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_repeats = num_repeats + effective_depth = num_layers * num_repeats + # XSA: which effective layers use exclusive self-attention + self.xsa_start = max(0, effective_depth - xsa_last_n) if xsa_last_n > 0 else effective_depth + self.tok_emb = nn.Embedding(vocab_size, model_dim) + # Value embeddings: extra embedding tables mixed into each effective layer + self.num_value_embeds = num_value_embeds + if num_value_embeds > 0: + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(num_value_embeds)]) + self.value_scales = nn.Parameter(torch.zeros(effective_depth, num_value_embeds, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + # Loop embedding: tells the model which effective layer it's at + self.loop_embed = nn.Parameter(torch.zeros(effective_depth, model_dim, dtype=torch.float32)) + # Cross-repeat skip: each block receives its own output from previous repeat + self.cross_repeat_scales = nn.Parameter(torch.zeros(num_layers, num_repeats - 1, model_dim, dtype=torch.float32)) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # Pre-compute value embeddings once + ve_list: list[Tensor] = [] + if self.num_value_embeds > 0: + for ve in self.value_embeds: + ve_list.append(ve(input_ids)) # (bsz, seq, dim) + + num_blocks = len(self.blocks) + prev_block_outputs: list[Tensor | None] = [None] * num_blocks + layer_idx = 0 + for repeat in range(self.num_repeats): + for block_idx, block in enumerate(self.blocks): + x = x + self.loop_embed[layer_idx].to(dtype=x.dtype) + # Value embeddings: add weighted extra embeddings at each layer + for ve_idx, ve_out in enumerate(ve_list): + vs = self.value_scales[layer_idx, ve_idx].to(dtype=x.dtype) + x = x + vs[None, None, :] * ve_out + # Cross-repeat skip: mix in this block's output from previous repeat + if repeat > 0 and prev_block_outputs[block_idx] is not None: + scale = self.cross_repeat_scales[block_idx, repeat - 1].to(dtype=x.dtype) + x = x + scale[None, None, :] * prev_block_outputs[block_idx] + x = block(x, x0, use_xsa=(layer_idx >= self.xsa_start)) + prev_block_outputs[block_idx] = x.detach() if not self.training else x + layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + num_repeats=args.num_repeats, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + num_value_embeds=args.num_value_embeds, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params.append(base_model.loop_embed) + scalar_params.append(base_model.cross_repeat_scales) + if base_model.num_value_embeds > 0: + scalar_params.append(base_model.value_scales) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params = [base_model.tok_emb.weight] + if base_model.num_value_embeds > 0: + embed_params.extend(ve.weight for ve in base_model.value_embeds) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown (accumulate in float for precision) + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().float() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None: + # Include final weights (may not have landed on swa_every boundary) + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + log0(f"swa: averaging {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed quantized+zstd artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + zstd_level = int(os.environ.get("ZSTD_LEVEL", 22)) + cctx = zstd.ZstdCompressor(level=zstd_level) + quant_blob = cctx.compress(quant_raw) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zstd{zstd_level}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zstd{zstd_level}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + dctx = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval + if args.eval_stride > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"window:{args.eval_seq_len} stride:{args.eval_stride} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # TTT eval: adapt model on each batch before evaluating + if args.ttt_steps > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt( + args, + base_model, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_ttt val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"ttt_steps:{args.ttt_steps} ttt_lr:{args.ttt_lr} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"final_ttt_exact val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_non_record_16mb/2026-03-26_ProgressiveDepth_4hour/README.md b/records/track_non_record_16mb/2026-03-26_ProgressiveDepth_4hour/README.md new file mode 100644 index 000000000..9193b9144 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_ProgressiveDepth_4hour/README.md @@ -0,0 +1,60 @@ +## Progressive Depth 4-Hour: Depth Recurrence Scaling Study + +val_bpb = **1.0889** (Hedge Mixer) | 1.1271 (sliding) | 1.1613 (roundtrip) + +4 hours on 8xH100 SXM. Non-record submission — unlimited compute track. + +### Research Question + +How does depth recurrence scale with compute? This is the first data point. + +Shared-weight recurrence has a unique scaling property: the same 3 blocks receive every gradient update. At 132K steps with 5 repeats, each block saw **~660K effective gradient passes** — impossible with unique-layer architectures at the same parameter count. + +### Scaling Curve + +| Steps | Time | Phase | val_bpb | Delta from prev | +|-------|------|-------|---------|-----------------| +| 5K | 6 min | 2 rep | 1.3061 | — | +| 30K | 36 min | 2 rep | 1.2713 | -0.035 | +| 55K | 66 min | 2 rep | 1.2649 | -0.006 | +| 60K | 73 min | 3 rep | 1.2627 | -0.002 | +| 85K | 117 min | 3 rep | 1.2437 | -0.019 | +| 100K | 151 min | 4 rep | 1.2341 | -0.010 | +| 115K | 188 min | 5 rep | 1.2273 | -0.007 | +| 125K | 217 min | 5 rep | 1.2179 | -0.009 | +| 132K | 240 min | 5 rep + SWA | **1.1576** | -0.060 | + +Key observations: +1. **Phase transitions matter**: each depth increase gives an immediate improvement, even late in training +2. **SWA is massive at scale**: 38 checkpoints gave -0.060 bpb — larger than any single phase transition +3. **Diminishing returns within phase**: Phase 1 (2 rep) shows clear flattening after ~40K steps +4. **Progressive Depth unlocks 5 repeats**: 15 effective layers from 3 physical blocks, only possible with gradual depth ramp + +### Comparison + +| Run | Compute | Steps | Sliding bpb | Hedge bpb | +|-----|---------|-------|-------------|-----------| +| Will DePue baseline (flat 9×512) | 4 hours | 329K | — | 1.2074 | +| Our 10-min Progressive Depth | 10 min | 5.7K | 1.1966 | 1.1454 | +| **Our 4-hour Progressive Depth** | **4 hours** | **132K** | **1.1271** | **1.0889** | + +4-hour depth recurrence beats 4-hour flat baseline by **0.119 bpb** (1.2074 → 1.0889 with Hedge, comparison without Hedge: 1.1271 vs 1.2074 = **-0.080**). + +### Configuration + +```bash +MAX_WALLCLOCK_SECONDS=14400 ITERATIONS=200000 \ +NUM_REPEATS=5 PROG_DEPTH="0.3:2,0.5:3,0.75:4,1.0:5" \ +WARMDOWN_ITERS=15000 SWA_EVERY=100 \ +MATRIX_LR=0.015 SCALAR_LR=0.015 TIED_EMBED_LR=0.018 \ +VAL_LOSS_EVERY=5000 TRAIN_LOG_EVERY=1000 USE_HEDGE=1 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +### Architecture + +3 shared blocks, progressive depth (2→3→4→5 repeats), dim=832, 8 heads, 4 KV heads, GQA, MLP 2×, tied embeddings. Cross-Repeat Skip (#148), XSA, LeakyReLU², SWA, Hedge Mixer. 17.15M params, 15.82MB artifact. + +### Credits + +Hedge Mixer from PR #688 (@RoyiRa), PR #745 (@stukenov). Will DePue's 4-hour baseline for comparison. diff --git a/records/track_non_record_16mb/2026-03-26_ProgressiveDepth_4hour/submission.json b/records/track_non_record_16mb/2026-03-26_ProgressiveDepth_4hour/submission.json new file mode 100644 index 000000000..7f636b44e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_ProgressiveDepth_4hour/submission.json @@ -0,0 +1,19 @@ +{ + "author": "Ivan Verbovoy", + "github_id": "iverbovoy", + "name": "Progressive Depth 4-Hour: Depth Recurrence Scaling Study", + "blurb": "4-hour unlimited compute run exploring how depth recurrence scales. 3 shared blocks, progressive depth 2\u21923\u21924\u21925 repeats (15 effective layers), 132K steps, 38 SWA checkpoints, Hedge Mixer eval. First data point on depth recurrence scaling with compute.", + "date": "2026-03-26T19:00:00Z", + "val_loss": 1.83860318, + "val_bpb": 1.08892390, + "sliding_val_loss": 1.90311353, + "sliding_val_bpb": 1.12713055, + "roundtrip_val_loss": 1.96078305, + "roundtrip_val_bpb": 1.16128616, + "step_stop": 132937, + "wallclock_seconds": 14400.160, + "eval_seconds": 695.601, + "bytes_total": 15888183, + "bytes_model_int8_zstd22": 15822329, + "bytes_code": 65854 +} diff --git a/records/track_non_record_16mb/2026-03-26_ProgressiveDepth_4hour/train.log b/records/track_non_record_16mb/2026-03-26_ProgressiveDepth_4hour/train.log new file mode 100644 index 000000000..8dc2d16c6 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_ProgressiveDepth_4hour/train.log @@ -0,0 +1,240 @@ +W0326 15:53:02.963000 21735 torch/distributed/run.py:793] +W0326 15:53:02.963000 21735 torch/distributed/run.py:793] ***************************************** +W0326 15:53:02.963000 21735 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 15:53:02.963000 21735 torch/distributed/run.py:793] ***************************************** +logs/1dbeb80b-7db2-4f3c-a553-9cd735fa8456.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17150040 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.018 head_lr:0.0 matrix_lr:0.015 scalar_lr:0.015 +train_batch_tokens:524288 train_seq_len:1024 iterations:200000 warmup_steps:20 max_wallclock_seconds:14400.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +prog_depth: schedule=[(0.3, 2), (0.5, 3), (0.75, 4), (1.0, 5)] starting_repeats=2 +step:0/200000 val_loss:6.9300 val_bpb:4.1043 train_time:0ms step_avg:0.01ms +step:1/200000 train_loss:6.9769 train_time:16845ms step_avg:16844.90ms +step:2/200000 train_loss:8.6314 train_time:16863ms step_avg:8431.45ms +step:3/200000 train_loss:8.2566 train_time:16932ms step_avg:5643.85ms +step:4/200000 train_loss:7.3935 train_time:17001ms step_avg:4250.35ms +step:5/200000 train_loss:6.3765 train_time:17074ms step_avg:3414.76ms +step:6/200000 train_loss:5.9886 train_time:17145ms step_avg:2857.50ms +step:7/200000 train_loss:5.6334 train_time:17218ms step_avg:2459.73ms +step:8/200000 train_loss:5.4596 train_time:17290ms step_avg:2161.24ms +step:9/200000 train_loss:5.2870 train_time:17362ms step_avg:1929.15ms +step:10/200000 train_loss:5.1550 train_time:17436ms step_avg:1743.55ms +step:1000/200000 train_loss:2.3841 train_time:88844ms step_avg:88.84ms +step:2000/200000 train_loss:2.2821 train_time:161286ms step_avg:80.64ms +step:3000/200000 train_loss:2.2612 train_time:233555ms step_avg:77.85ms +step:4000/200000 train_loss:2.2084 train_time:305800ms step_avg:76.45ms +step:5000/200000 train_loss:2.0848 train_time:377987ms step_avg:75.60ms +step:5000/200000 val_loss:2.2052 val_bpb:1.3061 train_time:378030ms step_avg:75.61ms +step:6000/200000 train_loss:2.2732 train_time:450156ms step_avg:75.03ms +step:7000/200000 train_loss:2.2837 train_time:522290ms step_avg:74.61ms +step:8000/200000 train_loss:2.1746 train_time:594389ms step_avg:74.30ms +step:9000/200000 train_loss:2.0870 train_time:666559ms step_avg:74.06ms +step:10000/200000 train_loss:2.1783 train_time:738695ms step_avg:73.87ms +step:10000/200000 val_loss:2.1806 val_bpb:1.2915 train_time:738738ms step_avg:73.87ms +step:11000/200000 train_loss:2.1819 train_time:810833ms step_avg:73.71ms +step:12000/200000 train_loss:2.1383 train_time:882934ms step_avg:73.58ms +step:13000/200000 train_loss:2.2662 train_time:955174ms step_avg:73.47ms +step:14000/200000 train_loss:2.1719 train_time:1027305ms step_avg:73.38ms +step:15000/200000 train_loss:2.1102 train_time:1099410ms step_avg:73.29ms +step:15000/200000 val_loss:2.1646 val_bpb:1.2820 train_time:1099452ms step_avg:73.30ms +step:16000/200000 train_loss:2.2081 train_time:1171504ms step_avg:73.22ms +step:17000/200000 train_loss:2.1836 train_time:1243662ms step_avg:73.16ms +step:18000/200000 train_loss:2.1602 train_time:1315789ms step_avg:73.10ms +step:19000/200000 train_loss:2.1642 train_time:1387892ms step_avg:73.05ms +step:20000/200000 train_loss:2.1936 train_time:1460000ms step_avg:73.00ms +step:20000/200000 val_loss:2.1571 val_bpb:1.2776 train_time:1460041ms step_avg:73.00ms +step:21000/200000 train_loss:2.2452 train_time:1532270ms step_avg:72.97ms +step:22000/200000 train_loss:2.1390 train_time:1604389ms step_avg:72.93ms +step:23000/200000 train_loss:2.1860 train_time:1676489ms step_avg:72.89ms +step:24000/200000 train_loss:2.1376 train_time:1748590ms step_avg:72.86ms +step:25000/200000 train_loss:2.2498 train_time:1820769ms step_avg:72.83ms +step:25000/200000 val_loss:2.1502 val_bpb:1.2735 train_time:1820812ms step_avg:72.83ms +step:26000/200000 train_loss:2.0429 train_time:1892885ms step_avg:72.80ms +step:27000/200000 train_loss:2.3168 train_time:1965009ms step_avg:72.78ms +step:28000/200000 train_loss:2.1552 train_time:2037142ms step_avg:72.76ms +step:29000/200000 train_loss:2.2571 train_time:2109419ms step_avg:72.74ms +step:30000/200000 train_loss:2.1109 train_time:2181562ms step_avg:72.72ms +step:30000/200000 val_loss:2.1466 val_bpb:1.2713 train_time:2181606ms step_avg:72.72ms +step:31000/200000 train_loss:1.9727 train_time:2253693ms step_avg:72.70ms +step:32000/200000 train_loss:2.1049 train_time:2325791ms step_avg:72.68ms +step:33000/200000 train_loss:2.1319 train_time:2397967ms step_avg:72.67ms +step:34000/200000 train_loss:2.0386 train_time:2470075ms step_avg:72.65ms +step:35000/200000 train_loss:2.2149 train_time:2542194ms step_avg:72.63ms +step:35000/200000 val_loss:2.1457 val_bpb:1.2708 train_time:2542236ms step_avg:72.64ms +step:36000/200000 train_loss:1.9947 train_time:2614314ms step_avg:72.62ms +step:37000/200000 train_loss:2.1807 train_time:2686439ms step_avg:72.61ms +step:38000/200000 train_loss:2.0213 train_time:2758647ms step_avg:72.60ms +step:39000/200000 train_loss:2.3018 train_time:2830744ms step_avg:72.58ms +step:40000/200000 train_loss:2.3711 train_time:2902835ms step_avg:72.57ms +step:40000/200000 val_loss:2.1430 val_bpb:1.2692 train_time:2902878ms step_avg:72.57ms +step:41000/200000 train_loss:2.4223 train_time:2974947ms step_avg:72.56ms +step:42000/200000 train_loss:2.0273 train_time:3047107ms step_avg:72.55ms +step:43000/200000 train_loss:2.1829 train_time:3119185ms step_avg:72.54ms +step:44000/200000 train_loss:2.1263 train_time:3191302ms step_avg:72.53ms +step:45000/200000 train_loss:2.1253 train_time:3263435ms step_avg:72.52ms +step:45000/200000 val_loss:2.1392 val_bpb:1.2669 train_time:3263477ms step_avg:72.52ms +step:46000/200000 train_loss:2.2437 train_time:3335684ms step_avg:72.51ms +step:47000/200000 train_loss:2.1515 train_time:3407801ms step_avg:72.51ms +step:48000/200000 train_loss:2.1868 train_time:3479915ms step_avg:72.50ms +step:49000/200000 train_loss:2.1920 train_time:3552038ms step_avg:72.49ms +step:50000/200000 train_loss:2.2547 train_time:3624187ms step_avg:72.48ms +step:50000/200000 val_loss:2.1386 val_bpb:1.2666 train_time:3624229ms step_avg:72.48ms +step:51000/200000 train_loss:2.0858 train_time:3696306ms step_avg:72.48ms +step:52000/200000 train_loss:2.1921 train_time:3768412ms step_avg:72.47ms +step:53000/200000 train_loss:2.0741 train_time:3840517ms step_avg:72.46ms +step:54000/200000 train_loss:2.1679 train_time:3912762ms step_avg:72.46ms +step:55000/200000 train_loss:2.0476 train_time:3984860ms step_avg:72.45ms +step:55000/200000 val_loss:2.1357 val_bpb:1.2649 train_time:3984903ms step_avg:72.45ms +step:56000/200000 train_loss:2.2066 train_time:4056990ms step_avg:72.45ms +step:57000/200000 train_loss:2.2238 train_time:4129076ms step_avg:72.44ms +step:58000/200000 train_loss:2.2995 train_time:4201225ms step_avg:72.43ms +step:59000/200000 train_loss:2.1368 train_time:4273324ms step_avg:72.43ms +prog_depth: switched to 3 repeats at step:59647 frac:0.30 +step:60000/200000 train_loss:2.0361 train_time:4377985ms step_avg:72.97ms +step:60000/200000 val_loss:2.1320 val_bpb:1.2627 train_time:4378050ms step_avg:72.97ms +step:61000/200000 train_loss:2.0974 train_time:4483921ms step_avg:73.51ms +step:62000/200000 train_loss:2.1659 train_time:4589987ms step_avg:74.03ms +step:63000/200000 train_loss:2.1583 train_time:4695907ms step_avg:74.54ms +step:64000/200000 train_loss:2.0648 train_time:4801867ms step_avg:75.03ms +step:65000/200000 train_loss:2.1450 train_time:4907829ms step_avg:75.51ms +step:65000/200000 val_loss:2.1106 val_bpb:1.2500 train_time:4907895ms step_avg:75.51ms +step:66000/200000 train_loss:1.8860 train_time:5013838ms step_avg:75.97ms +step:67000/200000 train_loss:2.0870 train_time:5119857ms step_avg:76.42ms +step:68000/200000 train_loss:2.1900 train_time:5225807ms step_avg:76.85ms +step:69000/200000 train_loss:2.0886 train_time:5331736ms step_avg:77.27ms +step:70000/200000 train_loss:2.2211 train_time:5437727ms step_avg:77.68ms +step:70000/200000 val_loss:2.1081 val_bpb:1.2485 train_time:5437791ms step_avg:77.68ms +step:71000/200000 train_loss:2.1345 train_time:5543651ms step_avg:78.08ms +step:72000/200000 train_loss:2.1333 train_time:5649648ms step_avg:78.47ms +step:73000/200000 train_loss:2.0843 train_time:5755578ms step_avg:78.84ms +step:74000/200000 train_loss:2.0270 train_time:5861483ms step_avg:79.21ms +step:75000/200000 train_loss:1.9700 train_time:5967463ms step_avg:79.57ms +step:75000/200000 val_loss:2.1054 val_bpb:1.2469 train_time:5967529ms step_avg:79.57ms +step:76000/200000 train_loss:2.1318 train_time:6073388ms step_avg:79.91ms +step:77000/200000 train_loss:2.1236 train_time:6179320ms step_avg:80.25ms +step:78000/200000 train_loss:2.0288 train_time:6285314ms step_avg:80.58ms +step:79000/200000 train_loss:2.0727 train_time:6391247ms step_avg:80.90ms +step:80000/200000 train_loss:2.1677 train_time:6497157ms step_avg:81.21ms +step:80000/200000 val_loss:2.1026 val_bpb:1.2453 train_time:6497222ms step_avg:81.22ms +step:81000/200000 train_loss:2.2056 train_time:6603068ms step_avg:81.52ms +step:82000/200000 train_loss:2.0218 train_time:6708965ms step_avg:81.82ms +step:83000/200000 train_loss:2.0673 train_time:6814955ms step_avg:82.11ms +step:84000/200000 train_loss:2.1253 train_time:6920923ms step_avg:82.39ms +step:85000/200000 train_loss:2.0886 train_time:7026821ms step_avg:82.67ms +step:85000/200000 val_loss:2.1000 val_bpb:1.2437 train_time:7026879ms step_avg:82.67ms +step:86000/200000 train_loss:2.1928 train_time:7132698ms step_avg:82.94ms +prog_depth: switched to 4 repeats at step:86635 frac:0.50 +step:87000/200000 train_loss:2.1398 train_time:7272924ms step_avg:83.60ms +step:88000/200000 train_loss:2.0588 train_time:7412765ms step_avg:84.24ms +step:89000/200000 train_loss:2.1491 train_time:7552665ms step_avg:84.86ms +step:90000/200000 train_loss:2.1829 train_time:7692449ms step_avg:85.47ms +step:90000/200000 val_loss:2.0901 val_bpb:1.2379 train_time:7692531ms step_avg:85.47ms +step:91000/200000 train_loss:1.9459 train_time:7832253ms step_avg:86.07ms +step:92000/200000 train_loss:2.0296 train_time:7972057ms step_avg:86.65ms +step:93000/200000 train_loss:2.1412 train_time:8111893ms step_avg:87.22ms +step:94000/200000 train_loss:2.1227 train_time:8251664ms step_avg:87.78ms +step:95000/200000 train_loss:2.0571 train_time:8391465ms step_avg:88.33ms +step:95000/200000 val_loss:2.0879 val_bpb:1.2365 train_time:8391548ms step_avg:88.33ms +step:96000/200000 train_loss:2.0182 train_time:8531192ms step_avg:88.87ms +step:97000/200000 train_loss:1.9758 train_time:8670990ms step_avg:89.39ms +step:98000/200000 train_loss:2.0186 train_time:8810717ms step_avg:89.91ms +step:99000/200000 train_loss:2.0495 train_time:8950493ms step_avg:90.41ms +step:100000/200000 train_loss:2.0471 train_time:9090190ms step_avg:90.90ms +step:100000/200000 val_loss:2.0837 val_bpb:1.2341 train_time:9090274ms step_avg:90.90ms +step:101000/200000 train_loss:2.1117 train_time:9230025ms step_avg:91.39ms +step:102000/200000 train_loss:2.0209 train_time:9369757ms step_avg:91.86ms +step:103000/200000 train_loss:2.0541 train_time:9509549ms step_avg:92.33ms +step:104000/200000 train_loss:2.0081 train_time:9649290ms step_avg:92.78ms +step:105000/200000 train_loss:2.0522 train_time:9789037ms step_avg:93.23ms +step:105000/200000 val_loss:2.0846 val_bpb:1.2346 train_time:9789122ms step_avg:93.23ms +step:106000/200000 train_loss:2.1964 train_time:9928800ms step_avg:93.67ms +step:107000/200000 train_loss:2.1594 train_time:10068511ms step_avg:94.10ms +step:108000/200000 train_loss:2.0851 train_time:10208311ms step_avg:94.52ms +step:109000/200000 train_loss:2.0508 train_time:10348010ms step_avg:94.94ms +step:110000/200000 train_loss:1.9928 train_time:10487787ms step_avg:95.34ms +step:110000/200000 val_loss:2.0804 val_bpb:1.2321 train_time:10487872ms step_avg:95.34ms +step:111000/200000 train_loss:1.9835 train_time:10627479ms step_avg:95.74ms +step:112000/200000 train_loss:2.0173 train_time:10767231ms step_avg:96.14ms +prog_depth: switched to 5 repeats at step:112234 frac:0.75 +step:113000/200000 train_loss:1.9453 train_time:10945906ms step_avg:96.87ms +step:114000/200000 train_loss:2.0244 train_time:11119241ms step_avg:97.54ms +step:115000/200000 train_loss:1.9814 train_time:11292478ms step_avg:98.20ms +step:115000/200000 val_loss:2.0722 val_bpb:1.2273 train_time:11292561ms step_avg:98.20ms +step:116000/200000 train_loss:2.1759 train_time:11465777ms step_avg:98.84ms +step:117000/200000 train_loss:2.1031 train_time:11639170ms step_avg:99.48ms +step:118000/200000 train_loss:1.9274 train_time:11812408ms step_avg:100.11ms +step:119000/200000 train_loss:1.9307 train_time:11985632ms step_avg:100.72ms +step:120000/200000 train_loss:2.0588 train_time:12158917ms step_avg:101.32ms +step:120000/200000 val_loss:2.0707 val_bpb:1.2264 train_time:12159000ms step_avg:101.32ms +step:121000/200000 train_loss:1.8769 train_time:12332051ms step_avg:101.92ms +step:122000/200000 train_loss:2.1114 train_time:12505223ms step_avg:102.50ms +step:123000/200000 train_loss:2.1025 train_time:12678435ms step_avg:103.08ms +step:124000/200000 train_loss:2.0868 train_time:12851768ms step_avg:103.64ms +step:125000/200000 train_loss:1.7949 train_time:13024947ms step_avg:104.20ms +step:125000/200000 val_loss:2.0563 val_bpb:1.2179 train_time:13025031ms step_avg:104.20ms +step:126000/200000 train_loss:2.0538 train_time:13198136ms step_avg:104.75ms +step:127000/200000 train_loss:2.0125 train_time:13371358ms step_avg:105.29ms +step:128000/200000 train_loss:1.9549 train_time:13544572ms step_avg:105.82ms +step:129000/200000 train_loss:2.0297 train_time:13717722ms step_avg:106.34ms +swa:start step:129300 +step:130000/200000 train_loss:2.0308 train_time:13891013ms step_avg:106.85ms +step:130000/200000 val_loss:2.0000 val_bpb:1.1845 train_time:13891115ms step_avg:106.85ms +step:131000/200000 train_loss:2.0374 train_time:14064355ms step_avg:107.36ms +step:132000/200000 train_loss:1.9361 train_time:14237682ms step_avg:107.86ms +step:132937/200000 val_loss:1.9545 val_bpb:1.1576 train_time:14400160ms step_avg:108.32ms +stopping_early: wallclock_cap train_time:14400160ms step:132937/200000 +peak memory allocated: 31668 MiB reserved: 32118 MiB +swa: averaging 38 checkpoints +Serialized model: 63406730 bytes +Code size: 65854 bytes +Total submission size: 63472584 bytes +Serialized model int8+zstd22: 15822329 bytes (payload:17263584 raw_torch:17280811 payload_ratio:3.67x) +Total submission size int8+zstd22: 15888183 bytes +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +final_roundtrip val_loss:1.9608 val_bpb:1.1613 eval_time:17704ms +final_roundtrip_exact val_loss:1.96078305 val_bpb:1.16128616 +final_sliding_window val_loss:1.9031 val_bpb:1.1271 window:1024 stride:256 eval_time:81914ms +final_sliding_window_exact val_loss:1.90311353 val_bpb:1.12713055 +final_hedge_mixer val_loss:1.8386 val_bpb:1.0889 eval_time:695601ms +final_hedge_mixer_exact val_loss:1.83860318 val_bpb:1.08892390 diff --git a/records/track_non_record_16mb/2026-03-26_ProgressiveDepth_4hour/train_gpt.py b/records/track_non_record_16mb/2026-03-26_ProgressiveDepth_4hour/train_gpt.py new file mode 100644 index 000000000..1738288f3 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_ProgressiveDepth_4hour/train_gpt.py @@ -0,0 +1,1498 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zstandard as zstd +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + + +class HedgeMixer: + """Online mixture of 5 experts via Hedge algorithm for eval-time improvement. + Experts: Neural, Unigram, Bigram, Trigram (hashed), Entropy.""" + def __init__(self, vocab_size: int = 1024, device: str = "cuda", eta: float = 0.1): + self.V = vocab_size + self.device = device + self.eta = eta + self.log_weights = torch.zeros(5, device=device) + self.log_weights[0] = 2.0 # bias toward neural + self.uni_counts = torch.zeros(vocab_size, device=device) + self.bi_counts = torch.zeros(vocab_size, vocab_size, device=device) + self.total_tokens = 0 + self.TRI_HASH = 65536 + self.tri_counts = torch.zeros(self.TRI_HASH, vocab_size, device=device) + self.tri_row_totals = torch.zeros(self.TRI_HASH, device=device) + + def update(self, tokens: Tensor) -> None: + t = tokens.to(self.device).long() + n = t.numel() + if n == 0: + return + self.total_tokens += n + ones = torch.ones(n, device=self.device) + self.uni_counts.scatter_add_(0, t, ones) + if n >= 2: + bi_idx = t[:-1] * self.V + t[1:] + self.bi_counts.reshape(-1).scatter_add_(0, bi_idx, torch.ones(n - 1, device=self.device)) + if n >= 3: + tri_ctx = ((t[:-2] * 36313) ^ (t[1:-1] * 27191)) % self.TRI_HASH + tri_idx = tri_ctx * self.V + t[2:] + ones_tri = torch.ones(n - 2, device=self.device) + self.tri_counts.reshape(-1).scatter_add_(0, tri_idx, ones_tri) + self.tri_row_totals.scatter_add_(0, tri_ctx, ones_tri) + + def mix_and_score(self, neural_logits: Tensor, x_batch: Tensor, y_batch: Tensor, wlens: list[int]) -> Tensor: + bsz, slen, V = neural_logits.shape + uniform_nll = math.log(self.V) + has_data = self.total_tokens > 0 + neural_lp = F.log_softmax(neural_logits, dim=-1) + neural_nll = -neural_lp.gather(2, y_batch.unsqueeze(2)).squeeze(2) + if not has_data or self.total_tokens < 10000: + return neural_nll + uni_probs = (self.uni_counts + 0.1) / (self.total_tokens + 0.1 * self.V) + uni_nll = -uni_probs.log()[y_batch] + bi_total = self.bi_counts.sum(dim=1, keepdim=True) + bi_probs = (self.bi_counts + 0.1) / (bi_total + 0.1 * self.V) + bi_nll = -bi_probs.log()[x_batch.reshape(-1), y_batch.reshape(-1)].reshape(bsz, slen) + if slen >= 2: + prev2 = torch.zeros_like(x_batch) + prev2[:, 1:] = x_batch[:, :-1] + ctx_hash = ((prev2 * 36313) ^ (x_batch * 27191)) % self.TRI_HASH + tri_count = self.tri_counts[ctx_hash.reshape(-1).long(), y_batch.reshape(-1).long()] + tri_total = self.tri_row_totals[ctx_hash.reshape(-1).long()].clamp(min=1) + tri_nll = -(((tri_count + 0.01) / (tri_total + 0.01 * self.V)).log()).reshape(bsz, slen) + else: + tri_nll = torch.full((bsz, slen), uniform_nll, device=self.device) + entropy_nll = -(neural_lp.exp() * neural_lp).sum(-1) + expert_nll = torch.stack([neural_nll, uni_nll, bi_nll, tri_nll, entropy_nll], dim=-1) + log_w = self.log_weights - self.log_weights.logsumexp(0) + mixed_nll = -(-expert_nll + log_w.unsqueeze(0).unsqueeze(0)).logsumexp(dim=-1) + # Update weights + wlens_t = torch.tensor(wlens, device=self.device, dtype=torch.long) + mask = torch.arange(slen, device=self.device).unsqueeze(0) < wlens_t.unsqueeze(1) + masked_nll = expert_nll * mask.unsqueeze(-1).float() + expert_mean_loss = masked_nll.sum(dim=(0, 1)) / mask.sum().clamp(min=1) + self.log_weights -= self.eta * expert_mean_loss + return mixed_nll + + +# HYPERPARAMETERS + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + + # Progressive Depth: train with fewer repeats early (faster), more repeats later (deeper). + # Schedule format: "frac1:rep1,frac2:rep2,..." e.g. "0.4:2,0.65:3,1.0:4" + prog_depth_schedule = os.environ.get("PROG_DEPTH", "0.4:2,0.65:3,1.0:4") + + # XSA (Exclusive Self-Attention) on last N effective layers. + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + + # SWA (Stochastic Weight Averaging) during warmdown. + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Sliding window eval. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 1024)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + # Hedge Mixer (eval-time n-gram ensemble). + use_hedge = bool(int(os.environ.get("USE_HEDGE", "1"))) + hedge_eta = float(os.environ.get("HEDGE_ETA", 0.1)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 3)) + num_repeats = int(os.environ.get("NUM_REPEATS", 4)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 832)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + num_value_embeds = int(os.environ.get("NUM_VALUE_EMBEDS", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.021)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.018)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.018)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + +# MUON OPTIMIZER +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + use_hedge: bool = False, + hedge_eta: float = 0.1, +) -> tuple[float, float]: + """Sliding window eval with batching. Windows of train_seq_len advance by eval_stride. + Only the last stride tokens per window are scored (first window scores all). + Optional Hedge Mixer: online n-gram ensemble over scored tokens.""" + seq_len = args.eval_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + + # With Hedge Mixer: process ALL windows on each rank (sequential, n-gram tables need full context) + # Without: distribute windows across ranks + if use_hedge: + my_windows = window_starts + else: + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + mixer = HedgeMixer(vocab_size=args.vocab_size, device=device, eta=hedge_eta) if use_hedge else None + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi : bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws : end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(x_batch) + + if mixer is not None: + nll = mixer.mix_and_score(logits.float(), x_batch, y_batch, wlens) + else: + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + val_loss_sum += scored_nll.sum() + val_token_count += float(wlen - s) + prev_ids = x_batch[i, s:wlen] + tgt_ids = y_batch[i, s:wlen] + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + # Update n-gram tables with scored tokens + if mixer is not None: + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + mixer.update(y_batch[i, s:wlen]) + + if not use_hedge and dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + + +# POST-TRAINING QUANTIZATION +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing and zstd compressing. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +# Quantization levels: 127 = int8, 31 = int6, 16 = int5. Per-tensor override via MLP_QUANT_LEVELS. +QUANT_LEVELS = int(os.environ.get("QUANT_LEVELS", 127)) +MLP_QUANT_LEVELS = int(os.environ.get("MLP_QUANT_LEVELS", 0)) # 0 = same as QUANT_LEVELS +MLP_TENSOR_PATTERNS = ("mlp.fc.", "mlp.proj.", "fc.weight", "mlp.proj.weight") + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 0.999999] + +def quantize_float_tensor(t: Tensor, ql: int = 0) -> tuple[Tensor, Tensor]: + if ql <= 0: + ql = QUANT_LEVELS + t32 = t.float() + if t32.ndim == 2: + # GPTQ-lite: try multiple clip percentiles per row, pick best MSE + abs_t = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in GPTQ_LITE_PERCENTILES: + clip_abs = ( + torch.quantile(abs_t, pct, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + s = (clip_abs / ql).clamp_min(1e-12) + q = torch.clamp(torch.round(clipped / s[:, None]), -ql, ql) + # Reconstruction error per row + recon = q * s[:, None] + mse = (t32 - recon).square().sum(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = s + else: + better = mse < best_mse + best_mse = torch.where(better, mse, best_mse) + best_q = torch.where(better[:, None], q, best_q) + best_scale = torch.where(better, s, best_scale) + return best_q.to(torch.int8).contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / ql if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -ql, ql).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + mlp_ql = MLP_QUANT_LEVELS if MLP_QUANT_LEVELS > 0 else QUANT_LEVELS + ql = mlp_ql if any(p in name for p in MLP_TENSOR_PATTERNS) else QUANT_LEVELS + q, s = quantize_float_tensor(t, ql=ql) + meta: dict[str, object] = {} + if s.ndim > 0: + meta["scheme"] = "per_row" + meta["axis"] = 0 + if ql != QUANT_LEVELS: + meta["ql"] = ql + if meta: + qmeta[name] = meta + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# DATA LOADING + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# TRANSFORMER MODULES + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def _xsa(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection from attention output (GQA-aware).""" + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(3) # [B, T, Hkv, 1, D] + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, use_xsa: bool = False) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + # XSA: remove self-value bias from attention output + if use_xsa: + y = y.transpose(1, 2).contiguous() # [B, T, H, D] + v_for_xsa = v.transpose(1, 2) # [B, T, Hkv, D] + y = self._xsa(y, v_for_xsa) + y = y.reshape(bsz, seqlen, dim) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu(0.5)^2 MLP — better gradient flow than relu^2 for deep/recurrent models + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, use_xsa: bool = False) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x), use_xsa=use_xsa) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_repeats: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + num_value_embeds: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_repeats = num_repeats + self.xsa_last_n = xsa_last_n + effective_depth = num_layers * num_repeats + self.tok_emb = nn.Embedding(vocab_size, model_dim) + # Value embeddings: extra embedding tables mixed into each effective layer + self.num_value_embeds = num_value_embeds + if num_value_embeds > 0: + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(num_value_embeds)]) + self.value_scales = nn.Parameter(torch.zeros(effective_depth, num_value_embeds, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + # Loop embedding: tells the model which effective layer it's at + self.loop_embed = nn.Parameter(torch.zeros(effective_depth, model_dim, dtype=torch.float32)) + # Cross-repeat skip: each block receives its own output from previous repeat + self.cross_repeat_scales = nn.Parameter(torch.zeros(num_layers, num_repeats - 1, model_dim, dtype=torch.float32)) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # Pre-compute value embeddings once + ve_list: list[Tensor] = [] + if self.num_value_embeds > 0: + for ve in self.value_embeds: + ve_list.append(ve(input_ids)) # (bsz, seq, dim) + + cur_repeats = self.cur_repeats if hasattr(self, "cur_repeats") else self.num_repeats + cur_depth = len(self.blocks) * cur_repeats + xsa_start = max(0, cur_depth - self.xsa_last_n) if self.xsa_last_n > 0 else cur_depth + + num_blocks = len(self.blocks) + prev_block_outputs: list[Tensor | None] = [None] * num_blocks + layer_idx = 0 + for repeat in range(cur_repeats): + for block_idx, block in enumerate(self.blocks): + x = x + self.loop_embed[layer_idx].to(dtype=x.dtype) + # Value embeddings: add weighted extra embeddings at each layer + if layer_idx < self.value_scales.size(0): + for ve_idx, ve_out in enumerate(ve_list): + vs = self.value_scales[layer_idx, ve_idx].to(dtype=x.dtype) + x = x + vs[None, None, :] * ve_out + # Cross-repeat skip: mix in this block's output from previous repeat + if repeat > 0 and prev_block_outputs[block_idx] is not None: + rep_idx = min(repeat - 1, self.cross_repeat_scales.size(1) - 1) + scale = self.cross_repeat_scales[block_idx, rep_idx].to(dtype=x.dtype) + x = x + scale[None, None, :] * prev_block_outputs[block_idx] + x = block(x, x0, use_xsa=(layer_idx >= xsa_start)) + prev_block_outputs[block_idx] = x.detach() if not self.training else x + layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# TRAINING + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # DISTRIBUTED + CUDA SETUP + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # TOKENIZER + VALIDATION METRIC SETUP + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + num_repeats=args.num_repeats, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + num_value_embeds=args.num_value_embeds, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params.append(base_model.loop_embed) + scalar_params.append(base_model.cross_repeat_scales) + if base_model.num_value_embeds > 0: + scalar_params.append(base_model.value_scales) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params = [base_model.tok_emb.weight] + if base_model.num_value_embeds > 0: + embed_params.extend(ve.weight for ve in base_model.value_embeds) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # Progressive depth schedule: parse "frac:repeats,..." and sort + prog_phases: list[tuple[float, int]] = [] + for entry in args.prog_depth_schedule.split(","): + frac_s, rep_s = entry.strip().split(":") + prog_phases.append((float(frac_s), int(rep_s))) + prog_phases.sort() + current_phase_repeats = prog_phases[0][1] if prog_phases else args.num_repeats + base_model.cur_repeats = current_phase_repeats + # Recompile with initial phase depth + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"prog_depth: schedule={prog_phases} starting_repeats={current_phase_repeats}") + + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + # Progressive depth: check if we need to switch phase + # Use synchronized elapsed time (max across ranks) to avoid race conditions + if max_wallclock_ms is not None and prog_phases: + if distributed: + elapsed_tensor = torch.tensor(elapsed_ms, device=device) + dist.all_reduce(elapsed_tensor, op=dist.ReduceOp.MAX) + frac = elapsed_tensor.item() / max_wallclock_ms + else: + frac = elapsed_ms / max_wallclock_ms + new_repeats = prog_phases[-1][1] # default to last + for phase_frac, phase_rep in prog_phases: + if frac < phase_frac: + new_repeats = phase_rep + break + if new_repeats != current_phase_repeats: + current_phase_repeats = new_repeats + base_model.cur_repeats = new_repeats + torch._dynamo.reset() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"prog_depth: switched to {new_repeats} repeats at step:{step} frac:{frac:.2f}") + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown (only at full depth to avoid mixing phases) + at_full_depth = current_phase_repeats == args.num_repeats + if args.swa_enabled and at_full_depth and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().float() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Restore full depth for eval/export + base_model.cur_repeats = args.num_repeats + torch._dynamo.reset() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None: + # Include final weights (may not have landed on swa_every boundary) + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + log0(f"swa: averaging {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed quantized+zstd artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + zstd_level = int(os.environ.get("ZSTD_LEVEL", 22)) + cctx = zstd.ZstdCompressor(level=zstd_level) + quant_blob = cctx.compress(quant_raw) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zstd{zstd_level}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zstd{zstd_level}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + dctx = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval + if args.eval_stride > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"window:{args.eval_seq_len} stride:{args.eval_stride} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Hedge Mixer eval (n-gram ensemble) + if args.use_hedge: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_hm = time.perf_counter() + hm_val_loss, hm_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + use_hedge=True, hedge_eta=args.hedge_eta, + ) + torch.cuda.synchronize() + log0( + f"final_hedge_mixer val_loss:{hm_val_loss:.4f} val_bpb:{hm_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_hm):.0f}ms" + ) + log0(f"final_hedge_mixer_exact val_loss:{hm_val_loss:.8f} val_bpb:{hm_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/train_gpt.py b/train_gpt.py index 0deb0565f..1738288f3 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -16,7 +16,7 @@ import sys import time import uuid -import zlib +import zstandard as zstd from pathlib import Path import numpy as np @@ -27,14 +27,77 @@ from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP -# ----------------------------- + +class HedgeMixer: + """Online mixture of 5 experts via Hedge algorithm for eval-time improvement. + Experts: Neural, Unigram, Bigram, Trigram (hashed), Entropy.""" + def __init__(self, vocab_size: int = 1024, device: str = "cuda", eta: float = 0.1): + self.V = vocab_size + self.device = device + self.eta = eta + self.log_weights = torch.zeros(5, device=device) + self.log_weights[0] = 2.0 # bias toward neural + self.uni_counts = torch.zeros(vocab_size, device=device) + self.bi_counts = torch.zeros(vocab_size, vocab_size, device=device) + self.total_tokens = 0 + self.TRI_HASH = 65536 + self.tri_counts = torch.zeros(self.TRI_HASH, vocab_size, device=device) + self.tri_row_totals = torch.zeros(self.TRI_HASH, device=device) + + def update(self, tokens: Tensor) -> None: + t = tokens.to(self.device).long() + n = t.numel() + if n == 0: + return + self.total_tokens += n + ones = torch.ones(n, device=self.device) + self.uni_counts.scatter_add_(0, t, ones) + if n >= 2: + bi_idx = t[:-1] * self.V + t[1:] + self.bi_counts.reshape(-1).scatter_add_(0, bi_idx, torch.ones(n - 1, device=self.device)) + if n >= 3: + tri_ctx = ((t[:-2] * 36313) ^ (t[1:-1] * 27191)) % self.TRI_HASH + tri_idx = tri_ctx * self.V + t[2:] + ones_tri = torch.ones(n - 2, device=self.device) + self.tri_counts.reshape(-1).scatter_add_(0, tri_idx, ones_tri) + self.tri_row_totals.scatter_add_(0, tri_ctx, ones_tri) + + def mix_and_score(self, neural_logits: Tensor, x_batch: Tensor, y_batch: Tensor, wlens: list[int]) -> Tensor: + bsz, slen, V = neural_logits.shape + uniform_nll = math.log(self.V) + has_data = self.total_tokens > 0 + neural_lp = F.log_softmax(neural_logits, dim=-1) + neural_nll = -neural_lp.gather(2, y_batch.unsqueeze(2)).squeeze(2) + if not has_data or self.total_tokens < 10000: + return neural_nll + uni_probs = (self.uni_counts + 0.1) / (self.total_tokens + 0.1 * self.V) + uni_nll = -uni_probs.log()[y_batch] + bi_total = self.bi_counts.sum(dim=1, keepdim=True) + bi_probs = (self.bi_counts + 0.1) / (bi_total + 0.1 * self.V) + bi_nll = -bi_probs.log()[x_batch.reshape(-1), y_batch.reshape(-1)].reshape(bsz, slen) + if slen >= 2: + prev2 = torch.zeros_like(x_batch) + prev2[:, 1:] = x_batch[:, :-1] + ctx_hash = ((prev2 * 36313) ^ (x_batch * 27191)) % self.TRI_HASH + tri_count = self.tri_counts[ctx_hash.reshape(-1).long(), y_batch.reshape(-1).long()] + tri_total = self.tri_row_totals[ctx_hash.reshape(-1).long()].clamp(min=1) + tri_nll = -(((tri_count + 0.01) / (tri_total + 0.01 * self.V)).log()).reshape(bsz, slen) + else: + tri_nll = torch.full((bsz, slen), uniform_nll, device=self.device) + entropy_nll = -(neural_lp.exp() * neural_lp).sum(-1) + expert_nll = torch.stack([neural_nll, uni_nll, bi_nll, tri_nll, entropy_nll], dim=-1) + log_w = self.log_weights - self.log_weights.logsumexp(0) + mixed_nll = -(-expert_nll + log_w.unsqueeze(0).unsqueeze(0)).logsumexp(dim=-1) + # Update weights + wlens_t = torch.tensor(wlens, device=self.device, dtype=torch.long) + mask = torch.arange(slen, device=self.device).unsqueeze(0) < wlens_t.unsqueeze(1) + masked_nll = expert_nll * mask.unsqueeze(-1).float() + expert_mean_loss = masked_nll.sum(dim=(0, 1)) / mask.sum().clamp(min=1) + self.log_weights -= self.eta * expert_mean_loss + return mixed_nll + + # HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap class Hyperparameters: # Data paths are shard globs produced by the existing preprocessing pipeline. @@ -52,20 +115,44 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2000)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + + # Progressive Depth: train with fewer repeats early (faster), more repeats later (deeper). + # Schedule format: "frac1:rep1,frac2:rep2,..." e.g. "0.4:2,0.65:3,1.0:4" + prog_depth_schedule = os.environ.get("PROG_DEPTH", "0.4:2,0.65:3,1.0:4") + + # XSA (Exclusive Self-Attention) on last N effective layers. + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + + # SWA (Stochastic Weight Averaging) during warmdown. + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Sliding window eval. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 1024)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + # Hedge Mixer (eval-time n-gram ensemble). + use_hedge = bool(int(os.environ.get("USE_HEDGE", "1"))) + hedge_eta = float(os.environ.get("HEDGE_ETA", 0.1)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_layers = int(os.environ.get("NUM_LAYERS", 3)) + num_repeats = int(os.environ.get("NUM_REPEATS", 4)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) + model_dim = int(os.environ.get("MODEL_DIM", 832)) num_heads = int(os.environ.get("NUM_HEADS", 8)) mlp_mult = int(os.environ.get("MLP_MULT", 2)) + num_value_embeds = int(os.environ.get("NUM_VALUE_EMBEDS", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) @@ -73,22 +160,20 @@ class Hyperparameters: # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.6)) head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.021)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.018)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.018)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) -# ----------------------------- # MUON OPTIMIZER -# ----------------------------- # # As borrowed from modded-nanogpt # Background on Muon: https://kellerjordan.github.io/posts/muon/ @@ -110,10 +195,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) - class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): super().__init__( params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), ) @torch.no_grad() @@ -159,18 +244,19 @@ def step(self, closure=None): if distributed: dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) curr = 0 for p in params: g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) p.add_(g, alpha=-lr) curr += p.numel() return loss -# ----------------------------- # TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- # # It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. # Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. @@ -277,13 +363,112 @@ def eval_val( model.train() return float(val_loss.item()), float(bits_per_token * tokens_per_byte) -# ----------------------------- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + use_hedge: bool = False, + hedge_eta: float = 0.1, +) -> tuple[float, float]: + """Sliding window eval with batching. Windows of train_seq_len advance by eval_stride. + Only the last stride tokens per window are scored (first window scores all). + Optional Hedge Mixer: online n-gram ensemble over scored tokens.""" + seq_len = args.eval_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + + # With Hedge Mixer: process ALL windows on each rank (sequential, n-gram tables need full context) + # Without: distribute windows across ranks + if use_hedge: + my_windows = window_starts + else: + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + mixer = HedgeMixer(vocab_size=args.vocab_size, device=device, eta=hedge_eta) if use_hedge else None + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi : bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws : end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(x_batch) + + if mixer is not None: + nll = mixer.mix_and_score(logits.float(), x_batch, y_batch, wlens) + else: + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + val_loss_sum += scored_nll.sum() + val_token_count += float(wlen - s) + prev_ids = x_batch[i, s:wlen] + tgt_ids = y_batch[i, s:wlen] + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + # Update n-gram tables with scored tokens + if mixer is not None: + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + mixer.update(y_batch[i, s:wlen]) + + if not use_hedge and dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + + # POST-TRAINING QUANTIZATION -# ----------------------------- # # It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. +# Instead, we get approximately the same model (with a small hit) by quantizing and zstd compressing. CONTROL_TENSOR_NAME_PATTERNS = tuple( pattern @@ -306,6 +491,10 @@ def eval_val( INT8_PER_ROW_SCALE_DTYPE = torch.float16 INT8_CLIP_PERCENTILE = 99.99984 INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +# Quantization levels: 127 = int8, 31 = int6, 16 = int5. Per-tensor override via MLP_QUANT_LEVELS. +QUANT_LEVELS = int(os.environ.get("QUANT_LEVELS", 127)) +MLP_QUANT_LEVELS = int(os.environ.get("MLP_QUANT_LEVELS", 0)) # 0 = same as QUANT_LEVELS +MLP_TENSOR_PATTERNS = ("mlp.fc.", "mlp.proj.", "fc.weight", "mlp.proj.weight") def tensor_nbytes(t: Tensor) -> int: return int(t.numel()) * int(t.element_size()) @@ -318,25 +507,44 @@ def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, s return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() return t -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: +GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 0.999999] + +def quantize_float_tensor(t: Tensor, ql: int = 0) -> tuple[Tensor, Tensor]: + if ql <= 0: + ql = QUANT_LEVELS t32 = t.float() if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + # GPTQ-lite: try multiple clip percentiles per row, pick best MSE + abs_t = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in GPTQ_LITE_PERCENTILES: + clip_abs = ( + torch.quantile(abs_t, pct, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + s = (clip_abs / ql).clamp_min(1e-12) + q = torch.clamp(torch.round(clipped / s[:, None]), -ql, ql) + # Reconstruction error per row + recon = q * s[:, None] + mse = (t32 - recon).square().sum(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = s + else: + better = mse < best_mse + best_mse = torch.where(better, mse, best_mse) + best_q = torch.where(better[:, None], q, best_q) + best_scale = torch.where(better, s, best_scale) + return best_q.to(torch.int8).contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - # Vectors / scalars use a simpler per-tensor scale. clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + scale = torch.tensor(clip_abs / ql if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -ql, ql).to(torch.int8).contiguous() return q, scale def quantize_state_dict_int8(state_dict: dict[str, Tensor]): @@ -377,9 +585,17 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): continue stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) + mlp_ql = MLP_QUANT_LEVELS if MLP_QUANT_LEVELS > 0 else QUANT_LEVELS + ql = mlp_ql if any(p in name for p in MLP_TENSOR_PATTERNS) else QUANT_LEVELS + q, s = quantize_float_tensor(t, ql=ql) + meta: dict[str, object] = {} if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} + meta["scheme"] = "per_row" + meta["axis"] = 0 + if ql != QUANT_LEVELS: + meta["ql"] = ql + if meta: + qmeta[name] = meta quantized[name] = q scales[name] = s dtypes[name] = str(t.dtype).removeprefix("torch.") @@ -422,9 +638,7 @@ def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: return out -# ----------------------------- # DATA LOADING -# ----------------------------- def load_data_shard(file: Path) -> Tensor: header_bytes = 256 * np.dtype(" y = local[1:].reshape(-1, seq_len) return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) -# ----------------------------- # TRANSFORMER MODULES -# ----------------------------- class RMSNorm(nn.Module): def __init__(self, eps: float | None = None): @@ -580,7 +792,17 @@ def __init__( self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) self.rotary = Rotary(self.head_dim, base=rope_base) - def forward(self, x: Tensor) -> Tensor: + def _xsa(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection from attention output (GQA-aware).""" + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(3) # [B, T, Hkv, 1, D] + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, use_xsa: bool = False) -> Tensor: bsz, seqlen, dim = x.shape q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) @@ -599,12 +821,19 @@ def forward(self, x: Tensor) -> Tensor: is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads), ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + # XSA: remove self-value bias from attention output + if use_xsa: + y = y.transpose(1, 2).contiguous() # [B, T, H, D] + v_for_xsa = v.transpose(1, 2) # [B, T, Hkv, D] + y = self._xsa(y, v_for_xsa) + y = y.reshape(bsz, seqlen, dim) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) return self.proj(y) class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup + # leaky_relu(0.5)^2 MLP — better gradient flow than relu^2 for deep/recurrent models def __init__(self, dim: int, mlp_mult: int): super().__init__() hidden = mlp_mult * dim @@ -613,7 +842,7 @@ def __init__(self, dim: int, mlp_mult: int): self.proj._zero_init = True def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) + x = F.leaky_relu(self.fc(x), negative_slope=0.5) return self.proj(x.square()) @@ -636,10 +865,10 @@ def __init__( self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - def forward(self, x: Tensor, x0: Tensor) -> Tensor: + def forward(self, x: Tensor, x0: Tensor, use_xsa: bool = False) -> Tensor: mix = self.resid_mix.to(dtype=x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) + attn_out = self.attn(self.attn_norm(x), use_xsa=use_xsa) x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) return x @@ -650,15 +879,18 @@ def __init__( self, vocab_size: int, num_layers: int, + num_repeats: int, model_dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + num_value_embeds: int, tie_embeddings: bool, tied_embed_init_std: float, logit_softcap: float, rope_base: float, qk_gain_init: float, + xsa_last_n: int = 0, ): super().__init__() if logit_softcap <= 0.0: @@ -666,11 +898,15 @@ def __init__( self.tie_embeddings = tie_embeddings self.tied_embed_init_std = tied_embed_init_std self.logit_softcap = logit_softcap + self.num_repeats = num_repeats + self.xsa_last_n = xsa_last_n + effective_depth = num_layers * num_repeats self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Value embeddings: extra embedding tables mixed into each effective layer + self.num_value_embeds = num_value_embeds + if num_value_embeds > 0: + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(num_value_embeds)]) + self.value_scales = nn.Parameter(torch.zeros(effective_depth, num_value_embeds, model_dim, dtype=torch.float32)) self.blocks = nn.ModuleList( [ Block( @@ -684,6 +920,10 @@ def __init__( for i in range(num_layers) ] ) + # Loop embedding: tells the model which effective layer it's at + self.loop_embed = nn.Parameter(torch.zeros(effective_depth, model_dim, dtype=torch.float32)) + # Cross-repeat skip: each block receives its own output from previous repeat + self.cross_repeat_scales = nn.Parameter(torch.zeros(num_layers, num_repeats - 1, model_dim, dtype=torch.float32)) self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) if self.lm_head is not None: @@ -697,23 +937,42 @@ def _init_weights(self) -> None: if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): nn.init.zeros_(module.weight) - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + def forward_logits(self, input_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) x = F.rms_norm(x, (x.size(-1),)) x0 = x - skips: list[Tensor] = [] - - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) - - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) + + # Pre-compute value embeddings once + ve_list: list[Tensor] = [] + if self.num_value_embeds > 0: + for ve in self.value_embeds: + ve_list.append(ve(input_ids)) # (bsz, seq, dim) + + cur_repeats = self.cur_repeats if hasattr(self, "cur_repeats") else self.num_repeats + cur_depth = len(self.blocks) * cur_repeats + xsa_start = max(0, cur_depth - self.xsa_last_n) if self.xsa_last_n > 0 else cur_depth + + num_blocks = len(self.blocks) + prev_block_outputs: list[Tensor | None] = [None] * num_blocks + layer_idx = 0 + for repeat in range(cur_repeats): + for block_idx, block in enumerate(self.blocks): + x = x + self.loop_embed[layer_idx].to(dtype=x.dtype) + # Value embeddings: add weighted extra embeddings at each layer + if layer_idx < self.value_scales.size(0): + for ve_idx, ve_out in enumerate(ve_list): + vs = self.value_scales[layer_idx, ve_idx].to(dtype=x.dtype) + x = x + vs[None, None, :] * ve_out + # Cross-repeat skip: mix in this block's output from previous repeat + if repeat > 0 and prev_block_outputs[block_idx] is not None: + rep_idx = min(repeat - 1, self.cross_repeat_scales.size(1) - 1) + scale = self.cross_repeat_scales[block_idx, rep_idx].to(dtype=x.dtype) + x = x + scale[None, None, :] * prev_block_outputs[block_idx] + x = block(x, x0, use_xsa=(layer_idx >= xsa_start)) + prev_block_outputs[block_idx] = x.detach() if not self.training else x + layer_idx += 1 + + x = self.final_norm(x) if self.tie_embeddings: logits_proj = F.linear(x, self.tok_emb.weight) else: @@ -721,12 +980,16 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: raise RuntimeError("lm_head is required when tie_embeddings=False") logits_proj = self.lm_head(x) logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) return F.cross_entropy(logits.float(), targets, reduction="mean") -# ----------------------------- # TRAINING -# ----------------------------- def main() -> None: global zeropower_via_newtonschulz5 @@ -735,19 +998,15 @@ def main() -> None: args = Hyperparameters() zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - + # DISTRIBUTED + CUDA SETUP + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) local_rank = int(os.environ.get("LOCAL_RANK", "0")) if world_size <= 0: raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size + grad_accum_steps = max(1, 8 // world_size) grad_scale = 1.0 / grad_accum_steps if not torch.cuda.is_available(): raise RuntimeError("CUDA is required") @@ -793,10 +1052,8 @@ def log0(msg: str, console: bool = True) -> None: ) log0("=" * 100, console=False) - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - + # TOKENIZER + VALIDATION METRIC SETUP + random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -819,22 +1076,23 @@ def log0(msg: str, console: bool = True) -> None: log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - + # MODEL + OPTIMIZER SETUP + base_model = GPT( vocab_size=args.vocab_size, num_layers=args.num_layers, + num_repeats=args.num_repeats, model_dim=args.model_dim, num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + num_value_embeds=args.num_value_embeds, tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + xsa_last_n=args.xsa_last_n, ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, CastedLinear): @@ -859,11 +1117,16 @@ def log0(msg: str, console: bool = True) -> None: for name, p in block_named_params if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.loop_embed) + scalar_params.append(base_model.cross_repeat_scales) + if base_model.num_value_embeds > 0: + scalar_params.append(base_model.value_scales) token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params = [base_model.tok_emb.weight] + if base_model.num_value_embeds > 0: + embed_params.extend(ve.weight for ve in base_model.value_embeds) optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, @@ -873,6 +1136,7 @@ def log0(msg: str, console: bool = True) -> None: lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, ) for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr @@ -909,10 +1173,8 @@ def log0(msg: str, console: bool = True) -> None: ) log0(f"seed:{args.seed}") - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) def zero_grad_all() -> None: @@ -960,12 +1222,26 @@ def lr_mul(step: int, elapsed_ms: float) -> float: model.require_backward_grad_sync = True train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - + # MAIN TRAINING LOOP + training_time_ms = 0.0 stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # Progressive depth schedule: parse "frac:repeats,..." and sort + prog_phases: list[tuple[float, int]] = [] + for entry in args.prog_depth_schedule.split(","): + frac_s, rep_s = entry.strip().split(":") + prog_phases.append((float(frac_s), int(rep_s))) + prog_phases.sort() + current_phase_repeats = prog_phases[0][1] if prog_phases else args.num_repeats + base_model.cur_repeats = current_phase_repeats + # Recompile with initial phase depth + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"prog_depth: schedule={prog_phases} starting_repeats={current_phase_repeats}") + torch.cuda.synchronize() t0 = time.perf_counter() @@ -1005,6 +1281,27 @@ def lr_mul(step: int, elapsed_ms: float) -> float: break elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + # Progressive depth: check if we need to switch phase + # Use synchronized elapsed time (max across ranks) to avoid race conditions + if max_wallclock_ms is not None and prog_phases: + if distributed: + elapsed_tensor = torch.tensor(elapsed_ms, device=device) + dist.all_reduce(elapsed_tensor, op=dist.ReduceOp.MAX) + frac = elapsed_tensor.item() / max_wallclock_ms + else: + frac = elapsed_ms / max_wallclock_ms + new_repeats = prog_phases[-1][1] # default to last + for phase_frac, phase_rep in prog_phases: + if frac < phase_frac: + new_repeats = phase_rep + break + if new_repeats != current_phase_repeats: + current_phase_repeats = new_repeats + base_model.cur_repeats = new_repeats + torch._dynamo.reset() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"prog_depth: switched to {new_repeats} repeats at step:{step} frac:{frac:.2f}") scale = lr_mul(step, elapsed_ms) zero_grad_all() train_loss = torch.zeros((), device=device) @@ -1035,6 +1332,19 @@ def lr_mul(step: int, elapsed_ms: float) -> float: step += 1 approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown (only at full depth to avoid mixing phases) + at_full_depth = current_phase_repeats == args.num_repeats + if args.swa_enabled and at_full_depth and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().float() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + should_log_train = ( args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) @@ -1059,11 +1369,29 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" ) - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. + # Restore full depth for eval/export + base_model.cur_repeats = args.num_repeats + torch._dynamo.reset() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None: + # Include final weights (may not have landed on swa_every boundary) + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + log0(f"swa: averaging {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed quantized+zstd artifact and validate the round-tripped weights. if master_process: torch.save(base_model.state_dict(), "final_model.pt") @@ -1077,7 +1405,9 @@ def lr_mul(step: int, elapsed_ms: float) -> float: quant_buf = io.BytesIO() torch.save(quant_obj, quant_buf) quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) + zstd_level = int(os.environ.get("ZSTD_LEVEL", 22)) + cctx = zstd.ZstdCompressor(level=zstd_level) + quant_blob = cctx.compress(quant_raw) quant_raw_bytes = len(quant_raw) if master_process: with open("final_model.int8.ptz", "wb") as f: @@ -1086,16 +1416,17 @@ def lr_mul(step: int, elapsed_ms: float) -> float: code_bytes = len(code.encode("utf-8")) ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"Serialized model int8+zstd{zstd_level}: {quant_file_bytes} bytes " f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size int8+zstd{zstd_level}: {quant_file_bytes + code_bytes} bytes") if distributed: dist.barrier() with open("final_model.int8.ptz", "rb") as f: quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + dctx = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) torch.cuda.synchronize() t_qeval = time.perf_counter() @@ -1113,10 +1444,51 @@ def lr_mul(step: int, elapsed_ms: float) -> float: ) torch.cuda.synchronize() log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"final_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + log0(f"final_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval + if args.eval_stride > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"window:{args.eval_seq_len} stride:{args.eval_stride} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Hedge Mixer eval (n-gram ensemble) + if args.use_hedge: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_hm = time.perf_counter() + hm_val_loss, hm_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + use_hedge=True, hedge_eta=args.hedge_eta, + ) + torch.cuda.synchronize() + log0( + f"final_hedge_mixer val_loss:{hm_val_loss:.4f} val_bpb:{hm_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_hm):.0f}ms" + ) + log0(f"final_hedge_mixer_exact val_loss:{hm_val_loss:.8f} val_bpb:{hm_val_bpb:.8f}") if distributed: dist.destroy_process_group()