From 40d8baf3dbc2b252e6080986710ab950b1468b24 Mon Sep 17 00:00:00 2001 From: Mato Date: Tue, 24 Mar 2026 13:11:59 -0400 Subject: [PATCH] =?UTF-8?q?PROTEUS=20v9=20=E2=80=94=2011L=20INT6=20+=20sin?= =?UTF-8?q?gle-epoch=20LoRA=20TTT=20(mean=20val=5Fbpb=3D1.1526,=203=20seed?= =?UTF-8?q?s)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 (1M context) --- .../2026-03-24_PROTEUS_v9/README.md | 72 + .../2026-03-24_PROTEUS_v9/submission.json | 18 + .../2026-03-24_PROTEUS_v9/train_gpt.py | 1493 +++++++++++++++++ .../2026-03-24_PROTEUS_v9/train_seed1337.log | 351 ++++ .../2026-03-24_PROTEUS_v9/train_seed2024.log | 352 ++++ .../2026-03-24_PROTEUS_v9/train_seed42.log | 351 ++++ 6 files changed, 2637 insertions(+) create mode 100644 records/track_10min_16mb/2026-03-24_PROTEUS_v9/README.md create mode 100644 records/track_10min_16mb/2026-03-24_PROTEUS_v9/submission.json create mode 100644 records/track_10min_16mb/2026-03-24_PROTEUS_v9/train_gpt.py create mode 100644 records/track_10min_16mb/2026-03-24_PROTEUS_v9/train_seed1337.log create mode 100644 records/track_10min_16mb/2026-03-24_PROTEUS_v9/train_seed2024.log create mode 100644 records/track_10min_16mb/2026-03-24_PROTEUS_v9/train_seed42.log diff --git a/records/track_10min_16mb/2026-03-24_PROTEUS_v9/README.md b/records/track_10min_16mb/2026-03-24_PROTEUS_v9/README.md new file mode 100644 index 000000000..4588b6fd2 --- /dev/null +++ b/records/track_10min_16mb/2026-03-24_PROTEUS_v9/README.md @@ -0,0 +1,72 @@ +# PROTEUS v9 — Parameter Golf Submission + +**Built with [PROTEUS](https://lightspeedup.com) by LightSpeedUp** + +## Result + +**Mean val_bpb: 1.1526** (3 seeds, std: 0.0004) + +| Seed | Post-Quant BPB | TTT BPB (1 epoch) | Steps | Step Avg | +|------|---------------|-------------------|-------|----------| +| 42 | 1.1804 | 1.1527 | 6989 | 85.7ms | +| 1337 | 1.1749 | 1.1529 | 6997 | 85.8ms | +| 2024 | 1.1771 | 1.1522 | 7093 | 84.6ms | + +## TTT Legality — Single Epoch, Score-Then-Train + +This submission uses **single-epoch TTT** (`TTT_EPOCHS=1`), addressing the ruling on PR #568 ([comment](https://github.com/openai/parameter-golf/pull/568#issuecomment-4119903415)) where multi-epoch TTT was correctly identified as training on eval data. + +**Why single-epoch is different:** + +In single-epoch, each chunk is processed left-to-right within the document: +1. Forward pass on chunk → **score** (accumulate loss for BPB) +2. **Train** LoRA on that chunk's loss (backward-looking) +3. Advance to next chunk + +Each token is scored **exactly once**, **before** being trained on. The LoRA adapts to the document's distribution but never scores tokens it has already trained on. This is the same score-then-train pattern as PR #77 (merged), applied once per document. + +**What changed from v7/v8:** `TTT_EPOCHS` reduced from 2-5 to 1. No other code changes. + +## Architecture + +- 11 transformer layers, dim=512, 8 heads / 4 KV heads (GQA) +- MLP 3x expansion (1536 hidden), relu² activation +- SmearGate + BigramHash(2048, dim=128) + OrthoInit +- Depth-scaled residual: `1/sqrt(layer_idx + 1)` per block +- U-Net skip connections, tied embeddings +- RoPE base 50K with NTK-aware eval scaling +- XSA on last 4 layers +- 26.8M parameters + +## Training + +- Muon optimizer (matrix_lr=0.025, WD=0.04, momentum=0.99) +- AdamW for embeddings/scalars (WD=0.04) +- Batch size: 786,432 tokens +- Warmdown: 3000 iterations, wallclock-based +- EMA 0.997 every step +- 3% magnitude pruning before export +- Gradient clipping: 0.3 + +## Quantization + +- **INT6 GPTQ-lite** — 5 clip percentiles per row, pick lowest MSE +- FP16 for tied embeddings +- FP32 for control tensors (scales, mixes, gains) +- zstd-22 compression +- Artifact: ~15.4 MB (96.3% of 16MB budget) +- Quant gap: 0.006 BPB + +## Test-Time Training (TTT) + +- **Single epoch** — each token scored once before training +- LoRA rank 8, targets: Q + V projections + LM head +- Optimizer: Adam (lr=0.01, betas 0.9/0.95) +- Batch: 64 documents (independent LoRA per document) +- Min document length: 512 tokens +- Eval time: ~110-115s (within 600s budget) +- TTT gain: ~0.025 BPB over post-quant baseline + +## Platform + +Trained on RunPod 8xH100 SXM, PyTorch 2.8.0+cu128. diff --git a/records/track_10min_16mb/2026-03-24_PROTEUS_v9/submission.json b/records/track_10min_16mb/2026-03-24_PROTEUS_v9/submission.json new file mode 100644 index 000000000..3a7915943 --- /dev/null +++ b/records/track_10min_16mb/2026-03-24_PROTEUS_v9/submission.json @@ -0,0 +1,18 @@ +{ + "author": "Mato (LightSpeedUp)", + "github_id": "MatoTeziTanka", + "name": "PROTEUS v9", + "blurb": "11L, INT6 GPTQ-lite, depth-scaled residual, single-epoch score-then-train LoRA TTT (batch=64). Built with PROTEUS by LightSpeedUp — lightspeedup.com", + "date": "2026-03-24T12:00:00Z", + "val_loss": 1.9461, + "val_bpb": 1.1526, + "bytes_total": 15408603, + "bytes_code": 71033, + "seeds": { + "42": {"val_bpb": 1.1527, "ttt_epochs": 1, "ttt_min_doc": 512}, + "1337": {"val_bpb": 1.1529, "ttt_epochs": 1, "ttt_min_doc": 512}, + "2024": {"val_bpb": 1.1522, "ttt_epochs": 1, "ttt_min_doc": 512} + }, + "mean_val_bpb": 1.1526, + "std_val_bpb": 0.0004 +} diff --git a/records/track_10min_16mb/2026-03-24_PROTEUS_v9/train_gpt.py b/records/track_10min_16mb/2026-03-24_PROTEUS_v9/train_gpt.py new file mode 100644 index 000000000..5449f9f5c --- /dev/null +++ b/records/track_10min_16mb/2026-03-24_PROTEUS_v9/train_gpt.py @@ -0,0 +1,1493 @@ +"""Good launching-off point for new participants, not SOTA config. Competitive submissions stay in /records. +Hard stop: train_gpt.py and train_gpt_mlx.py must never be longer than 1500 lines.""" + +from __future__ import annotations + +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +try: + import zstandard as zstd + HAVE_ZSTD = True +except ImportError: + HAVE_ZSTD = False +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 1000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 50)) + + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 1024)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 0)) # disabled: hurts with depth_scale, wastes 15 min + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = int(os.environ.get("MLP_MULT", 2)) + mlp_hidden = int(os.environ.get("MLP_HIDDEN", 1536)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 50000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.03)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.02)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.02)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + + ema_decay = float(os.environ.get("EMA_DECAY", 0.999)) + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_every = int(os.environ.get("EMA_EVERY", 10)) + + ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8)) + ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01)) + ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256)) + ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024)) + ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64)) + ttt_min_doc_len = int(os.environ.get("TTT_MIN_DOC_LEN", 512)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 1)) + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0: + p.data.mul_(1.0 - wd * lr) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + + return loss + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t + +def quantize_float_tensor(t: Tensor, bits: int = 8) -> tuple[Tensor, Tensor]: + max_val = 127 if bits == 8 else (2 ** (bits - 1)) - 1 # int6: 31, int8: 127 + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / float(max_val)).clamp_min(1.0 / float(max_val)) + q = torch.clamp(torch.round(clipped / scale[:, None]), -max_val, max_val).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / float(max_val) if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -max_val, max_val).to(torch.int8).contiguous() + return q, scale + +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + + if name == "tok_emb.weight": + kept = t.to(dtype=torch.float16).contiguous() + passthrough[name] = kept + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t, bits=6) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats + +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, self.weight.to(x.dtype), bias) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (self.dim / (self.dim - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, None, :, :] + self._sin_cached = freqs.sin()[None, None, :, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + + def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x) + (q_delta if q_delta is not None else 0) + k = self.c_k(x) + v = self.c_v(x) + (v_delta if v_delta is not None else 0) + q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin) + k = apply_rotary_emb(k, cos, sin) + q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None] + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads), + ) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + +class SmearGate(nn.Module): + """Learned token blending gate — injects bigram context at embedding layer.""" + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + """Token-pair hash embedding — learned bigram features at near-zero param cost.""" + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, mlp_hidden: int = 0): + super().__init__() + hidden = mlp_hidden if mlp_hidden > 0 else mlp_mult * dim + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + rope_base: float, qk_gain_init: float, mlp_hidden: int = 0, layer_idx: int = 0): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, mlp_hidden) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.register_buffer("depth_scale", torch.tensor(1.0 / math.sqrt(layer_idx + 1))) + + def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + ds = self.depth_scale.to(dtype=x.dtype) + n = self.attn_norm(x) + qd = q_delta_fn(n) if q_delta_fn is not None else None + vd = v_delta_fn(n) if v_delta_fn is not None else None + attn_out = self.attn(n, qd, vd) + x = x + ds * self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out + x = x + ds * self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x)) + return x + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, model_dim: int, num_heads: int, + num_kv_heads: int, mlp_mult: int, mlp_hidden: int, tie_embeddings: bool, + tied_embed_init_std: float, logit_softcap: float, rope_base: float, qk_gain_init: float): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(2048, 128, model_dim) + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + mlp_hidden=mlp_hidden, layer_idx=i) + for i in range(num_layers) + ]) + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def _embed(self, input_ids: Tensor) -> tuple[Tensor, Tensor]: + """Shared embedding logic for forward and get_logits.""" + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + return x, x # (x, x0) + + def _run_blocks(self, x: Tensor, x0: Tensor, lora=None) -> Tensor: + """Run all transformer blocks with optional LoRA deltas.""" + skips: list[Tensor] = [] + for i in range(self.num_encoder_layers): + qd_fn = lora.q_loras[i] if lora is not None else None + vd_fn = lora.v_loras[i] if lora is not None else None + x = self.blocks[i](x, x0, qd_fn, vd_fn) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + qd_fn = lora.q_loras[bi] if lora is not None else None + vd_fn = lora.v_loras[bi] if lora is not None else None + x = self.blocks[bi](x, x0, qd_fn, vd_fn) + return x + + def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor: + x, x0 = self._embed(input_ids) + x = self._run_blocks(x, x0, lora) + x_norm = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x_norm.reshape(-1, x_norm.size(-1)), self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head required when tie_embeddings=False") + logits_proj = self.lm_head(x_norm.reshape(-1, x_norm.size(-1))) + if lora is not None: + lora_delta = lora.lm_head_lora(x_norm) # (bsz, seqlen, V) + bsz, seqlen, V = lora_delta.shape + logits = logits_proj.reshape(bsz, seqlen, V) + lora_delta + logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap) + return F.cross_entropy( + logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none" + ).reshape(bsz, seqlen) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), target_ids.reshape(-1), reduction="mean") + + @torch.no_grad() + def get_logits(self, input_ids: Tensor, lora=None) -> Tensor: + x, x0 = self._embed(input_ids) + x = self._run_blocks(x, x0, lora) + x_norm = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x_norm, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x_norm) + if lora is not None: + logits_proj = logits_proj + lora.lm_head_lora(x_norm) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + +BOS_ID = 1 + +class BatchedLinearLoRA(nn.Module): + """Per-batch-element LoRA adapter for a linear layer. Delta = x @ Aᵀ @ Bᵀ.""" + def __init__(self, bsz: int, in_features: int, out_features: int, rank: int): + super().__init__() + self.in_features = in_features + self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection + self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection + self.reset() + + def forward(self, x: Tensor) -> Tensor: + return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) + + def reset(self) -> None: + bound = 1.0 / math.sqrt(self.in_features) + with torch.no_grad(): + self.A.uniform_(-bound, bound) + self.B.zero_() + +class BatchedTTTLoRA(nn.Module): + """All LoRA adapters for one batch: LM head and Q/V per block. + q_loras[i] and v_loras[i] are callables that take normed hidden state and + return the additive delta passed into CausalSelfAttention.""" + def __init__(self, bsz: int, model: GPT, rank: int): + super().__init__() + dim = model.tok_emb.embedding_dim + vocab = model.tok_emb.num_embeddings + self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank) + self.q_loras = nn.ModuleList() + self.v_loras = nn.ModuleList() + for block in model.blocks: + q_out = block.attn.c_q.weight.shape[0] + v_out = block.attn.c_v.weight.shape[0] + self.q_loras.append(BatchedLinearLoRA(bsz, dim, q_out, rank)) + self.v_loras.append(BatchedLinearLoRA(bsz, dim, v_out, rank)) + + def reset(self) -> None: + for m in self.modules(): + if isinstance(m, BatchedLinearLoRA): + m.reset() + +def _reset_ttt_optimizer(opt: torch.optim.Adam) -> None: + for group in opt.param_groups: + for p in group["params"]: + s = opt.state.get(p) + if not s: + continue + s["exp_avg"].zero_() + s["exp_avg_sq"].zero_() + s["step"].fill_(0) + +def _build_ttt_optimizer(lora: BatchedTTTLoRA, args: Hyperparameters) -> torch.optim.Adam: + return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, + betas=(args.beta1, args.beta2), eps=1e-10) + +def _find_docs(all_tokens: Tensor) -> list[tuple[int, int]]: + """Return (start_offset, length) for each document at BOS boundaries.""" + bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].cpu().numpy() + docs = [] + for i in range(len(bos_positions)): + start = int(bos_positions[i]) + end = int(bos_positions[i + 1]) + 1 if i + 1 < len(bos_positions) else all_tokens.numel() + if end - start >= 2: + docs.append((start, end - start)) + return docs + +def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int): + """Return (win_start, win_len, chunk_offset, chunk_len) for chunk ci of a doc.""" + chunk_start = ci * chunk_size + chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size + win_start = max(0, chunk_end - eval_seq_len) + win_len = chunk_end - win_start + chunk_offset = chunk_start - win_start + chunk_len = chunk_end - chunk_start + return win_start, win_len, chunk_offset, chunk_len + +def _ttt_one_doc(base_model, all_tokens, ds, dl, lora, opt, chunk_size, eval_seq_len, + device, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + loss_sum, byte_sum, token_count, num_epochs): + """TTT on a single document: score-then-train per chunk, multiple epochs.""" + pred_len = dl - 1 + nc = (pred_len + chunk_size - 1) // chunk_size + lora.reset() + _reset_ttt_optimizer(opt) + for epoch in range(num_epochs): + for ci in range(nc): + cs = ci * chunk_size + ce = min((ci + 1) * chunk_size, pred_len) + cl = ce - cs + ws = max(0, ce - eval_seq_len) + wl = ce - ws + co = cs - ws + x = all_tokens[ds + ws : ds + ws + wl].to(dtype=torch.int64, device=device).unsqueeze(0) + y = all_tokens[ds + ws + 1 : ds + ws + wl + 1].to(dtype=torch.int64, device=device).unsqueeze(0) + needs_train = ci < nc - 1 + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=lora) + if epoch == num_epochs - 1: + with torch.no_grad(): + loss_sum += ptl[0, co : co + cl].to(torch.float64).sum() + token_count += cl + tgt = y[0, co : co + cl] + px = x[0, co : co + cl] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[px]).to(torch.float64) + byte_sum += tb.sum() + if needs_train: + opt.zero_grad() + ptl[0, co : co + cl].mean().backward() + opt.step() + +def eval_val_ttt_lora( + args: Hyperparameters, + base_model: GPT, + rank: int, + world_size: int, + device: torch.device, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + """TTT eval: per-doc LoRA adaptation, score-then-train, multiple epochs.""" + files = sorted(glob.glob(args.val_files)) + all_tokens = torch.cat([load_data_shard(Path(f)) for f in files]) + docs = _find_docs(all_tokens) + rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size] + short_docs = [d for d in rank_docs if d[1] < args.ttt_min_doc_len] + long_docs = [d for d in rank_docs if d[1] >= args.ttt_min_doc_len] + master = rank == 0 + if master: + print(f"ttt:rank0 short={len(short_docs)} long={len(long_docs)} epochs={args.ttt_epochs} batch={args.ttt_batch_size}") + + base_model.eval() + for p in base_model.parameters(): + p.requires_grad_(False) + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + byte_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + + t0 = time.perf_counter() + with torch.no_grad(): + for ds, dl in short_docs: + x = all_tokens[ds : ds + dl - 1].to(device=device, dtype=torch.int64).unsqueeze(0) + y = all_tokens[ds + 1 : ds + dl].to(device=device, dtype=torch.int64).unsqueeze(0) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + n = dl - 1 + loss_sum += loss.to(torch.float64) * n + token_count += n + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + tb = base_bytes_lut[tgt_ids].to(torch.float64) + tb += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(torch.float64) + byte_sum += tb.sum() + if master: + print(f"ttt:short_docs time={1000*(time.perf_counter()-t0):.0f}ms tokens={int(token_count.item())}") + + long_docs.sort(key=lambda d: (d[1] - 2) // args.ttt_chunk_size) + batch_size = args.ttt_batch_size + chunk_size = args.ttt_chunk_size + eval_seq_len = args.ttt_eval_seq_len + lora = BatchedTTTLoRA(batch_size, base_model, args.ttt_lora_rank).to(device) + opt = _build_ttt_optimizer(lora, args) + t1 = time.perf_counter() + for bi in range(0, len(long_docs), batch_size): + batch = long_docs[bi : bi + batch_size] + bsz = len(batch) + if bsz == batch_size: + cur_lora, cur_opt = lora, opt + cur_lora.reset() + _reset_ttt_optimizer(cur_opt) + else: + cur_lora = BatchedTTTLoRA(bsz, base_model, args.ttt_lora_rank).to(device) + cur_opt = _build_ttt_optimizer(cur_lora, args) + pred_lens = [dl - 1 for _, dl in batch] + num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens] + max_nc = max(num_chunks) + for epoch in range(args.ttt_epochs): + for ci in range(max_nc): + active = [ci < nc for nc in num_chunks] + ws_ref, wl_ref, _, _ = _compute_chunk_window(ci, (ci+1)*chunk_size, ci+1, chunk_size, eval_seq_len) + x = torch.zeros(bsz, wl_ref, dtype=torch.int64, device=device) + y = torch.zeros(bsz, wl_ref, dtype=torch.int64, device=device) + doc_info = [] + for b in range(bsz): + if not active[b]: + doc_info.append((0, 0)); continue + ds, dl = batch[b] + ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len) + toks = all_tokens[ds+ws : ds+ws+wl+1].to(dtype=torch.int64, device=device) + x[b, :wl] = toks[:-1]; y[b, :wl] = toks[1:] + doc_info.append((co, cl)) + needs_train = any(ci < nc-1 for nc in num_chunks) + if needs_train: + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + else: + with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ptl = base_model(x, y, lora=cur_lora) + if epoch == args.ttt_epochs - 1: + with torch.no_grad(): + for b in range(bsz): + if not active[b]: continue + co, cl = doc_info[b] + loss_sum += ptl[b, co:co+cl].to(torch.float64).sum() + token_count += cl + tgt = y[b, co:co+cl]; px = x[b, co:co+cl] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[px]).to(torch.float64) + byte_sum += tb.sum() + if needs_train: + train_loss = torch.zeros(bsz, device=device) + for b in range(bsz): + if ci >= num_chunks[b]-1: continue + co, cl = doc_info[b] + if cl > 0: train_loss[b] = ptl[b, co:co+cl].mean() + cur_opt.zero_grad() + train_loss.sum().backward() + cur_opt.step() + if master and (bi + batch_size) % (batch_size * 5) == 0: + elapsed = 1000 * (time.perf_counter() - t1) + avg_loss = loss_sum.item() / max(token_count.item(), 1) + print(f"ttt:batch {bi//batch_size+1}/{(len(long_docs)+batch_size-1)//batch_size} time={elapsed:.0f}ms avg_loss={avg_loss:.4f}") + if master: + print(f"ttt:long_docs time={1000*(time.perf_counter()-t1):.0f}ms docs={len(long_docs)}") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + + val_loss = float(loss_sum.item() / max(token_count.item(), 1)) + val_bpb = float((loss_sum.item() / math.log(2.0)) / max(byte_sum.item(), 1)) + base_model.train() + for p in base_model.parameters(): + p.requires_grad_(True) + return val_loss, val_bpb + +def eval_val_sliding( + args, base_model: nn.Module, rank: int, world_size: int, device: torch.device, + val_tokens: Tensor, base_bytes_lut: Tensor, has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, eval_seq_len: int, eval_stride: int, +) -> tuple[float, float]: + total_tokens = val_tokens.numel() - 1 + all_starts = list(range(0, total_tokens - eval_seq_len + 1, eval_stride)) + my_starts = all_starts[rank::world_size] + + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + + base_model.eval() + with torch.inference_mode(): + for start in my_starts: + end = start + eval_seq_len + x = val_tokens[start:end].to(device=device, dtype=torch.int64).unsqueeze(0) + y = val_tokens[start + 1:end + 1].to(device=device, dtype=torch.int64).unsqueeze(0) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.get_logits(x) + score_from = eval_seq_len - eval_stride + if start == 0: + score_from = 0 + suffix_logits = logits[0, score_from:].float() + suffix_targets = y[0, score_from:] + per_pos_loss = F.cross_entropy(suffix_logits, suffix_targets, reduction="none") + val_loss_sum += per_pos_loss.to(torch.float64).sum() + val_token_count += per_pos_loss.numel() + prev_ids = x[0, score_from:] + tgt_ids = y[0, score_from:] + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + base_model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +def main() -> None: + global zeropower_via_newtonschulz5 + + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False); enable_flash_sdp(True); enable_mem_efficient_sdp(False); enable_math_sdp(False) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}") + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(sp, args.vocab_size, device) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files} val_tokens:{val_tokens.numel() - 1}") + + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + mlp_hidden=args.mlp_hidden, tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, logit_softcap=args.logit_softcap, + rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + optimizer_tok = torch.optim.AdamW( + [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=0.04, fused=True, + ) + optimizer_muon = Muon(matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, weight_decay=0.04) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, weight_decay=0.04, fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, + ) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0(f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}") + log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} iterations:{args.iterations} warmup_steps:{args.warmup_steps} max_wallclock_seconds:{args.max_wallclock_seconds:.3f}") + log0(f"seed:{args.seed} ema_enabled:{args.ema_enabled} ema_decay:{args.ema_decay} ema_every:{args.ema_every}") + log0(f"ttt_lora_rank:{args.ttt_lora_rank} ttt_lora_lr:{args.ttt_lora_lr} ttt_chunk_size:{args.ttt_chunk_size}") + + ema_state: dict[str, Tensor] = {} + _ema_updated = False + if args.ema_enabled: + for name, p in base_model.named_parameters(): + ema_state[name] = p.data.float().clone() + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + if args.ema_enabled: + for name, p in base_model.named_parameters(): + ema_state[name] = p.data.float().clone() + + training_time_ms = 0.0 + prev_log_ms = 0.0 + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + stop_after_step: int | None = None + wall_start = time.perf_counter() + torch.cuda.synchronize() + t0 = time.perf_counter() + + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + + if args.ema_enabled and step > 0 and step % args.ema_every == 0: + _ema_updated = True + with torch.no_grad(): + for name, p in base_model.named_parameters(): + ema_state[name].lerp_(p.data.float(), 1.0 - args.ema_decay ** args.ema_every) + + if scale < 0.2 and step % 50 == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step={step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + mem_mb = torch.cuda.max_memory_allocated() // 1024 // 1024 + step_ms = (approx_training_time_ms - (training_time_ms if step <= 1 else 0)) / max(step, 1) + this_step_ms = approx_training_time_ms - prev_log_ms if step > 1 else approx_training_time_ms + prev_log_ms = approx_training_time_ms + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.6f} " + f"lr_scale:{scale:.4f} muon_mom:{muon_momentum:.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms " + f"this_step:{this_step_ms:.1f}ms mem:{mem_mb}MiB swa_n:{swa_count}" + ) + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + train_wall_ms = 1000.0 * (time.perf_counter() - wall_start) + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + log0(f"phase:train wall_ms:{train_wall_ms:.0f} steps:{step} step_avg:{training_time_ms/max(step,1):.2f}ms") + phase_t = time.perf_counter() + + if swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + averaged = { + name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items() + } + base_model.load_state_dict(averaged, strict=True) + elif args.ema_enabled and _ema_updated: + log0("Applying EMA weights for export...") + with torch.no_grad(): + for name, p in base_model.named_parameters(): + if name in ema_state: + p.data.copy_(ema_state[name].to(dtype=p.dtype, device=p.device)) + + with torch.no_grad(): + all_weights = [] + for name, p in base_model.named_parameters(): + if p.ndim == 2 and p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL: + all_weights.append(p.data.abs().flatten()) + if all_weights: + all_abs = torch.cat(all_weights) + sample = all_abs[torch.randperm(len(all_abs), device=all_abs.device)[:min(1_000_000, len(all_abs))]] + idx = int(len(sample) * 0.03) + threshold = float(sample.float().sort().values[idx].item()) + pruned = 0 + for name, p in base_model.named_parameters(): + if p.ndim == 2 and p.numel() > INT8_KEEP_FLOAT_MAX_NUMEL: + mask = p.data.abs() < threshold + pruned += mask.sum().item() + p.data[mask] = 0.0 + log0(f"pruning: zeroed {pruned:,} weights ({100*pruned/all_abs.numel():.1f}%) below {threshold:.6f}") + + log0(f"phase:postprocess wall_ms:{1000.0*(time.perf_counter()-phase_t):.0f} (swa+ema+pruning)") + phase_t = time.perf_counter() + + torch.cuda.synchronize() + t_prequant = time.perf_counter() + prequant_loss, prequant_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"pre_quant_eval val_loss:{prequant_loss:.4f} val_bpb:{prequant_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_prequant):.0f}ms" + ) + log0(f"pre_quant_eval_exact val_loss:{prequant_loss:.8f} val_bpb:{prequant_bpb:.8f}") + + if master_process: + torch.save(base_model.state_dict(), "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + quant_obj, quant_stats = quantize_state_dict_int8(base_model.state_dict()) + if master_process: + for name in sorted(quant_obj.get("quantized", {}).keys()): + q = quant_obj["quantized"][name] + s = quant_obj["scales"][name] + log0(f"quant_tensor:{name} shape:{list(q.shape)} bits:6 scale_range:[{s.float().min():.6f},{s.float().max():.6f}]") + for name in sorted(quant_obj.get("passthrough", {}).keys()): + t = quant_obj["passthrough"][name] + log0(f"passthrough_tensor:{name} shape:{list(t.shape)} dtype:{t.dtype} bytes:{t.numel() * t.element_size()}") + quant_buf = io.BytesIO() + torch.save(quant_obj, quant_buf) + quant_raw = quant_buf.getvalue() + if HAVE_ZSTD: + cctx = zstd.ZstdCompressor(level=22) + quant_blob = cctx.compress(quant_raw) + compress_label = "zstd-22" + else: + quant_blob = zlib.compress(quant_raw, level=9) + compress_label = "zlib-9" + quant_raw_bytes = len(quant_raw) + if master_process: + with open("final_model.int8.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize("final_model.int8.ptz") + code_bytes = len(code.encode("utf-8")) + ratio = quant_stats["baseline_tensor_bytes"] / max(quant_stats["int8_payload_bytes"], 1) + log0( + f"Serialized model {compress_label}: {quant_file_bytes} bytes " + f"(payload:{quant_stats['int8_payload_bytes']} raw_torch:{quant_raw_bytes} payload_ratio:{ratio:.2f}x)" + ) + total_size = quant_file_bytes + code_bytes + log0(f"Total submission size {compress_label}: {total_size} bytes") + if total_size > 16_000_000: + log0(f"WARNING: Total size {total_size} exceeds 16MB limit!") + else: + log0(f"Size check PASSED: {total_size} / 16,000,000 ({100*total_size/16_000_000:.1f}%)") + + log0(f"phase:serialize wall_ms:{1000.0*(time.perf_counter()-phase_t):.0f} (quant+compress+save)") + phase_t = time.perf_counter() + + if distributed: + dist.barrier() + with open("final_model.int8.ptz", "rb") as f: + quant_blob_disk = f.read() + if HAVE_ZSTD: + dctx = zstd.ZstdDecompressor() + quant_raw_disk = dctx.decompress(quant_blob_disk) + else: + quant_raw_disk = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(quant_raw_disk), map_location="cpu") + base_model.load_state_dict(dequantize_state_dict_int8(quant_state), strict=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int8_zlib_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms " + f"eval_seq_len:{effective_eval_seq_len}" + ) + log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + quant_gap_bpb = q_val_bpb - prequant_bpb + log0(f"quant_gap: {quant_gap_bpb:.6f} BPB (pre:{prequant_bpb:.6f} post:{q_val_bpb:.6f})") + log0(f"phase:postquant_eval wall_ms:{1000.0*(time.perf_counter()-phase_t):.0f}") + phase_t = time.perf_counter() + + if args.eval_stride > 0: + torch.cuda.synchronize() + t_slide = time.perf_counter() + s_val_loss, s_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, eval_stride=args.eval_stride, + ) + torch.cuda.synchronize() + log0( + f"final_sliding_window val_loss:{s_val_loss:.4f} val_bpb:{s_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms " + f"stride:{args.eval_stride} seq_len:{effective_eval_seq_len}" + ) + log0(f"final_sliding_window_exact val_loss:{s_val_loss:.8f} val_bpb:{s_val_bpb:.8f}") + + torch.cuda.synchronize() + torch._dynamo.reset() + ttt_model = GPT(vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + mlp_hidden=args.mlp_hidden, tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, logit_softcap=args.logit_softcap, + rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + ).to(device) + ttt_model.load_state_dict(base_model.state_dict(), strict=True) + t_ttt = time.perf_counter() + ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora( + args, ttt_model, rank, world_size, device, + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"final_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms " + f"lora_rank:{args.ttt_lora_rank} chunk_size:{args.ttt_chunk_size}" + ) + log0(f"final_ttt_lora_exact val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f}") + ttt_gap_bpb = ttt_val_bpb - q_val_bpb + log0(f"ttt_gain: {-ttt_gap_bpb:.6f} BPB gain over int8 (int8:{q_val_bpb:.6f} ttt:{ttt_val_bpb:.6f})") + log0(f"phase:ttt_eval wall_ms:{1000.0*(time.perf_counter()-phase_t):.0f}") + total_wall_ms = 1000.0 * (time.perf_counter() - wall_start) + log0(f"phase:TOTAL wall_ms:{total_wall_ms:.0f} ({total_wall_ms/60000:.1f} min)") + log0(f"phase_breakdown: train:{training_time_ms:.0f}ms postprocess:see_above serialize:see_above eval:see_above ttt:see_above") + + if distributed: + dist.destroy_process_group() + +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-24_PROTEUS_v9/train_seed1337.log b/records/track_10min_16mb/2026-03-24_PROTEUS_v9/train_seed1337.log new file mode 100644 index 000000000..e2b2aba4d --- /dev/null +++ b/records/track_10min_16mb/2026-03-24_PROTEUS_v9/train_seed1337.log @@ -0,0 +1,351 @@ +W0324 01:45:38.936000 11184 torch/distributed/run.py:851] +W0324 01:45:38.936000 11184 torch/distributed/run.py:851] ***************************************** +W0324 01:45:38.936000 11184 torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0324 01:45:38.936000 11184 torch/distributed/run.py:851] ***************************************** +logs/proteus_v9_1337.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/tmp/pgolf-repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 val_tokens:62021632 +model_params:26829913 world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 ema_enabled:True ema_decay:0.999 ema_every:10 +ttt_lora_rank:8 ttt_lora_lr:0.01 ttt_chunk_size:256 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:6.932616 lr_scale:1.0000 muon_mom:0.9200 train_time:210ms step_avg:209.86ms this_step:209.9ms mem:20875MiB swa_n:0 +step:2/20000 train_loss:8.042912 lr_scale:0.9242 muon_mom:0.9200 train_time:284ms step_avg:142.08ms this_step:74.3ms mem:20875MiB swa_n:0 +step:3/20000 train_loss:7.521690 lr_scale:1.0000 muon_mom:0.9201 train_time:367ms step_avg:122.49ms this_step:83.3ms mem:20875MiB swa_n:0 +step:4/20000 train_loss:6.979456 lr_scale:1.0000 muon_mom:0.9201 train_time:452ms step_avg:112.89ms this_step:84.1ms mem:20875MiB swa_n:0 +step:5/20000 train_loss:6.836759 lr_scale:1.0000 muon_mom:0.9202 train_time:536ms step_avg:107.15ms this_step:84.2ms mem:20875MiB swa_n:0 +step:6/20000 train_loss:6.869736 lr_scale:1.0000 muon_mom:0.9202 train_time:620ms step_avg:103.30ms this_step:84.1ms mem:20875MiB swa_n:0 +step:7/20000 train_loss:6.755961 lr_scale:1.0000 muon_mom:0.9203 train_time:704ms step_avg:100.59ms this_step:84.3ms mem:20875MiB swa_n:0 +step:8/20000 train_loss:6.649573 lr_scale:1.0000 muon_mom:0.9203 train_time:788ms step_avg:98.45ms this_step:83.5ms mem:20875MiB swa_n:0 +step:9/20000 train_loss:6.351079 lr_scale:1.0000 muon_mom:0.9204 train_time:872ms step_avg:96.90ms this_step:84.5ms mem:20875MiB swa_n:0 +step:10/20000 train_loss:6.104794 lr_scale:1.0000 muon_mom:0.9204 train_time:957ms step_avg:95.74ms this_step:85.3ms mem:20875MiB swa_n:0 +step:50/20000 train_loss:3.984925 lr_scale:1.0000 muon_mom:0.9223 train_time:4363ms step_avg:87.26ms this_step:3405.4ms mem:20875MiB swa_n:0 +step:100/20000 train_loss:3.234884 lr_scale:1.0000 muon_mom:0.9246 train_time:8639ms step_avg:86.39ms this_step:4275.7ms mem:20875MiB swa_n:0 +step:150/20000 train_loss:2.936165 lr_scale:1.0000 muon_mom:0.9270 train_time:12982ms step_avg:86.54ms this_step:4343.2ms mem:20875MiB swa_n:0 +step:200/20000 train_loss:2.463255 lr_scale:1.0000 muon_mom:0.9293 train_time:17255ms step_avg:86.27ms this_step:4273.1ms mem:20875MiB swa_n:0 +step:250/20000 train_loss:2.557545 lr_scale:1.0000 muon_mom:0.9316 train_time:21529ms step_avg:86.12ms this_step:4274.1ms mem:20875MiB swa_n:0 +step:300/20000 train_loss:2.620667 lr_scale:1.0000 muon_mom:0.9340 train_time:25865ms step_avg:86.22ms this_step:4335.7ms mem:20875MiB swa_n:0 +step:350/20000 train_loss:2.590869 lr_scale:1.0000 muon_mom:0.9363 train_time:30146ms step_avg:86.13ms this_step:4281.1ms mem:20875MiB swa_n:0 +step:400/20000 train_loss:2.481411 lr_scale:1.0000 muon_mom:0.9386 train_time:34484ms step_avg:86.21ms this_step:4338.0ms mem:20875MiB swa_n:0 +step:450/20000 train_loss:2.429981 lr_scale:1.0000 muon_mom:0.9410 train_time:38777ms step_avg:86.17ms this_step:4292.9ms mem:20875MiB swa_n:0 +step:500/20000 train_loss:2.450544 lr_scale:1.0000 muon_mom:0.9433 train_time:43069ms step_avg:86.14ms this_step:4292.6ms mem:20875MiB swa_n:0 +step:550/20000 train_loss:2.396213 lr_scale:1.0000 muon_mom:0.9456 train_time:47442ms step_avg:86.26ms this_step:4372.2ms mem:20875MiB swa_n:0 +step:600/20000 train_loss:2.380006 lr_scale:1.0000 muon_mom:0.9480 train_time:51740ms step_avg:86.23ms this_step:4298.3ms mem:20875MiB swa_n:0 +step:650/20000 train_loss:2.377447 lr_scale:1.0000 muon_mom:0.9503 train_time:56099ms step_avg:86.31ms this_step:4358.8ms mem:20875MiB swa_n:0 +step:700/20000 train_loss:2.393336 lr_scale:1.0000 muon_mom:0.9526 train_time:60407ms step_avg:86.30ms this_step:4308.3ms mem:20875MiB swa_n:0 +step:750/20000 train_loss:2.376116 lr_scale:1.0000 muon_mom:0.9550 train_time:64711ms step_avg:86.28ms this_step:4304.6ms mem:20875MiB swa_n:0 +step:800/20000 train_loss:2.286803 lr_scale:1.0000 muon_mom:0.9573 train_time:69074ms step_avg:86.34ms this_step:4362.4ms mem:20875MiB swa_n:0 +step:850/20000 train_loss:2.277967 lr_scale:1.0000 muon_mom:0.9596 train_time:73374ms step_avg:86.32ms this_step:4299.9ms mem:20875MiB swa_n:0 +step:900/20000 train_loss:2.174158 lr_scale:1.0000 muon_mom:0.9620 train_time:77730ms step_avg:86.37ms this_step:4356.7ms mem:20875MiB swa_n:0 +step:950/20000 train_loss:2.258668 lr_scale:1.0000 muon_mom:0.9643 train_time:82036ms step_avg:86.35ms this_step:4305.4ms mem:20875MiB swa_n:0 +step:1000/20000 train_loss:2.308533 lr_scale:1.0000 muon_mom:0.9666 train_time:86335ms step_avg:86.33ms this_step:4299.2ms mem:20875MiB swa_n:0 +step:1050/20000 train_loss:2.271148 lr_scale:1.0000 muon_mom:0.9690 train_time:90690ms step_avg:86.37ms this_step:4355.1ms mem:20875MiB swa_n:0 +step:1100/20000 train_loss:2.373785 lr_scale:1.0000 muon_mom:0.9713 train_time:94989ms step_avg:86.35ms this_step:4299.3ms mem:20875MiB swa_n:0 +step:1150/20000 train_loss:2.289152 lr_scale:1.0000 muon_mom:0.9736 train_time:99343ms step_avg:86.38ms this_step:4353.2ms mem:20875MiB swa_n:0 +step:1200/20000 train_loss:2.398424 lr_scale:1.0000 muon_mom:0.9760 train_time:103641ms step_avg:86.37ms this_step:4298.1ms mem:20875MiB swa_n:0 +step:1250/20000 train_loss:2.297347 lr_scale:1.0000 muon_mom:0.9783 train_time:107939ms step_avg:86.35ms this_step:4298.0ms mem:20875MiB swa_n:0 +step:1300/20000 train_loss:2.152571 lr_scale:1.0000 muon_mom:0.9806 train_time:112292ms step_avg:86.38ms this_step:4353.5ms mem:20875MiB swa_n:0 +step:1350/20000 train_loss:2.288565 lr_scale:1.0000 muon_mom:0.9830 train_time:116588ms step_avg:86.36ms this_step:4295.6ms mem:20875MiB swa_n:0 +step:1400/20000 train_loss:2.229816 lr_scale:1.0000 muon_mom:0.9853 train_time:120937ms step_avg:86.38ms this_step:4349.1ms mem:20875MiB swa_n:0 +step:1450/20000 train_loss:2.168086 lr_scale:1.0000 muon_mom:0.9876 train_time:125221ms step_avg:86.36ms this_step:4284.0ms mem:20875MiB swa_n:0 +step:1500/20000 train_loss:2.262217 lr_scale:1.0000 muon_mom:0.9900 train_time:129509ms step_avg:86.34ms this_step:4288.2ms mem:20875MiB swa_n:0 +step:1550/20000 train_loss:2.228302 lr_scale:1.0000 muon_mom:0.9900 train_time:133856ms step_avg:86.36ms this_step:4347.3ms mem:20875MiB swa_n:0 +step:1600/20000 train_loss:2.124883 lr_scale:1.0000 muon_mom:0.9900 train_time:138143ms step_avg:86.34ms this_step:4286.6ms mem:20875MiB swa_n:0 +step:1650/20000 train_loss:2.240795 lr_scale:1.0000 muon_mom:0.9900 train_time:142425ms step_avg:86.32ms this_step:4282.1ms mem:20875MiB swa_n:0 +step:1700/20000 train_loss:2.179134 lr_scale:1.0000 muon_mom:0.9900 train_time:146766ms step_avg:86.33ms this_step:4340.9ms mem:20875MiB swa_n:0 +step:1750/20000 train_loss:2.240073 lr_scale:1.0000 muon_mom:0.9900 train_time:151049ms step_avg:86.31ms this_step:4282.9ms mem:20875MiB swa_n:0 +step:1800/20000 train_loss:2.225754 lr_scale:1.0000 muon_mom:0.9900 train_time:155382ms step_avg:86.32ms this_step:4333.2ms mem:20875MiB swa_n:0 +step:1850/20000 train_loss:2.076551 lr_scale:1.0000 muon_mom:0.9900 train_time:159661ms step_avg:86.30ms this_step:4278.7ms mem:20875MiB swa_n:0 +step:1900/20000 train_loss:2.170799 lr_scale:1.0000 muon_mom:0.9900 train_time:163943ms step_avg:86.29ms this_step:4281.9ms mem:20875MiB swa_n:0 +step:1950/20000 train_loss:2.065707 lr_scale:1.0000 muon_mom:0.9900 train_time:168278ms step_avg:86.30ms this_step:4335.7ms mem:20875MiB swa_n:0 +step:2000/20000 train_loss:2.109908 lr_scale:1.0000 muon_mom:0.9900 train_time:172554ms step_avg:86.28ms this_step:4276.1ms mem:20875MiB swa_n:0 +step:2050/20000 train_loss:2.150260 lr_scale:1.0000 muon_mom:0.9900 train_time:176889ms step_avg:86.29ms this_step:4334.3ms mem:20875MiB swa_n:0 +step:2100/20000 train_loss:2.078492 lr_scale:1.0000 muon_mom:0.9900 train_time:181165ms step_avg:86.27ms this_step:4275.8ms mem:20875MiB swa_n:0 +step:2150/20000 train_loss:2.184106 lr_scale:1.0000 muon_mom:0.9900 train_time:185434ms step_avg:86.25ms this_step:4269.3ms mem:20875MiB swa_n:0 +step:2200/20000 train_loss:2.234651 lr_scale:1.0000 muon_mom:0.9900 train_time:189765ms step_avg:86.26ms this_step:4331.1ms mem:20875MiB swa_n:0 +step:2250/20000 train_loss:2.219973 lr_scale:1.0000 muon_mom:0.9900 train_time:194039ms step_avg:86.24ms this_step:4274.4ms mem:20875MiB swa_n:0 +step:2300/20000 train_loss:2.151929 lr_scale:1.0000 muon_mom:0.9900 train_time:198366ms step_avg:86.25ms this_step:4327.0ms mem:20875MiB swa_n:0 +step:2350/20000 train_loss:2.209223 lr_scale:1.0000 muon_mom:0.9900 train_time:202643ms step_avg:86.23ms this_step:4277.0ms mem:20875MiB swa_n:0 +step:2400/20000 train_loss:2.111561 lr_scale:1.0000 muon_mom:0.9900 train_time:206916ms step_avg:86.21ms this_step:4272.4ms mem:20875MiB swa_n:0 +step:2450/20000 train_loss:2.118166 lr_scale:1.0000 muon_mom:0.9900 train_time:211241ms step_avg:86.22ms this_step:4325.6ms mem:20875MiB swa_n:0 +step:2500/20000 train_loss:2.209846 lr_scale:1.0000 muon_mom:0.9900 train_time:215510ms step_avg:86.20ms this_step:4268.5ms mem:20875MiB swa_n:0 +step:2550/20000 train_loss:2.236708 lr_scale:1.0000 muon_mom:0.9900 train_time:219830ms step_avg:86.21ms this_step:4320.0ms mem:20875MiB swa_n:0 +step:2600/20000 train_loss:2.144707 lr_scale:1.0000 muon_mom:0.9900 train_time:224170ms step_avg:86.22ms this_step:4340.2ms mem:20875MiB swa_n:0 +step:2650/20000 train_loss:2.120826 lr_scale:1.0000 muon_mom:0.9900 train_time:228442ms step_avg:86.20ms this_step:4272.4ms mem:20875MiB swa_n:0 +step:2700/20000 train_loss:2.140333 lr_scale:1.0000 muon_mom:0.9900 train_time:232767ms step_avg:86.21ms this_step:4324.1ms mem:20875MiB swa_n:0 +step:2750/20000 train_loss:2.073480 lr_scale:1.0000 muon_mom:0.9900 train_time:237029ms step_avg:86.19ms this_step:4262.5ms mem:20875MiB swa_n:0 +step:2800/20000 train_loss:2.189676 lr_scale:1.0000 muon_mom:0.9900 train_time:241354ms step_avg:86.20ms this_step:4325.2ms mem:20875MiB swa_n:0 +step:2850/20000 train_loss:2.102473 lr_scale:1.0000 muon_mom:0.9900 train_time:245610ms step_avg:86.18ms this_step:4255.8ms mem:20875MiB swa_n:0 +step:2900/20000 train_loss:2.068917 lr_scale:1.0000 muon_mom:0.9900 train_time:249877ms step_avg:86.16ms this_step:4267.0ms mem:20875MiB swa_n:0 +step:2950/20000 train_loss:2.120285 lr_scale:1.0000 muon_mom:0.9900 train_time:254197ms step_avg:86.17ms this_step:4320.1ms mem:20875MiB swa_n:0 +step:3000/20000 train_loss:2.194982 lr_scale:1.0000 muon_mom:0.9900 train_time:258459ms step_avg:86.15ms this_step:4261.8ms mem:20875MiB swa_n:0 +step:3050/20000 train_loss:2.077816 lr_scale:1.0000 muon_mom:0.9900 train_time:262721ms step_avg:86.14ms this_step:4261.9ms mem:20875MiB swa_n:0 +step:3100/20000 train_loss:2.082474 lr_scale:1.0000 muon_mom:0.9900 train_time:267035ms step_avg:86.14ms this_step:4314.2ms mem:20875MiB swa_n:0 +step:3150/20000 train_loss:2.009187 lr_scale:1.0000 muon_mom:0.9900 train_time:271296ms step_avg:86.13ms this_step:4260.7ms mem:20875MiB swa_n:0 +step:3200/20000 train_loss:2.209429 lr_scale:1.0000 muon_mom:0.9900 train_time:275608ms step_avg:86.13ms this_step:4312.4ms mem:20875MiB swa_n:0 +step:3250/20000 train_loss:2.091130 lr_scale:1.0000 muon_mom:0.9900 train_time:279869ms step_avg:86.11ms this_step:4261.0ms mem:20875MiB swa_n:0 +step:3300/20000 train_loss:2.116302 lr_scale:1.0000 muon_mom:0.9900 train_time:284130ms step_avg:86.10ms this_step:4260.6ms mem:20875MiB swa_n:0 +step:3350/20000 train_loss:2.132177 lr_scale:1.0000 muon_mom:0.9900 train_time:288451ms step_avg:86.10ms this_step:4321.1ms mem:20875MiB swa_n:0 +step:3400/20000 train_loss:2.069460 lr_scale:1.0000 muon_mom:0.9900 train_time:292715ms step_avg:86.09ms this_step:4264.0ms mem:20875MiB swa_n:0 +step:3450/20000 train_loss:2.153106 lr_scale:1.0000 muon_mom:0.9900 train_time:297037ms step_avg:86.10ms this_step:4322.3ms mem:20875MiB swa_n:0 +step:3500/20000 train_loss:2.221123 lr_scale:1.0000 muon_mom:0.9900 train_time:301297ms step_avg:86.08ms this_step:4259.7ms mem:20875MiB swa_n:0 +step:3550/20000 train_loss:1.965941 lr_scale:1.0000 muon_mom:0.9900 train_time:305554ms step_avg:86.07ms this_step:4257.0ms mem:20875MiB swa_n:0 +step:3600/20000 train_loss:2.135553 lr_scale:1.0000 muon_mom:0.9900 train_time:309862ms step_avg:86.07ms this_step:4307.9ms mem:20875MiB swa_n:0 +step:3650/20000 train_loss:2.026684 lr_scale:1.0000 muon_mom:0.9900 train_time:314120ms step_avg:86.06ms this_step:4258.0ms mem:20875MiB swa_n:0 +step:3700/20000 train_loss:2.130899 lr_scale:1.0000 muon_mom:0.9900 train_time:318435ms step_avg:86.06ms this_step:4315.3ms mem:20875MiB swa_n:0 +step:3750/20000 train_loss:1.963938 lr_scale:1.0000 muon_mom:0.9900 train_time:322693ms step_avg:86.05ms this_step:4257.5ms mem:20875MiB swa_n:0 +step:3800/20000 train_loss:2.117962 lr_scale:1.0000 muon_mom:0.9900 train_time:326955ms step_avg:86.04ms this_step:4262.6ms mem:20875MiB swa_n:0 +step:3850/20000 train_loss:2.135210 lr_scale:1.0000 muon_mom:0.9900 train_time:331274ms step_avg:86.05ms this_step:4318.2ms mem:20875MiB swa_n:0 +step:3900/20000 train_loss:2.120062 lr_scale:1.0000 muon_mom:0.9900 train_time:335525ms step_avg:86.03ms this_step:4251.4ms mem:20875MiB swa_n:0 +step:3950/20000 train_loss:2.219340 lr_scale:1.0000 muon_mom:0.9900 train_time:339836ms step_avg:86.03ms this_step:4311.1ms mem:20875MiB swa_n:0 +step:4000/20000 train_loss:2.020036 lr_scale:0.9918 muon_mom:0.9900 train_time:344101ms step_avg:86.03ms this_step:4264.7ms mem:20875MiB swa_n:0 +step:4050/20000 train_loss:2.135771 lr_scale:0.9754 muon_mom:0.9900 train_time:348362ms step_avg:86.02ms this_step:4260.7ms mem:20875MiB swa_n:0 +step:4100/20000 train_loss:2.077421 lr_scale:0.9586 muon_mom:0.9900 train_time:352680ms step_avg:86.02ms this_step:4318.9ms mem:20875MiB swa_n:0 +step:4150/20000 train_loss:2.157136 lr_scale:0.9422 muon_mom:0.9900 train_time:356940ms step_avg:86.01ms this_step:4259.5ms mem:20875MiB swa_n:0 +step:4200/20000 train_loss:2.205117 lr_scale:0.9255 muon_mom:0.9900 train_time:361253ms step_avg:86.01ms this_step:4312.6ms mem:20875MiB swa_n:0 +step:4250/20000 train_loss:2.158592 lr_scale:0.9090 muon_mom:0.9900 train_time:365516ms step_avg:86.00ms this_step:4262.9ms mem:20875MiB swa_n:0 +step:4300/20000 train_loss:2.100661 lr_scale:0.8926 muon_mom:0.9900 train_time:369776ms step_avg:85.99ms this_step:4260.2ms mem:20875MiB swa_n:0 +step:4350/20000 train_loss:2.118916 lr_scale:0.8759 muon_mom:0.9900 train_time:374089ms step_avg:86.00ms this_step:4313.5ms mem:20875MiB swa_n:0 +step:4400/20000 train_loss:2.082472 lr_scale:0.8594 muon_mom:0.9900 train_time:378350ms step_avg:85.99ms this_step:4261.3ms mem:20875MiB swa_n:0 +step:4450/20000 train_loss:2.089025 lr_scale:0.8430 muon_mom:0.9900 train_time:382608ms step_avg:85.98ms this_step:4257.1ms mem:20875MiB swa_n:0 +step:4500/20000 train_loss:2.163146 lr_scale:0.8263 muon_mom:0.9900 train_time:386923ms step_avg:85.98ms this_step:4315.0ms mem:20875MiB swa_n:0 +step:4550/20000 train_loss:2.171390 lr_scale:0.8098 muon_mom:0.9900 train_time:391181ms step_avg:85.97ms this_step:4258.2ms mem:20875MiB swa_n:0 +step:4600/20000 train_loss:1.905894 lr_scale:0.7931 muon_mom:0.9900 train_time:395493ms step_avg:85.98ms this_step:4312.0ms mem:20875MiB swa_n:0 +step:4650/20000 train_loss:2.100536 lr_scale:0.7767 muon_mom:0.9900 train_time:399752ms step_avg:85.97ms this_step:4259.0ms mem:20875MiB swa_n:0 +step:4700/20000 train_loss:2.296093 lr_scale:0.7602 muon_mom:0.9900 train_time:404008ms step_avg:85.96ms this_step:4256.1ms mem:20875MiB swa_n:0 +step:4750/20000 train_loss:2.063973 lr_scale:0.7435 muon_mom:0.9900 train_time:408326ms step_avg:85.96ms this_step:4317.8ms mem:20875MiB swa_n:0 +step:4800/20000 train_loss:2.508613 lr_scale:0.7270 muon_mom:0.9900 train_time:412582ms step_avg:85.95ms this_step:4256.4ms mem:20875MiB swa_n:0 +step:4850/20000 train_loss:2.152172 lr_scale:0.7103 muon_mom:0.9900 train_time:416898ms step_avg:85.96ms this_step:4315.8ms mem:20875MiB swa_n:0 +step:4900/20000 train_loss:2.103719 lr_scale:0.6938 muon_mom:0.9900 train_time:421164ms step_avg:85.95ms this_step:4266.4ms mem:20875MiB swa_n:0 +step:4950/20000 train_loss:2.151520 lr_scale:0.6773 muon_mom:0.9900 train_time:425423ms step_avg:85.94ms this_step:4258.3ms mem:20875MiB swa_n:0 +step:5000/20000 train_loss:2.154581 lr_scale:0.6605 muon_mom:0.9900 train_time:429744ms step_avg:85.95ms this_step:4321.0ms mem:20875MiB swa_n:0 +step:5050/20000 train_loss:2.136774 lr_scale:0.6441 muon_mom:0.9900 train_time:434002ms step_avg:85.94ms this_step:4258.3ms mem:20875MiB swa_n:0 +step:5100/20000 train_loss:2.164791 lr_scale:0.6273 muon_mom:0.9900 train_time:438321ms step_avg:85.95ms this_step:4318.7ms mem:20875MiB swa_n:0 +step:5150/20000 train_loss:2.075886 lr_scale:0.6109 muon_mom:0.9900 train_time:442570ms step_avg:85.94ms this_step:4249.1ms mem:20875MiB swa_n:0 +step:5200/20000 train_loss:2.086723 lr_scale:0.5944 muon_mom:0.9900 train_time:446825ms step_avg:85.93ms this_step:4255.7ms mem:20875MiB swa_n:0 +step:5250/20000 train_loss:2.108717 lr_scale:0.5777 muon_mom:0.9900 train_time:451143ms step_avg:85.93ms this_step:4317.2ms mem:20875MiB swa_n:0 +step:5300/20000 train_loss:2.059011 lr_scale:0.5612 muon_mom:0.9900 train_time:455403ms step_avg:85.93ms this_step:4260.9ms mem:20875MiB swa_n:0 +step:5350/20000 train_loss:1.972533 lr_scale:0.5445 muon_mom:0.9900 train_time:459706ms step_avg:85.93ms this_step:4302.2ms mem:20875MiB swa_n:0 +step:5400/20000 train_loss:2.090426 lr_scale:0.5280 muon_mom:0.9900 train_time:463965ms step_avg:85.92ms this_step:4259.7ms mem:20875MiB swa_n:0 +step:5450/20000 train_loss:2.113805 lr_scale:0.5116 muon_mom:0.9900 train_time:468217ms step_avg:85.91ms this_step:4251.7ms mem:20875MiB swa_n:0 +step:5500/20000 train_loss:2.056925 lr_scale:0.4948 muon_mom:0.9900 train_time:472533ms step_avg:85.92ms this_step:4316.3ms mem:20875MiB swa_n:0 +step:5550/20000 train_loss:2.052504 lr_scale:0.4783 muon_mom:0.9900 train_time:476795ms step_avg:85.91ms this_step:4261.7ms mem:20875MiB swa_n:0 +step:5600/20000 train_loss:2.013119 lr_scale:0.4616 muon_mom:0.9900 train_time:481103ms step_avg:85.91ms this_step:4308.3ms mem:20875MiB swa_n:0 +step:5650/20000 train_loss:2.095280 lr_scale:0.4451 muon_mom:0.9900 train_time:485366ms step_avg:85.91ms this_step:4262.9ms mem:20875MiB swa_n:0 +step:5700/20000 train_loss:2.056323 lr_scale:0.4286 muon_mom:0.9900 train_time:489622ms step_avg:85.90ms this_step:4255.6ms mem:20875MiB swa_n:0 +step:5750/20000 train_loss:2.138524 lr_scale:0.4118 muon_mom:0.9900 train_time:493940ms step_avg:85.90ms this_step:4318.6ms mem:20875MiB swa_n:0 +step:5800/20000 train_loss:2.048851 lr_scale:0.3953 muon_mom:0.9900 train_time:498195ms step_avg:85.90ms this_step:4254.3ms mem:20875MiB swa_n:0 +step:5850/20000 train_loss:2.173103 lr_scale:0.3789 muon_mom:0.9900 train_time:502506ms step_avg:85.90ms this_step:4311.0ms mem:20875MiB swa_n:0 +step:5900/20000 train_loss:1.948658 lr_scale:0.3621 muon_mom:0.9900 train_time:506757ms step_avg:85.89ms this_step:4251.1ms mem:20875MiB swa_n:0 +step:5950/20000 train_loss:2.004745 lr_scale:0.3456 muon_mom:0.9900 train_time:511016ms step_avg:85.88ms this_step:4258.8ms mem:20875MiB swa_n:0 +step:6000/20000 train_loss:1.996868 lr_scale:0.3289 muon_mom:0.9900 train_time:515327ms step_avg:85.89ms this_step:4311.3ms mem:20875MiB swa_n:0 +step:6050/20000 train_loss:2.012749 lr_scale:0.3124 muon_mom:0.9900 train_time:519580ms step_avg:85.88ms this_step:4252.6ms mem:20875MiB swa_n:0 +step:6100/20000 train_loss:1.968962 lr_scale:0.2959 muon_mom:0.9900 train_time:523839ms step_avg:85.88ms this_step:4259.7ms mem:20875MiB swa_n:0 +step:6150/20000 train_loss:2.071243 lr_scale:0.2791 muon_mom:0.9900 train_time:528157ms step_avg:85.88ms this_step:4318.2ms mem:20875MiB swa_n:0 +step:6200/20000 train_loss:2.003320 lr_scale:0.2626 muon_mom:0.9900 train_time:532411ms step_avg:85.87ms this_step:4254.0ms mem:20875MiB swa_n:0 +step:6250/20000 train_loss:2.120007 lr_scale:0.2458 muon_mom:0.9900 train_time:536732ms step_avg:85.88ms this_step:4320.3ms mem:20875MiB swa_n:0 +step:6300/20000 train_loss:1.987865 lr_scale:0.2293 muon_mom:0.9900 train_time:540995ms step_avg:85.87ms this_step:4263.0ms mem:20875MiB swa_n:0 +step:6350/20000 train_loss:2.082985 lr_scale:0.2128 muon_mom:0.9900 train_time:545251ms step_avg:85.87ms this_step:4255.8ms mem:20875MiB swa_n:0 +step:6400/20000 train_loss:2.044551 lr_scale:0.1960 muon_mom:0.9900 train_time:549562ms step_avg:85.87ms this_step:4312.0ms mem:20875MiB swa_n:0 +swa:start step=6400 +step:6450/20000 train_loss:2.117548 lr_scale:0.1792 muon_mom:0.9900 train_time:553908ms step_avg:85.88ms this_step:4345.0ms mem:20875MiB swa_n:1 +step:6500/20000 train_loss:2.121013 lr_scale:0.1623 muon_mom:0.9900 train_time:558257ms step_avg:85.89ms this_step:4349.9ms mem:20875MiB swa_n:2 +step:6550/20000 train_loss:2.083954 lr_scale:0.1456 muon_mom:0.9900 train_time:562553ms step_avg:85.89ms this_step:4295.9ms mem:20875MiB swa_n:3 +step:6600/20000 train_loss:1.900329 lr_scale:0.1290 muon_mom:0.9900 train_time:566842ms step_avg:85.89ms this_step:4288.7ms mem:20875MiB swa_n:4 +step:6650/20000 train_loss:1.853269 lr_scale:0.1121 muon_mom:0.9900 train_time:571185ms step_avg:85.89ms this_step:4343.0ms mem:20875MiB swa_n:5 +step:6700/20000 train_loss:1.983597 lr_scale:0.0954 muon_mom:0.9900 train_time:575487ms step_avg:85.89ms this_step:4302.1ms mem:20875MiB swa_n:6 +step:6750/20000 train_loss:2.133729 lr_scale:0.0785 muon_mom:0.9900 train_time:579845ms step_avg:85.90ms this_step:4357.6ms mem:20875MiB swa_n:7 +step:6800/20000 train_loss:2.063642 lr_scale:0.0618 muon_mom:0.9900 train_time:584136ms step_avg:85.90ms this_step:4291.5ms mem:20875MiB swa_n:8 +step:6850/20000 train_loss:1.875348 lr_scale:0.0452 muon_mom:0.9900 train_time:588419ms step_avg:85.90ms this_step:4283.3ms mem:20875MiB swa_n:9 +step:6900/20000 train_loss:1.874371 lr_scale:0.0284 muon_mom:0.9900 train_time:592759ms step_avg:85.91ms this_step:4339.3ms mem:20875MiB swa_n:10 +step:6950/20000 train_loss:1.997379 lr_scale:0.0117 muon_mom:0.9900 train_time:597044ms step_avg:85.91ms this_step:4284.7ms mem:20875MiB swa_n:11 +step:6985/20000 val_loss:1.9767 val_bpb:1.1707 train_time:600085ms step_avg:85.91ms +stopping_early: wallclock_cap train_time:600085ms step:6985/20000 +peak memory allocated: 20875 MiB reserved: 21082 MiB +phase:train wall_ms:611031 steps:6985 step_avg:85.91ms +swa:applying averaged 12 checkpoints +pruning: zeroed 1,341,022 weights (5.0%) below 0.007446 +phase:postprocess wall_ms:151 (swa+ema+pruning) +pre_quant_eval val_loss:1.9641 val_bpb:1.1633 eval_time:16590ms +pre_quant_eval_exact val_loss:1.96410637 val_bpb:1.16325443 +Serialized model: 105792597 bytes +Code size: 71033 bytes +Total submission size: 105863630 bytes +quant_tensor:bigram.embed.weight shape:[2048, 128] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.0.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.046204] +quant_tensor:blocks.0.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.033325] +quant_tensor:blocks.0.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.0.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.0.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.060059] +quant_tensor:blocks.0.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.1.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.083069] +quant_tensor:blocks.1.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.042877] +quant_tensor:blocks.1.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.1.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.1.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.039764] +quant_tensor:blocks.1.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.076538] +quant_tensor:blocks.10.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.045471] +quant_tensor:blocks.10.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.10.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032623] +quant_tensor:blocks.10.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.10.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.055847] +quant_tensor:blocks.10.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.095398] +quant_tensor:blocks.2.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.039307] +quant_tensor:blocks.2.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.2.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032562] +quant_tensor:blocks.2.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.2.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.032898] +quant_tensor:blocks.2.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.3.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.044464] +quant_tensor:blocks.3.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.3.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.3.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.3.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.181519] +quant_tensor:blocks.3.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.085754] +quant_tensor:blocks.4.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.036896] +quant_tensor:blocks.4.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.4.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.034851] +quant_tensor:blocks.4.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.4.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.034332] +quant_tensor:blocks.4.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.5.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.039124] +quant_tensor:blocks.5.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.5.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.5.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.5.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.037079] +quant_tensor:blocks.5.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.6.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.041992] +quant_tensor:blocks.6.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.6.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.035522] +quant_tensor:blocks.6.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.6.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.035461] +quant_tensor:blocks.6.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.7.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032898] +quant_tensor:blocks.7.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.7.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.034363] +quant_tensor:blocks.7.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.7.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.036865] +quant_tensor:blocks.7.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.8.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.046234] +quant_tensor:blocks.8.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.8.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.040863] +quant_tensor:blocks.8.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.035095] +quant_tensor:blocks.8.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.048035] +quant_tensor:blocks.8.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.9.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.043640] +quant_tensor:blocks.9.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.9.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.9.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.9.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.039154] +quant_tensor:blocks.9.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +passthrough_tensor:bigram.proj.weight shape:[512, 128] dtype:torch.float16 bytes:131072 +passthrough_tensor:bigram.scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.0.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.0.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.0.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.0.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.0.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.1.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.1.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.1.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.1.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.1.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.10.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.10.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.10.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.10.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.10.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.2.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.2.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.2.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.2.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.2.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.3.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.3.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.3.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.3.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.3.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.4.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.4.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.4.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.4.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.4.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.5.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.5.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.5.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.5.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.5.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.6.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.6.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.6.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.6.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.6.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.7.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.7.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.7.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.7.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.7.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.8.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.8.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.8.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.8.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.8.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.9.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.9.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.9.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.9.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.9.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:skip_weights shape:[5, 512] dtype:torch.float32 bytes:10240 +passthrough_tensor:smear.gate shape:[512] dtype:torch.float16 bytes:1024 +passthrough_tensor:tok_emb.weight shape:[1024, 512] dtype:torch.float16 bytes:1048576 +Serialized model zstd-22: 15350889 bytes (payload:27578744 raw_torch:27638331 payload_ratio:3.83x) +Total submission size zstd-22: 15421922 bytes +Size check PASSED: 15421922 / 16,000,000 (96.4%) +phase:serialize wall_ms:39081 (quant+compress+save) +final_int8_zlib_roundtrip val_loss:1.9838 val_bpb:1.1749 eval_time:2167ms eval_seq_len:2048 +final_int8_zlib_roundtrip_exact val_loss:1.98380334 val_bpb:1.17492008 +quant_gap: 0.011666 BPB (pre:1.163254 post:1.174920) +phase:postquant_eval wall_ms:2688 +ttt:rank0 short=2393 long=3857 epochs=1 batch=64 +ttt:short_docs time=23776ms tokens=732712 +ttt:batch 5/61 time=1015ms avg_loss=2.0049 +ttt:batch 10/61 time=1927ms avg_loss=1.9940 +ttt:batch 15/61 time=2837ms avg_loss=1.9823 +ttt:batch 20/61 time=4405ms avg_loss=1.9638 +ttt:batch 25/61 time=5978ms avg_loss=1.9583 +ttt:batch 30/61 time=8309ms avg_loss=1.9514 +ttt:batch 35/61 time=10935ms avg_loss=1.9465 +ttt:batch 40/61 time=14177ms avg_loss=1.9433 +ttt:batch 45/61 time=18340ms avg_loss=1.9396 +ttt:batch 50/61 time=23676ms avg_loss=1.9402 +ttt:batch 55/61 time=31304ms avg_loss=1.9340 +ttt:batch 60/61 time=54392ms avg_loss=1.9314 +ttt:long_docs time=62488ms docs=3857 +final_ttt_lora val_loss:1.9466 val_bpb:1.1529 eval_time:112336ms lora_rank:8 chunk_size:256 +final_ttt_lora_exact val_loss:1.94661285 val_bpb:1.15289371 +ttt_gain: 0.022026 BPB gain over int8 (int8:1.174920 ttt:1.152894) +phase:ttt_eval wall_ms:113056 +phase:TOTAL wall_ms:766008 (12.8 min) +phase_breakdown: train:600085ms postprocess:see_above serialize:see_above eval:see_above ttt:see_above diff --git a/records/track_10min_16mb/2026-03-24_PROTEUS_v9/train_seed2024.log b/records/track_10min_16mb/2026-03-24_PROTEUS_v9/train_seed2024.log new file mode 100644 index 000000000..278320c2b --- /dev/null +++ b/records/track_10min_16mb/2026-03-24_PROTEUS_v9/train_seed2024.log @@ -0,0 +1,352 @@ +W0324 01:59:14.131000 12900 torch/distributed/run.py:851] +W0324 01:59:14.131000 12900 torch/distributed/run.py:851] ***************************************** +W0324 01:59:14.131000 12900 torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0324 01:59:14.131000 12900 torch/distributed/run.py:851] ***************************************** +logs/proteus_v9_2024.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/tmp/pgolf-repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 val_tokens:62021632 +model_params:26829913 world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2024 ema_enabled:True ema_decay:0.999 ema_every:10 +ttt_lora_rank:8 ttt_lora_lr:0.01 ttt_chunk_size:256 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:6.931914 lr_scale:1.0000 muon_mom:0.9200 train_time:199ms step_avg:199.12ms this_step:199.1ms mem:20875MiB swa_n:0 +step:2/20000 train_loss:8.041641 lr_scale:0.9997 muon_mom:0.9200 train_time:268ms step_avg:133.86ms this_step:68.6ms mem:20875MiB swa_n:0 +step:3/20000 train_loss:7.454000 lr_scale:1.0000 muon_mom:0.9201 train_time:351ms step_avg:117.02ms this_step:83.3ms mem:20875MiB swa_n:0 +step:4/20000 train_loss:6.998686 lr_scale:1.0000 muon_mom:0.9201 train_time:434ms step_avg:108.60ms this_step:83.4ms mem:20875MiB swa_n:0 +step:5/20000 train_loss:6.865939 lr_scale:1.0000 muon_mom:0.9202 train_time:518ms step_avg:103.51ms this_step:83.1ms mem:20875MiB swa_n:0 +step:6/20000 train_loss:6.852402 lr_scale:1.0000 muon_mom:0.9202 train_time:601ms step_avg:100.17ms this_step:83.5ms mem:20875MiB swa_n:0 +step:7/20000 train_loss:6.738769 lr_scale:1.0000 muon_mom:0.9203 train_time:685ms step_avg:97.80ms this_step:83.6ms mem:20875MiB swa_n:0 +step:8/20000 train_loss:6.619420 lr_scale:1.0000 muon_mom:0.9203 train_time:768ms step_avg:96.00ms this_step:83.4ms mem:20875MiB swa_n:0 +step:9/20000 train_loss:6.400832 lr_scale:1.0000 muon_mom:0.9204 train_time:852ms step_avg:94.67ms this_step:84.0ms mem:20875MiB swa_n:0 +step:10/20000 train_loss:6.107140 lr_scale:1.0000 muon_mom:0.9204 train_time:936ms step_avg:93.55ms this_step:83.5ms mem:20875MiB swa_n:0 +step:50/20000 train_loss:3.984745 lr_scale:1.0000 muon_mom:0.9223 train_time:4338ms step_avg:86.75ms this_step:3402.0ms mem:20875MiB swa_n:0 +step:100/20000 train_loss:3.243303 lr_scale:1.0000 muon_mom:0.9246 train_time:8629ms step_avg:86.29ms this_step:4291.2ms mem:20875MiB swa_n:0 +step:150/20000 train_loss:2.935160 lr_scale:1.0000 muon_mom:0.9270 train_time:12997ms step_avg:86.65ms this_step:4368.5ms mem:20875MiB swa_n:0 +step:200/20000 train_loss:2.454417 lr_scale:1.0000 muon_mom:0.9293 train_time:17250ms step_avg:86.25ms this_step:4252.6ms mem:20875MiB swa_n:0 +step:250/20000 train_loss:2.549796 lr_scale:1.0000 muon_mom:0.9316 train_time:21513ms step_avg:86.05ms this_step:4263.2ms mem:20875MiB swa_n:0 +step:300/20000 train_loss:2.615655 lr_scale:1.0000 muon_mom:0.9340 train_time:25836ms step_avg:86.12ms this_step:4322.7ms mem:20875MiB swa_n:0 +step:350/20000 train_loss:2.591411 lr_scale:1.0000 muon_mom:0.9363 train_time:30105ms step_avg:86.01ms this_step:4268.9ms mem:20875MiB swa_n:0 +step:400/20000 train_loss:2.472593 lr_scale:1.0000 muon_mom:0.9386 train_time:34432ms step_avg:86.08ms this_step:4327.5ms mem:20875MiB swa_n:0 +step:450/20000 train_loss:2.428872 lr_scale:1.0000 muon_mom:0.9410 train_time:38711ms step_avg:86.02ms this_step:4278.8ms mem:20875MiB swa_n:0 +step:500/20000 train_loss:2.449510 lr_scale:1.0000 muon_mom:0.9433 train_time:42988ms step_avg:85.98ms this_step:4277.2ms mem:20875MiB swa_n:0 +step:550/20000 train_loss:2.390684 lr_scale:1.0000 muon_mom:0.9456 train_time:47325ms step_avg:86.05ms this_step:4336.7ms mem:20875MiB swa_n:0 +step:600/20000 train_loss:2.379141 lr_scale:1.0000 muon_mom:0.9480 train_time:51618ms step_avg:86.03ms this_step:4293.2ms mem:20875MiB swa_n:0 +step:650/20000 train_loss:2.374645 lr_scale:1.0000 muon_mom:0.9503 train_time:55969ms step_avg:86.11ms this_step:4351.4ms mem:20875MiB swa_n:0 +step:700/20000 train_loss:2.391678 lr_scale:1.0000 muon_mom:0.9526 train_time:60261ms step_avg:86.09ms this_step:4291.4ms mem:20875MiB swa_n:0 +step:750/20000 train_loss:2.373483 lr_scale:1.0000 muon_mom:0.9550 train_time:64552ms step_avg:86.07ms this_step:4291.1ms mem:20875MiB swa_n:0 +step:800/20000 train_loss:2.286453 lr_scale:1.0000 muon_mom:0.9573 train_time:68902ms step_avg:86.13ms this_step:4350.0ms mem:20875MiB swa_n:0 +step:850/20000 train_loss:2.279479 lr_scale:1.0000 muon_mom:0.9596 train_time:73193ms step_avg:86.11ms this_step:4291.0ms mem:20875MiB swa_n:0 +step:900/20000 train_loss:2.171954 lr_scale:1.0000 muon_mom:0.9620 train_time:77542ms step_avg:86.16ms this_step:4348.7ms mem:20875MiB swa_n:0 +step:950/20000 train_loss:2.262181 lr_scale:1.0000 muon_mom:0.9643 train_time:81839ms step_avg:86.15ms this_step:4297.4ms mem:20875MiB swa_n:0 +step:1000/20000 train_loss:2.310415 lr_scale:1.0000 muon_mom:0.9666 train_time:86127ms step_avg:86.13ms this_step:4288.4ms mem:20875MiB swa_n:0 +step:1050/20000 train_loss:2.270665 lr_scale:1.0000 muon_mom:0.9690 train_time:90474ms step_avg:86.17ms this_step:4346.5ms mem:20875MiB swa_n:0 +step:1100/20000 train_loss:2.376440 lr_scale:1.0000 muon_mom:0.9713 train_time:94767ms step_avg:86.15ms this_step:4293.1ms mem:20875MiB swa_n:0 +step:1150/20000 train_loss:2.285896 lr_scale:1.0000 muon_mom:0.9736 train_time:99108ms step_avg:86.18ms this_step:4340.7ms mem:20875MiB swa_n:0 +step:1200/20000 train_loss:2.393761 lr_scale:1.0000 muon_mom:0.9760 train_time:103397ms step_avg:86.16ms this_step:4288.8ms mem:20875MiB swa_n:0 +step:1250/20000 train_loss:2.296191 lr_scale:1.0000 muon_mom:0.9783 train_time:107678ms step_avg:86.14ms this_step:4281.6ms mem:20875MiB swa_n:0 +step:1300/20000 train_loss:2.154675 lr_scale:1.0000 muon_mom:0.9806 train_time:112024ms step_avg:86.17ms this_step:4346.0ms mem:20875MiB swa_n:0 +step:1350/20000 train_loss:2.291740 lr_scale:1.0000 muon_mom:0.9830 train_time:116301ms step_avg:86.15ms this_step:4276.5ms mem:20875MiB swa_n:0 +step:1400/20000 train_loss:2.226450 lr_scale:1.0000 muon_mom:0.9853 train_time:120640ms step_avg:86.17ms this_step:4339.7ms mem:20875MiB swa_n:0 +step:1450/20000 train_loss:2.164248 lr_scale:1.0000 muon_mom:0.9876 train_time:124912ms step_avg:86.15ms this_step:4271.7ms mem:20875MiB swa_n:0 +step:1500/20000 train_loss:2.259241 lr_scale:1.0000 muon_mom:0.9900 train_time:129186ms step_avg:86.12ms this_step:4273.6ms mem:20875MiB swa_n:0 +step:1550/20000 train_loss:2.227330 lr_scale:1.0000 muon_mom:0.9900 train_time:133518ms step_avg:86.14ms this_step:4331.9ms mem:20875MiB swa_n:0 +step:1600/20000 train_loss:2.119856 lr_scale:1.0000 muon_mom:0.9900 train_time:137791ms step_avg:86.12ms this_step:4273.2ms mem:20875MiB swa_n:0 +step:1650/20000 train_loss:2.237877 lr_scale:1.0000 muon_mom:0.9900 train_time:142067ms step_avg:86.10ms this_step:4276.2ms mem:20875MiB swa_n:0 +step:1700/20000 train_loss:2.178303 lr_scale:1.0000 muon_mom:0.9900 train_time:146398ms step_avg:86.12ms this_step:4330.9ms mem:20875MiB swa_n:0 +step:1750/20000 train_loss:2.238083 lr_scale:1.0000 muon_mom:0.9900 train_time:150668ms step_avg:86.10ms this_step:4269.8ms mem:20875MiB swa_n:0 +step:1800/20000 train_loss:2.229345 lr_scale:1.0000 muon_mom:0.9900 train_time:154995ms step_avg:86.11ms this_step:4327.4ms mem:20875MiB swa_n:0 +step:1850/20000 train_loss:2.070625 lr_scale:1.0000 muon_mom:0.9900 train_time:159265ms step_avg:86.09ms this_step:4270.5ms mem:20875MiB swa_n:0 +step:1900/20000 train_loss:2.171952 lr_scale:1.0000 muon_mom:0.9900 train_time:163532ms step_avg:86.07ms this_step:4266.1ms mem:20875MiB swa_n:0 +step:1950/20000 train_loss:2.065742 lr_scale:1.0000 muon_mom:0.9900 train_time:167859ms step_avg:86.08ms this_step:4327.8ms mem:20875MiB swa_n:0 +step:2000/20000 train_loss:2.107166 lr_scale:1.0000 muon_mom:0.9900 train_time:172124ms step_avg:86.06ms this_step:4264.7ms mem:20875MiB swa_n:0 +step:2050/20000 train_loss:2.148781 lr_scale:1.0000 muon_mom:0.9900 train_time:176447ms step_avg:86.07ms this_step:4322.7ms mem:20875MiB swa_n:0 +step:2100/20000 train_loss:2.076511 lr_scale:1.0000 muon_mom:0.9900 train_time:180713ms step_avg:86.05ms this_step:4266.2ms mem:20875MiB swa_n:0 +step:2150/20000 train_loss:2.181700 lr_scale:1.0000 muon_mom:0.9900 train_time:184978ms step_avg:86.04ms this_step:4264.8ms mem:20875MiB swa_n:0 +step:2200/20000 train_loss:2.233525 lr_scale:1.0000 muon_mom:0.9900 train_time:189298ms step_avg:86.04ms this_step:4320.2ms mem:20875MiB swa_n:0 +step:2250/20000 train_loss:2.210792 lr_scale:1.0000 muon_mom:0.9900 train_time:193563ms step_avg:86.03ms this_step:4265.3ms mem:20875MiB swa_n:0 +step:2300/20000 train_loss:2.148972 lr_scale:1.0000 muon_mom:0.9900 train_time:197886ms step_avg:86.04ms this_step:4322.4ms mem:20875MiB swa_n:0 +step:2350/20000 train_loss:2.206292 lr_scale:1.0000 muon_mom:0.9900 train_time:202152ms step_avg:86.02ms this_step:4266.7ms mem:20875MiB swa_n:0 +step:2400/20000 train_loss:2.109140 lr_scale:1.0000 muon_mom:0.9900 train_time:206417ms step_avg:86.01ms this_step:4264.4ms mem:20875MiB swa_n:0 +step:2450/20000 train_loss:2.118556 lr_scale:1.0000 muon_mom:0.9900 train_time:210730ms step_avg:86.01ms this_step:4313.4ms mem:20875MiB swa_n:0 +step:2500/20000 train_loss:2.212800 lr_scale:1.0000 muon_mom:0.9900 train_time:214981ms step_avg:85.99ms this_step:4250.4ms mem:20875MiB swa_n:0 +step:2550/20000 train_loss:2.239536 lr_scale:1.0000 muon_mom:0.9900 train_time:219290ms step_avg:86.00ms this_step:4309.1ms mem:20875MiB swa_n:0 +step:2600/20000 train_loss:2.142347 lr_scale:1.0000 muon_mom:0.9900 train_time:223546ms step_avg:85.98ms this_step:4256.5ms mem:20875MiB swa_n:0 +step:2650/20000 train_loss:2.116741 lr_scale:1.0000 muon_mom:0.9900 train_time:227802ms step_avg:85.96ms this_step:4255.7ms mem:20875MiB swa_n:0 +step:2700/20000 train_loss:2.136790 lr_scale:1.0000 muon_mom:0.9900 train_time:232113ms step_avg:85.97ms this_step:4311.4ms mem:20875MiB swa_n:0 +step:2750/20000 train_loss:2.068666 lr_scale:1.0000 muon_mom:0.9900 train_time:236370ms step_avg:85.95ms this_step:4256.2ms mem:20875MiB swa_n:0 +step:2800/20000 train_loss:2.189669 lr_scale:1.0000 muon_mom:0.9900 train_time:240681ms step_avg:85.96ms this_step:4311.6ms mem:20875MiB swa_n:0 +step:2850/20000 train_loss:2.103628 lr_scale:1.0000 muon_mom:0.9900 train_time:244938ms step_avg:85.94ms this_step:4256.5ms mem:20875MiB swa_n:0 +step:2900/20000 train_loss:2.069944 lr_scale:1.0000 muon_mom:0.9900 train_time:249191ms step_avg:85.93ms this_step:4253.4ms mem:20875MiB swa_n:0 +step:2950/20000 train_loss:2.115621 lr_scale:1.0000 muon_mom:0.9900 train_time:253496ms step_avg:85.93ms this_step:4304.6ms mem:20875MiB swa_n:0 +step:3000/20000 train_loss:2.193537 lr_scale:1.0000 muon_mom:0.9900 train_time:257746ms step_avg:85.92ms this_step:4250.7ms mem:20875MiB swa_n:0 +step:3050/20000 train_loss:2.080265 lr_scale:1.0000 muon_mom:0.9900 train_time:261995ms step_avg:85.90ms this_step:4249.0ms mem:20875MiB swa_n:0 +step:3100/20000 train_loss:2.077717 lr_scale:1.0000 muon_mom:0.9900 train_time:266310ms step_avg:85.91ms this_step:4314.5ms mem:20875MiB swa_n:0 +step:3150/20000 train_loss:2.010258 lr_scale:1.0000 muon_mom:0.9900 train_time:270559ms step_avg:85.89ms this_step:4249.5ms mem:20875MiB swa_n:0 +step:3200/20000 train_loss:2.209013 lr_scale:1.0000 muon_mom:0.9900 train_time:274865ms step_avg:85.90ms this_step:4305.3ms mem:20875MiB swa_n:0 +step:3250/20000 train_loss:2.088322 lr_scale:1.0000 muon_mom:0.9900 train_time:279121ms step_avg:85.88ms this_step:4255.9ms mem:20875MiB swa_n:0 +step:3300/20000 train_loss:2.110989 lr_scale:1.0000 muon_mom:0.9900 train_time:283371ms step_avg:85.87ms this_step:4250.4ms mem:20875MiB swa_n:0 +step:3350/20000 train_loss:2.136043 lr_scale:1.0000 muon_mom:0.9900 train_time:287679ms step_avg:85.87ms this_step:4308.1ms mem:20875MiB swa_n:0 +step:3400/20000 train_loss:2.072887 lr_scale:1.0000 muon_mom:0.9900 train_time:291931ms step_avg:85.86ms this_step:4252.3ms mem:20875MiB swa_n:0 +step:3450/20000 train_loss:2.156906 lr_scale:1.0000 muon_mom:0.9900 train_time:296239ms step_avg:85.87ms this_step:4308.1ms mem:20875MiB swa_n:0 +step:3500/20000 train_loss:2.221806 lr_scale:1.0000 muon_mom:0.9900 train_time:300491ms step_avg:85.85ms this_step:4251.6ms mem:20875MiB swa_n:0 +step:3550/20000 train_loss:1.968783 lr_scale:1.0000 muon_mom:0.9900 train_time:304744ms step_avg:85.84ms this_step:4252.8ms mem:20875MiB swa_n:0 +step:3600/20000 train_loss:2.136259 lr_scale:1.0000 muon_mom:0.9900 train_time:309054ms step_avg:85.85ms this_step:4309.7ms mem:20875MiB swa_n:0 +step:3650/20000 train_loss:2.024872 lr_scale:1.0000 muon_mom:0.9900 train_time:313299ms step_avg:85.84ms this_step:4245.6ms mem:20875MiB swa_n:0 +step:3700/20000 train_loss:2.130675 lr_scale:1.0000 muon_mom:0.9900 train_time:317609ms step_avg:85.84ms this_step:4310.1ms mem:20875MiB swa_n:0 +step:3750/20000 train_loss:1.962400 lr_scale:1.0000 muon_mom:0.9900 train_time:321855ms step_avg:85.83ms this_step:4245.8ms mem:20875MiB swa_n:0 +step:3800/20000 train_loss:2.118789 lr_scale:1.0000 muon_mom:0.9900 train_time:326106ms step_avg:85.82ms this_step:4251.3ms mem:20875MiB swa_n:0 +step:3850/20000 train_loss:2.133793 lr_scale:1.0000 muon_mom:0.9900 train_time:330415ms step_avg:85.82ms this_step:4308.6ms mem:20875MiB swa_n:0 +step:3900/20000 train_loss:2.122054 lr_scale:1.0000 muon_mom:0.9900 train_time:334666ms step_avg:85.81ms this_step:4251.1ms mem:20875MiB swa_n:0 +step:3950/20000 train_loss:2.218469 lr_scale:1.0000 muon_mom:0.9900 train_time:338959ms step_avg:85.81ms this_step:4293.1ms mem:20875MiB swa_n:0 +step:4000/20000 train_loss:2.021121 lr_scale:0.9978 muon_mom:0.9900 train_time:343215ms step_avg:85.80ms this_step:4256.0ms mem:20875MiB swa_n:0 +step:4050/20000 train_loss:2.134112 lr_scale:0.9814 muon_mom:0.9900 train_time:347461ms step_avg:85.79ms this_step:4246.2ms mem:20875MiB swa_n:0 +step:4100/20000 train_loss:2.076350 lr_scale:0.9647 muon_mom:0.9900 train_time:351765ms step_avg:85.80ms this_step:4303.8ms mem:20875MiB swa_n:0 +step:4150/20000 train_loss:2.156129 lr_scale:0.9482 muon_mom:0.9900 train_time:356015ms step_avg:85.79ms this_step:4249.9ms mem:20875MiB swa_n:0 +step:4200/20000 train_loss:2.202910 lr_scale:0.9315 muon_mom:0.9900 train_time:360321ms step_avg:85.79ms this_step:4305.7ms mem:20875MiB swa_n:0 +step:4250/20000 train_loss:2.157139 lr_scale:0.9151 muon_mom:0.9900 train_time:364566ms step_avg:85.78ms this_step:4245.1ms mem:20875MiB swa_n:0 +step:4300/20000 train_loss:2.099376 lr_scale:0.8987 muon_mom:0.9900 train_time:368814ms step_avg:85.77ms this_step:4248.1ms mem:20875MiB swa_n:0 +step:4350/20000 train_loss:2.116194 lr_scale:0.8819 muon_mom:0.9900 train_time:373129ms step_avg:85.78ms this_step:4314.7ms mem:20875MiB swa_n:0 +step:4400/20000 train_loss:2.084675 lr_scale:0.8655 muon_mom:0.9900 train_time:377369ms step_avg:85.77ms this_step:4240.6ms mem:20875MiB swa_n:0 +step:4450/20000 train_loss:2.089081 lr_scale:0.8491 muon_mom:0.9900 train_time:381616ms step_avg:85.76ms this_step:4246.9ms mem:20875MiB swa_n:0 +step:4500/20000 train_loss:2.162761 lr_scale:0.8323 muon_mom:0.9900 train_time:385919ms step_avg:85.76ms this_step:4302.8ms mem:20875MiB swa_n:0 +step:4550/20000 train_loss:2.175592 lr_scale:0.8159 muon_mom:0.9900 train_time:390171ms step_avg:85.75ms this_step:4252.0ms mem:20875MiB swa_n:0 +step:4600/20000 train_loss:1.907365 lr_scale:0.7991 muon_mom:0.9900 train_time:394471ms step_avg:85.75ms this_step:4299.6ms mem:20875MiB swa_n:0 +step:4650/20000 train_loss:2.099530 lr_scale:0.7827 muon_mom:0.9900 train_time:398716ms step_avg:85.75ms this_step:4245.1ms mem:20875MiB swa_n:0 +step:4700/20000 train_loss:2.295781 lr_scale:0.7663 muon_mom:0.9900 train_time:402968ms step_avg:85.74ms this_step:4251.9ms mem:20875MiB swa_n:0 +step:4750/20000 train_loss:2.064718 lr_scale:0.7495 muon_mom:0.9900 train_time:407276ms step_avg:85.74ms this_step:4308.9ms mem:20875MiB swa_n:0 +step:4800/20000 train_loss:2.507038 lr_scale:0.7330 muon_mom:0.9900 train_time:411524ms step_avg:85.73ms this_step:4247.5ms mem:20875MiB swa_n:0 +step:4850/20000 train_loss:2.154274 lr_scale:0.7163 muon_mom:0.9900 train_time:415825ms step_avg:85.74ms this_step:4300.6ms mem:20875MiB swa_n:0 +step:4900/20000 train_loss:2.101693 lr_scale:0.6998 muon_mom:0.9900 train_time:420073ms step_avg:85.73ms this_step:4248.1ms mem:20875MiB swa_n:0 +step:4950/20000 train_loss:2.152763 lr_scale:0.6834 muon_mom:0.9900 train_time:424322ms step_avg:85.72ms this_step:4249.4ms mem:20875MiB swa_n:0 +step:5000/20000 train_loss:2.152505 lr_scale:0.6666 muon_mom:0.9900 train_time:428632ms step_avg:85.73ms this_step:4309.6ms mem:20875MiB swa_n:0 +step:5050/20000 train_loss:2.136249 lr_scale:0.6501 muon_mom:0.9900 train_time:432874ms step_avg:85.72ms this_step:4242.4ms mem:20875MiB swa_n:0 +step:5100/20000 train_loss:2.164412 lr_scale:0.6333 muon_mom:0.9900 train_time:437185ms step_avg:85.72ms this_step:4311.1ms mem:20875MiB swa_n:0 +step:5150/20000 train_loss:2.073886 lr_scale:0.6169 muon_mom:0.9900 train_time:441430ms step_avg:85.71ms this_step:4244.7ms mem:20875MiB swa_n:0 +step:5200/20000 train_loss:2.090128 lr_scale:0.6004 muon_mom:0.9900 train_time:445674ms step_avg:85.71ms this_step:4244.4ms mem:20875MiB swa_n:0 +step:5250/20000 train_loss:2.105075 lr_scale:0.5837 muon_mom:0.9900 train_time:449983ms step_avg:85.71ms this_step:4308.5ms mem:20875MiB swa_n:0 +step:5300/20000 train_loss:2.055272 lr_scale:0.5672 muon_mom:0.9900 train_time:454227ms step_avg:85.70ms this_step:4244.3ms mem:20875MiB swa_n:0 +step:5350/20000 train_loss:1.974428 lr_scale:0.5505 muon_mom:0.9900 train_time:458525ms step_avg:85.71ms this_step:4298.2ms mem:20875MiB swa_n:0 +step:5400/20000 train_loss:2.090054 lr_scale:0.5340 muon_mom:0.9900 train_time:462779ms step_avg:85.70ms this_step:4254.3ms mem:20875MiB swa_n:0 +step:5450/20000 train_loss:2.115666 lr_scale:0.5175 muon_mom:0.9900 train_time:467026ms step_avg:85.69ms this_step:4246.8ms mem:20875MiB swa_n:0 +step:5500/20000 train_loss:2.059743 lr_scale:0.5007 muon_mom:0.9900 train_time:471331ms step_avg:85.70ms this_step:4305.2ms mem:20875MiB swa_n:0 +step:5550/20000 train_loss:2.054671 lr_scale:0.4842 muon_mom:0.9900 train_time:475584ms step_avg:85.69ms this_step:4252.5ms mem:20875MiB swa_n:0 +step:5600/20000 train_loss:2.015575 lr_scale:0.4674 muon_mom:0.9900 train_time:479890ms step_avg:85.69ms this_step:4305.9ms mem:20875MiB swa_n:0 +step:5650/20000 train_loss:2.092720 lr_scale:0.4510 muon_mom:0.9900 train_time:484140ms step_avg:85.69ms this_step:4249.9ms mem:20875MiB swa_n:0 +step:5700/20000 train_loss:2.057725 lr_scale:0.4344 muon_mom:0.9900 train_time:488392ms step_avg:85.68ms this_step:4252.6ms mem:20875MiB swa_n:0 +step:5750/20000 train_loss:2.137746 lr_scale:0.4177 muon_mom:0.9900 train_time:492701ms step_avg:85.69ms this_step:4308.4ms mem:20875MiB swa_n:0 +step:5800/20000 train_loss:2.049301 lr_scale:0.4012 muon_mom:0.9900 train_time:496945ms step_avg:85.68ms this_step:4243.9ms mem:20875MiB swa_n:0 +step:5850/20000 train_loss:2.174310 lr_scale:0.3847 muon_mom:0.9900 train_time:501250ms step_avg:85.68ms this_step:4305.4ms mem:20875MiB swa_n:0 +step:5900/20000 train_loss:1.952830 lr_scale:0.3679 muon_mom:0.9900 train_time:505494ms step_avg:85.68ms this_step:4244.1ms mem:20875MiB swa_n:0 +step:5950/20000 train_loss:1.999254 lr_scale:0.3514 muon_mom:0.9900 train_time:509741ms step_avg:85.67ms this_step:4247.1ms mem:20875MiB swa_n:0 +step:6000/20000 train_loss:1.996382 lr_scale:0.3347 muon_mom:0.9900 train_time:514050ms step_avg:85.67ms this_step:4308.5ms mem:20875MiB swa_n:0 +step:6050/20000 train_loss:2.014052 lr_scale:0.3181 muon_mom:0.9900 train_time:518300ms step_avg:85.67ms this_step:4249.9ms mem:20875MiB swa_n:0 +step:6100/20000 train_loss:1.972045 lr_scale:0.3016 muon_mom:0.9900 train_time:522551ms step_avg:85.66ms this_step:4251.5ms mem:20875MiB swa_n:0 +step:6150/20000 train_loss:2.071825 lr_scale:0.2848 muon_mom:0.9900 train_time:526859ms step_avg:85.67ms this_step:4307.6ms mem:20875MiB swa_n:0 +step:6200/20000 train_loss:2.006609 lr_scale:0.2683 muon_mom:0.9900 train_time:531113ms step_avg:85.66ms this_step:4253.7ms mem:20875MiB swa_n:0 +step:6250/20000 train_loss:2.118243 lr_scale:0.2516 muon_mom:0.9900 train_time:535408ms step_avg:85.67ms this_step:4295.7ms mem:20875MiB swa_n:0 +step:6300/20000 train_loss:1.988789 lr_scale:0.2351 muon_mom:0.9900 train_time:539656ms step_avg:85.66ms this_step:4248.1ms mem:20875MiB swa_n:0 +step:6350/20000 train_loss:2.080419 lr_scale:0.2186 muon_mom:0.9900 train_time:543906ms step_avg:85.65ms this_step:4249.3ms mem:20875MiB swa_n:0 +step:6400/20000 train_loss:2.045061 lr_scale:0.2018 muon_mom:0.9900 train_time:548211ms step_avg:85.66ms this_step:4305.7ms mem:20875MiB swa_n:0 +step:6450/20000 train_loss:2.118379 lr_scale:0.1853 muon_mom:0.9900 train_time:552461ms step_avg:85.65ms this_step:4249.4ms mem:20875MiB swa_n:0 +swa:start step=6450 +step:6500/20000 train_loss:2.121874 lr_scale:0.1681 muon_mom:0.9900 train_time:556856ms step_avg:85.67ms this_step:4395.0ms mem:20875MiB swa_n:1 +step:6550/20000 train_loss:2.086864 lr_scale:0.1515 muon_mom:0.9900 train_time:561133ms step_avg:85.67ms this_step:4277.3ms mem:20875MiB swa_n:2 +step:6600/20000 train_loss:1.900066 lr_scale:0.1349 muon_mom:0.9900 train_time:565408ms step_avg:85.67ms this_step:4274.6ms mem:20875MiB swa_n:3 +step:6650/20000 train_loss:1.855953 lr_scale:0.1180 muon_mom:0.9900 train_time:569748ms step_avg:85.68ms this_step:4340.5ms mem:20875MiB swa_n:4 +step:6700/20000 train_loss:1.987305 lr_scale:0.1013 muon_mom:0.9900 train_time:574020ms step_avg:85.67ms this_step:4272.4ms mem:20875MiB swa_n:5 +step:6750/20000 train_loss:2.131429 lr_scale:0.0845 muon_mom:0.9900 train_time:578343ms step_avg:85.68ms this_step:4322.7ms mem:20875MiB swa_n:6 +step:6800/20000 train_loss:2.058727 lr_scale:0.0679 muon_mom:0.9900 train_time:582623ms step_avg:85.68ms this_step:4279.6ms mem:20875MiB swa_n:7 +step:6850/20000 train_loss:1.874189 lr_scale:0.0512 muon_mom:0.9900 train_time:586918ms step_avg:85.68ms this_step:4294.8ms mem:20875MiB swa_n:8 +step:6900/20000 train_loss:1.873953 lr_scale:0.0343 muon_mom:0.9900 train_time:591253ms step_avg:85.69ms this_step:4335.0ms mem:20875MiB swa_n:9 +step:6950/20000 train_loss:1.997867 lr_scale:0.0176 muon_mom:0.9900 train_time:595538ms step_avg:85.69ms this_step:4285.5ms mem:20875MiB swa_n:10 +step:7000/20000 train_loss:1.846325 lr_scale:0.0007 muon_mom:0.9900 train_time:599875ms step_avg:85.70ms this_step:4336.7ms mem:20875MiB swa_n:11 +step:7001/20000 val_loss:1.9761 val_bpb:1.1704 train_time:600007ms step_avg:85.70ms +stopping_early: wallclock_cap train_time:600007ms step:7001/20000 +peak memory allocated: 20875 MiB reserved: 21082 MiB +phase:train wall_ms:610872 steps:7001 step_avg:85.70ms +swa:applying averaged 12 checkpoints +pruning: zeroed 1,336,252 weights (5.0%) below 0.007368 +phase:postprocess wall_ms:145 (swa+ema+pruning) +pre_quant_eval val_loss:1.9645 val_bpb:1.1635 eval_time:16443ms +pre_quant_eval_exact val_loss:1.96450683 val_bpb:1.16349160 +Serialized model: 105792597 bytes +Code size: 71033 bytes +Total submission size: 105863630 bytes +quant_tensor:bigram.embed.weight shape:[2048, 128] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.0.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.053558] +quant_tensor:blocks.0.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.035767] +quant_tensor:blocks.0.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.0.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.0.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.043854] +quant_tensor:blocks.0.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.1.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.094421] +quant_tensor:blocks.1.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.040527] +quant_tensor:blocks.1.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.1.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.033813] +quant_tensor:blocks.1.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.035828] +quant_tensor:blocks.1.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.083496] +quant_tensor:blocks.10.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.047607] +quant_tensor:blocks.10.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.035553] +quant_tensor:blocks.10.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.039612] +quant_tensor:blocks.10.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032959] +quant_tensor:blocks.10.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.052399] +quant_tensor:blocks.10.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.076050] +quant_tensor:blocks.2.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.042175] +quant_tensor:blocks.2.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.2.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032593] +quant_tensor:blocks.2.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.2.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.147461] +quant_tensor:blocks.2.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.063416] +quant_tensor:blocks.3.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.059479] +quant_tensor:blocks.3.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.033051] +quant_tensor:blocks.3.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.034149] +quant_tensor:blocks.3.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.3.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.059937] +quant_tensor:blocks.3.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.4.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.036652] +quant_tensor:blocks.4.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.4.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032562] +quant_tensor:blocks.4.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.4.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.034821] +quant_tensor:blocks.4.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.5.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.045410] +quant_tensor:blocks.5.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.5.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.5.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.5.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.035583] +quant_tensor:blocks.5.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.6.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.033264] +quant_tensor:blocks.6.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.6.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.034607] +quant_tensor:blocks.6.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.6.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.040161] +quant_tensor:blocks.6.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.7.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.035034] +quant_tensor:blocks.7.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032562] +quant_tensor:blocks.7.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.036011] +quant_tensor:blocks.7.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.037323] +quant_tensor:blocks.7.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.050598] +quant_tensor:blocks.7.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.8.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.054108] +quant_tensor:blocks.8.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.033447] +quant_tensor:blocks.8.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.038116] +quant_tensor:blocks.8.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.8.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.040527] +quant_tensor:blocks.8.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.9.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.046631] +quant_tensor:blocks.9.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.034149] +quant_tensor:blocks.9.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.042572] +quant_tensor:blocks.9.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.9.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.036438] +quant_tensor:blocks.9.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +passthrough_tensor:bigram.proj.weight shape:[512, 128] dtype:torch.float16 bytes:131072 +passthrough_tensor:bigram.scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.0.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.0.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.0.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.0.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.0.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.1.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.1.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.1.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.1.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.1.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.10.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.10.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.10.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.10.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.10.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.2.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.2.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.2.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.2.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.2.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.3.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.3.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.3.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.3.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.3.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.4.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.4.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.4.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.4.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.4.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.5.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.5.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.5.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.5.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.5.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.6.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.6.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.6.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.6.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.6.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.7.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.7.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.7.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.7.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.7.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.8.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.8.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.8.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.8.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.8.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.9.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.9.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.9.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.9.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.9.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:skip_weights shape:[5, 512] dtype:torch.float32 bytes:10240 +passthrough_tensor:smear.gate shape:[512] dtype:torch.float16 bytes:1024 +passthrough_tensor:tok_emb.weight shape:[1024, 512] dtype:torch.float16 bytes:1048576 +Serialized model zstd-22: 15328866 bytes (payload:27578744 raw_torch:27638331 payload_ratio:3.83x) +Total submission size zstd-22: 15399899 bytes +Size check PASSED: 15399899 / 16,000,000 (96.2%) +phase:serialize wall_ms:40267 (quant+compress+save) +final_int8_zlib_roundtrip val_loss:1.9875 val_bpb:1.1771 eval_time:2168ms eval_seq_len:2048 +final_int8_zlib_roundtrip_exact val_loss:1.98745774 val_bpb:1.17708442 +quant_gap: 0.013593 BPB (pre:1.163492 post:1.177084) +phase:postquant_eval wall_ms:2330 +ttt:rank0 short=2393 long=3857 epochs=1 batch=64 +ttt:short_docs time=23775ms tokens=732712 +ttt:batch 5/61 time=977ms avg_loss=2.0046 +ttt:batch 10/61 time=1886ms avg_loss=1.9937 +ttt:batch 15/61 time=2795ms avg_loss=1.9822 +ttt:batch 20/61 time=4358ms avg_loss=1.9633 +ttt:batch 25/61 time=5925ms avg_loss=1.9575 +ttt:batch 30/61 time=8255ms avg_loss=1.9506 +ttt:batch 35/61 time=10879ms avg_loss=1.9453 +ttt:batch 40/61 time=14114ms avg_loss=1.9421 +ttt:batch 45/61 time=18273ms avg_loss=1.9385 +ttt:batch 50/61 time=23596ms avg_loss=1.9390 +ttt:batch 55/61 time=31158ms avg_loss=1.9328 +ttt:batch 60/61 time=54061ms avg_loss=1.9300 +ttt:long_docs time=62123ms docs=3857 +final_ttt_lora val_loss:1.9454 val_bpb:1.1522 eval_time:111848ms lora_rank:8 chunk_size:256 +final_ttt_lora_exact val_loss:1.94543436 val_bpb:1.15219574 +ttt_gain: 0.024889 BPB gain over int8 (int8:1.177084 ttt:1.152196) +phase:ttt_eval wall_ms:112573 +phase:TOTAL wall_ms:766188 (12.8 min) +phase_breakdown: train:600007ms postprocess:see_above serialize:see_above eval:see_above ttt:see_above diff --git a/records/track_10min_16mb/2026-03-24_PROTEUS_v9/train_seed42.log b/records/track_10min_16mb/2026-03-24_PROTEUS_v9/train_seed42.log new file mode 100644 index 000000000..717e5de2d --- /dev/null +++ b/records/track_10min_16mb/2026-03-24_PROTEUS_v9/train_seed42.log @@ -0,0 +1,351 @@ +W0324 01:30:30.630000 1901 torch/distributed/run.py:851] +W0324 01:30:30.630000 1901 torch/distributed/run.py:851] ***************************************** +W0324 01:30:30.630000 1901 torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0324 01:30:30.630000 1901 torch/distributed/run.py:851] ***************************************** +logs/proteus_v9_42.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/tmp/pgolf-repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 val_tokens:62021632 +model_params:26829913 world_size:8 grad_accum_steps:1 +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.03 head_lr:0.0 matrix_lr:0.02 scalar_lr:0.02 +train_batch_tokens:786432 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 ema_enabled:True ema_decay:0.999 ema_every:10 +ttt_lora_rank:8 ttt_lora_lr:0.01 ttt_chunk_size:256 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:6.932050 lr_scale:1.0000 muon_mom:0.9200 train_time:209ms step_avg:208.79ms this_step:208.8ms mem:20877MiB swa_n:0 +step:2/20000 train_loss:8.088524 lr_scale:0.8313 muon_mom:0.9200 train_time:309ms step_avg:154.40ms this_step:100.0ms mem:20877MiB swa_n:0 +step:3/20000 train_loss:7.506648 lr_scale:1.0000 muon_mom:0.9201 train_time:392ms step_avg:130.82ms this_step:83.7ms mem:20877MiB swa_n:0 +step:4/20000 train_loss:6.846169 lr_scale:1.0000 muon_mom:0.9201 train_time:476ms step_avg:119.09ms this_step:83.9ms mem:20877MiB swa_n:0 +step:5/20000 train_loss:6.764615 lr_scale:1.0000 muon_mom:0.9202 train_time:561ms step_avg:112.18ms this_step:84.5ms mem:20877MiB swa_n:0 +step:6/20000 train_loss:6.878400 lr_scale:1.0000 muon_mom:0.9202 train_time:645ms step_avg:107.43ms this_step:83.7ms mem:20877MiB swa_n:0 +step:7/20000 train_loss:6.722106 lr_scale:1.0000 muon_mom:0.9203 train_time:728ms step_avg:104.06ms this_step:83.8ms mem:20877MiB swa_n:0 +step:8/20000 train_loss:6.636737 lr_scale:1.0000 muon_mom:0.9203 train_time:813ms step_avg:101.59ms this_step:84.3ms mem:20877MiB swa_n:0 +step:9/20000 train_loss:6.390717 lr_scale:1.0000 muon_mom:0.9204 train_time:897ms step_avg:99.69ms this_step:84.5ms mem:20877MiB swa_n:0 +step:10/20000 train_loss:6.135527 lr_scale:1.0000 muon_mom:0.9204 train_time:982ms step_avg:98.17ms this_step:84.5ms mem:20877MiB swa_n:0 +step:50/20000 train_loss:3.998825 lr_scale:1.0000 muon_mom:0.9223 train_time:4391ms step_avg:87.83ms this_step:3409.7ms mem:20877MiB swa_n:0 +step:100/20000 train_loss:3.263013 lr_scale:1.0000 muon_mom:0.9246 train_time:9066ms step_avg:90.66ms this_step:4674.3ms mem:20877MiB swa_n:0 +step:150/20000 train_loss:2.950302 lr_scale:1.0000 muon_mom:0.9270 train_time:13814ms step_avg:92.10ms this_step:4748.6ms mem:20877MiB swa_n:0 +step:200/20000 train_loss:2.474711 lr_scale:1.0000 muon_mom:0.9293 train_time:18109ms step_avg:90.54ms this_step:4294.3ms mem:20877MiB swa_n:0 +step:250/20000 train_loss:2.554001 lr_scale:1.0000 muon_mom:0.9316 train_time:22392ms step_avg:89.57ms this_step:4283.9ms mem:20877MiB swa_n:0 +step:300/20000 train_loss:2.625889 lr_scale:1.0000 muon_mom:0.9340 train_time:26721ms step_avg:89.07ms this_step:4328.8ms mem:20877MiB swa_n:0 +step:350/20000 train_loss:2.590093 lr_scale:1.0000 muon_mom:0.9363 train_time:31014ms step_avg:88.61ms this_step:4293.2ms mem:20877MiB swa_n:0 +step:400/20000 train_loss:2.482168 lr_scale:1.0000 muon_mom:0.9386 train_time:35355ms step_avg:88.39ms this_step:4340.8ms mem:20877MiB swa_n:0 +step:450/20000 train_loss:2.428273 lr_scale:1.0000 muon_mom:0.9410 train_time:39639ms step_avg:88.09ms this_step:4283.8ms mem:20877MiB swa_n:0 +step:500/20000 train_loss:2.450206 lr_scale:1.0000 muon_mom:0.9433 train_time:43927ms step_avg:87.85ms this_step:4288.2ms mem:20877MiB swa_n:0 +step:550/20000 train_loss:2.394403 lr_scale:1.0000 muon_mom:0.9456 train_time:48270ms step_avg:87.76ms this_step:4342.4ms mem:20877MiB swa_n:0 +step:600/20000 train_loss:2.377735 lr_scale:1.0000 muon_mom:0.9480 train_time:52566ms step_avg:87.61ms this_step:4296.4ms mem:20877MiB swa_n:0 +step:650/20000 train_loss:2.380548 lr_scale:1.0000 muon_mom:0.9503 train_time:56920ms step_avg:87.57ms this_step:4354.2ms mem:20877MiB swa_n:0 +step:700/20000 train_loss:2.393714 lr_scale:1.0000 muon_mom:0.9526 train_time:61214ms step_avg:87.45ms this_step:4294.1ms mem:20877MiB swa_n:0 +step:750/20000 train_loss:2.374254 lr_scale:1.0000 muon_mom:0.9550 train_time:65517ms step_avg:87.36ms this_step:4303.1ms mem:20877MiB swa_n:0 +step:800/20000 train_loss:2.288073 lr_scale:1.0000 muon_mom:0.9573 train_time:69874ms step_avg:87.34ms this_step:4356.9ms mem:20877MiB swa_n:0 +step:850/20000 train_loss:2.280262 lr_scale:1.0000 muon_mom:0.9596 train_time:74181ms step_avg:87.27ms this_step:4306.4ms mem:20877MiB swa_n:0 +step:900/20000 train_loss:2.178795 lr_scale:1.0000 muon_mom:0.9620 train_time:78529ms step_avg:87.25ms this_step:4347.9ms mem:20877MiB swa_n:0 +step:950/20000 train_loss:2.260362 lr_scale:1.0000 muon_mom:0.9643 train_time:82827ms step_avg:87.19ms this_step:4298.7ms mem:20877MiB swa_n:0 +step:1000/20000 train_loss:2.312956 lr_scale:1.0000 muon_mom:0.9666 train_time:87122ms step_avg:87.12ms this_step:4294.2ms mem:20877MiB swa_n:0 +step:1050/20000 train_loss:2.273446 lr_scale:1.0000 muon_mom:0.9690 train_time:91474ms step_avg:87.12ms this_step:4352.2ms mem:20877MiB swa_n:0 +step:1100/20000 train_loss:2.372427 lr_scale:1.0000 muon_mom:0.9713 train_time:96158ms step_avg:87.42ms this_step:4684.3ms mem:20877MiB swa_n:0 +step:1150/20000 train_loss:2.289131 lr_scale:1.0000 muon_mom:0.9736 train_time:100500ms step_avg:87.39ms this_step:4341.9ms mem:20877MiB swa_n:0 +step:1200/20000 train_loss:2.397977 lr_scale:1.0000 muon_mom:0.9760 train_time:104788ms step_avg:87.32ms this_step:4287.5ms mem:20877MiB swa_n:0 +step:1250/20000 train_loss:2.295452 lr_scale:1.0000 muon_mom:0.9783 train_time:109077ms step_avg:87.26ms this_step:4289.6ms mem:20877MiB swa_n:0 +step:1300/20000 train_loss:2.151206 lr_scale:1.0000 muon_mom:0.9806 train_time:113424ms step_avg:87.25ms this_step:4346.9ms mem:20877MiB swa_n:0 +step:1350/20000 train_loss:2.286731 lr_scale:1.0000 muon_mom:0.9830 train_time:117716ms step_avg:87.20ms this_step:4291.7ms mem:20877MiB swa_n:0 +step:1400/20000 train_loss:2.231361 lr_scale:1.0000 muon_mom:0.9853 train_time:122063ms step_avg:87.19ms this_step:4347.5ms mem:20877MiB swa_n:0 +step:1450/20000 train_loss:2.168387 lr_scale:1.0000 muon_mom:0.9876 train_time:126351ms step_avg:87.14ms this_step:4287.4ms mem:20877MiB swa_n:0 +step:1500/20000 train_loss:2.259269 lr_scale:1.0000 muon_mom:0.9900 train_time:130643ms step_avg:87.10ms this_step:4292.1ms mem:20877MiB swa_n:0 +step:1550/20000 train_loss:2.223686 lr_scale:1.0000 muon_mom:0.9900 train_time:134981ms step_avg:87.08ms this_step:4338.1ms mem:20877MiB swa_n:0 +step:1600/20000 train_loss:2.122992 lr_scale:1.0000 muon_mom:0.9900 train_time:139265ms step_avg:87.04ms this_step:4284.0ms mem:20877MiB swa_n:0 +step:1650/20000 train_loss:2.238400 lr_scale:1.0000 muon_mom:0.9900 train_time:143550ms step_avg:87.00ms this_step:4284.8ms mem:20877MiB swa_n:0 +step:1700/20000 train_loss:2.178571 lr_scale:1.0000 muon_mom:0.9900 train_time:147892ms step_avg:87.00ms this_step:4342.2ms mem:20877MiB swa_n:0 +step:1750/20000 train_loss:2.236989 lr_scale:1.0000 muon_mom:0.9900 train_time:152180ms step_avg:86.96ms this_step:4288.3ms mem:20877MiB swa_n:0 +step:1800/20000 train_loss:2.228819 lr_scale:1.0000 muon_mom:0.9900 train_time:156517ms step_avg:86.95ms this_step:4337.2ms mem:20877MiB swa_n:0 +step:1850/20000 train_loss:2.072316 lr_scale:1.0000 muon_mom:0.9900 train_time:160805ms step_avg:86.92ms this_step:4288.0ms mem:20877MiB swa_n:0 +step:1900/20000 train_loss:2.176599 lr_scale:1.0000 muon_mom:0.9900 train_time:165092ms step_avg:86.89ms this_step:4286.3ms mem:20877MiB swa_n:0 +step:1950/20000 train_loss:2.065437 lr_scale:1.0000 muon_mom:0.9900 train_time:169427ms step_avg:86.89ms this_step:4335.4ms mem:20877MiB swa_n:0 +step:2000/20000 train_loss:2.112659 lr_scale:1.0000 muon_mom:0.9900 train_time:173722ms step_avg:86.86ms this_step:4295.0ms mem:20877MiB swa_n:0 +step:2050/20000 train_loss:2.151097 lr_scale:1.0000 muon_mom:0.9900 train_time:178068ms step_avg:86.86ms this_step:4346.2ms mem:20877MiB swa_n:0 +step:2100/20000 train_loss:2.076366 lr_scale:1.0000 muon_mom:0.9900 train_time:182349ms step_avg:86.83ms this_step:4280.9ms mem:20877MiB swa_n:0 +step:2150/20000 train_loss:2.183758 lr_scale:1.0000 muon_mom:0.9900 train_time:186634ms step_avg:86.81ms this_step:4284.8ms mem:20877MiB swa_n:0 +step:2200/20000 train_loss:2.228511 lr_scale:1.0000 muon_mom:0.9900 train_time:190971ms step_avg:86.80ms this_step:4336.7ms mem:20877MiB swa_n:0 +step:2250/20000 train_loss:2.217029 lr_scale:1.0000 muon_mom:0.9900 train_time:195244ms step_avg:86.78ms this_step:4273.7ms mem:20877MiB swa_n:0 +step:2300/20000 train_loss:2.146494 lr_scale:1.0000 muon_mom:0.9900 train_time:199568ms step_avg:86.77ms this_step:4323.9ms mem:20877MiB swa_n:0 +step:2350/20000 train_loss:2.208739 lr_scale:1.0000 muon_mom:0.9900 train_time:203847ms step_avg:86.74ms this_step:4278.5ms mem:20877MiB swa_n:0 +step:2400/20000 train_loss:2.111727 lr_scale:1.0000 muon_mom:0.9900 train_time:208134ms step_avg:86.72ms this_step:4287.0ms mem:20877MiB swa_n:0 +step:2450/20000 train_loss:2.119874 lr_scale:1.0000 muon_mom:0.9900 train_time:212466ms step_avg:86.72ms this_step:4331.8ms mem:20877MiB swa_n:0 +step:2500/20000 train_loss:2.210375 lr_scale:1.0000 muon_mom:0.9900 train_time:216739ms step_avg:86.70ms this_step:4273.3ms mem:20877MiB swa_n:0 +step:2550/20000 train_loss:2.238856 lr_scale:1.0000 muon_mom:0.9900 train_time:221063ms step_avg:86.69ms this_step:4323.7ms mem:20877MiB swa_n:0 +step:2600/20000 train_loss:2.142381 lr_scale:1.0000 muon_mom:0.9900 train_time:225352ms step_avg:86.67ms this_step:4289.2ms mem:20877MiB swa_n:0 +step:2650/20000 train_loss:2.119841 lr_scale:1.0000 muon_mom:0.9900 train_time:229640ms step_avg:86.66ms this_step:4288.0ms mem:20877MiB swa_n:0 +step:2700/20000 train_loss:2.135784 lr_scale:1.0000 muon_mom:0.9900 train_time:233975ms step_avg:86.66ms this_step:4334.8ms mem:20877MiB swa_n:0 +step:2750/20000 train_loss:2.072520 lr_scale:1.0000 muon_mom:0.9900 train_time:238251ms step_avg:86.64ms this_step:4276.8ms mem:20877MiB swa_n:0 +step:2800/20000 train_loss:2.187783 lr_scale:1.0000 muon_mom:0.9900 train_time:242595ms step_avg:86.64ms this_step:4343.9ms mem:20877MiB swa_n:0 +step:2850/20000 train_loss:2.103125 lr_scale:1.0000 muon_mom:0.9900 train_time:246878ms step_avg:86.62ms this_step:4282.7ms mem:20877MiB swa_n:0 +step:2900/20000 train_loss:2.067446 lr_scale:1.0000 muon_mom:0.9900 train_time:251164ms step_avg:86.61ms this_step:4286.3ms mem:20877MiB swa_n:0 +step:2950/20000 train_loss:2.116747 lr_scale:1.0000 muon_mom:0.9900 train_time:255513ms step_avg:86.61ms this_step:4348.8ms mem:20877MiB swa_n:0 +step:3000/20000 train_loss:2.195950 lr_scale:1.0000 muon_mom:0.9900 train_time:259796ms step_avg:86.60ms this_step:4283.1ms mem:20877MiB swa_n:0 +step:3050/20000 train_loss:2.077259 lr_scale:1.0000 muon_mom:0.9900 train_time:264073ms step_avg:86.58ms this_step:4276.8ms mem:20877MiB swa_n:0 +step:3100/20000 train_loss:2.082485 lr_scale:1.0000 muon_mom:0.9900 train_time:268419ms step_avg:86.59ms this_step:4346.5ms mem:20877MiB swa_n:0 +step:3150/20000 train_loss:2.011695 lr_scale:1.0000 muon_mom:0.9900 train_time:272703ms step_avg:86.57ms this_step:4283.2ms mem:20877MiB swa_n:0 +step:3200/20000 train_loss:2.210581 lr_scale:1.0000 muon_mom:0.9900 train_time:277025ms step_avg:86.57ms this_step:4322.8ms mem:20877MiB swa_n:0 +step:3250/20000 train_loss:2.087856 lr_scale:1.0000 muon_mom:0.9900 train_time:281301ms step_avg:86.55ms this_step:4276.0ms mem:20877MiB swa_n:0 +step:3300/20000 train_loss:2.114615 lr_scale:1.0000 muon_mom:0.9900 train_time:285576ms step_avg:86.54ms this_step:4275.1ms mem:20877MiB swa_n:0 +step:3350/20000 train_loss:2.135349 lr_scale:1.0000 muon_mom:0.9900 train_time:289910ms step_avg:86.54ms this_step:4333.7ms mem:20877MiB swa_n:0 +step:3400/20000 train_loss:2.070315 lr_scale:1.0000 muon_mom:0.9900 train_time:294197ms step_avg:86.53ms this_step:4286.9ms mem:20877MiB swa_n:0 +step:3450/20000 train_loss:2.154049 lr_scale:1.0000 muon_mom:0.9900 train_time:298545ms step_avg:86.53ms this_step:4347.5ms mem:20877MiB swa_n:0 +step:3500/20000 train_loss:2.221444 lr_scale:1.0000 muon_mom:0.9900 train_time:302817ms step_avg:86.52ms this_step:4272.4ms mem:20877MiB swa_n:0 +step:3550/20000 train_loss:1.967561 lr_scale:1.0000 muon_mom:0.9900 train_time:307089ms step_avg:86.50ms this_step:4272.0ms mem:20877MiB swa_n:0 +step:3600/20000 train_loss:2.136210 lr_scale:1.0000 muon_mom:0.9900 train_time:311409ms step_avg:86.50ms this_step:4320.4ms mem:20877MiB swa_n:0 +step:3650/20000 train_loss:2.025363 lr_scale:1.0000 muon_mom:0.9900 train_time:315670ms step_avg:86.48ms this_step:4260.7ms mem:20877MiB swa_n:0 +step:3700/20000 train_loss:2.132116 lr_scale:1.0000 muon_mom:0.9900 train_time:319988ms step_avg:86.48ms this_step:4318.1ms mem:20877MiB swa_n:0 +step:3750/20000 train_loss:1.963250 lr_scale:1.0000 muon_mom:0.9900 train_time:324253ms step_avg:86.47ms this_step:4264.8ms mem:20877MiB swa_n:0 +step:3800/20000 train_loss:2.117608 lr_scale:1.0000 muon_mom:0.9900 train_time:328519ms step_avg:86.45ms this_step:4266.2ms mem:20877MiB swa_n:0 +step:3850/20000 train_loss:2.133176 lr_scale:1.0000 muon_mom:0.9900 train_time:332839ms step_avg:86.45ms this_step:4319.8ms mem:20877MiB swa_n:0 +step:3900/20000 train_loss:2.123193 lr_scale:1.0000 muon_mom:0.9900 train_time:337111ms step_avg:86.44ms this_step:4271.5ms mem:20877MiB swa_n:0 +step:3950/20000 train_loss:2.219206 lr_scale:0.9973 muon_mom:0.9900 train_time:341439ms step_avg:86.44ms this_step:4328.8ms mem:20877MiB swa_n:0 +step:4000/20000 train_loss:2.022180 lr_scale:0.9808 muon_mom:0.9900 train_time:345727ms step_avg:86.43ms this_step:4288.1ms mem:20877MiB swa_n:0 +step:4050/20000 train_loss:2.133180 lr_scale:0.9644 muon_mom:0.9900 train_time:350007ms step_avg:86.42ms this_step:4279.9ms mem:20877MiB swa_n:0 +step:4100/20000 train_loss:2.074527 lr_scale:0.9478 muon_mom:0.9900 train_time:354329ms step_avg:86.42ms this_step:4322.1ms mem:20877MiB swa_n:0 +step:4150/20000 train_loss:2.155265 lr_scale:0.9314 muon_mom:0.9900 train_time:358605ms step_avg:86.41ms this_step:4275.8ms mem:20877MiB swa_n:0 +step:4200/20000 train_loss:2.203617 lr_scale:0.9146 muon_mom:0.9900 train_time:362949ms step_avg:86.42ms this_step:4343.8ms mem:20877MiB swa_n:0 +step:4250/20000 train_loss:2.157001 lr_scale:0.8981 muon_mom:0.9900 train_time:367238ms step_avg:86.41ms this_step:4288.6ms mem:20877MiB swa_n:0 +step:4300/20000 train_loss:2.099408 lr_scale:0.8817 muon_mom:0.9900 train_time:371516ms step_avg:86.40ms this_step:4278.4ms mem:20877MiB swa_n:0 +step:4350/20000 train_loss:2.117557 lr_scale:0.8649 muon_mom:0.9900 train_time:375855ms step_avg:86.40ms this_step:4339.0ms mem:20877MiB swa_n:0 +step:4400/20000 train_loss:2.081449 lr_scale:0.8486 muon_mom:0.9900 train_time:380125ms step_avg:86.39ms this_step:4270.0ms mem:20877MiB swa_n:0 +step:4450/20000 train_loss:2.085667 lr_scale:0.8322 muon_mom:0.9900 train_time:384396ms step_avg:86.38ms this_step:4270.6ms mem:20877MiB swa_n:0 +step:4500/20000 train_loss:2.161180 lr_scale:0.8156 muon_mom:0.9900 train_time:388713ms step_avg:86.38ms this_step:4317.8ms mem:20877MiB swa_n:0 +step:4550/20000 train_loss:2.165619 lr_scale:0.7992 muon_mom:0.9900 train_time:392978ms step_avg:86.37ms this_step:4265.0ms mem:20877MiB swa_n:0 +step:4600/20000 train_loss:1.905738 lr_scale:0.7825 muon_mom:0.9900 train_time:397301ms step_avg:86.37ms this_step:4322.2ms mem:20877MiB swa_n:0 +step:4650/20000 train_loss:2.098183 lr_scale:0.7662 muon_mom:0.9900 train_time:401568ms step_avg:86.36ms this_step:4267.3ms mem:20877MiB swa_n:0 +step:4700/20000 train_loss:2.292833 lr_scale:0.7498 muon_mom:0.9900 train_time:405834ms step_avg:86.35ms this_step:4265.9ms mem:20877MiB swa_n:0 +step:4750/20000 train_loss:2.061094 lr_scale:0.7331 muon_mom:0.9900 train_time:410161ms step_avg:86.35ms this_step:4327.2ms mem:20877MiB swa_n:0 +step:4800/20000 train_loss:2.508126 lr_scale:0.7167 muon_mom:0.9900 train_time:414431ms step_avg:86.34ms this_step:4269.7ms mem:20877MiB swa_n:0 +step:4850/20000 train_loss:2.152413 lr_scale:0.7000 muon_mom:0.9900 train_time:418749ms step_avg:86.34ms this_step:4318.5ms mem:20877MiB swa_n:0 +step:4900/20000 train_loss:2.102630 lr_scale:0.6836 muon_mom:0.9900 train_time:423022ms step_avg:86.33ms this_step:4272.6ms mem:20877MiB swa_n:0 +step:4950/20000 train_loss:2.147992 lr_scale:0.6672 muon_mom:0.9900 train_time:427285ms step_avg:86.32ms this_step:4263.6ms mem:20877MiB swa_n:0 +step:5000/20000 train_loss:2.152503 lr_scale:0.6505 muon_mom:0.9900 train_time:431605ms step_avg:86.32ms this_step:4320.1ms mem:20877MiB swa_n:0 +step:5050/20000 train_loss:2.136943 lr_scale:0.6341 muon_mom:0.9900 train_time:435872ms step_avg:86.31ms this_step:4266.2ms mem:20877MiB swa_n:0 +step:5100/20000 train_loss:2.163745 lr_scale:0.6174 muon_mom:0.9900 train_time:440201ms step_avg:86.31ms this_step:4328.9ms mem:20877MiB swa_n:0 +step:5150/20000 train_loss:2.077157 lr_scale:0.6010 muon_mom:0.9900 train_time:444467ms step_avg:86.30ms this_step:4266.7ms mem:20877MiB swa_n:0 +step:5200/20000 train_loss:2.088799 lr_scale:0.5846 muon_mom:0.9900 train_time:448731ms step_avg:86.29ms this_step:4263.7ms mem:20877MiB swa_n:0 +step:5250/20000 train_loss:2.107197 lr_scale:0.5679 muon_mom:0.9900 train_time:453049ms step_avg:86.29ms this_step:4317.8ms mem:20877MiB swa_n:0 +step:5300/20000 train_loss:2.055757 lr_scale:0.5515 muon_mom:0.9900 train_time:457313ms step_avg:86.29ms this_step:4264.7ms mem:20877MiB swa_n:0 +step:5350/20000 train_loss:1.971710 lr_scale:0.5348 muon_mom:0.9900 train_time:461625ms step_avg:86.28ms this_step:4311.2ms mem:20877MiB swa_n:0 +step:5400/20000 train_loss:2.090475 lr_scale:0.5184 muon_mom:0.9900 train_time:465896ms step_avg:86.28ms this_step:4271.7ms mem:20877MiB swa_n:0 +step:5450/20000 train_loss:2.110433 lr_scale:0.5019 muon_mom:0.9900 train_time:470166ms step_avg:86.27ms this_step:4269.9ms mem:20877MiB swa_n:0 +step:5500/20000 train_loss:2.058485 lr_scale:0.4852 muon_mom:0.9900 train_time:474491ms step_avg:86.27ms this_step:4324.6ms mem:20877MiB swa_n:0 +step:5550/20000 train_loss:2.052030 lr_scale:0.4687 muon_mom:0.9900 train_time:478761ms step_avg:86.26ms this_step:4270.3ms mem:20877MiB swa_n:0 +step:5600/20000 train_loss:2.012795 lr_scale:0.4517 muon_mom:0.9900 train_time:483153ms step_avg:86.28ms this_step:4392.5ms mem:20877MiB swa_n:0 +step:5650/20000 train_loss:2.096638 lr_scale:0.4353 muon_mom:0.9900 train_time:487412ms step_avg:86.27ms this_step:4258.5ms mem:20877MiB swa_n:0 +step:5700/20000 train_loss:2.051742 lr_scale:0.4188 muon_mom:0.9900 train_time:491681ms step_avg:86.26ms this_step:4268.9ms mem:20877MiB swa_n:0 +step:5750/20000 train_loss:2.133353 lr_scale:0.4021 muon_mom:0.9900 train_time:495999ms step_avg:86.26ms this_step:4317.6ms mem:20877MiB swa_n:0 +step:5800/20000 train_loss:2.050149 lr_scale:0.3857 muon_mom:0.9900 train_time:500266ms step_avg:86.25ms this_step:4267.6ms mem:20877MiB swa_n:0 +step:5850/20000 train_loss:2.175231 lr_scale:0.3692 muon_mom:0.9900 train_time:504593ms step_avg:86.26ms this_step:4326.9ms mem:20877MiB swa_n:0 +step:5900/20000 train_loss:1.952179 lr_scale:0.3525 muon_mom:0.9900 train_time:508857ms step_avg:86.25ms this_step:4264.2ms mem:20877MiB swa_n:0 +step:5950/20000 train_loss:2.000511 lr_scale:0.3360 muon_mom:0.9900 train_time:513128ms step_avg:86.24ms this_step:4270.7ms mem:20877MiB swa_n:0 +step:6000/20000 train_loss:1.992597 lr_scale:0.3193 muon_mom:0.9900 train_time:517444ms step_avg:86.24ms this_step:4316.4ms mem:20877MiB swa_n:0 +step:6050/20000 train_loss:2.012013 lr_scale:0.3029 muon_mom:0.9900 train_time:521712ms step_avg:86.23ms this_step:4267.5ms mem:20877MiB swa_n:0 +step:6100/20000 train_loss:1.966169 lr_scale:0.2864 muon_mom:0.9900 train_time:525987ms step_avg:86.23ms this_step:4275.1ms mem:20877MiB swa_n:0 +step:6150/20000 train_loss:2.068375 lr_scale:0.2696 muon_mom:0.9900 train_time:530314ms step_avg:86.23ms this_step:4327.2ms mem:20877MiB swa_n:0 +step:6200/20000 train_loss:2.001224 lr_scale:0.2531 muon_mom:0.9900 train_time:534586ms step_avg:86.22ms this_step:4272.2ms mem:20877MiB swa_n:0 +step:6250/20000 train_loss:2.119708 lr_scale:0.2364 muon_mom:0.9900 train_time:538906ms step_avg:86.23ms this_step:4320.2ms mem:20877MiB swa_n:0 +step:6300/20000 train_loss:1.987534 lr_scale:0.2200 muon_mom:0.9900 train_time:543166ms step_avg:86.22ms this_step:4259.7ms mem:20877MiB swa_n:0 +step:6350/20000 train_loss:2.079607 lr_scale:0.2035 muon_mom:0.9900 train_time:547428ms step_avg:86.21ms this_step:4262.1ms mem:20877MiB swa_n:0 +step:6400/20000 train_loss:2.043034 lr_scale:0.1868 muon_mom:0.9900 train_time:551752ms step_avg:86.21ms this_step:4323.7ms mem:20877MiB swa_n:0 +swa:start step=6400 +step:6450/20000 train_loss:2.114308 lr_scale:0.1700 muon_mom:0.9900 train_time:556107ms step_avg:86.22ms this_step:4354.7ms mem:20877MiB swa_n:1 +step:6500/20000 train_loss:2.120173 lr_scale:0.1531 muon_mom:0.9900 train_time:560460ms step_avg:86.22ms this_step:4353.7ms mem:20877MiB swa_n:2 +step:6550/20000 train_loss:2.084322 lr_scale:0.1364 muon_mom:0.9900 train_time:564772ms step_avg:86.22ms this_step:4311.5ms mem:20877MiB swa_n:3 +step:6600/20000 train_loss:1.895434 lr_scale:0.1198 muon_mom:0.9900 train_time:569067ms step_avg:86.22ms this_step:4295.2ms mem:20877MiB swa_n:4 +step:6650/20000 train_loss:1.852380 lr_scale:0.1030 muon_mom:0.9900 train_time:573433ms step_avg:86.23ms this_step:4365.8ms mem:20877MiB swa_n:5 +step:6700/20000 train_loss:1.984610 lr_scale:0.0864 muon_mom:0.9900 train_time:577726ms step_avg:86.23ms this_step:4293.6ms mem:20877MiB swa_n:6 +step:6750/20000 train_loss:2.129882 lr_scale:0.0695 muon_mom:0.9900 train_time:582086ms step_avg:86.23ms this_step:4359.5ms mem:20877MiB swa_n:7 +step:6800/20000 train_loss:2.058971 lr_scale:0.0529 muon_mom:0.9900 train_time:586378ms step_avg:86.23ms this_step:4291.6ms mem:20877MiB swa_n:8 +step:6850/20000 train_loss:1.874111 lr_scale:0.0363 muon_mom:0.9900 train_time:590683ms step_avg:86.23ms this_step:4305.1ms mem:20877MiB swa_n:9 +step:6900/20000 train_loss:1.873940 lr_scale:0.0194 muon_mom:0.9900 train_time:595049ms step_avg:86.24ms this_step:4365.8ms mem:20877MiB swa_n:10 +step:6950/20000 train_loss:1.996027 lr_scale:0.0028 muon_mom:0.9900 train_time:599347ms step_avg:86.24ms this_step:4298.7ms mem:20877MiB swa_n:11 +step:6958/20000 val_loss:1.9762 val_bpb:1.1704 train_time:600069ms step_avg:86.24ms +stopping_early: wallclock_cap train_time:600069ms step:6958/20000 +peak memory allocated: 20877 MiB reserved: 20910 MiB +phase:train wall_ms:628422 steps:6958 step_avg:86.24ms +swa:applying averaged 12 checkpoints +pruning: zeroed 1,342,228 weights (5.0%) below 0.007436 +phase:postprocess wall_ms:160 (swa+ema+pruning) +pre_quant_eval val_loss:1.9693 val_bpb:1.1663 eval_time:40538ms +pre_quant_eval_exact val_loss:1.96928844 val_bpb:1.16632354 +Serialized model: 105792597 bytes +Code size: 71033 bytes +Total submission size: 105863630 bytes +quant_tensor:bigram.embed.weight shape:[2048, 128] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.0.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.057220] +quant_tensor:blocks.0.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.039520] +quant_tensor:blocks.0.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.0.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.0.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.045197] +quant_tensor:blocks.0.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.1.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.102722] +quant_tensor:blocks.1.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.037140] +quant_tensor:blocks.1.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.033386] +quant_tensor:blocks.1.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.1.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.050964] +quant_tensor:blocks.1.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.065308] +quant_tensor:blocks.10.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.046509] +quant_tensor:blocks.10.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.10.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.10.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.10.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.046783] +quant_tensor:blocks.10.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.123657] +quant_tensor:blocks.2.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.037292] +quant_tensor:blocks.2.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.2.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.2.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.2.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.099976] +quant_tensor:blocks.2.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.145020] +quant_tensor:blocks.3.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.036743] +quant_tensor:blocks.3.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032654] +quant_tensor:blocks.3.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.033112] +quant_tensor:blocks.3.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.3.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.039032] +quant_tensor:blocks.3.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.4.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.037689] +quant_tensor:blocks.4.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.4.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.4.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.4.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.032318] +quant_tensor:blocks.4.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.5.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.039703] +quant_tensor:blocks.5.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.5.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.033966] +quant_tensor:blocks.5.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.5.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.035583] +quant_tensor:blocks.5.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.6.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032623] +quant_tensor:blocks.6.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.036957] +quant_tensor:blocks.6.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.032562] +quant_tensor:blocks.6.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.6.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.040588] +quant_tensor:blocks.6.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.7.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.033661] +quant_tensor:blocks.7.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.7.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.038269] +quant_tensor:blocks.7.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.034698] +quant_tensor:blocks.7.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.033875] +quant_tensor:blocks.7.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.8.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.049225] +quant_tensor:blocks.8.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.033386] +quant_tensor:blocks.8.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.035126] +quant_tensor:blocks.8.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.8.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.038300] +quant_tensor:blocks.8.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.9.attn.c_k.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.049988] +quant_tensor:blocks.9.attn.c_q.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.037628] +quant_tensor:blocks.9.attn.c_v.weight shape:[256, 512] bits:6 scale_range:[0.032257,0.053040] +quant_tensor:blocks.9.attn.proj.weight shape:[512, 512] bits:6 scale_range:[0.032257,0.032257] +quant_tensor:blocks.9.mlp.fc.weight shape:[1536, 512] bits:6 scale_range:[0.032257,0.035187] +quant_tensor:blocks.9.mlp.proj.weight shape:[512, 1536] bits:6 scale_range:[0.032257,0.035645] +passthrough_tensor:bigram.proj.weight shape:[512, 128] dtype:torch.float16 bytes:131072 +passthrough_tensor:bigram.scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.0.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.0.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.0.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.0.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.0.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.1.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.1.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.1.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.1.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.1.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.10.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.10.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.10.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.10.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.10.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.2.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.2.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.2.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.2.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.2.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.3.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.3.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.3.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.3.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.3.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.4.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.4.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.4.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.4.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.4.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.5.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.5.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.5.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.5.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.5.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.6.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.6.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.6.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.6.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.6.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.7.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.7.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.7.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.7.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.7.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.8.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.8.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.8.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.8.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.8.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:blocks.9.attn.q_gain shape:[8] dtype:torch.float32 bytes:32 +passthrough_tensor:blocks.9.attn_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.9.depth_scale shape:[] dtype:torch.float16 bytes:2 +passthrough_tensor:blocks.9.mlp_scale shape:[512] dtype:torch.float32 bytes:2048 +passthrough_tensor:blocks.9.resid_mix shape:[2, 512] dtype:torch.float32 bytes:4096 +passthrough_tensor:skip_weights shape:[5, 512] dtype:torch.float32 bytes:10240 +passthrough_tensor:smear.gate shape:[512] dtype:torch.float16 bytes:1024 +passthrough_tensor:tok_emb.weight shape:[1024, 512] dtype:torch.float16 bytes:1048576 +Serialized model zstd-22: 15332954 bytes (payload:27578744 raw_torch:27638331 payload_ratio:3.83x) +Total submission size zstd-22: 15403987 bytes +Size check PASSED: 15403987 / 16,000,000 (96.3%) +phase:serialize wall_ms:64004 (quant+compress+save) +final_int8_zlib_roundtrip val_loss:1.9930 val_bpb:1.1804 eval_time:2176ms eval_seq_len:2048 +final_int8_zlib_roundtrip_exact val_loss:1.99297925 val_bpb:1.18035457 +quant_gap: 0.014031 BPB (pre:1.166324 post:1.180355) +phase:postquant_eval wall_ms:2968 +ttt:rank0 short=2393 long=3857 epochs=1 batch=64 +ttt:short_docs time=24200ms tokens=732712 +ttt:batch 5/61 time=1028ms avg_loss=2.0038 +ttt:batch 10/61 time=1948ms avg_loss=1.9926 +ttt:batch 15/61 time=2866ms avg_loss=1.9810 +ttt:batch 20/61 time=4437ms avg_loss=1.9626 +ttt:batch 25/61 time=6015ms avg_loss=1.9570 +ttt:batch 30/61 time=8355ms avg_loss=1.9505 +ttt:batch 35/61 time=10987ms avg_loss=1.9454 +ttt:batch 40/61 time=14235ms avg_loss=1.9423 +ttt:batch 45/61 time=18410ms avg_loss=1.9388 +ttt:batch 50/61 time=23756ms avg_loss=1.9393 +ttt:batch 55/61 time=31400ms avg_loss=1.9332 +ttt:batch 60/61 time=54597ms avg_loss=1.9309 +ttt:long_docs time=62797ms docs=3857 +final_ttt_lora val_loss:1.9463 val_bpb:1.1527 eval_time:114191ms lora_rank:8 chunk_size:256 +final_ttt_lora_exact val_loss:1.94632700 val_bpb:1.15272441 +ttt_gain: 0.027630 BPB gain over int8 (int8:1.180355 ttt:1.152724) +phase:ttt_eval wall_ms:114910 +phase:TOTAL wall_ms:810464 (13.5 min) +phase_breakdown: train:600069ms postprocess:see_above serialize:see_above eval:see_above ttt:see_above