diff --git a/records/track_10min_16mb/2026-03-23_LoquiAuris_10L_LoRA_TTT_1.0865/README.md b/records/track_10min_16mb/2026-03-23_LoquiAuris_10L_LoRA_TTT_1.0865/README.md new file mode 100644 index 000000000..906c4c667 --- /dev/null +++ b/records/track_10min_16mb/2026-03-23_LoquiAuris_10L_LoRA_TTT_1.0865/README.md @@ -0,0 +1,126 @@ +# Record Submission: 10L d=512 EMA + LoRA TTT + +**Author:** Loqui Auris ([@LoquiAuris](https://github.com/LoquiAuris)) +**val_bpb:** 1.0865 (mean of 2 seeds, std=0.0013) +**Artifact size:** 15,810,855 bytes (15.81 MB) +**Training time:** ~10 minutes on 8×H100 + +## Results + +| Seed | Pre-TTT val_bpb | Post-LoRA-TTT val_bpb | Artifact (bytes) | Steps | +|------|----------------|----------------------|------------------|-------| +| 42 | ~1.1610 | 1.0856 | 15,810,855 | 5,978 | +| 1337 | ~1.1610 | 1.0874 | 15,705,529 | 5,969 | +| Mean | | **1.0865 ±0.0013** | | | + +Third seed pending (compute grant). + +## Approach + +### Architecture + +Standard PR #162 transformer stack with the following configuration: + +- 10 layers, d_model=512, 8 attention heads, 4 KV heads (GQA) +- 3× FFN expansion (hidden=1536) with ReLU² activation +- SmearGate: learned blend with previous token representation +- BigramHash: 4096 buckets, dim=128, projected to 512 +- U-Net skip connections between symmetric layer pairs +- RMSNorm, logit softcap=30.0, orthogonal initialization +- RoPE positional encoding (persistent=False) +- Tied embeddings via `F.linear(x, tok_emb.weight)` +- Vocabulary: sp1024 (1,024 BPE tokens) +- ~24.7M parameters + +### Training + +- Optimizer: Muon (matrix_lr=0.02, momentum=0.99 with warmup from 0.92 over 1500 steps) + AdamW for embeddings and scalars +- Weight decay: 0.04 (Muon), 0.01 (AdamW) +- Gradient clipping: 0.3 +- Sequence length: 2048 +- Batch size: 786,432 tokens +- Warmup: 20 steps +- Warmdown: 3000 iterations (wallclock-based cosine schedule) +- EMA: decay=0.997, applied after training completes +- Steps completed: ~5,970 in 600s + +### Quantization & Compression + +- MLP weights: Int6 per-row symmetric (clip=31) +- Attention weights: Int6 per-row symmetric (clip=31) +- Embeddings: FP16 passthrough +- Norms, gates, control tensors: FP32 passthrough +- Compression: zstd level 22 + +### Evaluation: LoRA TTT (Test-Time Training) + +Per-document backward-looking LoRA adaptation during evaluation. This is the key technique that reduces bpb from ~1.161 (pre-TTT) to ~1.087 (post-TTT) — a **0.074 bpb improvement**. + +**How it works:** + +For each document in the validation set: +1. Add ephemeral LoRA adapters (rank=8) to Q projections, V projections, and the LM head +2. Split document into 256-token chunks with 1024-token context windows +3. Process chunks left-to-right over 2 epochs: + - Forward pass with LoRA-adapted model + - Score tokens on the final epoch (record loss for bpb) + - Train LoRA on non-final chunks (backward + optimizer step) +4. Reset LoRA weights + optimizer state before the next document + +**Key details:** +- LoRA rank 8 on Q + V projections + LM head per block +- Adam optimizer (lr=0.01, betas=0.9/0.95) +- Batch: 64 documents per GPU with independent LoRA per document +- Documents < 512 tokens: standard eval without TTT (insufficient context for adaptation) +- 8-GPU sharding: documents distributed across ranks, metrics all-reduced at end +- TTT time: ~245s per seed (within the 600s eval budget) + +**LoRA weights are NOT part of the 16MB artifact.** They are created fresh at eval time, trained on the fly per document, and discarded between documents. Only the base model weights are in the artifact. + +## Key Technique: Fresh Model for LoRA TTT + +`torch.compile` with `fullgraph=True` caches the forward graph from training, which has `None` for all LoRA delta arguments. The compiled graph silently ignores LoRA deltas at eval time — the LoRA additions to Q, V, and logits are treated as dead code by the compiled graph. + +**The fix:** Call `torch._dynamo.reset()` after training, create a fresh uncompiled `GPT` model from `state_dict`, and run LoRA TTT on the uncompiled model. This ensures all LoRA code paths are active during TTT. + +Without this fix, LoRA TTT produces **worse** results than no TTT (1.189 vs 1.161) because the model is effectively running without adaptation while still paying the per-document overhead. + +## Development Process + +This submission builds on the 1.1508 baseline (PR #350) with two additions: + +1. **EMA weight averaging** (decay=0.997) replaced SWA — marginal improvement +2. **LoRA TTT** adapted from PROTEUS v7 (PR #512) — the primary bpb improvement + +The LoRA TTT implementation required solving the `torch.compile` graph caching issue (see above), which was the critical debugging step. Batched document processing (64 docs/GPU) was essential for completing TTT within the eval time budget. + +### Progression + +| Submission | val_bpb | Technique | +|-----------|---------|-----------| +| PR #350 | 1.1508 | Baseline (no TTT) | +| This (pre-TTT) | ~1.1610 | + EMA | +| This (post-TTT) | **1.0865** | + LoRA TTT | + +## Hardware & Cost + +- Training: 8×H100 SXM (RunPod) +- Local testing: Apple Silicon (MPS) for architecture validation +- Total H100 time: ~1 hour for 2 seeds +- Estimated cost: ~$25 in RunPod credits + +## Acknowledgments + +- Training stack: PR #162 (raahilshah), PR #180 (thwu1) +- LoRA TTT approach: PR #512 (MatoTeziTanka), PR #77 (samacqua) +- EMA + TTT: PR #442 (sjp611) +- SmearGate/BigramHash: @unnir +- Muon optimizer, OrthoInit: Parameter Golf community + +## Files + +- `train_gpt.py` — Complete training script with environment variable configuration +- `train_seed42.log` — Training + TTT log (seed 42) +- `train_seed1337.log` — Training + TTT log (seed 1337) +- `submission.json` — Submission metadata +- `README.md` — This file diff --git a/records/track_10min_16mb/2026-03-23_LoquiAuris_10L_LoRA_TTT_1.0865/submission.json b/records/track_10min_16mb/2026-03-23_LoquiAuris_10L_LoRA_TTT_1.0865/submission.json new file mode 100644 index 000000000..3ef868f03 --- /dev/null +++ b/records/track_10min_16mb/2026-03-23_LoquiAuris_10L_LoRA_TTT_1.0865/submission.json @@ -0,0 +1,27 @@ +{ + "author": "Loqui Auris", + "github_id": "LoquiAuris", + "name": "Loqui Auris — 10L + EMA + LoRA TTT", + "blurb": "SP-1024 10x512 GQA-4KV + EMA(0.997) + per-document LoRA TTT (rank=8, 2 epochs, Adam lr=0.01). Batched TTT across 64 docs/GPU with fresh uncompiled model. Int6+zstd-22. Mean val_bpb=1.0865 (2 seeds).", + "date": "2026-03-23T15:30:00Z", + "seeds": { + "42": { + "val_loss": 1.8329802438108778, + "val_bpb": 1.0855941471729662, + "bytes_total": 15810855, + "bytes_code": 78226, + "steps": 5978 + }, + "1337": { + "val_loss": 1.8360795162680648, + "val_bpb": 1.0874297108956943, + "bytes_total": 15705529, + "bytes_code": 78226, + "steps": 5969 + } + }, + "val_loss": 1.8345298800394714, + "val_bpb": 1.0865119290343303, + "bytes_total": 15810855, + "bytes_code": 78226 +} diff --git a/records/track_10min_16mb/2026-03-23_LoquiAuris_10L_LoRA_TTT_1.0865/train_gpt.py b/records/track_10min_16mb/2026-03-23_LoquiAuris_10L_LoRA_TTT_1.0865/train_gpt.py new file mode 100644 index 000000000..f73029dc6 --- /dev/null +++ b/records/track_10min_16mb/2026-03-23_LoquiAuris_10L_LoRA_TTT_1.0865/train_gpt.py @@ -0,0 +1,1754 @@ +""" +The `train_gpt.py` and `train_gpt_mlx.py` scripts are intended as good launching-off points for new participants, not SOTA configs. We'll accept PRs that tune, improve, or simplify these scripts without significantly increasing complexity, but competitive submissions should stay in the `/records` folder. + +Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. +""" + +from __future__ import annotations + +import copy +import glob +import io +import json +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# ----------------------------- +# HYPERPARAMETERS +# ----------------------------- + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp8192") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_8192_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 8192)) + num_layers = int(os.environ.get("NUM_LAYERS", 9)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.01)) + + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32)) + + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 4096)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.5)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + embed_quant = os.environ.get("EMBED_QUANT", "fp16") # 'fp16', 'int6', 'int8' + + # EMA + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "0"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + + # TTT (test-time training) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.0005)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 10)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + + # Partial RoPE + rope_dims = int(os.environ.get("ROPE_DIMS", 0)) # 0 = full RoPE (all head dims) + + # LN Scale + ln_scale = bool(int(os.environ.get("LN_SCALE", "0"))) + + # XSA (cross-sequence attention — last N layers) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 0)) + + # TTT extras + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0)) + ttt_mode = os.environ.get("TTT_MODE", "standard") # 'standard', 'lora', or 'none' + + # LoRA TTT + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_min_doc_len = int(os.environ.get("TTT_MIN_DOC_LEN", 1024)) + +# ----------------------------- +# MUON OPTIMIZER +# ----------------------------- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() if G.device.type == "cuda" else G.float() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + muon_dtype = torch.bfloat16 if params[0].device.type == "cuda" else torch.float32 + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=muon_dtype) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + + +# ----------------------------- +# TOKENIZER-AGNOSTIC EVALUATION +# ----------------------------- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, TRAIN_SEQ_LEN={args.train_seq_len}" + ) + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float32) + val_token_count = torch.zeros((), device=device, dtype=torch.float32) + val_byte_count = torch.zeros((), device=device, dtype=torch.float32) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type=device.type if device.type in ("cuda","cpu") else "cpu", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.float() * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.float().sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + + +# ----------------------------- +# POST-TRAINING QUANTIZATION (INT8 legacy + INT6 mixed) +# ----------------------------- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,bigram.scale", + ).split(",") + if pattern +) +_default_fp16_keep = os.environ.get( + "FP16_KEEP_NAME_PATTERNS", + f"blocks.{int(os.environ.get('NUM_LAYERS', 9)) - 1}.attn.c_k", +) +FP16_KEEP_NAME_PATTERNS = tuple(p for p in _default_fp16_keep.split(",") if p) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -32, 31).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / 31.0, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -32, 31).to(torch.int8) + return q, scale + +def quantize_int5_per_row(t: Tensor) -> tuple[Tensor, Tensor]: + """Int5: [-16, 15] = 32 levels. Saves ~20% vs Int6 on MLP weights.""" + t32 = t.float() + if t32.ndim == 2: + row_max = t32.abs().amax(dim=1) + scale = (row_max / 15.0).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -16, 15).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / 15.0, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -16, 15).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], + int5_cats: set[str] | None = None, embed_quant: str = "fp16"): + """Mixed quantization: Int5 for MLP, Int6 for attention, configurable for embeddings.""" + if int5_cats is None: + int5_cats = set() + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + # Embedding quantization (controlled by embed_quant) + if cat == "embed" and t.ndim >= 1: + if embed_quant == "int6": + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + continue + elif embed_quant == "int8": + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + continue + else: # fp16 + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if any(pattern in name for pattern in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int5_cats and t.ndim >= 1: + q, s = quantize_int5_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5"} + elif cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + + +# ----------------------------- +# DATA LOADING +# ----------------------------- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + + +# ----------------------------- +# TRANSFORMER MODULES +# ----------------------------- + +def _rms_norm(x, normalized_shape, eps=None): + if hasattr(F, 'rms_norm'): + return F.rms_norm(x, normalized_shape, eps=eps) + dims = tuple(range(-len(normalized_shape), 0)) + rms = x.float().pow(2).mean(dims, keepdim=True).add(eps or 1e-6).rsqrt() + return (x.float() * rms).to(x.dtype) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return _rms_norm(x, (x.size(-1),), eps=self.eps) + + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) + freqs = torch.outer(t, self.inv_freq.to(device)) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + cos = self._cos_cached.to(dtype=dtype) + sin = self._sin_cached.to(dtype=dtype) + if cos.requires_grad or not cos.is_leaf: + return cos, sin + return cos.clone(), sin.clone() + + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + + +# ── LoRA TTT Modules ── + +class BatchedLinearLoRA(nn.Module): + """Per-batch-element LoRA adapter. Delta = x @ A^T @ B^T.""" + def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) + self.reset() + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) + + def reset(self) -> None: + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) + self.B.zero_() + + +class BatchedTTTLoRA(nn.Module): + """All LoRA adapters for one batch: LM head + Q/V per block.""" + def __init__(self, bsz: int, model, rank: int): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) + self.q_loras = nn.ModuleList() + self.v_loras = nn.ModuleList() + for block in model.blocks: + q_out = block.attn.c_q.weight.shape[0] + v_out = block.attn.c_v.weight.shape[0] + self.q_loras.append(BatchedLinearLoRA(bsz, dim, q_out, rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, v_out, rank)) + + def reset(self) -> None: + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + + +BOS_ID = 1 # SentencePiece BOS token ID + + +def _find_docs(all_tokens: Tensor) -> list[tuple[int, int]]: + """Return (start_offset, length) for each document at BOS boundaries.""" + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].cpu().numpy() + docs: list[tuple[int, int]] = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) + 1 if i + 1 < len(bos_positions) else all_tokens.numel() + if end - start >= 2: + docs.append((start, end - start)) + return docs + + +def _reset_ttt_optimizer(opt: torch.optim.Adam) -> None: + for group in opt.param_groups: + for p in group["params"]: + s = opt.state.get(p) + if not s: + continue + s["exp_avg"].zero_() + s["exp_avg_sq"].zero_() + s["step"].fill_(0) + + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_dims: int = 0, use_xsa: bool = False): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + self.rope_dims = rope_dims if rope_dims > 0 else self.head_dim + self.use_xsa = use_xsa + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.rope_dims, base=rope_base) + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """XSA: subtract self-value projection via GQA-aware reshape.""" + B, T, H, D = y.shape + Hkv = v.size(2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, q_delta: Tensor | None = None, v_delta: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q_out = self.c_q(x) + if q_delta is not None: + q_out = q_out + q_delta + q = q_out.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v_out = self.c_v(x) + if v_delta is not None: + v_out = v_out + v_delta + v = v_out.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = _rms_norm(q, (q.size(-1),)) + k = _rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + if self.rope_dims < self.head_dim: + q_rope, q_pass = q[..., :self.rope_dims], q[..., self.rope_dims:] + k_rope, k_pass = k[..., :self.rope_dims], k[..., self.rope_dims:] + q_rope = apply_rotary_emb(q_rope, cos, sin) + k_rope = apply_rotary_emb(k_rope, cos, sin) + q = torch.cat([q_rope, q_pass], dim=-1) + k = torch.cat([k_rope, k_pass], dim=-1) + else: + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + # GQA + if self.num_kv_heads != self.num_heads: + rep = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(rep, dim=1) + v = v.repeat_interleave(rep, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) + if self.use_xsa: + # XSA: orthogonal projection subtraction + y_xsa = y.transpose(1, 2) # [B, S, H, D] + v_xsa = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) # [B, S, Hkv, D] + y_xsa = self._xsa_efficient(y_xsa, v_xsa) + y = y_xsa.reshape(bsz, seqlen, dim) + else: + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + + +class SmearGate(nn.Module): + """Blend each token's embedding with the previous token's embedding.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + + +class BigramHashEmbedding(nn.Module): + """Hash consecutive token pairs into a learned embedding table.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, rope_dims: int = 0, ln_scale: bool = False, layer_idx: int = 0, use_xsa: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims, use_xsa=use_xsa) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + # LN Scale: 1/sqrt(layer_idx+1) dampening (not learnable) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + s = self.ln_scale_factor + n = self.attn_norm(x) * s + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s) + return x + + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + rope_dims: int = 0, + ln_scale: bool = False, + xsa_last_n: int = 0, + ): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + self.blocks = nn.ModuleList( + [ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + rope_dims=rope_dims, ln_scale=ln_scale, layer_idx=i, + use_xsa=(i >= num_layers - xsa_last_n) if xsa_last_n > 0 else False) + for i in range(num_layers) + ] + ) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = _rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + qd_fn = lora.q_loras[i] if lora is not None else None + vd_fn = lora.v_loras[i] if lora is not None else None + x = self.blocks[i](x, x0, qd_fn, vd_fn) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + bi = self.num_encoder_layers + i + qd_fn = lora.q_loras[bi] if lora is not None else None + vd_fn = lora.v_loras[bi] if lora is not None else None + x = self.blocks[bi](x, x0, qd_fn, vd_fn) + x_norm = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x_norm, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x_norm) + if lora is not None: + # Add LM head LoRA delta and return per-token loss + lora_delta = lora.lm_head_lora(x_norm) + logits_proj = logits_proj + lora_delta + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + # Return per-token loss: [bsz, seqlen] + return F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + target_ids.reshape(-1), + reduction="none", + ).reshape(input_ids.shape) + logits = self.logit_softcap * torch.tanh(logits_proj.reshape(-1, logits_proj.size(-1)) / self.logit_softcap) + return F.cross_entropy(logits.float(), target_ids.reshape(-1), reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = _rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + x = self.blocks[i](x, x0) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + x = self.blocks[self.num_encoder_layers + i](x, x0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + + loss_sum = torch.zeros((), device=device, dtype=torch.float32) + token_count = torch.zeros((), device=device, dtype=torch.float32) + byte_count = torch.zeros((), device=device, dtype=torch.float32) + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type=device.type if device.type in ("cuda","cpu") else "cpu", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].float() + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].float() + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).float() + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + running_bpb = 0.0 + if token_count.item() > 0: + rl = (loss_sum / token_count).item() + running_bpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} windows running_bpb={running_bpb:.6f}", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# ----------------------------- +# LORA TTT EVALUATION +# ----------------------------- + +def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): + """Return (win_start, win_len, chunk_offset, chunk_len) for chunk ci of a doc.""" + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + + +def _build_ttt_optimizer(lora, args): + return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10) + + +def eval_val_ttt_lora( + args, base_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + log_fn=None, +) -> tuple[float, float]: + """Batched LoRA TTT evaluation: per-document adaptation, sharded across GPUs.""" + docs = _find_docs(val_tokens) + # Shard docs across ranks + rank_docs = docs[(len(docs) * rank) // world_size: (len(docs) * (rank + 1)) // world_size] + short_docs = [d for d in rank_docs if d[1] < args.ttt_min_doc_len] + long_docs = [d for d in rank_docs if d[1] >= args.ttt_min_doc_len] + master = rank == 0 + if master and log_fn: + log_fn(f"lora_ttt: {len(docs)} total docs, rank0: {len(long_docs)} long + {len(short_docs)} short, batch={args.ttt_batch_seqs}") + + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + + # Score short docs without TTT + t0 = time.perf_counter() + with torch.no_grad(): + for ds, dl in short_docs: + x = val_tokens[ds: ds + dl - 1].to(device=device, dtype=torch.int64).unsqueeze(0) + y = val_tokens[ds + 1: ds + dl].to(device=device, dtype=torch.int64).unsqueeze(0) + n = dl - 1 + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + mean_loss = base_model(x, y) + loss_sum += mean_loss.to(torch.float64) * n + token_count += n + tgt = y.reshape(-1) + px = x.reshape(-1) + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[px]).to(torch.float64) + byte_sum += tb.sum() + if master and log_fn: + log_fn(f"lora_ttt: short_docs time={time.perf_counter()-t0:.1f}s tokens={int(token_count.item())}") + + # Sort long docs by chunk count for efficient batching + long_docs.sort(key=lambda d: (d[1] - 2) // args.ttt_chunk_size) + batch_size = args.ttt_batch_seqs + chunk_size = args.ttt_chunk_size + eval_seq_len = args.ttt_eval_seq_len + + # Create reusable LoRA and optimizer + lora = BatchedTTTLoRA(batch_size, base_model, args.ttt_lora_rank).to(device) + opt = _build_ttt_optimizer(lora, args) + + t1 = time.perf_counter() + for bi in range(0, len(long_docs), batch_size): + batch = long_docs[bi: bi + batch_size] + bsz = len(batch) + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset() + _reset_ttt_optimizer(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, args.ttt_lora_rank).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args) + pred_lens = [dl - 1 for _, dl in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + for epoch in range(args.ttt_epochs): + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + ws_ref, wl_ref, _, _ = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len) + x = torch.zeros(bsz, wl_ref, dtype=torch.int64, device=device) + y = torch.zeros(bsz, wl_ref, dtype=torch.int64, device=device) + doc_info = [] + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)) + continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + toks = val_tokens[ds + ws: ds + ws + wl + 1].to(dtype=torch.int64, device=device) + x[b, :wl] = toks[:-1] + y[b, :wl] = toks[1:] + doc_info.append((co, cl)) + needs_train = any(ci < nc - 1 for nc in num_chunks) + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + if epoch == args.ttt_epochs - 1: + with torch.no_grad(): + for b in range(bsz): + if not active[b]: + continue + co, cl = doc_info[b] + loss_sum += ptl[b, co: co + cl].to(torch.float64).sum() + token_count += cl + tgt = y[b, co: co + cl] + px = x[b, co: co + cl] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[px]).to(torch.float64) + byte_sum += tb.sum() + if needs_train: + train_loss = torch.zeros(bsz, device=device) + for b in range(bsz): + if ci >= num_chunks[b] - 1: + continue + co, cl = doc_info[b] + if cl > 0: + train_loss[b] = ptl[b, co: co + cl].mean() + cur_opt.zero_grad() + train_loss.sum().backward() + cur_opt.step() + if master and log_fn and (bi + batch_size) % (batch_size * 5) == 0: + elapsed = time.perf_counter() - t1 + avg_l = loss_sum.item() / max(token_count.item(), 1) + log_fn(f"lora_ttt: batch {bi // batch_size + 1}/{(len(long_docs) + batch_size - 1) // batch_size} time={elapsed:.1f}s avg_loss={avg_l:.4f}") + + if master and log_fn: + log_fn(f"lora_ttt: long_docs time={time.perf_counter()-t1:.1f}s docs={len(long_docs)}") + + # All-reduce across ranks + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + + base_model.train() + for p in base_model.parameters(): + p.requires_grad_(True) + + val_loss = float(loss_sum.item() / max(token_count.item(), 1)) + val_bpb = float((loss_sum.item() / math.log(2.0)) / max(byte_sum.item(), 1)) + if master and log_fn: + log_fn(f"lora_ttt:complete loss:{val_loss:.4f} bpb:{val_bpb:.4f} time:{time.perf_counter()-t0:.1f}s") + return val_loss, val_bpb + + +# ----------------------------- +# TRAINING +# ----------------------------- + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + try: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + except Exception: + pass # torch.compile not available (e.g. Python 3.12) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if torch.cuda.is_available(): + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + elif torch.backends.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + try: + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + except ImportError: + enable_flash_sdp = enable_mem_efficient_sdp = enable_math_sdp = enable_cudnn_sdp = lambda x: None + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + try: + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + except FileNotFoundError: + log0("nvidia-smi not available", console=False) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + # MODEL + OPTIMIZER SETUP + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + xsa_last_n=args.xsa_last_n, + ).to(device) + if device.type == "cuda": + base_model = base_model.bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if device.type == "cuda": + try: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + except Exception: + compiled_model = base_model + else: + compiled_model = base_model # skip compile on MPS/CPU + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=torch.cuda.is_available(), + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=0.04, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.weight_decay, + fused=torch.cuda.is_available(), + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=torch.cuda.is_available(), + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + + # EMA setup (tracks ALL state_dict items, matching SOTA) + ema_state: dict[str, Tensor] | None = None + if args.ema_enabled: + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + log0(f"ema:enabled decay:{args.ema_decay} tracking {len(ema_state)} tensors") + + # DATA LOADER & MODEL WARMUP + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type=device.type if device.type in ("cuda","cpu") else "cpu", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + # MAIN TRAINING LOOP + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + if torch.cuda.is_available(): torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + if torch.cuda.is_available(): torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + if torch.cuda.is_available(): torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type=device.type if device.type in ("cuda","cpu") else "cpu", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + # SWA: collect checkpoints during warmdown + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + # EMA: update every step (all state_dict items) + if ema_state is not None: + d = args.ema_decay + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(d).add_(t.detach().float(), alpha=1.0 - d) + + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024 if torch.cuda.is_available() else 0} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024 if torch.cuda.is_available() else 0} MiB" + ) + + # Apply SWA if collected + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(avg_state, strict=True) + + # Apply EMA if enabled + if ema_state is not None: + log0(f"ema:applying EMA weights (decay={args.ema_decay})") + current_sd = base_model.state_dict() + avg_state = {name: t.to(dtype=current_sd[name].dtype) for name, t in ema_state.items()} + del ema_state + base_model.load_state_dict(avg_state, strict=True) + del avg_state + + # SERIALIZATION + ROUNDTRIP VALIDATION + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + # INT6 mixed quantization + zstd/zlib export + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"attn"}, int5_cats={"mlp"}, embed_quant=args.embed_quant) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int8+zlib: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + base_model.load_state_dict(deq_state, strict=True) + + # TTT (test-time training) + if args.ttt_mode == "standard" and args.ttt_enabled: + log0(f"ttt:standard lr={args.ttt_lr} epochs={args.ttt_epochs} freeze_blocks={args.ttt_freeze_blocks}") + if args.ttt_freeze_blocks > 0: + for i, block in enumerate(base_model.blocks): + if i < args.ttt_freeze_blocks: + for p in block.parameters(): + p.requires_grad_(False) + base_model.train() + ttt_params = [p for p in base_model.parameters() if p.requires_grad] + ttt_opt = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0, + betas=(args.ttt_momentum, 0.999)) + ttt_seq_len = args.train_seq_len + ttt_batch_seqs = args.ttt_batch_seqs + total_seqs = (val_tokens.numel() - 1) // ttt_seq_len + my_start = (total_seqs * rank) // world_size + my_end = (total_seqs * (rank + 1)) // world_size + t_ttt = time.perf_counter() + for ttt_epoch in range(args.ttt_epochs): + _ttt_dtype = torch.float64 if device.type == "cuda" else torch.float32 + epoch_loss = torch.zeros((), device=device, dtype=_ttt_dtype) + epoch_tokens = torch.zeros((), device=device, dtype=_ttt_dtype) + for batch_start in range(my_start, my_end, ttt_batch_seqs): + batch_end = min(batch_start + ttt_batch_seqs, my_end) + raw_start = batch_start * ttt_seq_len + raw_end = batch_end * ttt_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64) + x_ttt = local[:-1].reshape(-1, ttt_seq_len) + y_ttt = local[1:].reshape(-1, ttt_seq_len) + ttt_opt.zero_grad(set_to_none=True) + with torch.autocast(device_type=device.type if device.type in ("cuda", "cpu") else "cpu", + dtype=torch.bfloat16, enabled=device.type == "cuda"): + loss_ttt = base_model(x_ttt, y_ttt) + loss_ttt.backward() + if distributed: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + ttt_opt.step() + epoch_loss += loss_ttt.detach().to(_ttt_dtype) * float(y_ttt.numel()) + epoch_tokens += float(y_ttt.numel()) + if distributed: + dist.all_reduce(epoch_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(epoch_tokens, op=dist.ReduceOp.SUM) + avg_ttt = (epoch_loss / epoch_tokens).item() + log0(f"ttt_epoch:{ttt_epoch+1}/{args.ttt_epochs} loss:{avg_ttt:.4f} time:{time.perf_counter()-t_ttt:.1f}s") + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + log0(f"ttt:complete total_time:{time.perf_counter()-t_ttt:.1f}s") + + # LoRA TTT — per-document adaptation with ephemeral LoRA + if args.ttt_mode == "lora": + log0(f"lora_ttt:starting rank={args.ttt_lora_rank} lr={args.ttt_lora_lr} epochs={args.ttt_epochs} chunk={args.ttt_chunk_size}") + # CRITICAL: Create fresh UNCOMPILED model for TTT. + # torch.compile caches the no-LoRA forward path, silently ignoring LoRA deltas. + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch._dynamo.reset() + ttt_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, xsa_last_n=args.xsa_last_n, + ).to(device) + ttt_model.load_state_dict(base_model.state_dict(), strict=True) + log0(f"lora_ttt: fresh uncompiled model created for TTT") + t_lora_ttt = time.perf_counter() + q_val_loss, q_val_bpb = eval_val_ttt_lora( + args, ttt_model, rank, world_size, device, val_tokens, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + log_fn=log0, + ) + del ttt_model + if torch.cuda.is_available(): + torch.cuda.synchronize() + log0( + f"final_lora_ttt val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_lora_ttt):.0f}ms" + ) + log0(f"final_lora_ttt_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + # Skip the standard sliding eval — LoRA TTT IS the eval + else: + # Standard sliding window eval (after optional standard TTT) + if torch.cuda.is_available(): torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, + ) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + if torch.cuda.is_available(): torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if master_process: + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + summary = { + "run_id": args.run_id, + "model_dim": args.model_dim, + "num_layers": args.num_layers, + "num_heads": args.num_heads, + "num_kv_heads": args.num_kv_heads, + "vocab_size": args.vocab_size, + "mlp_mult": args.mlp_mult, + "train_seq_len": args.train_seq_len, + "embed_quant": args.embed_quant, + "post_quant_bpb": q_val_bpb, + "post_quant_loss": q_val_loss, + "artifact_bytes": quant_file_bytes, + "code_bytes": code_bytes, + "total_bytes": quant_file_bytes + code_bytes, + "headroom_bytes": 16_000_000 - (quant_file_bytes + code_bytes), + "fits_16mb": (quant_file_bytes + code_bytes) <= 16_000_000, + "training_time_ms": training_time_ms, + "steps_completed": step, + "swa_count": swa_count, + "n_params": n_params, + } + os.makedirs("results", exist_ok=True) + summary_path = f"results/summary_{args.run_id}.json" + with open(summary_path, "w") as f: + json.dump(summary, f, indent=2) + log0(f"results_saved:{summary_path}") + log0( + f"SUBMISSION_SUMMARY run_id:{args.run_id} post_quant_bpb:{q_val_bpb:.6f} " + f"artifact:{quant_file_bytes:,} code:{code_bytes:,} total:{quant_file_bytes+code_bytes:,} " + f"fits:{'YES' if summary['fits_16mb'] else 'NO'} headroom:{summary['headroom_bytes']:,}" + ) + + if distributed: + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-23_LoquiAuris_10L_LoRA_TTT_1.0865/train_seed1337.log b/records/track_10min_16mb/2026-03-23_LoquiAuris_10L_LoRA_TTT_1.0865/train_seed1337.log new file mode 100644 index 000000000..53ddc79a4 --- /dev/null +++ b/records/track_10min_16mb/2026-03-23_LoquiAuris_10L_LoRA_TTT_1.0865/train_seed1337.log @@ -0,0 +1,163 @@ +W0323 14:39:00.672000 136962421367424 torch/distributed/run.py:779] +W0323 14:39:00.672000 136962421367424 torch/distributed/run.py:779] ***************************************** +W0323 14:39:00.672000 136962421367424 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0323 14:39:00.672000 136962421367424 torch/distributed/run.py:779] ***************************************** +logs/config1_v4.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +ema:enabled decay:0.997 tracking 106 tensors +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9275 val_bpb:4.1028 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9296 train_time:149ms step_avg:149.01ms +step:2/20000 train_loss:7.6886 train_time:223ms step_avg:111.56ms +step:3/20000 train_loss:7.2769 train_time:321ms step_avg:106.97ms +step:4/20000 train_loss:7.9410 train_time:419ms step_avg:104.73ms +step:5/20000 train_loss:8.3460 train_time:517ms step_avg:103.46ms +step:6/20000 train_loss:8.1653 train_time:616ms step_avg:102.67ms +step:7/20000 train_loss:7.5274 train_time:714ms step_avg:102.04ms +step:8/20000 train_loss:6.8399 train_time:812ms step_avg:101.55ms +step:9/20000 train_loss:6.5264 train_time:910ms step_avg:101.16ms +step:10/20000 train_loss:6.2605 train_time:1008ms step_avg:100.83ms +step:100/20000 train_loss:3.2007 train_time:9877ms step_avg:98.77ms +step:200/20000 train_loss:2.4249 train_time:19936ms step_avg:99.68ms +step:300/20000 train_loss:2.5789 train_time:30011ms step_avg:100.04ms +step:400/20000 train_loss:2.4470 train_time:40104ms step_avg:100.26ms +step:500/20000 train_loss:2.4179 train_time:50046ms step_avg:100.09ms +step:500/20000 val_loss:2.3791 val_bpb:1.4091 train_time:50074ms step_avg:100.15ms +step:600/20000 train_loss:2.3549 train_time:60175ms step_avg:100.29ms +step:700/20000 train_loss:2.3685 train_time:70305ms step_avg:100.44ms +step:800/20000 train_loss:2.2593 train_time:80445ms step_avg:100.56ms +step:900/20000 train_loss:2.1468 train_time:90597ms step_avg:100.66ms +step:1000/20000 train_loss:2.2863 train_time:100579ms step_avg:100.58ms +step:1000/20000 val_loss:2.2406 val_bpb:1.3270 train_time:100607ms step_avg:100.61ms +step:1100/20000 train_loss:2.3399 train_time:110711ms step_avg:100.65ms +step:1200/20000 train_loss:2.3667 train_time:120860ms step_avg:100.72ms +step:1300/20000 train_loss:2.1142 train_time:130982ms step_avg:100.76ms +step:1400/20000 train_loss:2.1969 train_time:141115ms step_avg:100.80ms +step:1500/20000 train_loss:2.2306 train_time:151071ms step_avg:100.71ms +step:1500/20000 val_loss:2.1951 val_bpb:1.3001 train_time:151097ms step_avg:100.73ms +step:1600/20000 train_loss:2.0883 train_time:161186ms step_avg:100.74ms +step:1700/20000 train_loss:2.1544 train_time:171307ms step_avg:100.77ms +step:1800/20000 train_loss:2.1662 train_time:181416ms step_avg:100.79ms +step:1900/20000 train_loss:2.1385 train_time:191341ms step_avg:100.71ms +step:2000/20000 train_loss:2.0781 train_time:201457ms step_avg:100.73ms +step:2000/20000 val_loss:2.1410 val_bpb:1.2680 train_time:201485ms step_avg:100.74ms +step:2100/20000 train_loss:2.0549 train_time:211550ms step_avg:100.74ms +step:2200/20000 train_loss:2.1460 train_time:221647ms step_avg:100.75ms +step:2300/20000 train_loss:2.1151 train_time:231744ms step_avg:100.76ms +step:2400/20000 train_loss:2.0724 train_time:241662ms step_avg:100.69ms +step:2500/20000 train_loss:2.1744 train_time:251751ms step_avg:100.70ms +step:2500/20000 val_loss:2.1109 val_bpb:1.2502 train_time:251778ms step_avg:100.71ms +step:2600/20000 train_loss:2.1113 train_time:261837ms step_avg:100.71ms +step:2700/20000 train_loss:2.1052 train_time:271923ms step_avg:100.71ms +step:2800/20000 train_loss:2.1581 train_time:282013ms step_avg:100.72ms +step:2900/20000 train_loss:2.0292 train_time:291922ms step_avg:100.66ms +step:3000/20000 train_loss:2.1633 train_time:302002ms step_avg:100.67ms +step:3000/20000 val_loss:2.0950 val_bpb:1.2408 train_time:302028ms step_avg:100.68ms +step:3100/20000 train_loss:2.0380 train_time:312075ms step_avg:100.67ms +step:3200/20000 train_loss:2.1678 train_time:322157ms step_avg:100.67ms +step:3300/20000 train_loss:2.0629 train_time:332058ms step_avg:100.62ms +step:3400/20000 train_loss:2.0128 train_time:342138ms step_avg:100.63ms +step:3500/20000 train_loss:2.1718 train_time:352200ms step_avg:100.63ms +step:3500/20000 val_loss:2.0715 val_bpb:1.2269 train_time:352228ms step_avg:100.64ms +step:3600/20000 train_loss:2.0867 train_time:362288ms step_avg:100.64ms +step:3700/20000 train_loss:2.0816 train_time:372365ms step_avg:100.64ms +step:3800/20000 train_loss:2.0586 train_time:382283ms step_avg:100.60ms +step:3900/20000 train_loss:2.0640 train_time:392362ms step_avg:100.61ms +step:4000/20000 train_loss:1.9589 train_time:402445ms step_avg:100.61ms +step:4000/20000 val_loss:2.0508 val_bpb:1.2146 train_time:402472ms step_avg:100.62ms +step:4100/20000 train_loss:1.9955 train_time:412514ms step_avg:100.61ms +step:4200/20000 train_loss:2.1329 train_time:422586ms step_avg:100.62ms +step:4300/20000 train_loss:2.0385 train_time:432480ms step_avg:100.58ms +step:4400/20000 train_loss:2.0162 train_time:442552ms step_avg:100.58ms +step:4500/20000 train_loss:2.1064 train_time:452618ms step_avg:100.58ms +step:4500/20000 val_loss:2.0280 val_bpb:1.2011 train_time:452645ms step_avg:100.59ms +step:4600/20000 train_loss:1.8248 train_time:462692ms step_avg:100.59ms +step:4700/20000 train_loss:2.2155 train_time:472600ms step_avg:100.55ms +step:4800/20000 train_loss:2.4084 train_time:482675ms step_avg:100.56ms +step:4900/20000 train_loss:2.0305 train_time:492746ms step_avg:100.56ms +step:5000/20000 train_loss:2.0832 train_time:502824ms step_avg:100.56ms +step:5000/20000 val_loss:2.0043 val_bpb:1.1870 train_time:502851ms step_avg:100.57ms +step:5100/20000 train_loss:2.1067 train_time:512899ms step_avg:100.57ms +step:5200/20000 train_loss:2.0214 train_time:522800ms step_avg:100.54ms +step:5300/20000 train_loss:1.9841 train_time:532885ms step_avg:100.54ms +step:5400/20000 train_loss:2.0264 train_time:543026ms step_avg:100.56ms +step:5500/20000 train_loss:1.9934 train_time:553104ms step_avg:100.56ms +step:5500/20000 val_loss:1.9792 val_bpb:1.1722 train_time:553131ms step_avg:100.57ms +step:5600/20000 train_loss:1.9348 train_time:563174ms step_avg:100.57ms +step:5700/20000 train_loss:1.9879 train_time:573078ms step_avg:100.54ms +step:5800/20000 train_loss:1.9710 train_time:583157ms step_avg:100.54ms +step:5900/20000 train_loss:1.8806 train_time:593240ms step_avg:100.55ms +step:5969/20000 val_loss:1.9602 val_bpb:1.1609 train_time:600098ms step_avg:100.54ms +stopping_early: wallclock_cap train_time:600098ms step:5969/20000 +peak memory allocated: 23759 MiB reserved: 24230 MiB +ema:applying EMA weights (decay=0.997) +Serialized model: 96864150 bytes +Code size: 78226 bytes +Total submission size: 96942376 bytes +Serialized model int6+zstd: 15627303 bytes +Total submission size int8+zlib: 15705529 bytes +/workspace/train_gpt.py:1600: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1600: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1600: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1600: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1600: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1600: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1600: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1600: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +lora_ttt:starting rank=8 lr=0.01 epochs=2 chunk=256 +lora_ttt: fresh uncompiled model created for TTT +lora_ttt: 50000 total docs, rank0: 3857 long + 2393 short, batch=64 +lora_ttt: short_docs time=22.7s tokens=732712 +lora_ttt: batch 5/61 time=2.2s avg_loss=1.9958 +lora_ttt: batch 10/61 time=4.4s avg_loss=1.9706 +lora_ttt: batch 15/61 time=6.5s avg_loss=1.9492 +lora_ttt: batch 20/61 time=10.2s avg_loss=1.9156 +lora_ttt: batch 25/61 time=14.0s avg_loss=1.8994 +lora_ttt: batch 30/61 time=19.6s avg_loss=1.8802 +lora_ttt: batch 35/61 time=25.9s avg_loss=1.8654 +lora_ttt: batch 40/61 time=33.8s avg_loss=1.8535 +lora_ttt: batch 45/61 time=43.9s avg_loss=1.8436 +lora_ttt: batch 50/61 time=56.9s avg_loss=1.8383 +lora_ttt: batch 55/61 time=75.5s avg_loss=1.8275 +lora_ttt: batch 60/61 time=133.2s avg_loss=1.8207 +lora_ttt: long_docs time=153.6s docs=3857 +lora_ttt:complete loss:1.8361 bpb:1.0874 time:243.7s +final_lora_ttt val_loss:1.8361 val_bpb:1.0874 eval_time:244046ms +final_lora_ttt_exact val_loss:1.83607952 val_bpb:1.08742971 +results_saved:results/summary_config1_v4.json +SUBMISSION_SUMMARY run_id:config1_v4 post_quant_bpb:1.087430 artifact:15,627,303 code:78,226 total:15,705,529 fits:YES headroom:294,471 diff --git a/records/track_10min_16mb/2026-03-23_LoquiAuris_10L_LoRA_TTT_1.0865/train_seed42.log b/records/track_10min_16mb/2026-03-23_LoquiAuris_10L_LoRA_TTT_1.0865/train_seed42.log new file mode 100644 index 000000000..1b417bc76 --- /dev/null +++ b/records/track_10min_16mb/2026-03-23_LoquiAuris_10L_LoRA_TTT_1.0865/train_seed42.log @@ -0,0 +1,163 @@ +W0323 15:00:10.245000 132669878153856 torch/distributed/run.py:779] +W0323 15:00:10.245000 132669878153856 torch/distributed/run.py:779] ***************************************** +W0323 15:00:10.245000 132669878153856 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0323 15:00:10.245000 132669878153856 torch/distributed/run.py:779] ***************************************** +logs/seed_42.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24730705 +world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +ema:enabled decay:0.997 tracking 106 tensors +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9283 val_bpb:4.1033 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9311 train_time:160ms step_avg:160.31ms +step:2/20000 train_loss:7.6461 train_time:235ms step_avg:117.34ms +step:3/20000 train_loss:7.2985 train_time:333ms step_avg:110.96ms +step:4/20000 train_loss:8.1111 train_time:431ms step_avg:107.66ms +step:5/20000 train_loss:8.2614 train_time:528ms step_avg:105.66ms +step:6/20000 train_loss:8.1465 train_time:626ms step_avg:104.34ms +step:7/20000 train_loss:7.5114 train_time:724ms step_avg:103.44ms +step:8/20000 train_loss:6.9140 train_time:822ms step_avg:102.69ms +step:9/20000 train_loss:6.4530 train_time:919ms step_avg:102.16ms +step:10/20000 train_loss:6.1755 train_time:1017ms step_avg:101.74ms +step:100/20000 train_loss:3.2064 train_time:9877ms step_avg:98.77ms +step:200/20000 train_loss:2.4350 train_time:19926ms step_avg:99.63ms +step:300/20000 train_loss:2.5814 train_time:29985ms step_avg:99.95ms +step:400/20000 train_loss:2.4461 train_time:40053ms step_avg:100.13ms +step:500/20000 train_loss:2.4155 train_time:49977ms step_avg:99.95ms +step:500/20000 val_loss:2.3809 val_bpb:1.4101 train_time:50004ms step_avg:100.01ms +step:600/20000 train_loss:2.3556 train_time:60082ms step_avg:100.14ms +step:700/20000 train_loss:2.3718 train_time:70194ms step_avg:100.28ms +step:800/20000 train_loss:2.2534 train_time:80307ms step_avg:100.38ms +step:900/20000 train_loss:2.1430 train_time:90409ms step_avg:100.45ms +step:1000/20000 train_loss:2.2879 train_time:100355ms step_avg:100.35ms +step:1000/20000 val_loss:2.2420 val_bpb:1.3278 train_time:100382ms step_avg:100.38ms +step:1100/20000 train_loss:2.3341 train_time:110465ms step_avg:100.42ms +step:1200/20000 train_loss:2.3691 train_time:120573ms step_avg:100.48ms +step:1300/20000 train_loss:2.1125 train_time:130668ms step_avg:100.51ms +step:1400/20000 train_loss:2.1915 train_time:140764ms step_avg:100.55ms +step:1500/20000 train_loss:2.2316 train_time:150684ms step_avg:100.46ms +step:1500/20000 val_loss:2.1949 val_bpb:1.2999 train_time:150711ms step_avg:100.47ms +step:1600/20000 train_loss:2.0890 train_time:160764ms step_avg:100.48ms +step:1700/20000 train_loss:2.1529 train_time:170844ms step_avg:100.50ms +step:1800/20000 train_loss:2.1738 train_time:180936ms step_avg:100.52ms +step:1900/20000 train_loss:2.1346 train_time:190841ms step_avg:100.44ms +step:2000/20000 train_loss:2.0760 train_time:200922ms step_avg:100.46ms +step:2000/20000 val_loss:2.1409 val_bpb:1.2680 train_time:200949ms step_avg:100.47ms +step:2100/20000 train_loss:2.0530 train_time:210990ms step_avg:100.47ms +step:2200/20000 train_loss:2.1636 train_time:221059ms step_avg:100.48ms +step:2300/20000 train_loss:2.1140 train_time:231124ms step_avg:100.49ms +step:2400/20000 train_loss:2.0733 train_time:241025ms step_avg:100.43ms +step:2500/20000 train_loss:2.1760 train_time:251090ms step_avg:100.44ms +step:2500/20000 val_loss:2.1114 val_bpb:1.2505 train_time:251117ms step_avg:100.45ms +step:2600/20000 train_loss:2.1130 train_time:261163ms step_avg:100.45ms +step:2700/20000 train_loss:2.1074 train_time:271230ms step_avg:100.46ms +step:2800/20000 train_loss:2.1599 train_time:281303ms step_avg:100.47ms +step:2900/20000 train_loss:2.0274 train_time:291193ms step_avg:100.41ms +step:3000/20000 train_loss:2.1611 train_time:301261ms step_avg:100.42ms +step:3000/20000 val_loss:2.0944 val_bpb:1.2405 train_time:301289ms step_avg:100.43ms +step:3100/20000 train_loss:2.0359 train_time:311328ms step_avg:100.43ms +step:3200/20000 train_loss:2.1719 train_time:321392ms step_avg:100.44ms +step:3300/20000 train_loss:2.0670 train_time:331286ms step_avg:100.39ms +step:3400/20000 train_loss:2.0171 train_time:341344ms step_avg:100.40ms +step:3500/20000 train_loss:2.1739 train_time:351407ms step_avg:100.40ms +step:3500/20000 val_loss:2.0718 val_bpb:1.2270 train_time:351433ms step_avg:100.41ms +step:3600/20000 train_loss:2.0868 train_time:361469ms step_avg:100.41ms +step:3700/20000 train_loss:2.0847 train_time:371536ms step_avg:100.42ms +step:3800/20000 train_loss:2.0599 train_time:381425ms step_avg:100.38ms +step:3900/20000 train_loss:2.0590 train_time:391481ms step_avg:100.38ms +step:4000/20000 train_loss:1.9614 train_time:401538ms step_avg:100.38ms +step:4000/20000 val_loss:2.0511 val_bpb:1.2148 train_time:401564ms step_avg:100.39ms +step:4100/20000 train_loss:2.0002 train_time:411595ms step_avg:100.39ms +step:4200/20000 train_loss:2.1360 train_time:421660ms step_avg:100.40ms +step:4300/20000 train_loss:2.0426 train_time:431546ms step_avg:100.36ms +step:4400/20000 train_loss:2.0185 train_time:441608ms step_avg:100.37ms +step:4500/20000 train_loss:2.1058 train_time:451682ms step_avg:100.37ms +step:4500/20000 val_loss:2.0283 val_bpb:1.2012 train_time:451709ms step_avg:100.38ms +step:4600/20000 train_loss:1.8254 train_time:461751ms step_avg:100.38ms +step:4700/20000 train_loss:2.2176 train_time:471713ms step_avg:100.36ms +step:4800/20000 train_loss:2.4093 train_time:481772ms step_avg:100.37ms +step:4900/20000 train_loss:2.0320 train_time:491833ms step_avg:100.37ms +step:5000/20000 train_loss:2.0849 train_time:501893ms step_avg:100.38ms +step:5000/20000 val_loss:2.0047 val_bpb:1.1873 train_time:501921ms step_avg:100.38ms +step:5100/20000 train_loss:2.1070 train_time:511964ms step_avg:100.39ms +step:5200/20000 train_loss:2.0199 train_time:521838ms step_avg:100.35ms +step:5300/20000 train_loss:1.9872 train_time:531912ms step_avg:100.36ms +step:5400/20000 train_loss:2.0243 train_time:541978ms step_avg:100.37ms +step:5500/20000 train_loss:1.9946 train_time:552042ms step_avg:100.37ms +step:5500/20000 val_loss:1.9803 val_bpb:1.1729 train_time:552068ms step_avg:100.38ms +step:5600/20000 train_loss:1.9317 train_time:562110ms step_avg:100.38ms +step:5700/20000 train_loss:1.9918 train_time:571993ms step_avg:100.35ms +step:5800/20000 train_loss:1.9743 train_time:582053ms step_avg:100.35ms +step:5900/20000 train_loss:1.8790 train_time:592115ms step_avg:100.36ms +step:5978/20000 val_loss:1.9606 val_bpb:1.1612 train_time:600035ms step_avg:100.37ms +stopping_early: wallclock_cap train_time:600035ms step:5978/20000 +peak memory allocated: 23759 MiB reserved: 24230 MiB +ema:applying EMA weights (decay=0.997) +Serialized model: 96864150 bytes +Code size: 78226 bytes +Total submission size: 96942376 bytes +Serialized model int6+zstd: 15732629 bytes +Total submission size int8+zlib: 15810855 bytes +/workspace/train_gpt.py:1600: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1600: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1600: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1600: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1600: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1600: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1600: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +/workspace/train_gpt.py:1600: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature. + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") +lora_ttt:starting rank=8 lr=0.01 epochs=2 chunk=256 +lora_ttt: fresh uncompiled model created for TTT +lora_ttt: 50000 total docs, rank0: 3857 long + 2393 short, batch=64 +lora_ttt: short_docs time=21.8s tokens=732712 +lora_ttt: batch 5/61 time=2.2s avg_loss=1.9969 +lora_ttt: batch 10/61 time=4.3s avg_loss=1.9709 +lora_ttt: batch 15/61 time=6.5s avg_loss=1.9496 +lora_ttt: batch 20/61 time=10.2s avg_loss=1.9154 +lora_ttt: batch 25/61 time=13.9s avg_loss=1.8991 +lora_ttt: batch 30/61 time=19.4s avg_loss=1.8793 +lora_ttt: batch 35/61 time=25.8s avg_loss=1.8638 +lora_ttt: batch 40/61 time=33.5s avg_loss=1.8512 +lora_ttt: batch 45/61 time=43.6s avg_loss=1.8402 +lora_ttt: batch 50/61 time=56.6s avg_loss=1.8338 +lora_ttt: batch 55/61 time=75.1s avg_loss=1.8225 +lora_ttt: batch 60/61 time=132.7s avg_loss=1.8165 +lora_ttt: long_docs time=153.0s docs=3857 +lora_ttt:complete loss:1.8330 bpb:1.0856 time:246.4s +final_lora_ttt val_loss:1.8330 val_bpb:1.0856 eval_time:246659ms +final_lora_ttt_exact val_loss:1.83298024 val_bpb:1.08559415 +results_saved:results/summary_seed_42.json +SUBMISSION_SUMMARY run_id:seed_42 post_quant_bpb:1.085594 artifact:15,732,629 code:78,226 total:15,810,855 fits:YES headroom:189,145