diff --git a/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/README.md b/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/README.md new file mode 100644 index 0000000000..9134c3a136 --- /dev/null +++ b/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/README.md @@ -0,0 +1,82 @@ +# Turbo-Muon + EngramLite + ParamBanking + GPTQ Reserve Opt (val_bpb 1.1126) + +**val_bpb: 1.1126** (3-seed mean, std 0.0003) | **~15.98 MB** | 8xH100 SXM, 600s train, ~120s eval + +Built on [PR #1089](https://github.com/openai/parameter-golf/pull/1089) by @mikeapedia. Fused Triton MLP architecture from [PR #1072](https://github.com/openai/parameter-golf/pull/1072) by @vimeto, forward-only fusion insight from [PR #1105](https://github.com/openai/parameter-golf/pull/1105) by @abaybektursun. + +## Results (8xH100 SXM, SWA applied, no TTT) + +| Seed | Sliding BPB | val_loss (nats) | Artifact | +|------|-------------|-----------------|----------| +| 1337 | **1.1126** | 1.87857 | 15,981,856 | +| 42 | **1.1123** | 1.87803 | 15,984,349 | +| 999 | **1.1129** | 1.87900 | 15,985,912 | +| **Mean +/- Std** | **1.1126 +/- 0.0003** | **1.87853** | | + +vs merged leaderboard SOTA ([PR #549](https://github.com/openai/parameter-golf/pull/549), 1.1194 BPB, 1.89002 nats): **-0.01149 nats** (-0.0068 BPB). Note: open PRs #1089 (1.1091) and #1105 (1.1138) achieve better scores. + +## What's New vs PR #1089 + +### 1. GPTQ Reserve Optimization +Reduced GPTQ calibration reserve from 14s to 9s. Calibration consistently completes in ~8.4s across all runs, so 14s wastes 5+ seconds of training budget. Recovers ~55 extra training steps at ~105ms/step. + +### 2. Forward-Only Fused Triton MLP Kernel Architecture +Designed a `torch.library.triton_op`-based fused kernel for `matmul + LeakyReLU(0.3) + square` with standard PyTorch backward (cuBLAS matmuls + elementwise ops). This architecture addresses two known issues: +- PR #1072's `torch.autograd.Function` crashes `torch.compile(fullgraph=True)` due to FakeTensor data pointer access +- PR #1105 showed Triton backward forces eager mode (2.7x slower) + +Our solution: `triton_op` + `wrap_triton` for compile-safe forward, `register_autograd` with standard ops for backward. The kernel code is included but **hard-disabled** — it produces NaN on PyTorch 2.9 due to a TTIR analysis bug. The scored runs use the standard MLP path. This is included as experimental code for future work. + +### 3. Centralized Activation Parameters +All `negative_slope` references unified via `_NEGATIVE_SLOPE = 0.3` constant with derived `_SLOPE_SQ = _NEGATIVE_SLOPE ** 2`. + +## Architecture (from PR #1089) + +- 11L, 512d, 8H/4KV (GQA), MLP 3.5x LeakyReLU(0.3)^2 +- Turbo-Muon optimizer (AOL preconditioning + Polar Express coefficients + row_col normalization, 4 Newton-Schulz iterations) +- EngramLite hash embeddings (bigram + trigram, 2 heads, 8192 buckets) +- Parameter Banking (3D bank tensors for batched Newton-Schulz via torch.bmm) +- U-Net sigmoid-gated skip connections + ValueEmbedding (layers 9-10) +- SmearGate, Partial RoPE(16), LN Scale +- SWA (threshold=0.2, every 50 steps, 14 snapshots) + EMA(0.997) fallback +- Mixed-precision GPTQ: int5 base + selective int6/int7 promotion by Hessian sensitivity +- Brotli-11 + byte-shuffle compression +- F.scaled_dot_product_attention (auto-selects FA3 backend) + +## Timing + +| Phase | Time | +|-------|------| +| Training (~5,668 steps @ 104ms) | 591s | +| GPTQ calibration + quantization | 9s (reserved) | +| Sliding window eval (stride=64) | ~120s | + +## Reproduction + +```bash +# Use official template: runpod/parameter-golf:latest (PyTorch 2.9.1+cu128) +# Or any 8xH100 SXM pod with PyTorch >= 2.6 + +pip install brotli sentencepiece +pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291 + +GPTQ_RESERVE_MS=9000 SEED=1337 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Rule Compliance + +- [x] Standard F.cross_entropy scoring (softmax, sum=1) +- [x] No eval-time training data access +- [x] Artifact < 16,000,000 bytes (all 3 seeds) +- [x] Training < 600s, eval < 600s +- [x] Causal sliding-window evaluation on full validation split (stride=64) +- [x] 3-seed verification: delta = -0.01149 nats vs SOTA (> 0.005 threshold) +- [x] No n-gram caching, no external downloads during eval + +## Credits + +- **Turbo-Muon + EngramLite + ParamBanking**: [PR #1089](https://github.com/openai/parameter-golf/pull/1089) by @mikeapedia +- **Fused Triton MLP kernel design**: [PR #1072](https://github.com/openai/parameter-golf/pull/1072) by @vimeto +- **Forward-only fusion insight**: [PR #1105](https://github.com/openai/parameter-golf/pull/1105) by @abaybektursun +- **Base scaffold**: [PR #549](https://github.com/openai/parameter-golf/pull/549) by @abaybektursun diff --git a/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/requirements.txt b/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/requirements.txt new file mode 100644 index 0000000000..dad7e4008d --- /dev/null +++ b/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/requirements.txt @@ -0,0 +1,3 @@ +torch>=2.6.0 +brotli +sentencepiece diff --git a/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/runpod_setup.sh b/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/runpod_setup.sh new file mode 100644 index 0000000000..6de4731f89 --- /dev/null +++ b/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/runpod_setup.sh @@ -0,0 +1,80 @@ +#!/bin/bash +# V17 RunPod Setup — PR #1089 (TurboMuon) + PR #1072 (Fused Triton Kernel) +# USAGE: +# bash runpod_setup.sh # Setup (PyTorch upgrade, deps) +# bash runpod_setup.sh run # Run training +set -e + +if [ "$1" = "run" ]; then + # ---- RUN MODE ---- + echo "=== V17 FusedTurboMuon ===" + echo "Config: SEED=${SEED:-1337} GPTQ_RESERVE_MS=${GPTQ_RESERVE_MS:-9000}" + echo "Starting in 3s..." + sleep 3 + GPTQ_RESERVE_MS=${GPTQ_RESERVE_MS:-9000} \ + SEED=${SEED:-1337} \ + torchrun --standalone --nproc_per_node=8 train_gpt.py + exit 0 +fi + +# ---- SETUP MODE ---- +echo "=============================================" +echo " V17 FUSED TURBOMUON — POD SETUP" +echo "=============================================" + +# 1. Check CUDA driver +DRIVER=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -1) +echo "CUDA Driver: $DRIVER" + +# 2. Check current PyTorch +CURRENT_PT=$(python3 -c "import torch; print(torch.__version__)" 2>/dev/null || echo "none") +echo "Current PyTorch: $CURRENT_PT" + +# 3. Install deps (brotli required for compression, sentencepiece for tokenizer) +pip install brotli sentencepiece 2>&1 | tail -2 + +# 4. Install FA3 for SDPA backend acceleration (30-second wheel install) +python3 -c "from flash_attn_interface import flash_attn_func" 2>/dev/null && echo "FA3: already installed" || { + echo "Installing FA3 pre-built wheel..." + pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291 2>&1 | tail -3 + python3 -c "from flash_attn_interface import flash_attn_func; print('FA3: OK')" 2>/dev/null || echo "FA3: not available (SDPA will use FA2 or math backend)" +} + +# 5. Symlink data if needed +[ -L data ] || [ -d data ] || ln -sf /workspace/data data +[ -d data/datasets/fineweb10B_sp1024 ] && echo "Data: OK" || echo "WARNING: Data not found at data/datasets/fineweb10B_sp1024" +[ -f data/tokenizers/fineweb_1024_bpe.model ] && echo "Tokenizer: OK" || echo "WARNING: Tokenizer not found" + +# 6. Check Triton +python3 -c " +import torch +print(f'PyTorch {torch.__version__}, CUDA {torch.version.cuda}') +try: + import triton + from triton.tools.tensor_descriptor import TensorDescriptor + print(f'Triton {triton.__version__} + TensorDescriptor: OK → Fused MLP kernel ENABLED') +except Exception as e: + print(f'Triton not available: {e} → Standard MLP path (slower)') +" + +echo "" +echo "=============================================" +echo " SETUP COMPLETE" +echo "=============================================" +echo "" +echo " V17 Stack:" +echo " Turbo-Muon + EngramLite + Parameter Banking (PR #1089)" +echo " Fused Triton MLP kernel (PR #1072, if Triton available)" +echo " Mixed-precision GPTQ int5/int6/int7 + Brotli compression" +echo " GPTQ reserve optimized to 9s (from 14s default)" +echo "" +echo "RUN COMMANDS:" +echo "" +echo " # Single seed test:" +echo " SEED=1337 bash runpod_setup.sh run" +echo "" +echo " # 3-seed submission:" +echo " for S in 1337 42 999; do" +echo " SEED=\$S bash runpod_setup.sh run | tee run_seed\$S.log" +echo " done" +echo "=============================================" diff --git a/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/submission.json b/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/submission.json new file mode 100644 index 0000000000..9d4033370c --- /dev/null +++ b/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/submission.json @@ -0,0 +1,9 @@ +{ + "name": "Turbo-Muon + EngramLite + ParamBanking + GPTQ Reserve Optimization", + "val_bpb": 1.1126, + "bytes_total": 15985912, + "blurb": "PR #1089 stack (Turbo-Muon, EngramLite, Parameter Banking, mixed GPTQ, brotli) with GPTQ reserve optimization (14s to 9s, +55 training steps). Includes experimental fused Triton MLP kernel architecture (disabled, pending PT2.11 compat). 3-seed mean: 1.1126 (std 0.0003). Built on PR #1089.", + "author": "Bortlesboat", + "github_id": "Bortlesboat", + "date": "2026-03-30" +} diff --git a/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/train_gpt.py b/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/train_gpt.py new file mode 100644 index 0000000000..830795c752 --- /dev/null +++ b/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/train_gpt.py @@ -0,0 +1,2672 @@ +from __future__ import annotations +import copy +import glob +import io +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +# Compression preference: brotli > lzma > zstd > zlib +# brotli q=11 typically beats lzma preset=9 by 1-5% on quantized weights +try: + import brotli as _brotli_probe # noqa: F401 + _COMPRESSOR = "brotli" +except ImportError: + _COMPRESSOR = "lzma" + +# Byte-shuffle preprocessing: reorder bytes by stride position before compression. +# For multi-byte values (float16 scales), grouping same-position bytes +# together creates runs of similar values -> better entropy coding. Lossless & fast (<1s). +_BYTE_SHUFFLE = True + +# --- Fused Triton MLP kernel via torch.library.triton_op --- +# Fuses: x @ up_w.T -> leaky_relu(neg_slope) -> square into one kernel pass +# Uses triton_op (not autograd.Function) for torch.compile(fullgraph=True) compat +# Backward uses standard PyTorch ops (PR #1105: Triton bwd = 2.7x slower eager) +HAS_FUSED_MLP = False +IS_ROCM = hasattr(torch.version, 'hip') and torch.version.hip is not None +_NEGATIVE_SLOPE = 0.3 +_SLOPE_SQ = _NEGATIVE_SLOPE ** 2 +try: + import triton + import triton.language as tl + from triton.tools.tensor_descriptor import TensorDescriptor + from torch.library import triton_op, wrap_triton + + @triton.jit + def _fused_leaky_relu_sq_kernel( + a_desc, b_desc, c_desc, pre_desc, + M, N, K, + NEG_SLOPE: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + NUM_SMS: tl.constexpr, + ): + dtype = tl.bfloat16 + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + k_tiles = tl.cdiv(K, BLOCK_SIZE_K) + num_tiles = num_pid_m * num_pid_n + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for ki in range(k_tiles): + offs_k = ki * BLOCK_SIZE_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_bn, offs_k]) + accumulator = tl.dot(a, b.T, accumulator) + acc = tl.reshape(accumulator, (BLOCK_SIZE_M, 2, BLOCK_SIZE_N // 2)) + acc = tl.permute(acc, (0, 2, 1)) + acc0, acc1 = tl.split(acc) + pre0 = acc0.to(dtype) + pre_desc.store([offs_am, offs_bn], pre0) + post0 = tl.where(pre0 > 0, pre0, NEG_SLOPE * pre0) + post0 = post0 * post0 + c_desc.store([offs_am, offs_bn], post0) + pre1 = acc1.to(dtype) + pre_desc.store([offs_am, offs_bn + BLOCK_SIZE_N // 2], pre1) + post1 = tl.where(pre1 > 0, pre1, NEG_SLOPE * pre1) + post1 = post1 * post1 + c_desc.store([offs_am, offs_bn + BLOCK_SIZE_N // 2], post1) + + @triton_op("paramgolf::fused_up_activation", mutates_args=()) + def _fused_up_activation_flat(x: Tensor, up_w: Tensor) -> tuple[Tensor, Tensor]: + M, K = x.shape + N = up_w.shape[0] + post = torch.empty((M, N), device=x.device, dtype=x.dtype) + pre = torch.empty((M, N), device=x.device, dtype=x.dtype) + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 64 + a_desc = TensorDescriptor.from_tensor(x, [BLOCK_M, BLOCK_K]) + b_desc = TensorDescriptor.from_tensor(up_w, [BLOCK_N, BLOCK_K]) + c_desc = TensorDescriptor.from_tensor(post, [BLOCK_M, BLOCK_N // 2]) + pre_desc = TensorDescriptor.from_tensor(pre, [BLOCK_M, BLOCK_N // 2]) + grid = lambda META: (min(NUM_SMS, triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N)),) + wrap_triton(_fused_leaky_relu_sq_kernel)[grid]( + a_desc, b_desc, c_desc, pre_desc, M, N, K, + NEG_SLOPE=_NEGATIVE_SLOPE, + BLOCK_SIZE_M=BLOCK_M, BLOCK_SIZE_N=BLOCK_N, BLOCK_SIZE_K=BLOCK_K, + GROUP_SIZE_M=1, NUM_SMS=NUM_SMS, num_stages=4, num_warps=8) + return post, pre + + def _setup_context(ctx, inputs, output): + x, up_w = inputs + _post, pre = output + ctx.save_for_backward(x, up_w, pre) + + def _backward(ctx, grad_post, _grad_pre): + x, up_w, pre = ctx.saved_tensors + act_grad = torch.where(pre > 0, 2.0 * pre, (2.0 * _SLOPE_SQ) * pre) + dpre = grad_post * act_grad + dW_up = dpre.T @ x + dx = dpre @ up_w + return dx, dW_up + + _fused_up_activation_flat.register_autograd(_backward, setup_context=_setup_context) + + def _fused_up_activation(x: Tensor, up_w: Tensor) -> Tensor: + orig_shape = x.shape + x_flat = x.reshape(-1, x.shape[-1]) + post, _pre = _fused_up_activation_flat(x_flat, up_w) + return post.view(orig_shape[:-1] + (post.shape[-1],)) + + HAS_FUSED_MLP = False # Hard-disabled: triton_op NaN on PT2.9, enable via FUSED_MLP=1 env +except (ImportError, Exception): + HAS_FUSED_MLP = False +_BYTE_SHUFFLE_STRIDE = 2 # 2 = optimal for float16-heavy data + +_BSHF_MAGIC = b"BSHF" # 4-byte magic for byte-shuffled data + + +def _byte_shuffle(data: bytes, stride: int = 2) -> bytes: + """Transpose byte stream by stride using numpy vectorized indexing. + + Groups byte[i % stride] positions together. For stride=2 on float16 data, + this puts all high bytes together and all low bytes together, dramatically + improving compression. Prepend 4-byte magic + 1-byte stride header. + """ + if stride <= 1 or len(data) < stride: + return data # no-op + src = np.frombuffer(data, dtype=np.uint8) + n = len(src) + out = np.empty(n, dtype=np.uint8) + dest_off = 0 + for pos in range(stride): + chunk = src[pos::stride] # every stride-th byte starting at pos + out[dest_off:dest_off + len(chunk)] = chunk + dest_off += len(chunk) + return _BSHF_MAGIC + bytes([stride]) + out.tobytes() + + +def _byte_unshuffle(data: bytes) -> bytes: + """Inverse of _byte_shuffle using numpy. Auto-detects BSHF magic header.""" + if len(data) < 5 or data[:4] != _BSHF_MAGIC: + return data # no shuffle header -> old format, return as-is + stride = data[4] + if stride < 2: + return data[5:] # invalid stride, just strip header + payload = np.frombuffer(data, dtype=np.uint8, offset=5) + n = len(payload) + out = np.empty(n, dtype=np.uint8) + src_off = 0 + for pos in range(stride): + chunk_len = n // stride + (1 if pos < n % stride else 0) + out[pos::stride][:chunk_len] = payload[src_off:src_off + chunk_len] + src_off += chunk_len + return out.tobytes() + + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + lr_floor = float(os.environ.get("LR_FLOOR", 0.05)) # Minimum LR as fraction of peak (prevents sharp quant-sensitive minima) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.5)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 4)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + swa_threshold = float(os.environ.get("SWA_THRESHOLD", 0.2)) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1"))) + ema_decay = float(os.environ.get("EMA_DECAY", 0.997)) + # New hyperparameters from our improvements + embed_beta1 = float(os.environ.get("EMBED_BETA1", 0.7)) + head_beta1 = float(os.environ.get("HEAD_BETA1", 0.7)) + muon_post_norm = os.environ.get("MUON_POST_NORM", "row_col") + qat_threshold = float(os.environ.get("QAT_THRESHOLD", 0.15)) + qat_clip_pct = float(os.environ.get("QAT_CLIP_PCT", 0.9995)) + late_qat = bool(int(os.environ.get("LATE_QAT", "1"))) + mixed_precision = bool(int(os.environ.get("MIXED_PRECISION", "1"))) + # mp_promote removed — dynamic allocation based on artifact size budget (see _allocate_bits_mixed) + target_bytes_limit = int(os.environ.get("TARGET_BYTES", 16_000_000)) + # N-gram params + ngram_buckets = int(os.environ.get("NGRAM_BUCKETS", 8192)) + ngram_heads = int(os.environ.get("NGRAM_HEADS", 2)) + ngram_orders = int(os.environ.get("NGRAM_ORDERS", 2)) + ngram_dim_per_head = int(os.environ.get("NGRAM_DIM_PER_HEAD", 32)) + # GPTQ calibration + gptq_calib_batches = int(os.environ.get("GPTQ_CALIB_BATCHES", 64)) + gptq_block_size = int(os.environ.get("GPTQ_BLOCK_SIZE", 128)) + gptq_damp = float(os.environ.get("GPTQ_DAMP", 0.01)) + gptq_reserve_ms = float(os.environ.get("GPTQ_RESERVE_MS", 14000.0)) # Reserve from training budget for GPTQ calibration (PR #634) + gptq_col_order = os.environ.get("GPTQ_COL_ORDER", "desc") # "desc" (actorder) or "asc" (PR#753-style) + gptq_single_pass = bool(int(os.environ.get("GPTQ_SINGLE_PASS", "1"))) # Pre-compute scales, run GPTQ once + soft_round_qat = bool(int(os.environ.get("SOFT_ROUND_QAT", "1"))) # 1=soft-round, 0=STE + # Snapshot: save unbanked_sd + hessians after training, or load them to skip training + snapshot_post_hessian = bool(int(os.environ.get("SNAPSHOT_POST_HESSIAN", "0"))) # save snapshot and exit + load_snapshot = os.environ.get("LOAD_SNAPSHOT", "") # path to snapshot file; skips training if set + +# --- Turbo-Muon Newton-Schulz orthogonalization --- + +# Polar Express optimal degree-5 coefficients (Amsel et al., arXiv:2505.16932). +# With AOL preconditioning we skip iter 1 -- AOL already contracts the singular value +# range via Gershgorin scaling, so we start from iteration 2. +_POLAR_COEFFS_FULL = [ + (8.28721201814563, -23.595886519098837, 17.300387312530933), # iter 1 (Frobenius init only) + (4.107059111542203, -2.9478499167379106, 0.5448431082926601), # iter 2 + (3.9486908534822946, -2.908902115962949, 0.5518191394370137), # iter 3 + (3.3184196573706015, -2.488488024314874, 0.51004894012372), # iter 4 + (2.300652019954817, -1.6689039845747493, 0.4188073119525673), # iter 5 + (1.891301407787398, -1.2679958271945868, 0.37680408948524835), # iter 6 + (1.875, -1.25, 0.375), # iter 7+: converged fixed point +] +_AOL_POLAR_COEFFS = _POLAR_COEFFS_FULL[1:] # skip iter 1 when using AOL + + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 4, eps: float = 1e-7) -> Tensor: + """Turbo-Muon: Newton-Schulz with left-Gram AOL + Polar Express coefficients. + + Supports both 2D (M, N) and batched 3D (B, M, N) input. + Uses left Gram (X@X.T) throughout -- always (m*m) where m <= n, giving + tighter Gershgorin bounds and up to 9x cheaper matmuls for non-square matrices. + """ + X = G.bfloat16() + if X.ndim == 2: + # --- 2D path (single matrix) --- + transposed = X.size(0) > X.size(1) + if transposed: + X = X.T + A = X @ X.T + s = 1.0 / (A.abs().sum(dim=1).sqrt() + eps) + X = s.unsqueeze(1) * X + A = s.unsqueeze(0) * A * s.unsqueeze(1) + for i in range(steps): + a, b, c = _AOL_POLAR_COEFFS[min(i, len(_AOL_POLAR_COEFFS) - 1)] + if i > 0: + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + else: + # --- 3D batched path (B, M, N) --- + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + A = X @ X.mT # (B, m, m) + s = 1.0 / (A.abs().sum(dim=-1).sqrt() + eps) # (B, m) + X = s.unsqueeze(-1) * X # (B, m, n) + A = s.unsqueeze(-2) * A * s.unsqueeze(-1) # (B, m, m) + for i in range(steps): + a, b, c = _AOL_POLAR_COEFFS[min(i, len(_AOL_POLAR_COEFFS) - 1)] + if i > 0: + A = X @ X.mT + B = b * A + c * A @ A + X = a * X + B @ X + return X.mT if transposed else X + + +def _post_ns_normalize(X: Tensor, mode: str) -> Tensor: + """Muon+ post-NS normalization: equalize per-neuron update magnitudes. + + Modes: "none" (passthrough), "row", "col", "row_col". + Supports both 2D and batched 3D input. + Norms computed in float32 for numerical stability (X is typically bf16 from NS). + """ + if mode == "none": + return X + if mode in ("row", "row_col"): + X = X / (X.float().norm(dim=-1, keepdim=True).to(X.dtype) + 1e-7) + if mode in ("col", "row_col"): + X = X / (X.float().norm(dim=-2, keepdim=True).to(X.dtype) + 1e-7) + return X + + +# --- Parallel Muon optimizer --- + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0, + post_norm: str = "none"): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay, post_norm=post_norm), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + post_norm = group.get("post_norm", "none") + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + update = _post_ns_normalize(update, post_norm) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# --- Tokenizer evaluation helpers --- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# --- Quantization helpers --- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,skip_gate,skip_gates,smear,ve_layer_scales,ve_shared.scale,ngram_gate", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + # Keep tok_emb.weight in fp16 passthrough (tied embeddings — quantization errors degrade all tokens) + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL or name == "tok_emb.weight": + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +# --- Data loading --- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# --- Transformer modules --- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + _qat_clip_pct: float = 0.9995 + _qat_default_bits: int = 5 + _qat_soft_round: bool = False + _qat_soft_alpha: Tensor = None # type: ignore[assignment] # CUDA tensor, ramped 1→16; None until QAT starts + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + bits = getattr(self, '_qat_bits', CastedLinear._qat_default_bits) + if CastedLinear._qat_soft_round: + w = _apply_qat_soft_round(w, self.weight, bits, CastedLinear._qat_soft_alpha) + else: + w = _apply_qat_ste(w, self.weight, bits) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + + +def _apply_qat_ste(w_cast: Tensor, w_fp32: Tensor, bits: int) -> Tensor: + """Straight-through estimator QAT for a weight matrix.""" + if bits <= 0: + return w_cast + qmax = (1 << (bits - 1)) - 1 + qmin = -(1 << (bits - 1)) + with torch.no_grad(): + w32 = w_fp32.float() + row_max = w32.abs().amax(dim=-1) * CastedLinear._qat_clip_pct + scale = (row_max / float(qmax)).clamp_min(1.0 / float(qmax)) + w_q = (torch.clamp(torch.round(w32 / scale.unsqueeze(-1)), qmin, qmax) * scale.unsqueeze(-1)).to(w_cast.dtype) + return w_cast + (w_q - w_cast).detach() + + +def _apply_qat_soft_round(w_cast: Tensor, w_fp32: Tensor, bits: int, alpha: "float | Tensor") -> Tensor: + """Differentiable quantization using soft-round (sigmoid approximation of round()). + + Instead of STE (zero gradient through round), uses sigmoid(alpha*(frac-0.5)) which + provides real gradient signal pushing weights toward quantization grid points. + alpha ramps from 1 (smooth) to 16 (near-hard rounding) during QAT. + """ + if bits <= 0: + return w_cast + out_dtype = w_cast.dtype + qmax = (1 << (bits - 1)) - 1 + qmin = -(1 << (bits - 1)) + # Scale: detached so gradient doesn't flow through row_max computation + row_max = w_fp32.float().detach().abs().amax(dim=-1) * CastedLinear._qat_clip_pct + scale = (row_max / float(qmax)).clamp_min(1.0 / float(qmax)) + # Differentiable quantization + w_scaled = w_fp32.float() / scale.unsqueeze(-1) + w_clamped = w_scaled.clamp(float(qmin), float(qmax)) + w_floor = w_clamped.detach().floor() # integer part (no grad) + frac = w_clamped - w_floor # fractional part (grad flows through w_clamped) + soft = w_floor + torch.sigmoid(alpha * (frac - 0.5)) + w_q = soft * scale.unsqueeze(-1) + return w_q.to(out_dtype) + + +def _apply_bank_qat(w: Tensor, bits: int, dtype: torch.dtype) -> Tensor: + """Apply QAT to a bank weight slice, returning the result in the target dtype.""" + w_cast = w.to(dtype) + if CastedLinear._qat_enabled and torch.is_grad_enabled() and w.ndim == 2: + if CastedLinear._qat_soft_round: + return _apply_qat_soft_round(w_cast, w, bits, CastedLinear._qat_soft_alpha) + return _apply_qat_ste(w_cast, w, bits) + return w_cast + + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + # No CastedLinear -- weights come from banks + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + bsz, seqlen, dim = x.shape + _qb = CastedLinear._qat_default_bits + q = F.linear(x, _apply_bank_qat(q_w, _qb, x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, _apply_bank_qat(k_w, _qb, x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, _apply_bank_qat(v_w, _qb, x.dtype)) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + # SDPA expects (B, H, T, D) + y = F.scaled_dot_product_attention( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), + is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads), + ).transpose(1, 2) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return F.linear(y, _apply_bank_qat(out_w, _qb, y.dtype)), None + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = F.pad(x[:, :-1], (0, 0, 1, 0)) # prepend zero row, avoids cat+alloc + return torch.lerp(x, x_prev, g) + +class EngramLite(nn.Module): + """Multi-head hash-based n-gram embedding with learned gating (Engram-lite). + + Replaces BigramHashEmbedding with: multi-head hashing for collision resistance, + bigram+trigram coverage, and a per-dim learned gate to suppress n-gram signal + when the Transformer's own reasoning is more informative. + """ + def __init__(self, num_buckets: int, num_heads: int, num_orders: int, dim_per_head: int, model_dim: int): + super().__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.num_orders = num_orders + self.dim_per_head = dim_per_head + total_slots = num_orders * num_heads * num_buckets + concat_dim = num_orders * num_heads * dim_per_head + self.embed = nn.Embedding(total_slots, dim_per_head) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(concat_dim, model_dim, bias=False) + self.proj._zero_init = True + self.ngram_gate = nn.Parameter(torch.zeros(model_dim, dtype=torch.float32)) + + def forward(self, input_ids: Tensor) -> Tensor: + B = self.num_buckets + prev_ids = F.pad(input_ids[:, :-1], (1, 0), value=0) + # Bigram hashes (2 heads, independent prime-based mixing) + bi_h0 = (prev_ids * 1009 + input_ids) % B + bi_h1 = ((prev_ids * 2719 + 314159) ^ (input_ids * 3137)) % B + indices = [bi_h0, bi_h1 + B] + # Trigram hashes (2 heads) if enabled + if self.num_orders >= 2: + pp_ids = F.pad(prev_ids[:, :-1], (1, 0), value=0) + tri_h0 = ((pp_ids * 36313) ^ (prev_ids * 27191) ^ (input_ids * 4903)) % B + tri_h1 = ((pp_ids * 7919) ^ (prev_ids * 4391) ^ (input_ids * 6151)) % B + offset = 2 * B + indices.extend([tri_h0 + offset, tri_h1 + offset + B]) + # Unified lookup + concat + all_idx = torch.stack(indices, dim=-1) + all_emb = self.embed(all_idx) + flat = all_emb.reshape(*input_ids.shape, -1) + out = self.proj(flat) + gate = torch.sigmoid(self.ngram_gate.to(dtype=out.dtype))[None, None, :] + return out * gate + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to kv_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, kv_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, kv_dim, bias=False) if ve_dim != kv_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float): + super().__init__() + # No CastedLinear -- weights come from banks + def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: + _qb = CastedLinear._qat_default_bits + up_w_q = _apply_bank_qat(up_w, _qb, x.dtype) + down_w_q = _apply_bank_qat(down_w, _qb, x.dtype) + if HAS_FUSED_MLP and x.is_cuda and not IS_ROCM: + # Fused Triton forward via triton_op — backward uses standard PyTorch ops + post = _fused_up_activation(x, up_w_q) + return F.linear(post, down_w_q) + x = F.leaky_relu(F.linear(x, up_w_q), negative_slope=_NEGATIVE_SLOPE) + return F.linear(x.square(), down_w_q) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None) -> tuple[Tensor, Tensor | None]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, raw_v = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + return x_out, raw_v + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: float, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + ngram_buckets: int = 0, + ngram_heads: int = 2, + ngram_orders: int = 2, + ngram_dim_per_head: int = 32, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = EngramLite(ngram_buckets, ngram_heads, ngram_orders, ngram_dim_per_head, model_dim) if ngram_buckets > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Parameter banks: contiguous 3D tensors for batched optimizer + head_dim = model_dim // num_heads + kv_dim = num_kv_heads * head_dim + mlp_dim = int(mlp_mult * model_dim) + self.num_layers = num_layers + self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim_ve = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + mimetic_alpha = 0.05 + head_dim = self.qo_bank.shape[1] // (self.blocks[0].attn.num_heads) # model_dim // num_heads + num_kv_heads = self.blocks[0].attn.num_kv_heads + num_heads = self.blocks[0].attn.num_heads + group = num_heads // num_kv_heads + # Init banks: orthogonal, with proj layers scaled down and out zero-init via mimetic V-O + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up + nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) + # Mimetic V-O Init for output projection: O_h = -alpha * V_h per head + v_w = self.kv_bank.data[n + i] # (kv_dim, model_dim) + o_w = torch.zeros_like(self.qo_bank.data[n + i]) # (model_dim, model_dim) + for kv_h in range(num_kv_heads): + v_block = v_w[kv_h * head_dim : (kv_h + 1) * head_dim, :] + for g_idx in range(group): + q_h = kv_h * group + g_idx + o_w[q_h * head_dim : (q_h + 1) * head_dim, :] = -mimetic_alpha * v_block + self.qo_bank.data[n + i].copy_(o_w) + # Init remaining nn.Linear modules (bigram proj, lm_head) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + def _compute_ve_base(self, input_ids: Tensor) -> Tensor | None: + """Precompute shared value embedding base once per forward pass.""" + if self.ve_shared is None: + return None + return self.ve_shared(input_ids) + + def _get_ve(self, layer_idx: int, ve_base: Tensor | None) -> Tensor | None: + """Get value embedding for a specific layer using precomputed base + per-layer scale.""" + if ve_base is None or layer_idx not in self.ve_layer_indices: + return None + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_base = self._compute_ve_base(input_ids) + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, ve_base) + x, _raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + skip = skips.pop() + g = torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[None, None, :] + scaled_skip = self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skip + x = torch.lerp(scaled_skip, x, g) + ve = self._get_ve(bi, ve_base) + x, _raw_v = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_base = self._compute_ve_base(input_ids) + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, ve_base) + x, _raw_v = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + skip = skips.pop() + g = torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[None, None, :] + scaled_skip = self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skip + x = torch.lerp(scaled_skip, x, g) + ve = self._get_ve(bi, ve_base) + x, _raw_v = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + +# --- Sliding window evaluation --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=True, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +# --- GPTQ-lite int6 quantization --- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def _precompute_row_scales(W: Tensor, qmax: int) -> Tensor: + """Pre-compute optimal per-row scales by searching percentile clipping thresholds. + PR #753-style: find best scale once, then run GPTQ once with fixed scales.""" + t32 = W.float() + best_s = t32.abs().amax(dim=1) / float(qmax) + best_s = best_s.clamp_min(1.0 / float(qmax)) + best_err = torch.full((t32.shape[0],), float("inf"), device=t32.device) + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / float(qmax)).clamp_min(1.0 / float(qmax)) + q = torch.clamp(torch.round(t32 / s[:, None]), -qmax, qmax) + recon = q * s[:, None] + err = (t32 - recon).pow(2).mean(dim=1) + improved = err < best_err + best_s[improved] = s[improved] + best_err[improved] = err[improved] + return best_s + +def _gptq_block_sweep(W: Tensor, Hinv: Tensor, sf: Tensor, qmin: int, qmax: int, block_size: int) -> Tensor: + """Run one GPTQ block-column sweep with fixed per-row scales. Returns quantized Q.""" + rows, cols = W.shape + Q = torch.zeros_like(W, dtype=torch.int8) + W_work = W.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + count = i2 - i1 + W1 = W_work[:, i1:i2].clone() + Q1 = torch.zeros(rows, count, dtype=torch.int8, device=W.device) + Err1 = torch.zeros(rows, count, device=W.device) + Hinv1 = Hinv[i1:i2, i1:i2] + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + q = torch.clamp(torch.round(w / sf), qmin, qmax).to(torch.int8) + Q1[:, i] = q + err = (w - q.float() * sf) / d + W1[:, i:] -= err.unsqueeze(1) * Hinv1[i, i:].unsqueeze(0) + Err1[:, i] = err + Q[:, i1:i2] = Q1 + if i2 < cols: + W_work[:, i2:] -= Err1 @ Hinv[i1:i2, i2:] + return Q + +def quantize_int6_gptq(weight, hessian=None, clip_range=31, block_size=128, damp_factor=0.01, + col_order="desc", single_pass=False): + """Full GPTQ: Hessian-aware int6 quantization with Cholesky error compensation. + If hessian is None, falls back to percentile search.""" + device = hessian.device if hessian is not None else weight.device + t32 = weight.float().to(device) + if t32.ndim != 2 or hessian is None: + return _quantize_int6_percentile(t32, clip_range) + rows, cols = t32.shape + H = hessian.float().clone() + diag = torch.diag(H) + dead = diag == 0 + # Compute damp from non-dead columns only, then add to all diagonals + damp = damp_factor * (torch.mean(diag[~dead]) if not dead.all() else torch.tensor(1.0, device=H.device)) if dead.any() else damp_factor * torch.mean(diag) + diag_idx = torch.arange(cols, device=H.device) + H[diag_idx, diag_idx] += damp + # Reset dead diag to just damp (smallest value) so they sort LAST in descending actorder + H[dead, dead] = damp + perm = torch.argsort(torch.diag(H), descending=(col_order == "desc")) + inv_perm = torch.argsort(perm) + W = t32[:, perm].clone() + W[:, dead[perm]] = 0 + H = H[perm][:, perm] + try: + Hinv = torch.linalg.cholesky(H) + Hinv = torch.cholesky_inverse(Hinv) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + except torch.linalg.LinAlgError: + return _quantize_int6_percentile(t32, clip_range) + # Single-pass mode: pre-compute scales once, run GPTQ once + if single_pass: + best_s = _precompute_row_scales(W, clip_range) + sf = best_s.float() + Q = _gptq_block_sweep(W, Hinv, sf, -clip_range, clip_range, block_size) + Q = Q[:, inv_perm] + return Q.cpu(), best_s.to(torch.float16).cpu() + # Multi-pass: search 5 percentiles, run GPTQ per percentile, pick best by MSE + best_q = None; best_scale = None; best_err = float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + sf = s.float() + Q = _gptq_block_sweep(W, Hinv, sf, -clip_range, clip_range, block_size) + recon = Q.float() * sf[:, None] + mse = (W - recon).pow(2).mean().item() + if mse < best_err: + best_q, best_scale, best_err = Q, s, mse + best_q = best_q[:, inv_perm] + return best_q.cpu(), best_scale.cpu() + +def _quantize_int6_percentile(t32, clip_range=31): + """Fallback: percentile search (for 1D or no-Hessian cases).""" + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: + """Convert 3D bank tensors into individual 2D tensors with standard names.""" + out: dict[str, Tensor] = {} + n = num_layers + for name, tensor in sd.items(): + if name == "qo_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] + out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] + elif name == "kv_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] + out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] + elif name == "mlp_up_bank": + for i in range(n): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "mlp_down_bank": + for i in range(n): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor] | None = None) -> dict[str, Tensor]: + """Convert individual 2D tensors back into 3D bank tensors.""" + out: dict[str, Tensor] = {} + n = num_layers + consumed: set[str] = set() + qo_slices: list[Tensor | None] = [None] * (2 * n) + kv_slices: list[Tensor | None] = [None] * (2 * n) + up_slices: list[Tensor | None] = [None] * n + down_slices: list[Tensor | None] = [None] * n + _BANK_MAP = { + "attn.c_q.weight": (qo_slices, 0), + "attn.proj.weight": (qo_slices, n), + "attn.c_k.weight": (kv_slices, 0), + "attn.c_v.weight": (kv_slices, n), + "mlp.fc.weight": (up_slices, 0), + "mlp.proj.weight": (down_slices, 0), + } + for i in range(n): + for suffix, (target_list, offset) in _BANK_MAP.items(): + key = f"blocks.{i}.{suffix}" + if key in sd: + target_list[offset + i] = sd[key] + consumed.add(key) + # Stack into bank tensors — validate all slices are present + for bank_name, slices in [("qo_bank", qo_slices), ("kv_bank", kv_slices), ("mlp_up_bank", up_slices), ("mlp_down_bank", down_slices)]: + if not any(s is not None for s in slices): + continue + missing = [i for i, s in enumerate(slices) if s is None] + if missing: + raise ValueError(f"_rebank_state_dict: {bank_name} missing slice indices {missing}") + out[bank_name] = torch.stack(slices) # type: ignore[arg-type] + # Pass through non-banked params + for key, val in sd.items(): + if key not in consumed: + out[key] = val + return out + +# --- Non-banked model for Hessian collection --- +# This mirrors the unbanked state dict keys: blocks.{i}.attn.c_q/c_k/c_v/proj, blocks.{i}.mlp.fc/proj + +class _HessianAttn(nn.Module): + """Non-banked attention with CastedLinear layers for Hessian hooks.""" + def __init__(self, dim, num_heads, num_kv_heads, rope_base, qk_gain_init): + super().__init__() + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + kv_dim = num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False + def _xsa_efficient(self, y, v): + B, T, H, D = y.shape; Hkv = v.size(-2); group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x, v_embed=None): + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = F.scaled_dot_product_attention( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), + is_causal=True, enable_gqa=(self.num_kv_heads != self.num_heads), + ).transpose(1, 2) + if self.use_xsa: + y = self._xsa_efficient(y, v) + return self.proj(y.reshape(bsz, seqlen, dim)) + +class _HessianMLP(nn.Module): + """Non-banked MLP with CastedLinear layers for Hessian hooks.""" + def __init__(self, dim, mlp_mult): + super().__init__() + self.fc = CastedLinear(dim, int(mlp_mult * dim), bias=False) + self.proj = CastedLinear(int(mlp_mult * dim), dim, bias=False) + def forward(self, x): + return self.proj(F.leaky_relu(self.fc(x), negative_slope=_NEGATIVE_SLOPE).square()) + +class _HessianBlock(nn.Module): + def __init__(self, dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, layer_idx=0, ln_scale=False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = _HessianAttn(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = _HessianMLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + def forward(self, x, x0, v_embed=None): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + return x_out + +class _HessianGPT(nn.Module): + """Non-banked GPT model matching unbanked state dict keys for Hessian collection.""" + def __init__(self, vocab_size, num_layers, model_dim, num_heads, num_kv_heads, + mlp_mult, tie_embeddings, logit_softcap, rope_base, qk_gain_init, + ngram_buckets=0, ngram_heads=2, ngram_orders=2, ngram_dim_per_head=32, + xsa_last_n=0, rope_dims=0, ln_scale=False, + ve_enabled=False, ve_dim=128, ve_layers="9,10"): + super().__init__() + self.tie_embeddings = tie_embeddings + self.logit_softcap = logit_softcap + self.num_layers = num_layers + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = EngramLite(ngram_buckets, ngram_heads, ngram_orders, ngram_dim_per_head, model_dim) if ngram_buckets > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.skip_gates = nn.Parameter(torch.zeros(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([ + _HessianBlock(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + layer_idx=i, ln_scale=ln_scale) + for i in range(num_layers) + ]) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + kv_dim = num_kv_heads * (model_dim // num_heads) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList([nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices]) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + def _get_ve(self, layer_idx, input_ids, ve_cache): + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_cache['ve'] * self.ve_layer_scales[ve_idx].to(dtype=ve_cache['ve'].dtype) + def forward(self, input_ids, target_ids): + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips = [] + ve_cache = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + skip = skips.pop() + g = torch.sigmoid(self.skip_gates[i].to(dtype=x.dtype))[None, None, :] + scaled_skip = self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skip + x = torch.lerp(scaled_skip, x, g) + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + logits_proj = F.linear(x_flat, self.tok_emb.weight) if self.tie_embeddings else self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + +def collect_hessians(hessian_model, train_loader, args, device, grad_accum_steps, num_batches=256): + """Run calibration batches through a non-banked model, collecting H = X^T X for each CastedLinear.""" + hessians = {} + hooks = [] + for name, module in hessian_model.named_modules(): + if isinstance(module, CastedLinear): + param_name = name + ".weight" + cols = module.weight.shape[1] + hessians[param_name] = torch.zeros(cols, cols, dtype=torch.float32, device=device) + def make_hook(pname): + def hook_fn(mod, inp, out): + x = inp[0].detach() + if x.ndim == 3: + x = x.reshape(-1, x.shape[-1]) + hessians[pname] += (x.T @ x).float() # bf16 matmul, fp32 accumulate + return hook_fn + h = module.register_forward_hook(make_hook(param_name)) + hooks.append(h) + hessian_model.eval() + with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=torch.bfloat16): + for _ in range(num_batches): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + hessian_model(x, y) + for h in hooks: + h.remove() + ws = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1 + for name in hessians: + H = hessians[name] + if ws > 1: + dist.all_reduce(H, op=dist.ReduceOp.SUM) + H /= (num_batches * ws) + hessians[name] = H.cpu() + hessian_model.train() + return hessians + +# --- Mixed precision bit allocation --- + +def _bits_to_range(bits: int) -> tuple[int, int]: + """Convert bit-width to (qmin, qmax).""" + return -(1 << (bits - 1)), (1 << (bits - 1)) - 1 + +# Dynamic mixed-precision constants (tune after observing estimate vs actual) +_MP_BYTES_PER_PARAM_INT5 = 0.46 # estimated compressed bytes per quantized param at int5 +_MP_COST_PER_EXTRA_BIT = 0.24 # additional compressed bytes per param per extra bit above int5 +_MP_NON_WEIGHT_COMPRESS = 0.55 # compression ratio for non-quantized tensors (fp16 embeds, scales) +_MP_PRUNE_HEADROOM_FRAC = 0.02 # reserve 2% of byte budget so selective pruning only trims a small % + + +def _allocate_bits_mixed( + hessian_map: dict[str, Tensor], + state_dict: dict[str, Tensor], + target_bytes: int = 16_000_000, + code_bytes: int = 0, +) -> tuple[dict[str, int], list[tuple[str, int, float]], dict[str, float]]: + """Dynamically allocate int5-int7 per tensor group based on Hessian sensitivity + and compressed artifact size budget. + + Greedy: promotes most-sensitive groups first (top group -> int7, rest -> int6) + until estimated compressed size approaches target_bytes minus pruning headroom. + Returns (tensor_name -> bits, [(group_key, bits, sensitivity)], estimate_info).""" + # 1. Group tensors by (layer, type) -- compute sensitivity and numel per group + group_traces: dict[str, list[float]] = {} + group_numel: dict[str, int] = {} + tensor_to_group: dict[str, str] = {} + for name, H in hessian_map.items(): + trace_val = float(torch.trace(H).item()) / H.shape[0] + if not name.startswith("blocks."): + continue + dot2 = name.index(".", 7) + layer_idx = int(name[7:dot2]) + gtype = "attn" if ".attn." in name else "mlp" if ".mlp." in name else "other" + gkey = f"layer.{layer_idx}.{gtype}" + group_traces.setdefault(gkey, []).append(trace_val) + tensor_to_group[name] = gkey + w = state_dict.get(name) + if w is not None: + group_numel[gkey] = group_numel.get(gkey, 0) + w.numel() + group_sensitivity = {k: sum(v) / len(v) for k, v in group_traces.items()} + ranked = sorted(group_sensitivity.items(), key=lambda x: x[1], reverse=True) + + # 2. Estimate baseline compressed size (all quantized weights at int5) + total_quant_numel = sum(group_numel.values()) + non_weight_raw = sum( + t.numel() * t.element_size() for name, t in state_dict.items() + if name not in hessian_map + ) + base_estimate = ( + code_bytes + + int(non_weight_raw * _MP_NON_WEIGHT_COMPRESS) + + int(total_quant_numel * _MP_BYTES_PER_PARAM_INT5) + ) + budget = int(target_bytes * (1.0 - _MP_PRUNE_HEADROOM_FRAC)) - base_estimate + + # Early exit: if base estimate already exceeds budget, return all int5 + if budget <= 0: + bit_allocation = {tname: 5 for tname in tensor_to_group} + log_entries = [(gkey, 5, group_sensitivity[gkey]) for gkey, _ in ranked] + estimate_info = { + "base_mb": base_estimate / 1e6, "promoted_mb": 0.0, + "total_mb": base_estimate / 1e6, "budget_mb": target_bytes / 1e6, + "headroom_kb": 0.0, "prune_room_bytes": target_bytes - base_estimate, + "warning": "budget_exhausted", + } + return bit_allocation, log_entries, estimate_info + + # 3. Greedy promotion: most sensitive first + # Pass 1: try int7 for top group. Pass 2: fill remaining with int6. + group_bits: dict[str, int] = {gkey: 5 for gkey, _ in ranked} + estimated_extra = 0 + if ranked: + top_gkey = ranked[0][0] + top_numel = group_numel.get(top_gkey, 0) + cost_int7 = int(top_numel * _MP_COST_PER_EXTRA_BIT * 2) + cost_int6 = int(top_numel * _MP_COST_PER_EXTRA_BIT * 1) + if top_numel > 0 and cost_int7 <= budget: + group_bits[top_gkey] = 7 + estimated_extra += cost_int7 + elif top_numel > 0 and cost_int6 <= budget: + group_bits[top_gkey] = 6 # int7 doesn't fit, fall back to int6 + estimated_extra += cost_int6 + for gkey, _sens in ranked: + if group_bits[gkey] > 5: + continue # already promoted + numel = group_numel.get(gkey, 0) + if numel == 0: + continue + cost = int(numel * _MP_COST_PER_EXTRA_BIT) # 1 extra bit for int6 + if estimated_extra + cost <= budget: + group_bits[gkey] = 6 + estimated_extra += cost + # don't break -- smaller groups later in the list may still fit + + # 4. Map back to tensor names + bit_allocation: dict[str, int] = {} + for tname, gkey in tensor_to_group.items(): + bit_allocation[tname] = group_bits[gkey] + log_entries = [(gkey, group_bits[gkey], group_sensitivity[gkey]) for gkey, _ in ranked] + + # 5. Size estimate metadata for caller to log via log0() + est_total = base_estimate + estimated_extra + headroom_bytes = int(target_bytes * _MP_PRUNE_HEADROOM_FRAC) + estimate_info = { + "base_mb": base_estimate / 1e6, + "promoted_mb": estimated_extra / 1e6, + "total_mb": est_total / 1e6, + "budget_mb": target_bytes / 1e6, + "headroom_kb": headroom_bytes / 1e3, + "prune_room_bytes": target_bytes - est_total - headroom_bytes, + } + return bit_allocation, log_entries, estimate_info + + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], hessians: dict[str, Tensor] | None = None, + bit_allocation: dict[str, int] | None = None, gptq_damp: float = 0.01, + block_size: int = 128, col_order: str = "desc", single_pass: bool = False): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536 or name == "tok_emb.weight": + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + # Use mixed precision bit allocation if available + bits = bit_allocation.get(name, 6) if bit_allocation else 6 + qmin, qmax = _bits_to_range(bits) + cr = qmax + H = hessians.get(name) if hessians else None + if H is not None: + q, s = quantize_int6_gptq(t, hessian=H, clip_range=cr, block_size=block_size, + damp_factor=gptq_damp, col_order=col_order, single_pass=single_pass) + else: + q, s = quantize_int6_per_row(t, clip_range=cr) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": f"int{bits}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# --- Training --- + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ and world_size > 1 + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(True) + enable_flash_sdp(True) + enable_mem_efficient_sdp(True) + enable_math_sdp(True) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + if not args.load_snapshot: + CastedLinear._qat_enabled = False # late_qat enables mid-run when LR scale drops + CastedLinear._qat_soft_round = False + CastedLinear._qat_clip_pct = args.qat_clip_pct + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + ngram_buckets=args.ngram_buckets, + ngram_heads=args.ngram_heads, + ngram_orders=args.ngram_orders, + ngram_dim_per_head=args.ngram_dim_per_head, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + ).to(device).bfloat16() + # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward + base_model.qo_bank.data = base_model.qo_bank.data.float() + base_model.kv_bank.data = base_model.kv_bank.data.float() + base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() + base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, + # and non-bank grads are manually all-reduced before Adam steps. + # Compile NS functions for ~2x speedup on the orthogonalization hot path + global zeropower_via_newtonschulz5, _post_ns_normalize + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + _post_ns_normalize = torch.compile(_post_ns_normalize) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model = compiled_model + + # Optimizer split: + # - 4 parameter banks -> Muon (batched Newton-Schulz) + # - token embedding -> Adam (with embed_beta1) + # - scalars/control tensors -> Adam + # - EngramLite proj -> Muon (small 2D matrix) + matrix_params = [ + base_model.qo_bank, base_model.kv_bank, + base_model.mlp_up_bank, base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + if hasattr(base_model, 'skip_gates') and base_model.skip_gates.numel() > 0: + scalar_params.append(base_model.skip_gates) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.ngram_gate) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + ve_proj_weight = None + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + ve_proj_weight = base_model.ve_shared.proj.weight # routed to Muon below + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.embed_beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + post_norm=args.muon_post_norm, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + # Non-bank params that need manual all-reduce (replicated across GPUs) + replicated_params = list(optimizer_tok.param_groups[0]["params"]) + for pg in optimizer_tok.param_groups[1:]: + replicated_params.extend(pg["params"]) + replicated_params.extend(scalar_params) + + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + + # EngramLite proj -> Muon (small 2D matrix, not banked) + if base_model.bigram is not None and base_model.bigram.proj is not None: + optimizer_bigram_proj = Muon( + [base_model.bigram.proj.weight], + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + post_norm=args.muon_post_norm, + ) + for group in optimizer_bigram_proj.param_groups: + group["base_lr"] = args.matrix_lr + optimizers.append(optimizer_bigram_proj) + replicated_params.append(base_model.bigram.proj.weight) + + # VE proj -> Muon (2D matrix, benefits from Newton-Schulz) + if ve_proj_weight is not None: + optimizer_ve_proj = Muon( + [ve_proj_weight], + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + post_norm=args.muon_post_norm, + ) + for group in optimizer_ve_proj.param_groups: + group["base_lr"] = args.matrix_lr + optimizers.append(optimizer_ve_proj) + replicated_params.append(ve_proj_weight) + + optimizer_head = None + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.head_beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + replicated_params.append(base_model.lm_head.weight) + optimizers.append(optimizer_head) + + # Pre-build flat buffer for coalesced all-reduce of non-bank grads (saves ~0.5-1ms/step on multi-GPU) + _nb_grad_numel = [p.numel() for p in replicated_params] + _nb_grad_buf = torch.zeros(sum(_nb_grad_numel), device=device, dtype=torch.float32) if distributed else None + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=True flash=True mem_efficient=True math=True") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"muon:post_norm:{args.muon_post_norm} embed_beta1:{args.embed_beta1} head_beta1:{args.head_beta1}") + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + # Reserve time for GPTQ calibration so LR warmdown, QAT, and wallclock cap all see the effective budget + if max_wallclock_ms is not None and args.gptq_calib_batches > 0 and args.gptq_reserve_ms > 0: + max_wallclock_ms -= args.gptq_reserve_ms + log0(f"gptq:reserving {args.gptq_reserve_ms:.0f}ms from training budget (effective cap: {max_wallclock_ms:.0f}ms)") + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), args.lr_floor) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = min(args.warmdown_iters * step_ms, max_wallclock_ms) + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return max(remaining_ms / max(warmdown_ms, 1e-9), args.lr_floor) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + # All-reduce all grads for warmup (simple, not optimized) + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + # Reset Muon shard_mom buffers — not captured by state_dict() + if optimizer_muon._built: + for m in optimizer_muon._bank_meta: + m['shard_mom'].zero_() + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state: dict[str, Tensor] | None = None + _ema_pairs: list[tuple[Tensor, Tensor]] | None = None + if args.ema_enabled: + ema_state = {name: p.data.detach().float().clone() for name, p in base_model.named_parameters()} + _ema_pairs = [(ema_state[name], p) for name, p in base_model.named_parameters()] + training_time_ms = 0.0 + _qat_start_step = 0 + _qat_total_steps = 0 + _train_loss = torch.zeros((), device=device) + _cap_tensor = torch.zeros(1, device=device, dtype=torch.int32) if distributed else None + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + if last_step: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + log0( + f"step:{step}/{args.iterations} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat and args.qat_threshold > 0 and not CastedLinear._qat_enabled and scale < args.qat_threshold: + CastedLinear._qat_enabled = True + _qat_start_step = step + # Estimate remaining QAT steps using best available info + if stop_after_step is not None: + _qat_total_steps = max(stop_after_step - step, 1) + elif max_wallclock_ms is not None: + _step_ms = elapsed_ms / max(step, 1) + _qat_total_steps = max(int((max_wallclock_ms - elapsed_ms) / max(_step_ms, 1e-9)), 1) + else: + _qat_total_steps = max(args.iterations - step, 1) + if args.soft_round_qat: + CastedLinear._qat_soft_round = True + CastedLinear._qat_soft_alpha = torch.tensor(1.0, device=device) + log0(f"late_qat:enabled step:{step} scale:{scale:.4f} soft_round={args.soft_round_qat} est_steps:{_qat_total_steps}") + elif CastedLinear._qat_soft_round: + # Soft-round alpha ramp: 1→16 over QAT phase (tensor value, no torch.compile recompiles) + # Re-estimate if wallclock cap set after QAT started (tighter bound) + if stop_after_step is not None: + _qat_total_steps = max(stop_after_step - _qat_start_step, 1) + qat_progress = min((step - _qat_start_step) / _qat_total_steps, 1.0) + with torch.no_grad(): + CastedLinear._qat_soft_alpha.fill_(1.0 + 15.0 * qat_progress) + zero_grad_all() + _train_loss.zero_() + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + _train_loss += loss.detach() + (loss * grad_scale).backward() + _train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + # === 3-phase overlapped optimizer step === + # Phase 1: Launch async reduce-scatter for banks (biggest first) + optimizer_muon.launch_reduce_scatters() + # Phase 2: All-reduce non-bank grads (coalesced) + step Adam (while bank RS is in-flight) + if distributed: + offset = 0 + for p, n in zip(replicated_params, _nb_grad_numel): + if p.grad is not None: + _nb_grad_buf[offset:offset + n].copy_(p.grad.reshape(-1)) + else: + _nb_grad_buf[offset:offset + n].zero_() + offset += n + dist.all_reduce(_nb_grad_buf, op=dist.ReduceOp.AVG) + offset = 0 + for p, n in zip(replicated_params, _nb_grad_numel): + if p.grad is not None: + p.grad.copy_(_nb_grad_buf[offset:offset + n].reshape_as(p.grad)) + offset += n + optimizer_tok.step() + optimizer_scalar.step() + # Step non-bank Muon optimizers (e.g., bigram proj) + for opt in optimizers: + if opt is not optimizer_muon and opt is not optimizer_tok and opt is not optimizer_scalar: + if optimizer_head is not None and opt is optimizer_head: + opt.step() + elif isinstance(opt, Muon): + opt.step() + # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) + optimizer_muon.step() + zero_grad_all() + # EMA update (skip .float() for already-fp32 bank params) + if _ema_pairs is not None: + _ema_weight = 1.0 - args.ema_decay + with torch.no_grad(): + for ema_buf, p in _ema_pairs: + ema_buf.lerp_( + p.data if p.data.dtype == torch.float32 else p.data.float(), + _ema_weight, + ) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < args.swa_threshold and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: p.data.detach().float().clone() for name, p in base_model.named_parameters()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, p in base_model.named_parameters(): + swa_state[name].add_(p.data if p.data.dtype == torch.float32 else p.data.float()) + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{_train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + _cap_tensor.fill_(int(reached_cap)) + dist.all_reduce(_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply weight averaging: SWA if available, else EMA + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints (source=raw)") + with torch.no_grad(): + for name, p in base_model.named_parameters(): + if name in swa_state: + p.data.copy_((swa_state[name] / swa_count).to(dtype=p.dtype)) + del swa_state + elif ema_state is not None: + log0("ema:applying EMA weights") + with torch.no_grad(): + for name, p in base_model.named_parameters(): + if name in ema_state: + p.data.copy_(ema_state[name].to(dtype=p.dtype)) + del ema_state + full_state_dict = base_model.state_dict() + export_sd = full_state_dict + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + # Unbank 3D tensors into individual 2D tensors for quantization + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) + + # Full GPTQ: collect Hessians via a temporary non-banked model + _gptq_t0 = time.perf_counter() + log0(f"gptq:building non-banked model for Hessian collection...") + hessian_model = _HessianGPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, logit_softcap=args.logit_softcap, + rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + ngram_buckets=args.ngram_buckets, ngram_heads=args.ngram_heads, + ngram_orders=args.ngram_orders, ngram_dim_per_head=args.ngram_dim_per_head, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + ).to(device).bfloat16() + for m in hessian_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(hessian_model) + # Load unbanked weights into the non-banked model + hessian_model.load_state_dict( + {k: v.to(device) for k, v in unbanked_sd.items() if k in hessian_model.state_dict()}, + strict=False, + ) + log0(f"gptq:calibrating with {args.gptq_calib_batches} batches...") + calib_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + hessians = collect_hessians(hessian_model, calib_loader, args, device, grad_accum_steps, + num_batches=args.gptq_calib_batches) + log0(f"gptq:collected hessians for {len(hessians)} layers") + _gptq_ms = 1000.0 * (time.perf_counter() - _gptq_t0) + _full_budget_ms = (max_wallclock_ms + args.gptq_reserve_ms) if max_wallclock_ms is not None else 0 + log0(f"gptq:budget_check train:{approx_training_time_ms:.0f}ms + gptq:{_gptq_ms:.0f}ms = {approx_training_time_ms + _gptq_ms:.0f}ms (budget:{_full_budget_ms:.0f}ms)") + del hessian_model + torch.cuda.empty_cache() + + # Snapshot: save unbanked_sd + hessians + sd_cpu so compression can be re-run without retraining + if args.snapshot_post_hessian: + if master_process: + snap_path = os.environ.get("SNAPSHOT_PATH", "snapshot_post_hessian.pt") + log0(f"snapshot:saving to {snap_path}...") + torch.save({"unbanked_sd": unbanked_sd, "hessians": {k: v.cpu() for k, v in hessians.items()}, "sd_cpu": sd_cpu}, snap_path) + snap_mb = os.path.getsize(snap_path) / 1e6 + log0(f"snapshot:saved {snap_mb:.1f}MB — exiting (use LOAD_SNAPSHOT={snap_path} to resume compression)") + if distributed: + dist.barrier() # all ranks wait for rank 0 to finish saving + dist.destroy_process_group() + return + + # --- Snapshot restore: load pre-computed state, skip training --- + if args.load_snapshot: + log0(f"snapshot:loading from {args.load_snapshot}...") + _snap = torch.load(args.load_snapshot, map_location="cpu", weights_only=True) + unbanked_sd = _snap["unbanked_sd"] + hessians = {k: v.to(device) for k, v in _snap["hessians"].items()} + sd_cpu = _snap["sd_cpu"] + del _snap + log0(f"snapshot:restored {len(unbanked_sd)} unbanked params, {len(hessians)} hessians") + + # Mixed precision bit allocation + mp_bit_allocation: dict[str, int] | None = None + if args.mixed_precision and hessians: + _mp_code_bytes = len(code.encode("utf-8")) + mp_bit_allocation, mp_log, mp_est = _allocate_bits_mixed( + hessians, unbanked_sd, target_bytes=args.target_bytes_limit, code_bytes=_mp_code_bytes, + ) + log0( + f"mixed_precision:estimate base={mp_est['base_mb']:.2f}MB + promoted={mp_est['promoted_mb']:.2f}MB " + f"= {mp_est['total_mb']:.2f}MB (budget={mp_est['budget_mb']:.1f}MB, " + f"headroom={mp_est['headroom_kb']:.0f}KB, prune_room={mp_est['prune_room_bytes']:+.0f}B)" + ) + promoted = sum(1 for _, b, _ in mp_log if b > 5) + for gkey, bits, sens in mp_log: + log0(f"mixed_precision: {gkey} -> int{bits} (sensitivity={sens:.4e})") + counts: dict[int, int] = {} + for b in mp_bit_allocation.values(): + counts[b] = counts.get(b, 0) + 1 + log0(f"mixed_precision: {' '.join(f'int{b}:{n}' for b, n in sorted(counts.items()))} ({promoted} groups promoted)") + + quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}, hessians=hessians, + bit_allocation=mp_bit_allocation, gptq_damp=args.gptq_damp, + block_size=args.gptq_block_size, col_order=args.gptq_col_order, + single_pass=args.gptq_single_pass) + # NOVEL: Selective +/-1 and +/-2 pruning by reconstruction error + target_bytes = args.target_bytes_limit + code_bytes_est = len(code.encode("utf-8")) + ones_info = [] # (tensor_key, flat_idx, error) + for name, info in quant_meta.items(): + if not isinstance(info, dict): continue + qk, sk = name + ".q", name + ".scale" + if qk not in quant_result or sk not in quant_result: continue + q, s = quant_result[qk], quant_result[sk] + if s.ndim > 0: + # Extended pruning: both +/-1 and +/-2 values + mask = (q.abs() <= 2) & (q.abs() > 0) + if mask.any(): + row_idx = torch.arange(q.shape[0]).unsqueeze(1).expand_as(q)[mask] + flat_idx = torch.arange(q.numel()).reshape(q.shape)[mask] + abs_vals = q.abs()[mask].float() + errors = s.float()[row_idx].pow(2) * abs_vals.pow(2) + for fi, err in zip(flat_idx.tolist(), errors.tolist()): + ones_info.append((qk, fi, err)) + if ones_info: + ones_info.sort(key=lambda x: x[2]) + + # Pre-group by tensor for vectorized pruning: O(len(ones_info)) once + _pg_pos: dict[str, list[int]] = {} + _pg_idx: dict[str, list[int]] = {} + for global_pos, (tname, fidx, _err) in enumerate(ones_info): + if tname in _pg_pos: + _pg_pos[tname].append(global_pos) + _pg_idx[tname].append(fidx) + else: + _pg_pos[tname] = [global_pos] + _pg_idx[tname] = [fidx] + prune_groups: dict[str, tuple[Tensor, Tensor]] = {} + for tname in _pg_pos: + prune_groups[tname] = ( + torch.tensor(_pg_pos[tname], dtype=torch.long), + torch.tensor(_pg_idx[tname], dtype=torch.long), + ) + del _pg_pos, _pg_idx + + def _compress_quant(qr, qm, fast=False): + buf = io.BytesIO() + torch.save({"w": qr, "m": qm}, buf) + raw = buf.getvalue() + if _BYTE_SHUFFLE: + raw = _byte_shuffle(raw, _BYTE_SHUFFLE_STRIDE) + if fast: + blob = zlib.compress(raw, 1) + elif _COMPRESSOR == "brotli": + import brotli + blob = brotli.compress(raw, quality=11) + elif _COMPRESSOR == "lzma": + blob = lzma.compress(raw, preset=9) + else: + blob = zlib.compress(raw, 9) + return len(blob) + code_bytes_est + + def _try_prune(n, fast=False): + """Trial-prune n entries using vectorized scatter, return compressed size.""" + n = min(n, len(ones_info)) + # Only clone tensors that will be modified + tmp = dict(quant_result) # shallow copy — shares unmodified tensors + for tname, (positions, flat_idxs) in prune_groups.items(): + count = int(torch.searchsorted(positions, n).item()) + if count > 0: + tmp[tname] = quant_result[tname].clone() + tmp[tname].view(-1)[flat_idxs[:count]] = 0 + return _compress_quant(tmp, quant_meta, fast=fast) + + def _apply_prune_inplace(n): + """Apply pruning in-place to quant_result.""" + n = min(n, len(ones_info)) + for tname, (positions, flat_idxs) in prune_groups.items(): + count = int(torch.searchsorted(positions, n).item()) + if count > 0: + quant_result[tname].view(-1)[flat_idxs[:count]] = 0 + + no_sz = _try_prune(0) + log0(f"selective_prune: {len(ones_info)} +/-1,+/-2 candidates, unpruned={no_sz/(1024*1024):.2f}MB target={target_bytes/1e6:.2f}MB") + if no_sz <= target_bytes: + log0("selective_prune: already fits, no pruning needed") + else: + # Calibrate fast-vs-real compressor ratio + fast_unpruned = _try_prune(0, fast=True) + fast_full = _try_prune(len(ones_info), fast=True) + real_full = _try_prune(len(ones_info)) + log0(f"selective_prune: full prune={real_full/(1024*1024):.2f}MB") + if real_full > target_bytes: + log0("selective_prune: even full prune not enough, applying all") + _apply_prune_inplace(len(ones_info)) + else: + fast_delta = fast_unpruned - fast_full + real_delta = no_sz - real_full + ratio = real_delta / max(fast_delta, 1) + fast_target = fast_unpruned - int((no_sz - target_bytes) / max(ratio, 0.01)) + log0(f"selective_prune: fast/real ratio={ratio:.3f} fast_target={fast_target/(1024*1024):.2f}MB") + # Binary search using fast compressor (~0.5s/probe vs ~30-60s) + lo, hi = 0, len(ones_info) + while lo < hi: + mid = (lo + hi) // 2 + sz = _try_prune(mid, fast=True) + if sz <= fast_target: hi = mid + else: lo = mid + 1 + # Verify with real compressor and adjust if needed + real_sz = _try_prune(lo) + if real_sz > target_bytes: + while lo < len(ones_info) and real_sz > target_bytes: + lo += max(1, len(ones_info) // 200) + lo = min(lo, len(ones_info)) + real_sz = _try_prune(lo) + log0(f"selective_prune: pruning {lo}/{len(ones_info)} values ({100*lo/max(len(ones_info),1):.1f}%) to fit {target_bytes/1e6:.2f}MB") + _apply_prune_inplace(lo) + + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _BYTE_SHUFFLE: + quant_raw = _byte_shuffle(quant_raw, _BYTE_SHUFFLE_STRIDE) + if _COMPRESSOR == "brotli": + import brotli + quant_blob = brotli.compress(quant_raw, quality=11) + elif _COMPRESSOR == "lzma": + quant_blob = lzma.compress(quant_raw, preset=9) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + # Decompress with cascade + if _COMPRESSOR == "brotli": + import brotli + quant_decompressed = brotli.decompress(quant_blob_disk) + elif _COMPRESSOR == "lzma": + quant_decompressed = lzma.decompress(quant_blob_disk) + else: + quant_decompressed = zlib.decompress(quant_blob_disk) + quant_decompressed = _byte_unshuffle(quant_decompressed) + quant_state = torch.load( + io.BytesIO(quant_decompressed), + map_location="cpu", + weights_only=False, + ) + deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) + # Re-bank the dequantized tensors + deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + ngram_buckets=args.ngram_buckets, ngram_heads=args.ngram_heads, + ngram_orders=args.ngram_orders, ngram_dim_per_head=args.ngram_dim_per_head, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + ).to(device).bfloat16() + eval_model.qo_bank.data = eval_model.qo_bank.data.float() + eval_model.kv_bank.data = eval_model.kv_bank.data.float() + eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() + eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/train_seed1337.log b/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/train_seed1337.log new file mode 100644 index 0000000000..4d518c58eb --- /dev/null +++ b/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/train_seed1337.log @@ -0,0 +1,105 @@ +W0331 02:27:59.206000 40330 torch/distributed/run.py:803] +W0331 02:27:59.206000 40330 torch/distributed/run.py:803] ***************************************** +W0331 02:27:59.206000 40330 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0331 02:27:59.206000 40330 torch/distributed/run.py:803] ***************************************** +logs/cba3bb5f-5b3e-4bc2-a4e5-27c2531104bc.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:30666843 +XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=True flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +muon:post_norm:row_col embed_beta1:0.7 head_beta1:0.7 +seed:1337 +gptq:reserving 9000ms from training budget (effective cap: 591000ms) +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:6.9328 train_time:734ms step_avg:733.61ms +step:2/20000 train_loss:8.3873 train_time:773ms step_avg:386.54ms +step:3/20000 train_loss:7.3414 train_time:874ms step_avg:291.27ms +step:4/20000 train_loss:8.1605 train_time:975ms step_avg:243.73ms +step:5/20000 train_loss:8.3465 train_time:1077ms step_avg:215.38ms +step:6/20000 train_loss:7.8704 train_time:1178ms step_avg:196.38ms +step:7/20000 train_loss:7.1267 train_time:1280ms step_avg:182.91ms +step:8/20000 train_loss:6.6741 train_time:1381ms step_avg:172.65ms +step:9/20000 train_loss:6.2596 train_time:1482ms step_avg:164.72ms +step:10/20000 train_loss:5.7769 train_time:1584ms step_avg:158.38ms +step:500/20000 train_loss:2.3386 train_time:52355ms step_avg:104.71ms +step:1000/20000 train_loss:2.2241 train_time:104586ms step_avg:104.59ms +step:1500/20000 train_loss:2.1770 train_time:156778ms step_avg:104.52ms +step:2000/20000 train_loss:2.0193 train_time:208951ms step_avg:104.48ms +step:2500/20000 train_loss:2.1178 train_time:261082ms step_avg:104.43ms +step:3000/20000 train_loss:2.0981 train_time:313174ms step_avg:104.39ms +step:3500/20000 train_loss:2.1090 train_time:365282ms step_avg:104.37ms +step:4000/20000 train_loss:1.8988 train_time:417368ms step_avg:104.34ms +step:4500/20000 train_loss:2.0416 train_time:469468ms step_avg:104.33ms +swa:start step:5000 +step:5000/20000 train_loss:2.0166 train_time:521551ms step_avg:104.31ms +late_qat:enabled step:5141 scale:0.1498 soft_round=True est_steps:524 +step:5500/20000 train_loss:1.9267 train_time:573618ms step_avg:104.29ms +step:5668/20000 train_time:591136ms step_avg:104.29ms +stopping_early: wallclock_cap train_time:591136ms step:5668/20000 +peak memory allocated: 24764 MiB reserved: 24776 MiB +swa:applying averaged 14 checkpoints (source=raw) +Serialized model: 119281399 bytes +Code size: 130434 bytes +gptq:building non-banked model for Hessian collection... +gptq:calibrating with 64 batches... +gptq:collected hessians for 68 layers +gptq:budget_check train:591070ms + gptq:7470ms = 598540ms (budget:600000ms) +mixed_precision:estimate base=15.33MB + promoted=0.19MB = 15.52MB (budget=16.0MB, headroom=320KB, prune_room=+158729B) +mixed_precision: layer.0.mlp -> int5 (sensitivity=2.3771e+07) +mixed_precision: layer.1.mlp -> int5 (sensitivity=3.4570e+06) +mixed_precision: layer.2.mlp -> int5 (sensitivity=1.2335e+06) +mixed_precision: layer.3.mlp -> int5 (sensitivity=7.3738e+05) +mixed_precision: layer.4.mlp -> int5 (sensitivity=4.3265e+05) +mixed_precision: layer.6.mlp -> int5 (sensitivity=3.5973e+05) +mixed_precision: layer.5.mlp -> int5 (sensitivity=3.1436e+05) +mixed_precision: layer.7.mlp -> int5 (sensitivity=2.1511e+05) +mixed_precision: layer.0.attn -> int6 (sensitivity=1.5172e+05) +mixed_precision: layer.1.attn -> int5 (sensitivity=1.4245e+05) +mixed_precision: layer.2.attn -> int5 (sensitivity=1.1758e+05) +mixed_precision: layer.4.attn -> int5 (sensitivity=9.2693e+04) +mixed_precision: layer.3.attn -> int5 (sensitivity=8.0671e+04) +mixed_precision: layer.8.mlp -> int5 (sensitivity=6.3116e+04) +mixed_precision: layer.5.attn -> int5 (sensitivity=4.8939e+04) +mixed_precision: layer.6.attn -> int5 (sensitivity=4.1592e+04) +mixed_precision: layer.7.attn -> int5 (sensitivity=3.7880e+04) +mixed_precision: layer.8.attn -> int5 (sensitivity=2.8107e+04) +mixed_precision: layer.9.mlp -> int5 (sensitivity=2.5858e+04) +mixed_precision: layer.10.mlp -> int5 (sensitivity=2.4475e+04) +mixed_precision: layer.9.attn -> int5 (sensitivity=2.0457e+04) +mixed_precision: layer.10.attn -> int5 (sensitivity=1.4297e+04) +mixed_precision: int5:62 int6:4 (1 groups promoted) +selective_prune: 12317356 +/-1,+/-2 candidates, unpruned=15.27MB target=16.00MB +selective_prune: full prune=10.93MB +selective_prune: fast/real ratio=1.287 fast_target=17.88MB +selective_prune: pruning 103549/12317356 values (0.8%) to fit 16.00MB +Serialized model int6+brotli: 15851422 bytes +Total submission size int6+brotli: 15981856 bytes +final_int6_roundtrip val_loss:1.9187 val_bpb:1.1364 eval_time:49629ms +final_int6_roundtrip_exact val_loss:1.91873346 val_bpb:1.13638203 +final_int6_sliding_window val_loss:1.8786 val_bpb:1.1126 stride:64 eval_time:121946ms +final_int6_sliding_window_exact val_loss:1.87856613 val_bpb:1.11259562 diff --git a/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/train_seed42.log b/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/train_seed42.log new file mode 100644 index 0000000000..e60ac0b572 --- /dev/null +++ b/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/train_seed42.log @@ -0,0 +1,105 @@ +W0331 02:49:51.609000 96139 torch/distributed/run.py:803] +W0331 02:49:51.609000 96139 torch/distributed/run.py:803] ***************************************** +W0331 02:49:51.609000 96139 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0331 02:49:51.609000 96139 torch/distributed/run.py:803] ***************************************** +logs/4e4b46db-1cd1-45c6-96e3-62daacb8f3cb.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:30666843 +XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=True flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +muon:post_norm:row_col embed_beta1:0.7 head_beta1:0.7 +seed:42 +gptq:reserving 9000ms from training budget (effective cap: 591000ms) +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:6.9326 train_time:662ms step_avg:661.53ms +step:2/20000 train_loss:8.3597 train_time:702ms step_avg:350.84ms +step:3/20000 train_loss:7.2960 train_time:802ms step_avg:267.26ms +step:4/20000 train_loss:8.1308 train_time:903ms step_avg:225.87ms +step:5/20000 train_loss:8.3119 train_time:1005ms step_avg:200.97ms +step:6/20000 train_loss:7.7904 train_time:1107ms step_avg:184.42ms +step:7/20000 train_loss:7.1321 train_time:1207ms step_avg:172.50ms +step:8/20000 train_loss:6.5198 train_time:1309ms step_avg:163.64ms +step:9/20000 train_loss:6.1464 train_time:1410ms step_avg:156.65ms +step:10/20000 train_loss:5.7179 train_time:1512ms step_avg:151.15ms +step:500/20000 train_loss:2.3255 train_time:52134ms step_avg:104.27ms +step:1000/20000 train_loss:2.2197 train_time:104310ms step_avg:104.31ms +step:1500/20000 train_loss:2.1716 train_time:156476ms step_avg:104.32ms +step:2000/20000 train_loss:2.0175 train_time:208636ms step_avg:104.32ms +step:2500/20000 train_loss:2.1173 train_time:260749ms step_avg:104.30ms +step:3000/20000 train_loss:2.0978 train_time:312837ms step_avg:104.28ms +step:3500/20000 train_loss:2.1095 train_time:364875ms step_avg:104.25ms +step:4000/20000 train_loss:1.8982 train_time:416870ms step_avg:104.22ms +step:4500/20000 train_loss:2.0404 train_time:468855ms step_avg:104.19ms +swa:start step:5000 +step:5000/20000 train_loss:2.0163 train_time:520829ms step_avg:104.17ms +late_qat:enabled step:5149 scale:0.1498 soft_round=True est_steps:524 +step:5500/20000 train_loss:1.9274 train_time:572800ms step_avg:104.15ms +step:5675/20000 train_time:591101ms step_avg:104.16ms +stopping_early: wallclock_cap train_time:591101ms step:5675/20000 +peak memory allocated: 24756 MiB reserved: 24854 MiB +swa:applying averaged 14 checkpoints (source=raw) +Serialized model: 119281399 bytes +Code size: 130434 bytes +gptq:building non-banked model for Hessian collection... +gptq:calibrating with 64 batches... +gptq:collected hessians for 68 layers +gptq:budget_check train:591034ms + gptq:7671ms = 598705ms (budget:600000ms) +mixed_precision:estimate base=15.33MB + promoted=0.19MB = 15.52MB (budget=16.0MB, headroom=320KB, prune_room=+158729B) +mixed_precision: layer.0.mlp -> int5 (sensitivity=2.6946e+07) +mixed_precision: layer.1.mlp -> int5 (sensitivity=2.3823e+06) +mixed_precision: layer.2.mlp -> int5 (sensitivity=1.9261e+06) +mixed_precision: layer.3.mlp -> int5 (sensitivity=5.6601e+05) +mixed_precision: layer.4.mlp -> int5 (sensitivity=4.8485e+05) +mixed_precision: layer.6.mlp -> int5 (sensitivity=3.7520e+05) +mixed_precision: layer.5.mlp -> int5 (sensitivity=3.4501e+05) +mixed_precision: layer.1.attn -> int6 (sensitivity=1.9168e+05) +mixed_precision: layer.0.attn -> int5 (sensitivity=1.4614e+05) +mixed_precision: layer.7.mlp -> int5 (sensitivity=1.4134e+05) +mixed_precision: layer.3.attn -> int5 (sensitivity=1.2848e+05) +mixed_precision: layer.2.attn -> int5 (sensitivity=8.4035e+04) +mixed_precision: layer.8.mlp -> int5 (sensitivity=6.5881e+04) +mixed_precision: layer.5.attn -> int5 (sensitivity=5.1862e+04) +mixed_precision: layer.6.attn -> int5 (sensitivity=5.0278e+04) +mixed_precision: layer.4.attn -> int5 (sensitivity=5.0249e+04) +mixed_precision: layer.7.attn -> int5 (sensitivity=3.5763e+04) +mixed_precision: layer.9.mlp -> int5 (sensitivity=2.5019e+04) +mixed_precision: layer.10.mlp -> int5 (sensitivity=2.2116e+04) +mixed_precision: layer.9.attn -> int5 (sensitivity=2.1195e+04) +mixed_precision: layer.8.attn -> int5 (sensitivity=1.7439e+04) +mixed_precision: layer.10.attn -> int5 (sensitivity=1.2450e+04) +mixed_precision: int5:62 int6:4 (1 groups promoted) +selective_prune: 12299580 +/-1,+/-2 candidates, unpruned=15.28MB target=16.00MB +selective_prune: full prune=10.95MB +selective_prune: fast/real ratio=1.286 fast_target=17.87MB +selective_prune: pruning 124998/12299580 values (1.0%) to fit 16.00MB +Serialized model int6+brotli: 15853915 bytes +Total submission size int6+brotli: 15984349 bytes +final_int6_roundtrip val_loss:1.9183 val_bpb:1.1361 eval_time:7685ms +final_int6_roundtrip_exact val_loss:1.91829847 val_bpb:1.13612441 +final_int6_sliding_window val_loss:1.8780 val_bpb:1.1123 stride:64 eval_time:98372ms +final_int6_sliding_window_exact val_loss:1.87802743 val_bpb:1.11227657 diff --git a/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/train_seed999.log b/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/train_seed999.log new file mode 100644 index 0000000000..780dc3b707 --- /dev/null +++ b/records/track_10min_16mb/2026-03-30_V18_FusedTritonOp/train_seed999.log @@ -0,0 +1,105 @@ +W0331 03:07:47.307000 97232 torch/distributed/run.py:803] +W0331 03:07:47.307000 97232 torch/distributed/run.py:803] ***************************************** +W0331 03:07:47.307000 97232 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0331 03:07:47.307000 97232 torch/distributed/run.py:803] ***************************************** +logs/a736232f-96c1-46af-9453-31869b3bd615.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:30666843 +XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=True flash=True mem_efficient=True math=True +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +muon:post_norm:row_col embed_beta1:0.7 head_beta1:0.7 +seed:999 +gptq:reserving 9000ms from training budget (effective cap: 591000ms) +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:1/20000 train_loss:6.9309 train_time:652ms step_avg:652.43ms +step:2/20000 train_loss:8.4451 train_time:692ms step_avg:346.16ms +step:3/20000 train_loss:7.2380 train_time:793ms step_avg:264.39ms +step:4/20000 train_loss:8.3873 train_time:895ms step_avg:223.75ms +step:5/20000 train_loss:8.4942 train_time:997ms step_avg:199.36ms +step:6/20000 train_loss:7.8448 train_time:1100ms step_avg:183.29ms +step:7/20000 train_loss:7.1675 train_time:1203ms step_avg:171.79ms +step:8/20000 train_loss:6.6310 train_time:1303ms step_avg:162.92ms +step:9/20000 train_loss:6.2091 train_time:1405ms step_avg:156.10ms +step:10/20000 train_loss:5.7519 train_time:1506ms step_avg:150.60ms +step:500/20000 train_loss:2.3331 train_time:52256ms step_avg:104.51ms +step:1000/20000 train_loss:2.2237 train_time:104447ms step_avg:104.45ms +step:1500/20000 train_loss:2.1744 train_time:156696ms step_avg:104.46ms +step:2000/20000 train_loss:2.0222 train_time:208829ms step_avg:104.41ms +step:2500/20000 train_loss:2.1198 train_time:260951ms step_avg:104.38ms +step:3000/20000 train_loss:2.1015 train_time:313062ms step_avg:104.35ms +step:3500/20000 train_loss:2.1102 train_time:365127ms step_avg:104.32ms +step:4000/20000 train_loss:1.8988 train_time:417173ms step_avg:104.29ms +step:4500/20000 train_loss:2.0438 train_time:469180ms step_avg:104.26ms +swa:start step:5000 +step:5000/20000 train_loss:2.0165 train_time:521172ms step_avg:104.23ms +late_qat:enabled step:5145 scale:0.1499 soft_round=True est_steps:524 +step:5500/20000 train_loss:1.9289 train_time:573152ms step_avg:104.21ms +step:5672/20000 train_time:591082ms step_avg:104.21ms +stopping_early: wallclock_cap train_time:591082ms step:5672/20000 +peak memory allocated: 24756 MiB reserved: 24854 MiB +swa:applying averaged 14 checkpoints (source=raw) +Serialized model: 119281399 bytes +Code size: 130434 bytes +gptq:building non-banked model for Hessian collection... +gptq:calibrating with 64 batches... +gptq:collected hessians for 68 layers +gptq:budget_check train:591015ms + gptq:7682ms = 598697ms (budget:600000ms) +mixed_precision:estimate base=15.33MB + promoted=0.19MB = 15.52MB (budget=16.0MB, headroom=320KB, prune_room=+158729B) +mixed_precision: layer.0.mlp -> int5 (sensitivity=3.3350e+07) +mixed_precision: layer.1.mlp -> int5 (sensitivity=2.3080e+06) +mixed_precision: layer.2.mlp -> int5 (sensitivity=1.7590e+06) +mixed_precision: layer.3.mlp -> int5 (sensitivity=6.5708e+05) +mixed_precision: layer.4.mlp -> int5 (sensitivity=3.8841e+05) +mixed_precision: layer.5.mlp -> int5 (sensitivity=3.5902e+05) +mixed_precision: layer.6.mlp -> int5 (sensitivity=3.5574e+05) +mixed_precision: layer.1.attn -> int6 (sensitivity=1.6027e+05) +mixed_precision: layer.0.attn -> int5 (sensitivity=1.5045e+05) +mixed_precision: layer.7.mlp -> int5 (sensitivity=1.2484e+05) +mixed_precision: layer.3.attn -> int5 (sensitivity=8.4910e+04) +mixed_precision: layer.2.attn -> int5 (sensitivity=8.2695e+04) +mixed_precision: layer.4.attn -> int5 (sensitivity=8.1111e+04) +mixed_precision: layer.5.attn -> int5 (sensitivity=5.9694e+04) +mixed_precision: layer.8.mlp -> int5 (sensitivity=5.5861e+04) +mixed_precision: layer.7.attn -> int5 (sensitivity=4.6659e+04) +mixed_precision: layer.6.attn -> int5 (sensitivity=3.8782e+04) +mixed_precision: layer.9.mlp -> int5 (sensitivity=3.0594e+04) +mixed_precision: layer.10.mlp -> int5 (sensitivity=2.2785e+04) +mixed_precision: layer.8.attn -> int5 (sensitivity=1.6645e+04) +mixed_precision: layer.10.attn -> int5 (sensitivity=1.5905e+04) +mixed_precision: layer.9.attn -> int5 (sensitivity=1.5543e+04) +mixed_precision: int5:62 int6:4 (1 groups promoted) +selective_prune: 12311794 +/-1,+/-2 candidates, unpruned=15.28MB target=16.00MB +selective_prune: full prune=10.93MB +selective_prune: fast/real ratio=1.286 fast_target=17.87MB +selective_prune: pruning 105390/12311794 values (0.9%) to fit 16.00MB +Serialized model int6+brotli: 15855478 bytes +Total submission size int6+brotli: 15985912 bytes +final_int6_roundtrip val_loss:1.9189 val_bpb:1.1365 eval_time:7536ms +final_int6_roundtrip_exact val_loss:1.91889254 val_bpb:1.13647625 +final_int6_sliding_window val_loss:1.8790 val_bpb:1.1129 stride:64 eval_time:98091ms +final_int6_sliding_window_exact val_loss:1.87899921 val_bpb:1.11285212