From 2effff3df3ec0855db2118afb4e8f3063c46634a Mon Sep 17 00:00:00 2001 From: Ivan Verbovoy Date: Thu, 19 Mar 2026 11:18:58 +0200 Subject: [PATCH 01/11] Depth recurrence + cross-repeat skip + value embeddings - Replace 9 unique blocks with 3 blocks x 4 repeats (12 effective layers) - Increase dim from 512 to 832, remove U-Net skips - Add loop_embed for timestep encoding per effective layer - Add cross-repeat skip: each block mixes in its output from previous repeat with per-repeat learned scales (stateful recurrence) - Add 2 value embedding tables mixed into each layer with learned scales - 17.14M params, best result: 1.6780 bpb (int8+zlib) on 2000 steps batch 8K --- train_gpt.py | 69 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 51 insertions(+), 18 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 0deb0565f..f2a97ea6c 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -61,11 +61,13 @@ class Hyperparameters: # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) - num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_layers = int(os.environ.get("NUM_LAYERS", 3)) + num_repeats = int(os.environ.get("NUM_REPEATS", 4)) num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) - model_dim = int(os.environ.get("MODEL_DIM", 512)) + model_dim = int(os.environ.get("MODEL_DIM", 832)) num_heads = int(os.environ.get("NUM_HEADS", 8)) mlp_mult = int(os.environ.get("MLP_MULT", 2)) + num_value_embeds = int(os.environ.get("NUM_VALUE_EMBEDS", 2)) tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) @@ -650,10 +652,12 @@ def __init__( self, vocab_size: int, num_layers: int, + num_repeats: int, model_dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + num_value_embeds: int, tie_embeddings: bool, tied_embed_init_std: float, logit_softcap: float, @@ -666,11 +670,14 @@ def __init__( self.tie_embeddings = tie_embeddings self.tied_embed_init_std = tied_embed_init_std self.logit_softcap = logit_softcap + self.num_repeats = num_repeats + effective_depth = num_layers * num_repeats self.tok_emb = nn.Embedding(vocab_size, model_dim) - self.num_encoder_layers = num_layers // 2 - self.num_decoder_layers = num_layers - self.num_encoder_layers - self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) - self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Value embeddings: extra embedding tables mixed into each effective layer + self.num_value_embeds = num_value_embeds + if num_value_embeds > 0: + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(num_value_embeds)]) + self.value_scales = nn.Parameter(torch.zeros(effective_depth, num_value_embeds, model_dim, dtype=torch.float32)) self.blocks = nn.ModuleList( [ Block( @@ -684,6 +691,11 @@ def __init__( for i in range(num_layers) ] ) + # Loop embedding: tells the model which effective layer it's at + self.loop_embed = nn.Parameter(torch.zeros(effective_depth, model_dim, dtype=torch.float32)) + # Cross-repeat skip: each block remembers its output from previous repeat + # Per-repeat scales (repeat 0 has no prev, so num_repeats-1 scales per block) + self.cross_repeat_scales = nn.Parameter(torch.zeros(num_layers, num_repeats - 1, model_dim, dtype=torch.float32)) self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) if self.lm_head is not None: @@ -701,16 +713,30 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) x = F.rms_norm(x, (x.size(-1),)) x0 = x - skips: list[Tensor] = [] - # First half stores skips; second half reuses them in reverse order. - for i in range(self.num_encoder_layers): - x = self.blocks[i](x, x0) - skips.append(x) - for i in range(self.num_decoder_layers): - if skips: - x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() - x = self.blocks[self.num_encoder_layers + i](x, x0) + # Pre-compute value embeddings once + ve_list: list[Tensor] = [] + if self.num_value_embeds > 0: + for ve in self.value_embeds: + ve_list.append(ve(input_ids)) # (bsz, seq, dim) + + num_blocks = len(self.blocks) + prev_block_outputs: list[Tensor | None] = [None] * num_blocks + layer_idx = 0 + for repeat in range(self.num_repeats): + for block_idx, block in enumerate(self.blocks): + x = x + self.loop_embed[layer_idx].to(dtype=x.dtype) + # Value embeddings: add weighted extra embeddings at each layer + for ve_idx, ve_out in enumerate(ve_list): + vs = self.value_scales[layer_idx, ve_idx].to(dtype=x.dtype) + x = x + vs[None, None, :] * ve_out + # Cross-repeat skip: mix in this block's output from previous repeat + if repeat > 0 and prev_block_outputs[block_idx] is not None: + scale = self.cross_repeat_scales[block_idx, repeat - 1].to(dtype=x.dtype) + x = x + scale[None, None, :] * prev_block_outputs[block_idx] + x = block(x, x0) + prev_block_outputs[block_idx] = x.detach() if not self.training else x + layer_idx += 1 x = self.final_norm(x).reshape(-1, x.size(-1)) targets = target_ids.reshape(-1) @@ -826,10 +852,12 @@ def log0(msg: str, console: bool = True) -> None: base_model = GPT( vocab_size=args.vocab_size, num_layers=args.num_layers, + num_repeats=args.num_repeats, model_dim=args.model_dim, num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + num_value_embeds=args.num_value_embeds, tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, logit_softcap=args.logit_softcap, @@ -859,11 +887,16 @@ def log0(msg: str, console: bool = True) -> None: for name, p in block_named_params if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) ] - if base_model.skip_weights.numel() > 0: - scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.loop_embed) + scalar_params.append(base_model.cross_repeat_scales) + if base_model.num_value_embeds > 0: + scalar_params.append(base_model.value_scales) token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params = [base_model.tok_emb.weight] + if base_model.num_value_embeds > 0: + embed_params.extend(ve.weight for ve in base_model.value_embeds) optimizer_tok = torch.optim.Adam( - [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, From f1752bf3862feec4f9abf3cf2411b538d41e21c6 Mon Sep 17 00:00:00 2001 From: Ivan Verbovoy Date: Thu, 19 Mar 2026 14:27:36 +0200 Subject: [PATCH 02/11] Add Test-Time Training (TTT) for eval MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add eval_val_ttt: adapts model on each val batch before evaluating - For each batch: save weights → K gradient steps → evaluate → restore - Controlled by TTT_STEPS (default 0 = disabled) and TTT_LR (default 1e-4) - Result: -0.010 bpb improvement on 200-step test (2.4124 → 2.4027) - TTT eval runs after normal roundtrip eval, reports both scores --- train_gpt.py | 109 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/train_gpt.py b/train_gpt.py index f2a97ea6c..db419134b 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -58,6 +58,8 @@ class Hyperparameters: train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + ttt_steps = int(os.environ.get("TTT_STEPS", 0)) + ttt_lr = float(os.environ.get("TTT_LR", 1e-4)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) @@ -279,6 +281,87 @@ def eval_val( model.train() return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +def eval_val_ttt( + args: Hyperparameters, + base_model: nn.Module, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Test-Time Training: adapt the model on each validation batch before evaluating. + # For each batch: save weights → K gradient steps → evaluate → restore weights. + if args.ttt_steps <= 0: + return eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Save original weights once + saved_state = {k: v.detach().clone() for k, v in base_model.state_dict().items()} + + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + + # TTT: adapt on this batch + model.train() + for _ttt_step in range(args.ttt_steps): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ttt_loss = model(x, y) + ttt_loss.backward() + with torch.no_grad(): + for p in base_model.parameters(): + if p.grad is not None: + p -= args.ttt_lr * p.grad + p.grad = None + + # Evaluate with adapted model + model.eval() + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + # Restore original weights + base_model.load_state_dict(saved_state, strict=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + # ----------------------------- # POST-TRAINING QUANTIZATION # ----------------------------- @@ -1151,6 +1234,32 @@ def lr_mul(step: int, elapsed_ms: float) -> float: ) log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + # TTT eval: adapt model on each batch before evaluating + if args.ttt_steps > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt( + args, + base_model, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_ttt val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"ttt_steps:{args.ttt_steps} ttt_lr:{args.ttt_lr} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"final_ttt_exact val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f}") + if distributed: dist.destroy_process_group() From 449aebfbb5798b58fafd1b0ac099a2f74899e79f Mon Sep 17 00:00:00 2001 From: Ivan Verbovoy Date: Fri, 20 Mar 2026 02:22:10 +0200 Subject: [PATCH 03/11] Add sliding window eval, lower LR, train@1024, grad clip MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Sliding window eval: window=1024, stride=256, ~-0.034 bpb - forward_logits() method for sliding window support - LR x0.3: matrix=0.012, embed=0.015, scalar=0.012 (sweep winner) - GRAD_CLIP_NORM=0.3 for recurrence stability - WARMDOWN_ITERS=3000 - train@1024 (not 2048) — better for recurrence (160ms vs 253ms/step) - Fix grad_accum for non-power-of-2 GPU counts - Best result: 1.2308 bpb sliding window on 6xH100 (3726 steps) --- sweep.sh | 46 ++++++++++++++++++++ train_gpt.py | 119 ++++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 154 insertions(+), 11 deletions(-) create mode 100755 sweep.sh diff --git a/sweep.sh b/sweep.sh new file mode 100755 index 000000000..5e6f9f8aa --- /dev/null +++ b/sweep.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Hyperparameter sweep — run overnight on 3060 +# Each run: 2000 steps, batch 8K, no TTT + +export ITERATIONS=2000 +export TRAIN_BATCH_TOKENS=8192 +export VAL_LOSS_EVERY=0 +export VAL_BATCH_SIZE=8192 +export MAX_WALLCLOCK_SECONDS=0 +export TTT_STEPS=0 + +echo "=== Starting sweep at $(date) ===" + +# 1. Baseline (current defaults: matrix_lr=0.04, embed_lr=0.05, scalar_lr=0.04) +echo "--- Run 1: baseline ---" +RUN_ID=sweep_baseline torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | grep -E "(model_params|step:2000|final_int8_zlib_roundtrip_exact)" + +# 2. All lr x1.5 +echo "--- Run 2: lr x1.5 ---" +RUN_ID=sweep_lr15 MATRIX_LR=0.06 TIED_EMBED_LR=0.075 SCALAR_LR=0.06 torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | grep -E "(step:2000|final_int8_zlib_roundtrip_exact)" + +# 3. All lr x2.0 +echo "--- Run 3: lr x2.0 ---" +RUN_ID=sweep_lr20 MATRIX_LR=0.08 TIED_EMBED_LR=0.1 SCALAR_LR=0.08 torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | grep -E "(step:2000|final_int8_zlib_roundtrip_exact)" + +# 4. All lr x0.5 +echo "--- Run 4: lr x0.5 ---" +RUN_ID=sweep_lr05 MATRIX_LR=0.02 TIED_EMBED_LR=0.025 SCALAR_LR=0.02 torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | grep -E "(step:2000|final_int8_zlib_roundtrip_exact)" + +# 5. Lower embed_lr ratio (embed_lr = 0.3x matrix_lr) +echo "--- Run 5: low embed_lr ---" +RUN_ID=sweep_lowemb TIED_EMBED_LR=0.012 torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | grep -E "(step:2000|final_int8_zlib_roundtrip_exact)" + +# 6. Longer warmdown (2400 iters) +echo "--- Run 6: warmdown_iters=2400 ---" +RUN_ID=sweep_wd2400 WARMDOWN_ITERS=2400 torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | grep -E "(step:2000|final_int8_zlib_roundtrip_exact)" + +# 7. Higher muon momentum +echo "--- Run 7: muon_momentum=0.98 ---" +RUN_ID=sweep_mom98 MUON_MOMENTUM=0.98 torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | grep -E "(step:2000|final_int8_zlib_roundtrip_exact)" + +# 8. Matrix lr x1.5 + lower embed +echo "--- Run 8: matrix_lr=0.06 + embed_lr=0.02 ---" +RUN_ID=sweep_combo MATRIX_LR=0.06 TIED_EMBED_LR=0.02 torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | grep -E "(step:2000|final_int8_zlib_roundtrip_exact)" + +echo "=== Sweep done at $(date) ===" diff --git a/train_gpt.py b/train_gpt.py index db419134b..5632db266 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -52,15 +52,20 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 1200)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) ttt_steps = int(os.environ.get("TTT_STEPS", 0)) ttt_lr = float(os.environ.get("TTT_LR", 1e-4)) + # Sliding window eval. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) + # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) num_layers = int(os.environ.get("NUM_LAYERS", 3)) @@ -77,10 +82,10 @@ class Hyperparameters: # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.6)) head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.05)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.015)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.04)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.04)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.012)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.012)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) @@ -88,7 +93,6 @@ class Hyperparameters: beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) - grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0)) # ----------------------------- # MUON OPTIMIZER @@ -282,6 +286,72 @@ def eval_val( return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Sliding window eval: each window is eval_seq_len tokens, advancing by eval_stride. + Loss is scored only on the last eval_stride tokens per window.""" + seq_len = args.eval_seq_len + stride = args.eval_stride + total_tokens = val_tokens.numel() + + starts: list[int] = [] + pos = 0 + while pos + seq_len < total_tokens: + starts.append(pos) + pos += stride + total_windows = len(starts) + win_start = (total_windows * rank) // world_size + win_end = (total_windows * (rank + 1)) // world_size + score_offset = seq_len - stride + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.no_grad(): + for wi in range(win_start, win_end): + s = starts[wi] + window = val_tokens[s : s + seq_len + 1].to(device=device, dtype=torch.int64) + x = window[:-1].unsqueeze(0) + y = window[1:].unsqueeze(0) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(x) + + tail_logits = logits[0, score_offset:, :].float() + tail_targets = y[0, score_offset:] + per_token_loss = F.cross_entropy(tail_logits, tail_targets, reduction="none") + val_loss_sum += per_token_loss.to(torch.float64).sum() + val_token_count += float(stride) + + tail_prev = x[0, score_offset:] + tail_tgt = y[0, score_offset:] + token_bytes = base_bytes_lut[tail_tgt].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tail_tgt] & ~is_boundary_token_lut[tail_prev]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + def eval_val_ttt( args: Hyperparameters, base_model: nn.Module, @@ -792,7 +862,7 @@ def _init_weights(self) -> None: if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): nn.init.zeros_(module.weight) - def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + def forward_logits(self, input_ids: Tensor) -> Tensor: x = self.tok_emb(input_ids) x = F.rms_norm(x, (x.size(-1),)) x0 = x @@ -821,8 +891,7 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: prev_block_outputs[block_idx] = x.detach() if not self.training else x layer_idx += 1 - x = self.final_norm(x).reshape(-1, x.size(-1)) - targets = target_ids.reshape(-1) + x = self.final_norm(x) if self.tie_embeddings: logits_proj = F.linear(x, self.tok_emb.weight) else: @@ -830,6 +899,12 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: raise RuntimeError("lm_head is required when tie_embeddings=False") logits_proj = self.lm_head(x) logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) return F.cross_entropy(logits.float(), targets, reduction="mean") @@ -854,9 +929,7 @@ def main() -> None: local_rank = int(os.environ.get("LOCAL_RANK", "0")) if world_size <= 0: raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") - if 8 % world_size != 0: - raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") - grad_accum_steps = 8 // world_size + grad_accum_steps = max(1, 8 // world_size) grad_scale = 1.0 / grad_accum_steps if not torch.cuda.is_available(): raise RuntimeError("CUDA is required") @@ -1234,6 +1307,30 @@ def lr_mul(step: int, elapsed_ms: float) -> float: ) log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + # Sliding window eval + if args.eval_stride > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"window:{args.eval_seq_len} stride:{args.eval_stride} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + # TTT eval: adapt model on each batch before evaluating if args.ttt_steps > 0: base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) From fa293060e34a5cc455e6392fbcb337c4f1a5b42f Mon Sep 17 00:00:00 2001 From: Ivan Verbovoy Date: Fri, 20 Mar 2026 03:40:09 +0200 Subject: [PATCH 04/11] Add submission: Depth Recurrence + Cross-Repeat Skip + Sliding Window --- .gitignore | 4 +- .../README.md | 58 + .../submission.json | 16 + .../train.log | 84 + .../train_gpt.py | 1365 +++++++++++++++++ train_gpt.py | 14 +- 6 files changed, 1533 insertions(+), 8 deletions(-) create mode 100644 records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/README.md create mode 100644 records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/submission.json create mode 100644 records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/train.log create mode 100644 records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/train_gpt.py diff --git a/.gitignore b/.gitignore index 3423c416a..9260888ec 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,6 @@ data/manifest.json data/docs_selected.jsonl .mypy_cache/ .venv -logs/ \ No newline at end of file +logs/ +final_model.* +sweep.sh \ No newline at end of file diff --git a/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/README.md b/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/README.md new file mode 100644 index 000000000..7fdbb7475 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/README.md @@ -0,0 +1,58 @@ +## Depth Recurrence + Cross-Repeat Skip + Value Embeddings + +Beats naive baseline (1.2244) by 0.005 bpb using 3.1x fewer training steps through stateful depth recurrence. + +val_bpb = 1.2196 (sliding window eval on int8+zlib roundtrip model, stride=256) +val_bpb = 1.2533 (standard int8+zlib roundtrip) + +### Architecture + +Replaced the baseline's 9 unique transformer blocks with 3 shared blocks repeated 4 times (12 effective layers). Trades unique parameters for effective depth. + +Changes from baseline: +- Depth recurrence: 3 blocks x 4 repeats = 12 effective layers (vs 9 in baseline) +- Cross-Repeat Skip (original): each block gets a weighted residual of its own output from the previous repeat, turning stateless recurrence into stateful. Per-repeat learned scales, ~7.5K params total. +- Value Embeddings: 2 extra embedding tables mixed into the residual stream at each effective layer with learned scales. From snimu's modded-nanogpt record. +- Loop Embedding: learned per-layer vector added before each block as depth-wise positional encoding. +- Model dim 832 (vs 512), 8 heads, 4 KV heads, MLP 2x +- Removed U-Net skip connections (Cross-Repeat Skip covers this role) +- 17.14M params, 12.83MB artifact + +### Training + +LR x0.3 from baseline — recurrence amplifies gradients through 4 passes, so optimal LR is much lower. Found via sweep of 10 configs on RTX 3060. + +MATRIX_LR=0.012, SCALAR_LR=0.012, TIED_EMBED_LR=0.015, GRAD_CLIP_NORM=0.3, WARMDOWN_ITERS=3000, TRAIN_SEQ_LEN=1024. + +Tested train@2048 but 1024 gives more steps (133ms vs 253ms/step) which matters more for this architecture. Standard Muon + Adam. + +### Evaluation + +Sliding window eval: window=1024, stride=256 on the int8+zlib roundtrip model. Eval time 209s on 8xH100. + +### Results (8xH100, 600s wallclock) + +4494 steps, 133ms/step avg. Pre-quant 1.2487, roundtrip 1.2533, sliding window 1.2196. Artifact 12.83MB, quant degradation 0.005 bpb, peak memory ~29GB/GPU. + +### Ablations (RTX 3060, 2000 steps each) + +- Cross-Repeat Skip: -0.041 bpb +- Value Embeddings (2 tables): -0.079 bpb +- LR x0.3: -0.052 bpb +- Sliding window eval: -0.034 bpb +- WARMDOWN_ITERS=3000: -0.027 bpb + +### Development + +All experiments, ablations, and hyperparameter sweeps done on a single RTX 3060 12GB. Cloud GPUs (1xH200, 6xH100) used only for validation. Final run on 8xH100. + +### Command + +``` +RUN_ID=submission_8xh100 \ +QUANT_LEVELS=127 \ +TTT_STEPS=0 \ +EVAL_STRIDE=256 \ +EVAL_SEQ_LEN=1024 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` diff --git a/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/submission.json b/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/submission.json new file mode 100644 index 000000000..f04f129d1 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/submission.json @@ -0,0 +1,16 @@ +{ + "author": "Ivan Verbovoy", + "github_id": "iverbovoy", + "name": "Depth Recurrence + Cross-Repeat Skip + Value Embeddings + Sliding Window", + "blurb": "3 unique blocks x 4 repeats (12 effective layers), dim=832, with Cross-Repeat Skip (stateful recurrence), 2 Value Embedding tables, LR x0.3, sliding window eval (stride=256). 4494 steps in 600s on 8xH100.", + "date": "2026-03-20T02:00:00Z", + "val_loss": 2.05921204, + "val_bpb": 1.21958209, + "roundtrip_val_loss": 2.11612232, + "roundtrip_val_bpb": 1.25328684, + "step_stop": 4494, + "wallclock_seconds": 600.133, + "bytes_total": 12829176, + "bytes_model_int8_zlib": 12771121, + "bytes_code": 58055 +} diff --git a/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/train.log b/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/train.log new file mode 100644 index 000000000..d9a0c1529 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/train.log @@ -0,0 +1,84 @@ +W0320 00:54:42.000000 1050 torch/distributed/run.py:852] +W0320 00:54:42.000000 1050 torch/distributed/run.py:852] ***************************************** +W0320 00:54:42.000000 1050 torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0320 00:54:42.000000 1050 torch/distributed/run.py:852] ***************************************** +logs/submission_8xh100.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17140056 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.015 head_lr:0.0 matrix_lr:0.012 scalar_lr:0.012 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9766 val_bpb:4.1319 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9765 train_time:162ms step_avg:161.95ms +step:2/20000 train_loss:9.0581 train_time:218ms step_avg:109.04ms +step:3/20000 train_loss:7.8439 train_time:342ms step_avg:114.12ms +step:4/20000 train_loss:6.5913 train_time:466ms step_avg:116.40ms +step:5/20000 train_loss:6.1067 train_time:589ms step_avg:117.72ms +step:6/20000 train_loss:6.3514 train_time:712ms step_avg:118.70ms +step:7/20000 train_loss:5.9725 train_time:836ms step_avg:119.39ms +step:8/20000 train_loss:5.8139 train_time:958ms step_avg:119.78ms +step:9/20000 train_loss:5.5629 train_time:1081ms step_avg:120.13ms +step:10/20000 train_loss:5.3728 train_time:1206ms step_avg:120.64ms +step:200/20000 train_loss:2.7739 train_time:26609ms step_avg:133.05ms +step:400/20000 train_loss:2.3107 train_time:53543ms step_avg:133.86ms +step:600/20000 train_loss:2.5249 train_time:80122ms step_avg:133.54ms +step:800/20000 train_loss:2.2710 train_time:106824ms step_avg:133.53ms +step:1000/20000 train_loss:2.3610 train_time:133649ms step_avg:133.65ms +step:1000/20000 val_loss:2.3206 val_bpb:1.3744 train_time:133722ms step_avg:133.72ms +step:1200/20000 train_loss:2.3700 train_time:160457ms step_avg:133.71ms +step:1400/20000 train_loss:2.4196 train_time:187085ms step_avg:133.63ms +step:1600/20000 train_loss:2.0826 train_time:213643ms step_avg:133.53ms +step:1800/20000 train_loss:2.1817 train_time:240257ms step_avg:133.48ms +step:2000/20000 train_loss:2.2342 train_time:266823ms step_avg:133.41ms +step:2000/20000 val_loss:2.2137 val_bpb:1.3111 train_time:266903ms step_avg:133.45ms +step:2200/20000 train_loss:2.0469 train_time:293423ms step_avg:133.37ms +step:2400/20000 train_loss:2.1757 train_time:320078ms step_avg:133.37ms +step:2600/20000 train_loss:2.3756 train_time:346626ms step_avg:133.32ms +step:2800/20000 train_loss:2.2012 train_time:373394ms step_avg:133.35ms +step:3000/20000 train_loss:2.1910 train_time:400062ms step_avg:133.35ms +step:3000/20000 val_loss:2.1585 val_bpb:1.2784 train_time:400147ms step_avg:133.38ms +step:3200/20000 train_loss:2.1485 train_time:426762ms step_avg:133.36ms +step:3400/20000 train_loss:2.1171 train_time:453425ms step_avg:133.36ms +step:3600/20000 train_loss:2.0703 train_time:480073ms step_avg:133.35ms +step:3800/20000 train_loss:2.1774 train_time:506627ms step_avg:133.32ms +step:4000/20000 train_loss:2.1156 train_time:532930ms step_avg:133.23ms +step:4000/20000 val_loss:2.1201 val_bpb:1.2556 train_time:533004ms step_avg:133.25ms +step:4200/20000 train_loss:2.1277 train_time:561906ms step_avg:133.79ms +step:4400/20000 train_loss:2.0541 train_time:588700ms step_avg:133.80ms +step:4494/20000 val_loss:2.1084 val_bpb:1.2487 train_time:600133ms step_avg:133.54ms +stopping_early: wallclock_cap train_time:600133ms step:4494/20000 +peak memory allocated: 21771 MiB reserved: 21818 MiB +Serialized model: 63387167 bytes +Code size: 58055 bytes +Total submission size: 63445222 bytes +Serialized model int8+zlib: 12771121 bytes (payload:17243616 raw_torch:17261176 payload_ratio:3.68x) +Total submission size int8+zlib: 12829176 bytes +final_int8_zlib_roundtrip val_loss:2.1161 val_bpb:1.2533 eval_time:3709ms +final_int8_zlib_roundtrip_exact val_loss:2.11612232 val_bpb:1.25328684 +final_sliding_window val_loss:2.0592 val_bpb:1.2196 window:1024 stride:256 eval_time:209349ms +final_sliding_window_exact val_loss:2.05921204 val_bpb:1.21958209 diff --git a/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/train_gpt.py b/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/train_gpt.py new file mode 100644 index 000000000..aa83a930b --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_DepthRecurrence_CrossRepeatSkip/train_gpt.py @@ -0,0 +1,1365 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + ttt_steps = int(os.environ.get("TTT_STEPS", 0)) + ttt_lr = float(os.environ.get("TTT_LR", 1e-4)) + + # Sliding window eval. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 3)) + num_repeats = int(os.environ.get("NUM_REPEATS", 4)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 832)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + num_value_embeds = int(os.environ.get("NUM_VALUE_EMBEDS", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.015)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.012)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.012)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Sliding window eval: each window is eval_seq_len tokens, advancing by eval_stride. + Loss is scored only on the last eval_stride tokens per window.""" + seq_len = args.eval_seq_len + stride = args.eval_stride + total_tokens = val_tokens.numel() + + starts: list[int] = [] + pos = 0 + while pos + seq_len < total_tokens: + starts.append(pos) + pos += stride + total_windows = len(starts) + win_start = (total_windows * rank) // world_size + win_end = (total_windows * (rank + 1)) // world_size + score_offset = seq_len - stride + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.no_grad(): + for wi in range(win_start, win_end): + s = starts[wi] + window = val_tokens[s : s + seq_len + 1].to(device=device, dtype=torch.int64) + x = window[:-1].unsqueeze(0) + y = window[1:].unsqueeze(0) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(x) + + tail_logits = logits[0, score_offset:, :].float() + tail_targets = y[0, score_offset:] + per_token_loss = F.cross_entropy(tail_logits, tail_targets, reduction="none") + val_loss_sum += per_token_loss.to(torch.float64).sum() + val_token_count += float(stride) + + tail_prev = x[0, score_offset:] + tail_tgt = y[0, score_offset:] + token_bytes = base_bytes_lut[tail_tgt].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tail_tgt] & ~is_boundary_token_lut[tail_prev]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_ttt( + args: Hyperparameters, + base_model: nn.Module, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Test-Time Training: adapt the model on each validation batch before evaluating. + # For each batch: save weights → K gradient steps → evaluate → restore weights. + if args.ttt_steps <= 0: + return eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Save original weights once + saved_state = {k: v.detach().clone() for k, v in base_model.state_dict().items()} + + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + + # TTT: adapt on this batch + model.train() + for _ttt_step in range(args.ttt_steps): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ttt_loss = model(x, y) + ttt_loss.backward() + with torch.no_grad(): + for p in base_model.parameters(): + if p.grad is not None: + p -= args.ttt_lr * p.grad + p.grad = None + + # Evaluate with adapted model + model.eval() + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + # Restore original weights + base_model.load_state_dict(saved_state, strict=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. +# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +# Int6 quantization: ±31 instead of ±127. Stored as int8 but zlib compresses better. +QUANT_LEVELS = int(os.environ.get("QUANT_LEVELS", 127)) # 127 = int8, 31 = int6 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + ql = QUANT_LEVELS # 31 for int6, 127 for int8 + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / ql).clamp_min(1.0 / ql) + q = torch.clamp(torch.round(clipped / scale[:, None]), -ql, ql).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / ql if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -ql, ql).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # relu^2 MLP from the original modded-nanogpt setup + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x)) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_repeats: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + num_value_embeds: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_repeats = num_repeats + effective_depth = num_layers * num_repeats + self.tok_emb = nn.Embedding(vocab_size, model_dim) + # Value embeddings: extra embedding tables mixed into each effective layer + self.num_value_embeds = num_value_embeds + if num_value_embeds > 0: + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(num_value_embeds)]) + self.value_scales = nn.Parameter(torch.zeros(effective_depth, num_value_embeds, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + # Loop embedding: tells the model which effective layer it's at + self.loop_embed = nn.Parameter(torch.zeros(effective_depth, model_dim, dtype=torch.float32)) + # Cross-repeat skip: each block remembers its output from previous repeat + # Per-repeat scales (repeat 0 has no prev, so num_repeats-1 scales per block) + self.cross_repeat_scales = nn.Parameter(torch.zeros(num_layers, num_repeats - 1, model_dim, dtype=torch.float32)) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # Pre-compute value embeddings once + ve_list: list[Tensor] = [] + if self.num_value_embeds > 0: + for ve in self.value_embeds: + ve_list.append(ve(input_ids)) # (bsz, seq, dim) + + num_blocks = len(self.blocks) + prev_block_outputs: list[Tensor | None] = [None] * num_blocks + layer_idx = 0 + for repeat in range(self.num_repeats): + for block_idx, block in enumerate(self.blocks): + x = x + self.loop_embed[layer_idx].to(dtype=x.dtype) + # Value embeddings: add weighted extra embeddings at each layer + for ve_idx, ve_out in enumerate(ve_list): + vs = self.value_scales[layer_idx, ve_idx].to(dtype=x.dtype) + x = x + vs[None, None, :] * ve_out + # Cross-repeat skip: mix in this block's output from previous repeat + if repeat > 0 and prev_block_outputs[block_idx] is not None: + scale = self.cross_repeat_scales[block_idx, repeat - 1].to(dtype=x.dtype) + x = x + scale[None, None, :] * prev_block_outputs[block_idx] + x = block(x, x0) + prev_block_outputs[block_idx] = x.detach() if not self.training else x + layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + num_repeats=args.num_repeats, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + num_value_embeds=args.num_value_embeds, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params.append(base_model.loop_embed) + scalar_params.append(base_model.cross_repeat_scales) + if base_model.num_value_embeds > 0: + scalar_params.append(base_model.value_scales) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params = [base_model.tok_emb.weight] + if base_model.num_value_embeds > 0: + embed_params.extend(ve.weight for ve in base_model.value_embeds) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed int8+zlib artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zlib.compress(quant_raw, level=9) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval + if args.eval_stride > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"window:{args.eval_seq_len} stride:{args.eval_stride} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # TTT eval: adapt model on each batch before evaluating + if args.ttt_steps > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt( + args, + base_model, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_ttt val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"ttt_steps:{args.ttt_steps} ttt_lr:{args.ttt_lr} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"final_ttt_exact val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/train_gpt.py b/train_gpt.py index 5632db266..6261f87b7 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -461,6 +461,8 @@ def eval_val_ttt( INT8_PER_ROW_SCALE_DTYPE = torch.float16 INT8_CLIP_PERCENTILE = 99.99984 INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +# Int6 quantization: ±31 instead of ±127. Stored as int8 but zlib compresses better. +QUANT_LEVELS = int(os.environ.get("QUANT_LEVELS", 31)) # 31 = int6, 127 = int8 def tensor_nbytes(t: Tensor) -> int: return int(t.numel()) * int(t.element_size()) @@ -474,24 +476,22 @@ def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, s return t def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + ql = QUANT_LEVELS # 31 for int6, 127 for int8 t32 = t.float() if t32.ndim == 2: - # Matrices get one scale per row, which usually tracks output-channel - # ranges much better than a single tensor-wide scale. clip_abs = ( torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32) ) clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) - q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + scale = (clip_abs / ql).clamp_min(1.0 / ql) + q = torch.clamp(torch.round(clipped / scale[:, None]), -ql, ql).to(torch.int8).contiguous() return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() - # Vectors / scalars use a simpler per-tensor scale. clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 - scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) - q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + scale = torch.tensor(clip_abs / ql if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -ql, ql).to(torch.int8).contiguous() return q, scale def quantize_state_dict_int8(state_dict: dict[str, Tensor]): From 0f019a11a2a33a674505181b20eb987c1603b163 Mon Sep 17 00:00:00 2001 From: Ivan Verbovoy Date: Sat, 21 Mar 2026 17:27:55 +0200 Subject: [PATCH 05/11] Add SWA, Muon WD, fix quantization clamp - Fix quantization clamp_min(1/ql) -> clamp_min(1e-12) preventing broken roundtrip on undertrained models - Add Muon weight decay (0.04) for training stability - Add SWA with float32 accumulation and final snapshot inclusion - Remove sweep.sh --- sweep.sh | 46 ---------------------------------------------- train_gpt.py | 44 +++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 49 deletions(-) delete mode 100755 sweep.sh diff --git a/sweep.sh b/sweep.sh deleted file mode 100755 index 5e6f9f8aa..000000000 --- a/sweep.sh +++ /dev/null @@ -1,46 +0,0 @@ -#!/bin/bash -# Hyperparameter sweep — run overnight on 3060 -# Each run: 2000 steps, batch 8K, no TTT - -export ITERATIONS=2000 -export TRAIN_BATCH_TOKENS=8192 -export VAL_LOSS_EVERY=0 -export VAL_BATCH_SIZE=8192 -export MAX_WALLCLOCK_SECONDS=0 -export TTT_STEPS=0 - -echo "=== Starting sweep at $(date) ===" - -# 1. Baseline (current defaults: matrix_lr=0.04, embed_lr=0.05, scalar_lr=0.04) -echo "--- Run 1: baseline ---" -RUN_ID=sweep_baseline torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | grep -E "(model_params|step:2000|final_int8_zlib_roundtrip_exact)" - -# 2. All lr x1.5 -echo "--- Run 2: lr x1.5 ---" -RUN_ID=sweep_lr15 MATRIX_LR=0.06 TIED_EMBED_LR=0.075 SCALAR_LR=0.06 torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | grep -E "(step:2000|final_int8_zlib_roundtrip_exact)" - -# 3. All lr x2.0 -echo "--- Run 3: lr x2.0 ---" -RUN_ID=sweep_lr20 MATRIX_LR=0.08 TIED_EMBED_LR=0.1 SCALAR_LR=0.08 torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | grep -E "(step:2000|final_int8_zlib_roundtrip_exact)" - -# 4. All lr x0.5 -echo "--- Run 4: lr x0.5 ---" -RUN_ID=sweep_lr05 MATRIX_LR=0.02 TIED_EMBED_LR=0.025 SCALAR_LR=0.02 torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | grep -E "(step:2000|final_int8_zlib_roundtrip_exact)" - -# 5. Lower embed_lr ratio (embed_lr = 0.3x matrix_lr) -echo "--- Run 5: low embed_lr ---" -RUN_ID=sweep_lowemb TIED_EMBED_LR=0.012 torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | grep -E "(step:2000|final_int8_zlib_roundtrip_exact)" - -# 6. Longer warmdown (2400 iters) -echo "--- Run 6: warmdown_iters=2400 ---" -RUN_ID=sweep_wd2400 WARMDOWN_ITERS=2400 torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | grep -E "(step:2000|final_int8_zlib_roundtrip_exact)" - -# 7. Higher muon momentum -echo "--- Run 7: muon_momentum=0.98 ---" -RUN_ID=sweep_mom98 MUON_MOMENTUM=0.98 torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | grep -E "(step:2000|final_int8_zlib_roundtrip_exact)" - -# 8. Matrix lr x1.5 + lower embed -echo "--- Run 8: matrix_lr=0.06 + embed_lr=0.02 ---" -RUN_ID=sweep_combo MATRIX_LR=0.06 TIED_EMBED_LR=0.02 torchrun --standalone --nproc_per_node=1 train_gpt.py 2>&1 | grep -E "(step:2000|final_int8_zlib_roundtrip_exact)" - -echo "=== Sweep done at $(date) ===" diff --git a/train_gpt.py b/train_gpt.py index 6261f87b7..eb810a046 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -62,6 +62,11 @@ class Hyperparameters: ttt_steps = int(os.environ.get("TTT_STEPS", 0)) ttt_lr = float(os.environ.get("TTT_LR", 1e-4)) + # SWA (Stochastic Weight Averaging) during warmdown. + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.15)) + swa_every = int(os.environ.get("SWA_EVERY", 25)) + # Sliding window eval. eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) @@ -90,6 +95,7 @@ class Hyperparameters: muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) beta1 = float(os.environ.get("BETA1", 0.9)) beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) @@ -118,10 +124,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) - class Muon(torch.optim.Optimizer): - def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): super().__init__( params, - dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov), + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), ) @torch.no_grad() @@ -167,9 +173,12 @@ def step(self, closure=None): if distributed: dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) curr = 0 for p in params: g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) p.add_(g, alpha=-lr) curr += p.numel() @@ -485,7 +494,7 @@ def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: else torch.empty((t32.shape[0],), dtype=torch.float32) ) clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / ql).clamp_min(1.0 / ql) + scale = (clip_abs / ql).clamp_min(1e-12) q = torch.clamp(torch.round(clipped / scale[:, None]), -ql, ql).to(torch.int8).contiguous() return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() @@ -1062,6 +1071,7 @@ def log0(msg: str, console: bool = True) -> None: lr=args.matrix_lr, momentum=args.muon_momentum, backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, ) for group in optimizer_muon.param_groups: group["base_lr"] = args.matrix_lr @@ -1155,6 +1165,8 @@ def lr_mul(step: int, elapsed_ms: float) -> float: training_time_ms = 0.0 stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 torch.cuda.synchronize() t0 = time.perf_counter() @@ -1224,6 +1236,18 @@ def lr_mul(step: int, elapsed_ms: float) -> float: step += 1 approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown (accumulate in float for precision) + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().float() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + should_log_train = ( args.train_log_every > 0 and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) @@ -1248,6 +1272,20 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" ) + # Apply SWA if collected + if args.swa_enabled and swa_state is not None: + # Include final weights (may not have landed on swa_every boundary) + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + log0(f"swa: averaging {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + # ----------------------------- # SERIALIZATION + ROUNDTRIP VALIDATION # ----------------------------- From 6e319a229e9858386f1788a8e5f74f51e4ccc6f9 Mon Sep 17 00:00:00 2001 From: Ivan Verbovoy Date: Thu, 26 Mar 2026 00:37:36 +0200 Subject: [PATCH 06/11] =?UTF-8?q?Add=20XSA,=20LeakyReLU=C2=B2,=20GPTQ-lite?= =?UTF-8?q?,=20zstd-22=20=E2=80=94=20val=5Fbpb=201.2070?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Improvements over previous submission (1.2196 → 1.2070, -0.014 bpb): - XSA (Exclusive Self-Attention) on last 4 effective layers: -0.010 bpb - LeakyReLU(0.5)² instead of relu²: -0.004 bpb - GPTQ-lite: per-row best-of-5 clip percentiles for quantization - zstd-22 compression instead of zlib (saves ~1.85MB artifact) - SWA tuned to frac=0.4, every=50 Tested on 8xH100, 80 train shards, PyTorch 2.5, 4290 steps. --- train_gpt.py | 205 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 138 insertions(+), 67 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index eb810a046..8987bfe69 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -16,7 +16,7 @@ import sys import time import uuid -import zlib +import zstandard as zstd from pathlib import Path import numpy as np @@ -62,14 +62,18 @@ class Hyperparameters: ttt_steps = int(os.environ.get("TTT_STEPS", 0)) ttt_lr = float(os.environ.get("TTT_LR", 1e-4)) + # XSA (Exclusive Self-Attention) on last N effective layers. + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + # SWA (Stochastic Weight Averaging) during warmdown. swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) - swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.15)) - swa_every = int(os.environ.get("SWA_EVERY", 25)) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # Sliding window eval. - eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 1024)) eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) @@ -306,48 +310,61 @@ def eval_val_sliding( has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, ) -> tuple[float, float]: - """Sliding window eval: each window is eval_seq_len tokens, advancing by eval_stride. - Loss is scored only on the last eval_stride tokens per window.""" + """Sliding window eval with batching. Windows of train_seq_len advance by eval_stride. + Only the last stride tokens per window are scored (first window scores all).""" seq_len = args.eval_seq_len stride = args.eval_stride - total_tokens = val_tokens.numel() - - starts: list[int] = [] - pos = 0 - while pos + seq_len < total_tokens: - starts.append(pos) - pos += stride - total_windows = len(starts) - win_start = (total_windows * rank) // world_size - win_end = (total_windows * (rank + 1)) // world_size - score_offset = seq_len - stride + batch_seqs = args.eval_batch_seqs + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) val_token_count = torch.zeros((), device=device, dtype=torch.float64) val_byte_count = torch.zeros((), device=device, dtype=torch.float64) base_model.eval() - with torch.no_grad(): - for wi in range(win_start, win_end): - s = starts[wi] - window = val_tokens[s : s + seq_len + 1].to(device=device, dtype=torch.int64) - x = window[:-1].unsqueeze(0) - y = window[1:].unsqueeze(0) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi : bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws : end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - logits = base_model.forward_logits(x) - - tail_logits = logits[0, score_offset:, :].float() - tail_targets = y[0, score_offset:] - per_token_loss = F.cross_entropy(tail_logits, tail_targets, reduction="none") - val_loss_sum += per_token_loss.to(torch.float64).sum() - val_token_count += float(stride) - - tail_prev = x[0, score_offset:] - tail_tgt = y[0, score_offset:] - token_bytes = base_bytes_lut[tail_tgt].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tail_tgt] & ~is_boundary_token_lut[tail_prev]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() + logits = base_model.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + val_loss_sum += scored_nll.sum() + val_token_count += float(wlen - s) + prev_ids = x_batch[i, s:wlen] + tgt_ids = y_batch[i, s:wlen] + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() if dist.is_available() and dist.is_initialized(): dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) @@ -446,8 +463,7 @@ def eval_val_ttt( # ----------------------------- # # It's silly to export our model, which is trained in bf16 and fp32, at that same precision. -# Instead, we get approximately the same model (with a small hit) by quantizing the model to int8 & zlib compressing. -# We can then decompress the model and run in higher precision for evaluation, after closing in under the size limit. +# Instead, we get approximately the same model (with a small hit) by quantizing and zstd compressing. CONTROL_TENSOR_NAME_PATTERNS = tuple( pattern @@ -470,8 +486,10 @@ def eval_val_ttt( INT8_PER_ROW_SCALE_DTYPE = torch.float16 INT8_CLIP_PERCENTILE = 99.99984 INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 -# Int6 quantization: ±31 instead of ±127. Stored as int8 but zlib compresses better. -QUANT_LEVELS = int(os.environ.get("QUANT_LEVELS", 31)) # 31 = int6, 127 = int8 +# Quantization levels: 127 = int8, 31 = int6, 16 = int5. Per-tensor override via MLP_QUANT_LEVELS. +QUANT_LEVELS = int(os.environ.get("QUANT_LEVELS", 127)) +MLP_QUANT_LEVELS = int(os.environ.get("MLP_QUANT_LEVELS", 0)) # 0 = same as QUANT_LEVELS +MLP_TENSOR_PATTERNS = ("mlp.fc.", "mlp.proj.", "fc.weight", "mlp.proj.weight") def tensor_nbytes(t: Tensor) -> int: return int(t.numel()) * int(t.element_size()) @@ -484,19 +502,40 @@ def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, s return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() return t -def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: - ql = QUANT_LEVELS # 31 for int6, 127 for int8 +GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 0.999999] + +def quantize_float_tensor(t: Tensor, ql: int = 0) -> tuple[Tensor, Tensor]: + if ql <= 0: + ql = QUANT_LEVELS t32 = t.float() if t32.ndim == 2: - clip_abs = ( - torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) - if t32.numel() - else torch.empty((t32.shape[0],), dtype=torch.float32) - ) - clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) - scale = (clip_abs / ql).clamp_min(1e-12) - q = torch.clamp(torch.round(clipped / scale[:, None]), -ql, ql).to(torch.int8).contiguous() - return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + # GPTQ-lite: try multiple clip percentiles per row, pick best MSE + abs_t = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in GPTQ_LITE_PERCENTILES: + clip_abs = ( + torch.quantile(abs_t, pct, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + s = (clip_abs / ql).clamp_min(1e-12) + q = torch.clamp(torch.round(clipped / s[:, None]), -ql, ql) + # Reconstruction error per row + recon = q * s[:, None] + mse = (t32 - recon).square().sum(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = s + else: + better = mse < best_mse + best_mse = torch.where(better, mse, best_mse) + best_q = torch.where(better[:, None], q, best_q) + best_scale = torch.where(better, s, best_scale) + return best_q.to(torch.int8).contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 scale = torch.tensor(clip_abs / ql if clip_abs > 0 else 1.0, dtype=torch.float32) @@ -541,9 +580,17 @@ def quantize_state_dict_int8(state_dict: dict[str, Tensor]): continue stats["num_float_tensors"] += 1 - q, s = quantize_float_tensor(t) + mlp_ql = MLP_QUANT_LEVELS if MLP_QUANT_LEVELS > 0 else QUANT_LEVELS + ql = mlp_ql if any(p in name for p in MLP_TENSOR_PATTERNS) else QUANT_LEVELS + q, s = quantize_float_tensor(t, ql=ql) + meta: dict[str, object] = {} if s.ndim > 0: - qmeta[name] = {"scheme": "per_row", "axis": 0} + meta["scheme"] = "per_row" + meta["axis"] = 0 + if ql != QUANT_LEVELS: + meta["ql"] = ql + if meta: + qmeta[name] = meta quantized[name] = q scales[name] = s dtypes[name] = str(t.dtype).removeprefix("torch.") @@ -744,7 +791,17 @@ def __init__( self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) self.rotary = Rotary(self.head_dim, base=rope_base) - def forward(self, x: Tensor) -> Tensor: + def _xsa(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection from attention output (GQA-aware).""" + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(3) # [B, T, Hkv, 1, D] + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, use_xsa: bool = False) -> Tensor: bsz, seqlen, dim = x.shape q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) @@ -763,12 +820,19 @@ def forward(self, x: Tensor) -> Tensor: is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads), ) - y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + # XSA: remove self-value bias from attention output + if use_xsa: + y = y.transpose(1, 2).contiguous() # [B, T, H, D] + v_for_xsa = v.transpose(1, 2) # [B, T, Hkv, D] + y = self._xsa(y, v_for_xsa) + y = y.reshape(bsz, seqlen, dim) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) return self.proj(y) class MLP(nn.Module): - # relu^2 MLP from the original modded-nanogpt setup + # leaky_relu(0.5)^2 MLP — better gradient flow than relu^2 for deep/recurrent models def __init__(self, dim: int, mlp_mult: int): super().__init__() hidden = mlp_mult * dim @@ -777,7 +841,7 @@ def __init__(self, dim: int, mlp_mult: int): self.proj._zero_init = True def forward(self, x: Tensor) -> Tensor: - x = torch.relu(self.fc(x)) + x = F.leaky_relu(self.fc(x), negative_slope=0.5) return self.proj(x.square()) @@ -800,10 +864,10 @@ def __init__( self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) - def forward(self, x: Tensor, x0: Tensor) -> Tensor: + def forward(self, x: Tensor, x0: Tensor, use_xsa: bool = False) -> Tensor: mix = self.resid_mix.to(dtype=x.dtype) x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 - attn_out = self.attn(self.attn_norm(x)) + attn_out = self.attn(self.attn_norm(x), use_xsa=use_xsa) x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) return x @@ -825,6 +889,7 @@ def __init__( logit_softcap: float, rope_base: float, qk_gain_init: float, + xsa_last_n: int = 0, ): super().__init__() if logit_softcap <= 0.0: @@ -834,6 +899,8 @@ def __init__( self.logit_softcap = logit_softcap self.num_repeats = num_repeats effective_depth = num_layers * num_repeats + # XSA: which effective layers use exclusive self-attention + self.xsa_start = max(0, effective_depth - xsa_last_n) if xsa_last_n > 0 else effective_depth self.tok_emb = nn.Embedding(vocab_size, model_dim) # Value embeddings: extra embedding tables mixed into each effective layer self.num_value_embeds = num_value_embeds @@ -896,7 +963,7 @@ def forward_logits(self, input_ids: Tensor) -> Tensor: if repeat > 0 and prev_block_outputs[block_idx] is not None: scale = self.cross_repeat_scales[block_idx, repeat - 1].to(dtype=x.dtype) x = x + scale[None, None, :] * prev_block_outputs[block_idx] - x = block(x, x0) + x = block(x, x0, use_xsa=(layer_idx >= self.xsa_start)) prev_block_outputs[block_idx] = x.detach() if not self.training else x layer_idx += 1 @@ -1028,6 +1095,7 @@ def log0(msg: str, console: bool = True) -> None: logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + xsa_last_n=args.xsa_last_n, ).to(device).bfloat16() for module in base_model.modules(): if isinstance(module, CastedLinear): @@ -1290,7 +1358,7 @@ def lr_mul(step: int, elapsed_ms: float) -> float: # SERIALIZATION + ROUNDTRIP VALIDATION # ----------------------------- # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce - # the compressed int8+zlib artifact and validate the round-tripped weights. + # the compressed quantized+zstd artifact and validate the round-tripped weights. if master_process: torch.save(base_model.state_dict(), "final_model.pt") @@ -1304,7 +1372,9 @@ def lr_mul(step: int, elapsed_ms: float) -> float: quant_buf = io.BytesIO() torch.save(quant_obj, quant_buf) quant_raw = quant_buf.getvalue() - quant_blob = zlib.compress(quant_raw, level=9) + zstd_level = int(os.environ.get("ZSTD_LEVEL", 22)) + cctx = zstd.ZstdCompressor(level=zstd_level) + quant_blob = cctx.compress(quant_raw) quant_raw_bytes = len(quant_raw) if master_process: with open("final_model.int8.ptz", "wb") as f: @@ -1313,16 +1383,17 @@ def lr_mul(step: int, elapsed_ms: float) -> float: code_bytes = len(code.encode("utf-8")) ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) log0( - f"Serialized model int8+zlib: {quant_file_bytes} bytes " + f"Serialized model int8+zstd{zstd_level}: {quant_file_bytes} bytes " f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" ) - log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size int8+zstd{zstd_level}: {quant_file_bytes + code_bytes} bytes") if distributed: dist.barrier() with open("final_model.int8.ptz", "rb") as f: quant_blob_disk = f.read() - quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu") + dctx = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) torch.cuda.synchronize() t_qeval = time.perf_counter() @@ -1340,10 +1411,10 @@ def lr_mul(step: int, elapsed_ms: float) -> float: ) torch.cuda.synchronize() log0( - f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"final_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" ) - log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + log0(f"final_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") # Sliding window eval if args.eval_stride > 0: From 697f820632b8682a927a1c37957498b0ae64dd4b Mon Sep 17 00:00:00 2001 From: Ivan Verbovoy Date: Thu, 26 Mar 2026 01:37:16 +0200 Subject: [PATCH 07/11] =?UTF-8?q?Add=20submission:=20Depth=20Recurrence=20?= =?UTF-8?q?+=20XSA=20+=20LeakyReLU=C2=B2=20(val=5Fbpb=201.2065)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Improvements over previous submission (1.2196 → 1.2065, -0.013 bpb): - XSA (Exclusive Self-Attention) on last 4 effective layers: -0.010 bpb - LeakyReLU(0.5)² instead of relu²: -0.004 bpb - GPTQ-lite: per-row best-of-5 clip percentiles - zstd-22 compression instead of zlib - SWA tuned to frac=0.4, every=50 8xH100, 80 train shards, 4300 steps, 140ms/step, 15.87MB artifact. --- .../README.md | 52 + .../submission.json | 16 + .../train.log | 101 ++ .../train_gpt.py | 1473 +++++++++++++++++ train_gpt.py | 3 +- 5 files changed, 1643 insertions(+), 2 deletions(-) create mode 100644 records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/README.md create mode 100644 records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/submission.json create mode 100644 records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/train.log create mode 100644 records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/train_gpt.py diff --git a/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/README.md b/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/README.md new file mode 100644 index 000000000..8a06b7a2e --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/README.md @@ -0,0 +1,52 @@ +## Depth Recurrence + XSA + LeakyReLU² + +Improves previous submission (1.2196 → 1.2065, -0.013 bpb) through three zero-parameter additions on top of depth recurrence. + +val_bpb = 1.2065 (sliding window eval on int8+zstd22 roundtrip model, stride=256) +val_bpb = 1.2398 (standard int8+zstd22 roundtrip) + +### Architecture + +Same depth recurrence base as previous submission: 3 shared blocks repeated 4 times (12 effective layers), dim=832, 8 heads, 4 KV heads, MLP 2x, tied embeddings. + +New additions (all zero extra parameters): +- **XSA (Exclusive Self-Attention)** on last 4 effective layers: removes self-value bias from attention output via GQA-aware projection subtraction. -0.010 bpb. +- **LeakyReLU(0.5)²** instead of relu²: preserves negative gradient flow while maintaining sparsity. Better gradient propagation through 4 recurrence passes. -0.004 bpb. +- **GPTQ-lite**: per-row best-of-5 clip percentiles during quantization (post-training, zero cost). +- **zstd-22** compression instead of zlib (saves ~1.85MB artifact space). +- **SWA** tuned to frac=0.4, every=50 steps. +- **Muon weight decay** 0.04. + +Retained from previous submission: +- Cross-Repeat Skip (stateful recurrence with per-repeat learned scales) +- 2 Value Embedding tables +- Loop Embedding (per-effective-layer depth encoding) + +17.14M params, 15.87MB artifact. + +### Training + +Same LR schedule as previous: MATRIX_LR=0.012, SCALAR_LR=0.012, TIED_EMBED_LR=0.015, GRAD_CLIP_NORM=0.3, WARMDOWN_ITERS=3000, TRAIN_SEQ_LEN=1024. + +### Results (8xH100, 600s wallclock) + +4300 steps, 140ms/step avg. Pre-quant 1.2373, roundtrip 1.2398, sliding window 1.2065. Artifact 15.87MB, quant degradation +0.003 bpb. + +### Ablations (8xH100, 80 shards, all cumulative) + +| Change | Sliding bpb | Delta | +|--------|-------------|-------| +| Baseline (previous submission repro) | 1.2213 | — | +| + XSA last 4 layers | 1.2110 | -0.0103 | +| + LeakyReLU(0.5)² | 1.2070 | -0.0040 | +| + GPTQ-lite + zstd-22 | 1.2065 | -0.0005 | + +### Command + +``` +XSA_LAST_N=4 \ +QUANT_LEVELS=127 \ +EVAL_SEQ_LEN=1024 \ +EVAL_STRIDE=256 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` diff --git a/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/submission.json b/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/submission.json new file mode 100644 index 000000000..137cdb5ac --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/submission.json @@ -0,0 +1,16 @@ +{ + "author": "Ivan Verbovoy", + "github_id": "iverbovoy", + "name": "Depth Recurrence + XSA + LeakyReLU² + GPTQ-lite + zstd-22", + "blurb": "3 unique blocks x 4 repeats (12 effective layers), dim=832, with Cross-Repeat Skip, XSA on last 4 layers, LeakyReLU(0.5)², GPTQ-lite quantization, SWA, Muon WD=0.04, zstd-22 compression. 4300 steps in 600s on 8xH100.", + "date": "2026-03-26T00:00:00Z", + "val_loss": 2.03711228, + "val_bpb": 1.20649213, + "roundtrip_val_loss": 2.09336895, + "roundtrip_val_bpb": 1.23981101, + "step_stop": 4300, + "wallclock_seconds": 600.151, + "bytes_total": 15873439, + "bytes_model_int8_zlib": 15810364, + "bytes_code": 63075 +} diff --git a/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/train.log b/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/train.log new file mode 100644 index 000000000..b77e2a36a --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/train.log @@ -0,0 +1,101 @@ +W0325 23:14:13.792000 1272 torch/distributed/run.py:793] +W0325 23:14:13.792000 1272 torch/distributed/run.py:793] ***************************************** +W0325 23:14:13.792000 1272 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0325 23:14:13.792000 1272 torch/distributed/run.py:793] ***************************************** +logs/80906ba7-598b-4113-8215-45b1a3a1b567.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17140056 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.015 head_lr:0.0 matrix_lr:0.012 scalar_lr:0.012 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9300 val_bpb:4.1043 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9769 train_time:192ms step_avg:191.99ms +step:2/20000 train_loss:6.4406 train_time:248ms step_avg:123.86ms +step:3/20000 train_loss:7.4687 train_time:384ms step_avg:128.14ms +step:4/20000 train_loss:7.5661 train_time:522ms step_avg:130.52ms +step:5/20000 train_loss:6.8849 train_time:659ms step_avg:131.83ms +step:6/20000 train_loss:6.2342 train_time:796ms step_avg:132.73ms +step:7/20000 train_loss:5.3396 train_time:934ms step_avg:133.40ms +step:8/20000 train_loss:5.0498 train_time:1072ms step_avg:133.94ms +step:9/20000 train_loss:4.8488 train_time:1210ms step_avg:134.49ms +step:10/20000 train_loss:4.7649 train_time:1350ms step_avg:134.97ms +step:200/20000 train_loss:2.7331 train_time:27722ms step_avg:138.61ms +step:400/20000 train_loss:2.2867 train_time:55632ms step_avg:139.08ms +step:600/20000 train_loss:2.5063 train_time:83627ms step_avg:139.38ms +step:800/20000 train_loss:2.2652 train_time:111578ms step_avg:139.47ms +step:1000/20000 train_loss:2.3527 train_time:139525ms step_avg:139.53ms +step:1000/20000 val_loss:2.3114 val_bpb:1.3689 train_time:139609ms step_avg:139.61ms +step:1200/20000 train_loss:2.3656 train_time:167550ms step_avg:139.63ms +step:1400/20000 train_loss:2.4157 train_time:195456ms step_avg:139.61ms +step:1600/20000 train_loss:2.0725 train_time:223357ms step_avg:139.60ms +step:1800/20000 train_loss:2.1766 train_time:251241ms step_avg:139.58ms +step:2000/20000 train_loss:2.2289 train_time:279108ms step_avg:139.55ms +step:2000/20000 val_loss:2.2075 val_bpb:1.3074 train_time:279190ms step_avg:139.60ms +step:2200/20000 train_loss:2.0380 train_time:306975ms step_avg:139.53ms +step:2400/20000 train_loss:2.1660 train_time:334846ms step_avg:139.52ms +step:2600/20000 train_loss:2.3737 train_time:362708ms step_avg:139.50ms +step:2800/20000 train_loss:2.1927 train_time:390569ms step_avg:139.49ms +step:3000/20000 train_loss:2.1817 train_time:418424ms step_avg:139.47ms +step:3000/20000 val_loss:2.1487 val_bpb:1.2726 train_time:418509ms step_avg:139.50ms +swa:start step:3150 +step:3200/20000 train_loss:2.1378 train_time:446330ms step_avg:139.48ms +step:3400/20000 train_loss:2.1057 train_time:474263ms step_avg:139.49ms +step:3600/20000 train_loss:2.0547 train_time:502221ms step_avg:139.51ms +step:3800/20000 train_loss:2.1572 train_time:530168ms step_avg:139.52ms +step:4000/20000 train_loss:2.0950 train_time:558094ms step_avg:139.52ms +step:4000/20000 val_loss:2.0989 val_bpb:1.2431 train_time:558197ms step_avg:139.55ms +step:4200/20000 train_loss:2.0995 train_time:586106ms step_avg:139.55ms +step:4300/20000 val_loss:2.0892 val_bpb:1.2373 train_time:600151ms step_avg:139.57ms +stopping_early: wallclock_cap train_time:600151ms step:4300/20000 +peak memory allocated: 25696 MiB reserved: 27322 MiB +swa: averaging 25 checkpoints +Serialized model: 63386762 bytes +Code size: 63075 bytes +Total submission size: 63449837 bytes +Serialized model int8+zstd22: 15810364 bytes (payload:17243616 raw_torch:17260843 payload_ratio:3.68x) +Total submission size int8+zstd22: 15873439 bytes +/workspace/parameter-golf/train_gpt.py:1395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1395: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +final_roundtrip val_loss:2.0934 val_bpb:1.2398 eval_time:4076ms +final_roundtrip_exact val_loss:2.09336895 val_bpb:1.23981101 +final_sliding_window val_loss:2.0371 val_bpb:1.2065 window:1024 stride:256 eval_time:66852ms +final_sliding_window_exact val_loss:2.03711228 val_bpb:1.20649213 diff --git a/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/train_gpt.py b/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/train_gpt.py new file mode 100644 index 000000000..41f4ac4b9 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-26_DepthRecurrence_XSA_LeakyReLU/train_gpt.py @@ -0,0 +1,1473 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zstandard as zstd +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- +# Default Simple Baseline run: +# - 9 transformer blocks at width 512 +# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion +# - vocab size 1024, sequence length 1024, tied embeddings +# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + ttt_steps = int(os.environ.get("TTT_STEPS", 0)) + ttt_lr = float(os.environ.get("TTT_LR", 1e-4)) + + # XSA (Exclusive Self-Attention) on last N effective layers. + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + + # SWA (Stochastic Weight Averaging) during warmdown. + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Sliding window eval. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 1024)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 3)) + num_repeats = int(os.environ.get("NUM_REPEATS", 4)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 832)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + num_value_embeds = int(os.environ.get("NUM_VALUE_EMBEDS", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.015)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.012)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.012)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# ----------------------------- +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Sliding window eval with batching. Windows of train_seq_len advance by eval_stride. + Only the last stride tokens per window are scored (first window scores all).""" + seq_len = args.eval_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi : bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws : end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + val_loss_sum += scored_nll.sum() + val_token_count += float(wlen - s) + prev_ids = x_batch[i, s:wlen] + tgt_ids = y_batch[i, s:wlen] + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_ttt( + args: Hyperparameters, + base_model: nn.Module, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Test-Time Training: adapt the model on each validation batch before evaluating. + # For each batch: save weights → K gradient steps → evaluate → restore weights. + if args.ttt_steps <= 0: + return eval_val(args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Save original weights once + saved_state = {k: v.detach().clone() for k, v in base_model.state_dict().items()} + + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + + # TTT: adapt on this batch + model.train() + for _ttt_step in range(args.ttt_steps): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + ttt_loss = model(x, y) + ttt_loss.backward() + with torch.no_grad(): + for p in base_model.parameters(): + if p.grad is not None: + p -= args.ttt_lr * p.grad + p.grad = None + + # Evaluate with adapted model + model.eval() + with torch.no_grad(): + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + # Restore original weights + base_model.load_state_dict(saved_state, strict=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION +# ----------------------------- +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing and zstd compressing. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +# Quantization levels: 127 = int8, 31 = int6, 16 = int5. Per-tensor override via MLP_QUANT_LEVELS. +QUANT_LEVELS = int(os.environ.get("QUANT_LEVELS", 127)) +MLP_QUANT_LEVELS = int(os.environ.get("MLP_QUANT_LEVELS", 0)) # 0 = same as QUANT_LEVELS +MLP_TENSOR_PATTERNS = ("mlp.fc.", "mlp.proj.", "fc.weight", "mlp.proj.weight") + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 0.999999] + +def quantize_float_tensor(t: Tensor, ql: int = 0) -> tuple[Tensor, Tensor]: + if ql <= 0: + ql = QUANT_LEVELS + t32 = t.float() + if t32.ndim == 2: + # GPTQ-lite: try multiple clip percentiles per row, pick best MSE + abs_t = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in GPTQ_LITE_PERCENTILES: + clip_abs = ( + torch.quantile(abs_t, pct, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + s = (clip_abs / ql).clamp_min(1e-12) + q = torch.clamp(torch.round(clipped / s[:, None]), -ql, ql) + # Reconstruction error per row + recon = q * s[:, None] + mse = (t32 - recon).square().sum(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = s + else: + better = mse < best_mse + best_mse = torch.where(better, mse, best_mse) + best_q = torch.where(better[:, None], q, best_q) + best_scale = torch.where(better, s, best_scale) + return best_q.to(torch.int8).contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / ql if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -ql, ql).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + mlp_ql = MLP_QUANT_LEVELS if MLP_QUANT_LEVELS > 0 else QUANT_LEVELS + ql = mlp_ql if any(p in name for p in MLP_TENSOR_PATTERNS) else QUANT_LEVELS + q, s = quantize_float_tensor(t, ql=ql) + meta: dict[str, object] = {} + if s.ndim > 0: + meta["scheme"] = "per_row" + meta["axis"] = 0 + if ql != QUANT_LEVELS: + meta["ql"] = ql + if meta: + qmeta[name] = meta + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def _xsa(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection from attention output (GQA-aware).""" + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(3) # [B, T, Hkv, 1, D] + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, use_xsa: bool = False) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + # XSA: remove self-value bias from attention output + if use_xsa: + y = y.transpose(1, 2).contiguous() # [B, T, H, D] + v_for_xsa = v.transpose(1, 2) # [B, T, Hkv, D] + y = self._xsa(y, v_for_xsa) + y = y.reshape(bsz, seqlen, dim) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu(0.5)^2 MLP — better gradient flow than relu^2 for deep/recurrent models + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, use_xsa: bool = False) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x), use_xsa=use_xsa) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_repeats: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + num_value_embeds: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_repeats = num_repeats + effective_depth = num_layers * num_repeats + # XSA: which effective layers use exclusive self-attention + self.xsa_start = max(0, effective_depth - xsa_last_n) if xsa_last_n > 0 else effective_depth + self.tok_emb = nn.Embedding(vocab_size, model_dim) + # Value embeddings: extra embedding tables mixed into each effective layer + self.num_value_embeds = num_value_embeds + if num_value_embeds > 0: + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(num_value_embeds)]) + self.value_scales = nn.Parameter(torch.zeros(effective_depth, num_value_embeds, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + # Loop embedding: tells the model which effective layer it's at + self.loop_embed = nn.Parameter(torch.zeros(effective_depth, model_dim, dtype=torch.float32)) + # Cross-repeat skip: each block receives its own output from previous repeat + self.cross_repeat_scales = nn.Parameter(torch.zeros(num_layers, num_repeats - 1, model_dim, dtype=torch.float32)) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # Pre-compute value embeddings once + ve_list: list[Tensor] = [] + if self.num_value_embeds > 0: + for ve in self.value_embeds: + ve_list.append(ve(input_ids)) # (bsz, seq, dim) + + num_blocks = len(self.blocks) + prev_block_outputs: list[Tensor | None] = [None] * num_blocks + layer_idx = 0 + for repeat in range(self.num_repeats): + for block_idx, block in enumerate(self.blocks): + x = x + self.loop_embed[layer_idx].to(dtype=x.dtype) + # Value embeddings: add weighted extra embeddings at each layer + for ve_idx, ve_out in enumerate(ve_list): + vs = self.value_scales[layer_idx, ve_idx].to(dtype=x.dtype) + x = x + vs[None, None, :] * ve_out + # Cross-repeat skip: mix in this block's output from previous repeat + if repeat > 0 and prev_block_outputs[block_idx] is not None: + scale = self.cross_repeat_scales[block_idx, repeat - 1].to(dtype=x.dtype) + x = x + scale[None, None, :] * prev_block_outputs[block_idx] + x = block(x, x0, use_xsa=(layer_idx >= self.xsa_start)) + prev_block_outputs[block_idx] = x.detach() if not self.training else x + layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # ----------------------------- + # DISTRIBUTED + CUDA SETUP + # ----------------------------- + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # ----------------------------- + # TOKENIZER + VALIDATION METRIC SETUP + # ----------------------------- + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # ----------------------------- + # MODEL + OPTIMIZER SETUP + # ----------------------------- + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + num_repeats=args.num_repeats, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + num_value_embeds=args.num_value_embeds, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params.append(base_model.loop_embed) + scalar_params.append(base_model.cross_repeat_scales) + if base_model.num_value_embeds > 0: + scalar_params.append(base_model.value_scales) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params = [base_model.tok_emb.weight] + if base_model.num_value_embeds > 0: + embed_params.extend(ve.weight for ve in base_model.value_embeds) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # ----------------------------- + # DATA LOADER & MODEL WARMUP + # ----------------------------- + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # ----------------------------- + # MAIN TRAINING LOOP + # ----------------------------- + + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown (accumulate in float for precision) + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().float() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None: + # Include final weights (may not have landed on swa_every boundary) + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + log0(f"swa: averaging {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # ----------------------------- + # SERIALIZATION + ROUNDTRIP VALIDATION + # ----------------------------- + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed quantized+zstd artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + zstd_level = int(os.environ.get("ZSTD_LEVEL", 22)) + cctx = zstd.ZstdCompressor(level=zstd_level) + quant_blob = cctx.compress(quant_raw) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zstd{zstd_level}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zstd{zstd_level}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + dctx = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval + if args.eval_stride > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"window:{args.eval_seq_len} stride:{args.eval_stride} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # TTT eval: adapt model on each batch before evaluating + if args.ttt_steps > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt( + args, + base_model, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_ttt val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"ttt_steps:{args.ttt_steps} ttt_lr:{args.ttt_lr} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"final_ttt_exact val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/train_gpt.py b/train_gpt.py index 8987bfe69..41f4ac4b9 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -922,8 +922,7 @@ def __init__( ) # Loop embedding: tells the model which effective layer it's at self.loop_embed = nn.Parameter(torch.zeros(effective_depth, model_dim, dtype=torch.float32)) - # Cross-repeat skip: each block remembers its output from previous repeat - # Per-repeat scales (repeat 0 has no prev, so num_repeats-1 scales per block) + # Cross-repeat skip: each block receives its own output from previous repeat self.cross_repeat_scales = nn.Parameter(torch.zeros(num_layers, num_repeats - 1, model_dim, dtype=torch.float32)) self.final_norm = RMSNorm() self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) From c0f6b1b62793eb776248882cb0e805029188100c Mon Sep 17 00:00:00 2001 From: Ivan Verbovoy Date: Thu, 26 Mar 2026 03:56:36 +0200 Subject: [PATCH 08/11] =?UTF-8?q?Add=20Progressive=20Depth=20training=20?= =?UTF-8?q?=E2=80=94=20val=5Fbpb=201.1973?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Dynamic depth scheduling unique to shared-weight recurrence: - Phase 1 (0-40%): 2 repeats, ~75ms/step — fast base training - Phase 2 (40-65%): 3 repeats, ~83ms/step — intermediate depth - Phase 3 (65-100%): 4 repeats, ~100ms/step — full recurrence 5981 steps vs 4300 without progressive depth (+39%). SWA collected only at full depth (last phase) to avoid mixing phases. Removed unused TTT eval code. 8xH100, 80 train shards, sliding 1.1973 (-0.009 vs previous 1.2065). --- train_gpt.py | 227 +++++++++++++++------------------------------------ 1 file changed, 67 insertions(+), 160 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index 41f4ac4b9..0127f8fcc 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -27,14 +27,7 @@ from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP -# ----------------------------- # HYPERPARAMETERS -# ----------------------------- -# Default Simple Baseline run: -# - 9 transformer blocks at width 512 -# - 8 attention heads with 4 KV heads (GQA) and 2x MLP expansion -# - vocab size 1024, sequence length 1024, tied embeddings -# - 524,288 train tokens per step for 20,000 iterations with a ~10 minute cap class Hyperparameters: # Data paths are shard globs produced by the existing preprocessing pipeline. @@ -59,8 +52,10 @@ class Hyperparameters: max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) - ttt_steps = int(os.environ.get("TTT_STEPS", 0)) - ttt_lr = float(os.environ.get("TTT_LR", 1e-4)) + + # Progressive Depth: train with fewer repeats early (faster), more repeats later (deeper). + # Schedule format: "frac1:rep1,frac2:rep2,..." e.g. "0.4:2,0.65:3,1.0:4" + prog_depth_schedule = os.environ.get("PROG_DEPTH", "0.4:2,0.65:3,1.0:4") # XSA (Exclusive Self-Attention) on last N effective layers. xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) @@ -104,9 +99,7 @@ class Hyperparameters: beta2 = float(os.environ.get("BETA2", 0.95)) adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) -# ----------------------------- # MUON OPTIMIZER -# ----------------------------- # # As borrowed from modded-nanogpt # Background on Muon: https://kellerjordan.github.io/posts/muon/ @@ -189,9 +182,7 @@ def step(self, closure=None): return loss -# ----------------------------- # TOKENIZER-AGNOSTIC EVALUATION SETUP -# ----------------------------- # # It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. # Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. @@ -378,89 +369,8 @@ def eval_val_sliding( return float(val_loss.item()), float(bits_per_token * tokens_per_byte) -def eval_val_ttt( - args: Hyperparameters, - base_model: nn.Module, - model: nn.Module, - rank: int, - world_size: int, - device: torch.device, - grad_accum_steps: int, - val_tokens: Tensor, - base_bytes_lut: Tensor, - has_leading_space_lut: Tensor, - is_boundary_token_lut: Tensor, -) -> tuple[float, float]: - # Test-Time Training: adapt the model on each validation batch before evaluating. - # For each batch: save weights → K gradient steps → evaluate → restore weights. - if args.ttt_steps <= 0: - return eval_val(args, model, rank, world_size, device, grad_accum_steps, - val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) - - local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) - local_batch_seqs = local_batch_tokens // args.train_seq_len - total_seqs = (val_tokens.numel() - 1) // args.train_seq_len - seq_start = (total_seqs * rank) // world_size - seq_end = (total_seqs * (rank + 1)) // world_size - val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) - val_token_count = torch.zeros((), device=device, dtype=torch.float64) - val_byte_count = torch.zeros((), device=device, dtype=torch.float64) - - # Save original weights once - saved_state = {k: v.detach().clone() for k, v in base_model.state_dict().items()} - - for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): - batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) - raw_start = batch_seq_start * args.train_seq_len - raw_end = batch_seq_end * args.train_seq_len + 1 - local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) - x = local[:-1].reshape(-1, args.train_seq_len) - y = local[1:].reshape(-1, args.train_seq_len) - - # TTT: adapt on this batch - model.train() - for _ttt_step in range(args.ttt_steps): - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - ttt_loss = model(x, y) - ttt_loss.backward() - with torch.no_grad(): - for p in base_model.parameters(): - if p.grad is not None: - p -= args.ttt_lr * p.grad - p.grad = None - - # Evaluate with adapted model - model.eval() - with torch.no_grad(): - with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): - batch_loss = model(x, y).detach() - batch_token_count = float(y.numel()) - val_loss_sum += batch_loss.to(torch.float64) * batch_token_count - val_token_count += batch_token_count - prev_ids = x.reshape(-1) - tgt_ids = y.reshape(-1) - token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) - token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) - val_byte_count += token_bytes.to(torch.float64).sum() - - # Restore original weights - base_model.load_state_dict(saved_state, strict=True) - - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) - dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) - dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) - - val_loss = val_loss_sum / val_token_count - bits_per_token = val_loss.item() / math.log(2.0) - tokens_per_byte = val_token_count.item() / val_byte_count.item() - model.train() - return float(val_loss.item()), float(bits_per_token * tokens_per_byte) - -# ----------------------------- # POST-TRAINING QUANTIZATION -# ----------------------------- # # It's silly to export our model, which is trained in bf16 and fp32, at that same precision. # Instead, we get approximately the same model (with a small hit) by quantizing and zstd compressing. @@ -633,9 +543,7 @@ def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: return out -# ----------------------------- # DATA LOADING -# ----------------------------- def load_data_shard(file: Path) -> Tensor: header_bytes = 256 * np.dtype(" y = local[1:].reshape(-1, seq_len) return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) -# ----------------------------- # TRANSFORMER MODULES -# ----------------------------- class RMSNorm(nn.Module): def __init__(self, eps: float | None = None): @@ -898,9 +804,8 @@ def __init__( self.tied_embed_init_std = tied_embed_init_std self.logit_softcap = logit_softcap self.num_repeats = num_repeats + self.xsa_last_n = xsa_last_n effective_depth = num_layers * num_repeats - # XSA: which effective layers use exclusive self-attention - self.xsa_start = max(0, effective_depth - xsa_last_n) if xsa_last_n > 0 else effective_depth self.tok_emb = nn.Embedding(vocab_size, model_dim) # Value embeddings: extra embedding tables mixed into each effective layer self.num_value_embeds = num_value_embeds @@ -948,21 +853,27 @@ def forward_logits(self, input_ids: Tensor) -> Tensor: for ve in self.value_embeds: ve_list.append(ve(input_ids)) # (bsz, seq, dim) + cur_repeats = self.cur_repeats if hasattr(self, "cur_repeats") else self.num_repeats + cur_depth = len(self.blocks) * cur_repeats + xsa_start = max(0, cur_depth - self.xsa_last_n) if self.xsa_last_n > 0 else cur_depth + num_blocks = len(self.blocks) prev_block_outputs: list[Tensor | None] = [None] * num_blocks layer_idx = 0 - for repeat in range(self.num_repeats): + for repeat in range(cur_repeats): for block_idx, block in enumerate(self.blocks): x = x + self.loop_embed[layer_idx].to(dtype=x.dtype) # Value embeddings: add weighted extra embeddings at each layer - for ve_idx, ve_out in enumerate(ve_list): - vs = self.value_scales[layer_idx, ve_idx].to(dtype=x.dtype) - x = x + vs[None, None, :] * ve_out + if layer_idx < self.value_scales.size(0): + for ve_idx, ve_out in enumerate(ve_list): + vs = self.value_scales[layer_idx, ve_idx].to(dtype=x.dtype) + x = x + vs[None, None, :] * ve_out # Cross-repeat skip: mix in this block's output from previous repeat if repeat > 0 and prev_block_outputs[block_idx] is not None: - scale = self.cross_repeat_scales[block_idx, repeat - 1].to(dtype=x.dtype) + rep_idx = min(repeat - 1, self.cross_repeat_scales.size(1) - 1) + scale = self.cross_repeat_scales[block_idx, rep_idx].to(dtype=x.dtype) x = x + scale[None, None, :] * prev_block_outputs[block_idx] - x = block(x, x0, use_xsa=(layer_idx >= self.xsa_start)) + x = block(x, x0, use_xsa=(layer_idx >= xsa_start)) prev_block_outputs[block_idx] = x.detach() if not self.training else x layer_idx += 1 @@ -983,9 +894,7 @@ def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: return F.cross_entropy(logits.float(), targets, reduction="mean") -# ----------------------------- # TRAINING -# ----------------------------- def main() -> None: global zeropower_via_newtonschulz5 @@ -994,10 +903,8 @@ def main() -> None: args = Hyperparameters() zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) - # ----------------------------- - # DISTRIBUTED + CUDA SETUP - # ----------------------------- - + # DISTRIBUTED + CUDA SETUP + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ rank = int(os.environ.get("RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) @@ -1050,10 +957,8 @@ def log0(msg: str, console: bool = True) -> None: ) log0("=" * 100, console=False) - # ----------------------------- - # TOKENIZER + VALIDATION METRIC SETUP - # ----------------------------- - + # TOKENIZER + VALIDATION METRIC SETUP + random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) @@ -1076,10 +981,8 @@ def log0(msg: str, console: bool = True) -> None: log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") - # ----------------------------- - # MODEL + OPTIMIZER SETUP - # ----------------------------- - + # MODEL + OPTIMIZER SETUP + base_model = GPT( vocab_size=args.vocab_size, num_layers=args.num_layers, @@ -1175,10 +1078,8 @@ def log0(msg: str, console: bool = True) -> None: ) log0(f"seed:{args.seed}") - # ----------------------------- - # DATA LOADER & MODEL WARMUP - # ----------------------------- - + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) def zero_grad_all() -> None: @@ -1226,14 +1127,26 @@ def lr_mul(step: int, elapsed_ms: float) -> float: model.require_backward_grad_sync = True train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) - # ----------------------------- - # MAIN TRAINING LOOP - # ----------------------------- - + # MAIN TRAINING LOOP + training_time_ms = 0.0 stop_after_step: int | None = None swa_state: dict[str, Tensor] | None = None swa_count = 0 + + # Progressive depth schedule: parse "frac:repeats,..." and sort + prog_phases: list[tuple[float, int]] = [] + for entry in args.prog_depth_schedule.split(","): + frac_s, rep_s = entry.strip().split(":") + prog_phases.append((float(frac_s), int(rep_s))) + prog_phases.sort() + current_phase_repeats = prog_phases[0][1] if prog_phases else args.num_repeats + base_model.cur_repeats = current_phase_repeats + # Recompile with initial phase depth + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"prog_depth: schedule={prog_phases} starting_repeats={current_phase_repeats}") + torch.cuda.synchronize() t0 = time.perf_counter() @@ -1273,6 +1186,21 @@ def lr_mul(step: int, elapsed_ms: float) -> float: break elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + # Progressive depth: check if we need to switch phase + if max_wallclock_ms is not None and prog_phases: + frac = elapsed_ms / max_wallclock_ms + new_repeats = prog_phases[-1][1] # default to last + for phase_frac, phase_rep in prog_phases: + if frac < phase_frac: + new_repeats = phase_rep + break + if new_repeats != current_phase_repeats: + current_phase_repeats = new_repeats + base_model.cur_repeats = new_repeats + torch._dynamo.reset() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"prog_depth: switched to {new_repeats} repeats at step:{step} frac:{frac:.2f}") scale = lr_mul(step, elapsed_ms) zero_grad_all() train_loss = torch.zeros((), device=device) @@ -1304,8 +1232,9 @@ def lr_mul(step: int, elapsed_ms: float) -> float: step += 1 approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) - # SWA: collect checkpoints during warmdown (accumulate in float for precision) - if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + # SWA: collect checkpoints during warmdown (only at full depth to avoid mixing phases) + at_full_depth = current_phase_repeats == args.num_repeats + if args.swa_enabled and at_full_depth and scale < args.swa_start_frac and step % args.swa_every == 0: if swa_state is None: swa_state = {name: t.detach().cpu().float() for name, t in base_model.state_dict().items()} swa_count = 1 @@ -1339,6 +1268,12 @@ def lr_mul(step: int, elapsed_ms: float) -> float: f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" ) + # Restore full depth for eval/export + base_model.cur_repeats = args.num_repeats + torch._dynamo.reset() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + # Apply SWA if collected if args.swa_enabled and swa_state is not None: # Include final weights (may not have landed on swa_every boundary) @@ -1353,10 +1288,8 @@ def lr_mul(step: int, elapsed_ms: float) -> float: } base_model.load_state_dict(avg_state, strict=True) - # ----------------------------- - # SERIALIZATION + ROUNDTRIP VALIDATION - # ----------------------------- - # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # SERIALIZATION + ROUNDTRIP VALIDATION + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce # the compressed quantized+zstd artifact and validate the round-tripped weights. if master_process: @@ -1439,32 +1372,6 @@ def lr_mul(step: int, elapsed_ms: float) -> float: ) log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") - # TTT eval: adapt model on each batch before evaluating - if args.ttt_steps > 0: - base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) - torch.cuda.synchronize() - t_ttt = time.perf_counter() - ttt_val_loss, ttt_val_bpb = eval_val_ttt( - args, - base_model, - model, - rank, - world_size, - device, - grad_accum_steps, - val_tokens, - base_bytes_lut, - has_leading_space_lut, - is_boundary_token_lut, - ) - torch.cuda.synchronize() - log0( - f"final_ttt val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " - f"ttt_steps:{args.ttt_steps} ttt_lr:{args.ttt_lr} " - f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" - ) - log0(f"final_ttt_exact val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f}") - if distributed: dist.destroy_process_group() From c0ce492f9efc3f2960efd5b5564cd6b1433f5216 Mon Sep 17 00:00:00 2001 From: Ivan Verbovoy Date: Thu, 26 Mar 2026 10:25:06 +0200 Subject: [PATCH 09/11] =?UTF-8?q?Add=20submission:=20Progressive=20Depth?= =?UTF-8?q?=20Training=20=E2=80=94=20val=5Fbpb=201.1980?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Progressive depth scheduling (2→3→4 repeats) unique to shared-weight recurrence. 5861 steps in 600s vs ~4300 at constant depth (+36%). Fix DDP race condition in phase switching via all_reduce sync. --- .../2026-03-26_ProgressiveDepth/README.md | 67 + .../submission.json | 16 + .../2026-03-26_ProgressiveDepth/train.log | 113 ++ .../2026-03-26_ProgressiveDepth/train_gpt.py | 1386 +++++++++++++++++ train_gpt.py | 8 +- 5 files changed, 1589 insertions(+), 1 deletion(-) create mode 100644 records/track_10min_16mb/2026-03-26_ProgressiveDepth/README.md create mode 100644 records/track_10min_16mb/2026-03-26_ProgressiveDepth/submission.json create mode 100644 records/track_10min_16mb/2026-03-26_ProgressiveDepth/train.log create mode 100644 records/track_10min_16mb/2026-03-26_ProgressiveDepth/train_gpt.py diff --git a/records/track_10min_16mb/2026-03-26_ProgressiveDepth/README.md b/records/track_10min_16mb/2026-03-26_ProgressiveDepth/README.md new file mode 100644 index 000000000..b0201bf74 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ProgressiveDepth/README.md @@ -0,0 +1,67 @@ +## Progressive Depth Training via Shared-Weight Recurrence + +val_bpb = **1.1980** (sliding window, stride=256, int8+zstd22 roundtrip) +val_bpb = 1.2315 (standard int8+zstd22 roundtrip) + +Progressive Depth is a training-time advantage unique to shared-weight recurrence — flat architectures cannot dynamically adjust their depth during training. + +Because the same 3 blocks are reused at every depth, we can start training with 2 repeats (fast, cheap steps), then progressively increase to 3 and 4 repeats as training progresses. The model learns coarse representations quickly at shallow depth, then refines them at full depth. This is structurally impossible with flat architectures where each layer has unique parameters — you cannot add or remove layers mid-training without changing the parameter space. + +### Progressive Depth Schedule + +| Phase | Time | Repeats | Eff. depth | ms/step | Steps | val_bpb at end | +|-------|------|---------|------------|---------|-------|----------------| +| 1 | 0–40% | 2 | 6 | ~75 | ~3200 | 1.319 | +| 2 | 40–65% | 3 | 9 | ~86 | ~1200 | 1.298 | +| 3 | 65–100% | 4 | 12 | ~96 | ~1800 | 1.229 | + +**Total: 5861 steps** in 600s vs ~4300 steps at constant depth 4 (+36% more gradient updates). + +SWA (Stochastic Weight Averaging) collects checkpoints only during Phase 3 at full depth to avoid mixing representations from different depths. 18 checkpoints averaged. + +### Ablation Trajectory + +Each change isolated and measured on 8xH100 (sliding window eval): + +| Change | val_bpb | Delta | +|--------|---------|-------| +| OpenAI Naive Baseline (9×512, unique layers) | 1.2244 | — | +| Depth Recurrence 3×4 + Cross-Repeat Skip (PR [#148](https://github.com/openai/parameter-golf/pull/148)) | 1.2213 | -0.003 | +| + XSA (Exclusive Self-Attention, last 4 layers) | 1.2110 | -0.010 | +| + LeakyReLU(0.5)² MLP | 1.2069 | -0.004 | +| + Progressive Depth (2→3→4 schedule) | 1.1980 | -0.009 | +| **Total** | **1.1980** | **-0.026** | + +### Cross-Repeat Skip (Novel, PR [#148](https://github.com/openai/parameter-golf/pull/148)) + +Standard depth recurrence is stateless — each repeat starts fresh with no memory of previous passes. Cross-Repeat Skip turns this into stateful recurrence: each block receives a weighted residual of its own output from the previous repeat. Per-block, per-repeat learned scales (~7.5K params). This gives the model a direct gradient path across repeats without the overhead of unique parameters. + +### Architecture + +- 3 shared blocks × 4 repeats = 12 effective layers +- dim=832, 8 heads, 4 KV heads (GQA), MLP 2×, tied embeddings +- **XSA**: Subtracts self-value projection from attention output on last 4 effective layers (reduces attention collapse in deep recurrence) +- **LeakyReLU(0.5)²**: Replaces ReLU² — preserves gradient flow on negative activations through 4 recurrence passes +- 2 Value Embedding tables with per-layer learned scales +- Loop Embedding (depth-wise positional encoding) +- Logit softcap=30, RoPE, RMSNorm +- GPTQ-lite int8 quantization (per-row clip percentile search) + zstd-22 compression +- 17.14M params, 15.88MB artifact + +### Training + +Muon optimizer (momentum=0.95, 5 Newton-Schulz steps, WD=0.04) for matrix params, Adam for scalars/embeddings. + +MATRIX_LR=0.012, SCALAR_LR=0.012, TIED_EMBED_LR=0.015, GRAD_CLIP_NORM=0.3, WARMDOWN_ITERS=3000. + +Phase switching synchronized across DDP ranks via `all_reduce` (max elapsed time) to prevent race conditions during `torch.compile` recompilation. + +### Command + +``` +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +### Results + +5861 steps, 600s on 8xH100. Roundtrip val_bpb 1.2315, sliding window 1.1980. Peak memory 25.5 GB/GPU. diff --git a/records/track_10min_16mb/2026-03-26_ProgressiveDepth/submission.json b/records/track_10min_16mb/2026-03-26_ProgressiveDepth/submission.json new file mode 100644 index 000000000..104d080fa --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ProgressiveDepth/submission.json @@ -0,0 +1,16 @@ +{ + "author": "Ivan Verbovoy", + "github_id": "iverbovoy", + "name": "Progressive Depth + Depth Recurrence + XSA + LeakyReLU\u00b2", + "blurb": "3 unique blocks with progressive depth scheduling (2\u21923\u21924 repeats), XSA on last 4 layers, LeakyReLU(0.5)\u00b2 MLP, SWA over 18 checkpoints, GPTQ-lite int8+zstd22 compression. 5861 steps in 600s on 8xH100.", + "date": "2026-03-26T07:40:00Z", + "val_loss": 2.02277954, + "val_bpb": 1.19800347, + "roundtrip_val_loss": 2.07939783, + "roundtrip_val_bpb": 1.23153652, + "step_stop": 5861, + "wallclock_seconds": 600.140, + "bytes_total": 15875591, + "bytes_model_int8_zstd22": 15815371, + "bytes_code": 60220 +} diff --git a/records/track_10min_16mb/2026-03-26_ProgressiveDepth/train.log b/records/track_10min_16mb/2026-03-26_ProgressiveDepth/train.log new file mode 100644 index 000000000..ea7b7f77f --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ProgressiveDepth/train.log @@ -0,0 +1,113 @@ +W0326 08:03:49.332000 20006 torch/distributed/run.py:793] +W0326 08:03:49.332000 20006 torch/distributed/run.py:793] ***************************************** +W0326 08:03:49.332000 20006 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 08:03:49.332000 20006 torch/distributed/run.py:793] ***************************************** +logs/b9c03e97-bd7a-4a2b-bcb0-89a9dbb80dd2.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17140056 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.015 head_lr:0.0 matrix_lr:0.012 scalar_lr:0.012 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +prog_depth: schedule=[(0.4, 2), (0.65, 3), (1.0, 4)] starting_repeats=2 +step:0/20000 val_loss:6.9300 val_bpb:4.1043 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9769 train_time:6848ms step_avg:6847.62ms +step:2/20000 train_loss:7.6698 train_time:6867ms step_avg:3433.58ms +step:3/20000 train_loss:7.5199 train_time:6936ms step_avg:2312.01ms +step:4/20000 train_loss:7.1738 train_time:7006ms step_avg:1751.49ms +step:5/20000 train_loss:6.6264 train_time:7077ms step_avg:1415.41ms +step:6/20000 train_loss:6.3184 train_time:7150ms step_avg:1191.64ms +step:7/20000 train_loss:5.8051 train_time:7223ms step_avg:1031.86ms +step:8/20000 train_loss:5.6215 train_time:7296ms step_avg:911.96ms +step:9/20000 train_loss:5.4769 train_time:7365ms step_avg:818.36ms +step:10/20000 train_loss:5.3601 train_time:7438ms step_avg:743.83ms +step:200/20000 train_loss:2.7618 train_time:21072ms step_avg:105.36ms +step:400/20000 train_loss:2.3134 train_time:35458ms step_avg:88.65ms +step:600/20000 train_loss:2.5245 train_time:49866ms step_avg:83.11ms +step:800/20000 train_loss:2.2887 train_time:64314ms step_avg:80.39ms +step:1000/20000 train_loss:2.3830 train_time:78796ms step_avg:78.80ms +step:1000/20000 val_loss:2.3413 val_bpb:1.3867 train_time:78837ms step_avg:78.84ms +step:1200/20000 train_loss:2.3953 train_time:93285ms step_avg:77.74ms +step:1400/20000 train_loss:2.4453 train_time:107763ms step_avg:76.97ms +step:1600/20000 train_loss:2.1185 train_time:122260ms step_avg:76.41ms +step:1800/20000 train_loss:2.2231 train_time:136732ms step_avg:75.96ms +step:2000/20000 train_loss:2.2778 train_time:151200ms step_avg:75.60ms +step:2000/20000 val_loss:2.2603 val_bpb:1.3387 train_time:151241ms step_avg:75.62ms +step:2200/20000 train_loss:2.1014 train_time:165670ms step_avg:75.30ms +step:2400/20000 train_loss:2.2244 train_time:180134ms step_avg:75.06ms +step:2600/20000 train_loss:2.4395 train_time:194595ms step_avg:74.84ms +step:2800/20000 train_loss:2.2726 train_time:209041ms step_avg:74.66ms +step:3000/20000 train_loss:2.2600 train_time:223486ms step_avg:74.50ms +step:3000/20000 val_loss:2.2269 val_bpb:1.3189 train_time:223527ms step_avg:74.51ms +step:3200/20000 train_loss:2.2188 train_time:237917ms step_avg:74.35ms +prog_depth: switched to 3 repeats at step:3229 frac:0.40 +step:3400/20000 train_loss:2.1932 train_time:279477ms step_avg:82.20ms +step:3600/20000 train_loss:2.1460 train_time:300629ms step_avg:83.51ms +step:3800/20000 train_loss:2.2472 train_time:321879ms step_avg:84.70ms +step:4000/20000 train_loss:2.1847 train_time:343064ms step_avg:85.77ms +step:4000/20000 val_loss:2.1917 val_bpb:1.2981 train_time:343131ms step_avg:85.78ms +step:4200/20000 train_loss:2.1871 train_time:364232ms step_avg:86.72ms +step:4400/20000 train_loss:2.1208 train_time:385376ms step_avg:87.59ms +prog_depth: switched to 4 repeats at step:4444 frac:0.65 +step:4600/20000 train_loss:1.9634 train_time:423021ms step_avg:91.96ms +step:4800/20000 train_loss:2.2479 train_time:450939ms step_avg:93.95ms +step:5000/20000 train_loss:1.9975 train_time:478895ms step_avg:95.78ms +step:5000/20000 val_loss:2.1211 val_bpb:1.2562 train_time:478979ms step_avg:95.80ms +swa:start step:5050 +step:5200/20000 train_loss:2.1314 train_time:507384ms step_avg:97.57ms +step:5400/20000 train_loss:2.1322 train_time:535401ms step_avg:99.15ms +step:5600/20000 train_loss:2.1209 train_time:563472ms step_avg:100.62ms +step:5800/20000 train_loss:2.0748 train_time:591529ms step_avg:101.99ms +step:5861/20000 val_loss:2.0758 val_bpb:1.2294 train_time:600140ms step_avg:102.40ms +stopping_early: wallclock_cap train_time:600140ms step:5861/20000 +peak memory allocated: 25539 MiB reserved: 26118 MiB +swa: averaging 18 checkpoints +Serialized model: 63386762 bytes +Code size: 60220 bytes +Total submission size: 63446982 bytes +Serialized model int8+zstd22: 15815371 bytes (payload:17243616 raw_torch:17260843 payload_ratio:3.68x) +Total submission size int8+zstd22: 15875591 bytes +/workspace/parameter-golf/train_gpt.py:1334: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1334: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1334: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1334: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1334: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1334: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1334: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1334: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +final_roundtrip val_loss:2.0794 val_bpb:1.2315 eval_time:14103ms +final_roundtrip_exact val_loss:2.07939783 val_bpb:1.23153652 +final_sliding_window val_loss:2.0228 val_bpb:1.1980 window:1024 stride:256 eval_time:66815ms +final_sliding_window_exact val_loss:2.02277954 val_bpb:1.19800347 diff --git a/records/track_10min_16mb/2026-03-26_ProgressiveDepth/train_gpt.py b/records/track_10min_16mb/2026-03-26_ProgressiveDepth/train_gpt.py new file mode 100644 index 000000000..e45fdfc2e --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ProgressiveDepth/train_gpt.py @@ -0,0 +1,1386 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zstandard as zstd +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# HYPERPARAMETERS + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + + # Progressive Depth: train with fewer repeats early (faster), more repeats later (deeper). + # Schedule format: "frac1:rep1,frac2:rep2,..." e.g. "0.4:2,0.65:3,1.0:4" + prog_depth_schedule = os.environ.get("PROG_DEPTH", "0.4:2,0.65:3,1.0:4") + + # XSA (Exclusive Self-Attention) on last N effective layers. + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + + # SWA (Stochastic Weight Averaging) during warmdown. + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Sliding window eval. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 1024)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 3)) + num_repeats = int(os.environ.get("NUM_REPEATS", 4)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 832)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + num_value_embeds = int(os.environ.get("NUM_VALUE_EMBEDS", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.015)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.012)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.012)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + +# MUON OPTIMIZER +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Sliding window eval with batching. Windows of train_seq_len advance by eval_stride. + Only the last stride tokens per window are scored (first window scores all).""" + seq_len = args.eval_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi : bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws : end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(x_batch) + + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + val_loss_sum += scored_nll.sum() + val_token_count += float(wlen - s) + prev_ids = x_batch[i, s:wlen] + tgt_ids = y_batch[i, s:wlen] + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + + +# POST-TRAINING QUANTIZATION +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing and zstd compressing. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +# Quantization levels: 127 = int8, 31 = int6, 16 = int5. Per-tensor override via MLP_QUANT_LEVELS. +QUANT_LEVELS = int(os.environ.get("QUANT_LEVELS", 127)) +MLP_QUANT_LEVELS = int(os.environ.get("MLP_QUANT_LEVELS", 0)) # 0 = same as QUANT_LEVELS +MLP_TENSOR_PATTERNS = ("mlp.fc.", "mlp.proj.", "fc.weight", "mlp.proj.weight") + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 0.999999] + +def quantize_float_tensor(t: Tensor, ql: int = 0) -> tuple[Tensor, Tensor]: + if ql <= 0: + ql = QUANT_LEVELS + t32 = t.float() + if t32.ndim == 2: + # GPTQ-lite: try multiple clip percentiles per row, pick best MSE + abs_t = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in GPTQ_LITE_PERCENTILES: + clip_abs = ( + torch.quantile(abs_t, pct, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + s = (clip_abs / ql).clamp_min(1e-12) + q = torch.clamp(torch.round(clipped / s[:, None]), -ql, ql) + # Reconstruction error per row + recon = q * s[:, None] + mse = (t32 - recon).square().sum(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = s + else: + better = mse < best_mse + best_mse = torch.where(better, mse, best_mse) + best_q = torch.where(better[:, None], q, best_q) + best_scale = torch.where(better, s, best_scale) + return best_q.to(torch.int8).contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / ql if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -ql, ql).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + mlp_ql = MLP_QUANT_LEVELS if MLP_QUANT_LEVELS > 0 else QUANT_LEVELS + ql = mlp_ql if any(p in name for p in MLP_TENSOR_PATTERNS) else QUANT_LEVELS + q, s = quantize_float_tensor(t, ql=ql) + meta: dict[str, object] = {} + if s.ndim > 0: + meta["scheme"] = "per_row" + meta["axis"] = 0 + if ql != QUANT_LEVELS: + meta["ql"] = ql + if meta: + qmeta[name] = meta + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# DATA LOADING + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# TRANSFORMER MODULES + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def _xsa(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection from attention output (GQA-aware).""" + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(3) # [B, T, Hkv, 1, D] + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, use_xsa: bool = False) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + # XSA: remove self-value bias from attention output + if use_xsa: + y = y.transpose(1, 2).contiguous() # [B, T, H, D] + v_for_xsa = v.transpose(1, 2) # [B, T, Hkv, D] + y = self._xsa(y, v_for_xsa) + y = y.reshape(bsz, seqlen, dim) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu(0.5)^2 MLP — better gradient flow than relu^2 for deep/recurrent models + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, use_xsa: bool = False) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x), use_xsa=use_xsa) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_repeats: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + num_value_embeds: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_repeats = num_repeats + self.xsa_last_n = xsa_last_n + effective_depth = num_layers * num_repeats + self.tok_emb = nn.Embedding(vocab_size, model_dim) + # Value embeddings: extra embedding tables mixed into each effective layer + self.num_value_embeds = num_value_embeds + if num_value_embeds > 0: + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(num_value_embeds)]) + self.value_scales = nn.Parameter(torch.zeros(effective_depth, num_value_embeds, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + # Loop embedding: tells the model which effective layer it's at + self.loop_embed = nn.Parameter(torch.zeros(effective_depth, model_dim, dtype=torch.float32)) + # Cross-repeat skip: each block receives its own output from previous repeat + self.cross_repeat_scales = nn.Parameter(torch.zeros(num_layers, num_repeats - 1, model_dim, dtype=torch.float32)) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # Pre-compute value embeddings once + ve_list: list[Tensor] = [] + if self.num_value_embeds > 0: + for ve in self.value_embeds: + ve_list.append(ve(input_ids)) # (bsz, seq, dim) + + cur_repeats = self.cur_repeats if hasattr(self, "cur_repeats") else self.num_repeats + cur_depth = len(self.blocks) * cur_repeats + xsa_start = max(0, cur_depth - self.xsa_last_n) if self.xsa_last_n > 0 else cur_depth + + num_blocks = len(self.blocks) + prev_block_outputs: list[Tensor | None] = [None] * num_blocks + layer_idx = 0 + for repeat in range(cur_repeats): + for block_idx, block in enumerate(self.blocks): + x = x + self.loop_embed[layer_idx].to(dtype=x.dtype) + # Value embeddings: add weighted extra embeddings at each layer + if layer_idx < self.value_scales.size(0): + for ve_idx, ve_out in enumerate(ve_list): + vs = self.value_scales[layer_idx, ve_idx].to(dtype=x.dtype) + x = x + vs[None, None, :] * ve_out + # Cross-repeat skip: mix in this block's output from previous repeat + if repeat > 0 and prev_block_outputs[block_idx] is not None: + rep_idx = min(repeat - 1, self.cross_repeat_scales.size(1) - 1) + scale = self.cross_repeat_scales[block_idx, rep_idx].to(dtype=x.dtype) + x = x + scale[None, None, :] * prev_block_outputs[block_idx] + x = block(x, x0, use_xsa=(layer_idx >= xsa_start)) + prev_block_outputs[block_idx] = x.detach() if not self.training else x + layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# TRAINING + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # DISTRIBUTED + CUDA SETUP + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # TOKENIZER + VALIDATION METRIC SETUP + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + num_repeats=args.num_repeats, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + num_value_embeds=args.num_value_embeds, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params.append(base_model.loop_embed) + scalar_params.append(base_model.cross_repeat_scales) + if base_model.num_value_embeds > 0: + scalar_params.append(base_model.value_scales) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params = [base_model.tok_emb.weight] + if base_model.num_value_embeds > 0: + embed_params.extend(ve.weight for ve in base_model.value_embeds) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # Progressive depth schedule: parse "frac:repeats,..." and sort + prog_phases: list[tuple[float, int]] = [] + for entry in args.prog_depth_schedule.split(","): + frac_s, rep_s = entry.strip().split(":") + prog_phases.append((float(frac_s), int(rep_s))) + prog_phases.sort() + current_phase_repeats = prog_phases[0][1] if prog_phases else args.num_repeats + base_model.cur_repeats = current_phase_repeats + # Recompile with initial phase depth + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"prog_depth: schedule={prog_phases} starting_repeats={current_phase_repeats}") + + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + # Progressive depth: check if we need to switch phase + # Use synchronized elapsed time (max across ranks) to avoid race conditions + if max_wallclock_ms is not None and prog_phases: + if distributed: + elapsed_tensor = torch.tensor(elapsed_ms, device=device) + dist.all_reduce(elapsed_tensor, op=dist.ReduceOp.MAX) + frac = elapsed_tensor.item() / max_wallclock_ms + else: + frac = elapsed_ms / max_wallclock_ms + new_repeats = prog_phases[-1][1] # default to last + for phase_frac, phase_rep in prog_phases: + if frac < phase_frac: + new_repeats = phase_rep + break + if new_repeats != current_phase_repeats: + current_phase_repeats = new_repeats + base_model.cur_repeats = new_repeats + torch._dynamo.reset() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"prog_depth: switched to {new_repeats} repeats at step:{step} frac:{frac:.2f}") + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown (only at full depth to avoid mixing phases) + at_full_depth = current_phase_repeats == args.num_repeats + if args.swa_enabled and at_full_depth and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().float() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Restore full depth for eval/export + base_model.cur_repeats = args.num_repeats + torch._dynamo.reset() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None: + # Include final weights (may not have landed on swa_every boundary) + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + log0(f"swa: averaging {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed quantized+zstd artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + zstd_level = int(os.environ.get("ZSTD_LEVEL", 22)) + cctx = zstd.ZstdCompressor(level=zstd_level) + quant_blob = cctx.compress(quant_raw) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zstd{zstd_level}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zstd{zstd_level}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + dctx = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval + if args.eval_stride > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"window:{args.eval_seq_len} stride:{args.eval_stride} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/train_gpt.py b/train_gpt.py index 0127f8fcc..e45fdfc2e 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -1187,8 +1187,14 @@ def lr_mul(step: int, elapsed_ms: float) -> float: elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) # Progressive depth: check if we need to switch phase + # Use synchronized elapsed time (max across ranks) to avoid race conditions if max_wallclock_ms is not None and prog_phases: - frac = elapsed_ms / max_wallclock_ms + if distributed: + elapsed_tensor = torch.tensor(elapsed_ms, device=device) + dist.all_reduce(elapsed_tensor, op=dist.ReduceOp.MAX) + frac = elapsed_tensor.item() / max_wallclock_ms + else: + frac = elapsed_ms / max_wallclock_ms new_repeats = prog_phases[-1][1] # default to last for phase_frac, phase_rep in prog_phases: if frac < phase_frac: From 57516979342f7c9ab66c503819ececf40f77746e Mon Sep 17 00:00:00 2001 From: Ivan Verbovoy Date: Thu, 26 Mar 2026 14:59:16 +0200 Subject: [PATCH 10/11] =?UTF-8?q?Tune=20hyperparameters:=20LR=200.018,=20w?= =?UTF-8?q?armdown=202000=20=E2=80=94=20val=5Fbpb=201.1960?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Systematic tuning on 8xH100 (6 runs): - WARMDOWN_ITERS 3000→2000: full LR at phase 4 entry (-0.0009) - MATRIX/SCALAR_LR 0.012→0.018: higher LR for progressive depth (-0.0011) - Combined: val_bpb 1.1960 sliding (-0.0020 from 1.1980) Tested and rejected: schedule changes (3-phase optimal), SWA_EVERY=25, 5 repeats, GRAD_CLIP=0.5, VRL, per-repeat LoRA (artifact >16MB). --- train_gpt.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train_gpt.py b/train_gpt.py index e45fdfc2e..7410c0580 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -45,7 +45,7 @@ class Hyperparameters: # Training length. iterations = int(os.environ.get("ITERATIONS", 20000)) - warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2000)) warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) @@ -86,10 +86,10 @@ class Hyperparameters: # Optimizer hyperparameters. embed_lr = float(os.environ.get("EMBED_LR", 0.6)) head_lr = float(os.environ.get("HEAD_LR", 0.008)) - tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.015)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.021)) tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) - matrix_lr = float(os.environ.get("MATRIX_LR", 0.012)) - scalar_lr = float(os.environ.get("SCALAR_LR", 0.012)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.018)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.018)) muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) From 2f25412ebb4fc7163d6d13b1b8aa04eb8a1ee163 Mon Sep 17 00:00:00 2001 From: Ivan Verbovoy Date: Thu, 26 Mar 2026 16:51:02 +0200 Subject: [PATCH 11/11] =?UTF-8?q?Add=20Hedge=20Mixer=20+=20tuned=20hyperpa?= =?UTF-8?q?rams=20=E2=80=94=20val=5Fbpb=201.1454?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 5-expert online ensemble (neural + unigram + bigram + trigram + entropy) via Hedge algorithm at eval time. -0.051 bpb over sliding window. Tuned defaults: LR=0.018, WARMDOWN=2000 (-0.002 from previous). Total improvement: 1.2244 → 1.1454 (-0.079 from baseline). --- .../README.md | 62 + .../submission.json | 19 + .../train.log | 114 ++ .../train_gpt.py | 1498 +++++++++++++++++ train_gpt.py | 136 +- 5 files changed, 1817 insertions(+), 12 deletions(-) create mode 100644 records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/README.md create mode 100644 records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/submission.json create mode 100644 records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/train.log create mode 100644 records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/train_gpt.py diff --git a/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/README.md b/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/README.md new file mode 100644 index 000000000..3995ce744 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/README.md @@ -0,0 +1,62 @@ +## Progressive Depth + Hedge Mixer + +val_bpb = **1.1454** (Hedge Mixer eval, int8+zstd22 roundtrip model) +val_bpb = 1.1966 (sliding window only) +val_bpb = 1.2304 (standard roundtrip) + +### Hedge Mixer: 5-Expert Online Ensemble + +Eval-time improvement via online mixture of 5 experts using the Hedge algorithm (multiplicative weights). No training data access — n-gram tables built from already-scored tokens only. + +| Expert | Source | Role | +|--------|--------|------| +| Neural | Model softmax output | Primary prediction | +| Unigram | Token frequency from scored data | Frequency prior | +| Bigram | P(next\|prev) from scored data | Local context | +| Trigram | Hash table (64K buckets) from scored data | Extended context | +| Entropy | Model confidence weighting | Calibration | + +Weights initialized with neural bias (log_weight=2.0), updated via `log_w -= eta * expert_mean_loss` after each batch. The mixer is cold-started (uses pure neural output until 10K tokens scored), then progressively improves as n-gram statistics accumulate. + +**Impact: -0.051 bpb** over sliding window eval (1.1966 → 1.1454). This is larger than all architectural improvements combined. + +Eval time: 579s on 8xH100 (sequential processing required for n-gram table consistency). + +### Architecture (unchanged from PR #835) + +3 shared transformer blocks with depth recurrence, progressive depth scheduling unique to shared-weight recurrence. + +- **Progressive Depth Training**: Phase 1 (0-40%): 2 repeats ~75ms/step. Phase 2 (40-65%): 3 repeats ~86ms/step. Phase 3 (65-100%): 4 repeats ~96ms/step. 5673 steps in 600s. +- **Cross-Repeat Skip** (#148, Novel): Stateful recurrence — each block receives weighted residual from previous repeat. +- **XSA**: Exclusive Self-Attention on last 4 effective layers. +- **LeakyReLU(0.5)²**: Better gradient flow through 4-repeat recurrence. +- dim=832, 8 heads, 4 KV heads (GQA), MLP 2×, tied embeddings, SWA (18 checkpoints). +- 17.14M params, 15.88MB artifact (int8+zstd22). + +### Tuned Hyperparameters + +MATRIX_LR=0.018, SCALAR_LR=0.018, TIED_EMBED_LR=0.021, WARMDOWN_ITERS=2000. + +Higher LR compensates for progressive depth's shallow early phases. Shorter warmdown gives full LR at full-depth entry. + +### Ablation Trajectory + +| Change | val_bpb | Delta | +|--------|---------|-------| +| OpenAI Naive Baseline | 1.2244 | — | +| Depth Recurrence 3×4 + Cross-Repeat Skip (#148) | 1.2213 | -0.003 | +| + XSA + LeakyReLU² (#784) | 1.2069 | -0.014 | +| + Progressive Depth (#835) | 1.1980 | -0.009 | +| + LR/Warmdown tuning | 1.1960 | -0.002 | +| + Hedge Mixer (eval) | 1.1454 | -0.051 | +| **Total** | **1.1454** | **-0.079** | + +### Command + +``` +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +### Credits + +Hedge Mixer algorithm adapted from PR #688 (@RoyiRa) and PR #745 (@stukenov). diff --git a/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/submission.json b/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/submission.json new file mode 100644 index 000000000..9661bb9c4 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/submission.json @@ -0,0 +1,19 @@ +{ + "author": "Ivan Verbovoy", + "github_id": "iverbovoy", + "name": "Progressive Depth + Hedge Mixer (5-expert online ensemble)", + "blurb": "3 unique blocks with progressive depth scheduling (2\u21923\u21924 repeats), XSA, LeakyReLU\u00b2, Cross-Repeat Skip, SWA, int8+zstd22. Eval: 5-expert Hedge Mixer (neural + unigram + bigram + trigram + entropy) with online multiplicative weight updates. 5673 steps in 600s train, 579s eval on 8xH100.", + "date": "2026-03-26T15:00:00Z", + "val_loss": 1.93403169, + "val_bpb": 1.14544202, + "roundtrip_val_loss": 2.07744208, + "roundtrip_val_bpb": 1.23037822, + "sliding_val_loss": 2.02046173, + "sliding_val_bpb": 1.19663074, + "step_stop": 5673, + "wallclock_seconds": 600.218, + "eval_seconds": 579.109, + "bytes_total": 15884272, + "bytes_model_int8_zstd22": 15818418, + "bytes_code": 65854 +} diff --git a/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/train.log b/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/train.log new file mode 100644 index 000000000..221bf19a8 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/train.log @@ -0,0 +1,114 @@ +W0326 14:09:06.471000 2358 torch/distributed/run.py:793] +W0326 14:09:06.471000 2358 torch/distributed/run.py:793] ***************************************** +W0326 14:09:06.471000 2358 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 14:09:06.471000 2358 torch/distributed/run.py:793] ***************************************** +logs/07b6a996-fe2d-47a4-a5a5-24bf61fec8f0.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:17140056 +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.021 head_lr:0.0 matrix_lr:0.018 scalar_lr:0.018 +train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +prog_depth: schedule=[(0.4, 2), (0.65, 3), (1.0, 4)] starting_repeats=2 +step:0/20000 val_loss:6.9300 val_bpb:4.1043 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9769 train_time:20810ms step_avg:20809.51ms +step:2/20000 train_loss:9.6250 train_time:20829ms step_avg:10414.40ms +step:3/20000 train_loss:9.4925 train_time:20897ms step_avg:6965.53ms +step:4/20000 train_loss:9.1975 train_time:20968ms step_avg:5241.90ms +step:5/20000 train_loss:8.6451 train_time:21039ms step_avg:4207.76ms +step:6/20000 train_loss:8.1740 train_time:21110ms step_avg:3518.30ms +step:7/20000 train_loss:7.2979 train_time:21182ms step_avg:3026.05ms +step:8/20000 train_loss:6.6939 train_time:21255ms step_avg:2656.83ms +step:9/20000 train_loss:6.1779 train_time:21331ms step_avg:2370.13ms +step:10/20000 train_loss:5.8322 train_time:21400ms step_avg:2140.02ms +step:200/20000 train_loss:2.7613 train_time:35036ms step_avg:175.18ms +step:400/20000 train_loss:2.3100 train_time:49402ms step_avg:123.50ms +step:600/20000 train_loss:2.5366 train_time:63819ms step_avg:106.36ms +step:800/20000 train_loss:2.2979 train_time:78288ms step_avg:97.86ms +step:1000/20000 train_loss:2.3821 train_time:92765ms step_avg:92.76ms +step:1000/20000 val_loss:2.3450 val_bpb:1.3888 train_time:92807ms step_avg:92.81ms +step:1200/20000 train_loss:2.4011 train_time:107254ms step_avg:89.38ms +step:1400/20000 train_loss:2.4523 train_time:121727ms step_avg:86.95ms +step:1600/20000 train_loss:2.1223 train_time:136207ms step_avg:85.13ms +step:1800/20000 train_loss:2.2266 train_time:150673ms step_avg:83.71ms +step:2000/20000 train_loss:2.2854 train_time:165142ms step_avg:82.57ms +step:2000/20000 val_loss:2.2671 val_bpb:1.3427 train_time:165184ms step_avg:82.59ms +step:2200/20000 train_loss:2.1072 train_time:179607ms step_avg:81.64ms +step:2400/20000 train_loss:2.2328 train_time:194078ms step_avg:80.87ms +step:2600/20000 train_loss:2.4458 train_time:208531ms step_avg:80.20ms +step:2800/20000 train_loss:2.2812 train_time:222979ms step_avg:79.64ms +step:3000/20000 train_loss:2.2707 train_time:237400ms step_avg:79.13ms +step:3000/20000 val_loss:2.2365 val_bpb:1.3246 train_time:237442ms step_avg:79.15ms +prog_depth: switched to 3 repeats at step:3036 frac:0.40 +step:3200/20000 train_loss:2.2283 train_time:278657ms step_avg:87.08ms +step:3400/20000 train_loss:2.1915 train_time:299788ms step_avg:88.17ms +step:3600/20000 train_loss:2.1526 train_time:320953ms step_avg:89.15ms +step:3800/20000 train_loss:2.2486 train_time:342113ms step_avg:90.03ms +step:4000/20000 train_loss:2.1951 train_time:363281ms step_avg:90.82ms +step:4000/20000 val_loss:2.2001 val_bpb:1.3030 train_time:363349ms step_avg:90.84ms +step:4200/20000 train_loss:2.2068 train_time:384450ms step_avg:91.54ms +prog_depth: switched to 4 repeats at step:4252 frac:0.65 +step:4400/20000 train_loss:2.1355 train_time:421875ms step_avg:95.88ms +step:4600/20000 train_loss:1.9747 train_time:449821ms step_avg:97.79ms +step:4800/20000 train_loss:2.2530 train_time:477877ms step_avg:99.56ms +step:5000/20000 train_loss:2.0057 train_time:505841ms step_avg:101.17ms +step:5000/20000 val_loss:2.1261 val_bpb:1.2592 train_time:505925ms step_avg:101.19ms +swa:start step:5100 +step:5200/20000 train_loss:2.1309 train_time:533773ms step_avg:102.65ms +step:5400/20000 train_loss:2.1258 train_time:561825ms step_avg:104.04ms +step:5600/20000 train_loss:2.1075 train_time:589899ms step_avg:105.34ms +step:5673/20000 val_loss:2.0739 val_bpb:1.2283 train_time:600218ms step_avg:105.80ms +stopping_early: wallclock_cap train_time:600218ms step:5673/20000 +peak memory allocated: 25696 MiB reserved: 27322 MiB +swa: averaging 13 checkpoints +Serialized model: 63386762 bytes +Code size: 65854 bytes +Total submission size: 63452616 bytes +Serialized model int8+zstd22: 15818418 bytes (payload:17243616 raw_torch:17260843 payload_ratio:3.68x) +Total submission size int8+zstd22: 15884272 bytes +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +/workspace/parameter-golf/train_gpt.py:1429: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") +final_roundtrip val_loss:2.0774 val_bpb:1.2304 eval_time:13407ms +final_roundtrip_exact val_loss:2.07744208 val_bpb:1.23037822 +final_sliding_window val_loss:2.0205 val_bpb:1.1966 window:1024 stride:256 eval_time:66781ms +final_sliding_window_exact val_loss:2.02046173 val_bpb:1.19663074 +final_hedge_mixer val_loss:1.9340 val_bpb:1.1454 eval_time:579109ms +final_hedge_mixer_exact val_loss:1.93403169 val_bpb:1.14544202 diff --git a/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/train_gpt.py b/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/train_gpt.py new file mode 100644 index 000000000..1738288f3 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ProgressiveDepth_HedgeMixer/train_gpt.py @@ -0,0 +1,1498 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: `train_gpt.py` and `train_gpt_mlx.py` must never be longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zstandard as zstd +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + + +class HedgeMixer: + """Online mixture of 5 experts via Hedge algorithm for eval-time improvement. + Experts: Neural, Unigram, Bigram, Trigram (hashed), Entropy.""" + def __init__(self, vocab_size: int = 1024, device: str = "cuda", eta: float = 0.1): + self.V = vocab_size + self.device = device + self.eta = eta + self.log_weights = torch.zeros(5, device=device) + self.log_weights[0] = 2.0 # bias toward neural + self.uni_counts = torch.zeros(vocab_size, device=device) + self.bi_counts = torch.zeros(vocab_size, vocab_size, device=device) + self.total_tokens = 0 + self.TRI_HASH = 65536 + self.tri_counts = torch.zeros(self.TRI_HASH, vocab_size, device=device) + self.tri_row_totals = torch.zeros(self.TRI_HASH, device=device) + + def update(self, tokens: Tensor) -> None: + t = tokens.to(self.device).long() + n = t.numel() + if n == 0: + return + self.total_tokens += n + ones = torch.ones(n, device=self.device) + self.uni_counts.scatter_add_(0, t, ones) + if n >= 2: + bi_idx = t[:-1] * self.V + t[1:] + self.bi_counts.reshape(-1).scatter_add_(0, bi_idx, torch.ones(n - 1, device=self.device)) + if n >= 3: + tri_ctx = ((t[:-2] * 36313) ^ (t[1:-1] * 27191)) % self.TRI_HASH + tri_idx = tri_ctx * self.V + t[2:] + ones_tri = torch.ones(n - 2, device=self.device) + self.tri_counts.reshape(-1).scatter_add_(0, tri_idx, ones_tri) + self.tri_row_totals.scatter_add_(0, tri_ctx, ones_tri) + + def mix_and_score(self, neural_logits: Tensor, x_batch: Tensor, y_batch: Tensor, wlens: list[int]) -> Tensor: + bsz, slen, V = neural_logits.shape + uniform_nll = math.log(self.V) + has_data = self.total_tokens > 0 + neural_lp = F.log_softmax(neural_logits, dim=-1) + neural_nll = -neural_lp.gather(2, y_batch.unsqueeze(2)).squeeze(2) + if not has_data or self.total_tokens < 10000: + return neural_nll + uni_probs = (self.uni_counts + 0.1) / (self.total_tokens + 0.1 * self.V) + uni_nll = -uni_probs.log()[y_batch] + bi_total = self.bi_counts.sum(dim=1, keepdim=True) + bi_probs = (self.bi_counts + 0.1) / (bi_total + 0.1 * self.V) + bi_nll = -bi_probs.log()[x_batch.reshape(-1), y_batch.reshape(-1)].reshape(bsz, slen) + if slen >= 2: + prev2 = torch.zeros_like(x_batch) + prev2[:, 1:] = x_batch[:, :-1] + ctx_hash = ((prev2 * 36313) ^ (x_batch * 27191)) % self.TRI_HASH + tri_count = self.tri_counts[ctx_hash.reshape(-1).long(), y_batch.reshape(-1).long()] + tri_total = self.tri_row_totals[ctx_hash.reshape(-1).long()].clamp(min=1) + tri_nll = -(((tri_count + 0.01) / (tri_total + 0.01 * self.V)).log()).reshape(bsz, slen) + else: + tri_nll = torch.full((bsz, slen), uniform_nll, device=self.device) + entropy_nll = -(neural_lp.exp() * neural_lp).sum(-1) + expert_nll = torch.stack([neural_nll, uni_nll, bi_nll, tri_nll, entropy_nll], dim=-1) + log_w = self.log_weights - self.log_weights.logsumexp(0) + mixed_nll = -(-expert_nll + log_w.unsqueeze(0).unsqueeze(0)).logsumexp(dim=-1) + # Update weights + wlens_t = torch.tensor(wlens, device=self.device, dtype=torch.long) + mask = torch.arange(slen, device=self.device).unsqueeze(0) < wlens_t.unsqueeze(1) + masked_nll = expert_nll * mask.unsqueeze(-1).float() + expert_mean_loss = masked_nll.sum(dim=(0, 1)) / mask.sum().clamp(min=1) + self.log_weights -= self.eta * expert_mean_loss + return mixed_nll + + +# HYPERPARAMETERS + +class Hyperparameters: + # Data paths are shard globs produced by the existing preprocessing pipeline. + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + # Validation cadence and batch size. Validation always uses the full fineweb_val split. + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 200)) + + # Training length. + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 2000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 524_288)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + + # Progressive Depth: train with fewer repeats early (faster), more repeats later (deeper). + # Schedule format: "frac1:rep1,frac2:rep2,..." e.g. "0.4:2,0.65:3,1.0:4" + prog_depth_schedule = os.environ.get("PROG_DEPTH", "0.4:2,0.65:3,1.0:4") + + # XSA (Exclusive Self-Attention) on last N effective layers. + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + + # SWA (Stochastic Weight Averaging) during warmdown. + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Sliding window eval. + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 1024)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + # Hedge Mixer (eval-time n-gram ensemble). + use_hedge = bool(int(os.environ.get("USE_HEDGE", "1"))) + hedge_eta = float(os.environ.get("HEDGE_ETA", 0.1)) + + # Model shape. + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 3)) + num_repeats = int(os.environ.get("NUM_REPEATS", 4)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 832)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + num_value_embeds = int(os.environ.get("NUM_VALUE_EMBEDS", 2)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + # Optimizer hyperparameters. + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.021)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.018)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.018)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.95)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.85)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 500)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + +# MUON OPTIMIZER +# +# As borrowed from modded-nanogpt +# Background on Muon: https://kellerjordan.github.io/posts/muon/ + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + # Orthogonalize a 2D update matrix with a fast Newton-Schulz iteration. + # Muon uses this to normalize matrix-shaped gradients before applying them. + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + # Scale correction from Muon reference implementations. + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + + +# TOKENIZER-AGNOSTIC EVALUATION SETUP +# +# It's common for small models have a large fraction of their parameters be embeddings, since the 2 * d_model * d_vocab vectors can be gigantic. +# Instead of locking the tokenizer, we let you bring your own and calculate our validation metrics on the average compression of the validation set. +# We calculate BPB (bits-per-byte) instead of validation loss, so we need methods to count the number of bits per token in the tokenizer. +# Note: Submissions that edit the tokenizer will be examined more carefully, since screwing this up might unjustly improve your score. + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + # The export pipeline writes the fixed first-50k-doc validation set to fineweb_val_*. + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + # Validation computes two metrics: + # - val_loss: token cross-entropy (natural log) + # - val_bpb: tokenizer-agnostic compression metric used by the challenge + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + use_hedge: bool = False, + hedge_eta: float = 0.1, +) -> tuple[float, float]: + """Sliding window eval with batching. Windows of train_seq_len advance by eval_stride. + Only the last stride tokens per window are scored (first window scores all). + Optional Hedge Mixer: online n-gram ensemble over scored tokens.""" + seq_len = args.eval_seq_len + stride = args.eval_stride + batch_seqs = args.eval_batch_seqs + total_tokens = val_tokens.numel() - 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + + # With Hedge Mixer: process ALL windows on each rank (sequential, n-gram tables need full context) + # Without: distribute windows across ranks + if use_hedge: + my_windows = window_starts + else: + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + mixer = HedgeMixer(vocab_size=args.vocab_size, device=device, eta=hedge_eta) if use_hedge else None + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi : bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws : end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + logits = base_model.forward_logits(x_batch) + + if mixer is not None: + nll = mixer.mix_and_score(logits.float(), x_batch, y_batch, wlens) + else: + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + val_loss_sum += scored_nll.sum() + val_token_count += float(wlen - s) + prev_ids = x_batch[i, s:wlen] + tgt_ids = y_batch[i, s:wlen] + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + # Update n-gram tables with scored tokens + if mixer is not None: + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + mixer.update(y_batch[i, s:wlen]) + + if not use_hedge and dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + + +# POST-TRAINING QUANTIZATION +# +# It's silly to export our model, which is trained in bf16 and fp32, at that same precision. +# Instead, we get approximately the same model (with a small hit) by quantizing and zstd compressing. + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +# Quantization levels: 127 = int8, 31 = int6, 16 = int5. Per-tensor override via MLP_QUANT_LEVELS. +QUANT_LEVELS = int(os.environ.get("QUANT_LEVELS", 127)) +MLP_QUANT_LEVELS = int(os.environ.get("MLP_QUANT_LEVELS", 0)) # 0 = same as QUANT_LEVELS +MLP_TENSOR_PATTERNS = ("mlp.fc.", "mlp.proj.", "fc.weight", "mlp.proj.weight") + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +GPTQ_LITE_PERCENTILES = [0.9999, 0.99995, 0.99999, 0.999995, 0.999999] + +def quantize_float_tensor(t: Tensor, ql: int = 0) -> tuple[Tensor, Tensor]: + if ql <= 0: + ql = QUANT_LEVELS + t32 = t.float() + if t32.ndim == 2: + # GPTQ-lite: try multiple clip percentiles per row, pick best MSE + abs_t = t32.abs() + best_q = None + best_scale = None + best_mse = None + for pct in GPTQ_LITE_PERCENTILES: + clip_abs = ( + torch.quantile(abs_t, pct, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + s = (clip_abs / ql).clamp_min(1e-12) + q = torch.clamp(torch.round(clipped / s[:, None]), -ql, ql) + # Reconstruction error per row + recon = q * s[:, None] + mse = (t32 - recon).square().sum(dim=1) + if best_mse is None: + best_mse = mse + best_q = q + best_scale = s + else: + better = mse < best_mse + best_mse = torch.where(better, mse, best_mse) + best_q = torch.where(better[:, None], q, best_q) + best_scale = torch.where(better, s, best_scale) + return best_q.to(torch.int8).contiguous(), best_scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / ql if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -ql, ql).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + # Single supported clean-script export format: + # - per-row int8 for 2D float tensors + # - per-tensor int8 for other float tensors + # - exact passthrough for non-floats + # - passthrough for small float tensors, stored as fp16 to save bytes + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + # Small float tensors are cheap enough to keep directly. We still downcast + # fp32/bf16 passthrough tensors to fp16 so metadata does not dominate size. + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + mlp_ql = MLP_QUANT_LEVELS if MLP_QUANT_LEVELS > 0 else QUANT_LEVELS + ql = mlp_ql if any(p in name for p in MLP_TENSOR_PATTERNS) else QUANT_LEVELS + q, s = quantize_float_tensor(t, ql=ql) + meta: dict[str, object] = {} + if s.ndim > 0: + meta["scheme"] = "per_row" + meta["axis"] = 0 + if ql != QUANT_LEVELS: + meta["ql"] = ql + if meta: + qmeta[name] = meta + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + # Broadcast the saved row scale back across trailing dimensions. + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + # Restore small tensors, undoing the temporary fp16 storage cast if needed. + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + + +# DATA LOADING + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + # Each call consumes a contiguous chunk from the shared token stream, then slices out + # one disjoint span per rank. The extra "+1" token lets us build (x, y) by shifting. + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# TRANSFORMER MODULES + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # Keep weights in fp32 for optimizer/state quality, cast at matmul time for bf16 compute. + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + # Keep small/control parameters in fp32 even when the model body runs in bf16. + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + # Caches cos/sin tables per sequence length on the current device. + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def _xsa(self, y: Tensor, v: Tensor) -> Tensor: + """Subtract self-value projection from attention output (GQA-aware).""" + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(3) # [B, T, Hkv, 1, D] + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, use_xsa: bool = False) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ) + # XSA: remove self-value bias from attention output + if use_xsa: + y = y.transpose(1, 2).contiguous() # [B, T, H, D] + v_for_xsa = v.transpose(1, 2) # [B, T, Hkv, D] + y = self._xsa(y, v_for_xsa) + y = y.reshape(bsz, seqlen, dim) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + # leaky_relu(0.5)^2 MLP — better gradient flow than relu^2 for deep/recurrent models + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, use_xsa: bool = False) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x), use_xsa=use_xsa) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + num_repeats: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + num_value_embeds: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.num_repeats = num_repeats + self.xsa_last_n = xsa_last_n + effective_depth = num_layers * num_repeats + self.tok_emb = nn.Embedding(vocab_size, model_dim) + # Value embeddings: extra embedding tables mixed into each effective layer + self.num_value_embeds = num_value_embeds + if num_value_embeds > 0: + self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(num_value_embeds)]) + self.value_scales = nn.Parameter(torch.zeros(effective_depth, num_value_embeds, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + ) + for i in range(num_layers) + ] + ) + # Loop embedding: tells the model which effective layer it's at + self.loop_embed = nn.Parameter(torch.zeros(effective_depth, model_dim, dtype=torch.float32)) + # Cross-repeat skip: each block receives its own output from previous repeat + self.cross_repeat_scales = nn.Parameter(torch.zeros(num_layers, num_repeats - 1, model_dim, dtype=torch.float32)) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + for module in self.modules(): + if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x0 = x + + # Pre-compute value embeddings once + ve_list: list[Tensor] = [] + if self.num_value_embeds > 0: + for ve in self.value_embeds: + ve_list.append(ve(input_ids)) # (bsz, seq, dim) + + cur_repeats = self.cur_repeats if hasattr(self, "cur_repeats") else self.num_repeats + cur_depth = len(self.blocks) * cur_repeats + xsa_start = max(0, cur_depth - self.xsa_last_n) if self.xsa_last_n > 0 else cur_depth + + num_blocks = len(self.blocks) + prev_block_outputs: list[Tensor | None] = [None] * num_blocks + layer_idx = 0 + for repeat in range(cur_repeats): + for block_idx, block in enumerate(self.blocks): + x = x + self.loop_embed[layer_idx].to(dtype=x.dtype) + # Value embeddings: add weighted extra embeddings at each layer + if layer_idx < self.value_scales.size(0): + for ve_idx, ve_out in enumerate(ve_list): + vs = self.value_scales[layer_idx, ve_idx].to(dtype=x.dtype) + x = x + vs[None, None, :] * ve_out + # Cross-repeat skip: mix in this block's output from previous repeat + if repeat > 0 and prev_block_outputs[block_idx] is not None: + rep_idx = min(repeat - 1, self.cross_repeat_scales.size(1) - 1) + scale = self.cross_repeat_scales[block_idx, rep_idx].to(dtype=x.dtype) + x = x + scale[None, None, :] * prev_block_outputs[block_idx] + x = block(x, x0, use_xsa=(layer_idx >= xsa_start)) + prev_block_outputs[block_idx] = x.detach() if not self.training else x + layer_idx += 1 + + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return logits + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + logits = self.forward_logits(input_ids) + logits = logits.reshape(-1, logits.size(-1)) + targets = target_ids.reshape(-1) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + +# TRAINING + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + # DISTRIBUTED + CUDA SETUP + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + grad_accum_steps = max(1, 8 // world_size) + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + # Fast math knobs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + # TOKENIZER + VALIDATION METRIC SETUP + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + num_repeats=args.num_repeats, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + num_value_embeds=args.num_value_embeds, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Optimizer split: + # - token embedding (Adam) uses EMBED_LR + # - untied lm_head (Adam) uses HEAD_LR + # - matrix params in transformer blocks use MATRIX_LR via Muon + # - vectors/scalars use SCALAR_LR via Adam + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params.append(base_model.loop_embed) + scalar_params.append(base_model.cross_repeat_scales) + if base_model.num_value_embeds > 0: + scalar_params.append(base_model.value_scales) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + embed_params = [base_model.tok_emb.weight] + if base_model.num_value_embeds > 0: + embed_params.extend(ve.weight for ve in base_model.value_embeds) + optimizer_tok = torch.optim.Adam( + [{"params": embed_params, "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.Adam( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + # Warmup primes the compiled forward/backward/optimizer paths, then we restore the + # initial weights/optimizer state so measured training starts from the true init. + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + + # Progressive depth schedule: parse "frac:repeats,..." and sort + prog_phases: list[tuple[float, int]] = [] + for entry in args.prog_depth_schedule.split(","): + frac_s, rep_s = entry.strip().split(":") + prog_phases.append((float(frac_s), int(rep_s))) + prog_phases.sort() + current_phase_repeats = prog_phases[0][1] if prog_phases else args.num_repeats + base_model.cur_repeats = current_phase_repeats + # Recompile with initial phase depth + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"prog_depth: schedule={prog_phases} starting_repeats={current_phase_repeats}") + + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + # Progressive depth: check if we need to switch phase + # Use synchronized elapsed time (max across ranks) to avoid race conditions + if max_wallclock_ms is not None and prog_phases: + if distributed: + elapsed_tensor = torch.tensor(elapsed_ms, device=device) + dist.all_reduce(elapsed_tensor, op=dist.ReduceOp.MAX) + frac = elapsed_tensor.item() / max_wallclock_ms + else: + frac = elapsed_ms / max_wallclock_ms + new_repeats = prog_phases[-1][1] # default to last + for phase_frac, phase_rep in prog_phases: + if frac < phase_frac: + new_repeats = phase_rep + break + if new_repeats != current_phase_repeats: + current_phase_repeats = new_repeats + base_model.cur_repeats = new_repeats + torch._dynamo.reset() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + log0(f"prog_depth: switched to {new_repeats} repeats at step:{step} frac:{frac:.2f}") + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown (only at full depth to avoid mixing phases) + at_full_depth = current_phase_repeats == args.num_repeats + if args.swa_enabled and at_full_depth and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().float() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + # Needed to sync whether we've reached the wallclock cap. + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Restore full depth for eval/export + base_model.cur_repeats = args.num_repeats + torch._dynamo.reset() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None: + # Include final weights (may not have landed on swa_every boundary) + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu().float() + swa_count += 1 + log0(f"swa: averaging {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + # Save the raw state (useful for debugging/loading in PyTorch directly), then always produce + # the compressed quantized+zstd artifact and validate the round-tripped weights. + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + zstd_level = int(os.environ.get("ZSTD_LEVEL", 22)) + cctx = zstd.ZstdCompressor(level=zstd_level) + quant_blob = cctx.compress(quant_raw) + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model int8+zstd{zstd_level}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + log0(f"Total submission size int8+zstd{zstd_level}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + dctx = zstd.ZstdDecompressor() + quant_state = torch.load(io.BytesIO(dctx.decompress(quant_blob_disk)), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # Sliding window eval + if args.eval_stride > 0: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_sw = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, + base_model, + rank, + world_size, + device, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"window:{args.eval_seq_len} stride:{args.eval_stride} " + f"eval_time:{1000.0 * (time.perf_counter() - t_sw):.0f}ms" + ) + log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + + # Hedge Mixer eval (n-gram ensemble) + if args.use_hedge: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_hm = time.perf_counter() + hm_val_loss, hm_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + use_hedge=True, hedge_eta=args.hedge_eta, + ) + torch.cuda.synchronize() + log0( + f"final_hedge_mixer val_loss:{hm_val_loss:.4f} val_bpb:{hm_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_hm):.0f}ms" + ) + log0(f"final_hedge_mixer_exact val_loss:{hm_val_loss:.8f} val_bpb:{hm_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/train_gpt.py b/train_gpt.py index 7410c0580..1738288f3 100644 --- a/train_gpt.py +++ b/train_gpt.py @@ -27,6 +27,76 @@ from torch import Tensor, nn from torch.nn.parallel import DistributedDataParallel as DDP + +class HedgeMixer: + """Online mixture of 5 experts via Hedge algorithm for eval-time improvement. + Experts: Neural, Unigram, Bigram, Trigram (hashed), Entropy.""" + def __init__(self, vocab_size: int = 1024, device: str = "cuda", eta: float = 0.1): + self.V = vocab_size + self.device = device + self.eta = eta + self.log_weights = torch.zeros(5, device=device) + self.log_weights[0] = 2.0 # bias toward neural + self.uni_counts = torch.zeros(vocab_size, device=device) + self.bi_counts = torch.zeros(vocab_size, vocab_size, device=device) + self.total_tokens = 0 + self.TRI_HASH = 65536 + self.tri_counts = torch.zeros(self.TRI_HASH, vocab_size, device=device) + self.tri_row_totals = torch.zeros(self.TRI_HASH, device=device) + + def update(self, tokens: Tensor) -> None: + t = tokens.to(self.device).long() + n = t.numel() + if n == 0: + return + self.total_tokens += n + ones = torch.ones(n, device=self.device) + self.uni_counts.scatter_add_(0, t, ones) + if n >= 2: + bi_idx = t[:-1] * self.V + t[1:] + self.bi_counts.reshape(-1).scatter_add_(0, bi_idx, torch.ones(n - 1, device=self.device)) + if n >= 3: + tri_ctx = ((t[:-2] * 36313) ^ (t[1:-1] * 27191)) % self.TRI_HASH + tri_idx = tri_ctx * self.V + t[2:] + ones_tri = torch.ones(n - 2, device=self.device) + self.tri_counts.reshape(-1).scatter_add_(0, tri_idx, ones_tri) + self.tri_row_totals.scatter_add_(0, tri_ctx, ones_tri) + + def mix_and_score(self, neural_logits: Tensor, x_batch: Tensor, y_batch: Tensor, wlens: list[int]) -> Tensor: + bsz, slen, V = neural_logits.shape + uniform_nll = math.log(self.V) + has_data = self.total_tokens > 0 + neural_lp = F.log_softmax(neural_logits, dim=-1) + neural_nll = -neural_lp.gather(2, y_batch.unsqueeze(2)).squeeze(2) + if not has_data or self.total_tokens < 10000: + return neural_nll + uni_probs = (self.uni_counts + 0.1) / (self.total_tokens + 0.1 * self.V) + uni_nll = -uni_probs.log()[y_batch] + bi_total = self.bi_counts.sum(dim=1, keepdim=True) + bi_probs = (self.bi_counts + 0.1) / (bi_total + 0.1 * self.V) + bi_nll = -bi_probs.log()[x_batch.reshape(-1), y_batch.reshape(-1)].reshape(bsz, slen) + if slen >= 2: + prev2 = torch.zeros_like(x_batch) + prev2[:, 1:] = x_batch[:, :-1] + ctx_hash = ((prev2 * 36313) ^ (x_batch * 27191)) % self.TRI_HASH + tri_count = self.tri_counts[ctx_hash.reshape(-1).long(), y_batch.reshape(-1).long()] + tri_total = self.tri_row_totals[ctx_hash.reshape(-1).long()].clamp(min=1) + tri_nll = -(((tri_count + 0.01) / (tri_total + 0.01 * self.V)).log()).reshape(bsz, slen) + else: + tri_nll = torch.full((bsz, slen), uniform_nll, device=self.device) + entropy_nll = -(neural_lp.exp() * neural_lp).sum(-1) + expert_nll = torch.stack([neural_nll, uni_nll, bi_nll, tri_nll, entropy_nll], dim=-1) + log_w = self.log_weights - self.log_weights.logsumexp(0) + mixed_nll = -(-expert_nll + log_w.unsqueeze(0).unsqueeze(0)).logsumexp(dim=-1) + # Update weights + wlens_t = torch.tensor(wlens, device=self.device, dtype=torch.long) + mask = torch.arange(slen, device=self.device).unsqueeze(0) < wlens_t.unsqueeze(1) + masked_nll = expert_nll * mask.unsqueeze(-1).float() + expert_mean_loss = masked_nll.sum(dim=(0, 1)) / mask.sum().clamp(min=1) + self.log_weights -= self.eta * expert_mean_loss + return mixed_nll + + # HYPERPARAMETERS class Hyperparameters: @@ -70,6 +140,10 @@ class Hyperparameters: eval_stride = int(os.environ.get("EVAL_STRIDE", 256)) eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + # Hedge Mixer (eval-time n-gram ensemble). + use_hedge = bool(int(os.environ.get("USE_HEDGE", "1"))) + hedge_eta = float(os.environ.get("HEDGE_ETA", 0.1)) + # Model shape. vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) num_layers = int(os.environ.get("NUM_LAYERS", 3)) @@ -300,9 +374,12 @@ def eval_val_sliding( base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + use_hedge: bool = False, + hedge_eta: float = 0.1, ) -> tuple[float, float]: """Sliding window eval with batching. Windows of train_seq_len advance by eval_stride. - Only the last stride tokens per window are scored (first window scores all).""" + Only the last stride tokens per window are scored (first window scores all). + Optional Hedge Mixer: online n-gram ensemble over scored tokens.""" seq_len = args.eval_seq_len stride = args.eval_stride batch_seqs = args.eval_batch_seqs @@ -310,10 +387,18 @@ def eval_val_sliding( window_starts = [ws for ws in range(0, total_tokens, stride) if min(ws + seq_len, total_tokens) - ws >= 1] - total_windows = len(window_starts) - my_s = (total_windows * rank) // world_size - my_e = (total_windows * (rank + 1)) // world_size - my_windows = window_starts[my_s:my_e] + + # With Hedge Mixer: process ALL windows on each rank (sequential, n-gram tables need full context) + # Without: distribute windows across ranks + if use_hedge: + my_windows = window_starts + else: + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + mixer = HedgeMixer(vocab_size=args.vocab_size, device=device, eta=hedge_eta) if use_hedge else None val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) val_token_count = torch.zeros((), device=device, dtype=torch.float64) @@ -326,7 +411,7 @@ def eval_val_sliding( bsz = len(batch_ws) x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) - wlens = [] + wlens: list[int] = [] for i, ws in enumerate(batch_ws): end = min(ws + seq_len, total_tokens) @@ -339,11 +424,14 @@ def eval_val_sliding( with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): logits = base_model.forward_logits(x_batch) - nll = F.cross_entropy( - logits.reshape(-1, logits.size(-1)).float(), - y_batch.reshape(-1), - reduction="none", - ).reshape(bsz, seq_len) + if mixer is not None: + nll = mixer.mix_and_score(logits.float(), x_batch, y_batch, wlens) + else: + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) for i, ws in enumerate(batch_ws): wlen = wlens[i] @@ -357,7 +445,14 @@ def eval_val_sliding( token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) val_byte_count += token_bytes.to(torch.float64).sum() - if dist.is_available() and dist.is_initialized(): + # Update n-gram tables with scored tokens + if mixer is not None: + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + mixer.update(y_batch[i, s:wlen]) + + if not use_hedge and dist.is_available() and dist.is_initialized(): dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) @@ -1378,6 +1473,23 @@ def lr_mul(step: int, elapsed_ms: float) -> float: ) log0(f"final_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + # Hedge Mixer eval (n-gram ensemble) + if args.use_hedge: + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_hm = time.perf_counter() + hm_val_loss, hm_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + use_hedge=True, hedge_eta=args.hedge_eta, + ) + torch.cuda.synchronize() + log0( + f"final_hedge_mixer val_loss:{hm_val_loss:.4f} val_bpb:{hm_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_hm):.0f}ms" + ) + log0(f"final_hedge_mixer_exact val_loss:{hm_val_loss:.8f} val_bpb:{hm_val_bpb:.8f}") + if distributed: dist.destroy_process_group()