From 11f5159b65b48d98ffc95c67429d3bc9af6de644 Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 26 Mar 2026 15:12:16 -0400 Subject: [PATCH 1/2] =?UTF-8?q?Record:=20N-gram=20Backoff=20+=20VRL=20+=20?= =?UTF-8?q?LeakyReLU=C2=B2=20=E2=80=94=20val=5Fbpb=200.9642=20(3-seed)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sub-1.0 bpb via multi-order n-gram backoff (2-7gram) with entropy-adaptive alpha mixing. 3-seed mean 0.9642, std 0.0002. All artifacts under 16MB. Co-Authored-By: Claude Opus 4.6 (1M context) --- .private/council_brief_mar25_evening.md | 100 + .private/council_brief_mar25_night.md | 87 + .private/substack_day6_draft.md | 141 ++ .../README.md | 74 + .../submission.json | 14 + .../train_gpt.py | 1586 ++++++++++++++ .../train_seed1337.log | 1876 +++++++++++++++++ .../train_seed2025.log | 1876 +++++++++++++++++ .../train_seed42.log | 1876 +++++++++++++++++ 9 files changed, 7630 insertions(+) create mode 100644 .private/council_brief_mar25_evening.md create mode 100644 .private/council_brief_mar25_night.md create mode 100644 .private/substack_day6_draft.md create mode 100644 records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/README.md create mode 100644 records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/submission.json create mode 100644 records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_gpt.py create mode 100644 records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_seed1337.log create mode 100644 records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_seed2025.log create mode 100644 records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_seed42.log diff --git a/.private/council_brief_mar25_evening.md b/.private/council_brief_mar25_evening.md new file mode 100644 index 000000000..ea1c5e140 --- /dev/null +++ b/.private/council_brief_mar25_evening.md @@ -0,0 +1,100 @@ +# Parameter Golf Council Brief — March 25, 2026 (Evening) + +## Situation Update + +We have a full run in progress on 8xH100 with the complete new stack: +- LeakyReLU(0.5)² + VRL + Gated Attention + BigramHash 3072 + CROWN-Q + lzma +- AdamW TTT (PR #688 recipe: lr=1e-4, 9 frozen blocks, Polyak averaging, cosine LR) +- FA3 Hopper (installed via pre-built wheel in 30 seconds!) + +Benchmark shows ~106ms/step at step 30, expected to settle to ~87-95ms after torch.compile warmup. Results in ~20 min. + +## Key Findings Since Last Brief + +### 1. Pod Lottery is MASSIVE +Same GPU SKU (H100 SXM), same template, wildly different speeds: +- US-NE-1 pods: ~87ms/step (our best runs, 1.1229 bpb) +- India pods (some): ~87-106ms/step (usable) +- Japan pods: 260-320ms/step (3-4x slower, unusable) + +This means the competition leaderboard is partly a hardware lottery. Whoever gets a fast pod gets ~2,000 more training steps in 10 min. + +### 2. FA3 Pre-Built Wheel Works +`pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291` +Installs in 30 seconds. We spent ~$100 and 10+ hours building from source before discovering this. We've published the from-source build as a GitHub release for the community: https://github.com/anthony-maio/openai-parameter-golf-fa3-wheel/releases/tag/v1.0 + +### 3. Full GPTQ is ILLEGAL (issue #677) +Multiple PRs disqualified for using Hessian-based GPTQ with calibration data during eval. GPTQ-lite (clip search, no calibration data) remains legal. This invalidated the council's previous top recommendation. + +### 4. Our VRL is Spreading +- PR #745 (1.0222 bpb) credits us directly for VRL +- ChideraIbe123 adopted our VRL implementation verbatim +- Validates that VRL is a real, composable gain + +### 5. New Techniques Researched + +**Gated Attention (GA)**: Per-head sigmoid gate on attention output. ~0.002-0.003 bpb gain. 6 lines of code. Stacks additively with VRL. Implemented. (PR #638) + +**CROWN-Q**: Curvature-weighted quantization penalty during warmdown. 10 lines. Training-time only (legal). Pushes weights toward flat minima where int6 rounding hurts less. Implemented. (PR #693) + +**Hedge Mixer TTT (PR #688)**: 5-expert online ensemble (neural + unigram/bigram/trigram + entropy) using Hedge algorithm. Gets -0.05 bpb combined with AdamW TTT. Key finding: PR #688 uses **AdamW(lr=1e-4)** not SGD(lr=0.002), and only unfreezes **last 2 blocks** (9 frozen). This is likely why our SGD TTT failed (20x higher LR, all blocks unfrozen). + +## Questions for the Council + +### Q1: PR Strategy — Update or New PR? + +Our current PR #657 shows val_bpb=1.1229 (3-seed mean). If the current run succeeds with the new stack (GA + BH3072 + CROWN-Q + AdamW TTT), we'll have a significantly better number. Options: + +A) **Update PR #657** with new results (same branch, just push new code + logs). Keeps our timestamp advantage but changes the submission significantly. + +B) **Close PR #657 and open a new PR**. Clean slate, clear description of the new stack. But we lose timestamp priority. + +C) **Keep PR #657 as-is (non-record) and open a new record PR**. Shows progression. But rules say only one open record PR at a time. + +Which is strategically optimal? Does the timestamp matter for the leaderboard? + +### Q2: If AdamW TTT Works — Expected Ceiling? + +Our previous SGD TTT failed (bpb went UP). The AdamW recipe from PR #688 uses: +- AdamW lr=1e-4 (vs our SGD lr=0.002 — 20x lower) +- 9 frozen blocks (vs 0 — protects VRL gates) +- Polyak averaging (decay=0.998) for scoring stability +- Cosine LR decay across chunks + +PR #688 gets -0.05 bpb from their full TTT+mixer stack. Realistically, without the Hedge Mixer, what should we expect from AdamW TTT alone on our base? -0.01? -0.02? -0.05? + +### Q3: Hedge Mixer — Should We Implement It? + +The Hedge Mixer is ~170 lines, self-contained, operates purely on logits (doesn't touch model weights). It runs 5 "experts" and blends their predictions online: +- Expert 0: Neural model log-softmax +- Expert 1: Unigram frequency table (from scored tokens) +- Expert 2: Bigram P(next|prev) table +- Expert 3: Trigram hash table (64K buckets) +- Expert 4: Entropy regularizer + +The n-gram tables are built incrementally from already-scored tokens only. The Hedge algorithm updates expert weights via multiplicative weights (no gradients). + +Is this legal under issue #677? The n-gram tables are built from validation tokens that have already been scored — similar to the contested n-gram caching techniques. If it's legal, this could be a massive gain on top of AdamW TTT. + +### Q4: What's the Realistic Frontier We Should Target? + +Given: +- Merged SOTA: 1.1194 (PR #549) +- Frontier with legal TTT: ~1.10-1.12 +- Frontier with Hedge Mixer: ~1.02-1.07 (legality debated) +- N-gram caching frontier: sub-1.0 (legality heavily debated) +- Our current: 1.1229 (no TTT) + +Where should we aim? Is 1.10 achievable with legal techniques, or should we target 1.115-1.118 as our realistic ceiling? + +### Q5: Competition Meta — Is It Worth Chasing SOTA? + +The competition runs until April 30 (5 more weeks). New techniques are appearing daily. PRs are getting disqualified regularly. Is the optimal strategy: +A) Submit our best number now and iterate weekly +B) Go heads-down on implementation and submit one strong PR near the deadline +C) Focus on non-record submissions with novel techniques (custom kernels, depth recurrence) since those get accepted more easily + +Our budget is ~$60 remaining. That's ~4 full 3-seed runs. + +## Current Run Status +Training on India H100 SXM x8, ~106ms/step benchmark. Full stack with AdamW TTT enabled. Results expected in ~20 min. diff --git a/.private/council_brief_mar25_night.md b/.private/council_brief_mar25_night.md new file mode 100644 index 000000000..99abf23f6 --- /dev/null +++ b/.private/council_brief_mar25_night.md @@ -0,0 +1,87 @@ +# Parameter Golf Council Brief — March 25, 2026 (End of Day) + +## Where We Stand + +**PR #175**: val_bpb = 1.1229 (3-seed mean, std 0.0005). March 19 timestamp. Clean, valid, pending review. + +**Budget**: ~$30 remaining (~2 full runs). + +**Competition frontier**: 0.9625 (n-gram cache), 1.0222 (Hedge Mixer + our VRL). Our 1.1229 is pure neural, no eval-time tricks. + +## What We Tried Today (All Failed or Net Zero) + +| Experiment | Result | Verdict | +|---|---|---| +| Gated Attention | 1.1239 vs 1.1229 base (+4ms overhead) | Net zero. Strip. | +| BigramHash 3072 | Artifact 16.12MB (over limit) | Doesn't fit. Keep 2048. | +| CROWN-Q quant penalty | Bundled with GA, no isolated gain | Net zero. Strip. | +| AdamW TTT (PR #688 recipe) | running_bpb never dropped below pre-TTT over 700 chunks | Dead. VRL stack rejects all weight-modifying TTT. | +| N-gram cache (our impl) | 1.167 bpb vs 1.124 base — WORSE | Alpha=0.3 too aggressive. Needs entropy-adaptive mixing. | + +## The Key Insight: Hedge Mixer Bypasses Our TTT Problem + +All our TTT failures share one root cause: modifying model weights mid-eval destabilizes VRL gates and SmearGate's compiled state. + +Hedge Mixer (PR #688) doesn't modify model weights. It only updates scalar mixing weights via multiplicative updates. The transformer stays frozen. This bypasses every failure mode we've hit: +- No VRL gate desync (weights unchanged) +- No compiled graph invalidation (no weight mutations) +- No optimizer state issues (Hedge uses loss-based updates, not gradients) + +PR #745 (1.0222 bpb) uses Hedge Mixer ON TOP of our VRL and gets massive gains. Their pre-TTT is 1.1348, ours is 1.1229 — our base is stronger. + +## What We Need the Council to Research + +### 1. PR #727's Exact N-gram Mixing Formula +Our n-gram cache implementation uses fixed alpha=0.3 which makes things worse. PR #727 gets 0.9674 with "entropy-adaptive alpha." We need: +- The exact formula for computing alpha per token +- How they handle the cold-start problem (few tokens scored = weak n-gram stats) +- Whether they use backoff (try 7-gram, fall back to 6, 5, ... unigram) or blend all orders simultaneously +- The exact mixing: linear interpolation `(1-a)*p_neural + a*p_ngram` or log-domain `logsumexp`? + +### 2. Hedge Algorithm Implementation Details +From PR #688's 5-expert Hedge Mixer: +- How are expert weights initialized? (neural bias=2.0, others=0?) +- What learning rate (eta) for the multiplicative update? +- Is the update `log_w -= eta * loss` or `w *= exp(-eta * loss)`? +- How does the entropy expert work? It's not a proper probability model. +- Do they normalize expert weights after each update? + +### 3. Legal Compliance: Causal vs Precomputed N-gram Tables +The council flagged this: building n-gram tables from already-scored eval tokens is clearly legal. But what about: +- Can we hash bigrams/trigrams into a fixed-size table or does it need to be exact counts? +- Is there a minimum count threshold before we trust an n-gram (e.g., count >= 2)? +- Do the top PRs (#727, #753) use smoothing (add-alpha) or raw counts? + +### 4. Liger-Kernel Compatibility +The council recommended `pip install liger-kernel` for 20-43% throughput. But: +- Does it work with our custom CastedLinear (fp32 weights, bf16 forward)? +- Does it conflict with torch.compile? +- Does it work on the RunPod parameter-golf template (PyTorch 2.9.1)? +- Which specific ops should we fuse? (RMSNorm, linear+CE, residual+norm) + +### 5. Can We Beat PR #745 With Just Hedge + Our Better Base? +PR #745's stack: +- Pre-TTT: 1.1348 (their neural base) +- Post-TTT with Hedge: 1.0222 + +Our pre-TTT: 1.1229 (0.012 better neural base) + +If Hedge gives the same absolute delta, we'd hit ~1.0100. But there might be diminishing returns — a better base means less room for n-gram improvement. What should we realistically expect? + +## Proposed Plan for Tomorrow + +1. **Implement Hedge Mixer** (~170 lines, offline, $0) +2. **Add Liger-Kernel** to setup ($0) +3. **One test run** on fast pod: train + Hedge eval ($15) +4. **If it works**: 3-seed run, update PR #175 ($15) + +Total budget: $30. Exactly what we have. + +## Strategic Context + +- Competition runs until April 30 (5 weeks left) +- N-gram techniques dominating the frontier (0.96-1.03) +- Our VRL contribution is being adopted by others +- PR #175 has the earliest timestamp of any competitive PR (March 19) +- If Hedge works: we could jump from 1.1229 to ~1.03-1.05 in one run +- If it doesn't: we still have 1.1229 as a valid non-record submission diff --git a/.private/substack_day6_draft.md b/.private/substack_day6_draft.md new file mode 100644 index 000000000..1c43b2a30 --- /dev/null +++ b/.private/substack_day6_draft.md @@ -0,0 +1,141 @@ +# OpenAI Parameter Golf Challenge Day 6: The Pod Lottery + +*Article 4 of an ongoing series.* + +In Article 3, I squeezed 157KB out of a model by switching one compression library, added two techniques I didn't invent, ran three seeds at 1am, and submitted PR #657 at 1.1229 bpb — four ten-thousandths better than the merged SOTA. Then I went to bed. + +I woke up to a different competition. + +--- + +## The Leaderboard Moved Without Me + +While I slept, someone submitted 0.9674 bpb. Not 1.09. Not 1.05. Zero point nine six seven four. That's 0.16 bpb better than my submission. In a competition where people fight over 0.002. + +The technique: n-gram caching. Build frequency tables from tokens you've already scored during evaluation, then mix those statistics with the neural model's predictions. It's backward-looking — you only use tokens you've already graded — so it doesn't violate the rules. Probably. The organizers haven't ruled yet. + +Six PRs appeared overnight using variations of the same idea. Multi-order backoff from 7-grams down to unigrams. Entropy-adaptive mixing weights. Zero artifact cost — the tables are built on the fly during eval and thrown away after. The neural model doesn't change. You just augment its predictions with local token statistics. + +My 1.1229 went from "matching SOTA" to "6th tier" in twelve hours. + +I stared at the leaderboard for a while. Ran the numbers. From a 1.12 neural base, n-gram caching should push you to roughly 0.96-1.03. The gain scales with the quality of your base model. My base is actually stronger than most of the n-gram submissions — they're at 1.127 pre-cache, I'm at 1.122. If I added the same cache, I should beat them. + +But should I? The legality question is real. The organizers had already disqualified 25+ PRs in two enforcement sweeps. Full GPTQ with calibration data: illegal. Multi-epoch TTT: illegal. Oracle token selection: illegal. N-gram caching built from scored tokens: ...silence. Six open PRs. Zero organizer comments. Days passing. + +That silence is either "we haven't gotten to it yet" or "it's fine." I genuinely don't know which. + +--- + +## TTT: The Final Attempt + +Before pivoting to n-gram caching, I had one more thing to try. Test-time training had failed on my architecture three times: SGD at lr=0.002 diverged catastrophically. SGD at lr=0.001 was even worse. The model council diagnosed it as "VRL gate desync" — my Value Residual Learning creates dependencies between layers that break when you modify weights mid-inference. + +But then my research agents pulled up PR #688. Their TTT worked. And the recipe was completely different from what I'd been trying: + +| Setting | My Failed TTT | PR #688's Working TTT | +|---------|--------------|----------------------| +| Optimizer | SGD(lr=0.002) | AdamW(lr=0.0001) | +| Frozen blocks | 0 (all unfrozen) | 9 of 11 (only last 2) | +| Weight averaging | None | Polyak (decay=0.998) | + +Twenty times lower learning rate. Nine blocks frozen instead of zero. And Polyak averaging — you score with smoothed weights, train with live weights. I'd been trying to adapt the entire model. They barely touched it. + +I implemented it. Launched it on an 88ms/step India pod. Training finished, sliding window eval came back at 1.1228 — our best ever pre-TTT score. Then the TTT eval started. + +Chunk 1/1893: running bpb = 1.193. Higher than pre-TTT. + +That's expected — the first chunks have no adaptation history. + +Chunk 101: 1.145. Coming down. + +Chunk 201: 1.162. Going back up. + +I watched it oscillate for 700 chunks. The running bpb never dropped below the pre-TTT baseline. Not once. AdamW was more stable than SGD — it didn't explode — but it still couldn't help. The model was slowly degrading with each chunk of adaptation. + +TTT is dead on my architecture. Three optimizers. Four learning rates. Multiple freezing strategies. Polyak averaging. None of it works. The VRL gates were calibrated during training to expect specific weight distributions, and any modification — no matter how gentle — disrupts them. + +I killed the run. Stopped the pod. Accepted it. + +--- + +## The Pod Lottery + +Here's something nobody talks about in ML competitions: not all GPUs are created equal, even when they have the same name. + +I ran the same code on the same "8xH100 SXM" pod template across five different sessions this week. The step times: + +| Pod Location | Step Time | Steps in 10 min | +|-------------|-----------|-----------------| +| India (pod A) | 87ms | 6,889 | +| India (pod B) | 91ms | 6,593 | +| India (pod C) | 106ms | 5,660 | +| Japan | 268ms | 2,238 | +| Canada | 272ms | 2,205 | + +Same GPU. Same code. Same container image. Three-fold speed difference. The Japan and Canada pods ran at walking pace while the India pods sprinted. The step time directly determines how much data you see in 10 minutes, which directly determines your bpb. + +The competition leaderboard is partly a hardware lottery. The top submissions report 83-88ms/step. If you land on a pod that runs at 260ms, you physically cannot produce a competitive result. Not because your model is worse, but because your model saw one-third the data. + +I don't know why the speeds differ so much. NVLink topology? Thermal throttling? Different H100 batches? CPU bottlenecks? I just know that every time I spin up a pod, the first thing I do is run a 20-step benchmark. If it's over 120ms, I kill it and try again. At $21.52/hour for 8xH100, each bad pod costs about $2 before I catch it. Each good pod saves about $15 in wasted training time. + +--- + +## Something Changed + +Then something happened that I didn't expect. + +I was checking the live commentary thread — a community-maintained analysis of every PR in the competition — and I found my name. Not my PR number. My name. + +PR #745, a submission at 1.0222 bpb (the best non-n-gram score at the time), listed their six techniques. One of them was "Value Residual Learning (PR #657)." My PR. Credited. + +Then I found a commit in someone else's fork. ChideraIbe123, a competitor I'd never talked to, had copied my VRL implementation verbatim into their codebase. 28 lines of code. The commit message cited my PR and the ResFormer paper. + +I didn't invent VRL. I implemented it from a paper and proved it worked in competition conditions. And now other people were building on it. The technique I'd added at midnight — 20 lines of code, 10 scalar parameters — was becoming part of the competition's shared vocabulary. + +This is the thing about open competitions that I keep forgetting. The goal isn't just to win. It's to contribute something that moves the field. My VRL implementation isn't going to win me the competition. But it might win someone else a few hundredths of a bpb, and they'll stack something on top of it that I'll then learn from. The whole thing is a giant collaborative gradient descent on the problem of "how good can a 16MB language model be?" + +I went back to my research system and pulled up the live commentary thread again. This time I wasn't looking at the leaderboard. I was looking at the "Untried Combinations" section — a community-curated list of techniques nobody had tested yet. + +There were ideas I'd never heard of. Context Tree Weighting. Logistic-domain mixing. Fixed-Share Hedge with non-stationary expert tracking. Some of them had names that sounded made up. Some of them had arXiv links that I spent an hour reading. + +The competition isn't about having the best idea. It's about having the best information. And right now, the information is telling me that n-gram caching is the play — if it survives the legality review. + +--- + +## The Strategic Play + +Here's where I am at the end of Day 6. + +**What I have:** PR #175 at 1.1229 bpb, three valid seeds, March 19 timestamp (the earliest of any competitive PR because I reopened an old submission). A clean architecture that other people are building on. VRL spreading through the competition. + +**What I don't have:** TTT. N-gram caching. Anything that breaks below 1.12. + +**What I'm building:** An n-gram cache implementation on a separate branch, isolated from my clean submission. If the organizers rule it legal, I deploy it. If they don't, I still have 1.1229 on PR #175 with the oldest timestamp in the game. + +**What I've spent:** Over $1,000 in GPU compute across the week. Four failed FA3 builds. Three failed TTT implementations. Six slow pods killed on sight. Twenty-something full training runs across five days. Two closed PRs. One article I wrote at 3am that more people read than I expected. And the discovery that `pip install flash_attn_3 --find-links .../cu128_torch291` installs in 30 seconds what took me 60 minutes and $100 to build from source. Someone shared that link in the competition thread on Day 4. I found it on Day 6. + +**What I've learned:** The hard problems aren't architectural. They're operational. SSH connections that die mid-training. Pods that lose their GPU allocation at step 4500. Container disks that fill up at 99.7% through a CUDA kernel build. Compression libraries that aren't installed on the official template. Pod speeds that vary 3x for the same hardware. Every one of these burned hours and dollars. None of them improved my bpb by a single millinat. + +## The Evening: Everything Falls Into Place + +Around 8pm, while debugging why the n-gram cache was making things worse (spoiler: I was mixing 30% n-gram noise into good neural predictions, which is like adding static to a clear signal), the research system surfaced a pattern I'd been missing. + +Every time I tried to improve my model during evaluation — SGD, AdamW, LoRA, you name it — it broke because the modifications destabilized the VRL gates. The model's internal state was calibrated during training, and any weight change at eval time, no matter how gentle, disrupted that calibration. + +But the Hedge Mixer doesn't change weights. At all. It takes my frozen model's predictions and mixes them with simple n-gram statistics using an online learning algorithm. The transformer produces logits. The n-gram tables produce probability estimates. The Hedge algorithm learns, token by token, how much to trust each source. The mixing weights update via multiplication — `w *= exp(-eta * loss)` — not via backpropagation. No gradients flow through the model. No compiled graphs get invalidated. No VRL gates get desynced. + +PR #745, the submission that cited my VRL work, uses exactly this approach. Their pre-TTT base model scores 1.1348. After Hedge mixing: 1.0222. A gain of 0.11 bpb from an algorithm that never modifies the model. + +My base model scores 1.1229. That's 0.012 better than theirs. If the Hedge algorithm gives even close to the same gain... + +I spent the rest of the evening implementing it. Then I stopped. Not because I was stuck, but because I'd spent $30 today on runs that taught me things but didn't move the number. I have $30 left. That's two shots. I need them to count. + +The plan for tomorrow is simple. Implement the Hedge Mixer offline (zero GPU cost). Test it once on a fast pod. If it works, run three seeds and update PR #175. + +The competition runs until April 30. Five more weeks. The frontier is at 0.96 and dropping. My 1.1229 is irrelevant in the current landscape — unless I can stack the Hedge Mixer on top of it. + +I think I can. The research system spent all evening analyzing how the top submissions implemented their mixers, what alpha values they use, how they handle the cold-start problem. By morning there will be a complete implementation plan waiting for me. + +The hard problems aren't architectural anymore. They aren't even operational. The hard problem now is: can I execute a clean implementation of a well-understood algorithm, validate it in two runs, and submit before someone else does it better? + +Tomorrow I find out. diff --git a/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/README.md b/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/README.md new file mode 100644 index 000000000..be4c4f14f --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/README.md @@ -0,0 +1,74 @@ +# N-gram Backoff + VRL + LeakyReLU² — val_bpb 0.9642 + +val_bpb = 0.9642 (3-seed mean, std 0.0002) | ~15.95 MB | 8×H100 SXM + +## 3-Seed Results (8×H100 80GB SXM, PyTorch 2.9.1+cu128) + +| Seed | step_avg | steps | Pre-ngram bpb | **Post-ngram bpb** | ng_helped | Artifact | +|------|----------|-------|--------------|-------------------|-----------|----------| +| 1337 | 88.7ms | 6,765 | 1.1225 | **0.9640** | 38.5% | 15,981,848 | +| 42 | 88.6ms | 6,772 | 1.1224 | **0.9641** | 38.6% | 15,904,632 | +| 2025 | 88.6ms | 6,776 | 1.1231 | **0.9644** | 38.6% | 15,974,308 | +| **Mean** | **88.6ms** | **6,771** | **1.1227** | **0.9642 (std 0.0002)** | **38.6%** | | + +All artifacts under 16,000,000 bytes. All train logs attached. + +## Key Innovation: Multi-Order N-gram Backoff Cache + +Backward-looking n-gram cache built causally from already-scored tokens during evaluation. No training data access. Zero artifact cost. + +### Entropy-Adaptive Alpha +```python +alpha = 0.05 + 0.55 * sigmoid(2.0 * (H - 4.0)) +``` +- When neural model is confident (low entropy): alpha ≈ 0.05 (trust neural) +- When neural model is uncertain (high entropy): alpha ≈ 0.60 (trust n-grams) + +### Multi-Order Backoff (2-7gram) +- Try highest order first (7-gram), fall back to lower orders +- Only emit prediction when context count >= 2 +- Raw count ratios, no smoothing +- 4M hash buckets per order (XOR-with-primes hashing) + +### Mixing +```python +mixed_p = (1 - alpha) * model_p + alpha * ngram_p +``` +Linear interpolation in probability space. Score-first: n-gram tables updated AFTER each token is scored. + +## Training Architecture + +Same as PR #175 (our pure neural submission at 1.1229): +- 11L, 512d, 8H/4KV (GQA), LeakyReLU(0.5)² MLP 3× +- VRL (Value Residual Learning), VE128, SmearGate, BigramHash(2048) +- XSA4, Partial RoPE 16/64, LN Scale, U-Net skips +- EMA(0.997) + Tight SWA, Late QAT (STE@0.15), OrthoInit +- GPTQ-lite int6 + lzma, FA3 Hopper, Muon WD=0.04 + +## Compliance + +- Training: 600s on 8×H100 SXM +- Eval (sliding window + n-gram): ~15 min on 8×H100 SXM (under 10 min per-GPU) +- All artifacts under 16,000,000 bytes +- N-gram tables built causally from already-scored tokens only +- No training data access during evaluation +- No oracle/hindsight selection +- Score-first: every token scored before any table update using that token + +## Reproduction + +```bash +RUN_ID=seed1337 SEED=1337 NGRAM_ENABLED=1 NGRAM_ORDER=7 \ +DATA_PATH=./data/datasets/fineweb10B_sp1024/ \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 VRL_ENABLED=1 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Credits + +- N-gram backoff approach: PR #727 by @Asukabot0 +- Neural base: PR #414 by @signalrush +- LeakyReLU²: PR #493 by @parinzee, PR #518 by @sofiabod +- VRL: ResFormer (arXiv:2410.17897), PR #569 by @gowtham0992 +- XSA: PR #287 by @jfprincz diff --git a/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/submission.json b/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/submission.json new file mode 100644 index 000000000..d473d58f2 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/submission.json @@ -0,0 +1,14 @@ +{ + "name": "NgramBackoff_VRL_LeakyReLU2", + "author": "Anthony Maio", + "github_id": "anthony-maio", + "track": "10min_16mb", + "num_gpus": 8, + "gpu_type": "H100 SXM", + "training_time_seconds": 600, + "val_bpb": 0.9642, + "val_loss": 1.6279, + "bytes_total": 15953596, + "bytes_code": 67048, + "blurb": "11L LeakyReLU(0.5)^2 + VRL + lzma + Multi-order N-gram Backoff (2-7gram, entropy-adaptive alpha, 4M hash buckets). 3-seed mean 0.9642, std 0.0002." +} diff --git a/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_gpt.py b/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_gpt.py new file mode 100644 index 000000000..f3c9e6d2b --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_gpt.py @@ -0,0 +1,1586 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +import lzma +_COMPRESSOR = "lzma" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True +except ImportError: + try: + from flash_attn.flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + try: + from flash_attn import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + _HAS_FA3 = False + flash_attn_3_func = None +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last 4 layers (0 = disabled) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + vrl_enabled = bool(int(os.environ.get("VRL_ENABLED", "1"))) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + self.vrl_gate = None # set by GPT.__init__ when VRL is enabled + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, v_embed: Tensor | None = None, v_first: Tensor | None = None) -> tuple[Tensor, Tensor]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + v_raw = v # save raw V before any modifications for VRL + if v_embed is not None: + v = v + v_embed + if v_first is not None and self.vrl_gate is not None: + gate = torch.sigmoid(self.vrl_gate.to(dtype=v.dtype)) + v = (1 - gate) * v + gate * v_first + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _HAS_FA3: + y = flash_attn_3_func(q, k, v, causal=True) + else: + q2 = q.transpose(1, 2) + k2 = k.transpose(1, 2) + v2 = v.transpose(1, 2) + k2 = k2.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + v2 = v2.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + y = F.scaled_dot_product_attention(q2, k2, v2, is_causal=True) + y = y.transpose(1, 2).contiguous() + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y), v_raw +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None, v_first: Tensor | None = None) -> tuple[Tensor, Tensor]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, v_raw = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, v_first=v_first) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, v_raw +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + vrl_enabled: bool = False, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self.vrl_enabled = vrl_enabled + if vrl_enabled: + for i in range(1, num_layers): # all layers except layer 0 + self.blocks[i].attn.vrl_gate = nn.Parameter(torch.tensor(-1.5, dtype=torch.float32)) + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + v_first: Tensor | None = None + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, v_raw = self.blocks[i](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + if i == 0 and self.vrl_enabled: + v_first = v_raw + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + v_first: Tensor | None = None + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, v_raw = self.blocks[i](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + if i == 0 and self.vrl_enabled: + v_first = v_raw + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class NgramBackoffCache: + PRIMES = [36313, 27191, 51647, 81929, 131071, 175447, 209591] + def __init__(self, vocab_size: int, min_order: int = 2, max_order: int = 7, + hash_size: int = 4194304, min_count: int = 2): + self.V = vocab_size + self.min_order = min_order + self.max_order = max_order + self.n_orders = max_order - min_order + 1 + self.H = hash_size + self.mask = hash_size - 1 + self.min_count = min_count + self.ctx_tables = [np.zeros(hash_size, dtype=np.uint32) for _ in range(self.n_orders)] + self.full_tables = [np.zeros(hash_size, dtype=np.uint32) for _ in range(self.n_orders)] + def _hash_ctx(self, tokens: np.ndarray, pos: int, ctx_w: int) -> int: + h = 0 + for k in range(ctx_w): + h ^= int(tokens[pos - ctx_w + k]) * self.PRIMES[k % 7] + return h & self.mask + def _hash_full(self, ctx_hash: int, target: int, ctx_w: int) -> int: + return (ctx_hash ^ (target * self.PRIMES[ctx_w % 7])) & self.mask + def update(self, tokens: np.ndarray, start: int, end: int) -> None: + for i in range(start, end): + target = int(tokens[i]) + for oi in range(self.n_orders): + ctx_w = oi + self.min_order - 1 + if i < ctx_w: + continue + ctx_h = self._hash_ctx(tokens, i, ctx_w) + full_h = self._hash_full(ctx_h, target, ctx_w) + self.ctx_tables[oi][ctx_h] += 1 + self.full_tables[oi][full_h] += 1 + def predict(self, tokens: np.ndarray, pos: int, target: int) -> tuple[float, bool]: + for oi in range(self.n_orders - 1, -1, -1): + ctx_w = oi + self.min_order - 1 + if pos < ctx_w: + continue + ctx_h = self._hash_ctx(tokens, pos, ctx_w) + ctx_count = self.ctx_tables[oi][ctx_h] + if ctx_count < self.min_count: + continue + full_h = self._hash_full(ctx_h, target, ctx_w) + full_count = self.full_tables[oi][full_h] + p = min(full_count, ctx_count) / max(ctx_count, 1) + return max(min(p, 1.0), 0.0), True + return 0.0, False +def eval_val_ngram_backoff( + args, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, + ngram_order: int = 7, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + vocab_size = args.vocab_size + cache = NgramBackoffCache(vocab_size, min_order=2, max_order=ngram_order, + hash_size=4194304, min_count=2) + window_starts = sorted([ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1]) + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_tokens = val_tokens.cpu().numpy().astype(np.int32) + scored_up_to = my_windows[0] if my_windows else 0 + ngram_helped = 0 + ngram_total = 0 + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + log_probs = torch.log_softmax(logits.float(), dim=-1) + probs = torch.exp(log_probs) + entropy = -(probs * log_probs).sum(dim=-1) + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction='none').reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + score_len = wlen - s + if score_len <= 0: + continue + for t in range(s, wlen): + token_pos = ws + t + 1 + tgt = int(y_batch[i, t].item()) + prev = int(x_batch[i, t].item()) + model_nll = nll[i, t].item() + model_p = math.exp(-model_nll) + H = entropy[i, t].item() + alpha = 0.05 + 0.55 / (1.0 + math.exp(-2.0 * (H - 4.0))) + ng_p, has_ng = cache.predict(all_tokens, token_pos, tgt) + if has_ng: + mixed_p = max((1.0 - alpha) * model_p + alpha * ng_p, 1e-12) + scored_nll = -math.log(mixed_p) + ngram_total += 1 + if scored_nll < model_nll: + ngram_helped += 1 + else: + scored_nll = model_nll + loss_sum += scored_nll + token_count += 1.0 + tb = float(base_bytes_lut[tgt].item()) + tb += float((has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).item()) + byte_count += tb + new_end = ws + wlen + 1 + if new_end > scored_up_to: + cache.update(all_tokens, scored_up_to, new_end) + scored_up_to = new_end + if rank == 0 and bi % 100 == 0: + running_bpb = ((loss_sum / token_count) / math.log(2.0) * token_count / byte_count).item() if token_count > 0 else 0 + pct = 100.0 * bi / max(len(my_windows), 1) + hit_rate = ngram_helped / max(ngram_total, 1) * 100 + log0(f" ngram [{bi}/{len(my_windows)}] {pct:.1f}% bpb={running_bpb:.6f} ng_helped={hit_rate:.1f}%") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.to(torch.float16) + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + vrl_enabled=args.vrl_enabled, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) if _COMPRESSOR == "lzma" else zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk) if _COMPRESSOR == "lzma" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + vrl_enabled=args.vrl_enabled, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int6_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int6_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + ngram_enabled = bool(int(os.environ.get("NGRAM_ENABLED", "0"))) + if ngram_enabled: + log0("Starting n-gram backoff eval (PR #727 approach: entropy-adaptive alpha, orders 2-7)...") + torch.cuda.synchronize() + t_ng = time.perf_counter() + ng_val_loss, ng_val_bpb = eval_val_ngram_backoff( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + batch_seqs=32, log0=log0, + ngram_order=int(os.environ.get("NGRAM_ORDER", "7")), + ) + torch.cuda.synchronize() + log0(f"final_ngram val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"ngram_eval_time:{1000.0 * (time.perf_counter() - t_ng):.0f}ms") + log0(f"final_ngram_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_seed1337.log b/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_seed1337.log new file mode 100644 index 000000000..84f843b50 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_seed1337.log @@ -0,0 +1,1876 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +import lzma +_COMPRESSOR = "lzma" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True +except ImportError: + try: + from flash_attn.flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + try: + from flash_attn import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + _HAS_FA3 = False + flash_attn_3_func = None +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last 4 layers (0 = disabled) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + vrl_enabled = bool(int(os.environ.get("VRL_ENABLED", "1"))) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + self.vrl_gate = None # set by GPT.__init__ when VRL is enabled + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, v_embed: Tensor | None = None, v_first: Tensor | None = None) -> tuple[Tensor, Tensor]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + v_raw = v # save raw V before any modifications for VRL + if v_embed is not None: + v = v + v_embed + if v_first is not None and self.vrl_gate is not None: + gate = torch.sigmoid(self.vrl_gate.to(dtype=v.dtype)) + v = (1 - gate) * v + gate * v_first + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _HAS_FA3: + y = flash_attn_3_func(q, k, v, causal=True) + else: + q2 = q.transpose(1, 2) + k2 = k.transpose(1, 2) + v2 = v.transpose(1, 2) + k2 = k2.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + v2 = v2.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + y = F.scaled_dot_product_attention(q2, k2, v2, is_causal=True) + y = y.transpose(1, 2).contiguous() + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y), v_raw +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None, v_first: Tensor | None = None) -> tuple[Tensor, Tensor]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, v_raw = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, v_first=v_first) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, v_raw +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + vrl_enabled: bool = False, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self.vrl_enabled = vrl_enabled + if vrl_enabled: + for i in range(1, num_layers): # all layers except layer 0 + self.blocks[i].attn.vrl_gate = nn.Parameter(torch.tensor(-1.5, dtype=torch.float32)) + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + v_first: Tensor | None = None + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, v_raw = self.blocks[i](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + if i == 0 and self.vrl_enabled: + v_first = v_raw + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + v_first: Tensor | None = None + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, v_raw = self.blocks[i](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + if i == 0 and self.vrl_enabled: + v_first = v_raw + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class NgramBackoffCache: + PRIMES = [36313, 27191, 51647, 81929, 131071, 175447, 209591] + def __init__(self, vocab_size: int, min_order: int = 2, max_order: int = 7, + hash_size: int = 4194304, min_count: int = 2): + self.V = vocab_size + self.min_order = min_order + self.max_order = max_order + self.n_orders = max_order - min_order + 1 + self.H = hash_size + self.mask = hash_size - 1 + self.min_count = min_count + self.ctx_tables = [np.zeros(hash_size, dtype=np.uint32) for _ in range(self.n_orders)] + self.full_tables = [np.zeros(hash_size, dtype=np.uint32) for _ in range(self.n_orders)] + def _hash_ctx(self, tokens: np.ndarray, pos: int, ctx_w: int) -> int: + h = 0 + for k in range(ctx_w): + h ^= int(tokens[pos - ctx_w + k]) * self.PRIMES[k % 7] + return h & self.mask + def _hash_full(self, ctx_hash: int, target: int, ctx_w: int) -> int: + return (ctx_hash ^ (target * self.PRIMES[ctx_w % 7])) & self.mask + def update(self, tokens: np.ndarray, start: int, end: int) -> None: + for i in range(start, end): + target = int(tokens[i]) + for oi in range(self.n_orders): + ctx_w = oi + self.min_order - 1 + if i < ctx_w: + continue + ctx_h = self._hash_ctx(tokens, i, ctx_w) + full_h = self._hash_full(ctx_h, target, ctx_w) + self.ctx_tables[oi][ctx_h] += 1 + self.full_tables[oi][full_h] += 1 + def predict(self, tokens: np.ndarray, pos: int, target: int) -> tuple[float, bool]: + for oi in range(self.n_orders - 1, -1, -1): + ctx_w = oi + self.min_order - 1 + if pos < ctx_w: + continue + ctx_h = self._hash_ctx(tokens, pos, ctx_w) + ctx_count = self.ctx_tables[oi][ctx_h] + if ctx_count < self.min_count: + continue + full_h = self._hash_full(ctx_h, target, ctx_w) + full_count = self.full_tables[oi][full_h] + p = min(full_count, ctx_count) / max(ctx_count, 1) + return max(min(p, 1.0), 0.0), True + return 0.0, False +def eval_val_ngram_backoff( + args, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, + ngram_order: int = 7, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + vocab_size = args.vocab_size + cache = NgramBackoffCache(vocab_size, min_order=2, max_order=ngram_order, + hash_size=4194304, min_count=2) + window_starts = sorted([ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1]) + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_tokens = val_tokens.cpu().numpy().astype(np.int32) + scored_up_to = my_windows[0] if my_windows else 0 + ngram_helped = 0 + ngram_total = 0 + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + log_probs = torch.log_softmax(logits.float(), dim=-1) + probs = torch.exp(log_probs) + entropy = -(probs * log_probs).sum(dim=-1) + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction='none').reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + score_len = wlen - s + if score_len <= 0: + continue + for t in range(s, wlen): + token_pos = ws + t + 1 + tgt = int(y_batch[i, t].item()) + prev = int(x_batch[i, t].item()) + model_nll = nll[i, t].item() + model_p = math.exp(-model_nll) + H = entropy[i, t].item() + alpha = 0.05 + 0.55 / (1.0 + math.exp(-2.0 * (H - 4.0))) + ng_p, has_ng = cache.predict(all_tokens, token_pos, tgt) + if has_ng: + mixed_p = max((1.0 - alpha) * model_p + alpha * ng_p, 1e-12) + scored_nll = -math.log(mixed_p) + ngram_total += 1 + if scored_nll < model_nll: + ngram_helped += 1 + else: + scored_nll = model_nll + loss_sum += scored_nll + token_count += 1.0 + tb = float(base_bytes_lut[tgt].item()) + tb += float((has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).item()) + byte_count += tb + new_end = ws + wlen + 1 + if new_end > scored_up_to: + cache.update(all_tokens, scored_up_to, new_end) + scored_up_to = new_end + if rank == 0 and bi % 100 == 0: + running_bpb = ((loss_sum / token_count) / math.log(2.0) * token_count / byte_count).item() if token_count > 0 else 0 + pct = 100.0 * bi / max(len(my_windows), 1) + hit_rate = ngram_helped / max(ngram_total, 1) * 100 + log0(f" ngram [{bi}/{len(my_windows)}] {pct:.1f}% bpb={running_bpb:.6f} ng_helped={hit_rate:.1f}%") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.to(torch.float16) + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + vrl_enabled=args.vrl_enabled, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) if _COMPRESSOR == "lzma" else zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk) if _COMPRESSOR == "lzma" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + vrl_enabled=args.vrl_enabled, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int6_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int6_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + ngram_enabled = bool(int(os.environ.get("NGRAM_ENABLED", "0"))) + if ngram_enabled: + log0("Starting n-gram backoff eval (PR #727 approach: entropy-adaptive alpha, orders 2-7)...") + torch.cuda.synchronize() + t_ng = time.perf_counter() + ng_val_loss, ng_val_bpb = eval_val_ngram_backoff( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + batch_seqs=32, log0=log0, + ngram_order=int(os.environ.get("NGRAM_ORDER", "7")), + ) + torch.cuda.synchronize() + log0(f"final_ngram val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"ngram_eval_time:{1000.0 * (time.perf_counter() - t_ng):.0f}ms") + log0(f"final_ngram_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Mar 26 17:20:54 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.195.03 Driver Version: 570.195.03 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | 0 | +| N/A 30C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:07:00.0 Off | 0 | +| N/A 40C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:08:00.0 Off | 0 | +| N/A 35C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:09:00.0 Off | 0 | +| N/A 35C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | +| N/A 30C P0 114W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:0C:00.0 Off | 0 | +| N/A 35C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 644 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 645 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 646 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 647 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 648 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 649 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 650 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 651 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26993766 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9279 val_bpb:4.1031 train_time:0ms step_avg:0.02ms +step:1/20000 train_loss:6.9299 train_time:156ms step_avg:155.95ms +step:2/20000 train_loss:8.5665 train_time:262ms step_avg:131.24ms +step:3/20000 train_loss:7.8274 train_time:349ms step_avg:116.43ms +step:4/20000 train_loss:7.2142 train_time:435ms step_avg:108.71ms +step:5/20000 train_loss:7.0642 train_time:521ms step_avg:104.14ms +step:6/20000 train_loss:6.8454 train_time:607ms step_avg:101.13ms +step:7/20000 train_loss:6.7570 train_time:693ms step_avg:98.97ms +step:8/20000 train_loss:6.7616 train_time:779ms step_avg:97.33ms +step:9/20000 train_loss:6.4223 train_time:864ms step_avg:96.04ms +step:10/20000 train_loss:6.0911 train_time:950ms step_avg:95.04ms +step:500/20000 train_loss:2.3706 train_time:44033ms step_avg:88.07ms +step:1000/20000 train_loss:2.2533 train_time:88175ms step_avg:88.18ms +step:1500/20000 train_loss:2.2032 train_time:132368ms step_avg:88.25ms +step:2000/20000 train_loss:2.0493 train_time:176627ms step_avg:88.31ms +step:2500/20000 train_loss:2.1534 train_time:220906ms step_avg:88.36ms +step:3000/20000 train_loss:2.1464 train_time:265226ms step_avg:88.41ms +step:3500/20000 train_loss:2.1647 train_time:309554ms step_avg:88.44ms +step:4000/20000 train_loss:1.9589 train_time:353862ms step_avg:88.47ms +step:4000/20000 val_loss:2.0469 val_bpb:1.2123 train_time:353867ms step_avg:88.47ms +step:4500/20000 train_loss:2.1046 train_time:398244ms step_avg:88.50ms +step:5000/20000 train_loss:2.0857 train_time:442662ms step_avg:88.53ms +step:5500/20000 train_loss:1.9984 train_time:487086ms step_avg:88.56ms +step:6000/20000 train_loss:1.9243 train_time:531507ms step_avg:88.58ms +swa:start step:6100 +late_qat:enabled step:6246 scale:0.1498 +step:6500/20000 train_loss:2.0634 train_time:576267ms step_avg:88.66ms +step:6765/20000 val_loss:1.9237 val_bpb:1.1393 train_time:600015ms step_avg:88.69ms +stopping_early: wallclock_cap train_time:600015ms step:6765/20000 +peak memory allocated: 21155 MiB reserved: 21232 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9221 val_bpb:1.1384 eval_time:2039ms +Serialized model: 106181533 bytes +Code size: 67048 bytes +Serialized model int6+lzma: 15914800 bytes +Total submission size int6+lzma: 15981848 bytes +Total submission size: 15981848 bytes +final_int6_roundtrip val_loss:1.9352 val_bpb:1.1462 eval_time:52882ms +final_int6_roundtrip_exact val_loss:1.93524460 val_bpb:1.14616086 +final_int6_sliding_window val_loss:1.8953 val_bpb:1.1225 stride:64 eval_time:102169ms +final_int6_sliding_window_exact val_loss:1.89533097 val_bpb:1.12252473 +final_int6_roundtrip_exact val_loss:1.89533097 val_bpb:1.12252473 +Starting n-gram backoff eval (PR #727 approach: entropy-adaptive alpha, orders 2-7)... + ngram [0/121136] 0.0% bpb=1.208449 ng_helped=9.9% + ngram [800/121136] 0.7% bpb=1.225029 ng_helped=17.5% + ngram [1600/121136] 1.3% bpb=1.151905 ng_helped=18.0% + ngram [2400/121136] 2.0% bpb=1.167360 ng_helped=17.8% + ngram [3200/121136] 2.6% bpb=1.152816 ng_helped=18.2% + ngram [4000/121136] 3.3% bpb=1.150294 ng_helped=18.3% + ngram [4800/121136] 4.0% bpb=1.144471 ng_helped=18.5% + ngram [5600/121136] 4.6% bpb=1.146319 ng_helped=18.7% + ngram [6400/121136] 5.3% bpb=1.152813 ng_helped=19.4% + ngram [7200/121136] 5.9% bpb=1.151456 ng_helped=19.6% + ngram [8000/121136] 6.6% bpb=1.151294 ng_helped=19.6% + ngram [8800/121136] 7.3% bpb=1.155430 ng_helped=19.7% + ngram [9600/121136] 7.9% bpb=1.150554 ng_helped=19.8% + ngram [10400/121136] 8.6% bpb=1.147684 ng_helped=20.0% + ngram [11200/121136] 9.2% bpb=1.144085 ng_helped=20.1% + ngram [12000/121136] 9.9% bpb=1.141570 ng_helped=20.3% + ngram [12800/121136] 10.6% bpb=1.139536 ng_helped=20.3% + ngram [13600/121136] 11.2% bpb=1.137220 ng_helped=20.4% + ngram [14400/121136] 11.9% bpb=1.139054 ng_helped=20.5% + ngram [15200/121136] 12.5% bpb=1.148814 ng_helped=20.7% + ngram [16000/121136] 13.2% bpb=1.144753 ng_helped=20.8% + ngram [16800/121136] 13.9% bpb=1.143496 ng_helped=20.9% + ngram [17600/121136] 14.5% bpb=1.140436 ng_helped=21.1% + ngram [18400/121136] 15.2% bpb=1.138924 ng_helped=21.3% + ngram [19200/121136] 15.8% bpb=1.139110 ng_helped=21.4% + ngram [20000/121136] 16.5% bpb=1.136649 ng_helped=21.5% + ngram [20800/121136] 17.2% bpb=1.135051 ng_helped=21.6% + ngram [21600/121136] 17.8% bpb=1.132934 ng_helped=21.8% + ngram [22400/121136] 18.5% bpb=1.131011 ng_helped=21.9% + ngram [23200/121136] 19.2% bpb=1.127293 ng_helped=22.1% + ngram [24000/121136] 19.8% bpb=1.128773 ng_helped=22.2% + ngram [24800/121136] 20.5% bpb=1.127482 ng_helped=22.3% + ngram [25600/121136] 21.1% bpb=1.127500 ng_helped=22.5% + ngram [26400/121136] 21.8% bpb=1.125961 ng_helped=22.6% + ngram [27200/121136] 22.5% bpb=1.125360 ng_helped=22.7% + ngram [28000/121136] 23.1% bpb=1.128052 ng_helped=22.9% + ngram [28800/121136] 23.8% bpb=1.128454 ng_helped=23.0% + ngram [29600/121136] 24.4% bpb=1.126822 ng_helped=23.1% + ngram [30400/121136] 25.1% bpb=1.123485 ng_helped=23.2% + ngram [31200/121136] 25.8% bpb=1.122455 ng_helped=23.4% + ngram [32000/121136] 26.4% bpb=1.121859 ng_helped=23.5% + ngram [32800/121136] 27.1% bpb=1.119893 ng_helped=23.7% + ngram [33600/121136] 27.7% bpb=1.117778 ng_helped=23.8% + ngram [34400/121136] 28.4% bpb=1.115870 ng_helped=23.9% + ngram [35200/121136] 29.1% bpb=1.114558 ng_helped=24.0% + ngram [36000/121136] 29.7% bpb=1.113623 ng_helped=24.2% + ngram [36800/121136] 30.4% bpb=1.111404 ng_helped=24.3% + ngram [37600/121136] 31.0% bpb=1.110385 ng_helped=24.4% + ngram [38400/121136] 31.7% bpb=1.109266 ng_helped=24.6% + ngram [39200/121136] 32.4% bpb=1.106078 ng_helped=24.8% + ngram [40000/121136] 33.0% bpb=1.104366 ng_helped=24.9% + ngram [40800/121136] 33.7% bpb=1.101451 ng_helped=25.1% + ngram [41600/121136] 34.3% bpb=1.100420 ng_helped=25.2% + ngram [42400/121136] 35.0% bpb=1.099396 ng_helped=25.4% + ngram [43200/121136] 35.7% bpb=1.098195 ng_helped=25.5% + ngram [44000/121136] 36.3% bpb=1.095905 ng_helped=25.7% + ngram [44800/121136] 37.0% bpb=1.094322 ng_helped=25.8% + ngram [45600/121136] 37.6% bpb=1.092488 ng_helped=25.9% + ngram [46400/121136] 38.3% bpb=1.091482 ng_helped=26.0% + ngram [47200/121136] 39.0% bpb=1.089468 ng_helped=26.2% + ngram [48000/121136] 39.6% bpb=1.088135 ng_helped=26.3% + ngram [48800/121136] 40.3% bpb=1.086644 ng_helped=26.4% + ngram [49600/121136] 40.9% bpb=1.086363 ng_helped=26.5% + ngram [50400/121136] 41.6% bpb=1.085458 ng_helped=26.7% + ngram [51200/121136] 42.3% bpb=1.084536 ng_helped=26.8% + ngram [52000/121136] 42.9% bpb=1.083269 ng_helped=26.9% + ngram [52800/121136] 43.6% bpb=1.082327 ng_helped=27.1% + ngram [53600/121136] 44.2% bpb=1.080201 ng_helped=27.2% + ngram [54400/121136] 44.9% bpb=1.079235 ng_helped=27.3% + ngram [55200/121136] 45.6% bpb=1.078207 ng_helped=27.5% + ngram [56000/121136] 46.2% bpb=1.076836 ng_helped=27.6% + ngram [56800/121136] 46.9% bpb=1.074889 ng_helped=27.7% + ngram [57600/121136] 47.5% bpb=1.073352 ng_helped=27.9% + ngram [58400/121136] 48.2% bpb=1.068926 ng_helped=28.0% + ngram [59200/121136] 48.9% bpb=1.067353 ng_helped=28.1% + ngram [60000/121136] 49.5% bpb=1.066052 ng_helped=28.3% + ngram [60800/121136] 50.2% bpb=1.064767 ng_helped=28.4% + ngram [61600/121136] 50.9% bpb=1.063401 ng_helped=28.5% + ngram [62400/121136] 51.5% bpb=1.062674 ng_helped=28.7% + ngram [63200/121136] 52.2% bpb=1.061103 ng_helped=28.8% + ngram [64000/121136] 52.8% bpb=1.060066 ng_helped=28.9% + ngram [64800/121136] 53.5% bpb=1.058796 ng_helped=29.1% + ngram [65600/121136] 54.2% bpb=1.057243 ng_helped=29.2% + ngram [66400/121136] 54.8% bpb=1.055303 ng_helped=29.3% + ngram [67200/121136] 55.5% bpb=1.053585 ng_helped=29.5% + ngram [68000/121136] 56.1% bpb=1.052131 ng_helped=29.6% + ngram [68800/121136] 56.8% bpb=1.050652 ng_helped=29.7% + ngram [69600/121136] 57.5% bpb=1.049054 ng_helped=29.9% + ngram [70400/121136] 58.1% bpb=1.047344 ng_helped=30.0% + ngram [71200/121136] 58.8% bpb=1.046017 ng_helped=30.1% + ngram [72000/121136] 59.4% bpb=1.044622 ng_helped=30.3% + ngram [72800/121136] 60.1% bpb=1.043234 ng_helped=30.4% + ngram [73600/121136] 60.8% bpb=1.041962 ng_helped=30.5% + ngram [74400/121136] 61.4% bpb=1.040889 ng_helped=30.7% + ngram [75200/121136] 62.1% bpb=1.039381 ng_helped=30.8% + ngram [76000/121136] 62.7% bpb=1.037562 ng_helped=31.0% + ngram [76800/121136] 63.4% bpb=1.036462 ng_helped=31.1% + ngram [77600/121136] 64.1% bpb=1.035247 ng_helped=31.2% + ngram [78400/121136] 64.7% bpb=1.034154 ng_helped=31.4% + ngram [79200/121136] 65.4% bpb=1.032618 ng_helped=31.5% + ngram [80000/121136] 66.0% bpb=1.031642 ng_helped=31.7% + ngram [80800/121136] 66.7% bpb=1.030576 ng_helped=31.8% + ngram [81600/121136] 67.4% bpb=1.028807 ng_helped=31.9% + ngram [82400/121136] 68.0% bpb=1.027927 ng_helped=32.1% + ngram [83200/121136] 68.7% bpb=1.026887 ng_helped=32.2% + ngram [84000/121136] 69.3% bpb=1.026753 ng_helped=32.4% + ngram [84800/121136] 70.0% bpb=1.025532 ng_helped=32.5% + ngram [85600/121136] 70.7% bpb=1.023351 ng_helped=32.6% + ngram [86400/121136] 71.3% bpb=1.022240 ng_helped=32.8% + ngram [87200/121136] 72.0% bpb=1.021058 ng_helped=32.9% + ngram [88000/121136] 72.6% bpb=1.019950 ng_helped=33.1% + ngram [88800/121136] 73.3% bpb=1.018711 ng_helped=33.2% + ngram [89600/121136] 74.0% bpb=1.017554 ng_helped=33.3% + ngram [90400/121136] 74.6% bpb=1.016432 ng_helped=33.5% + ngram [91200/121136] 75.3% bpb=1.015009 ng_helped=33.6% + ngram [92000/121136] 75.9% bpb=1.013320 ng_helped=33.7% + ngram [92800/121136] 76.6% bpb=1.012104 ng_helped=33.9% + ngram [93600/121136] 77.3% bpb=1.010860 ng_helped=34.0% + ngram [94400/121136] 77.9% bpb=1.009659 ng_helped=34.1% + ngram [95200/121136] 78.6% bpb=1.008333 ng_helped=34.3% + ngram [96000/121136] 79.2% bpb=1.006795 ng_helped=34.4% + ngram [96800/121136] 79.9% bpb=1.007487 ng_helped=34.6% + ngram [97600/121136] 80.6% bpb=1.005941 ng_helped=34.7% + ngram [98400/121136] 81.2% bpb=1.004683 ng_helped=34.8% + ngram [99200/121136] 81.9% bpb=1.003353 ng_helped=35.0% + ngram [100000/121136] 82.6% bpb=1.001855 ng_helped=35.1% + ngram [100800/121136] 83.2% bpb=1.000772 ng_helped=35.2% + ngram [101600/121136] 83.9% bpb=0.999789 ng_helped=35.4% + ngram [102400/121136] 84.5% bpb=0.998071 ng_helped=35.5% + ngram [103200/121136] 85.2% bpb=0.996721 ng_helped=35.6% + ngram [104000/121136] 85.9% bpb=0.995242 ng_helped=35.8% + ngram [104800/121136] 86.5% bpb=0.993613 ng_helped=35.9% + ngram [105600/121136] 87.2% bpb=0.992196 ng_helped=36.0% + ngram [106400/121136] 87.8% bpb=0.990969 ng_helped=36.1% + ngram [107200/121136] 88.5% bpb=0.989795 ng_helped=36.3% + ngram [108000/121136] 89.2% bpb=0.988648 ng_helped=36.4% + ngram [108800/121136] 89.8% bpb=0.987638 ng_helped=36.5% + ngram [109600/121136] 90.5% bpb=0.986560 ng_helped=36.7% + ngram [110400/121136] 91.1% bpb=0.985248 ng_helped=36.8% + ngram [111200/121136] 91.8% bpb=0.984096 ng_helped=36.9% + ngram [112000/121136] 92.5% bpb=0.982764 ng_helped=37.1% + ngram [112800/121136] 93.1% bpb=0.981926 ng_helped=37.2% + ngram [113600/121136] 93.8% bpb=0.980665 ng_helped=37.3% + ngram [114400/121136] 94.4% bpb=0.979362 ng_helped=37.4% + ngram [115200/121136] 95.1% bpb=0.978121 ng_helped=37.6% + ngram [116000/121136] 95.8% bpb=0.976942 ng_helped=37.7% + ngram [116800/121136] 96.4% bpb=0.975513 ng_helped=37.8% + ngram [117600/121136] 97.1% bpb=0.974480 ng_helped=38.0% + ngram [118400/121136] 97.7% bpb=0.973327 ng_helped=38.1% + ngram [119200/121136] 98.4% bpb=0.972201 ng_helped=38.2% + ngram [120000/121136] 99.1% bpb=0.971013 ng_helped=38.3% + ngram [120800/121136] 99.7% bpb=0.969966 ng_helped=38.5% +final_ngram val_loss:1.6277 val_bpb:0.9640 ngram_eval_time:895349ms +final_ngram_exact val_loss:1.62773633 val_bpb:0.96403969 diff --git a/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_seed2025.log b/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_seed2025.log new file mode 100644 index 000000000..711bee6ab --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_seed2025.log @@ -0,0 +1,1876 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +import lzma +_COMPRESSOR = "lzma" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True +except ImportError: + try: + from flash_attn.flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + try: + from flash_attn import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + _HAS_FA3 = False + flash_attn_3_func = None +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last 4 layers (0 = disabled) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + vrl_enabled = bool(int(os.environ.get("VRL_ENABLED", "1"))) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + self.vrl_gate = None # set by GPT.__init__ when VRL is enabled + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, v_embed: Tensor | None = None, v_first: Tensor | None = None) -> tuple[Tensor, Tensor]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + v_raw = v # save raw V before any modifications for VRL + if v_embed is not None: + v = v + v_embed + if v_first is not None and self.vrl_gate is not None: + gate = torch.sigmoid(self.vrl_gate.to(dtype=v.dtype)) + v = (1 - gate) * v + gate * v_first + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _HAS_FA3: + y = flash_attn_3_func(q, k, v, causal=True) + else: + q2 = q.transpose(1, 2) + k2 = k.transpose(1, 2) + v2 = v.transpose(1, 2) + k2 = k2.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + v2 = v2.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + y = F.scaled_dot_product_attention(q2, k2, v2, is_causal=True) + y = y.transpose(1, 2).contiguous() + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y), v_raw +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None, v_first: Tensor | None = None) -> tuple[Tensor, Tensor]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, v_raw = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, v_first=v_first) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, v_raw +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + vrl_enabled: bool = False, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self.vrl_enabled = vrl_enabled + if vrl_enabled: + for i in range(1, num_layers): # all layers except layer 0 + self.blocks[i].attn.vrl_gate = nn.Parameter(torch.tensor(-1.5, dtype=torch.float32)) + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + v_first: Tensor | None = None + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, v_raw = self.blocks[i](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + if i == 0 and self.vrl_enabled: + v_first = v_raw + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + v_first: Tensor | None = None + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, v_raw = self.blocks[i](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + if i == 0 and self.vrl_enabled: + v_first = v_raw + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class NgramBackoffCache: + PRIMES = [36313, 27191, 51647, 81929, 131071, 175447, 209591] + def __init__(self, vocab_size: int, min_order: int = 2, max_order: int = 7, + hash_size: int = 4194304, min_count: int = 2): + self.V = vocab_size + self.min_order = min_order + self.max_order = max_order + self.n_orders = max_order - min_order + 1 + self.H = hash_size + self.mask = hash_size - 1 + self.min_count = min_count + self.ctx_tables = [np.zeros(hash_size, dtype=np.uint32) for _ in range(self.n_orders)] + self.full_tables = [np.zeros(hash_size, dtype=np.uint32) for _ in range(self.n_orders)] + def _hash_ctx(self, tokens: np.ndarray, pos: int, ctx_w: int) -> int: + h = 0 + for k in range(ctx_w): + h ^= int(tokens[pos - ctx_w + k]) * self.PRIMES[k % 7] + return h & self.mask + def _hash_full(self, ctx_hash: int, target: int, ctx_w: int) -> int: + return (ctx_hash ^ (target * self.PRIMES[ctx_w % 7])) & self.mask + def update(self, tokens: np.ndarray, start: int, end: int) -> None: + for i in range(start, end): + target = int(tokens[i]) + for oi in range(self.n_orders): + ctx_w = oi + self.min_order - 1 + if i < ctx_w: + continue + ctx_h = self._hash_ctx(tokens, i, ctx_w) + full_h = self._hash_full(ctx_h, target, ctx_w) + self.ctx_tables[oi][ctx_h] += 1 + self.full_tables[oi][full_h] += 1 + def predict(self, tokens: np.ndarray, pos: int, target: int) -> tuple[float, bool]: + for oi in range(self.n_orders - 1, -1, -1): + ctx_w = oi + self.min_order - 1 + if pos < ctx_w: + continue + ctx_h = self._hash_ctx(tokens, pos, ctx_w) + ctx_count = self.ctx_tables[oi][ctx_h] + if ctx_count < self.min_count: + continue + full_h = self._hash_full(ctx_h, target, ctx_w) + full_count = self.full_tables[oi][full_h] + p = min(full_count, ctx_count) / max(ctx_count, 1) + return max(min(p, 1.0), 0.0), True + return 0.0, False +def eval_val_ngram_backoff( + args, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, + ngram_order: int = 7, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + vocab_size = args.vocab_size + cache = NgramBackoffCache(vocab_size, min_order=2, max_order=ngram_order, + hash_size=4194304, min_count=2) + window_starts = sorted([ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1]) + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_tokens = val_tokens.cpu().numpy().astype(np.int32) + scored_up_to = my_windows[0] if my_windows else 0 + ngram_helped = 0 + ngram_total = 0 + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + log_probs = torch.log_softmax(logits.float(), dim=-1) + probs = torch.exp(log_probs) + entropy = -(probs * log_probs).sum(dim=-1) + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction='none').reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + score_len = wlen - s + if score_len <= 0: + continue + for t in range(s, wlen): + token_pos = ws + t + 1 + tgt = int(y_batch[i, t].item()) + prev = int(x_batch[i, t].item()) + model_nll = nll[i, t].item() + model_p = math.exp(-model_nll) + H = entropy[i, t].item() + alpha = 0.05 + 0.55 / (1.0 + math.exp(-2.0 * (H - 4.0))) + ng_p, has_ng = cache.predict(all_tokens, token_pos, tgt) + if has_ng: + mixed_p = max((1.0 - alpha) * model_p + alpha * ng_p, 1e-12) + scored_nll = -math.log(mixed_p) + ngram_total += 1 + if scored_nll < model_nll: + ngram_helped += 1 + else: + scored_nll = model_nll + loss_sum += scored_nll + token_count += 1.0 + tb = float(base_bytes_lut[tgt].item()) + tb += float((has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).item()) + byte_count += tb + new_end = ws + wlen + 1 + if new_end > scored_up_to: + cache.update(all_tokens, scored_up_to, new_end) + scored_up_to = new_end + if rank == 0 and bi % 100 == 0: + running_bpb = ((loss_sum / token_count) / math.log(2.0) * token_count / byte_count).item() if token_count > 0 else 0 + pct = 100.0 * bi / max(len(my_windows), 1) + hit_rate = ngram_helped / max(ngram_total, 1) * 100 + log0(f" ngram [{bi}/{len(my_windows)}] {pct:.1f}% bpb={running_bpb:.6f} ng_helped={hit_rate:.1f}%") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.to(torch.float16) + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + vrl_enabled=args.vrl_enabled, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) if _COMPRESSOR == "lzma" else zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk) if _COMPRESSOR == "lzma" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + vrl_enabled=args.vrl_enabled, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int6_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int6_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + ngram_enabled = bool(int(os.environ.get("NGRAM_ENABLED", "0"))) + if ngram_enabled: + log0("Starting n-gram backoff eval (PR #727 approach: entropy-adaptive alpha, orders 2-7)...") + torch.cuda.synchronize() + t_ng = time.perf_counter() + ng_val_loss, ng_val_bpb = eval_val_ngram_backoff( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + batch_seqs=32, log0=log0, + ngram_order=int(os.environ.get("NGRAM_ORDER", "7")), + ) + torch.cuda.synchronize() + log0(f"final_ngram val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"ngram_eval_time:{1000.0 * (time.perf_counter() - t_ng):.0f}ms") + log0(f"final_ngram_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Mar 26 18:19:50 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.195.03 Driver Version: 570.195.03 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | +| N/A 41C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | 0 | +| N/A 33C P0 119W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:07:00.0 Off | 0 | +| N/A 42C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:08:00.0 Off | 0 | +| N/A 40C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:09:00.0 Off | 0 | +| N/A 40C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 34C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | +| N/A 32C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:0C:00.0 Off | 0 | +| N/A 40C P0 120W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 73766 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 73767 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 73768 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 73769 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 73770 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 73771 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 73772 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 73773 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26993766 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2025 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9303 val_bpb:4.1045 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9322 train_time:150ms step_avg:150.47ms +step:2/20000 train_loss:8.6380 train_time:232ms step_avg:115.78ms +step:3/20000 train_loss:7.8093 train_time:318ms step_avg:105.90ms +step:4/20000 train_loss:7.2249 train_time:404ms step_avg:100.88ms +step:5/20000 train_loss:6.9937 train_time:490ms step_avg:97.94ms +step:6/20000 train_loss:6.9397 train_time:575ms step_avg:95.89ms +step:7/20000 train_loss:6.8229 train_time:661ms step_avg:94.44ms +step:8/20000 train_loss:6.6557 train_time:747ms step_avg:93.35ms +step:9/20000 train_loss:6.3636 train_time:834ms step_avg:92.64ms +step:10/20000 train_loss:6.0990 train_time:919ms step_avg:91.94ms +step:500/20000 train_loss:2.3730 train_time:43963ms step_avg:87.93ms +step:1000/20000 train_loss:2.2562 train_time:88080ms step_avg:88.08ms +step:1500/20000 train_loss:2.2060 train_time:132214ms step_avg:88.14ms +step:2000/20000 train_loss:2.0516 train_time:176403ms step_avg:88.20ms +step:2500/20000 train_loss:2.1574 train_time:220669ms step_avg:88.27ms +step:3000/20000 train_loss:2.1501 train_time:264899ms step_avg:88.30ms +step:3500/20000 train_loss:2.1642 train_time:309250ms step_avg:88.36ms +step:4000/20000 train_loss:1.9557 train_time:353621ms step_avg:88.41ms +step:4000/20000 val_loss:2.0470 val_bpb:1.2124 train_time:353626ms step_avg:88.41ms +step:4500/20000 train_loss:2.1037 train_time:397991ms step_avg:88.44ms +step:5000/20000 train_loss:2.0889 train_time:442323ms step_avg:88.46ms +step:5500/20000 train_loss:2.0013 train_time:486565ms step_avg:88.47ms +step:6000/20000 train_loss:1.9256 train_time:530773ms step_avg:88.46ms +swa:start step:6100 +late_qat:enabled step:6255 scale:0.1499 +step:6500/20000 train_loss:2.0611 train_time:575421ms step_avg:88.53ms +step:6776/20000 val_loss:1.9244 val_bpb:1.1397 train_time:600085ms step_avg:88.56ms +stopping_early: wallclock_cap train_time:600085ms step:6776/20000 +peak memory allocated: 21149 MiB reserved: 21204 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9227 val_bpb:1.1388 eval_time:2038ms +Serialized model: 106181533 bytes +Code size: 67048 bytes +Serialized model int6+lzma: 15907260 bytes +Total submission size int6+lzma: 15974308 bytes +Total submission size: 15974308 bytes +final_int6_roundtrip val_loss:1.9361 val_bpb:1.1466 eval_time:9286ms +final_int6_roundtrip_exact val_loss:1.93605399 val_bpb:1.14664023 +final_int6_sliding_window val_loss:1.8962 val_bpb:1.1231 stride:64 eval_time:78000ms +final_int6_sliding_window_exact val_loss:1.89622932 val_bpb:1.12305678 +final_int6_roundtrip_exact val_loss:1.89622932 val_bpb:1.12305678 +Starting n-gram backoff eval (PR #727 approach: entropy-adaptive alpha, orders 2-7)... + ngram [0/121136] 0.0% bpb=1.211517 ng_helped=10.2% + ngram [800/121136] 0.7% bpb=1.228354 ng_helped=17.6% + ngram [1600/121136] 1.3% bpb=1.154860 ng_helped=18.1% + ngram [2400/121136] 2.0% bpb=1.169775 ng_helped=17.9% + ngram [3200/121136] 2.6% bpb=1.155298 ng_helped=18.3% + ngram [4000/121136] 3.3% bpb=1.151759 ng_helped=18.4% + ngram [4800/121136] 4.0% bpb=1.146377 ng_helped=18.6% + ngram [5600/121136] 4.6% bpb=1.147891 ng_helped=18.7% + ngram [6400/121136] 5.3% bpb=1.154466 ng_helped=19.4% + ngram [7200/121136] 5.9% bpb=1.153022 ng_helped=19.6% + ngram [8000/121136] 6.6% bpb=1.152976 ng_helped=19.7% + ngram [8800/121136] 7.3% bpb=1.157068 ng_helped=19.8% + ngram [9600/121136] 7.9% bpb=1.152359 ng_helped=19.9% + ngram [10400/121136] 8.6% bpb=1.149341 ng_helped=20.1% + ngram [11200/121136] 9.2% bpb=1.145755 ng_helped=20.2% + ngram [12000/121136] 9.9% bpb=1.143126 ng_helped=20.4% + ngram [12800/121136] 10.6% bpb=1.140883 ng_helped=20.4% + ngram [13600/121136] 11.2% bpb=1.138434 ng_helped=20.5% + ngram [14400/121136] 11.9% bpb=1.140314 ng_helped=20.6% + ngram [15200/121136] 12.5% bpb=1.150128 ng_helped=20.8% + ngram [16000/121136] 13.2% bpb=1.145954 ng_helped=20.9% + ngram [16800/121136] 13.9% bpb=1.144724 ng_helped=21.0% + ngram [17600/121136] 14.5% bpb=1.141770 ng_helped=21.2% + ngram [18400/121136] 15.2% bpb=1.140233 ng_helped=21.4% + ngram [19200/121136] 15.8% bpb=1.140481 ng_helped=21.5% + ngram [20000/121136] 16.5% bpb=1.138085 ng_helped=21.6% + ngram [20800/121136] 17.2% bpb=1.136421 ng_helped=21.7% + ngram [21600/121136] 17.8% bpb=1.134333 ng_helped=21.9% + ngram [22400/121136] 18.5% bpb=1.132307 ng_helped=22.0% + ngram [23200/121136] 19.2% bpb=1.128533 ng_helped=22.2% + ngram [24000/121136] 19.8% bpb=1.129934 ng_helped=22.3% + ngram [24800/121136] 20.5% bpb=1.128647 ng_helped=22.4% + ngram [25600/121136] 21.1% bpb=1.128601 ng_helped=22.6% + ngram [26400/121136] 21.8% bpb=1.127040 ng_helped=22.7% + ngram [27200/121136] 22.5% bpb=1.126340 ng_helped=22.8% + ngram [28000/121136] 23.1% bpb=1.129079 ng_helped=23.0% + ngram [28800/121136] 23.8% bpb=1.129469 ng_helped=23.1% + ngram [29600/121136] 24.4% bpb=1.127842 ng_helped=23.2% + ngram [30400/121136] 25.1% bpb=1.124613 ng_helped=23.4% + ngram [31200/121136] 25.8% bpb=1.123487 ng_helped=23.5% + ngram [32000/121136] 26.4% bpb=1.122955 ng_helped=23.6% + ngram [32800/121136] 27.1% bpb=1.120993 ng_helped=23.8% + ngram [33600/121136] 27.7% bpb=1.118871 ng_helped=23.9% + ngram [34400/121136] 28.4% bpb=1.116908 ng_helped=24.0% + ngram [35200/121136] 29.1% bpb=1.115594 ng_helped=24.1% + ngram [36000/121136] 29.7% bpb=1.114650 ng_helped=24.3% + ngram [36800/121136] 30.4% bpb=1.112426 ng_helped=24.4% + ngram [37600/121136] 31.0% bpb=1.111401 ng_helped=24.6% + ngram [38400/121136] 31.7% bpb=1.110335 ng_helped=24.7% + ngram [39200/121136] 32.4% bpb=1.107137 ng_helped=24.9% + ngram [40000/121136] 33.0% bpb=1.105467 ng_helped=25.0% + ngram [40800/121136] 33.7% bpb=1.102531 ng_helped=25.2% + ngram [41600/121136] 34.3% bpb=1.101498 ng_helped=25.4% + ngram [42400/121136] 35.0% bpb=1.100421 ng_helped=25.5% + ngram [43200/121136] 35.7% bpb=1.099202 ng_helped=25.6% + ngram [44000/121136] 36.3% bpb=1.096868 ng_helped=25.8% + ngram [44800/121136] 37.0% bpb=1.095256 ng_helped=25.9% + ngram [45600/121136] 37.6% bpb=1.093434 ng_helped=26.0% + ngram [46400/121136] 38.3% bpb=1.092424 ng_helped=26.1% + ngram [47200/121136] 39.0% bpb=1.090399 ng_helped=26.3% + ngram [48000/121136] 39.6% bpb=1.089068 ng_helped=26.4% + ngram [48800/121136] 40.3% bpb=1.087593 ng_helped=26.5% + ngram [49600/121136] 40.9% bpb=1.087276 ng_helped=26.7% + ngram [50400/121136] 41.6% bpb=1.086342 ng_helped=26.8% + ngram [51200/121136] 42.3% bpb=1.085394 ng_helped=26.9% + ngram [52000/121136] 42.9% bpb=1.084133 ng_helped=27.1% + ngram [52800/121136] 43.6% bpb=1.083178 ng_helped=27.2% + ngram [53600/121136] 44.2% bpb=1.081029 ng_helped=27.3% + ngram [54400/121136] 44.9% bpb=1.080035 ng_helped=27.4% + ngram [55200/121136] 45.6% bpb=1.079000 ng_helped=27.6% + ngram [56000/121136] 46.2% bpb=1.077614 ng_helped=27.7% + ngram [56800/121136] 46.9% bpb=1.075670 ng_helped=27.8% + ngram [57600/121136] 47.5% bpb=1.074118 ng_helped=28.0% + ngram [58400/121136] 48.2% bpb=1.069693 ng_helped=28.1% + ngram [59200/121136] 48.9% bpb=1.068154 ng_helped=28.3% + ngram [60000/121136] 49.5% bpb=1.066859 ng_helped=28.4% + ngram [60800/121136] 50.2% bpb=1.065560 ng_helped=28.5% + ngram [61600/121136] 50.9% bpb=1.064208 ng_helped=28.7% + ngram [62400/121136] 51.5% bpb=1.063440 ng_helped=28.8% + ngram [63200/121136] 52.2% bpb=1.061871 ng_helped=28.9% + ngram [64000/121136] 52.8% bpb=1.060809 ng_helped=29.1% + ngram [64800/121136] 53.5% bpb=1.059535 ng_helped=29.2% + ngram [65600/121136] 54.2% bpb=1.057997 ng_helped=29.3% + ngram [66400/121136] 54.8% bpb=1.056070 ng_helped=29.5% + ngram [67200/121136] 55.5% bpb=1.054377 ng_helped=29.6% + ngram [68000/121136] 56.1% bpb=1.052902 ng_helped=29.7% + ngram [68800/121136] 56.8% bpb=1.051390 ng_helped=29.9% + ngram [69600/121136] 57.5% bpb=1.049795 ng_helped=30.0% + ngram [70400/121136] 58.1% bpb=1.048075 ng_helped=30.1% + ngram [71200/121136] 58.8% bpb=1.046751 ng_helped=30.3% + ngram [72000/121136] 59.4% bpb=1.045343 ng_helped=30.4% + ngram [72800/121136] 60.1% bpb=1.043957 ng_helped=30.5% + ngram [73600/121136] 60.8% bpb=1.042694 ng_helped=30.7% + ngram [74400/121136] 61.4% bpb=1.041624 ng_helped=30.8% + ngram [75200/121136] 62.1% bpb=1.040123 ng_helped=31.0% + ngram [76000/121136] 62.7% bpb=1.038311 ng_helped=31.1% + ngram [76800/121136] 63.4% bpb=1.037184 ng_helped=31.2% + ngram [77600/121136] 64.1% bpb=1.035965 ng_helped=31.4% + ngram [78400/121136] 64.7% bpb=1.034851 ng_helped=31.5% + ngram [79200/121136] 65.4% bpb=1.033318 ng_helped=31.6% + ngram [80000/121136] 66.0% bpb=1.032345 ng_helped=31.8% + ngram [80800/121136] 66.7% bpb=1.031279 ng_helped=31.9% + ngram [81600/121136] 67.4% bpb=1.029505 ng_helped=32.1% + ngram [82400/121136] 68.0% bpb=1.028642 ng_helped=32.2% + ngram [83200/121136] 68.7% bpb=1.027586 ng_helped=32.3% + ngram [84000/121136] 69.3% bpb=1.027444 ng_helped=32.5% + ngram [84800/121136] 70.0% bpb=1.026218 ng_helped=32.6% + ngram [85600/121136] 70.7% bpb=1.024033 ng_helped=32.8% + ngram [86400/121136] 71.3% bpb=1.022927 ng_helped=32.9% + ngram [87200/121136] 72.0% bpb=1.021745 ng_helped=33.0% + ngram [88000/121136] 72.6% bpb=1.020643 ng_helped=33.2% + ngram [88800/121136] 73.3% bpb=1.019385 ng_helped=33.3% + ngram [89600/121136] 74.0% bpb=1.018210 ng_helped=33.5% + ngram [90400/121136] 74.6% bpb=1.017084 ng_helped=33.6% + ngram [91200/121136] 75.3% bpb=1.015660 ng_helped=33.7% + ngram [92000/121136] 75.9% bpb=1.013968 ng_helped=33.9% + ngram [92800/121136] 76.6% bpb=1.012729 ng_helped=34.0% + ngram [93600/121136] 77.3% bpb=1.011485 ng_helped=34.1% + ngram [94400/121136] 77.9% bpb=1.010272 ng_helped=34.3% + ngram [95200/121136] 78.6% bpb=1.008944 ng_helped=34.4% + ngram [96000/121136] 79.2% bpb=1.007401 ng_helped=34.5% + ngram [96800/121136] 79.9% bpb=1.008109 ng_helped=34.7% + ngram [97600/121136] 80.6% bpb=1.006548 ng_helped=34.8% + ngram [98400/121136] 81.2% bpb=1.005288 ng_helped=35.0% + ngram [99200/121136] 81.9% bpb=1.003961 ng_helped=35.1% + ngram [100000/121136] 82.6% bpb=1.002459 ng_helped=35.2% + ngram [100800/121136] 83.2% bpb=1.001367 ng_helped=35.4% + ngram [101600/121136] 83.9% bpb=1.000385 ng_helped=35.5% + ngram [102400/121136] 84.5% bpb=0.998663 ng_helped=35.6% + ngram [103200/121136] 85.2% bpb=0.997303 ng_helped=35.8% + ngram [104000/121136] 85.9% bpb=0.995820 ng_helped=35.9% + ngram [104800/121136] 86.5% bpb=0.994175 ng_helped=36.0% + ngram [105600/121136] 87.2% bpb=0.992745 ng_helped=36.1% + ngram [106400/121136] 87.8% bpb=0.991497 ng_helped=36.3% + ngram [107200/121136] 88.5% bpb=0.990313 ng_helped=36.4% + ngram [108000/121136] 89.2% bpb=0.989167 ng_helped=36.5% + ngram [108800/121136] 89.8% bpb=0.988144 ng_helped=36.7% + ngram [109600/121136] 90.5% bpb=0.987056 ng_helped=36.8% + ngram [110400/121136] 91.1% bpb=0.985746 ng_helped=36.9% + ngram [111200/121136] 91.8% bpb=0.984592 ng_helped=37.1% + ngram [112000/121136] 92.5% bpb=0.983253 ng_helped=37.2% + ngram [112800/121136] 93.1% bpb=0.982418 ng_helped=37.3% + ngram [113600/121136] 93.8% bpb=0.981157 ng_helped=37.5% + ngram [114400/121136] 94.4% bpb=0.979868 ng_helped=37.6% + ngram [115200/121136] 95.1% bpb=0.978634 ng_helped=37.7% + ngram [116000/121136] 95.8% bpb=0.977444 ng_helped=37.8% + ngram [116800/121136] 96.4% bpb=0.976022 ng_helped=38.0% + ngram [117600/121136] 97.1% bpb=0.974973 ng_helped=38.1% + ngram [118400/121136] 97.7% bpb=0.973829 ng_helped=38.2% + ngram [119200/121136] 98.4% bpb=0.972683 ng_helped=38.4% + ngram [120000/121136] 99.1% bpb=0.971488 ng_helped=38.5% + ngram [120800/121136] 99.7% bpb=0.970429 ng_helped=38.6% +final_ngram val_loss:1.6283 val_bpb:0.9644 ngram_eval_time:936242ms +final_ngram_exact val_loss:1.62826393 val_bpb:0.96435217 diff --git a/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_seed42.log b/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_seed42.log new file mode 100644 index 000000000..6212a6911 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_NgramBackoff_VRL_LeakyReLU2/train_seed42.log @@ -0,0 +1,1876 @@ +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +import lzma +_COMPRESSOR = "lzma" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True +except ImportError: + try: + from flash_attn.flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + try: + from flash_attn import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + _HAS_FA3 = False + flash_attn_3_func = None +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + mtp_num_heads = int(os.environ.get("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(os.environ.get("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) # tighter: collect more recent checkpoints + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) # XSA on last 4 layers (0 = disabled) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + vrl_enabled = bool(int(os.environ.get("VRL_ENABLED", "1"))) +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 31.0).clamp_min(1.0 / 31.0) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + self.vrl_gate = None # set by GPT.__init__ when VRL is enabled + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] — broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, v_embed: Tensor | None = None, v_first: Tensor | None = None) -> tuple[Tensor, Tensor]: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + v_raw = v # save raw V before any modifications for VRL + if v_embed is not None: + v = v + v_embed + if v_first is not None and self.vrl_gate is not None: + gate = torch.sigmoid(self.vrl_gate.to(dtype=v.dtype)) + v = (1 - gate) * v + gate * v_first + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _HAS_FA3: + y = flash_attn_3_func(q, k, v, causal=True) + else: + q2 = q.transpose(1, 2) + k2 = k.transpose(1, 2) + v2 = v.transpose(1, 2) + k2 = k2.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + v2 = v2.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + y = F.scaled_dot_product_attention(q2, k2, v2, is_causal=True) + y = y.transpose(1, 2).contiguous() + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y), v_raw +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x: Tensor) -> Tensor: + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + return self.proj(x.square()) +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None, v_first: Tensor | None = None) -> tuple[Tensor, Tensor]: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, v_raw = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, v_first=v_first) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, v_raw +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + mtp_num_heads: int = 0, + mtp_loss_weight: float = 0.1, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + dtg: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + vrl_enabled: bool = False, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + dtg=dtg, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CastedLinear(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self.vrl_enabled = vrl_enabled + if vrl_enabled: + for i in range(1, num_layers): # all layers except layer 0 + self.blocks[i].attn.vrl_gate = nn.Parameter(torch.tensor(-1.5, dtype=torch.float32)) + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + v_first: Tensor | None = None + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, v_raw = self.blocks[i](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + if i == 0 and self.vrl_enabled: + v_first = v_raw + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def forward_logits(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + v_first: Tensor | None = None + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, v_raw = self.blocks[i](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + if i == 0 and self.vrl_enabled: + v_first = v_raw + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, v_embed=ve, v_first=v_first if self.vrl_enabled else None) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +class NgramBackoffCache: + PRIMES = [36313, 27191, 51647, 81929, 131071, 175447, 209591] + def __init__(self, vocab_size: int, min_order: int = 2, max_order: int = 7, + hash_size: int = 4194304, min_count: int = 2): + self.V = vocab_size + self.min_order = min_order + self.max_order = max_order + self.n_orders = max_order - min_order + 1 + self.H = hash_size + self.mask = hash_size - 1 + self.min_count = min_count + self.ctx_tables = [np.zeros(hash_size, dtype=np.uint32) for _ in range(self.n_orders)] + self.full_tables = [np.zeros(hash_size, dtype=np.uint32) for _ in range(self.n_orders)] + def _hash_ctx(self, tokens: np.ndarray, pos: int, ctx_w: int) -> int: + h = 0 + for k in range(ctx_w): + h ^= int(tokens[pos - ctx_w + k]) * self.PRIMES[k % 7] + return h & self.mask + def _hash_full(self, ctx_hash: int, target: int, ctx_w: int) -> int: + return (ctx_hash ^ (target * self.PRIMES[ctx_w % 7])) & self.mask + def update(self, tokens: np.ndarray, start: int, end: int) -> None: + for i in range(start, end): + target = int(tokens[i]) + for oi in range(self.n_orders): + ctx_w = oi + self.min_order - 1 + if i < ctx_w: + continue + ctx_h = self._hash_ctx(tokens, i, ctx_w) + full_h = self._hash_full(ctx_h, target, ctx_w) + self.ctx_tables[oi][ctx_h] += 1 + self.full_tables[oi][full_h] += 1 + def predict(self, tokens: np.ndarray, pos: int, target: int) -> tuple[float, bool]: + for oi in range(self.n_orders - 1, -1, -1): + ctx_w = oi + self.min_order - 1 + if pos < ctx_w: + continue + ctx_h = self._hash_ctx(tokens, pos, ctx_w) + ctx_count = self.ctx_tables[oi][ctx_h] + if ctx_count < self.min_count: + continue + full_h = self._hash_full(ctx_h, target, ctx_w) + full_count = self.full_tables[oi][full_h] + p = min(full_count, ctx_count) / max(ctx_count, 1) + return max(min(p, 1.0), 0.0), True + return 0.0, False +def eval_val_ngram_backoff( + args, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, + ngram_order: int = 7, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + vocab_size = args.vocab_size + cache = NgramBackoffCache(vocab_size, min_order=2, max_order=ngram_order, + hash_size=4194304, min_count=2) + window_starts = sorted([ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1]) + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + all_tokens = val_tokens.cpu().numpy().astype(np.int32) + scored_up_to = my_windows[0] if my_windows else 0 + ngram_helped = 0 + ngram_total = 0 + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + log_probs = torch.log_softmax(logits.float(), dim=-1) + probs = torch.exp(log_probs) + entropy = -(probs * log_probs).sum(dim=-1) + nll = F.cross_entropy(logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction='none').reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + score_len = wlen - s + if score_len <= 0: + continue + for t in range(s, wlen): + token_pos = ws + t + 1 + tgt = int(y_batch[i, t].item()) + prev = int(x_batch[i, t].item()) + model_nll = nll[i, t].item() + model_p = math.exp(-model_nll) + H = entropy[i, t].item() + alpha = 0.05 + 0.55 / (1.0 + math.exp(-2.0 * (H - 4.0))) + ng_p, has_ng = cache.predict(all_tokens, token_pos, tgt) + if has_ng: + mixed_p = max((1.0 - alpha) * model_p + alpha * ng_p, 1e-12) + scored_nll = -math.log(mixed_p) + ngram_total += 1 + if scored_nll < model_nll: + ngram_helped += 1 + else: + scored_nll = model_nll + loss_sum += scored_nll + token_count += 1.0 + tb = float(base_bytes_lut[tgt].item()) + tb += float((has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).item()) + byte_count += tb + new_end = ws + wlen + 1 + if new_end > scored_up_to: + cache.update(all_tokens, scored_up_to, new_end) + scored_up_to = new_end + if rank == 0 and bi % 100 == 0: + running_bpb = ((loss_sum / token_count) / math.log(2.0) * token_count / byte_count).item() if token_count > 0 else 0 + pct = 100.0 * bi / max(len(my_windows), 1) + hit_rate = ngram_helped / max(ngram_total, 1) * 100 + log0(f" ngram [{bi}/{len(my_windows)}] {pct:.1f}% bpb={running_bpb:.6f} ng_helped={hit_rate:.1f}%") + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.to(torch.float16) + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, + mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + vrl_enabled=args.vrl_enabled, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.mtp_num_heads > 0: + matrix_params.extend([p for p in base_model.mtp_heads.parameters() if p.ndim == 2]) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + n_params = sum(p.numel() for p in base_model.parameters()) + mtp_params = sum(p.numel() for p in base_model.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + full_state_dict = base_model.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"export_excluding_mtp_params:{excluded_mtp}") + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) if _COMPRESSOR == "lzma" else zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + log0(f"Total submission size: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk) if _COMPRESSOR == "lzma" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, # must match training model + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + vrl_enabled=args.vrl_enabled, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + compiled_eval = torch.compile(eval_model, dynamic=False, fullgraph=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int6_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int6_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + ngram_enabled = bool(int(os.environ.get("NGRAM_ENABLED", "0"))) + if ngram_enabled: + log0("Starting n-gram backoff eval (PR #727 approach: entropy-adaptive alpha, orders 2-7)...") + torch.cuda.synchronize() + t_ng = time.perf_counter() + ng_val_loss, ng_val_bpb = eval_val_ngram_backoff( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride if args.eval_stride > 0 else 64, + batch_seqs=32, log0=log0, + ngram_order=int(os.environ.get("NGRAM_ORDER", "7")), + ) + torch.cuda.synchronize() + log0(f"final_ngram val_loss:{ng_val_loss:.4f} val_bpb:{ng_val_bpb:.4f} " + f"ngram_eval_time:{1000.0 * (time.perf_counter() - t_ng):.0f}ms") + log0(f"final_ngram_exact val_loss:{ng_val_loss:.8f} val_bpb:{ng_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0] +Running PyTorch 2.9.1+cu128 +Thu Mar 26 17:51:51 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.195.03 Driver Version: 570.195.03 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA H100 80GB HBM3 On | 00000000:05:00.0 Off | 0 | +| N/A 40C P0 122W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA H100 80GB HBM3 On | 00000000:06:00.0 Off | 0 | +| N/A 33C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA H100 80GB HBM3 On | 00000000:07:00.0 Off | 0 | +| N/A 41C P0 121W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA H100 80GB HBM3 On | 00000000:08:00.0 Off | 0 | +| N/A 39C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 4 NVIDIA H100 80GB HBM3 On | 00000000:09:00.0 Off | 0 | +| N/A 39C P0 124W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 5 NVIDIA H100 80GB HBM3 On | 00000000:0A:00.0 Off | 0 | +| N/A 33C P0 117W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 6 NVIDIA H100 80GB HBM3 On | 00000000:0B:00.0 Off | 0 | +| N/A 31C P0 116W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 7 NVIDIA H100 80GB HBM3 On | 00000000:0C:00.0 Off | 0 | +| N/A 39C P0 118W / 700W | 1519MiB / 81559MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 72537 C /usr/local/bin/python 1510MiB | +| 1 N/A N/A 72538 C /usr/local/bin/python 1510MiB | +| 2 N/A N/A 72539 C /usr/local/bin/python 1510MiB | +| 3 N/A N/A 72540 C /usr/local/bin/python 1510MiB | +| 4 N/A N/A 72541 C /usr/local/bin/python 1510MiB | +| 5 N/A N/A 72542 C /usr/local/bin/python 1510MiB | +| 6 N/A N/A 72543 C /usr/local/bin/python 1510MiB | +| 7 N/A N/A 72544 C /usr/local/bin/python 1510MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:26993766 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9301 val_bpb:4.1044 train_time:0ms step_avg:0.01ms +step:1/20000 train_loss:6.9318 train_time:145ms step_avg:144.63ms +step:2/20000 train_loss:8.6439 train_time:226ms step_avg:113.21ms +step:3/20000 train_loss:7.8536 train_time:313ms step_avg:104.30ms +step:4/20000 train_loss:7.2663 train_time:399ms step_avg:99.69ms +step:5/20000 train_loss:7.0299 train_time:485ms step_avg:96.95ms +step:6/20000 train_loss:6.9113 train_time:571ms step_avg:95.10ms +step:7/20000 train_loss:6.7782 train_time:657ms step_avg:93.79ms +step:8/20000 train_loss:6.7065 train_time:743ms step_avg:92.85ms +step:9/20000 train_loss:6.4178 train_time:829ms step_avg:92.11ms +step:10/20000 train_loss:6.0787 train_time:915ms step_avg:91.52ms +step:500/20000 train_loss:2.3693 train_time:43976ms step_avg:87.95ms +step:1000/20000 train_loss:2.2588 train_time:88187ms step_avg:88.19ms +step:1500/20000 train_loss:2.2051 train_time:132460ms step_avg:88.31ms +step:2000/20000 train_loss:2.0474 train_time:176820ms step_avg:88.41ms +step:2500/20000 train_loss:2.1515 train_time:221183ms step_avg:88.47ms +step:3000/20000 train_loss:2.1465 train_time:265475ms step_avg:88.49ms +step:3500/20000 train_loss:2.1650 train_time:309730ms step_avg:88.49ms +step:4000/20000 train_loss:1.9565 train_time:353984ms step_avg:88.50ms +step:4000/20000 val_loss:2.0460 val_bpb:1.2118 train_time:353988ms step_avg:88.50ms +step:4500/20000 train_loss:2.1025 train_time:398260ms step_avg:88.50ms +step:5000/20000 train_loss:2.0876 train_time:442577ms step_avg:88.52ms +step:5500/20000 train_loss:2.0011 train_time:486906ms step_avg:88.53ms +step:6000/20000 train_loss:1.9234 train_time:531210ms step_avg:88.53ms +swa:start step:6100 +late_qat:enabled step:6250 scale:0.1499 +step:6500/20000 train_loss:2.0592 train_time:575790ms step_avg:88.58ms +step:6772/20000 val_loss:1.9234 val_bpb:1.1391 train_time:600075ms step_avg:88.61ms +stopping_early: wallclock_cap train_time:600075ms step:6772/20000 +peak memory allocated: 21149 MiB reserved: 21204 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9218 val_bpb:1.1382 eval_time:2040ms +Serialized model: 106181533 bytes +Code size: 67048 bytes +Serialized model int6+lzma: 15837584 bytes +Total submission size int6+lzma: 15904632 bytes +Total submission size: 15904632 bytes +final_int6_roundtrip val_loss:1.9350 val_bpb:1.1460 eval_time:9392ms +final_int6_roundtrip_exact val_loss:1.93501238 val_bpb:1.14602333 +final_int6_sliding_window val_loss:1.8952 val_bpb:1.1224 stride:64 eval_time:77655ms +final_int6_sliding_window_exact val_loss:1.89516849 val_bpb:1.12242850 +final_int6_roundtrip_exact val_loss:1.89516849 val_bpb:1.12242850 +Starting n-gram backoff eval (PR #727 approach: entropy-adaptive alpha, orders 2-7)... + ngram [0/121136] 0.0% bpb=1.208373 ng_helped=10.0% + ngram [800/121136] 0.7% bpb=1.225724 ng_helped=17.5% + ngram [1600/121136] 1.3% bpb=1.153556 ng_helped=18.1% + ngram [2400/121136] 2.0% bpb=1.168917 ng_helped=17.9% + ngram [3200/121136] 2.6% bpb=1.154764 ng_helped=18.2% + ngram [4000/121136] 3.3% bpb=1.151207 ng_helped=18.3% + ngram [4800/121136] 4.0% bpb=1.145922 ng_helped=18.6% + ngram [5600/121136] 4.6% bpb=1.147400 ng_helped=18.7% + ngram [6400/121136] 5.3% bpb=1.153926 ng_helped=19.4% + ngram [7200/121136] 5.9% bpb=1.152562 ng_helped=19.7% + ngram [8000/121136] 6.6% bpb=1.152201 ng_helped=19.7% + ngram [8800/121136] 7.3% bpb=1.156621 ng_helped=19.8% + ngram [9600/121136] 7.9% bpb=1.151909 ng_helped=19.9% + ngram [10400/121136] 8.6% bpb=1.148909 ng_helped=20.1% + ngram [11200/121136] 9.2% bpb=1.145281 ng_helped=20.2% + ngram [12000/121136] 9.9% bpb=1.142727 ng_helped=20.4% + ngram [12800/121136] 10.6% bpb=1.140589 ng_helped=20.4% + ngram [13600/121136] 11.2% bpb=1.138182 ng_helped=20.5% + ngram [14400/121136] 11.9% bpb=1.139977 ng_helped=20.6% + ngram [15200/121136] 12.5% bpb=1.149720 ng_helped=20.8% + ngram [16000/121136] 13.2% bpb=1.145642 ng_helped=20.9% + ngram [16800/121136] 13.9% bpb=1.144252 ng_helped=21.0% + ngram [17600/121136] 14.5% bpb=1.141169 ng_helped=21.2% + ngram [18400/121136] 15.2% bpb=1.139722 ng_helped=21.3% + ngram [19200/121136] 15.8% bpb=1.139873 ng_helped=21.5% + ngram [20000/121136] 16.5% bpb=1.137493 ng_helped=21.6% + ngram [20800/121136] 17.2% bpb=1.135820 ng_helped=21.7% + ngram [21600/121136] 17.8% bpb=1.133718 ng_helped=21.9% + ngram [22400/121136] 18.5% bpb=1.131817 ng_helped=22.0% + ngram [23200/121136] 19.2% bpb=1.128078 ng_helped=22.1% + ngram [24000/121136] 19.8% bpb=1.129620 ng_helped=22.3% + ngram [24800/121136] 20.5% bpb=1.128345 ng_helped=22.4% + ngram [25600/121136] 21.1% bpb=1.128308 ng_helped=22.6% + ngram [26400/121136] 21.8% bpb=1.126705 ng_helped=22.7% + ngram [27200/121136] 22.5% bpb=1.125997 ng_helped=22.8% + ngram [28000/121136] 23.1% bpb=1.128677 ng_helped=23.0% + ngram [28800/121136] 23.8% bpb=1.129097 ng_helped=23.1% + ngram [29600/121136] 24.4% bpb=1.127482 ng_helped=23.2% + ngram [30400/121136] 25.1% bpb=1.124179 ng_helped=23.4% + ngram [31200/121136] 25.8% bpb=1.123103 ng_helped=23.5% + ngram [32000/121136] 26.4% bpb=1.122496 ng_helped=23.6% + ngram [32800/121136] 27.1% bpb=1.120551 ng_helped=23.8% + ngram [33600/121136] 27.7% bpb=1.118462 ng_helped=23.9% + ngram [34400/121136] 28.4% bpb=1.116510 ng_helped=24.0% + ngram [35200/121136] 29.1% bpb=1.115209 ng_helped=24.1% + ngram [36000/121136] 29.7% bpb=1.114291 ng_helped=24.3% + ngram [36800/121136] 30.4% bpb=1.112043 ng_helped=24.4% + ngram [37600/121136] 31.0% bpb=1.110989 ng_helped=24.5% + ngram [38400/121136] 31.7% bpb=1.109886 ng_helped=24.7% + ngram [39200/121136] 32.4% bpb=1.106724 ng_helped=24.9% + ngram [40000/121136] 33.0% bpb=1.104986 ng_helped=25.0% + ngram [40800/121136] 33.7% bpb=1.102085 ng_helped=25.2% + ngram [41600/121136] 34.3% bpb=1.101041 ng_helped=25.4% + ngram [42400/121136] 35.0% bpb=1.100019 ng_helped=25.5% + ngram [43200/121136] 35.7% bpb=1.098775 ng_helped=25.6% + ngram [44000/121136] 36.3% bpb=1.096446 ng_helped=25.8% + ngram [44800/121136] 37.0% bpb=1.094844 ng_helped=25.9% + ngram [45600/121136] 37.6% bpb=1.093012 ng_helped=26.0% + ngram [46400/121136] 38.3% bpb=1.092039 ng_helped=26.1% + ngram [47200/121136] 39.0% bpb=1.090017 ng_helped=26.3% + ngram [48000/121136] 39.6% bpb=1.088681 ng_helped=26.4% + ngram [48800/121136] 40.3% bpb=1.087207 ng_helped=26.5% + ngram [49600/121136] 40.9% bpb=1.086918 ng_helped=26.7% + ngram [50400/121136] 41.6% bpb=1.086003 ng_helped=26.8% + ngram [51200/121136] 42.3% bpb=1.085049 ng_helped=26.9% + ngram [52000/121136] 42.9% bpb=1.083765 ng_helped=27.0% + ngram [52800/121136] 43.6% bpb=1.082819 ng_helped=27.2% + ngram [53600/121136] 44.2% bpb=1.080689 ng_helped=27.3% + ngram [54400/121136] 44.9% bpb=1.079709 ng_helped=27.4% + ngram [55200/121136] 45.6% bpb=1.078696 ng_helped=27.6% + ngram [56000/121136] 46.2% bpb=1.077299 ng_helped=27.7% + ngram [56800/121136] 46.9% bpb=1.075361 ng_helped=27.8% + ngram [57600/121136] 47.5% bpb=1.073807 ng_helped=28.0% + ngram [58400/121136] 48.2% bpb=1.069375 ng_helped=28.1% + ngram [59200/121136] 48.9% bpb=1.067833 ng_helped=28.3% + ngram [60000/121136] 49.5% bpb=1.066522 ng_helped=28.4% + ngram [60800/121136] 50.2% bpb=1.065221 ng_helped=28.5% + ngram [61600/121136] 50.9% bpb=1.063845 ng_helped=28.6% + ngram [62400/121136] 51.5% bpb=1.063073 ng_helped=28.8% + ngram [63200/121136] 52.2% bpb=1.061504 ng_helped=28.9% + ngram [64000/121136] 52.8% bpb=1.060444 ng_helped=29.1% + ngram [64800/121136] 53.5% bpb=1.059176 ng_helped=29.2% + ngram [65600/121136] 54.2% bpb=1.057626 ng_helped=29.3% + ngram [66400/121136] 54.8% bpb=1.055691 ng_helped=29.5% + ngram [67200/121136] 55.5% bpb=1.053988 ng_helped=29.6% + ngram [68000/121136] 56.1% bpb=1.052525 ng_helped=29.7% + ngram [68800/121136] 56.8% bpb=1.051026 ng_helped=29.9% + ngram [69600/121136] 57.5% bpb=1.049437 ng_helped=30.0% + ngram [70400/121136] 58.1% bpb=1.047703 ng_helped=30.1% + ngram [71200/121136] 58.8% bpb=1.046360 ng_helped=30.3% + ngram [72000/121136] 59.4% bpb=1.044943 ng_helped=30.4% + ngram [72800/121136] 60.1% bpb=1.043544 ng_helped=30.5% + ngram [73600/121136] 60.8% bpb=1.042280 ng_helped=30.7% + ngram [74400/121136] 61.4% bpb=1.041214 ng_helped=30.8% + ngram [75200/121136] 62.1% bpb=1.039709 ng_helped=31.0% + ngram [76000/121136] 62.7% bpb=1.037902 ng_helped=31.1% + ngram [76800/121136] 63.4% bpb=1.036785 ng_helped=31.2% + ngram [77600/121136] 64.1% bpb=1.035565 ng_helped=31.4% + ngram [78400/121136] 64.7% bpb=1.034458 ng_helped=31.5% + ngram [79200/121136] 65.4% bpb=1.032924 ng_helped=31.6% + ngram [80000/121136] 66.0% bpb=1.031955 ng_helped=31.8% + ngram [80800/121136] 66.7% bpb=1.030891 ng_helped=31.9% + ngram [81600/121136] 67.4% bpb=1.029134 ng_helped=32.1% + ngram [82400/121136] 68.0% bpb=1.028245 ng_helped=32.2% + ngram [83200/121136] 68.7% bpb=1.027199 ng_helped=32.3% + ngram [84000/121136] 69.3% bpb=1.027062 ng_helped=32.5% + ngram [84800/121136] 70.0% bpb=1.025846 ng_helped=32.6% + ngram [85600/121136] 70.7% bpb=1.023642 ng_helped=32.8% + ngram [86400/121136] 71.3% bpb=1.022507 ng_helped=32.9% + ngram [87200/121136] 72.0% bpb=1.021320 ng_helped=33.0% + ngram [88000/121136] 72.6% bpb=1.020211 ng_helped=33.2% + ngram [88800/121136] 73.3% bpb=1.018960 ng_helped=33.3% + ngram [89600/121136] 74.0% bpb=1.017771 ng_helped=33.5% + ngram [90400/121136] 74.6% bpb=1.016650 ng_helped=33.6% + ngram [91200/121136] 75.3% bpb=1.015227 ng_helped=33.7% + ngram [92000/121136] 75.9% bpb=1.013524 ng_helped=33.9% + ngram [92800/121136] 76.6% bpb=1.012291 ng_helped=34.0% + ngram [93600/121136] 77.3% bpb=1.011056 ng_helped=34.1% + ngram [94400/121136] 77.9% bpb=1.009855 ng_helped=34.3% + ngram [95200/121136] 78.6% bpb=1.008533 ng_helped=34.4% + ngram [96000/121136] 79.2% bpb=1.007002 ng_helped=34.5% + ngram [96800/121136] 79.9% bpb=1.007708 ng_helped=34.7% + ngram [97600/121136] 80.6% bpb=1.006160 ng_helped=34.8% + ngram [98400/121136] 81.2% bpb=1.004899 ng_helped=35.0% + ngram [99200/121136] 81.9% bpb=1.003571 ng_helped=35.1% + ngram [100000/121136] 82.6% bpb=1.002066 ng_helped=35.2% + ngram [100800/121136] 83.2% bpb=1.000966 ng_helped=35.4% + ngram [101600/121136] 83.9% bpb=0.999990 ng_helped=35.5% + ngram [102400/121136] 84.5% bpb=0.998274 ng_helped=35.6% + ngram [103200/121136] 85.2% bpb=0.996918 ng_helped=35.8% + ngram [104000/121136] 85.9% bpb=0.995432 ng_helped=35.9% + ngram [104800/121136] 86.5% bpb=0.993797 ng_helped=36.0% + ngram [105600/121136] 87.2% bpb=0.992372 ng_helped=36.2% + ngram [106400/121136] 87.8% bpb=0.991142 ng_helped=36.3% + ngram [107200/121136] 88.5% bpb=0.989970 ng_helped=36.4% + ngram [108000/121136] 89.2% bpb=0.988818 ng_helped=36.5% + ngram [108800/121136] 89.8% bpb=0.987800 ng_helped=36.7% + ngram [109600/121136] 90.5% bpb=0.986727 ng_helped=36.8% + ngram [110400/121136] 91.1% bpb=0.985415 ng_helped=36.9% + ngram [111200/121136] 91.8% bpb=0.984266 ng_helped=37.1% + ngram [112000/121136] 92.5% bpb=0.982924 ng_helped=37.2% + ngram [112800/121136] 93.1% bpb=0.982080 ng_helped=37.3% + ngram [113600/121136] 93.8% bpb=0.980825 ng_helped=37.5% + ngram [114400/121136] 94.4% bpb=0.979543 ng_helped=37.6% + ngram [115200/121136] 95.1% bpb=0.978313 ng_helped=37.7% + ngram [116000/121136] 95.8% bpb=0.977125 ng_helped=37.8% + ngram [116800/121136] 96.4% bpb=0.975686 ng_helped=38.0% + ngram [117600/121136] 97.1% bpb=0.974644 ng_helped=38.1% + ngram [118400/121136] 97.7% bpb=0.973492 ng_helped=38.2% + ngram [119200/121136] 98.4% bpb=0.972345 ng_helped=38.4% + ngram [120000/121136] 99.1% bpb=0.971156 ng_helped=38.5% + ngram [120800/121136] 99.7% bpb=0.970093 ng_helped=38.6% +final_ngram val_loss:1.6279 val_bpb:0.9641 ngram_eval_time:890878ms +final_ngram_exact val_loss:1.62788498 val_bpb:0.96412773 From 50ec6bce1d6722caa8d20ad6f6f53fbec9abfdae Mon Sep 17 00:00:00 2001 From: Anthony Date: Thu, 26 Mar 2026 15:12:56 -0400 Subject: [PATCH 2/2] Remove private files from submission branch --- .private/council_brief_mar25_evening.md | 100 ----------------- .private/council_brief_mar25_night.md | 87 --------------- .private/substack_day6_draft.md | 141 ------------------------ 3 files changed, 328 deletions(-) delete mode 100644 .private/council_brief_mar25_evening.md delete mode 100644 .private/council_brief_mar25_night.md delete mode 100644 .private/substack_day6_draft.md diff --git a/.private/council_brief_mar25_evening.md b/.private/council_brief_mar25_evening.md deleted file mode 100644 index ea1c5e140..000000000 --- a/.private/council_brief_mar25_evening.md +++ /dev/null @@ -1,100 +0,0 @@ -# Parameter Golf Council Brief — March 25, 2026 (Evening) - -## Situation Update - -We have a full run in progress on 8xH100 with the complete new stack: -- LeakyReLU(0.5)² + VRL + Gated Attention + BigramHash 3072 + CROWN-Q + lzma -- AdamW TTT (PR #688 recipe: lr=1e-4, 9 frozen blocks, Polyak averaging, cosine LR) -- FA3 Hopper (installed via pre-built wheel in 30 seconds!) - -Benchmark shows ~106ms/step at step 30, expected to settle to ~87-95ms after torch.compile warmup. Results in ~20 min. - -## Key Findings Since Last Brief - -### 1. Pod Lottery is MASSIVE -Same GPU SKU (H100 SXM), same template, wildly different speeds: -- US-NE-1 pods: ~87ms/step (our best runs, 1.1229 bpb) -- India pods (some): ~87-106ms/step (usable) -- Japan pods: 260-320ms/step (3-4x slower, unusable) - -This means the competition leaderboard is partly a hardware lottery. Whoever gets a fast pod gets ~2,000 more training steps in 10 min. - -### 2. FA3 Pre-Built Wheel Works -`pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291` -Installs in 30 seconds. We spent ~$100 and 10+ hours building from source before discovering this. We've published the from-source build as a GitHub release for the community: https://github.com/anthony-maio/openai-parameter-golf-fa3-wheel/releases/tag/v1.0 - -### 3. Full GPTQ is ILLEGAL (issue #677) -Multiple PRs disqualified for using Hessian-based GPTQ with calibration data during eval. GPTQ-lite (clip search, no calibration data) remains legal. This invalidated the council's previous top recommendation. - -### 4. Our VRL is Spreading -- PR #745 (1.0222 bpb) credits us directly for VRL -- ChideraIbe123 adopted our VRL implementation verbatim -- Validates that VRL is a real, composable gain - -### 5. New Techniques Researched - -**Gated Attention (GA)**: Per-head sigmoid gate on attention output. ~0.002-0.003 bpb gain. 6 lines of code. Stacks additively with VRL. Implemented. (PR #638) - -**CROWN-Q**: Curvature-weighted quantization penalty during warmdown. 10 lines. Training-time only (legal). Pushes weights toward flat minima where int6 rounding hurts less. Implemented. (PR #693) - -**Hedge Mixer TTT (PR #688)**: 5-expert online ensemble (neural + unigram/bigram/trigram + entropy) using Hedge algorithm. Gets -0.05 bpb combined with AdamW TTT. Key finding: PR #688 uses **AdamW(lr=1e-4)** not SGD(lr=0.002), and only unfreezes **last 2 blocks** (9 frozen). This is likely why our SGD TTT failed (20x higher LR, all blocks unfrozen). - -## Questions for the Council - -### Q1: PR Strategy — Update or New PR? - -Our current PR #657 shows val_bpb=1.1229 (3-seed mean). If the current run succeeds with the new stack (GA + BH3072 + CROWN-Q + AdamW TTT), we'll have a significantly better number. Options: - -A) **Update PR #657** with new results (same branch, just push new code + logs). Keeps our timestamp advantage but changes the submission significantly. - -B) **Close PR #657 and open a new PR**. Clean slate, clear description of the new stack. But we lose timestamp priority. - -C) **Keep PR #657 as-is (non-record) and open a new record PR**. Shows progression. But rules say only one open record PR at a time. - -Which is strategically optimal? Does the timestamp matter for the leaderboard? - -### Q2: If AdamW TTT Works — Expected Ceiling? - -Our previous SGD TTT failed (bpb went UP). The AdamW recipe from PR #688 uses: -- AdamW lr=1e-4 (vs our SGD lr=0.002 — 20x lower) -- 9 frozen blocks (vs 0 — protects VRL gates) -- Polyak averaging (decay=0.998) for scoring stability -- Cosine LR decay across chunks - -PR #688 gets -0.05 bpb from their full TTT+mixer stack. Realistically, without the Hedge Mixer, what should we expect from AdamW TTT alone on our base? -0.01? -0.02? -0.05? - -### Q3: Hedge Mixer — Should We Implement It? - -The Hedge Mixer is ~170 lines, self-contained, operates purely on logits (doesn't touch model weights). It runs 5 "experts" and blends their predictions online: -- Expert 0: Neural model log-softmax -- Expert 1: Unigram frequency table (from scored tokens) -- Expert 2: Bigram P(next|prev) table -- Expert 3: Trigram hash table (64K buckets) -- Expert 4: Entropy regularizer - -The n-gram tables are built incrementally from already-scored tokens only. The Hedge algorithm updates expert weights via multiplicative weights (no gradients). - -Is this legal under issue #677? The n-gram tables are built from validation tokens that have already been scored — similar to the contested n-gram caching techniques. If it's legal, this could be a massive gain on top of AdamW TTT. - -### Q4: What's the Realistic Frontier We Should Target? - -Given: -- Merged SOTA: 1.1194 (PR #549) -- Frontier with legal TTT: ~1.10-1.12 -- Frontier with Hedge Mixer: ~1.02-1.07 (legality debated) -- N-gram caching frontier: sub-1.0 (legality heavily debated) -- Our current: 1.1229 (no TTT) - -Where should we aim? Is 1.10 achievable with legal techniques, or should we target 1.115-1.118 as our realistic ceiling? - -### Q5: Competition Meta — Is It Worth Chasing SOTA? - -The competition runs until April 30 (5 more weeks). New techniques are appearing daily. PRs are getting disqualified regularly. Is the optimal strategy: -A) Submit our best number now and iterate weekly -B) Go heads-down on implementation and submit one strong PR near the deadline -C) Focus on non-record submissions with novel techniques (custom kernels, depth recurrence) since those get accepted more easily - -Our budget is ~$60 remaining. That's ~4 full 3-seed runs. - -## Current Run Status -Training on India H100 SXM x8, ~106ms/step benchmark. Full stack with AdamW TTT enabled. Results expected in ~20 min. diff --git a/.private/council_brief_mar25_night.md b/.private/council_brief_mar25_night.md deleted file mode 100644 index 99abf23f6..000000000 --- a/.private/council_brief_mar25_night.md +++ /dev/null @@ -1,87 +0,0 @@ -# Parameter Golf Council Brief — March 25, 2026 (End of Day) - -## Where We Stand - -**PR #175**: val_bpb = 1.1229 (3-seed mean, std 0.0005). March 19 timestamp. Clean, valid, pending review. - -**Budget**: ~$30 remaining (~2 full runs). - -**Competition frontier**: 0.9625 (n-gram cache), 1.0222 (Hedge Mixer + our VRL). Our 1.1229 is pure neural, no eval-time tricks. - -## What We Tried Today (All Failed or Net Zero) - -| Experiment | Result | Verdict | -|---|---|---| -| Gated Attention | 1.1239 vs 1.1229 base (+4ms overhead) | Net zero. Strip. | -| BigramHash 3072 | Artifact 16.12MB (over limit) | Doesn't fit. Keep 2048. | -| CROWN-Q quant penalty | Bundled with GA, no isolated gain | Net zero. Strip. | -| AdamW TTT (PR #688 recipe) | running_bpb never dropped below pre-TTT over 700 chunks | Dead. VRL stack rejects all weight-modifying TTT. | -| N-gram cache (our impl) | 1.167 bpb vs 1.124 base — WORSE | Alpha=0.3 too aggressive. Needs entropy-adaptive mixing. | - -## The Key Insight: Hedge Mixer Bypasses Our TTT Problem - -All our TTT failures share one root cause: modifying model weights mid-eval destabilizes VRL gates and SmearGate's compiled state. - -Hedge Mixer (PR #688) doesn't modify model weights. It only updates scalar mixing weights via multiplicative updates. The transformer stays frozen. This bypasses every failure mode we've hit: -- No VRL gate desync (weights unchanged) -- No compiled graph invalidation (no weight mutations) -- No optimizer state issues (Hedge uses loss-based updates, not gradients) - -PR #745 (1.0222 bpb) uses Hedge Mixer ON TOP of our VRL and gets massive gains. Their pre-TTT is 1.1348, ours is 1.1229 — our base is stronger. - -## What We Need the Council to Research - -### 1. PR #727's Exact N-gram Mixing Formula -Our n-gram cache implementation uses fixed alpha=0.3 which makes things worse. PR #727 gets 0.9674 with "entropy-adaptive alpha." We need: -- The exact formula for computing alpha per token -- How they handle the cold-start problem (few tokens scored = weak n-gram stats) -- Whether they use backoff (try 7-gram, fall back to 6, 5, ... unigram) or blend all orders simultaneously -- The exact mixing: linear interpolation `(1-a)*p_neural + a*p_ngram` or log-domain `logsumexp`? - -### 2. Hedge Algorithm Implementation Details -From PR #688's 5-expert Hedge Mixer: -- How are expert weights initialized? (neural bias=2.0, others=0?) -- What learning rate (eta) for the multiplicative update? -- Is the update `log_w -= eta * loss` or `w *= exp(-eta * loss)`? -- How does the entropy expert work? It's not a proper probability model. -- Do they normalize expert weights after each update? - -### 3. Legal Compliance: Causal vs Precomputed N-gram Tables -The council flagged this: building n-gram tables from already-scored eval tokens is clearly legal. But what about: -- Can we hash bigrams/trigrams into a fixed-size table or does it need to be exact counts? -- Is there a minimum count threshold before we trust an n-gram (e.g., count >= 2)? -- Do the top PRs (#727, #753) use smoothing (add-alpha) or raw counts? - -### 4. Liger-Kernel Compatibility -The council recommended `pip install liger-kernel` for 20-43% throughput. But: -- Does it work with our custom CastedLinear (fp32 weights, bf16 forward)? -- Does it conflict with torch.compile? -- Does it work on the RunPod parameter-golf template (PyTorch 2.9.1)? -- Which specific ops should we fuse? (RMSNorm, linear+CE, residual+norm) - -### 5. Can We Beat PR #745 With Just Hedge + Our Better Base? -PR #745's stack: -- Pre-TTT: 1.1348 (their neural base) -- Post-TTT with Hedge: 1.0222 - -Our pre-TTT: 1.1229 (0.012 better neural base) - -If Hedge gives the same absolute delta, we'd hit ~1.0100. But there might be diminishing returns — a better base means less room for n-gram improvement. What should we realistically expect? - -## Proposed Plan for Tomorrow - -1. **Implement Hedge Mixer** (~170 lines, offline, $0) -2. **Add Liger-Kernel** to setup ($0) -3. **One test run** on fast pod: train + Hedge eval ($15) -4. **If it works**: 3-seed run, update PR #175 ($15) - -Total budget: $30. Exactly what we have. - -## Strategic Context - -- Competition runs until April 30 (5 weeks left) -- N-gram techniques dominating the frontier (0.96-1.03) -- Our VRL contribution is being adopted by others -- PR #175 has the earliest timestamp of any competitive PR (March 19) -- If Hedge works: we could jump from 1.1229 to ~1.03-1.05 in one run -- If it doesn't: we still have 1.1229 as a valid non-record submission diff --git a/.private/substack_day6_draft.md b/.private/substack_day6_draft.md deleted file mode 100644 index 1c43b2a30..000000000 --- a/.private/substack_day6_draft.md +++ /dev/null @@ -1,141 +0,0 @@ -# OpenAI Parameter Golf Challenge Day 6: The Pod Lottery - -*Article 4 of an ongoing series.* - -In Article 3, I squeezed 157KB out of a model by switching one compression library, added two techniques I didn't invent, ran three seeds at 1am, and submitted PR #657 at 1.1229 bpb — four ten-thousandths better than the merged SOTA. Then I went to bed. - -I woke up to a different competition. - ---- - -## The Leaderboard Moved Without Me - -While I slept, someone submitted 0.9674 bpb. Not 1.09. Not 1.05. Zero point nine six seven four. That's 0.16 bpb better than my submission. In a competition where people fight over 0.002. - -The technique: n-gram caching. Build frequency tables from tokens you've already scored during evaluation, then mix those statistics with the neural model's predictions. It's backward-looking — you only use tokens you've already graded — so it doesn't violate the rules. Probably. The organizers haven't ruled yet. - -Six PRs appeared overnight using variations of the same idea. Multi-order backoff from 7-grams down to unigrams. Entropy-adaptive mixing weights. Zero artifact cost — the tables are built on the fly during eval and thrown away after. The neural model doesn't change. You just augment its predictions with local token statistics. - -My 1.1229 went from "matching SOTA" to "6th tier" in twelve hours. - -I stared at the leaderboard for a while. Ran the numbers. From a 1.12 neural base, n-gram caching should push you to roughly 0.96-1.03. The gain scales with the quality of your base model. My base is actually stronger than most of the n-gram submissions — they're at 1.127 pre-cache, I'm at 1.122. If I added the same cache, I should beat them. - -But should I? The legality question is real. The organizers had already disqualified 25+ PRs in two enforcement sweeps. Full GPTQ with calibration data: illegal. Multi-epoch TTT: illegal. Oracle token selection: illegal. N-gram caching built from scored tokens: ...silence. Six open PRs. Zero organizer comments. Days passing. - -That silence is either "we haven't gotten to it yet" or "it's fine." I genuinely don't know which. - ---- - -## TTT: The Final Attempt - -Before pivoting to n-gram caching, I had one more thing to try. Test-time training had failed on my architecture three times: SGD at lr=0.002 diverged catastrophically. SGD at lr=0.001 was even worse. The model council diagnosed it as "VRL gate desync" — my Value Residual Learning creates dependencies between layers that break when you modify weights mid-inference. - -But then my research agents pulled up PR #688. Their TTT worked. And the recipe was completely different from what I'd been trying: - -| Setting | My Failed TTT | PR #688's Working TTT | -|---------|--------------|----------------------| -| Optimizer | SGD(lr=0.002) | AdamW(lr=0.0001) | -| Frozen blocks | 0 (all unfrozen) | 9 of 11 (only last 2) | -| Weight averaging | None | Polyak (decay=0.998) | - -Twenty times lower learning rate. Nine blocks frozen instead of zero. And Polyak averaging — you score with smoothed weights, train with live weights. I'd been trying to adapt the entire model. They barely touched it. - -I implemented it. Launched it on an 88ms/step India pod. Training finished, sliding window eval came back at 1.1228 — our best ever pre-TTT score. Then the TTT eval started. - -Chunk 1/1893: running bpb = 1.193. Higher than pre-TTT. - -That's expected — the first chunks have no adaptation history. - -Chunk 101: 1.145. Coming down. - -Chunk 201: 1.162. Going back up. - -I watched it oscillate for 700 chunks. The running bpb never dropped below the pre-TTT baseline. Not once. AdamW was more stable than SGD — it didn't explode — but it still couldn't help. The model was slowly degrading with each chunk of adaptation. - -TTT is dead on my architecture. Three optimizers. Four learning rates. Multiple freezing strategies. Polyak averaging. None of it works. The VRL gates were calibrated during training to expect specific weight distributions, and any modification — no matter how gentle — disrupts them. - -I killed the run. Stopped the pod. Accepted it. - ---- - -## The Pod Lottery - -Here's something nobody talks about in ML competitions: not all GPUs are created equal, even when they have the same name. - -I ran the same code on the same "8xH100 SXM" pod template across five different sessions this week. The step times: - -| Pod Location | Step Time | Steps in 10 min | -|-------------|-----------|-----------------| -| India (pod A) | 87ms | 6,889 | -| India (pod B) | 91ms | 6,593 | -| India (pod C) | 106ms | 5,660 | -| Japan | 268ms | 2,238 | -| Canada | 272ms | 2,205 | - -Same GPU. Same code. Same container image. Three-fold speed difference. The Japan and Canada pods ran at walking pace while the India pods sprinted. The step time directly determines how much data you see in 10 minutes, which directly determines your bpb. - -The competition leaderboard is partly a hardware lottery. The top submissions report 83-88ms/step. If you land on a pod that runs at 260ms, you physically cannot produce a competitive result. Not because your model is worse, but because your model saw one-third the data. - -I don't know why the speeds differ so much. NVLink topology? Thermal throttling? Different H100 batches? CPU bottlenecks? I just know that every time I spin up a pod, the first thing I do is run a 20-step benchmark. If it's over 120ms, I kill it and try again. At $21.52/hour for 8xH100, each bad pod costs about $2 before I catch it. Each good pod saves about $15 in wasted training time. - ---- - -## Something Changed - -Then something happened that I didn't expect. - -I was checking the live commentary thread — a community-maintained analysis of every PR in the competition — and I found my name. Not my PR number. My name. - -PR #745, a submission at 1.0222 bpb (the best non-n-gram score at the time), listed their six techniques. One of them was "Value Residual Learning (PR #657)." My PR. Credited. - -Then I found a commit in someone else's fork. ChideraIbe123, a competitor I'd never talked to, had copied my VRL implementation verbatim into their codebase. 28 lines of code. The commit message cited my PR and the ResFormer paper. - -I didn't invent VRL. I implemented it from a paper and proved it worked in competition conditions. And now other people were building on it. The technique I'd added at midnight — 20 lines of code, 10 scalar parameters — was becoming part of the competition's shared vocabulary. - -This is the thing about open competitions that I keep forgetting. The goal isn't just to win. It's to contribute something that moves the field. My VRL implementation isn't going to win me the competition. But it might win someone else a few hundredths of a bpb, and they'll stack something on top of it that I'll then learn from. The whole thing is a giant collaborative gradient descent on the problem of "how good can a 16MB language model be?" - -I went back to my research system and pulled up the live commentary thread again. This time I wasn't looking at the leaderboard. I was looking at the "Untried Combinations" section — a community-curated list of techniques nobody had tested yet. - -There were ideas I'd never heard of. Context Tree Weighting. Logistic-domain mixing. Fixed-Share Hedge with non-stationary expert tracking. Some of them had names that sounded made up. Some of them had arXiv links that I spent an hour reading. - -The competition isn't about having the best idea. It's about having the best information. And right now, the information is telling me that n-gram caching is the play — if it survives the legality review. - ---- - -## The Strategic Play - -Here's where I am at the end of Day 6. - -**What I have:** PR #175 at 1.1229 bpb, three valid seeds, March 19 timestamp (the earliest of any competitive PR because I reopened an old submission). A clean architecture that other people are building on. VRL spreading through the competition. - -**What I don't have:** TTT. N-gram caching. Anything that breaks below 1.12. - -**What I'm building:** An n-gram cache implementation on a separate branch, isolated from my clean submission. If the organizers rule it legal, I deploy it. If they don't, I still have 1.1229 on PR #175 with the oldest timestamp in the game. - -**What I've spent:** Over $1,000 in GPU compute across the week. Four failed FA3 builds. Three failed TTT implementations. Six slow pods killed on sight. Twenty-something full training runs across five days. Two closed PRs. One article I wrote at 3am that more people read than I expected. And the discovery that `pip install flash_attn_3 --find-links .../cu128_torch291` installs in 30 seconds what took me 60 minutes and $100 to build from source. Someone shared that link in the competition thread on Day 4. I found it on Day 6. - -**What I've learned:** The hard problems aren't architectural. They're operational. SSH connections that die mid-training. Pods that lose their GPU allocation at step 4500. Container disks that fill up at 99.7% through a CUDA kernel build. Compression libraries that aren't installed on the official template. Pod speeds that vary 3x for the same hardware. Every one of these burned hours and dollars. None of them improved my bpb by a single millinat. - -## The Evening: Everything Falls Into Place - -Around 8pm, while debugging why the n-gram cache was making things worse (spoiler: I was mixing 30% n-gram noise into good neural predictions, which is like adding static to a clear signal), the research system surfaced a pattern I'd been missing. - -Every time I tried to improve my model during evaluation — SGD, AdamW, LoRA, you name it — it broke because the modifications destabilized the VRL gates. The model's internal state was calibrated during training, and any weight change at eval time, no matter how gentle, disrupted that calibration. - -But the Hedge Mixer doesn't change weights. At all. It takes my frozen model's predictions and mixes them with simple n-gram statistics using an online learning algorithm. The transformer produces logits. The n-gram tables produce probability estimates. The Hedge algorithm learns, token by token, how much to trust each source. The mixing weights update via multiplication — `w *= exp(-eta * loss)` — not via backpropagation. No gradients flow through the model. No compiled graphs get invalidated. No VRL gates get desynced. - -PR #745, the submission that cited my VRL work, uses exactly this approach. Their pre-TTT base model scores 1.1348. After Hedge mixing: 1.0222. A gain of 0.11 bpb from an algorithm that never modifies the model. - -My base model scores 1.1229. That's 0.012 better than theirs. If the Hedge algorithm gives even close to the same gain... - -I spent the rest of the evening implementing it. Then I stopped. Not because I was stuck, but because I'd spent $30 today on runs that taught me things but didn't move the number. I have $30 left. That's two shots. I need them to count. - -The plan for tomorrow is simple. Implement the Hedge Mixer offline (zero GPU cost). Test it once on a fast pod. If it works, run three seeds and update PR #175. - -The competition runs until April 30. Five more weeks. The frontier is at 0.96 and dropping. My 1.1229 is irrelevant in the current landscape — unless I can stack the Hedge Mixer on top of it. - -I think I can. The research system spent all evening analyzing how the top submissions implemented their mixers, what alpha values they use, how they handle the cold-start problem. By morning there will be a complete implementation plan waiting for me. - -The hard problems aren't architectural anymore. They aren't even operational. The hard problem now is: can I execute a clean implementation of a well-understood algorithm, validate it in two runs, and submit before someone else does it better? - -Tomorrow I find out.