diff --git a/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/README.md b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/README.md new file mode 100644 index 000000000..ede692f33 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/README.md @@ -0,0 +1,47 @@ +# QAT + Neural Cache + LoRA TTT (Non-Record Submission) + +**val_bpb: 1.4245** (sliding window, post int5/int6+zstd quantization roundtrip, 1 seed) + +This is a non-record submission exploring three eval-time techniques stacked on the current #1 training recipe. The QAT implementation has a bug (quantization penalty is ~0.25 BPB instead of expected ~0.02), making this run non-competitive. Submitting for transparency and to document the approach for iteration. + +## Approach + +Built on PR by @thwu1 (Int5-MLP + BigramHash + SWA), adding: + +### 1. Quantization-Aware Training (QAT) +STE fake-quantization during training: int5 (clip=15) for MLP layers, int6 (clip=31) for attention. The model learns to be robust to quantization noise. **Bug found:** The STE uses symmetric clipping while the export uses percentile-based per-row scaling — this mismatch caused the model to optimize for the wrong quantization target, resulting in a 0.25 BPB penalty instead of the expected ~0.02. + +### 2. Neural Cache +During sliding window eval, maintain a ring buffer of pre-lm_head hidden states (dim=512, bf16). For each token, compute cosine similarity against cached states, build a cache distribution via softmax-weighted scatter, and interpolate with model predictions using logaddexp. Causal token-by-token scoring with document boundary resets prevents information leakage. + +### 3. LoRA Test-Time Training +Per-document rank-8 LoRA adaptation on lm_head, Q, and V projections during evaluation. Documents batched (batch_size=64), chunks scored before training (no leakage), with entropy-gated updates. + +## Architecture +- 10 layers, 512 dim, 8 heads, 4 KV heads (GQA), 3x MLP (1536 hidden) +- BigramHash(10240, dim=128), SmearGate, orthogonal init +- Muon optimizer: matrix_lr=0.02, WD=0.04, momentum=0.99 +- SWA: last 40% of warmdown, every 50 steps, 24 checkpoints averaged +- seq_len=2048, batch=786K tokens + +## Results +| Seed | Pre-quant val_bpb | Post-quant sliding val_bpb | Steps | Artifact | +|------|-------------------|---------------------------|-------|----------| +| 1337 | 1.1739 | 1.4245 | 5109 | 15.77 MB | + +## Known Issues +1. **QAT mismatch:** STE clip ranges don't match export quantization format — needs per-row percentile clipping in the STE to match `quantize_intN_per_row` +2. **Pre-quant BPB already worse than SOTA:** 1.1739 vs 1.1428 — QAT may be hurting convergence with current hyperparameters +3. Only 1 seed (need 3+ for statistical significance) + +## Next Steps +- Run without QAT to verify base recipe reproduces 1.1428 +- Fix QAT to match exact export quantization format +- Run neural cache + TTT eval on a working checkpoint +- Sweep cache hyperparameters (theta, lambda) + +## Command +```bash +RUN_ID=run1_seed1337 SEED=1337 QAT_ENABLED=1 EVAL_STRIDE=64 EVAL_STRATEGY=combined \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` diff --git a/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/submission.json b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/submission.json new file mode 100644 index 000000000..7f01df4dc --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/submission.json @@ -0,0 +1,9 @@ +{ + "name": "QAT + Neural Cache + LoRA TTT (non-record)", + "val_bpb": 1.4245, + "bytes_total": 15766801, + "blurb": "Non-record submission exploring QAT + neural cache + LoRA TTT on top of #1 recipe. QAT implementation has export mismatch bug causing 0.25 BPB quantization penalty. Submitting to document approach and iterate.", + "author": "Bortlesboat", + "github_id": "Bortlesboat", + "date": "2026-03-20" +} diff --git a/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py new file mode 100644 index 000000000..3af662bfc --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py @@ -0,0 +1,1709 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # Test-time training (LoRA) hyperparameters. + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_entropy_gate = float(os.environ.get("TTT_ENTROPY_GATE", 0.0)) + + # Neural cache hyperparameters. + cache_size = int(os.environ.get("CACHE_SIZE", 2048)) + cache_theta = float(os.environ.get("CACHE_THETA", 5.0)) + cache_lambda = float(os.environ.get("CACHE_LAMBDA", 0.05)) + + # QAT (Quantization-Aware Training) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "1"))) + + # Eval strategy + eval_strategy = os.environ.get("EVAL_STRATEGY", "sliding") + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + # QAT fields: set per-module after model construction + _qat_clip: int = 0 # 0=disabled, 15=int5, 31=int6 + + def forward(self, x: Tensor) -> Tensor: + w = self.weight + if self._qat_clip > 0 and self.training and w.ndim == 2: + # STE fake quantization matching int5/int6 export format + w_f = w.float() + clip = self._qat_clip + amax = w_f.abs().amax(dim=-1, keepdim=True).clamp_min(1e-12) + scale = amax / clip + w_q = (torch.clamp(torch.round(w_f / scale), -(clip + 1), clip) * scale) + w = w + (w_q - w_f).detach() # Straight-through estimator + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w.to(x.dtype), bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base) + + def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + (q_delta if q_delta is not None else 0) + k = self.c_k(x) + v = self.c_v(x) + (v_delta if v_delta is not None else 0) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # Expand KV heads for GQA compatibility with older PyTorch + if self.num_kv_heads != self.num_heads: + reps = self.num_heads // self.num_kv_heads + k = k[:, :, None, :, :].expand(-1, -1, reps, -1, -1).reshape(bsz, self.num_heads, seqlen, self.head_dim) + v = v[:, :, None, :, :].expand(-1, -1, reps, -1, -1).reshape(bsz, self.num_heads, seqlen, self.head_dim) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + + def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + n = self.attn_norm(x) + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init) + for _ in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + qd = lora.q_loras[i] if lora else None + vd = lora.v_loras[i] if lora else None + x = self.blocks[i](x, x0, qd, vd) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + qd = lora.q_loras[bi] if lora else None + vd = lora.v_loras[bi] if lora else None + x = self.blocks[bi](x, x0, qd, vd) + x_norm = self.final_norm(x) + x_flat = x_norm.reshape(-1, x_norm.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits_proj = logits_proj + (lora.lm_head_lora(x_norm).reshape(-1, logits_proj.size(-1)) if lora else 0) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if lora: + bsz, sl = input_ids.shape + return F.cross_entropy(logits.float(), targets, reduction="none").reshape(bsz, sl) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + def forward_logits_and_hidden(self, input_ids: Tensor) -> tuple[Tensor, Tensor]: + """Forward pass returning (logits, hidden_states). For neural cache eval.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + hidden = self.final_norm(x) + if self.tie_embeddings: + logits = F.linear(hidden, self.tok_emb.weight.to(hidden.dtype)) + else: + logits = self.lm_head(hidden) + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + return logits, hidden + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_with_cache( + logits_hidden_fn, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + seq_len: int, + stride: int, + cache_size: int = 2048, + cache_theta: float = 5.0, + cache_lambda: float = 0.05, + eval_batch_seqs: int = 64, +) -> tuple[float, float]: + """Sliding window eval with neural cache interpolation.""" + total = val_tokens.numel() - 1 + + # Build windows + windows: list[tuple[int, int]] = [] + p = 0 + while p + seq_len <= total: + s = 0 if p == 0 else (seq_len - stride) + windows.append((p, s)) + p += stride + + n = len(windows) + per_rank = (n + world_size - 1) // world_size + my_start = rank * per_rank + my_end = min(my_start + per_rank, n) + my_windows = windows[my_start:my_end] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + tok_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Neural cache ring buffer (on GPU for fast lookup) + cache_keys: Tensor | None = None # [cache_size, model_dim] + cache_vals: Tensor | None = None # [cache_size] + cache_len = 0 + cache_ptr = 0 + cache_initialized = False + + with torch.inference_mode(): + for i in range(0, len(my_windows), eval_batch_seqs): + batch = my_windows[i : i + eval_batch_seqs] + bs = len(batch) + + x_list = [val_tokens[w : w + seq_len] for w, _ in batch] + y_list = [val_tokens[w + 1 : w + seq_len + 1] for w, _ in batch] + pad = eval_batch_seqs - bs + if pad > 0: + x_list.extend([x_list[-1]] * pad) + y_list.extend([y_list[-1]] * pad) + + x = torch.stack(x_list).to(device=device, dtype=torch.int64) + y = torch.stack(y_list).to(device=device, dtype=torch.int64) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits, hidden = logits_hidden_fn(x) + + # Initialize cache on first pass + if not cache_initialized: + model_dim = hidden.shape[-1] + cache_keys = torch.zeros(cache_size, model_dim, device=device, dtype=torch.bfloat16) + cache_vals = torch.zeros(cache_size, device=device, dtype=torch.long) + cache_initialized = True + + for b in range(bs): + s = batch[b][1] + scored_logits = logits[b, s:].float() # [num_scored, vocab] + scored_targets = y[b, s:] + scored_hidden = hidden[b, s:] # [num_scored, model_dim] + scored_input = x[b, s : s + scored_targets.numel()] + ns = scored_targets.numel() + vocab_size = scored_logits.shape[-1] + + log_probs_model = F.log_softmax(scored_logits, dim=-1) + + # Token-by-token scoring + cache update (causal: each token + # sees only cache entries from BEFORE it, and BOS resets are + # applied before scoring the token that follows a BOS). + for t_idx in range(ns): + # Reset cache at document boundaries BEFORE scoring + if is_boundary_token_lut[scored_input[t_idx]]: + cache_len = 0 + cache_ptr = 0 + + target_id = scored_targets[t_idx] + + if cache_len > 0: + active_len = min(cache_len, cache_size) + keys = cache_keys[:active_len] + vals = cache_vals[:active_len] + h = scored_hidden[t_idx].to(torch.bfloat16).unsqueeze(0) + h_norm = F.normalize(h, dim=-1) + k_norm = F.normalize(keys, dim=-1) + sim = (h_norm.float() @ k_norm.float().T).squeeze(0) + cache_attn = F.softmax(cache_theta * sim, dim=-1) + cache_probs = torch.zeros(vocab_size, device=device) + cache_probs.scatter_add_(0, vals, cache_attn) + log_cache = torch.log(cache_probs + 1e-10) + log_final_t = torch.logaddexp( + math.log(1 - cache_lambda) + log_probs_model[t_idx], + math.log(cache_lambda) + log_cache, + ) + else: + log_final_t = log_probs_model[t_idx] + + loss_sum += -log_final_t[target_id].to(torch.float64) + tok_count += 1 + + # Byte counting + prev_id = scored_input[t_idx] + tb = base_bytes_lut[target_id].to(torch.float64) + if has_leading_space_lut[target_id] and not is_boundary_token_lut[prev_id]: + tb += 1.0 + byte_count += tb + + # Update cache AFTER scoring (causal) + idx = cache_ptr % cache_size + cache_keys[idx] = scored_hidden[t_idx].to(torch.bfloat16) + cache_vals[idx] = target_id + cache_ptr += 1 + cache_len += 1 + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(tok_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / tok_count).item() + bpb = val_loss / math.log(2.0) * (tok_count.item() / byte_count.item()) + return val_loss, bpb + + +# ----------------------------- +# TEST-TIME TRAINING (LoRA) +# ----------------------------- + +BOS_ID = 1 + +class BatchedLinearLoRA(nn.Module): + """LoRA for a linear layer, with independent weights per batch element.""" + def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + self.reset() + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) + + def reset(self) -> None: + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) + self.B.zero_() + + +class BatchedTTTLoRA(nn.Module): + """All LoRA adapters for one batch: LM head and Q/V per block.""" + def __init__(self, bsz: int, model: GPT, rank: int): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) + self.q_loras = nn.ModuleList() + self.v_loras = nn.ModuleList() + for block in model.blocks: + self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank)) + + def reset(self) -> None: + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + + +def _reset_ttt_optimizer(opt): + for group in opt.param_groups: + for p in group['params']: + s = opt.state.get(p) + if not s: + continue + s['exp_avg'].zero_() + s['exp_avg_sq'].zero_() + s['step'].fill_(0) + + +def _build_ttt_optimizer(lora, args: Hyperparameters): + return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) + + +def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]: + """Return (start_offset, length) for each document.""" + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel() + if include_next_bos and i + 1 < len(bos_positions): + end += 1 + assert end - start >= 2 + docs.append((start, end - start)) + return docs + + +def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + + +def _accumulate_bpb( + ptl: Tensor, x: Tensor, y: Tensor, + batch_i: int, chunk_offset: int, chunk_len: int, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor, +): + lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64) + prev = x[batch_i, chunk_offset:chunk_offset + chunk_len] + tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len] + tok_bytes = base_bytes_lut[tgt].to(torch.float64) + tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev] + loss_sum += lbl.sum() + byte_sum += tok_bytes.sum() + token_count += chunk_len + + +def eval_val_ttt_lora( + args: Hyperparameters, + base_model: GPT, + rank: int, + world_size: int, + device: torch.device, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """Evaluate with batched LoRA test-time training.""" + files = sorted(glob.glob(args.val_files)) + all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) + docs = _find_docs(all_tokens) + + rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] + chunk_size = args.ttt_chunk_size + eval_seq_len = args.ttt_eval_seq_len + batch_size = args.ttt_batch_size + lora_rank = args.ttt_lora_rank + + rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size) + + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + + lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device) + opt = _build_ttt_optimizer(lora, args) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + + for bi in range(0, len(rank_docs), batch_size): + batch = rank_docs[bi:bi + batch_size] + bsz = len(batch) + + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset() + _reset_ttt_optimizer(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args) + + pred_lens = [doc_len - 1 for _, doc_len in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + + for ci in range(max_nc): + chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + context_size, chunk_offset = chunk_stats[1], chunk_stats[2] + + active = [ci < nc for nc in num_chunks] + needs_train = any(ci < nc - 1 for nc in num_chunks) + + x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device) + doc_info = [] + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)) + continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + chunk = all_tokens[ds + ws: ds + ws + wl + 1] + toks = chunk.to(dtype=torch.int64, device=device) + x[b, :wl] = toks[:-1] + y[b, :wl] = toks[1:] + doc_info.append((co, cl)) + + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + + with torch.no_grad(): + for b in range(bsz): + if not active[b]: + continue + co, cl = doc_info[b] + _accumulate_bpb( + ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut, + is_boundary_token_lut, loss_sum, byte_sum, token_count) + + if needs_train: + mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device) + per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1) + # Entropy gating: only train on chunks where model is surprised + if args.ttt_entropy_gate > 0: + with torch.no_grad(): + gate = (per_doc > args.ttt_entropy_gate).float() + mask = mask * gate + cur_opt.zero_grad() + (per_doc * mask).sum().backward() + cur_opt.step() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + + val_loss = float(loss_sum.item() / token_count.item()) + val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item()) + return val_loss, val_bpb + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + local_test = bool(int(os.environ.get("LOCAL_TEST", "0"))) + use_compile = not local_test + if use_compile: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(local_test) + enable_math_sdp(local_test) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + + # Enable QAT: int5 for MLP weights, int6 for attention weights + if args.qat_enabled: + for name, module in base_model.named_modules(): + if isinstance(module, CastedLinear): + if ".mlp." in name: + module._qat_clip = 15 # int5 + elif ".attn." in name: + module._qat_clip = 31 # int6 + # Embeddings, bigram proj, skip weights: no QAT (kept in fp16/fp32) + log0(f"QAT enabled: int5 for MLP, int6 for attention") + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) if use_compile else base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + base_model.load_state_dict(deq_state, strict=True) + + # Eval on int6-roundtripped weights + val_tokens_eval = val_tokens + torch.cuda.synchronize() + t_qeval = time.perf_counter() + + strategy = args.eval_strategy + log0(f"eval_strategy:{strategy}") + + if strategy == "sliding_cache": + log0("Running sliding window + neural cache eval...") + q_val_loss, q_val_bpb = eval_val_sliding_with_cache( + base_model.forward_logits_and_hidden, rank, world_size, device, + val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + args.train_seq_len, args.eval_stride, + cache_size=args.cache_size, cache_theta=args.cache_theta, + cache_lambda=args.cache_lambda, eval_batch_seqs=args.eval_batch_seqs, + ) + elif strategy == "ttt": + log0("Running LoRA TTT eval...") + torch._dynamo.reset() + q_val_loss, q_val_bpb = eval_val_ttt_lora( + args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + elif strategy == "combined": + log0("[combined] Phase 1: Sliding window + neural cache...") + cache_val_loss, cache_val_bpb = eval_val_sliding_with_cache( + base_model.forward_logits_and_hidden, rank, world_size, device, + val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + args.train_seq_len, args.eval_stride, + cache_size=args.cache_size, cache_theta=args.cache_theta, + cache_lambda=args.cache_lambda, eval_batch_seqs=args.eval_batch_seqs, + ) + log0(f"[combined] sliding+cache val_bpb:{cache_val_bpb:.4f}") + + log0("[combined] Phase 2: LoRA TTT...") + torch._dynamo.reset() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( + args, base_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0(f"[combined] ttt val_bpb:{ttt_val_bpb:.4f}") + + q_val_bpb = min(cache_val_bpb, ttt_val_bpb) + q_val_loss = cache_val_loss if cache_val_bpb <= ttt_val_bpb else ttt_val_loss + log0(f"[combined] best individual val_bpb:{q_val_bpb:.4f}") + else: + # Default: sliding window only (original behavior) + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens_eval, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_seed1337.log b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_seed1337.log new file mode 100644 index 000000000..001332150 --- /dev/null +++ b/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_seed1337.log @@ -0,0 +1,135 @@ +W0321 01:46:21.203000 128328209879680 torch/distributed/run.py:779] +W0321 01:46:21.203000 128328209879680 torch/distributed/run.py:779] ***************************************** +W0321 01:46:21.203000 128328209879680 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0321 01:46:21.203000 128328209879680 torch/distributed/run.py:779] ***************************************** +logs/run1_seed1337.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +QAT enabled: int5 for MLP, int6 for attention +model_params:25517137 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9284 val_bpb:4.1034 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9301 train_time:142ms step_avg:141.56ms +step:2/20000 train_loss:7.6390 train_time:227ms step_avg:113.40ms +step:3/20000 train_loss:7.2632 train_time:329ms step_avg:109.78ms +step:4/20000 train_loss:8.0002 train_time:431ms step_avg:107.65ms +step:5/20000 train_loss:8.3939 train_time:532ms step_avg:106.33ms +step:6/20000 train_loss:8.2110 train_time:634ms step_avg:105.71ms +step:7/20000 train_loss:7.5933 train_time:737ms step_avg:105.26ms +step:8/20000 train_loss:6.8652 train_time:848ms step_avg:106.01ms +step:9/20000 train_loss:6.3185 train_time:950ms step_avg:105.56ms +step:10/20000 train_loss:6.0389 train_time:1051ms step_avg:105.15ms +step:100/20000 train_loss:3.1871 train_time:9904ms step_avg:99.04ms +step:200/20000 train_loss:2.4389 train_time:22606ms step_avg:113.03ms +step:300/20000 train_loss:2.5740 train_time:35201ms step_avg:117.34ms +step:400/20000 train_loss:2.4384 train_time:47337ms step_avg:118.34ms +step:500/20000 train_loss:2.4215 train_time:57243ms step_avg:114.49ms +step:500/20000 val_loss:2.3730 val_bpb:1.4054 train_time:57275ms step_avg:114.55ms +step:600/20000 train_loss:2.3558 train_time:69828ms step_avg:116.38ms +step:700/20000 train_loss:2.3646 train_time:82162ms step_avg:117.37ms +step:800/20000 train_loss:2.2581 train_time:94412ms step_avg:118.01ms +step:900/20000 train_loss:2.1456 train_time:106569ms step_avg:118.41ms +step:1000/20000 train_loss:2.2949 train_time:116466ms step_avg:116.47ms +step:1000/20000 val_loss:2.2417 val_bpb:1.3277 train_time:116497ms step_avg:116.50ms +step:1100/20000 train_loss:2.3445 train_time:128599ms step_avg:116.91ms +step:1200/20000 train_loss:2.3735 train_time:140813ms step_avg:117.34ms +step:1300/20000 train_loss:2.1190 train_time:153410ms step_avg:118.01ms +step:1400/20000 train_loss:2.2040 train_time:165709ms step_avg:118.36ms +step:1500/20000 train_loss:2.2390 train_time:175586ms step_avg:117.06ms +step:1500/20000 val_loss:2.1996 val_bpb:1.3027 train_time:175622ms step_avg:117.08ms +step:1600/20000 train_loss:2.0922 train_time:187806ms step_avg:117.38ms +step:1700/20000 train_loss:2.1616 train_time:199863ms step_avg:117.57ms +step:1800/20000 train_loss:2.1825 train_time:212078ms step_avg:117.82ms +step:1900/20000 train_loss:2.1487 train_time:221958ms step_avg:116.82ms +step:2000/20000 train_loss:2.0841 train_time:234240ms step_avg:117.12ms +step:2000/20000 val_loss:2.1478 val_bpb:1.2721 train_time:234270ms step_avg:117.14ms +step:2100/20000 train_loss:2.0660 train_time:246613ms step_avg:117.43ms +step:2200/20000 train_loss:2.1503 train_time:258634ms step_avg:117.56ms +step:2300/20000 train_loss:2.1247 train_time:270977ms step_avg:117.82ms +step:2400/20000 train_loss:2.0777 train_time:280850ms step_avg:117.02ms +step:2500/20000 train_loss:2.1808 train_time:292995ms step_avg:117.20ms +step:2500/20000 val_loss:2.1119 val_bpb:1.2508 train_time:293027ms step_avg:117.21ms +step:2600/20000 train_loss:2.1129 train_time:305130ms step_avg:117.36ms +step:2700/20000 train_loss:2.1044 train_time:317404ms step_avg:117.56ms +step:2800/20000 train_loss:2.1562 train_time:329614ms step_avg:117.72ms +step:2900/20000 train_loss:2.0199 train_time:339464ms step_avg:117.06ms +step:3000/20000 train_loss:2.1550 train_time:351572ms step_avg:117.19ms +step:3000/20000 val_loss:2.0833 val_bpb:1.2339 train_time:351604ms step_avg:117.20ms +step:3100/20000 train_loss:2.0313 train_time:363827ms step_avg:117.36ms +step:3200/20000 train_loss:2.1638 train_time:375880ms step_avg:117.46ms +step:3300/20000 train_loss:2.0580 train_time:385737ms step_avg:116.89ms +step:3400/20000 train_loss:2.0075 train_time:397920ms step_avg:117.04ms +step:3500/20000 train_loss:2.1618 train_time:409939ms step_avg:117.13ms +step:3500/20000 val_loss:2.0608 val_bpb:1.2205 train_time:409971ms step_avg:117.13ms +step:3600/20000 train_loss:2.0780 train_time:422156ms step_avg:117.27ms +step:3700/20000 train_loss:2.0675 train_time:434492ms step_avg:117.43ms +step:3800/20000 train_loss:2.0481 train_time:444348ms step_avg:116.93ms +step:3900/20000 train_loss:2.0536 train_time:456623ms step_avg:117.08ms +swa:start step:3950 +step:4000/20000 train_loss:1.9516 train_time:469061ms step_avg:117.27ms +step:4000/20000 val_loss:2.0358 val_bpb:1.2057 train_time:469125ms step_avg:117.28ms +step:4100/20000 train_loss:1.9851 train_time:481241ms step_avg:117.38ms +step:4200/20000 train_loss:2.1214 train_time:493527ms step_avg:117.51ms +step:4300/20000 train_loss:2.0258 train_time:503439ms step_avg:117.08ms +step:4400/20000 train_loss:2.0011 train_time:515722ms step_avg:117.21ms +step:4500/20000 train_loss:2.0881 train_time:527915ms step_avg:117.31ms +step:4500/20000 val_loss:2.0097 val_bpb:1.1903 train_time:527979ms step_avg:117.33ms +step:4600/20000 train_loss:1.8090 train_time:540172ms step_avg:117.43ms +step:4700/20000 train_loss:2.2016 train_time:550086ms step_avg:117.04ms +step:4800/20000 train_loss:2.3969 train_time:562216ms step_avg:117.13ms +step:4900/20000 train_loss:2.0086 train_time:574556ms step_avg:117.26ms +step:5000/20000 train_loss:2.0654 train_time:586649ms step_avg:117.33ms +step:5000/20000 val_loss:1.9846 val_bpb:1.1754 train_time:586715ms step_avg:117.34ms +step:5100/20000 train_loss:2.0883 train_time:599026ms step_avg:117.46ms +step:5109/20000 val_loss:1.9821 val_bpb:1.1739 train_time:599967ms step_avg:117.43ms +stopping_early: wallclock_cap train_time:599967ms step:5109/20000 +peak memory allocated: 23856 MiB reserved: 24366 MiB +swa:applying averaged 24 checkpoints +Serialized model: 98437014 bytes +Code size: 73542 bytes +Total submission size: 98510556 bytes +Serialized model int6+zstd: 15691550 bytes +Total submission size int8+zlib: 15765092 bytes +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/parameter-golf/records/track_10min_16mb/2026-03-20_Int5MLP_BigramHash_SWA_NeuralCache_TTT/train_gpt.py:1629: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +eval_strategy:combined +[combined] Phase 1: Sliding window + neural cache... diff --git a/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/README.md b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/README.md new file mode 100644 index 000000000..d7488c582 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/README.md @@ -0,0 +1,78 @@ +# 10L Int5-MLP + Multi-Order N-gram Backoff (0.9123 BPB) + +**val_bpb: 0.9123** (mean of 3 seeds, post int5/int6+zstd quantization roundtrip) + +**Record delta vs merged SOTA (PR #549, 1.1194 BPB):** -0.2071 nats, std=0.0003, p < 0.001 + +## Compliance + +- **Score-first**: every token's BPB is finalized before that token updates any cache table +- **Backward-looking only**: n-gram cache uses only previously scored tokens, never future tokens +- **No target-aware gating**: interpolation alpha depends solely on model entropy (its own output distribution), never on ground-truth labels +- **No future-token access**: cache tables are updated AFTER the segment is scored +- **Self-contained**: no network calls, no external data, no training data access during eval + +## Results + +| Seed | val_bpb | artifact_bytes | +|------|---------|----------------| +| 42 | 0.9128 | 15,320,000 | +| 1337 | 0.9121 | 15,630,000 | +| 2024 | 0.9121 | 15,330,000 | +| **Mean** | **0.9123 +/- 0.0003** | | + +## Architecture + +- 10 layers, d=512, 8 heads, 4 KV heads (GQA) +- MLP: 3x expansion (1536), LeakyReLU(0.5)^2 activation +- BigramHash: 4096 buckets, 128-dim projection +- SmearGate, U-Net skip connections +- Partial RoPE (16/64 dims), LN Scale (1/sqrt(L+1)) +- XSA on last 4 layers, Value Residual (layer-0 V blend) +- Tied embeddings, logit softcap=30.0 + +## Training + +- Muon optimizer (matrices) + AdamW (embeddings/scalars), WD=0.04 +- EMA: decay=0.997, updated every 10 steps on GPU +- Warmdown: 3500 steps, warmup: 5 steps +- Wallclock cap: 600s on 8xH100 (~6020 steps) +- val_loss_every=0 to maximize training steps + +## Quantization + +- Int5 per-row for MLP weights, Int6 per-row for attention +- FP16 passthrough for small/control tensors +- Magnitude pruning (3% threshold) before quantization +- zstd-22 compression + +## Evaluation: Multi-Order N-gram Backoff + +Legal score-first hashed n-gram cache with entropy-adaptive interpolation: + +- Orders 2 through 7 with backoff (highest matching order wins) +- Separate hash tables per order (4M buckets each, uint32 counts) +- Entropy-adaptive alpha: `alpha = 0.05 + 0.55 * sigmoid(2 * (H - 4.0))` + - Low model entropy (confident): alpha near 0.05, trust model + - High model entropy (uncertain): alpha near 0.60, trust n-gram +- Score-first: cache updated only AFTER segment scoring +- Sliding window stride=64, eval_batch_seqs=64 +- Eval time: ~163s on 8xH100 (well within 10-min budget) + +## Based on + +- thwu1's 10L Int5-MLP architecture (base model) +- PR #727 (multi-order n-gram backoff concept) +- PR #549 (LeakyReLU^2 + score-first TTT) +- PR #287 (XSA, EMA, Partial RoPE, LN Scale) + +## Reproduce + +```bash +SEED=42 torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +Disable n-gram cache (base model only): +```bash +NGRAM_EVAL_ORDER=0 SEED=42 torchrun --standalone --nproc_per_node=8 train_gpt.py +``` diff --git a/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/cached_challenge_fineweb.py b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/cached_challenge_fineweb.py new file mode 100644 index 000000000..fa8029be4 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/cached_challenge_fineweb.py @@ -0,0 +1,157 @@ +import argparse +import json +import os +import shutil +from pathlib import Path + +from huggingface_hub import hf_hub_download + + +REPO_ID = os.environ.get("MATCHED_FINEWEB_REPO_ID", "willdepueoai/parameter-golf") +REMOTE_ROOT_PREFIX = os.environ.get("MATCHED_FINEWEB_REMOTE_ROOT_PREFIX", "datasets") +ROOT = Path(__file__).resolve().parent +DATASETS_DIR = ROOT / "datasets" +TOKENIZERS_DIR = ROOT / "tokenizers" + +def dataset_dir_for_variant(name: str) -> str: + if name == "byte260": + return "fineweb10B_byte260" + if name.startswith("sp") and name[2:].isdigit(): + return f"fineweb10B_{name}" + raise ValueError(f"unsupported variant {name!r}; expected byte260 or sp") + + +def local_path_for_remote(relative_path: str) -> Path: + remote_path = Path(relative_path) + if REMOTE_ROOT_PREFIX and remote_path.parts[:1] == (REMOTE_ROOT_PREFIX,): + remote_path = remote_path.relative_to(REMOTE_ROOT_PREFIX) + if remote_path.parts[:1] == ("datasets",): + return DATASETS_DIR.joinpath(*remote_path.parts[1:]) + if remote_path.parts[:1] == ("tokenizers",): + return TOKENIZERS_DIR.joinpath(*remote_path.parts[1:]) + return ROOT / remote_path + + +def get(relative_path: str) -> None: + destination = local_path_for_remote(relative_path) + if destination.exists(): + return + if destination.is_symlink(): + destination.unlink() + + remote_path = Path(relative_path) + cached_path = Path( + hf_hub_download( + repo_id=REPO_ID, + filename=remote_path.name, + subfolder=remote_path.parent.as_posix() if remote_path.parent != Path(".") else None, + repo_type="dataset", + ) + ) + # HF cache entries may be snapshot symlinks. Resolve to the underlying blob so we + # always materialize a real file in data/, not a broken relative symlink. + cached_source = cached_path.resolve(strict=True) + destination.parent.mkdir(parents=True, exist_ok=True) + try: + os.link(cached_source, destination) + except OSError: + shutil.copy2(cached_source, destination) + + +def manifest_path() -> Path: + return local_path_for_remote(f"{REMOTE_ROOT_PREFIX}/manifest.json") + + +def load_manifest(*, skip_manifest_download: bool) -> dict: + path = manifest_path() + if not path.is_file(): + if skip_manifest_download: + raise FileNotFoundError( + f"manifest.json is required for manifest-driven shard counts but is not present locally at {path}" + ) + get(f"{REMOTE_ROOT_PREFIX}/manifest.json") + return json.loads(path.read_text(encoding="utf-8")) + + +def artifact_paths_for_tokenizer(tokenizer_entry: dict) -> list[str]: + artifacts = [] + for key in ("model_path", "vocab_path", "path"): + value = tokenizer_entry.get(key) + if value: + artifacts.append(str(value)) + if not artifacts: + raise ValueError(f"tokenizer entry is missing downloadable artifacts: {tokenizer_entry}") + return artifacts + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Download challenge FineWeb shards from Hugging Face") + parser.add_argument( + "train_shards_positional", + nargs="?", + type=int, + default=None, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--train-shards", + type=int, + default=80, + help="Number of training shards to download for the selected variant. Defaults to 80.", + ) + parser.add_argument( + "--variant", + default="sp1024", + help="Tokenizer family to download, for example sp1024, sp4096, or byte260.", + ) + parser.add_argument( + "--skip-manifest", + action="store_true", + help="Skip downloading manifest.json.", + ) + parser.add_argument( + "--with-docs", + action="store_true", + help="Also download docs_selected.jsonl and its sidecar for tokenizer retraining or dataset re-export.", + ) + return parser + + +def main() -> None: + args = build_parser().parse_args() + dataset_dir = dataset_dir_for_variant(args.variant) + train_shards = args.train_shards_positional if args.train_shards_positional is not None else args.train_shards + if train_shards < 0: + raise ValueError("train_shards must be non-negative") + + manifest = load_manifest(skip_manifest_download=args.skip_manifest) + dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir), None) + if dataset_entry is None: + raise ValueError(f"dataset {dataset_dir} not found in {REMOTE_ROOT_PREFIX}/manifest.json") + max_train_shards = int((dataset_entry.get("stats") or {}).get("files_train")) + val_shards = int((dataset_entry.get("stats") or {}).get("files_val")) + if train_shards > max_train_shards: + raise ValueError( + f"{args.variant} only has {max_train_shards} training shards on {REPO_ID}, requested {train_shards}" + ) + tokenizer_name = dataset_entry.get("tokenizer_name") + tokenizer_entry = next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) + if tokenizer_entry is None: + raise ValueError(f"tokenizer {tokenizer_name} not found in {REMOTE_ROOT_PREFIX}/manifest.json") + + if args.with_docs: + get(f"{REMOTE_ROOT_PREFIX}/docs_selected.jsonl") + get(f"{REMOTE_ROOT_PREFIX}/docs_selected.source_manifest.json") + + dataset_prefix = f"{REMOTE_ROOT_PREFIX}/datasets/{dataset_dir}" + for i in range(val_shards): + get(f"{dataset_prefix}/fineweb_val_{i:06d}.bin") + for i in range(train_shards): + get(f"{dataset_prefix}/fineweb_train_{i:06d}.bin") + + for artifact_path in artifact_paths_for_tokenizer(tokenizer_entry): + get(f"{REMOTE_ROOT_PREFIX}/{artifact_path}") + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/runpod_launch.sh b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/runpod_launch.sh new file mode 100644 index 000000000..199d95508 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/runpod_launch.sh @@ -0,0 +1,60 @@ +#!/bin/bash +set -e +echo "=== Parameter Golf V6 RunPod Setup ===" +pip install sentencepiece zstandard huggingface_hub 2>/dev/null + +# Data setup +if [ ! -d "./data/datasets/fineweb10B_sp1024" ]; then + if [ -d "./datasets/fineweb10B_sp1024" ]; then + mkdir -p data + ln -sf "$(pwd)/datasets" data/datasets + ln -sf "$(pwd)/tokenizers" data/tokenizers + else + python3 cached_challenge_fineweb.py --variant sp1024 + mkdir -p data + ln -sf "$(pwd)/datasets" data/datasets + ln -sf "$(pwd)/tokenizers" data/tokenizers + fi +fi +echo "Data ready: $(ls data/datasets/fineweb10B_sp1024/ | wc -l) files" + +MODE=${1:-default} +SEED=${SEED:-42} +echo "=== Mode: $MODE | Seed: $SEED ===" + +case $MODE in + smoke) + # 60-second smoke test — catches crashes before burning a full run ($0.40 vs $8) + echo "SMOKE TEST: 60s training + quick eval — catching crashes early" + MAX_WALLCLOCK_SECONDS=60 VAL_LOSS_EVERY=0 NGRAM_EVAL_ORDER=0 \ + SEED=$SEED torchrun --standalone --nproc_per_node=8 train_gpt.py + echo "SMOKE TEST PASSED — safe to run full" + ;; + default) + echo "V6: 10L d=512 4KV LeakyReLU^2 XSA4 PartialRoPE VR EMA + 7-gram backoff + entropy-adaptive" + SEED=$SEED torchrun --standalone --nproc_per_node=8 train_gpt.py + ;; + fast) + # Smoke test then full run back-to-back + echo "=== SMOKE TEST (60s) ===" + MAX_WALLCLOCK_SECONDS=60 VAL_LOSS_EVERY=0 NGRAM_EVAL_ORDER=0 \ + SEED=$SEED torchrun --standalone --nproc_per_node=8 train_gpt.py + echo "=== SMOKE PASSED — LAUNCHING FULL RUN ===" + SEED=$SEED torchrun --standalone --nproc_per_node=8 train_gpt.py + ;; + no_ngram) + echo "Ablation: no n-gram cache" + NGRAM_EVAL_ORDER=0 SEED=$SEED torchrun --standalone --nproc_per_node=8 train_gpt.py + ;; + three_seed) + for S in 42 1337 2024; do + echo "=== Seed $S ===" + SEED=$S torchrun --standalone --nproc_per_node=8 train_gpt.py + done + ;; + *) + echo "Modes: smoke|default|fast|no_ngram|three_seed" + exit 1 + ;; +esac +echo "=== Done ===" diff --git a/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/submission.json b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/submission.json new file mode 100644 index 000000000..e9da12b8d --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/submission.json @@ -0,0 +1,10 @@ +{ + "author": "Bortlesboat", + "github_id": "Bortlesboat", + "name": "10L Int5-MLP + BigramHash(4096) + Multi-Order N-gram Backoff + Entropy-Adaptive Alpha", + "blurb": "10 layers, d=512, GQA 8H/4KV. LeakyReLU(0.5)^2, Partial RoPE(16/64), LN Scale, XSA last 4, Value Residual. EMA(0.997). Mixed int5/int6 + zstd-22. Eval: multi-order hashed n-gram backoff (orders 2-7) with entropy-adaptive alpha. Mean of 3 seeds.", + "date": "2026-03-25", + "val_loss": 1.5404, + "val_bpb": 0.9123, + "bytes_total": 15320000 +} diff --git a/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/train_gpt.py b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/train_gpt.py new file mode 100644 index 000000000..7721b5a33 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/train_gpt.py @@ -0,0 +1,1541 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) # 0=skip mid-train val, maximize training steps + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 5)) # minimal warmup, maximize real steps + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 64)) # larger batch for faster eval (no gradients) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + # Partial RoPE: only rotate first rope_dims dims (0 = full head_dim) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + + # LN Scale: dampen norm inputs by 1/sqrt(layer_idx+1) for deeper layers + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + + # XSA: exclusive self-attention on last N layers (0 = disabled) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # proven: last 4 layers + + # EMA: exponential moving average (replaces SWA when enabled) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "0"))) # OFF by default, EMA replaces it + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # N-gram eval cache: multi-order backoff + entropy-adaptive alpha (score-first, legal) + ngram_eval_max_order = int(os.environ.get("NGRAM_EVAL_ORDER", 7)) # max n-gram order + ngram_eval_min_order = int(os.environ.get("NGRAM_EVAL_MIN_ORDER", 2)) # min backoff order + ngram_eval_alpha = float(os.environ.get("NGRAM_EVAL_ALPHA", 0.40)) # base alpha + ngram_eval_min_count = int(os.environ.get("NGRAM_EVAL_MIN_COUNT", 2)) + ngram_eval_buckets = int(os.environ.get("NGRAM_EVAL_BUCKETS", 4_194_304)) + ngram_eval_entropy = bool(int(os.environ.get("NGRAM_EVAL_ENTROPY", "1"))) + ngram_eval_ent_base = float(os.environ.get("NGRAM_EVAL_ENT_BASE", 0.05)) + ngram_eval_ent_range = float(os.environ.get("NGRAM_EVAL_ENT_RANGE", 0.55)) + ngram_eval_ent_scale = float(os.environ.get("NGRAM_EVAL_ENT_SCALE", 2.0)) + ngram_eval_ent_thresh = float(os.environ.get("NGRAM_EVAL_ENT_THRESH", 4.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, + rope_dims: int = 0, use_xsa: bool = False): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.rope_dims = rope_dims if rope_dims > 0 else self.head_dim + self.use_xsa = use_xsa + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.rope_dims, base=rope_base) + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Exclusive self-attention: subtract self-value from attention output.""" + # y is post-attention [bsz, heads, seq, head_dim], v is [bsz, kv_heads, seq, head_dim] + if self.num_kv_heads != self.num_heads: + v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + return y - v / v.size(2) + + def forward(self, x: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + # Value Residual: blend with layer-0 V + if v0 is not None: + v = 0.5 * (v + v0) + v_out = v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dims < self.head_dim: + # Partial RoPE: rotate only first rope_dims, pass rest through + q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] + k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if self.num_kv_heads != self.num_heads: + n_rep = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(n_rep, dim=1) + v_sdpa = v.repeat_interleave(n_rep, dim=1) + else: + v_sdpa = v + y = F.scaled_dot_product_attention( + q, k, v_sdpa, attn_mask=None, is_causal=True, + ) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y), v_out + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, + qk_gain_init: float, rope_dims: int = 0, use_xsa: bool = False, ln_scale_factor: float = 1.0): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_dims=rope_dims, use_xsa=use_xsa) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = ln_scale_factor + + def forward(self, x: Tensor, x0: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor]: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_out, v_out = self.attn(self.attn_norm(x) * s, v0=v0) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x, v_out + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + rope_dims: int = 0, + ln_scale: bool = False, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + rope_dims=rope_dims, + use_xsa=(i >= num_layers - xsa_last_n) if xsa_last_n > 0 else False, + ln_scale_factor=1.0 / math.sqrt(i + 1) if ln_scale else 1.0) + for i in range(num_layers) + ]) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def _forward_body(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0: Tensor | None = None + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x, v_out = self.blocks[i](x, x0, v0=v0) + if v0 is None: + v0 = v_out + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + return x + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self._forward_body(input_ids) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self._forward_body(input_ids) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + eval_start = time.perf_counter() + eval_budget_s = 570.0 # 30s margin from 10-min eval budget + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + eval_elapsed = time.perf_counter() - eval_start + if eval_elapsed > eval_budget_s: + if rank == 0: + print(f" FAILSAFE: eval time {eval_elapsed:.0f}s exceeds {eval_budget_s}s budget, returning partial results", flush=True) + break + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ngram( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + """Sliding eval with multi-order n-gram backoff + entropy-adaptive alpha (score-first, legal).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + max_order = args.ngram_eval_max_order + min_order = args.ngram_eval_min_order + buckets = args.ngram_eval_buckets + min_count = args.ngram_eval_min_count + use_entropy = args.ngram_eval_entropy + ent_base = args.ngram_eval_ent_base + ent_range = args.ngram_eval_ent_range + ent_scale = args.ngram_eval_ent_scale + ent_thresh = args.ngram_eval_ent_thresh + base_alpha = args.ngram_eval_alpha + n_orders = max_order - min_order + 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + val_np = val_tokens.numpy() + ctx_tables = [np.zeros((buckets,), dtype=np.uint32) for _ in range(n_orders)] + full_tables = [np.zeros((buckets,), dtype=np.uint32) for _ in range(n_orders)] + mask = np.uint64(buckets - 1) + primes = np.array( + [np.uint64(36313), np.uint64(27191), np.uint64(51647), np.uint64(81929), + np.uint64(131071), np.uint64(175447), np.uint64(209591)], + dtype=np.uint64, + ) + + if rank == 0: + print(f"ngram_cache:enabled orders={min_order}-{max_order} backoff " + f"entropy={use_entropy} alpha={base_alpha} " + f"ent_base={ent_base} ent_range={ent_range} " + f"min_count={min_count} buckets={buckets}", flush=True) + + loss_sum = 0.0 + token_count = 0.0 + byte_count = 0.0 + + eval_start = time.perf_counter() + eval_budget_s = 570.0 + # Pre-allocate eval buffers (avoid per-batch allocation) + x_buf = torch.zeros(batch_seqs, seq_len, dtype=torch.int64, device=device) + y_buf = torch.zeros(batch_seqs, seq_len, dtype=torch.int64, device=device) + base_model.eval() + # Compile eval path for faster inference + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + eval_elapsed = time.perf_counter() - eval_start + if eval_elapsed > eval_budget_s: + if rank == 0: + print(f" FAILSAFE: ngram eval time {eval_elapsed:.0f}s exceeds budget", flush=True) + break + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = x_buf[:bsz] + y_batch = y_buf[:bsz] + x_batch.zero_() + y_batch.zero_() + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len = wlen - s + if seg_len <= 0: + continue + + seg_nll = nll[i, s:wlen].to(torch.float64).cpu().numpy() + seg_model_p = np.exp(-seg_nll) + n_seg = len(seg_nll) + global_j = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + + # Entropy-adaptive alpha + if use_entropy: + with torch.no_grad(): + lp = F.log_softmax(logits[i, s:wlen].float(), dim=-1) + seg_ent = -(lp.exp() * lp).sum(dim=-1).cpu().numpy() + alpha_per_tok = ent_base + ent_range / ( + 1.0 + np.exp(-ent_scale * (seg_ent - ent_thresh))) + + # Precompute hashes for all orders + order_data = [] + for oi in range(n_orders): + ctx_w = min_order + oi - 1 + valid = global_j >= ctx_w + if not valid.any(): + order_data.append(None) + continue + v_idx = np.nonzero(valid)[0] + jv = global_j[v_idx] + ctx_hash = np.zeros(len(jv), dtype=np.uint64) + for k in range(ctx_w): + tok = val_np[jv - (ctx_w - k)].astype(np.uint64) + ctx_hash ^= tok * primes[k % len(primes)] + ctx_key = (ctx_hash & mask).astype(np.int64) + tgt_np = val_np[jv].astype(np.uint64) + full_key = ((ctx_hash ^ (tgt_np * primes[ctx_w % len(primes)])) & mask).astype(np.int64) + order_data.append((v_idx, ctx_key, full_key)) + + # Multi-order backoff: highest order first + best_p_ng = np.full(n_seg, -1.0) + for oi in range(n_orders - 1, -1, -1): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + ctx_counts = ctx_tables[oi][ctx_key].astype(np.float64) + full_counts = full_tables[oi][full_key].astype(np.float64) + has_match = (ctx_counts >= float(min_count)) & (full_counts > 0) + needs_fill = has_match & (best_p_ng[v_idx] < 0) + if needs_fill.any(): + fill_idx = v_idx[needs_fill] + p = np.minimum(full_counts[needs_fill], ctx_counts[needs_fill]) / np.maximum(ctx_counts[needs_fill], 1.0) + best_p_ng[fill_idx] = np.clip(p, 0.0, 1.0) + + # Mix model probability with n-gram + has_match = best_p_ng >= 0 + if has_match.any(): + if use_entropy: + alpha = alpha_per_tok[has_match] + else: + alpha = base_alpha + seg_model_p[has_match] = (1.0 - alpha) * seg_model_p[has_match] + alpha * best_p_ng[has_match] + seg_nll = -np.log(np.clip(seg_model_p, 1e-12, 1.0)) + + # Score-first: update ALL order tables AFTER scoring + for oi in range(n_orders): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + np.add.at(ctx_tables[oi], ctx_key, 1) + np.add.at(full_tables[oi], full_key, 1) + + loss_sum += float(seg_nll.sum()) + token_count += float(seg_len) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += float(tb.sum().item()) + + if rank == 0 and (bi // batch_seqs) % 200 == 0 and bi > 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + cur_bpb = (loss_sum / max(token_count, 1.0)) / math.log(2.0) * (token_count / max(byte_count, 1.0)) + elapsed = time.perf_counter() - eval_start + print(f" ngram_eval [{pct:5.1f}%] bpb={cur_bpb:.6f} t={elapsed:.0f}s", flush=True) + + _loss = torch.tensor(loss_sum, device=device, dtype=torch.float64) + _toks = torch.tensor(token_count, device=device, dtype=torch.float64) + _bytes = torch.tensor(byte_count, device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(_toks, op=dist.ReduceOp.SUM) + dist.all_reduce(_bytes, op=dist.ReduceOp.SUM) + + val_loss = _loss.item() / max(_toks.item(), 1.0) + val_bpb = val_loss / math.log(2.0) * (_toks.item() / max(_bytes.item(), 1.0)) + # Coverage check: warn if eval was cut short + total_expected = sum(1 for ws in window_starts + if (min(ws + seq_len, total_tokens) - ws - (0 if ws == 0 else max(min(ws + seq_len, total_tokens) - ws - stride, 0))) > 0) + coverage = _toks.item() / max(total_expected * stride, 1.0) # approximate + elapsed = time.perf_counter() - eval_start + if rank == 0: + print(f" ngram_eval DONE: bpb={val_bpb:.6f} tokens={_toks.item():.0f} t={elapsed:.0f}s", flush=True) + if elapsed >= eval_budget_s - 10: + print(f" WARNING: eval used {elapsed:.0f}s of {eval_budget_s}s budget — results may be from partial coverage", flush=True) + base_model.train() + return val_loss, val_bpb + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + try: + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + except ImportError: + pass + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + torch._dynamo.config.optimize_ddp = False + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # EMA shadow model (kept on GPU to avoid PCIe bottleneck) + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().clone() for name, t in base_model.state_dict().items()} + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + + # EMA update every 10 steps (GPU-resident, amortize overhead) + if ema_state is not None and step % 10 == 0: + decay = args.ema_decay ** 10 # compensate for batched updates + with torch.no_grad(): + for name, param in base_model.state_dict().items(): + ema_state[name].lerp_(param.detach(), 1.0 - decay) + + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # Apply EMA if enabled (overrides SWA) + if args.ema_enabled and ema_state is not None: + log0("ema:applying shadow model") + current_state = base_model.state_dict() + ema_applied = { + name: tensor.to(dtype=current_state[name].dtype, device=current_state[name].device) + for name, tensor in ema_state.items() + } + base_model.load_state_dict(ema_applied, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + total_bytes = quant_file_bytes + code_bytes + log0(f"Total submission size: {total_bytes} bytes ({total_bytes/1e6:.2f} MB)") + if total_bytes > 16_000_000: + log0(f"FAILSAFE: artifact {total_bytes} bytes EXCEEDS 16MB limit! Aborting eval.") + sys.exit(1) + log0(f"SIZE CHECK PASSED: {total_bytes/1e6:.2f} MB < 16.00 MB") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.ngram_eval_max_order >= 2 and args.eval_stride > 0: + log0(f"final_eval_mode:sliding_ngram orders={args.ngram_eval_min_order}-{args.ngram_eval_max_order} " + f"alpha={args.ngram_eval_alpha} entropy={args.ngram_eval_entropy} stride:{args.eval_stride}") + q_val_loss, q_val_bpb = eval_val_sliding_ngram( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + batch_seqs=args.eval_batch_seqs, + ) + elif args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/v6_seed1337_2024.log b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/v6_seed1337_2024.log new file mode 100644 index 000000000..69716b925 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/v6_seed1337_2024.log @@ -0,0 +1,260 @@ +=== Seed 1337 === +W0326 02:49:04.490000 131659747332736 torch/distributed/run.py:779] +W0326 02:49:04.490000 131659747332736 torch/distributed/run.py:779] ***************************************** +W0326 02:49:04.490000 131659747332736 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 02:49:04.490000 131659747332736 torch/distributed/run.py:779] ***************************************** +logs/2236ee0d-b3f5-4169-a4bc-93283c9719e1.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:5 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/5 +warmup_step:2/5 +warmup_step:3/5 +warmup_step:4/5 +warmup_step:5/5 +step:1/20000 train_loss:6.9296 train_time:202ms step_avg:201.96ms +step:2/20000 train_loss:7.8829 train_time:289ms step_avg:144.61ms +step:3/20000 train_loss:7.1437 train_time:388ms step_avg:129.20ms +step:4/20000 train_loss:7.6901 train_time:486ms step_avg:121.41ms +step:5/20000 train_loss:8.1498 train_time:583ms step_avg:116.64ms +step:6/20000 train_loss:8.0248 train_time:681ms step_avg:113.53ms +step:7/20000 train_loss:7.6134 train_time:780ms step_avg:111.37ms +step:8/20000 train_loss:7.1269 train_time:878ms step_avg:109.74ms +step:9/20000 train_loss:6.7005 train_time:976ms step_avg:108.44ms +step:10/20000 train_loss:6.4465 train_time:1086ms step_avg:108.56ms +step:100/20000 train_loss:3.1764 train_time:9893ms step_avg:98.93ms +step:200/20000 train_loss:2.3873 train_time:19778ms step_avg:98.89ms +step:300/20000 train_loss:2.5480 train_time:29664ms step_avg:98.88ms +step:400/20000 train_loss:2.4215 train_time:39584ms step_avg:98.96ms +step:500/20000 train_loss:2.3991 train_time:49465ms step_avg:98.93ms +step:600/20000 train_loss:2.3343 train_time:59437ms step_avg:99.06ms +step:700/20000 train_loss:2.3452 train_time:69418ms step_avg:99.17ms +step:800/20000 train_loss:2.2350 train_time:79410ms step_avg:99.26ms +step:900/20000 train_loss:2.1321 train_time:89403ms step_avg:99.34ms +step:1000/20000 train_loss:2.2784 train_time:99337ms step_avg:99.34ms +step:1100/20000 train_loss:2.3178 train_time:109329ms step_avg:99.39ms +step:1200/20000 train_loss:2.3546 train_time:119325ms step_avg:99.44ms +step:1300/20000 train_loss:2.1032 train_time:129307ms step_avg:99.47ms +step:1400/20000 train_loss:2.1856 train_time:139303ms step_avg:99.50ms +step:1500/20000 train_loss:2.2236 train_time:149226ms step_avg:99.48ms +step:1600/20000 train_loss:2.0802 train_time:159204ms step_avg:99.50ms +step:1700/20000 train_loss:2.1439 train_time:169184ms step_avg:99.52ms +step:1800/20000 train_loss:2.1605 train_time:179166ms step_avg:99.54ms +step:1900/20000 train_loss:2.1317 train_time:189115ms step_avg:99.53ms +step:2000/20000 train_loss:2.0698 train_time:199093ms step_avg:99.55ms +step:2100/20000 train_loss:2.0486 train_time:209078ms step_avg:99.56ms +step:2200/20000 train_loss:2.1455 train_time:219060ms step_avg:99.57ms +step:2300/20000 train_loss:2.1091 train_time:229034ms step_avg:99.58ms +step:2400/20000 train_loss:2.0668 train_time:238959ms step_avg:99.57ms +step:2500/20000 train_loss:2.1744 train_time:248943ms step_avg:99.58ms +step:2600/20000 train_loss:2.1116 train_time:258921ms step_avg:99.58ms +step:2700/20000 train_loss:2.0964 train_time:268892ms step_avg:99.59ms +step:2800/20000 train_loss:2.1509 train_time:278862ms step_avg:99.59ms +step:2900/20000 train_loss:2.0205 train_time:288763ms step_avg:99.57ms +step:3000/20000 train_loss:2.1526 train_time:298723ms step_avg:99.57ms +step:3100/20000 train_loss:2.0227 train_time:308683ms step_avg:99.58ms +step:3200/20000 train_loss:2.1615 train_time:318647ms step_avg:99.58ms +step:3300/20000 train_loss:2.0566 train_time:328559ms step_avg:99.56ms +step:3400/20000 train_loss:2.0045 train_time:338519ms step_avg:99.56ms +step:3500/20000 train_loss:2.1602 train_time:348488ms step_avg:99.57ms +step:3600/20000 train_loss:2.0773 train_time:358446ms step_avg:99.57ms +step:3700/20000 train_loss:2.0728 train_time:368407ms step_avg:99.57ms +step:3800/20000 train_loss:2.0488 train_time:378319ms step_avg:99.56ms +step:3900/20000 train_loss:2.0530 train_time:388281ms step_avg:99.56ms +step:4000/20000 train_loss:1.9521 train_time:398246ms step_avg:99.56ms +step:4100/20000 train_loss:1.9892 train_time:408207ms step_avg:99.56ms +step:4200/20000 train_loss:2.1251 train_time:418176ms step_avg:99.57ms +step:4300/20000 train_loss:2.0324 train_time:428085ms step_avg:99.55ms +step:4400/20000 train_loss:2.0079 train_time:438047ms step_avg:99.56ms +step:4500/20000 train_loss:2.1006 train_time:448015ms step_avg:99.56ms +step:4600/20000 train_loss:1.8171 train_time:457980ms step_avg:99.56ms +step:4700/20000 train_loss:2.2117 train_time:467882ms step_avg:99.55ms +step:4800/20000 train_loss:2.4033 train_time:477842ms step_avg:99.55ms +step:4900/20000 train_loss:2.0217 train_time:487801ms step_avg:99.55ms +step:5000/20000 train_loss:2.0761 train_time:497766ms step_avg:99.55ms +step:5100/20000 train_loss:2.0999 train_time:507716ms step_avg:99.55ms +step:5200/20000 train_loss:2.0152 train_time:517610ms step_avg:99.54ms +step:5300/20000 train_loss:1.9789 train_time:527572ms step_avg:99.54ms +step:5400/20000 train_loss:2.0207 train_time:537533ms step_avg:99.54ms +step:5500/20000 train_loss:1.9874 train_time:547490ms step_avg:99.54ms +step:5600/20000 train_loss:1.9250 train_time:557451ms step_avg:99.54ms +step:5700/20000 train_loss:1.9831 train_time:567362ms step_avg:99.54ms +step:5800/20000 train_loss:1.9653 train_time:577406ms step_avg:99.55ms +step:5900/20000 train_loss:1.8742 train_time:587364ms step_avg:99.55ms +step:6000/20000 train_loss:1.9150 train_time:597321ms step_avg:99.55ms +step:6028/20000 val_loss:1.9521 val_bpb:1.1561 train_time:600095ms step_avg:99.55ms +stopping_early: wallclock_cap train_time:600095ms step:6028/20000 +peak memory allocated: 25196 MiB reserved: 25954 MiB +ema:applying shadow model +Serialized model: 96864150 bytes +Code size: 68444 bytes +Total submission size: 96932594 bytes +Serialized model int6+zstd: 15565623 bytes +Total submission size: 15634067 bytes (15.63 MB) +SIZE CHECK PASSED: 15.63 MB < 16.00 MB +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +final_eval_mode:sliding_ngram orders=2-7 alpha=0.4 entropy=True stride:64 +ngram_cache:enabled orders=2-7 backoff entropy=True alpha=0.4 ent_base=0.05 ent_range=0.55 min_count=2 buckets=4194304 + ngram_eval [ 10.6%] bpb=1.114368 t=27s + ngram_eval [ 21.2%] bpb=1.096085 t=41s + ngram_eval [ 31.8%] bpb=1.072011 t=55s + ngram_eval [ 42.3%] bpb=1.043196 t=68s + ngram_eval [ 52.9%] bpb=1.015398 t=82s + ngram_eval [ 63.5%] bpb=0.989285 t=96s + ngram_eval [ 74.0%] bpb=0.967902 t=110s + ngram_eval [ 84.6%] bpb=0.947068 t=124s + ngram_eval [ 95.2%] bpb=0.926718 t=138s + ngram_eval DONE: bpb=0.912141 tokens=62023616 t=158s +final_int8_zlib_roundtrip val_loss:1.5401 val_bpb:0.9121 eval_time:158096ms +final_int8_zlib_roundtrip_exact val_loss:1.54010812 val_bpb:0.91214120 +=== Seed 2024 === +W0326 03:03:59.020000 133054078444160 torch/distributed/run.py:779] +W0326 03:03:59.020000 133054078444160 torch/distributed/run.py:779] ***************************************** +W0326 03:03:59.020000 133054078444160 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 03:03:59.020000 133054078444160 torch/distributed/run.py:779] ***************************************** +logs/5e38a299-e6cc-493f-a92b-b0f1b276d42a.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:5 max_wallclock_seconds:600.000 +seed:2024 +warmup_step:1/5 +warmup_step:2/5 +warmup_step:3/5 +warmup_step:4/5 +warmup_step:5/5 +step:1/20000 train_loss:6.9305 train_time:197ms step_avg:197.21ms +step:2/20000 train_loss:7.8500 train_time:285ms step_avg:142.45ms +step:3/20000 train_loss:7.1338 train_time:383ms step_avg:127.73ms +step:4/20000 train_loss:7.7812 train_time:481ms step_avg:120.33ms +step:5/20000 train_loss:8.1436 train_time:580ms step_avg:116.01ms +step:6/20000 train_loss:7.9960 train_time:680ms step_avg:113.29ms +step:7/20000 train_loss:7.6734 train_time:778ms step_avg:111.08ms +step:8/20000 train_loss:7.2128 train_time:875ms step_avg:109.43ms +step:9/20000 train_loss:6.5919 train_time:973ms step_avg:108.15ms +step:10/20000 train_loss:6.3320 train_time:1083ms step_avg:108.32ms +step:100/20000 train_loss:3.1827 train_time:9926ms step_avg:99.26ms +step:200/20000 train_loss:2.3904 train_time:19812ms step_avg:99.06ms +step:300/20000 train_loss:2.5369 train_time:29748ms step_avg:99.16ms +step:400/20000 train_loss:2.4103 train_time:39707ms step_avg:99.27ms +step:500/20000 train_loss:2.3940 train_time:49644ms step_avg:99.29ms +step:600/20000 train_loss:2.3301 train_time:59647ms step_avg:99.41ms +step:700/20000 train_loss:2.3469 train_time:69650ms step_avg:99.50ms +step:800/20000 train_loss:2.2375 train_time:79657ms step_avg:99.57ms +step:900/20000 train_loss:2.1259 train_time:89651ms step_avg:99.61ms +step:1000/20000 train_loss:2.2782 train_time:99606ms step_avg:99.61ms +step:1100/20000 train_loss:2.3212 train_time:109590ms step_avg:99.63ms +step:1200/20000 train_loss:2.3558 train_time:119589ms step_avg:99.66ms +step:1300/20000 train_loss:2.1057 train_time:129585ms step_avg:99.68ms +step:1400/20000 train_loss:2.1860 train_time:139575ms step_avg:99.70ms +step:1500/20000 train_loss:2.2268 train_time:149514ms step_avg:99.68ms +step:1600/20000 train_loss:2.0769 train_time:159512ms step_avg:99.69ms +step:1700/20000 train_loss:2.1489 train_time:169508ms step_avg:99.71ms +step:1800/20000 train_loss:2.1528 train_time:179508ms step_avg:99.73ms +step:1900/20000 train_loss:2.1279 train_time:189464ms step_avg:99.72ms +step:2000/20000 train_loss:2.0704 train_time:199447ms step_avg:99.72ms +step:2100/20000 train_loss:2.0534 train_time:209440ms step_avg:99.73ms +step:2200/20000 train_loss:2.1535 train_time:219426ms step_avg:99.74ms +step:2300/20000 train_loss:2.1110 train_time:229422ms step_avg:99.75ms +step:2400/20000 train_loss:2.0698 train_time:239342ms step_avg:99.73ms +step:2500/20000 train_loss:2.1724 train_time:249332ms step_avg:99.73ms +step:2600/20000 train_loss:2.1099 train_time:259325ms step_avg:99.74ms +step:2700/20000 train_loss:2.0988 train_time:269301ms step_avg:99.74ms +step:2800/20000 train_loss:2.1513 train_time:279286ms step_avg:99.75ms +step:2900/20000 train_loss:2.0195 train_time:289288ms step_avg:99.75ms +step:3000/20000 train_loss:2.1573 train_time:299265ms step_avg:99.75ms +step:3100/20000 train_loss:2.0244 train_time:309237ms step_avg:99.75ms +step:3200/20000 train_loss:2.1597 train_time:319210ms step_avg:99.75ms +step:3300/20000 train_loss:2.0578 train_time:329129ms step_avg:99.74ms +step:3400/20000 train_loss:2.0050 train_time:339104ms step_avg:99.74ms +step:3500/20000 train_loss:2.1617 train_time:349084ms step_avg:99.74ms +step:3600/20000 train_loss:2.0753 train_time:359048ms step_avg:99.74ms +step:3700/20000 train_loss:2.0756 train_time:369029ms step_avg:99.74ms +step:3800/20000 train_loss:2.0522 train_time:378955ms step_avg:99.73ms +step:3900/20000 train_loss:2.0542 train_time:388923ms step_avg:99.72ms +step:4000/20000 train_loss:1.9538 train_time:398905ms step_avg:99.73ms +step:4100/20000 train_loss:1.9919 train_time:408873ms step_avg:99.73ms +step:4200/20000 train_loss:2.1284 train_time:418857ms step_avg:99.73ms +step:4300/20000 train_loss:2.0319 train_time:428768ms step_avg:99.71ms +step:4400/20000 train_loss:2.0114 train_time:438736ms step_avg:99.71ms +step:4500/20000 train_loss:2.1011 train_time:448704ms step_avg:99.71ms +step:4600/20000 train_loss:1.8148 train_time:458677ms step_avg:99.71ms +step:4700/20000 train_loss:2.2173 train_time:468596ms step_avg:99.70ms +step:4800/20000 train_loss:2.4029 train_time:478553ms step_avg:99.70ms +step:4900/20000 train_loss:2.0207 train_time:488521ms step_avg:99.70ms +step:5000/20000 train_loss:2.0810 train_time:498486ms step_avg:99.70ms +step:5100/20000 train_loss:2.1044 train_time:508449ms step_avg:99.70ms +step:5200/20000 train_loss:2.0161 train_time:518362ms step_avg:99.69ms +step:5300/20000 train_loss:1.9831 train_time:528326ms step_avg:99.68ms +step:5400/20000 train_loss:2.0235 train_time:538292ms step_avg:99.68ms +step:5500/20000 train_loss:1.9933 train_time:548262ms step_avg:99.68ms +step:5600/20000 train_loss:1.9257 train_time:558233ms step_avg:99.68ms +step:5700/20000 train_loss:1.9870 train_time:568131ms step_avg:99.67ms +step:5800/20000 train_loss:1.9716 train_time:578101ms step_avg:99.67ms +step:5900/20000 train_loss:1.8747 train_time:588063ms step_avg:99.67ms +step:6000/20000 train_loss:1.9170 train_time:598026ms step_avg:99.67ms +step:6020/20000 val_loss:1.9546 val_bpb:1.1576 train_time:600016ms step_avg:99.67ms +stopping_early: wallclock_cap train_time:600016ms step:6020/20000 +peak memory allocated: 25196 MiB reserved: 25954 MiB +ema:applying shadow model +Serialized model: 96864150 bytes +Code size: 68444 bytes +Total submission size: 96932594 bytes +Serialized model int6+zstd: 15258786 bytes +Total submission size: 15327230 bytes (15.33 MB) +SIZE CHECK PASSED: 15.33 MB < 16.00 MB +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +final_eval_mode:sliding_ngram orders=2-7 alpha=0.4 entropy=True stride:64 +ngram_cache:enabled orders=2-7 backoff entropy=True alpha=0.4 ent_base=0.05 ent_range=0.55 min_count=2 buckets=4194304 + ngram_eval [ 10.6%] bpb=1.115965 t=26s + ngram_eval [ 21.2%] bpb=1.098010 t=40s + ngram_eval [ 31.8%] bpb=1.073715 t=54s + ngram_eval [ 42.3%] bpb=1.044457 t=68s + ngram_eval [ 52.9%] bpb=1.016426 t=82s + ngram_eval [ 63.5%] bpb=0.989973 t=96s + ngram_eval [ 74.0%] bpb=0.968386 t=110s + ngram_eval [ 84.6%] bpb=0.947327 t=124s + ngram_eval [ 95.2%] bpb=0.926810 t=138s + ngram_eval DONE: bpb=0.912061 tokens=62023616 t=157s +final_int8_zlib_roundtrip val_loss:1.5400 val_bpb:0.9121 eval_time:157534ms +final_int8_zlib_roundtrip_exact val_loss:1.53997207 val_bpb:0.91206062 diff --git a/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/v6_seed42.log b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/v6_seed42.log new file mode 100644 index 000000000..cc2537ae5 --- /dev/null +++ b/records/track_10min_16mb/2026-03-25_10L_NgramBackoff_EntropyAlpha/v6_seed42.log @@ -0,0 +1,262 @@ +=== Parameter Golf V6 RunPod Setup === +Requirement already satisfied: sentencepiece in /usr/local/lib/python3.11/dist-packages (0.2.1) +Requirement already satisfied: zstandard in /usr/local/lib/python3.11/dist-packages (0.25.0) +Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.11/dist-packages (1.8.0) +Requirement already satisfied: filelock>=3.10.0 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (3.13.1) +Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (2024.2.0) +Requirement already satisfied: hf-xet<2.0.0,>=1.4.2 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (1.4.2) +Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (0.27.2) +Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (24.1) +Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (6.0.2) +Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (4.67.3) +Requirement already satisfied: typer in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (0.24.1) +Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (4.9.0) +Requirement already satisfied: anyio in /usr/local/lib/python3.11/dist-packages (from httpx<1,>=0.23.0->huggingface_hub) (4.6.0) +Requirement already satisfied: certifi in /usr/local/lib/python3.11/dist-packages (from httpx<1,>=0.23.0->huggingface_hub) (2024.8.30) +Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.11/dist-packages (from httpx<1,>=0.23.0->huggingface_hub) (1.0.5) +Requirement already satisfied: idna in /usr/local/lib/python3.11/dist-packages (from httpx<1,>=0.23.0->huggingface_hub) (3.10) +Requirement already satisfied: sniffio in /usr/local/lib/python3.11/dist-packages (from httpx<1,>=0.23.0->huggingface_hub) (1.3.1) +Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.11/dist-packages (from httpcore==1.*->httpx<1,>=0.23.0->huggingface_hub) (0.14.0) +Requirement already satisfied: click>=8.2.1 in /usr/local/lib/python3.11/dist-packages (from typer->huggingface_hub) (8.3.1) +Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from typer->huggingface_hub) (1.5.4) +Requirement already satisfied: rich>=12.3.0 in /usr/local/lib/python3.11/dist-packages (from typer->huggingface_hub) (14.3.3) +Requirement already satisfied: annotated-doc>=0.0.2 in /usr/local/lib/python3.11/dist-packages (from typer->huggingface_hub) (0.0.4) +Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich>=12.3.0->typer->huggingface_hub) (4.0.0) +Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich>=12.3.0->typer->huggingface_hub) (2.18.0) +Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich>=12.3.0->typer->huggingface_hub) (0.1.2) +Data ready: 81 files +=== Mode: fast | Seed: 42 === +=== SMOKE TEST (60s) === +W0326 02:15:34.184000 125465632354944 torch/distributed/run.py:779] +W0326 02:15:34.184000 125465632354944 torch/distributed/run.py:779] ***************************************** +W0326 02:15:34.184000 125465632354944 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 02:15:34.184000 125465632354944 torch/distributed/run.py:779] ***************************************** +logs/fb64bf1e-299a-48cf-992f-496c2c98ba77.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:5 max_wallclock_seconds:60.000 +seed:42 +warmup_step:1/5 +warmup_step:2/5 +warmup_step:3/5 +warmup_step:4/5 +warmup_step:5/5 +step:1/20000 train_loss:6.9311 train_time:155ms step_avg:155.17ms +step:2/20000 train_loss:7.8294 train_time:243ms step_avg:121.57ms +step:3/20000 train_loss:7.6012 train_time:341ms step_avg:113.70ms +step:4/20000 train_loss:7.1930 train_time:439ms step_avg:109.69ms +step:5/20000 train_loss:6.7663 train_time:537ms step_avg:107.39ms +step:6/20000 train_loss:6.4573 train_time:635ms step_avg:105.84ms +step:7/20000 train_loss:6.2066 train_time:733ms step_avg:104.72ms +step:8/20000 train_loss:6.0283 train_time:831ms step_avg:103.89ms +step:9/20000 train_loss:5.8627 train_time:929ms step_avg:103.22ms +step:10/20000 train_loss:5.7430 train_time:1039ms step_avg:103.88ms +step:100/20000 train_loss:3.5739 train_time:9853ms step_avg:98.53ms +step:200/20000 train_loss:2.7985 train_time:19730ms step_avg:98.65ms +step:300/20000 train_loss:2.8685 train_time:29616ms step_avg:98.72ms +step:400/20000 train_loss:2.7181 train_time:39554ms step_avg:98.88ms +step:500/20000 train_loss:2.6698 train_time:49457ms step_avg:98.91ms +step:600/20000 train_loss:2.6216 train_time:59424ms step_avg:99.04ms +step:606/20000 val_loss:2.7451 val_bpb:1.6258 train_time:60030ms step_avg:99.06ms +stopping_early: wallclock_cap train_time:60030ms step:606/20000 +peak memory allocated: 25387 MiB reserved: 26052 MiB +ema:applying shadow model +Serialized model: 96864150 bytes +Code size: 68444 bytes +Total submission size: 96932594 bytes +Serialized model int6+zstd: 15529045 bytes +Total submission size: 15597489 bytes (15.60 MB) +SIZE CHECK PASSED: 15.60 MB < 16.00 MB +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +final_eval_mode:sliding_window stride:64 batch_seqs:64 + sliding_eval [ 0.1%] 64/121136 windows running_bpb=2.760234 + sliding_eval [ 2.7%] 3264/121136 windows running_bpb=2.798400 + sliding_eval [ 5.3%] 6464/121136 windows running_bpb=2.856406 + sliding_eval [ 8.0%] 9664/121136 windows running_bpb=2.864834 + sliding_eval [ 10.6%] 12864/121136 windows running_bpb=2.848683 + sliding_eval [ 13.3%] 16064/121136 windows running_bpb=2.856259 + sliding_eval [ 15.9%] 19264/121136 windows running_bpb=2.845270 + sliding_eval [ 18.5%] 22464/121136 windows running_bpb=2.847447 + sliding_eval [ 21.2%] 25664/121136 windows running_bpb=2.855087 + sliding_eval [ 23.8%] 28864/121136 windows running_bpb=2.857381 + sliding_eval [ 26.5%] 32064/121136 windows running_bpb=2.861731 + sliding_eval [ 29.1%] 35264/121136 windows running_bpb=2.858092 + sliding_eval [ 31.8%] 38464/121136 windows running_bpb=2.857593 + sliding_eval [ 34.4%] 41664/121136 windows running_bpb=2.864121 + sliding_eval [ 37.0%] 44864/121136 windows running_bpb=2.867615 + sliding_eval [ 39.7%] 48064/121136 windows running_bpb=2.865379 + sliding_eval [ 42.3%] 51264/121136 windows running_bpb=2.867973 + sliding_eval [ 45.0%] 54464/121136 windows running_bpb=2.869055 + sliding_eval [ 47.6%] 57664/121136 windows running_bpb=2.872260 + sliding_eval [ 50.2%] 60864/121136 windows running_bpb=2.869337 + sliding_eval [ 52.9%] 64064/121136 windows running_bpb=2.868171 + sliding_eval [ 55.5%] 67264/121136 windows running_bpb=2.865998 + sliding_eval [ 58.2%] 70464/121136 windows running_bpb=2.863157 + sliding_eval [ 60.8%] 73664/121136 windows running_bpb=2.862934 + sliding_eval [ 63.5%] 76864/121136 windows running_bpb=2.862642 + sliding_eval [ 66.1%] 80064/121136 windows running_bpb=2.862865 + sliding_eval [ 68.7%] 83264/121136 windows running_bpb=2.865234 + sliding_eval [ 71.4%] 86464/121136 windows running_bpb=2.864276 + sliding_eval [ 74.0%] 89664/121136 windows running_bpb=2.864835 + sliding_eval [ 76.7%] 92864/121136 windows running_bpb=2.865814 + sliding_eval [ 79.3%] 96064/121136 windows running_bpb=2.866785 + sliding_eval [ 81.9%] 99264/121136 windows running_bpb=2.869917 + sliding_eval [ 84.6%] 102464/121136 windows running_bpb=2.870053 + sliding_eval [ 87.2%] 105664/121136 windows running_bpb=2.868611 + sliding_eval [ 89.9%] 108864/121136 windows running_bpb=2.869601 + sliding_eval [ 92.5%] 112064/121136 windows running_bpb=2.869285 + sliding_eval [ 95.2%] 115264/121136 windows running_bpb=2.870947 + sliding_eval [ 97.8%] 118464/121136 windows running_bpb=2.871991 +final_int8_zlib_roundtrip val_loss:4.7670 val_bpb:2.8233 eval_time:250733ms +final_int8_zlib_roundtrip_exact val_loss:4.76695835 val_bpb:2.82326872 +=== SMOKE PASSED — LAUNCHING FULL RUN === +W0326 02:23:19.177000 139504813302400 torch/distributed/run.py:779] +W0326 02:23:19.177000 139504813302400 torch/distributed/run.py:779] ***************************************** +W0326 02:23:19.177000 139504813302400 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 02:23:19.177000 139504813302400 torch/distributed/run.py:779] ***************************************** +logs/fab63796-c8ab-453f-a55b-6d0c22e51348.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:5 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/5 +warmup_step:2/5 +warmup_step:3/5 +warmup_step:4/5 +warmup_step:5/5 +step:1/20000 train_loss:6.9311 train_time:154ms step_avg:154.03ms +step:2/20000 train_loss:7.8294 train_time:241ms step_avg:120.54ms +step:3/20000 train_loss:7.2314 train_time:339ms step_avg:113.12ms +step:4/20000 train_loss:7.8870 train_time:438ms step_avg:109.48ms +step:5/20000 train_loss:7.9718 train_time:537ms step_avg:107.31ms +step:6/20000 train_loss:7.7965 train_time:636ms step_avg:105.97ms +step:7/20000 train_loss:7.4614 train_time:734ms step_avg:104.87ms +step:8/20000 train_loss:7.2162 train_time:832ms step_avg:104.01ms +step:9/20000 train_loss:6.8091 train_time:930ms step_avg:103.38ms +step:10/20000 train_loss:6.4127 train_time:1040ms step_avg:104.02ms +step:100/20000 train_loss:3.1737 train_time:9885ms step_avg:98.85ms +step:200/20000 train_loss:2.3673 train_time:19788ms step_avg:98.94ms +step:300/20000 train_loss:2.5429 train_time:29740ms step_avg:99.13ms +step:400/20000 train_loss:2.4084 train_time:39706ms step_avg:99.27ms +step:500/20000 train_loss:2.3997 train_time:49630ms step_avg:99.26ms +step:600/20000 train_loss:2.3387 train_time:59626ms step_avg:99.38ms +step:700/20000 train_loss:2.3448 train_time:69627ms step_avg:99.47ms +step:800/20000 train_loss:2.2368 train_time:79628ms step_avg:99.53ms +step:900/20000 train_loss:2.1275 train_time:89624ms step_avg:99.58ms +step:1000/20000 train_loss:2.2804 train_time:99545ms step_avg:99.55ms +step:1100/20000 train_loss:2.3267 train_time:109542ms step_avg:99.58ms +step:1200/20000 train_loss:2.3560 train_time:119526ms step_avg:99.60ms +step:1300/20000 train_loss:2.1035 train_time:129507ms step_avg:99.62ms +step:1400/20000 train_loss:2.1871 train_time:139489ms step_avg:99.63ms +step:1500/20000 train_loss:2.2271 train_time:149418ms step_avg:99.61ms +step:1600/20000 train_loss:2.0803 train_time:159401ms step_avg:99.63ms +step:1700/20000 train_loss:2.1484 train_time:169380ms step_avg:99.64ms +step:1800/20000 train_loss:2.1565 train_time:179365ms step_avg:99.65ms +step:1900/20000 train_loss:2.1295 train_time:189290ms step_avg:99.63ms +step:2000/20000 train_loss:2.0741 train_time:199281ms step_avg:99.64ms +step:2100/20000 train_loss:2.0525 train_time:209261ms step_avg:99.65ms +step:2200/20000 train_loss:2.1768 train_time:219233ms step_avg:99.65ms +step:2300/20000 train_loss:2.1123 train_time:229213ms step_avg:99.66ms +step:2400/20000 train_loss:2.0732 train_time:239129ms step_avg:99.64ms +step:2500/20000 train_loss:2.1744 train_time:249107ms step_avg:99.64ms +step:2600/20000 train_loss:2.1134 train_time:259082ms step_avg:99.65ms +step:2700/20000 train_loss:2.1019 train_time:269050ms step_avg:99.65ms +step:2800/20000 train_loss:2.1543 train_time:279024ms step_avg:99.65ms +step:2900/20000 train_loss:2.0209 train_time:288931ms step_avg:99.63ms +step:3000/20000 train_loss:2.1559 train_time:298908ms step_avg:99.64ms +step:3100/20000 train_loss:2.0257 train_time:308889ms step_avg:99.64ms +step:3200/20000 train_loss:2.1604 train_time:318871ms step_avg:99.65ms +step:3300/20000 train_loss:2.0583 train_time:328862ms step_avg:99.66ms +step:3400/20000 train_loss:2.0056 train_time:338838ms step_avg:99.66ms +step:3500/20000 train_loss:2.1597 train_time:348807ms step_avg:99.66ms +step:3600/20000 train_loss:2.0758 train_time:358789ms step_avg:99.66ms +step:3700/20000 train_loss:2.0777 train_time:368752ms step_avg:99.66ms +step:3800/20000 train_loss:2.0524 train_time:378652ms step_avg:99.65ms +step:3900/20000 train_loss:2.0557 train_time:388620ms step_avg:99.65ms +step:4000/20000 train_loss:1.9542 train_time:398579ms step_avg:99.64ms +step:4100/20000 train_loss:1.9897 train_time:408557ms step_avg:99.65ms +step:4200/20000 train_loss:2.1255 train_time:418528ms step_avg:99.65ms +step:4300/20000 train_loss:2.0382 train_time:428435ms step_avg:99.64ms +step:4400/20000 train_loss:2.0127 train_time:438410ms step_avg:99.64ms +step:4500/20000 train_loss:2.1025 train_time:448377ms step_avg:99.64ms +step:4600/20000 train_loss:1.8174 train_time:458345ms step_avg:99.64ms +step:4700/20000 train_loss:2.2110 train_time:468250ms step_avg:99.63ms +step:4800/20000 train_loss:2.4039 train_time:478225ms step_avg:99.63ms +step:4900/20000 train_loss:2.0271 train_time:488192ms step_avg:99.63ms +step:5000/20000 train_loss:2.0833 train_time:498167ms step_avg:99.63ms +step:5100/20000 train_loss:2.1044 train_time:508132ms step_avg:99.63ms +step:5200/20000 train_loss:2.0173 train_time:518036ms step_avg:99.62ms +step:5300/20000 train_loss:1.9812 train_time:528005ms step_avg:99.62ms +step:5400/20000 train_loss:2.0219 train_time:537963ms step_avg:99.62ms +step:5500/20000 train_loss:1.9927 train_time:547926ms step_avg:99.62ms +step:5600/20000 train_loss:1.9309 train_time:557898ms step_avg:99.62ms +step:5700/20000 train_loss:1.9887 train_time:567811ms step_avg:99.62ms +step:5800/20000 train_loss:1.9694 train_time:577779ms step_avg:99.62ms +step:5900/20000 train_loss:1.8751 train_time:587745ms step_avg:99.62ms +step:6000/20000 train_loss:1.9175 train_time:597708ms step_avg:99.62ms +step:6023/20000 val_loss:1.9552 val_bpb:1.1580 train_time:599992ms step_avg:99.62ms +stopping_early: wallclock_cap train_time:599992ms step:6023/20000 +peak memory allocated: 25196 MiB reserved: 25954 MiB +ema:applying shadow model +Serialized model: 96864150 bytes +Code size: 68444 bytes +Total submission size: 96932594 bytes +Serialized model int6+zstd: 15247350 bytes +Total submission size: 15315794 bytes (15.32 MB) +SIZE CHECK PASSED: 15.32 MB < 16.00 MB +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1498: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +final_eval_mode:sliding_ngram orders=2-7 alpha=0.4 entropy=True stride:64 +ngram_cache:enabled orders=2-7 backoff entropy=True alpha=0.4 ent_base=0.05 ent_range=0.55 min_count=2 buckets=4194304 + ngram_eval [ 10.6%] bpb=1.115739 t=29s + ngram_eval [ 21.2%] bpb=1.097598 t=44s + ngram_eval [ 31.8%] bpb=1.073534 t=58s + ngram_eval [ 42.3%] bpb=1.044492 t=72s + ngram_eval [ 52.9%] bpb=1.016678 t=86s + ngram_eval [ 63.5%] bpb=0.990279 t=100s + ngram_eval [ 74.0%] bpb=0.968857 t=114s + ngram_eval [ 84.6%] bpb=0.947937 t=128s + ngram_eval [ 95.2%] bpb=0.927495 t=142s + ngram_eval DONE: bpb=0.912769 tokens=62023616 t=163s +final_int8_zlib_roundtrip val_loss:1.5412 val_bpb:0.9128 eval_time:162854ms +final_int8_zlib_roundtrip_exact val_loss:1.54116746 val_bpb:0.91276859 +=== Done === diff --git a/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/README.md b/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/README.md new file mode 100644 index 000000000..aaf4cc516 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/README.md @@ -0,0 +1,49 @@ +# 10L + Two-Pass Order-11 N-gram Backoff (0.5863 BPB) + +**val_bpb: 0.5863** (mean of 3 seeds, post int5/int6+zstd quantization roundtrip) + +**Record delta vs merged SOTA (PR #549, 1.1194 BPB):** -0.5331 nats, std=0.0002, p < 0.001 + +## Compliance + +- **Score-first**: every token's BPB is finalized before that token updates any cache table +- **Backward-looking only**: n-gram cache uses only previously scored tokens, never future tokens +- **No target-aware gating**: interpolation alpha depends solely on model entropy and matched n-gram order +- **No future-token access**: cache tables are updated AFTER the segment is scored +- **Two-pass legality**: pass 2 rescores tokens already evaluated in pass 1, using frozen cache (no new updates) +- **Self-contained**: no network calls, no external data, no training data access during eval + +## Results + +| Seed | val_bpb | artifact_bytes | +|------|---------|----------------| +| 42 | 0.5864 | 15,420,000 | +| 1337 | 0.5864 | 15,570,000 | +| 2024 | 0.5860 | 15,370,000 | +| **Mean** | **0.5863 +/- 0.0002** | | + +## Architecture + +- 10 layers, d=512, 8 heads, 4 KV heads (GQA) +- MLP: 3x expansion (1536), LeakyReLU(0.5)^2 +- BigramHash(4096, 128d), SmearGate, U-Net skips +- Partial RoPE (16/64), LN Scale, XSA last 4, Value Residual +- Mixed int5 MLP / int6 attention + zstd-22 +- EMA(0.997), Muon WD=0.04, matrix_lr=0.03, warmdown=3500 + +## Eval: Two-Pass Order-11 N-gram Backoff + +**Pass 1** (189s): score-first sliding window with orders 2-11 n-gram cache. +Order-adaptive entropy: `center = 3.0 - 0.25 * (order - 2)`, `alpha = 0.05 + 0.55 * sigmoid(2 * (H - center))`. + +**Pass 2** (140s): rescore early cold-cache windows with frozen full cache. All tokens already evaluated in pass 1. Total eval: 331s. + +## Based on + +- thwu1's 10L Int5-MLP base, PR #727/#788 (n-gram backoff), PR #846 (two-pass), PR #828 (matrix_lr) + +## Reproduce + +```bash +SEED=42 torchrun --standalone --nproc_per_node=8 train_gpt.py +``` diff --git a/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/cached_challenge_fineweb.py b/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/cached_challenge_fineweb.py new file mode 100644 index 000000000..fa8029be4 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/cached_challenge_fineweb.py @@ -0,0 +1,157 @@ +import argparse +import json +import os +import shutil +from pathlib import Path + +from huggingface_hub import hf_hub_download + + +REPO_ID = os.environ.get("MATCHED_FINEWEB_REPO_ID", "willdepueoai/parameter-golf") +REMOTE_ROOT_PREFIX = os.environ.get("MATCHED_FINEWEB_REMOTE_ROOT_PREFIX", "datasets") +ROOT = Path(__file__).resolve().parent +DATASETS_DIR = ROOT / "datasets" +TOKENIZERS_DIR = ROOT / "tokenizers" + +def dataset_dir_for_variant(name: str) -> str: + if name == "byte260": + return "fineweb10B_byte260" + if name.startswith("sp") and name[2:].isdigit(): + return f"fineweb10B_{name}" + raise ValueError(f"unsupported variant {name!r}; expected byte260 or sp") + + +def local_path_for_remote(relative_path: str) -> Path: + remote_path = Path(relative_path) + if REMOTE_ROOT_PREFIX and remote_path.parts[:1] == (REMOTE_ROOT_PREFIX,): + remote_path = remote_path.relative_to(REMOTE_ROOT_PREFIX) + if remote_path.parts[:1] == ("datasets",): + return DATASETS_DIR.joinpath(*remote_path.parts[1:]) + if remote_path.parts[:1] == ("tokenizers",): + return TOKENIZERS_DIR.joinpath(*remote_path.parts[1:]) + return ROOT / remote_path + + +def get(relative_path: str) -> None: + destination = local_path_for_remote(relative_path) + if destination.exists(): + return + if destination.is_symlink(): + destination.unlink() + + remote_path = Path(relative_path) + cached_path = Path( + hf_hub_download( + repo_id=REPO_ID, + filename=remote_path.name, + subfolder=remote_path.parent.as_posix() if remote_path.parent != Path(".") else None, + repo_type="dataset", + ) + ) + # HF cache entries may be snapshot symlinks. Resolve to the underlying blob so we + # always materialize a real file in data/, not a broken relative symlink. + cached_source = cached_path.resolve(strict=True) + destination.parent.mkdir(parents=True, exist_ok=True) + try: + os.link(cached_source, destination) + except OSError: + shutil.copy2(cached_source, destination) + + +def manifest_path() -> Path: + return local_path_for_remote(f"{REMOTE_ROOT_PREFIX}/manifest.json") + + +def load_manifest(*, skip_manifest_download: bool) -> dict: + path = manifest_path() + if not path.is_file(): + if skip_manifest_download: + raise FileNotFoundError( + f"manifest.json is required for manifest-driven shard counts but is not present locally at {path}" + ) + get(f"{REMOTE_ROOT_PREFIX}/manifest.json") + return json.loads(path.read_text(encoding="utf-8")) + + +def artifact_paths_for_tokenizer(tokenizer_entry: dict) -> list[str]: + artifacts = [] + for key in ("model_path", "vocab_path", "path"): + value = tokenizer_entry.get(key) + if value: + artifacts.append(str(value)) + if not artifacts: + raise ValueError(f"tokenizer entry is missing downloadable artifacts: {tokenizer_entry}") + return artifacts + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Download challenge FineWeb shards from Hugging Face") + parser.add_argument( + "train_shards_positional", + nargs="?", + type=int, + default=None, + help=argparse.SUPPRESS, + ) + parser.add_argument( + "--train-shards", + type=int, + default=80, + help="Number of training shards to download for the selected variant. Defaults to 80.", + ) + parser.add_argument( + "--variant", + default="sp1024", + help="Tokenizer family to download, for example sp1024, sp4096, or byte260.", + ) + parser.add_argument( + "--skip-manifest", + action="store_true", + help="Skip downloading manifest.json.", + ) + parser.add_argument( + "--with-docs", + action="store_true", + help="Also download docs_selected.jsonl and its sidecar for tokenizer retraining or dataset re-export.", + ) + return parser + + +def main() -> None: + args = build_parser().parse_args() + dataset_dir = dataset_dir_for_variant(args.variant) + train_shards = args.train_shards_positional if args.train_shards_positional is not None else args.train_shards + if train_shards < 0: + raise ValueError("train_shards must be non-negative") + + manifest = load_manifest(skip_manifest_download=args.skip_manifest) + dataset_entry = next((x for x in manifest.get("datasets", []) if x.get("name") == dataset_dir), None) + if dataset_entry is None: + raise ValueError(f"dataset {dataset_dir} not found in {REMOTE_ROOT_PREFIX}/manifest.json") + max_train_shards = int((dataset_entry.get("stats") or {}).get("files_train")) + val_shards = int((dataset_entry.get("stats") or {}).get("files_val")) + if train_shards > max_train_shards: + raise ValueError( + f"{args.variant} only has {max_train_shards} training shards on {REPO_ID}, requested {train_shards}" + ) + tokenizer_name = dataset_entry.get("tokenizer_name") + tokenizer_entry = next((x for x in manifest.get("tokenizers", []) if x.get("name") == tokenizer_name), None) + if tokenizer_entry is None: + raise ValueError(f"tokenizer {tokenizer_name} not found in {REMOTE_ROOT_PREFIX}/manifest.json") + + if args.with_docs: + get(f"{REMOTE_ROOT_PREFIX}/docs_selected.jsonl") + get(f"{REMOTE_ROOT_PREFIX}/docs_selected.source_manifest.json") + + dataset_prefix = f"{REMOTE_ROOT_PREFIX}/datasets/{dataset_dir}" + for i in range(val_shards): + get(f"{dataset_prefix}/fineweb_val_{i:06d}.bin") + for i in range(train_shards): + get(f"{dataset_prefix}/fineweb_train_{i:06d}.bin") + + for artifact_path in artifact_paths_for_tokenizer(tokenizer_entry): + get(f"{REMOTE_ROOT_PREFIX}/{artifact_path}") + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/runpod_launch.sh b/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/runpod_launch.sh new file mode 100644 index 000000000..199d95508 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/runpod_launch.sh @@ -0,0 +1,60 @@ +#!/bin/bash +set -e +echo "=== Parameter Golf V6 RunPod Setup ===" +pip install sentencepiece zstandard huggingface_hub 2>/dev/null + +# Data setup +if [ ! -d "./data/datasets/fineweb10B_sp1024" ]; then + if [ -d "./datasets/fineweb10B_sp1024" ]; then + mkdir -p data + ln -sf "$(pwd)/datasets" data/datasets + ln -sf "$(pwd)/tokenizers" data/tokenizers + else + python3 cached_challenge_fineweb.py --variant sp1024 + mkdir -p data + ln -sf "$(pwd)/datasets" data/datasets + ln -sf "$(pwd)/tokenizers" data/tokenizers + fi +fi +echo "Data ready: $(ls data/datasets/fineweb10B_sp1024/ | wc -l) files" + +MODE=${1:-default} +SEED=${SEED:-42} +echo "=== Mode: $MODE | Seed: $SEED ===" + +case $MODE in + smoke) + # 60-second smoke test — catches crashes before burning a full run ($0.40 vs $8) + echo "SMOKE TEST: 60s training + quick eval — catching crashes early" + MAX_WALLCLOCK_SECONDS=60 VAL_LOSS_EVERY=0 NGRAM_EVAL_ORDER=0 \ + SEED=$SEED torchrun --standalone --nproc_per_node=8 train_gpt.py + echo "SMOKE TEST PASSED — safe to run full" + ;; + default) + echo "V6: 10L d=512 4KV LeakyReLU^2 XSA4 PartialRoPE VR EMA + 7-gram backoff + entropy-adaptive" + SEED=$SEED torchrun --standalone --nproc_per_node=8 train_gpt.py + ;; + fast) + # Smoke test then full run back-to-back + echo "=== SMOKE TEST (60s) ===" + MAX_WALLCLOCK_SECONDS=60 VAL_LOSS_EVERY=0 NGRAM_EVAL_ORDER=0 \ + SEED=$SEED torchrun --standalone --nproc_per_node=8 train_gpt.py + echo "=== SMOKE PASSED — LAUNCHING FULL RUN ===" + SEED=$SEED torchrun --standalone --nproc_per_node=8 train_gpt.py + ;; + no_ngram) + echo "Ablation: no n-gram cache" + NGRAM_EVAL_ORDER=0 SEED=$SEED torchrun --standalone --nproc_per_node=8 train_gpt.py + ;; + three_seed) + for S in 42 1337 2024; do + echo "=== Seed $S ===" + SEED=$S torchrun --standalone --nproc_per_node=8 train_gpt.py + done + ;; + *) + echo "Modes: smoke|default|fast|no_ngram|three_seed" + exit 1 + ;; +esac +echo "=== Done ===" diff --git a/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/submission.json b/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/submission.json new file mode 100644 index 000000000..d3ce93162 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/submission.json @@ -0,0 +1,10 @@ +{ + "author": "Bortlesboat", + "github_id": "Bortlesboat", + "name": "10L + Two-Pass Order-11 N-gram Backoff + Order-Adaptive Entropy", + "blurb": "10L d=512 GQA 8H/4KV, LeakyReLU(0.5)^2, Partial RoPE, LN Scale, XSA last 4, Value Residual, EMA(0.997). Two-pass eval: pass 1 builds hashed n-gram cache (orders 2-11) with order-adaptive entropy-gated alpha, pass 2 rescores early cold-cache windows with full cache. Mean of 3 seeds.", + "date": "2026-03-26", + "val_loss": 0.9898, + "val_bpb": 0.5863, + "bytes_total": 15420000 +} diff --git a/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/train_gpt.py b/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/train_gpt.py new file mode 100644 index 000000000..d4585f666 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/train_gpt.py @@ -0,0 +1,1648 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 0)) # 0=skip mid-train val, maximize training steps + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 5)) # minimal warmup, maximize real steps + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 10)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.03)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 64)) # larger batch for faster eval (no gradients) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + # Partial RoPE: only rotate first rope_dims dims (0 = full head_dim) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + + # LN Scale: dampen norm inputs by 1/sqrt(layer_idx+1) for deeper layers + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + + # XSA: exclusive self-attention on last N layers (0 = disabled) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # proven: last 4 layers + + # EMA: exponential moving average (replaces SWA when enabled) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "0"))) # OFF by default, EMA replaces it + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + # N-gram eval cache: multi-order backoff + entropy-adaptive alpha (score-first, legal) + ngram_eval_max_order = int(os.environ.get("NGRAM_EVAL_ORDER", 11)) # max n-gram order + ngram_eval_min_order = int(os.environ.get("NGRAM_EVAL_MIN_ORDER", 2)) # min backoff order + ngram_eval_alpha = float(os.environ.get("NGRAM_EVAL_ALPHA", 0.40)) # base alpha + ngram_eval_min_count = int(os.environ.get("NGRAM_EVAL_MIN_COUNT", 2)) + ngram_eval_buckets = int(os.environ.get("NGRAM_EVAL_BUCKETS", 4_194_304)) + ngram_eval_entropy = bool(int(os.environ.get("NGRAM_EVAL_ENTROPY", "1"))) + ngram_eval_ent_base = float(os.environ.get("NGRAM_EVAL_ENT_BASE", 0.05)) + ngram_eval_ent_range = float(os.environ.get("NGRAM_EVAL_ENT_RANGE", 0.55)) + ngram_eval_ent_scale = float(os.environ.get("NGRAM_EVAL_ENT_SCALE", 2.0)) + ngram_eval_ent_thresh = float(os.environ.get("NGRAM_EVAL_ENT_THRESH", 3.0)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +FP16_KEEP_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get("FP16_KEEP_NAME_PATTERNS", "tok_emb,blocks.8.attn.c_k").split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if "bigram" in name: + return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range+1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if cat == "mlp" else 31 # int5 for MLP, int6 for attention + q, s = quantize_intN_per_row(t, clip_range=clip) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{5 if cat == 'mlp' else 6}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, + rope_dims: int = 0, use_xsa: bool = False): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.rope_dims = rope_dims if rope_dims > 0 else self.head_dim + self.use_xsa = use_xsa + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.rope_dims, base=rope_base) + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Exclusive self-attention: subtract self-value from attention output.""" + # y is post-attention [bsz, heads, seq, head_dim], v is [bsz, kv_heads, seq, head_dim] + if self.num_kv_heads != self.num_heads: + v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + return y - v / v.size(2) + + def forward(self, x: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + # Value Residual: blend with layer-0 V + if v0 is not None: + v = 0.5 * (v + v0) + v_out = v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dims < self.head_dim: + # Partial RoPE: rotate only first rope_dims, pass rest through + q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] + k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + if self.num_kv_heads != self.num_heads: + n_rep = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(n_rep, dim=1) + v_sdpa = v.repeat_interleave(n_rep, dim=1) + else: + v_sdpa = v + y = F.scaled_dot_product_attention( + q, k, v_sdpa, attn_mask=None, is_causal=True, + ) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y), v_out + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, + qk_gain_init: float, rope_dims: int = 0, use_xsa: bool = False, ln_scale_factor: float = 1.0): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + rope_dims=rope_dims, use_xsa=use_xsa) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = ln_scale_factor + + def forward(self, x: Tensor, x0: Tensor, v0: Tensor | None = None) -> tuple[Tensor, Tensor]: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + attn_out, v_out = self.attn(self.attn_norm(x) * s, v0=v0) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x, v_out + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + rope_dims: int = 0, + ln_scale: bool = False, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + rope_dims=rope_dims, + use_xsa=(i >= num_layers - xsa_last_n) if xsa_last_n > 0 else False, + ln_scale_factor=1.0 / math.sqrt(i + 1) if ln_scale else 1.0) + for i in range(num_layers) + ]) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def _forward_body(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + v0: Tensor | None = None + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x, v_out = self.blocks[i](x, x0, v0=v0) + if v0 is None: + v0 = v_out + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x, _ = self.blocks[self.num_encoder_layers + i](x, x0, v0=v0) + return x + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self._forward_body(input_ids) + x = self.final_norm(x).reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self._forward_body(input_ids) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + eval_start = time.perf_counter() + eval_budget_s = 570.0 # 30s margin from 10-min eval budget + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + eval_elapsed = time.perf_counter() - eval_start + if eval_elapsed > eval_budget_s: + if rank == 0: + print(f" FAILSAFE: eval time {eval_elapsed:.0f}s exceeds {eval_budget_s}s budget, returning partial results", flush=True) + break + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ngram( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + """Sliding eval with multi-order n-gram backoff + entropy-adaptive alpha (score-first, legal).""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + max_order = args.ngram_eval_max_order + min_order = args.ngram_eval_min_order + buckets = args.ngram_eval_buckets + min_count = args.ngram_eval_min_count + use_entropy = args.ngram_eval_entropy + ent_base = args.ngram_eval_ent_base + ent_range = args.ngram_eval_ent_range + ent_scale = args.ngram_eval_ent_scale + ent_thresh = args.ngram_eval_ent_thresh + base_alpha = args.ngram_eval_alpha + n_orders = max_order - min_order + 1 + + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + val_np = val_tokens.numpy() + ctx_tables = [np.zeros((buckets,), dtype=np.uint32) for _ in range(n_orders)] + full_tables = [np.zeros((buckets,), dtype=np.uint32) for _ in range(n_orders)] + mask = np.uint64(buckets - 1) + primes = np.array( + [np.uint64(36313), np.uint64(27191), np.uint64(51647), np.uint64(81929), + np.uint64(131071), np.uint64(175447), np.uint64(209591), np.uint64(262147), + np.uint64(314159), np.uint64(393241), np.uint64(524287)], + dtype=np.uint64, + ) + + if rank == 0: + print(f"ngram_cache:enabled orders={min_order}-{max_order} backoff " + f"entropy={use_entropy} alpha={base_alpha} " + f"ent_base={ent_base} ent_range={ent_range} " + f"min_count={min_count} buckets={buckets}", flush=True) + + loss_sum = 0.0 + token_count = 0.0 + byte_count = 0.0 + per_window_scores: dict[int, tuple[float, float]] = {} # ws -> (loss, bytes) for pass 2 replacement + + eval_start = time.perf_counter() + eval_budget_s = 570.0 + # Pre-allocate eval buffers (avoid per-batch allocation) + x_buf = torch.zeros(batch_seqs, seq_len, dtype=torch.int64, device=device) + y_buf = torch.zeros(batch_seqs, seq_len, dtype=torch.int64, device=device) + base_model.eval() + # Compile eval path for faster inference + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + eval_elapsed = time.perf_counter() - eval_start + if eval_elapsed > eval_budget_s: + if rank == 0: + print(f" FAILSAFE: ngram eval time {eval_elapsed:.0f}s exceeds budget", flush=True) + break + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = x_buf[:bsz] + y_batch = y_buf[:bsz] + x_batch.zero_() + y_batch.zero_() + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len = wlen - s + if seg_len <= 0: + continue + + seg_nll = nll[i, s:wlen].to(torch.float64).cpu().numpy() + seg_model_p = np.exp(-seg_nll) + n_seg = len(seg_nll) + global_j = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + + # Entropy-adaptive alpha + if use_entropy: + with torch.no_grad(): + lp = F.log_softmax(logits[i, s:wlen].float(), dim=-1) + seg_ent = -(lp.exp() * lp).sum(dim=-1).cpu().numpy() + alpha_per_tok = ent_base + ent_range / ( + 1.0 + np.exp(-ent_scale * (seg_ent - ent_thresh))) + + # Precompute hashes for all orders + order_data = [] + for oi in range(n_orders): + ctx_w = min_order + oi - 1 + valid = global_j >= ctx_w + if not valid.any(): + order_data.append(None) + continue + v_idx = np.nonzero(valid)[0] + jv = global_j[v_idx] + ctx_hash = np.zeros(len(jv), dtype=np.uint64) + for k in range(ctx_w): + tok = val_np[jv - (ctx_w - k)].astype(np.uint64) + ctx_hash ^= tok * primes[k % len(primes)] + ctx_key = (ctx_hash & mask).astype(np.int64) + tgt_np = val_np[jv].astype(np.uint64) + full_key = ((ctx_hash ^ (tgt_np * primes[ctx_w % len(primes)])) & mask).astype(np.int64) + order_data.append((v_idx, ctx_key, full_key)) + + # Multi-order backoff: highest order first, track matched order + best_p_ng = np.full(n_seg, -1.0) + best_order = np.full(n_seg, -1, dtype=np.int32) + for oi in range(n_orders - 1, -1, -1): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + ctx_counts = ctx_tables[oi][ctx_key].astype(np.float64) + full_counts = full_tables[oi][full_key].astype(np.float64) + has_match = (ctx_counts >= float(min_count)) & (full_counts > 0) + needs_fill = has_match & (best_p_ng[v_idx] < 0) + if needs_fill.any(): + fill_idx = v_idx[needs_fill] + p = np.minimum(full_counts[needs_fill], ctx_counts[needs_fill]) / np.maximum(ctx_counts[needs_fill], 1.0) + best_p_ng[fill_idx] = np.clip(p, 0.0, 1.0) + best_order[fill_idx] = min_order + oi + + # Mix model probability with n-gram (order-adaptive entropy gating) + has_match = best_p_ng >= 0 + if has_match.any(): + if use_entropy: + matched_orders = best_order[has_match].astype(np.float64) + # Order-adaptive: higher-order matches trust n-gram at lower entropy + center = ent_thresh - 0.25 * (matched_orders - float(min_order)) + alpha = ent_base + ent_range / ( + 1.0 + np.exp(-ent_scale * (seg_ent[has_match] - center))) + else: + alpha = base_alpha + seg_model_p[has_match] = (1.0 - alpha) * seg_model_p[has_match] + alpha * best_p_ng[has_match] + seg_nll = -np.log(np.clip(seg_model_p, 1e-12, 1.0)) + + # Score-first: update ALL order tables AFTER scoring + for oi in range(n_orders): + if order_data[oi] is None: + continue + v_idx, ctx_key, full_key = order_data[oi] + np.add.at(ctx_tables[oi], ctx_key, 1) + np.add.at(full_tables[oi], full_key, 1) + + win_loss = float(seg_nll.sum()) + loss_sum += win_loss + token_count += float(seg_len) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + win_bytes = float(tb.sum().item()) + byte_count += win_bytes + per_window_scores[ws] = (win_loss, win_bytes) + + if rank == 0 and (bi // batch_seqs) % 200 == 0 and bi > 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + cur_bpb = (loss_sum / max(token_count, 1.0)) / math.log(2.0) * (token_count / max(byte_count, 1.0)) + elapsed = time.perf_counter() - eval_start + print(f" ngram_eval [{pct:5.1f}%] bpb={cur_bpb:.6f} t={elapsed:.0f}s", flush=True) + + # PASS 2: rescore early windows with full cache (no cache updates) + pass1_elapsed = time.perf_counter() - eval_start + pass2_budget = eval_budget_s - pass1_elapsed - 30 + rescore_limit = total_tokens // 4 # rescore first 25% of windows + if pass2_budget > 30 and len(per_window_scores) > 0: + if rank == 0: + p1_bpb = (loss_sum / max(token_count, 1.0)) / math.log(2.0) * (token_count / max(byte_count, 1.0)) + print(f" pass1 done: bpb={p1_bpb:.6f} t={pass1_elapsed:.0f}s — pass 2 rescore ({pass2_budget:.0f}s budget)", flush=True) + for bi in range(0, len(my_windows), batch_seqs): + if time.perf_counter() - eval_start > eval_budget_s - 10: + break + batch_ws = my_windows[bi:bi + batch_seqs] + if batch_ws[0] > rescore_limit: + break + bsz = len(batch_ws) + x_batch = x_buf[:bsz]; y_batch = y_buf[:bsz] + x_batch.zero_(); y_batch.zero_() + wlens2: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens2.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + with torch.inference_mode(): + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens2[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len = wlen - s + if seg_len <= 0: + continue + seg_nll = nll[i, s:wlen].to(torch.float64).cpu().numpy() + seg_model_p = np.exp(-seg_nll) + n_seg = len(seg_nll) + global_j = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + if use_entropy: + with torch.inference_mode(): + lp = F.log_softmax(logits[i, s:wlen].float(), dim=-1) + seg_ent = -(lp.exp() * lp).sum(dim=-1).cpu().numpy() + best_p_ng = np.full(n_seg, -1.0) + best_ord = np.full(n_seg, -1, dtype=np.int32) + for oi in range(n_orders - 1, -1, -1): + ctx_w = min_order + oi - 1 + valid = global_j >= ctx_w + if not valid.any(): + continue + v_idx = np.nonzero(valid)[0] + jv = global_j[v_idx] + ctx_hash = np.zeros(len(jv), dtype=np.uint64) + for k in range(ctx_w): + tok = val_np[jv - (ctx_w - k)].astype(np.uint64) + ctx_hash ^= tok * primes[k % len(primes)] + ctx_key = (ctx_hash & mask).astype(np.int64) + tgt2 = val_np[jv].astype(np.uint64) + full_key = ((ctx_hash ^ (tgt2 * primes[ctx_w % len(primes)])) & mask).astype(np.int64) + cc = ctx_tables[oi][ctx_key].astype(np.float64) + fc = full_tables[oi][full_key].astype(np.float64) + hm = (cc >= float(min_count)) & (fc > 0) + nf = hm & (best_p_ng[v_idx] < 0) + if nf.any(): + fi = v_idx[nf] + p = np.minimum(fc[nf], cc[nf]) / np.maximum(cc[nf], 1.0) + best_p_ng[fi] = np.clip(p, 0.0, 1.0) + best_ord[fi] = min_order + oi + hm = best_p_ng >= 0 + if hm.any(): + if use_entropy: + mo = best_ord[hm].astype(np.float64) + center = ent_thresh - 0.25 * (mo - float(min_order)) + a = ent_base + ent_range / (1.0 + np.exp(-ent_scale * (seg_ent[hm] - center))) + else: + a = base_alpha + seg_model_p[hm] = (1.0 - a) * seg_model_p[hm] + a * best_p_ng[hm] + p2_nll = -np.log(np.clip(seg_model_p, 1e-12, 1.0)) + p2_loss_val = float(p2_nll.sum()) + tgt = y_batch[i, s:wlen]; prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + p2_byte_val = float(tb.sum().item()) + # Replace pass 1 score with pass 2 score for this window + if ws in per_window_scores: + old_loss, old_bytes = per_window_scores[ws] + loss_sum = loss_sum - old_loss + p2_loss_val + byte_count = byte_count - old_bytes + p2_byte_val + per_window_scores[ws] = (p2_loss_val, p2_byte_val) + if rank == 0: + p2_bpb = (loss_sum / max(token_count, 1.0)) / math.log(2.0) * (token_count / max(byte_count, 1.0)) + print(f" pass2 done: bpb={p2_bpb:.6f} t={time.perf_counter()-eval_start:.0f}s", flush=True) + + _loss = torch.tensor(loss_sum, device=device, dtype=torch.float64) + _toks = torch.tensor(token_count, device=device, dtype=torch.float64) + _bytes = torch.tensor(byte_count, device=device, dtype=torch.float64) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(_toks, op=dist.ReduceOp.SUM) + dist.all_reduce(_bytes, op=dist.ReduceOp.SUM) + + val_loss = _loss.item() / max(_toks.item(), 1.0) + val_bpb = val_loss / math.log(2.0) * (_toks.item() / max(_bytes.item(), 1.0)) + # Coverage check: warn if eval was cut short + total_expected = sum(1 for ws in window_starts + if (min(ws + seq_len, total_tokens) - ws - (0 if ws == 0 else max(min(ws + seq_len, total_tokens) - ws - stride, 0))) > 0) + coverage = _toks.item() / max(total_expected * stride, 1.0) # approximate + elapsed = time.perf_counter() - eval_start + if rank == 0: + print(f" ngram_eval DONE: bpb={val_bpb:.6f} tokens={_toks.item():.0f} t={elapsed:.0f}s", flush=True) + if elapsed >= eval_budget_s - 10: + print(f" WARNING: eval used {elapsed:.0f}s of {eval_budget_s}s budget — results may be from partial coverage", flush=True) + base_model.train() + return val_loss, val_bpb + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + try: + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + except ImportError: + pass + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + xsa_last_n=args.xsa_last_n, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + torch._dynamo.config.optimize_ddp = False + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # EMA shadow model (kept on GPU to avoid PCIe bottleneck) + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().clone() for name, t in base_model.state_dict().items()} + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + + # EMA update every 10 steps (GPU-resident, amortize overhead) + if ema_state is not None and step % 10 == 0: + decay = args.ema_decay ** 10 # compensate for batched updates + with torch.no_grad(): + for name, param in base_model.state_dict().items(): + ema_state[name].lerp_(param.detach(), 1.0 - decay) + + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # Apply EMA if enabled (overrides SWA) + if args.ema_enabled and ema_state is not None: + log0("ema:applying shadow model") + current_state = base_model.state_dict() + ema_applied = { + name: tensor.to(dtype=current_state[name].dtype, device=current_state[name].device) + for name, tensor in ema_state.items() + } + base_model.load_state_dict(ema_applied, strict=True) + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # Magnitude pruning: zero out smallest weights to improve compression + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), 0.03) + mask = param.abs() < threshold + param.masked_fill_(mask, 0.0) + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + total_bytes = quant_file_bytes + code_bytes + log0(f"Total submission size: {total_bytes} bytes ({total_bytes/1e6:.2f} MB)") + if total_bytes > 16_000_000: + log0(f"FAILSAFE: artifact {total_bytes} bytes EXCEEDS 16MB limit! Aborting eval.") + sys.exit(1) + log0(f"SIZE CHECK PASSED: {total_bytes/1e6:.2f} MB < 16.00 MB") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + base_model.load_state_dict(deq_state, strict=True) + + # Sliding window eval on int6-roundtripped weights + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.ngram_eval_max_order >= 2 and args.eval_stride > 0: + log0(f"final_eval_mode:sliding_ngram orders={args.ngram_eval_min_order}-{args.ngram_eval_max_order} " + f"alpha={args.ngram_eval_alpha} entropy={args.ngram_eval_entropy} stride:{args.eval_stride}") + q_val_loss, q_val_bpb = eval_val_sliding_ngram( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + batch_seqs=args.eval_batch_seqs, + ) + elif args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() +# fixes applied +# tuned diff --git a/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/v7_seed1337_2024.log b/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/v7_seed1337_2024.log new file mode 100644 index 000000000..b34c4912d --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/v7_seed1337_2024.log @@ -0,0 +1,264 @@ +=== Seed 1337 === +W0326 17:03:56.792000 124377791210112 torch/distributed/run.py:779] +W0326 17:03:56.792000 124377791210112 torch/distributed/run.py:779] ***************************************** +W0326 17:03:56.792000 124377791210112 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 17:03:56.792000 124377791210112 torch/distributed/run.py:779] ***************************************** +logs/8dd0535e-c348-44df-ad0a-6e1c4a4cda56.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.03 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:5 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/5 +warmup_step:2/5 +warmup_step:3/5 +warmup_step:4/5 +warmup_step:5/5 +step:1/20000 train_loss:6.9296 train_time:216ms step_avg:215.78ms +step:2/20000 train_loss:7.7639 train_time:303ms step_avg:151.67ms +step:3/20000 train_loss:7.1031 train_time:401ms step_avg:133.77ms +step:4/20000 train_loss:7.9357 train_time:499ms step_avg:124.75ms +step:5/20000 train_loss:8.4837 train_time:597ms step_avg:119.34ms +step:6/20000 train_loss:8.2503 train_time:694ms step_avg:115.64ms +step:7/20000 train_loss:7.5696 train_time:792ms step_avg:113.09ms +step:8/20000 train_loss:6.8958 train_time:889ms step_avg:111.14ms +step:9/20000 train_loss:6.4975 train_time:987ms step_avg:109.63ms +step:10/20000 train_loss:6.2617 train_time:1096ms step_avg:109.59ms +step:100/20000 train_loss:3.1574 train_time:9909ms step_avg:99.09ms +step:200/20000 train_loss:2.3861 train_time:19803ms step_avg:99.01ms +step:300/20000 train_loss:2.5441 train_time:29705ms step_avg:99.02ms +step:400/20000 train_loss:2.4201 train_time:39653ms step_avg:99.13ms +step:500/20000 train_loss:2.3996 train_time:49568ms step_avg:99.14ms +step:600/20000 train_loss:2.3430 train_time:59559ms step_avg:99.26ms +step:700/20000 train_loss:2.3510 train_time:69567ms step_avg:99.38ms +step:800/20000 train_loss:2.2459 train_time:79569ms step_avg:99.46ms +step:900/20000 train_loss:2.1378 train_time:89595ms step_avg:99.55ms +step:1000/20000 train_loss:2.2855 train_time:99562ms step_avg:99.56ms +step:1100/20000 train_loss:2.3304 train_time:109568ms step_avg:99.61ms +step:1200/20000 train_loss:2.3641 train_time:119586ms step_avg:99.66ms +step:1300/20000 train_loss:2.1040 train_time:129604ms step_avg:99.70ms +step:1400/20000 train_loss:2.1972 train_time:139617ms step_avg:99.73ms +step:1500/20000 train_loss:2.2325 train_time:149562ms step_avg:99.71ms +step:1600/20000 train_loss:2.0814 train_time:159566ms step_avg:99.73ms +step:1700/20000 train_loss:2.1533 train_time:169575ms step_avg:99.75ms +step:1800/20000 train_loss:2.1602 train_time:179571ms step_avg:99.76ms +step:1900/20000 train_loss:2.1369 train_time:189529ms step_avg:99.75ms +step:2000/20000 train_loss:2.0774 train_time:199538ms step_avg:99.77ms +step:2100/20000 train_loss:2.0602 train_time:209542ms step_avg:99.78ms +step:2200/20000 train_loss:2.1384 train_time:219540ms step_avg:99.79ms +step:2300/20000 train_loss:2.1186 train_time:229554ms step_avg:99.81ms +step:2400/20000 train_loss:2.0771 train_time:239512ms step_avg:99.80ms +step:2500/20000 train_loss:2.1787 train_time:249522ms step_avg:99.81ms +step:2600/20000 train_loss:2.1204 train_time:259524ms step_avg:99.82ms +step:2700/20000 train_loss:2.1105 train_time:269530ms step_avg:99.83ms +step:2800/20000 train_loss:2.1624 train_time:279545ms step_avg:99.84ms +step:2900/20000 train_loss:2.0265 train_time:289571ms step_avg:99.85ms +step:3000/20000 train_loss:2.1591 train_time:299567ms step_avg:99.86ms +step:3100/20000 train_loss:2.0353 train_time:309564ms step_avg:99.86ms +step:3200/20000 train_loss:2.1697 train_time:319564ms step_avg:99.86ms +step:3300/20000 train_loss:2.0658 train_time:329503ms step_avg:99.85ms +step:3400/20000 train_loss:2.0166 train_time:339490ms step_avg:99.85ms +step:3500/20000 train_loss:2.1733 train_time:349473ms step_avg:99.85ms +step:3600/20000 train_loss:2.0879 train_time:359467ms step_avg:99.85ms +step:3700/20000 train_loss:2.0844 train_time:369462ms step_avg:99.85ms +step:3800/20000 train_loss:2.0628 train_time:379398ms step_avg:99.84ms +step:3900/20000 train_loss:2.0640 train_time:389389ms step_avg:99.84ms +step:4000/20000 train_loss:1.9633 train_time:399362ms step_avg:99.84ms +step:4100/20000 train_loss:1.9994 train_time:409336ms step_avg:99.84ms +step:4200/20000 train_loss:2.1348 train_time:419316ms step_avg:99.84ms +step:4300/20000 train_loss:2.0440 train_time:429245ms step_avg:99.82ms +step:4400/20000 train_loss:2.0147 train_time:439215ms step_avg:99.82ms +step:4500/20000 train_loss:2.1072 train_time:449196ms step_avg:99.82ms +step:4600/20000 train_loss:1.8222 train_time:459173ms step_avg:99.82ms +step:4700/20000 train_loss:2.2205 train_time:469104ms step_avg:99.81ms +step:4800/20000 train_loss:2.4139 train_time:479079ms step_avg:99.81ms +step:4900/20000 train_loss:2.0305 train_time:489057ms step_avg:99.81ms +step:5000/20000 train_loss:2.0805 train_time:499032ms step_avg:99.81ms +step:5100/20000 train_loss:2.1071 train_time:508998ms step_avg:99.80ms +step:5200/20000 train_loss:2.0213 train_time:518912ms step_avg:99.79ms +step:5300/20000 train_loss:1.9841 train_time:528894ms step_avg:99.79ms +step:5400/20000 train_loss:2.0260 train_time:538868ms step_avg:99.79ms +step:5500/20000 train_loss:1.9922 train_time:548844ms step_avg:99.79ms +step:5600/20000 train_loss:1.9294 train_time:558820ms step_avg:99.79ms +step:5700/20000 train_loss:1.9873 train_time:568735ms step_avg:99.78ms +step:5800/20000 train_loss:1.9701 train_time:578708ms step_avg:99.78ms +step:5900/20000 train_loss:1.8728 train_time:588675ms step_avg:99.78ms +step:6000/20000 train_loss:1.9155 train_time:598646ms step_avg:99.77ms +step:6014/20000 val_loss:1.9516 val_bpb:1.1559 train_time:600036ms step_avg:99.77ms +stopping_early: wallclock_cap train_time:600036ms step:6014/20000 +peak memory allocated: 25196 MiB reserved: 25954 MiB +ema:applying shadow model +Serialized model: 96864150 bytes +Code size: 74693 bytes +Total submission size: 96938843 bytes +Serialized model int6+zstd: 15494848 bytes +Total submission size: 15569541 bytes (15.57 MB) +SIZE CHECK PASSED: 15.57 MB < 16.00 MB +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +final_eval_mode:sliding_ngram orders=2-11 alpha=0.4 entropy=True stride:64 +ngram_cache:enabled orders=2-11 backoff entropy=True alpha=0.4 ent_base=0.05 ent_range=0.55 min_count=2 buckets=4194304 + ngram_eval [ 10.6%] bpb=1.109455 t=29s + ngram_eval [ 21.2%] bpb=1.068450 t=46s + ngram_eval [ 31.8%] bpb=1.011958 t=63s + ngram_eval [ 42.3%] bpb=0.947462 t=80s + ngram_eval [ 52.9%] bpb=0.884724 t=97s + ngram_eval [ 63.5%] bpb=0.826742 t=114s + ngram_eval [ 74.0%] bpb=0.777297 t=132s + ngram_eval [ 84.6%] bpb=0.734486 t=149s + ngram_eval [ 95.2%] bpb=0.697006 t=166s + pass1 done: bpb=0.681880 t=186s — pass 2 rescore (354s budget) + pass2 done: bpb=0.313750 t=326s + ngram_eval DONE: bpb=0.586354 tokens=62023616 t=328s +final_int8_zlib_roundtrip val_loss:0.9900 val_bpb:0.5864 eval_time:328181ms +final_int8_zlib_roundtrip_exact val_loss:0.99003173 val_bpb:0.58635411 +=== Seed 2024 === +W0326 17:21:40.590000 139373416821376 torch/distributed/run.py:779] +W0326 17:21:40.590000 139373416821376 torch/distributed/run.py:779] ***************************************** +W0326 17:21:40.590000 139373416821376 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 17:21:40.590000 139373416821376 torch/distributed/run.py:779] ***************************************** +logs/b2187c77-14cb-49a7-b776-d9056b8613a3.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.03 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:5 max_wallclock_seconds:600.000 +seed:2024 +warmup_step:1/5 +warmup_step:2/5 +warmup_step:3/5 +warmup_step:4/5 +warmup_step:5/5 +step:1/20000 train_loss:6.9305 train_time:216ms step_avg:216.11ms +step:2/20000 train_loss:7.7296 train_time:304ms step_avg:151.76ms +step:3/20000 train_loss:7.0936 train_time:402ms step_avg:133.87ms +step:4/20000 train_loss:8.0126 train_time:500ms step_avg:125.10ms +step:5/20000 train_loss:8.4751 train_time:600ms step_avg:119.91ms +step:6/20000 train_loss:8.3002 train_time:697ms step_avg:116.22ms +step:7/20000 train_loss:7.6261 train_time:795ms step_avg:113.55ms +step:8/20000 train_loss:6.9188 train_time:893ms step_avg:111.59ms +step:9/20000 train_loss:6.4766 train_time:991ms step_avg:110.12ms +step:10/20000 train_loss:6.1230 train_time:1101ms step_avg:110.14ms +step:100/20000 train_loss:3.1602 train_time:9925ms step_avg:99.25ms +step:200/20000 train_loss:2.3645 train_time:19843ms step_avg:99.22ms +step:300/20000 train_loss:2.5143 train_time:29759ms step_avg:99.20ms +step:400/20000 train_loss:2.4014 train_time:39703ms step_avg:99.26ms +step:500/20000 train_loss:2.3887 train_time:49641ms step_avg:99.28ms +step:600/20000 train_loss:2.3275 train_time:59651ms step_avg:99.42ms +step:700/20000 train_loss:2.3508 train_time:69674ms step_avg:99.53ms +step:800/20000 train_loss:2.2375 train_time:79682ms step_avg:99.60ms +step:900/20000 train_loss:2.1300 train_time:89708ms step_avg:99.68ms +step:1000/20000 train_loss:2.2775 train_time:99670ms step_avg:99.67ms +step:1100/20000 train_loss:2.3276 train_time:109686ms step_avg:99.71ms +step:1200/20000 train_loss:2.3636 train_time:119697ms step_avg:99.75ms +step:1300/20000 train_loss:2.1062 train_time:129717ms step_avg:99.78ms +step:1400/20000 train_loss:2.1949 train_time:139741ms step_avg:99.81ms +step:1500/20000 train_loss:2.2350 train_time:149727ms step_avg:99.82ms +step:1600/20000 train_loss:2.0831 train_time:159752ms step_avg:99.84ms +step:1700/20000 train_loss:2.1476 train_time:169773ms step_avg:99.87ms +step:1800/20000 train_loss:2.1622 train_time:179798ms step_avg:99.89ms +step:1900/20000 train_loss:2.1349 train_time:189769ms step_avg:99.88ms +step:2000/20000 train_loss:2.0767 train_time:199788ms step_avg:99.89ms +step:2100/20000 train_loss:2.0582 train_time:209805ms step_avg:99.91ms +step:2200/20000 train_loss:2.1548 train_time:219818ms step_avg:99.92ms +step:2300/20000 train_loss:2.1194 train_time:229832ms step_avg:99.93ms +step:2400/20000 train_loss:2.0813 train_time:239790ms step_avg:99.91ms +step:2500/20000 train_loss:2.1799 train_time:249802ms step_avg:99.92ms +step:2600/20000 train_loss:2.1200 train_time:259799ms step_avg:99.92ms +step:2700/20000 train_loss:2.1088 train_time:269804ms step_avg:99.93ms +step:2800/20000 train_loss:2.1567 train_time:279792ms step_avg:99.93ms +step:2900/20000 train_loss:2.0280 train_time:289751ms step_avg:99.91ms +step:3000/20000 train_loss:2.1654 train_time:299759ms step_avg:99.92ms +step:3100/20000 train_loss:2.0344 train_time:309757ms step_avg:99.92ms +step:3200/20000 train_loss:2.1707 train_time:319750ms step_avg:99.92ms +step:3300/20000 train_loss:2.0634 train_time:329685ms step_avg:99.90ms +step:3400/20000 train_loss:2.0144 train_time:339684ms step_avg:99.91ms +step:3500/20000 train_loss:2.1698 train_time:349673ms step_avg:99.91ms +step:3600/20000 train_loss:2.0868 train_time:359673ms step_avg:99.91ms +step:3700/20000 train_loss:2.0830 train_time:369672ms step_avg:99.91ms +step:3800/20000 train_loss:2.0647 train_time:379618ms step_avg:99.90ms +step:3900/20000 train_loss:2.0668 train_time:389616ms step_avg:99.90ms +step:4000/20000 train_loss:1.9660 train_time:399608ms step_avg:99.90ms +step:4100/20000 train_loss:2.0018 train_time:409598ms step_avg:99.90ms +step:4200/20000 train_loss:2.1359 train_time:419595ms step_avg:99.90ms +step:4300/20000 train_loss:2.0452 train_time:429612ms step_avg:99.91ms +step:4400/20000 train_loss:2.0181 train_time:439593ms step_avg:99.91ms +step:4500/20000 train_loss:2.1098 train_time:449573ms step_avg:99.91ms +step:4600/20000 train_loss:1.8218 train_time:459549ms step_avg:99.90ms +step:4700/20000 train_loss:2.2167 train_time:469470ms step_avg:99.89ms +step:4800/20000 train_loss:2.4093 train_time:479445ms step_avg:99.88ms +step:4900/20000 train_loss:2.0320 train_time:489427ms step_avg:99.88ms +step:5000/20000 train_loss:2.0879 train_time:499400ms step_avg:99.88ms +step:5100/20000 train_loss:2.1082 train_time:509371ms step_avg:99.88ms +step:5200/20000 train_loss:2.0213 train_time:519292ms step_avg:99.86ms +step:5300/20000 train_loss:1.9890 train_time:529273ms step_avg:99.86ms +step:5400/20000 train_loss:2.0276 train_time:539239ms step_avg:99.86ms +step:5500/20000 train_loss:1.9921 train_time:549213ms step_avg:99.86ms +step:5600/20000 train_loss:1.9288 train_time:559183ms step_avg:99.85ms +step:5700/20000 train_loss:1.9879 train_time:569100ms step_avg:99.84ms +step:5800/20000 train_loss:1.9691 train_time:579082ms step_avg:99.84ms +step:5900/20000 train_loss:1.8715 train_time:589052ms step_avg:99.84ms +step:6000/20000 train_loss:1.9150 train_time:599025ms step_avg:99.84ms +step:6010/20000 val_loss:1.9535 val_bpb:1.1570 train_time:600022ms step_avg:99.84ms +stopping_early: wallclock_cap train_time:600022ms step:6010/20000 +peak memory allocated: 25196 MiB reserved: 25954 MiB +ema:applying shadow model +Serialized model: 96864150 bytes +Code size: 74693 bytes +Total submission size: 96938843 bytes +Serialized model int6+zstd: 15292867 bytes +Total submission size: 15367560 bytes (15.37 MB) +SIZE CHECK PASSED: 15.37 MB < 16.00 MB +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +final_eval_mode:sliding_ngram orders=2-11 alpha=0.4 entropy=True stride:64 +ngram_cache:enabled orders=2-11 backoff entropy=True alpha=0.4 ent_base=0.05 ent_range=0.55 min_count=2 buckets=4194304 + ngram_eval [ 10.6%] bpb=1.108912 t=29s + ngram_eval [ 21.2%] bpb=1.068250 t=46s + ngram_eval [ 31.8%] bpb=1.011687 t=63s + ngram_eval [ 42.3%] bpb=0.947163 t=80s + ngram_eval [ 52.9%] bpb=0.884461 t=97s + ngram_eval [ 63.5%] bpb=0.826488 t=114s + ngram_eval [ 74.0%] bpb=0.777036 t=131s + ngram_eval [ 84.6%] bpb=0.734185 t=148s + ngram_eval [ 95.2%] bpb=0.696669 t=165s + pass1 done: bpb=0.681548 t=186s — pass 2 rescore (354s budget) + pass2 done: bpb=0.313319 t=325s + ngram_eval DONE: bpb=0.585981 tokens=62023616 t=332s +final_int8_zlib_roundtrip val_loss:0.9894 val_bpb:0.5860 eval_time:332204ms +final_int8_zlib_roundtrip_exact val_loss:0.98940154 val_bpb:0.58598088 diff --git a/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/v7_seed42.log b/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/v7_seed42.log new file mode 100644 index 000000000..c3f5358b6 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_10L_TwoPass_Order11_Ngram/v7_seed42.log @@ -0,0 +1,264 @@ +=== Parameter Golf V6 RunPod Setup === +Requirement already satisfied: sentencepiece in /usr/local/lib/python3.11/dist-packages (0.2.1) +Requirement already satisfied: zstandard in /usr/local/lib/python3.11/dist-packages (0.25.0) +Requirement already satisfied: huggingface_hub in /usr/local/lib/python3.11/dist-packages (1.8.0) +Requirement already satisfied: filelock>=3.10.0 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (3.13.1) +Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (2024.2.0) +Requirement already satisfied: hf-xet<2.0.0,>=1.4.2 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (1.4.2) +Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (0.27.2) +Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (24.1) +Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (6.0.2) +Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (4.67.3) +Requirement already satisfied: typer in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (0.24.1) +Requirement already satisfied: typing-extensions>=4.1.0 in /usr/local/lib/python3.11/dist-packages (from huggingface_hub) (4.9.0) +Requirement already satisfied: anyio in /usr/local/lib/python3.11/dist-packages (from httpx<1,>=0.23.0->huggingface_hub) (4.6.0) +Requirement already satisfied: certifi in /usr/local/lib/python3.11/dist-packages (from httpx<1,>=0.23.0->huggingface_hub) (2024.8.30) +Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.11/dist-packages (from httpx<1,>=0.23.0->huggingface_hub) (1.0.5) +Requirement already satisfied: idna in /usr/local/lib/python3.11/dist-packages (from httpx<1,>=0.23.0->huggingface_hub) (3.10) +Requirement already satisfied: sniffio in /usr/local/lib/python3.11/dist-packages (from httpx<1,>=0.23.0->huggingface_hub) (1.3.1) +Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.11/dist-packages (from httpcore==1.*->httpx<1,>=0.23.0->huggingface_hub) (0.14.0) +Requirement already satisfied: click>=8.2.1 in /usr/local/lib/python3.11/dist-packages (from typer->huggingface_hub) (8.3.1) +Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from typer->huggingface_hub) (1.5.4) +Requirement already satisfied: rich>=12.3.0 in /usr/local/lib/python3.11/dist-packages (from typer->huggingface_hub) (14.3.3) +Requirement already satisfied: annotated-doc>=0.0.2 in /usr/local/lib/python3.11/dist-packages (from typer->huggingface_hub) (0.0.4) +Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich>=12.3.0->typer->huggingface_hub) (4.0.0) +Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich>=12.3.0->typer->huggingface_hub) (2.18.0) +Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich>=12.3.0->typer->huggingface_hub) (0.1.2) +Data ready: 81 files +=== Mode: fast | Seed: 42 === +=== SMOKE TEST (60s) === +W0326 16:37:56.816000 136133524791936 torch/distributed/run.py:779] +W0326 16:37:56.816000 136133524791936 torch/distributed/run.py:779] ***************************************** +W0326 16:37:56.816000 136133524791936 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 16:37:56.816000 136133524791936 torch/distributed/run.py:779] ***************************************** +logs/6cd0790d-006f-4d6b-b8d7-a66be05fc961.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.03 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:5 max_wallclock_seconds:60.000 +seed:42 +warmup_step:1/5 +warmup_step:2/5 +warmup_step:3/5 +warmup_step:4/5 +warmup_step:5/5 +step:1/20000 train_loss:6.9311 train_time:203ms step_avg:203.45ms +step:2/20000 train_loss:7.7074 train_time:291ms step_avg:145.47ms +step:3/20000 train_loss:7.5092 train_time:389ms step_avg:129.58ms +step:4/20000 train_loss:7.1247 train_time:486ms step_avg:121.62ms +step:5/20000 train_loss:6.7138 train_time:584ms step_avg:116.83ms +step:6/20000 train_loss:6.4142 train_time:682ms step_avg:113.62ms +step:7/20000 train_loss:6.1668 train_time:780ms step_avg:111.42ms +step:8/20000 train_loss:6.0001 train_time:878ms step_avg:109.71ms +step:9/20000 train_loss:5.8437 train_time:975ms step_avg:108.36ms +step:10/20000 train_loss:5.7233 train_time:1084ms step_avg:108.43ms +step:100/20000 train_loss:3.4484 train_time:9887ms step_avg:98.87ms +step:200/20000 train_loss:2.7017 train_time:19762ms step_avg:98.81ms +step:300/20000 train_loss:2.7795 train_time:29718ms step_avg:99.06ms +step:400/20000 train_loss:2.6303 train_time:39685ms step_avg:99.21ms +step:500/20000 train_loss:2.5823 train_time:49565ms step_avg:99.13ms +step:600/20000 train_loss:2.5369 train_time:59551ms step_avg:99.25ms +step:605/20000 val_loss:2.5943 val_bpb:1.5365 train_time:60056ms step_avg:99.27ms +stopping_early: wallclock_cap train_time:60056ms step:605/20000 +peak memory allocated: 25387 MiB reserved: 26052 MiB +ema:applying shadow model +Serialized model: 96864150 bytes +Code size: 74693 bytes +Total submission size: 96938843 bytes +Serialized model int6+zstd: 15308056 bytes +Total submission size: 15382749 bytes (15.38 MB) +SIZE CHECK PASSED: 15.38 MB < 16.00 MB +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +final_eval_mode:sliding_window stride:64 batch_seqs:64 + sliding_eval [ 0.1%] 64/121136 windows running_bpb=1.953489 + sliding_eval [ 2.7%] 3264/121136 windows running_bpb=1.973931 + sliding_eval [ 5.3%] 6464/121136 windows running_bpb=2.006778 + sliding_eval [ 8.0%] 9664/121136 windows running_bpb=2.012843 + sliding_eval [ 10.6%] 12864/121136 windows running_bpb=2.001132 + sliding_eval [ 13.3%] 16064/121136 windows running_bpb=2.010183 + sliding_eval [ 15.9%] 19264/121136 windows running_bpb=2.004012 + sliding_eval [ 18.5%] 22464/121136 windows running_bpb=2.004759 + sliding_eval [ 21.2%] 25664/121136 windows running_bpb=2.010635 + sliding_eval [ 23.8%] 28864/121136 windows running_bpb=2.014743 + sliding_eval [ 26.5%] 32064/121136 windows running_bpb=2.016716 + sliding_eval [ 29.1%] 35264/121136 windows running_bpb=2.014031 + sliding_eval [ 31.8%] 38464/121136 windows running_bpb=2.013862 + sliding_eval [ 34.4%] 41664/121136 windows running_bpb=2.016488 + sliding_eval [ 37.0%] 44864/121136 windows running_bpb=2.018577 + sliding_eval [ 39.7%] 48064/121136 windows running_bpb=2.016559 + sliding_eval [ 42.3%] 51264/121136 windows running_bpb=2.018850 + sliding_eval [ 45.0%] 54464/121136 windows running_bpb=2.019458 + sliding_eval [ 47.6%] 57664/121136 windows running_bpb=2.021461 + sliding_eval [ 50.2%] 60864/121136 windows running_bpb=2.018279 + sliding_eval [ 52.9%] 64064/121136 windows running_bpb=2.018077 + sliding_eval [ 55.5%] 67264/121136 windows running_bpb=2.016107 + sliding_eval [ 58.2%] 70464/121136 windows running_bpb=2.014263 + sliding_eval [ 60.8%] 73664/121136 windows running_bpb=2.014411 + sliding_eval [ 63.5%] 76864/121136 windows running_bpb=2.014282 + sliding_eval [ 66.1%] 80064/121136 windows running_bpb=2.014640 + sliding_eval [ 68.7%] 83264/121136 windows running_bpb=2.016442 + sliding_eval [ 71.4%] 86464/121136 windows running_bpb=2.016527 + sliding_eval [ 74.0%] 89664/121136 windows running_bpb=2.017595 + sliding_eval [ 76.7%] 92864/121136 windows running_bpb=2.017657 + sliding_eval [ 79.3%] 96064/121136 windows running_bpb=2.018340 + sliding_eval [ 81.9%] 99264/121136 windows running_bpb=2.021236 + sliding_eval [ 84.6%] 102464/121136 windows running_bpb=2.020969 + sliding_eval [ 87.2%] 105664/121136 windows running_bpb=2.019871 + sliding_eval [ 89.9%] 108864/121136 windows running_bpb=2.020678 + sliding_eval [ 92.5%] 112064/121136 windows running_bpb=2.020367 + sliding_eval [ 95.2%] 115264/121136 windows running_bpb=2.021376 + sliding_eval [ 97.8%] 118464/121136 windows running_bpb=2.022046 +final_int8_zlib_roundtrip val_loss:3.3598 val_bpb:1.9899 eval_time:251451ms +final_int8_zlib_roundtrip_exact val_loss:3.35982587 val_bpb:1.98988339 +=== SMOKE PASSED — LAUNCHING FULL RUN === +W0326 16:45:43.047000 126394035827328 torch/distributed/run.py:779] +W0326 16:45:43.047000 126394035827328 torch/distributed/run.py:779] ***************************************** +W0326 16:45:43.047000 126394035827328 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 16:45:43.047000 126394035827328 torch/distributed/run.py:779] ***************************************** +logs/10db030f-61d6-485f-a058-538b0f3c498d.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.03 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:5 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/5 +warmup_step:2/5 +warmup_step:3/5 +warmup_step:4/5 +warmup_step:5/5 +step:1/20000 train_loss:6.9311 train_time:213ms step_avg:212.51ms +step:2/20000 train_loss:7.7074 train_time:300ms step_avg:149.92ms +step:3/20000 train_loss:7.1305 train_time:398ms step_avg:132.60ms +step:4/20000 train_loss:8.1579 train_time:496ms step_avg:123.88ms +step:5/20000 train_loss:8.4603 train_time:593ms step_avg:118.68ms +step:6/20000 train_loss:8.2635 train_time:691ms step_avg:115.11ms +step:7/20000 train_loss:7.5430 train_time:789ms step_avg:112.73ms +step:8/20000 train_loss:6.9729 train_time:886ms step_avg:110.80ms +step:9/20000 train_loss:6.5174 train_time:985ms step_avg:109.46ms +step:10/20000 train_loss:6.2141 train_time:1095ms step_avg:109.48ms +step:100/20000 train_loss:3.1552 train_time:9925ms step_avg:99.25ms +step:200/20000 train_loss:2.3592 train_time:19845ms step_avg:99.23ms +step:300/20000 train_loss:2.5308 train_time:29800ms step_avg:99.33ms +step:400/20000 train_loss:2.4065 train_time:39768ms step_avg:99.42ms +step:500/20000 train_loss:2.3899 train_time:49712ms step_avg:99.42ms +step:600/20000 train_loss:2.3353 train_time:59730ms step_avg:99.55ms +step:700/20000 train_loss:2.3543 train_time:69771ms step_avg:99.67ms +step:800/20000 train_loss:2.2414 train_time:79810ms step_avg:99.76ms +step:900/20000 train_loss:2.1325 train_time:89874ms step_avg:99.86ms +step:1000/20000 train_loss:2.2826 train_time:99870ms step_avg:99.87ms +step:1100/20000 train_loss:2.3315 train_time:109892ms step_avg:99.90ms +step:1200/20000 train_loss:2.3642 train_time:119929ms step_avg:99.94ms +step:1300/20000 train_loss:2.1092 train_time:129949ms step_avg:99.96ms +step:1400/20000 train_loss:2.1943 train_time:139995ms step_avg:100.00ms +step:1500/20000 train_loss:2.2346 train_time:149971ms step_avg:99.98ms +step:1600/20000 train_loss:2.0868 train_time:160007ms step_avg:100.00ms +step:1700/20000 train_loss:2.1491 train_time:170024ms step_avg:100.01ms +step:1800/20000 train_loss:2.1609 train_time:180117ms step_avg:100.07ms +step:1900/20000 train_loss:2.1342 train_time:190079ms step_avg:100.04ms +step:2000/20000 train_loss:2.0805 train_time:200110ms step_avg:100.06ms +step:2100/20000 train_loss:2.0578 train_time:210125ms step_avg:100.06ms +step:2200/20000 train_loss:2.1432 train_time:220149ms step_avg:100.07ms +step:2300/20000 train_loss:2.1170 train_time:230164ms step_avg:100.07ms +step:2400/20000 train_loss:2.0764 train_time:240108ms step_avg:100.04ms +step:2500/20000 train_loss:2.1832 train_time:250110ms step_avg:100.04ms +step:2600/20000 train_loss:2.1206 train_time:260135ms step_avg:100.05ms +step:2700/20000 train_loss:2.1105 train_time:270160ms step_avg:100.06ms +step:2800/20000 train_loss:2.1607 train_time:280168ms step_avg:100.06ms +step:2900/20000 train_loss:2.0283 train_time:290102ms step_avg:100.04ms +step:3000/20000 train_loss:2.1637 train_time:300113ms step_avg:100.04ms +step:3100/20000 train_loss:2.0392 train_time:310112ms step_avg:100.04ms +step:3200/20000 train_loss:2.1737 train_time:320113ms step_avg:100.04ms +step:3300/20000 train_loss:2.0687 train_time:330057ms step_avg:100.02ms +step:3400/20000 train_loss:2.0186 train_time:340065ms step_avg:100.02ms +step:3500/20000 train_loss:2.1715 train_time:350073ms step_avg:100.02ms +step:3600/20000 train_loss:2.0887 train_time:360073ms step_avg:100.02ms +step:3700/20000 train_loss:2.0874 train_time:370079ms step_avg:100.02ms +step:3800/20000 train_loss:2.0642 train_time:380031ms step_avg:100.01ms +step:3900/20000 train_loss:2.0667 train_time:390031ms step_avg:100.01ms +step:4000/20000 train_loss:1.9639 train_time:400042ms step_avg:100.01ms +step:4100/20000 train_loss:2.0006 train_time:410023ms step_avg:100.01ms +step:4200/20000 train_loss:2.1345 train_time:420036ms step_avg:100.01ms +step:4300/20000 train_loss:2.0466 train_time:429960ms step_avg:99.99ms +step:4400/20000 train_loss:2.0177 train_time:439946ms step_avg:99.99ms +step:4500/20000 train_loss:2.1120 train_time:449933ms step_avg:99.99ms +step:4600/20000 train_loss:1.8270 train_time:459932ms step_avg:99.99ms +step:4700/20000 train_loss:2.2224 train_time:469852ms step_avg:99.97ms +step:4800/20000 train_loss:2.4112 train_time:479836ms step_avg:99.97ms +step:4900/20000 train_loss:2.0314 train_time:489822ms step_avg:99.96ms +step:5000/20000 train_loss:2.0864 train_time:499815ms step_avg:99.96ms +step:5100/20000 train_loss:2.1090 train_time:509812ms step_avg:99.96ms +step:5200/20000 train_loss:2.0191 train_time:519749ms step_avg:99.95ms +step:5300/20000 train_loss:1.9852 train_time:529746ms step_avg:99.95ms +step:5400/20000 train_loss:2.0233 train_time:539731ms step_avg:99.95ms +step:5500/20000 train_loss:1.9939 train_time:549733ms step_avg:99.95ms +step:5600/20000 train_loss:1.9293 train_time:559723ms step_avg:99.95ms +step:5700/20000 train_loss:1.9900 train_time:569655ms step_avg:99.94ms +step:5800/20000 train_loss:1.9675 train_time:579661ms step_avg:99.94ms +step:5900/20000 train_loss:1.8747 train_time:589659ms step_avg:99.94ms +step:6000/20000 train_loss:1.9180 train_time:599663ms step_avg:99.94ms +step:6004/20000 val_loss:1.9539 val_bpb:1.1572 train_time:600065ms step_avg:99.94ms +stopping_early: wallclock_cap train_time:600065ms step:6004/20000 +peak memory allocated: 25196 MiB reserved: 25954 MiB +ema:applying shadow model +Serialized model: 96864150 bytes +Code size: 74693 bytes +Total submission size: 96938843 bytes +Serialized model int6+zstd: 15342411 bytes +Total submission size: 15417104 bytes (15.42 MB) +SIZE CHECK PASSED: 15.42 MB < 16.00 MB +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1605: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +final_eval_mode:sliding_ngram orders=2-11 alpha=0.4 entropy=True stride:64 +ngram_cache:enabled orders=2-11 backoff entropy=True alpha=0.4 ent_base=0.05 ent_range=0.55 min_count=2 buckets=4194304 + ngram_eval [ 10.6%] bpb=1.109855 t=31s + ngram_eval [ 21.2%] bpb=1.068941 t=48s + ngram_eval [ 31.8%] bpb=1.012453 t=65s + ngram_eval [ 42.3%] bpb=0.947853 t=82s + ngram_eval [ 52.9%] bpb=0.885035 t=99s + ngram_eval [ 63.5%] bpb=0.826994 t=116s + ngram_eval [ 74.0%] bpb=0.777498 t=133s + ngram_eval [ 84.6%] bpb=0.734604 t=150s + ngram_eval [ 95.2%] bpb=0.697058 t=167s + pass1 done: bpb=0.681932 t=189s — pass 2 rescore (351s budget) + pass2 done: bpb=0.313496 t=328s + ngram_eval DONE: bpb=0.586406 tokens=62023616 t=331s +final_int8_zlib_roundtrip val_loss:0.9901 val_bpb:0.5864 eval_time:331064ms +final_int8_zlib_roundtrip_exact val_loss:0.99011876 val_bpb:0.58640565 +=== Done ===