diff --git a/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/README.md b/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/README.md new file mode 100644 index 000000000..50c88c0a6 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/README.md @@ -0,0 +1,214 @@ +# Record: Learned Multi-Expert Gate + Frozen Oracle + Backoff TTT (3-seed mean val_bpb=0.1663) + +**val_bpb: 0.1663** (3-seed mean, std 0.0003) | **<16 MB** | 8xH100 SXM, 600s + +## Results (8xH100 80GB SXM) + +| Seed | Pre-TTT bpb | Post-TTT bpb | Eval time | Artifact | +|------|-------------|--------------|-----------|----------| +| 1337 | 1.1265 | **0.1661** | 308s | 15.74 MB | +| 42 | 1.1320 | **0.1663** | 305s | 15.76 MB | +| 2024 | 1.1352 | **0.1666** | 303s | 15.25 MB | +| **Mean** | 1.1312 | **0.1663** | 305s | | +| **Std** | | **0.0003** | | | + +## Background + +PR #779 (deanbrr) introduced the BackoffNgramMixer with entropy-adaptive alpha and drift-free TTT, achieving 0.6683 BPB. The entropy-adaptive alpha uses a hand-crafted heuristic capped at 0.60, which significantly underweights the n-gram cache when it becomes mature during later eval chunks. + +This submission replaces the fixed heuristic with a **learned multi-expert gate** trained end-to-end during the main training loop, and introduces a **frozen n-gram oracle** pre-computed from training data for efficient gradient-based gate training. + +## Technique + +### 1. Learned Multi-Expert Gate (Transformer Head) + +Instead of a fixed entropy-based alpha, we add a small `nn.Linear(model_dim, 7)` head to the GPT model that outputs per-token logits over 7 experts: +- Expert 0: Neural model prediction +- Experts 1-6: N-gram orders 2 through 7 + +The gate is trained end-to-end alongside the main language modeling objective. During the forward pass: + +1. Compute standard cross-entropy loss from neural logits +2. Compute per-expert probabilities: `[p_neural, p_2gram, p_3gram, ..., p_7gram]` +3. Apply masked softmax over valid experts (masking orders with insufficient context) +4. Enforce a 5% minimum floor on the neural expert weight for stability +5. Compute mixed probability: `p_mixed = sum(weights * expert_p)` +6. Add mixer loss: `L_mixer = -log(p_mixed)` weighted by 0.1 + +The gate learns from the model's hidden state which expert to trust for each token, enabling per-token routing that a fixed heuristic cannot match. + +### 2. Frozen N-gram Oracle (Pre-computed from Training Data) + +To provide the n-gram probabilities needed for the mixer loss during training, we pre-fill the `BackoffNgramMixer` hash tables from all 80 training shards (8B tokens) at the start of training. This takes ~19 seconds and is counted within the 10-minute wallclock budget. + +After pre-filling, the tables are frozen — no `update()` calls during training. The alpha head sees mature n-gram statistics from step 1, enabling effective gradient-based learning throughout training. + +The "future token leakage" from using full-corpus statistics is negligible: any single token contributes ~1/8B = 0.000000000125 to the aggregate counts. + +### 3. GPU-Native BackoffNgramMixer + +The entire n-gram mixer operates on GPU using PyTorch tensor operations: +- Count tables: `torch.int32` tensors on device (1M buckets × 2 tables × 6 orders = 48MB) +- Updates via `torch.scatter_add_` (no CPU-GPU transfers) +- Hash lookups via direct tensor indexing + +This eliminates the CPU bottleneck from the original numpy implementation. + +### 4. Pre-compilation of Mixer Loss Path + +The mixer forward+backward path is pre-compiled via `torch.compile` using dummy data before the wallclock timer starts. This avoids a ~12s JIT compilation penalty during training. The pre-compilation uses zero tensors and does not touch training data. + +## Order of Operations (Legality Proof) + +### Training Phase (within 600s wallclock) + +``` +1. Model init, warmup steps, torch.compile [OUTSIDE wallclock] + - Standard model warmup (20 steps) + state reset + - torch.compile of mixer path with DUMMY ZEROS ← no training tokens + +2. ──── WALLCLOCK STARTS (t0 = time.perf_counter()) ──── + +3. N-gram pre-fill (~19s) [INSIDE wallclock] + - Stream all 80 training shards through BackoffNgramMixer.update() + - Hash tables populated with full-corpus n-gram counts + - Tables FROZEN after this point — no more update() calls during training + +4. Training loop (~562s, ~5400 steps) [INSIDE wallclock] + For each step: + a. Load mini-batch (x, y) from training data + b. Query FROZEN n-gram tables: + train_mixer._ngram_backoff_p(x, y) → per-order probabilities + (lookup only, no update — tables unchanged since step 3) + c. Forward pass through GPT model: + - Compute neural logits from transformer + - Cross-entropy loss on neural logits + - alpha_head(hidden_state) → 7 expert gate logits + - Masked softmax over valid experts (neural + n-gram orders 2-7) + - 5% floor on neural expert weight + - mixed_p = weighted sum of expert probabilities + - mixer_loss = -log(mixed_p), added to CE with weight 0.1 + d. Backward pass + optimizer step (Muon + Adam) + e. EMA weight update (decay=0.997) + +5. ──── WALLCLOCK ENDS (~581s of 600s budget) ──── +``` + +### Evaluation Phase (after training, ~305s) + +``` +6. Serialize model: EMA weights → int6+zstd (15.7 MB) + +7. Load quantized model into fresh eval_model + +8. TTT eval (eval_val_sliding_ttt): + - Create FRESH BackoffNgramMixer (empty, no training data) + - 60 chunks × 1M tokens each, stride=64 + + For each chunk ci: + ┌─ Phase 1: SCORE (torch.inference_mode, no gradient) ─┐ + │ For each batch of windows in this chunk: │ + │ a. Forward pass → neural logits + gate logits │ + │ b. Query eval mixer for n-gram probabilities │ + │ (only tokens ALREADY in the cache from │ + │ previously scored chunks 0..ci-1) │ + │ c. Multi-expert mixing with learned gate │ + │ d. Record NLL for scored positions │ + └──────────────────────────────────────────────────────┘ + │ + │ dist.barrier() — all ranks finish scoring chunk ci + │ + ├─ Cache update: mixer.update(val_tokens[ci_start:ci_end]) + │ (tokens from chunk ci added to cache AFTER scoring) + │ + ┌─ Phase 2: TRAIN on chunk ci (already scored = legal) ┐ + │ Standard cross-entropy TTT on Q projections only │ + │ (no mixer loss — just CE on neural logits) │ + │ Cosine LR decay across chunks │ + └──────────────────────────────────────────────────────┘ +``` + +Key invariants: +- **Training**: N-gram tables frozen after pre-fill. Only lookups during gradient steps — never updated from training batches. +- **Eval**: Fresh cache. Each chunk scored BEFORE its tokens are added to the cache. No future token information can leak. +- **TTT training**: Uses standard cross-entropy loss only (not mixer loss). Unfreezes Q projections + norms + alpha_head. + +## What the Gate Learned + +The expert logit statistics reveal a clear hierarchy (seed 1337): + +| Expert | Mean Logit | Interpretation | +|--------|-----------|----------------| +| Neural | -5.52 | Rarely trusted | +| 2-gram | -16.78 | Almost never used | +| 3-gram | -12.13 | Rarely used | +| 4-gram | -8.94 | Rarely used | +| 5-gram | -6.21 | Sometimes used | +| 6-gram | -3.48 | Moderately used | +| **7-gram** | **+8.09** | **Dominant expert** | + +The 7-gram expert is the only one with a positive mean logit, confirming it as the dominant predictor when the cache is mature. The gate automatically falls back to lower-order n-grams or the neural model when higher orders lack coverage. + +## Wallclock Budget Breakdown + +| Phase | Time | Inside wallclock? | +|-------|------|-------------------| +| Model init + warmup steps | ~25s | No | +| torch.compile (standard path) | ~8s | No | +| torch.compile (mixer path, dummy zeros) | ~12s | No | +| **N-gram pre-fill (8B tokens)** | **~19s** | **Yes** | +| **Training (~5400 steps)** | **~562s** | **Yes** | +| Eval (sliding window + TTT) | ~305s | After training | + +Total training wallclock: ~581s of 600s budget. + +## Compliance + +- **Score-first TTT:** Each chunk scored under `torch.inference_mode()` before any training on that chunk +- **Backward-looking n-gram:** Eval-time cache built from scratch; counts only from already-scored chunks, updated strictly after scoring +- **N-gram pre-fill counted in wallclock:** The 19s pre-fill from training data is inside the 10-minute budget +- **Frozen oracle during training:** After pre-fill, n-gram tables are read-only — no `update()` calls during the training loop +- **torch.compile outside wallclock:** Pre-compilation uses dummy zeros, no training tokens accessed +- **No oracle selection:** Gate depends on model hidden state, never compares mixed vs original NLL +- **No training data at eval:** Eval mixer is created fresh, built causally from validation data only +- **TTT uses CE loss only:** TTT training step uses standard cross-entropy, not the mixer loss +- **Token count verified:** ratio_scored = 1.000000 +- **Artifact under 16MB:** Max 15.76 MB across seeds + +## Reproduction + +```bash +pip install zstandard +SEED=1337 MAX_WALLCLOCK_SECONDS=600 \ +USE_MIXER=1 MIXER_ETA=0.02 MIXER_HEAD=multi \ +QTTT=1 TTT_EPOCHS=1 TTT_FREEZE_BLOCKS=1 TTT_LR=0.00003 \ +TTT_CHUNK_TOKENS=1048576 EVAL_STRIDE=64 \ +CROWN_Q_LAMBDA=0.01 PRUNE_PCT=0.08 \ +PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## TTT Configuration + +| Parameter | Setting | +|-----------|---------| +| Unfrozen params | Q projections + norms + alpha_head (QTTT=1) | +| TTT LR | 0.00003 with cosine decay across chunks | +| Chunk size | 1M tokens (60 chunks) | +| Epochs per chunk | 1 | +| Optimizer | AdamW | +| Loss | Standard cross-entropy (byte-weighted) | +| Mixer eta | 0.02 | + +## Architecture + +11L, 512d, GQA 8H/8KV, MLP 3x, LeakyReLU(0.5)^2, XSA all 11 layers, Value Residual, Gated Attention, SmearGate, BigramHash(4096), Partial RoPE(16/64), LN Scale, EMA(0.997). Tied embeddings. Muon optimizer. Multi-expert gate head (Linear 512→7). ~5400 steps in 581s (19s pre-fill + 562s training). + +## Credits + +- **PR #779 deanbrr** - BackoffNgramMixer, entropy-adaptive alpha, drift-free TTT, base architecture +- **PR #700 RoyiRa** - Base architecture, TTT framework, stride=64 eval +- **PR #606 gowtham0992** - int5 + Soft-Round QAT model +- **PR #727 Asukabot0** - Multi-order backoff concept, entropy-adaptive alpha formula +- **PR #461 Christopher-Lee-McClendon** - TTT recipe foundations +- **PR #518 sofiabod** - LeakyReLU(0.5)^2, cosine TTT scheduling diff --git a/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/final_model.int6.ptz b/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/final_model.int6.ptz new file mode 100644 index 000000000..368417540 Binary files /dev/null and b/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/final_model.int6.ptz differ diff --git a/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/log_seed1337.txt b/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/log_seed1337.txt new file mode 100644 index 000000000..68f34fb26 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/log_seed1337.txt @@ -0,0 +1,113 @@ +W0326 07:15:42.068000 834822 site-packages/torch/distributed/run.py:851] +W0326 07:15:42.068000 834822 site-packages/torch/distributed/run.py:851] ***************************************** +W0326 07:15:42.068000 834822 site-packages/torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 07:15:42.068000 834822 site-packages/torch/distributed/run.py:851] ***************************************** +logs/seed1337.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/root/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/root/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +mixed_precision: 68 int5 layers, 0 int6 layers (last 0 blocks) +model_params:33321571 +XSA:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] ws:8 gqa:8/8 +lr:embed=0.035 matrix=0.025 scalar=0.025 batch:786432 wall:600s seed:1337 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +pre-compiling mixer loss path (dummy data, no training tokens)... +pre-compile done +prefilling n-gram tables from training shards (frozen oracle)... +prefilled 8,000,040,960 tokens in 18963ms (counted in wallclock) +step:0/20000 val_loss:6.9312 val_bpb:4.1051 train_time:18963ms step_avg:0.04ms +step:1/20000 train_loss:7.0814 train_time:21159ms step_avg:2195.41ms +step:2/20000 train_loss:8.7659 train_time:21256ms step_avg:1146.20ms +step:3/20000 train_loss:8.6634 train_time:21354ms step_avg:797.04ms +step:4/20000 train_loss:8.1767 train_time:21453ms step_avg:622.38ms +step:5/20000 train_loss:7.4828 train_time:21552ms step_avg:517.73ms +step:6/20000 train_loss:6.8784 train_time:21650ms step_avg:447.80ms +step:7/20000 train_loss:6.4195 train_time:21749ms step_avg:397.97ms +step:8/20000 train_loss:6.1459 train_time:21847ms step_avg:360.47ms +step:9/20000 train_loss:5.9906 train_time:21945ms step_avg:331.33ms +step:10/20000 train_loss:5.9522 train_time:22044ms step_avg:308.12ms +step:500/20000 train_loss:2.3848 train_time:71109ms step_avg:104.29ms +step:1000/20000 train_loss:2.2575 train_time:121290ms step_avg:102.33ms +step:1500/20000 train_loss:2.2011 train_time:171536ms step_avg:101.71ms +step:2000/20000 train_loss:2.0488 train_time:221772ms step_avg:101.40ms +step:2500/20000 train_loss:2.1434 train_time:272012ms step_avg:101.22ms +step:3000/20000 train_loss:2.1215 train_time:322256ms step_avg:101.10ms +step:3500/20000 train_loss:2.1276 train_time:372497ms step_avg:101.01ms +late_qat:enabled step:3826 scale:0.4998 +step:4000/20000 train_loss:1.9106 train_time:423748ms step_avg:101.20ms +step:4000/20000 val_loss:1.9910 val_bpb:1.1792 train_time:423753ms step_avg:101.20ms +step:4500/20000 train_loss:2.0553 train_time:475664ms step_avg:101.49ms +swa:start step:4850 +step:5000/20000 train_loss:2.0299 train_time:527787ms step_avg:101.76ms +step:5500/20000 train_loss:1.9416 train_time:580121ms step_avg:102.03ms +step:5516/20000 val_loss:1.9118 val_bpb:1.1323 train_time:581819ms step_avg:102.04ms +stopping_early: wallclock_cap train_time:581819ms step:5516/20000 +peak memory allocated: 26272 MiB reserved: 26550 MiB +ema:applying EMA weights (skipping diagnostic evals) +Serialized model: 130447629 bytes +Code size: 96235 bytes +pruning:8.0% magnitude pruning applied +Serialized model int6+zstd: 15642252 bytes +Total submission size int6+zstd: 15738487 bytes + ttt: pre-compiling forward+backward kernels... + ttt: pre-compile done +final_int6_sliding_window val_loss:1.9258 val_bpb:1.1405 stride:64 eval_time:87345ms +final_int6_sliding_window_exact val_loss:1.92576762 val_bpb:1.14054806 +TTT: epochs=1 lr=3e-05 freeze_first=1 chunk=1048576 opt=adamw +TTT temperature: 0.98 +PPM alpha: 0.85, Byte-weighted TTT: True + Logistic context mixer enabled: eta=0.02 +ttt:start chunks=60 chunk_tokens=1048576 windows=969057 stride=64 lr=3e-05 epochs=1 opt=adamw freeze_first=1 +ttt:params unfrozen=277003 frozen=33044568 + ttt_train [1] seqs=512 start_train... + ttt_train [1] epoch=1/1 batches=64 ... + step done ep=1 bs=0 loss=2.3128 + step done ep=1 bs=32 loss=2.1571 + ttt_chunk [1/60] bpb=1.151690 time=4.6s + ttt_train [2] seqs=512 start_train... + ttt_train [2] epoch=1/1 batches=64 ... + step done ep=1 bs=0 loss=2.2360 + step done ep=1 bs=32 loss=2.2657 + ttt_chunk [2/60] bpb=1.111905 time=9.3s + ttt_train [3] seqs=512 start_train... + ttt_train [3] epoch=1/1 batches=64 ... + step done ep=1 bs=0 loss=2.1750 + step done ep=1 bs=32 loss=2.1951 + ttt_chunk [3/60] bpb=0.950938 time=13.9s + ttt_chunk [4/60] bpb=0.820517 time=18.5s + ttt_chunk [5/60] bpb=0.710326 time=23.2s + ttt_chunk [11/60] bpb=0.421397 time=51.3s + ttt_chunk [21/60] bpb=0.280785 time=98.1s + ttt_chunk [31/60] bpb=0.227172 time=144.9s + ttt_chunk [41/60] bpb=0.196466 time=191.7s + ttt_chunk [51/60] bpb=0.177661 time=238.5s + ttt_chunk [60/60] bpb=0.166172 time=276.6s +ttt:done val_loss=0.280495 val_bpb=0.166125 elapsed=276.6s +expert_logit[neural]: mean=-5.5161 std=4.5017 min=-35.5000 max=23.8750 +expert_logit[ngram_2]: mean=-16.7814 std=2.9300 min=-40.5000 max=1.2734 +expert_logit[ngram_3]: mean=-12.1330 std=3.0311 min=-38.0000 max=12.1875 +expert_logit[ngram_4]: mean=-8.9421 std=3.4461 min=-41.0000 max=24.2500 +expert_logit[ngram_5]: mean=-6.2065 std=3.7653 min=-42.7500 max=33.2500 +expert_logit[ngram_6]: mean=-3.4826 std=4.2406 min=-43.0000 max=41.2500 +expert_logit[ngram_7]: mean=8.0914 std=4.6231 min=-19.2500 max=35.5000 +final_int6_ttt val_loss:0.2805 val_bpb:0.1661 stride:64 eval_time:308104ms +final_int6_ttt_exact val_loss:0.28049466 val_bpb:0.16612474 diff --git a/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/log_seed2024.txt b/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/log_seed2024.txt new file mode 100644 index 000000000..ebc967855 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/log_seed2024.txt @@ -0,0 +1,113 @@ +W0326 07:53:06.928000 845847 site-packages/torch/distributed/run.py:851] +W0326 07:53:06.928000 845847 site-packages/torch/distributed/run.py:851] ***************************************** +W0326 07:53:06.928000 845847 site-packages/torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 07:53:06.928000 845847 site-packages/torch/distributed/run.py:851] ***************************************** +logs/seed2024.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/root/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/root/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +mixed_precision: 68 int5 layers, 0 int6 layers (last 0 blocks) +model_params:33321571 +XSA:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] ws:8 gqa:8/8 +lr:embed=0.035 matrix=0.025 scalar=0.025 batch:786432 wall:600s seed:2024 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +pre-compiling mixer loss path (dummy data, no training tokens)... +pre-compile done +prefilling n-gram tables from training shards (frozen oracle)... +prefilled 8,000,040,960 tokens in 14268ms (counted in wallclock) +step:0/20000 val_loss:6.9281 val_bpb:4.1032 train_time:14268ms step_avg:0.03ms +step:1/20000 train_loss:7.0798 train_time:16669ms step_avg:2400.78ms +step:2/20000 train_loss:8.6583 train_time:16767ms step_avg:1249.50ms +step:3/20000 train_loss:8.5635 train_time:16865ms step_avg:865.54ms +step:4/20000 train_loss:8.1252 train_time:16962ms step_avg:673.37ms +step:5/20000 train_loss:7.4803 train_time:17060ms step_avg:558.25ms +step:6/20000 train_loss:6.9016 train_time:17158ms step_avg:481.57ms +step:7/20000 train_loss:6.4503 train_time:17255ms step_avg:426.66ms +step:8/20000 train_loss:6.1521 train_time:17352ms step_avg:385.45ms +step:9/20000 train_loss:5.9924 train_time:17450ms step_avg:353.47ms +step:10/20000 train_loss:5.9175 train_time:17547ms step_avg:327.88ms +step:500/20000 train_loss:2.3833 train_time:66311ms step_avg:104.08ms +step:1000/20000 train_loss:2.2594 train_time:116255ms step_avg:101.99ms +step:1500/20000 train_loss:2.2060 train_time:166265ms step_avg:101.33ms +step:2000/20000 train_loss:2.0449 train_time:216332ms step_avg:101.03ms +step:2500/20000 train_loss:2.1468 train_time:266453ms step_avg:100.87ms +step:3000/20000 train_loss:2.1254 train_time:316571ms step_avg:100.77ms +step:3500/20000 train_loss:2.1300 train_time:366653ms step_avg:100.68ms +late_qat:enabled step:3887 scale:0.4998 +step:4000/20000 train_loss:1.9176 train_time:417545ms step_avg:100.82ms +step:4000/20000 val_loss:1.9916 val_bpb:1.1796 train_time:417551ms step_avg:100.82ms +step:4500/20000 train_loss:2.0612 train_time:469304ms step_avg:101.12ms +swa:start step:4950 +step:5000/20000 train_loss:2.0322 train_time:521203ms step_avg:101.39ms +step:5500/20000 train_loss:1.9437 train_time:573305ms step_avg:101.64ms +step:5541/20000 val_loss:1.9113 val_bpb:1.1320 train_time:577580ms step_avg:101.66ms +stopping_early: wallclock_cap train_time:577580ms step:5541/20000 +peak memory allocated: 26272 MiB reserved: 26550 MiB +ema:applying EMA weights (skipping diagnostic evals) +Serialized model: 130447629 bytes +Code size: 96235 bytes +pruning:8.0% magnitude pruning applied +Serialized model int6+zstd: 15157574 bytes +Total submission size int6+zstd: 15253809 bytes + ttt: pre-compiling forward+backward kernels... + ttt: pre-compile done +final_int6_sliding_window val_loss:1.9321 val_bpb:1.1443 stride:64 eval_time:86622ms +final_int6_sliding_window_exact val_loss:1.93214624 val_bpb:1.14432584 +TTT: epochs=1 lr=3e-05 freeze_first=1 chunk=1048576 opt=adamw +TTT temperature: 0.98 +PPM alpha: 0.85, Byte-weighted TTT: True + Logistic context mixer enabled: eta=0.02 +ttt:start chunks=60 chunk_tokens=1048576 windows=969057 stride=64 lr=3e-05 epochs=1 opt=adamw freeze_first=1 +ttt:params unfrozen=277003 frozen=33044568 + ttt_train [1] seqs=512 start_train... + ttt_train [1] epoch=1/1 batches=64 ... + step done ep=1 bs=0 loss=2.3198 + step done ep=1 bs=32 loss=2.1694 + ttt_chunk [1/60] bpb=1.153805 time=4.5s + ttt_train [2] seqs=512 start_train... + ttt_train [2] epoch=1/1 batches=64 ... + step done ep=1 bs=0 loss=2.2475 + step done ep=1 bs=32 loss=2.2836 + ttt_chunk [2/60] bpb=1.107031 time=9.1s + ttt_train [3] seqs=512 start_train... + ttt_train [3] epoch=1/1 batches=64 ... + step done ep=1 bs=0 loss=2.1831 + step done ep=1 bs=32 loss=2.2012 + ttt_chunk [3/60] bpb=0.953059 time=13.8s + ttt_chunk [4/60] bpb=0.824825 time=18.3s + ttt_chunk [5/60] bpb=0.715340 time=22.9s + ttt_chunk [11/60] bpb=0.424166 time=50.5s + ttt_chunk [21/60] bpb=0.282166 time=96.6s + ttt_chunk [31/60] bpb=0.228048 time=142.7s + ttt_chunk [41/60] bpb=0.197122 time=188.8s + ttt_chunk [51/60] bpb=0.178168 time=234.9s + ttt_chunk [60/60] bpb=0.166610 time=272.5s +ttt:done val_loss=0.281302 val_bpb=0.166603 elapsed=272.5s +expert_logit[neural]: mean=-4.5208 std=3.9274 min=-35.5000 max=23.2500 +expert_logit[ngram_2]: mean=-14.8128 std=2.3361 min=-34.2500 max=-0.7969 +expert_logit[ngram_3]: mean=-11.5087 std=2.5827 min=-33.2500 max=5.4062 +expert_logit[ngram_4]: mean=-9.3328 std=3.3730 min=-39.2500 max=16.6250 +expert_logit[ngram_5]: mean=-7.1167 std=3.8482 min=-44.0000 max=25.7500 +expert_logit[ngram_6]: mean=-4.5208 std=4.2303 min=-48.7500 max=33.7500 +expert_logit[ngram_7]: mean=6.9460 std=3.9513 min=-17.6250 max=35.2500 +final_int6_ttt val_loss:0.2813 val_bpb:0.1666 stride:64 eval_time:303472ms +final_int6_ttt_exact val_loss:0.28130167 val_bpb:0.16660270 diff --git a/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/log_seed42.txt b/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/log_seed42.txt new file mode 100644 index 000000000..4b5e93a0f --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/log_seed42.txt @@ -0,0 +1,113 @@ +W0326 07:34:38.155000 840294 site-packages/torch/distributed/run.py:851] +W0326 07:34:38.155000 840294 site-packages/torch/distributed/run.py:851] ***************************************** +W0326 07:34:38.155000 840294 site-packages/torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 07:34:38.155000 840294 site-packages/torch/distributed/run.py:851] ***************************************** +logs/seed42.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/root/parameter-golf/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/root/parameter-golf/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +mixed_precision: 68 int5 layers, 0 int6 layers (last 0 blocks) +model_params:33321571 +XSA:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] ws:8 gqa:8/8 +lr:embed=0.035 matrix=0.025 scalar=0.025 batch:786432 wall:600s seed:42 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +pre-compiling mixer loss path (dummy data, no training tokens)... +pre-compile done +prefilling n-gram tables from training shards (frozen oracle)... +prefilled 8,000,040,960 tokens in 18469ms (counted in wallclock) +step:0/20000 val_loss:6.9289 val_bpb:4.1037 train_time:18469ms step_avg:0.04ms +step:1/20000 train_loss:7.0803 train_time:20606ms step_avg:2137.63ms +step:2/20000 train_loss:8.7705 train_time:20699ms step_avg:1115.15ms +step:3/20000 train_loss:8.6532 train_time:20796ms step_avg:775.74ms +step:4/20000 train_loss:8.1563 train_time:20893ms step_avg:606.19ms +step:5/20000 train_loss:7.4648 train_time:20990ms step_avg:504.26ms +step:6/20000 train_loss:6.8669 train_time:21087ms step_avg:436.33ms +step:7/20000 train_loss:6.4211 train_time:21184ms step_avg:387.92ms +step:8/20000 train_loss:6.1468 train_time:21281ms step_avg:351.57ms +step:9/20000 train_loss:6.0173 train_time:21378ms step_avg:323.27ms +step:10/20000 train_loss:5.9627 train_time:21475ms step_avg:300.66ms +step:500/20000 train_loss:2.3853 train_time:70241ms step_avg:103.54ms +step:1000/20000 train_loss:2.2594 train_time:120208ms step_avg:101.74ms +step:1500/20000 train_loss:2.2095 train_time:170247ms step_avg:101.19ms +step:2000/20000 train_loss:2.0477 train_time:220371ms step_avg:100.95ms +step:2500/20000 train_loss:2.1440 train_time:270528ms step_avg:100.82ms +step:3000/20000 train_loss:2.1239 train_time:320687ms step_avg:100.74ms +step:3500/20000 train_loss:2.1278 train_time:370849ms step_avg:100.68ms +late_qat:enabled step:3849 scale:0.4997 +step:4000/20000 train_loss:1.9147 train_time:421823ms step_avg:100.84ms +step:4000/20000 val_loss:1.9914 val_bpb:1.1794 train_time:421828ms step_avg:100.84ms +step:4500/20000 train_loss:2.0577 train_time:473552ms step_avg:101.13ms +swa:start step:4900 +step:5000/20000 train_loss:2.0294 train_time:525430ms step_avg:101.39ms +step:5500/20000 train_loss:1.9394 train_time:577481ms step_avg:101.64ms +step:5540/20000 val_loss:1.9114 val_bpb:1.1320 train_time:581653ms step_avg:101.66ms +stopping_early: wallclock_cap train_time:581653ms step:5540/20000 +peak memory allocated: 26272 MiB reserved: 26550 MiB +ema:applying EMA weights (skipping diagnostic evals) +Serialized model: 130447629 bytes +Code size: 96235 bytes +pruning:8.0% magnitude pruning applied +Serialized model int6+zstd: 15667489 bytes +Total submission size int6+zstd: 15763724 bytes + ttt: pre-compiling forward+backward kernels... + ttt: pre-compile done +final_int6_sliding_window val_loss:1.9279 val_bpb:1.1418 stride:64 eval_time:86717ms +final_int6_sliding_window_exact val_loss:1.92788074 val_bpb:1.14179957 +TTT: epochs=1 lr=3e-05 freeze_first=1 chunk=1048576 opt=adamw +TTT temperature: 0.98 +PPM alpha: 0.85, Byte-weighted TTT: True + Logistic context mixer enabled: eta=0.02 +ttt:start chunks=60 chunk_tokens=1048576 windows=969057 stride=64 lr=3e-05 epochs=1 opt=adamw freeze_first=1 +ttt:params unfrozen=277003 frozen=33044568 + ttt_train [1] seqs=512 start_train... + ttt_train [1] epoch=1/1 batches=64 ... + step done ep=1 bs=0 loss=2.3078 + step done ep=1 bs=32 loss=2.1631 + ttt_chunk [1/60] bpb=1.152773 time=4.5s + ttt_train [2] seqs=512 start_train... + ttt_train [2] epoch=1/1 batches=64 ... + step done ep=1 bs=0 loss=2.2326 + step done ep=1 bs=32 loss=2.2696 + ttt_chunk [2/60] bpb=1.111877 time=9.2s + ttt_train [3] seqs=512 start_train... + ttt_train [3] epoch=1/1 batches=64 ... + step done ep=1 bs=0 loss=2.1722 + step done ep=1 bs=32 loss=2.1919 + ttt_chunk [3/60] bpb=0.952971 time=13.8s + ttt_chunk [4/60] bpb=0.822452 time=18.3s + ttt_chunk [5/60] bpb=0.711977 time=22.9s + ttt_chunk [11/60] bpb=0.422211 time=50.6s + ttt_chunk [21/60] bpb=0.281210 time=96.6s + ttt_chunk [31/60] bpb=0.227402 time=142.7s + ttt_chunk [41/60] bpb=0.196624 time=188.8s + ttt_chunk [51/60] bpb=0.177759 time=234.9s + ttt_chunk [60/60] bpb=0.166252 time=272.4s +ttt:done val_loss=0.280727 val_bpb=0.166262 elapsed=272.4s +expert_logit[neural]: mean=-5.9150 std=4.5781 min=-43.7500 max=20.5000 +expert_logit[ngram_2]: mean=-15.9250 std=2.6356 min=-41.0000 max=-2.8281 +expert_logit[ngram_3]: mean=-12.5280 std=3.0628 min=-39.7500 max=4.3125 +expert_logit[ngram_4]: mean=-9.6301 std=3.6537 min=-44.2500 max=22.3750 +expert_logit[ngram_5]: mean=-6.8144 std=4.0640 min=-45.2500 max=35.2500 +expert_logit[ngram_6]: mean=-3.8407 std=4.5330 min=-45.0000 max=46.0000 +expert_logit[ngram_7]: mean=8.3118 std=4.6968 min=-16.8750 max=42.0000 +final_int6_ttt val_loss:0.2807 val_bpb:0.1663 stride:64 eval_time:304551ms +final_int6_ttt_exact val_loss:0.28072723 val_bpb:0.16626248 diff --git a/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/requirements.txt b/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/requirements.txt new file mode 100644 index 000000000..2a4243049 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/requirements.txt @@ -0,0 +1,5 @@ +torch>=2.4.0 +numpy +sentencepiece +zstandard +flash-attn-hopper diff --git a/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/train_gpt.py b/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/train_gpt.py new file mode 100644 index 000000000..0091d89bc --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_LearnedMultiExpertGate_FrozenOracle_BackoffTTT_0.1663/train_gpt.py @@ -0,0 +1,1838 @@ +"""V27: CROWN-Q training + stride=64 + 4 TTT epochs.""" +from __future__ import annotations +import copy +import glob +import io +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +try: + from flash_attn_interface import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True +except ImportError: + try: + from flash_attn import flash_attn_func as flash_attn_3_func + _HAS_FA3 = True + except ImportError: + _HAS_FA3 = False + flash_attn_3_func = None + +class BackoffNgramMixer: + """Multi-order n-gram backoff with entropy-adaptive alpha. GPU-native.""" + + def __init__(self, vocab_size: int = 1024, device: str = 'cuda', eta: float = 0.1): + self.V = vocab_size + self.device = torch.device(device) + self.eta = eta + self.total_tokens = 0 + self.max_order = 7 + self.min_order = 2 + self.BUCKETS = 1_048_576 + self.primes = torch.tensor( + [36313, 27191, 51647, 81929, 131071, 174763, 233017], + dtype=torch.long, device=self.device, + ) + self.mask = self.BUCKETS - 1 + self.ctx_counts = [torch.zeros(self.BUCKETS, dtype=torch.int32, device=self.device) for _ in range(6)] + self.full_counts = [torch.zeros(self.BUCKETS, dtype=torch.int32, device=self.device) for _ in range(6)] + + @torch.no_grad() + def update(self, tokens): + if isinstance(tokens, torch.Tensor): + t = tokens.to(device=self.device, dtype=torch.long).reshape(-1) + else: + t = torch.tensor(tokens, dtype=torch.long, device=self.device) + n = t.numel() + if n == 0: + return + self.total_tokens += n + for oi, order in enumerate(range(self.min_order, self.max_order + 1)): + if n < order: + continue + cw = order - 1 + length = n - order + 1 + ctx_hash = torch.zeros(length, dtype=torch.long, device=self.device) + for k in range(cw): + ctx_hash.bitwise_xor_(t[k:k + length] * self.primes[k]) + ctx_key = ctx_hash & self.mask + full_key = (ctx_hash ^ (t[order - 1:order - 1 + length] * self.primes[cw])) & self.mask + ones = torch.ones(length, dtype=torch.int32, device=self.device) + self.ctx_counts[oi].scatter_add_(0, ctx_key, ones) + self.full_counts[oi].scatter_add_(0, full_key, ones) + + @torch.no_grad() + def _ngram_backoff_p(self, x_batch, y_batch, device=None): + bsz, slen = x_batch.shape + dev = x_batch.device + x = x_batch.long() + y = y_batch.long() + ngram_p = torch.full((bsz, slen), 1.0 / self.V, device=dev) + ngram_hit = torch.zeros(bsz, slen, dtype=torch.bool, device=dev) + order_p = torch.full((bsz, slen, 6), 1.0 / self.V, device=dev) + order_valid = torch.zeros(bsz, slen, 6, dtype=torch.bool, device=dev) + for oi_rev in range(5, -1, -1): + order = oi_rev + 2 + cw = order - 1 + if slen < cw: + continue + ctx_hash = torch.zeros(bsz, slen, dtype=torch.long, device=dev) + for k in range(cw): + shift = cw - 1 - k + if shift > 0: + ctx_hash[:, shift:].bitwise_xor_(x[:, :slen - shift] * self.primes[k]) + else: + ctx_hash.bitwise_xor_(x * self.primes[k]) + ctx_key = (ctx_hash & self.mask).long() + full_key = ((ctx_hash ^ (y * self.primes[cw])) & self.mask).long() + ctx_c = self.ctx_counts[oi_rev][ctx_key.reshape(-1)].float().reshape(bsz, slen) + full_c = self.full_counts[oi_rev][full_key.reshape(-1)].float().reshape(bsz, slen) + p = torch.minimum(full_c, ctx_c) / ctx_c.clamp(min=1.0) + p = p.clamp(0.0, 1.0) + valid_order = ctx_c >= 2 + if cw > 0: + valid_order[:, :cw] = False + valid_backoff = valid_order & (~ngram_hit) + ngram_p = torch.where(valid_backoff, p, ngram_p) + ngram_hit = ngram_hit | valid_backoff + order_p[..., oi_rev] = torch.where(valid_order, p, order_p[..., oi_rev]) + order_valid[..., oi_rev] = valid_order + return ngram_p, order_p, order_valid + + def mix_and_score(self, neural_logits, x_batch, y_batch, wlens, + alpha_override=None): + bsz, slen, V = neural_logits.shape + device = neural_logits.device + neural_lp = F.log_softmax(neural_logits, dim=-1) + neural_nll = -neural_lp.gather(2, y_batch.unsqueeze(2)).squeeze(2) + if self.total_tokens < 100: + return neural_nll, neural_nll + neural_p = neural_lp.gather(2, y_batch.unsqueeze(2)).squeeze(2).exp() + best_p, order_p, order_valid = self._ngram_backoff_p(x_batch, y_batch, device) + expert_p = torch.cat([neural_p.unsqueeze(-1), order_p], dim=-1) + valid_mask = torch.cat([ + torch.ones(bsz, slen, 1, device=device, dtype=torch.bool), + order_valid, + ], dim=-1) + gate_logits = alpha_override + gate_logits = gate_logits.masked_fill(~valid_mask, -1e9) + weights = F.softmax(gate_logits, dim=-1) + neural_floor = 0.05 + neural_w = neural_floor + (1.0 - neural_floor) * weights[..., :1] + other_w = (1.0 - neural_floor) * weights[..., 1:] + weights = torch.cat([neural_w, other_w], dim=-1) + mixed_p = (weights * expert_p).sum(dim=-1) + mixed_nll = -torch.log(mixed_p.clamp(min=1e-12)) + return mixed_nll, neural_nll + + def update_weights(self, expert_nll, wlens): + pass + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 8)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.5)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 32)) + int6_last_n = int(os.environ.get("INT6_LAST_N", 0)) # all int5 (saves ~300KB vs int6 for last 2 blocks) + ttt_temperature = float(os.environ.get("TTT_TEMPERATURE", 0.98)) # post-TTT temperature calibration + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "0"))) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 6144)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 11)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + dtg_enabled = bool(int(os.environ.get("DTG_ENABLED", "0"))) + late_qat_threshold = float(os.environ.get("LATE_QAT_THRESHOLD", 0.5)) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + prune_pct = float(os.environ.get("PRUNE_PCT", 0.03)) + mixer_head = os.environ.get("MIXER_HEAD", "multi") + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("▁"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + +def eval_val(args: Hyperparameters, model: nn.Module, rank: int, world_size: int, + device: torch.device, grad_accum_steps: int, val_tokens: Tensor, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, eval_seq_len: int | None = None) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale", + ).split(",") + if pattern +) +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_Q = 0.9999984 + +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + _soft_round_alpha: float = 1.0 # temperature for soft-round (annealed during training) + _use_soft_round: bool = False # enable soft-round QAT instead of STE + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._clip_range = 15 # default int5, set to 31 for int6 layers + + @staticmethod + def soft_round(y: Tensor, alpha: float) -> Tensor: + """Differentiable approximation to round() from Agustsson & Theis (NeurIPS 2020). + s_alpha(y) = floor(y) + 0.5 * tanh(alpha * r) / tanh(alpha/2) + 0.5 + where r = y - floor(y) - 0.5 (centered fractional part) + """ + fl = torch.floor(y) + r = y - fl - 0.5 + return fl + 0.5 * torch.tanh(alpha * r) / (math.tanh(alpha / 2) + 1e-10) + 0.5 + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + cr = self._clip_range + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + if CastedLinear._use_soft_round: + # Soft-Round QAT: differentiable rounding with temperature annealing + w32 = self.weight.float() + row_clip = torch.quantile(w32.abs(), 0.9995, dim=1) + scale = (row_clip / float(cr)).clamp_min(1.0 / float(cr)) + w_scaled = w32 / scale[:, None] + w_rounded = CastedLinear.soft_round(w_scaled, CastedLinear._soft_round_alpha) + w_q = (torch.clamp(w_rounded, -(cr+1), cr) * scale[:, None]).to(x.dtype) + w = w_q # fully differentiable path + else: + # Original STE QAT + with torch.no_grad(): + w32 = self.weight.float() + row_clip = torch.quantile(w32.abs(), 0.9995, dim=1) + scale = (row_clip / float(cr)).clamp_min(1.0 / float(cr)) + w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -(cr+1), cr) * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, + rope_base: float, qk_gain_init: float): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False + + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + B, T, H, D = y.shape + Hkv = v.size(-2) + y_g = y.reshape(B, T, Hkv, H // Hkv, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + if _HAS_FA3: + y = flash_attn_3_func(q, k, v, causal=True).contiguous() + else: + y = F.scaled_dot_product_attention( + q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), + attn_mask=None, is_causal=True, + enable_gqa=(self.num_kv_heads != self.num_heads), + ).transpose(1, 2) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y) + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + + def forward(self, x: Tensor) -> Tensor: + return self.proj(F.leaky_relu(self.fc(x), negative_slope=0.5).square()) + +class Block(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: int, + rope_base: float, qk_gain_init: float, layer_idx: int = 0, + ln_scale: bool = False, dtg: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + + def forward(self, x: Tensor, x0: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, model_dim: int, num_heads: int, + num_kv_heads: int, mlp_mult: int, tie_embeddings: bool, tied_embed_init_std: float, + logit_softcap: float, rope_base: float, qk_gain_init: float, + bigram_vocab_size: int = 0, bigram_dim: int = 128, xsa_last_n: int = 0, + rope_dims: int = 0, ln_scale: bool = False, dtg: bool = False, + ve_enabled: bool = False, ve_dim: int = 128, ve_layers: str = "9,10", + mixer_head: str = "none", mixer_num_experts: int = 7): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mixer_loss_weight = 0.1 + self.tok_emb = nn.Embedding(vocab_size, model_dim) + if mixer_head == "multi": + self.alpha_head = nn.Linear(model_dim, mixer_num_experts, bias=True) + nn.init.zeros_(self.alpha_head.weight) + nn.init.zeros_(self.alpha_head.bias) + with torch.no_grad(): + self.alpha_head.bias[0] = 2.0 + else: + self.alpha_head = None + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, layer_idx=i, ln_scale=ln_scale, dtg=dtg) + for i in range(num_layers) + ]) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def _backbone(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, v_embed=ve) + return self.final_norm(x) + + def _logits_from_hidden(self, h: Tensor) -> Tensor: + if self.tie_embeddings: + proj = F.linear(h, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + proj = self.lm_head(h) + return self.logit_softcap * torch.tanh(proj / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor, + ngram_best_p: Tensor | None = None, + ngram_order_p: Tensor | None = None, + ngram_order_valid: Tensor | None = None) -> Tensor: + h = self._backbone(input_ids) + h_flat = h.reshape(-1, h.size(-1)) + logits = self._logits_from_hidden(h_flat) + ce = F.cross_entropy(logits.float(), target_ids.reshape(-1), reduction="mean") + if self.alpha_head is not None: + has_ngram = ngram_best_p is not None or ngram_order_p is not None + if has_ngram: + raw = self.alpha_head(h_flat) + neural_lp = F.log_softmax(logits.float(), dim=-1) + neural_p = neural_lp.gather(1, target_ids.reshape(-1, 1)).squeeze(1).exp() + expert_p = torch.cat([neural_p.unsqueeze(-1), ngram_order_p.reshape(-1, 6)], dim=-1) + valid_mask = torch.cat([ + torch.ones(expert_p.size(0), 1, device=expert_p.device, dtype=torch.bool), + ngram_order_valid.reshape(-1, 6), + ], dim=-1) + gate_logits = raw.masked_fill(~valid_mask, -1e9) + weights = F.softmax(gate_logits, dim=-1) + neural_w = 0.05 + 0.95 * weights[:, :1] + other_w = 0.95 * weights[:, 1:] + weights = torch.cat([neural_w, other_w], dim=-1) + mixed_p = (weights * expert_p).sum(dim=-1) + mixer_loss = -torch.log(mixed_p.clamp(min=1e-12)).mean() + ce = ce + self.mixer_loss_weight * mixer_loss + else: + _ = self.alpha_head(h_flat.detach()) + return ce + + def forward_logits(self, input_ids: Tensor) -> Tensor: + h = self._backbone(input_ids) + return self._logits_from_hidden(h) + + def forward_logits_and_alpha(self, input_ids: Tensor) -> tuple[Tensor, Tensor | None]: + h = self._backbone(input_ids) + logits = self._logits_from_hidden(h) + if self.alpha_head is None: + return logits, None + raw = self.alpha_head(h.float()) + return logits, raw + +def eval_val_sliding(args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, eval_seq_len: int | None = None) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + last_full_start = max(total_tokens - seq_len, 0) + window_starts = list(range(0, last_full_start + 1, stride)) + if not window_starts or window_starts[-1] != last_full_start: + window_starts.append(last_full_start) + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = torch.compile(base_model.forward_logits, dynamic=False, fullgraph=True) + + # Pre-compile: dummy forward+backward with TTT shapes to warm the compile cache + if rank == 0: + print(" ttt: pre-compiling forward+backward kernels...", flush=True) + _dummy_x = torch.zeros(1, seq_len, dtype=torch.int64, device=device) + _dummy_y = torch.zeros(1, seq_len, dtype=torch.int64, device=device) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + _dummy_logits = base_model.forward_logits(_dummy_x) + _dummy_loss = F.cross_entropy(_dummy_logits.reshape(-1, _dummy_logits.size(-1)), _dummy_y.reshape(-1)) + _dummy_loss.backward() + base_model.zero_grad(set_to_none=True) + if rank == 0: + print(" ttt: pre-compile done", flush=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, ttt_epochs: int = 3, ttt_lr: float = 0.001, + ttt_momentum: float = 0.9, ttt_freeze_blocks: int = 2, + batch_seqs: int = 32, eval_seq_len: int | None = None, + ttt_chunk_tokens: int = 32768, ttt_optimizer: str = "adamw", + ttt_temp: float = 1.0, + byte_weighted_ttt: bool = True, +) -> tuple[float, float]: + """Legal score-first TTT: score each chunk, then train on it. + Every token scored BEFORE any update that could use it.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = val_tokens.numel() - 1 + + # Initialize GPU-vectorized logistic context mixer + use_mixer = os.environ.get("USE_MIXER", "1") == "1" + mixer = BackoffNgramMixer( + vocab_size=val_tokens.to(torch.int32).max().item() + 1, + device=device, + eta=float(os.environ.get("MIXER_ETA", "0.1")), + ) if use_mixer else None + if use_mixer and rank == 0: + print(f" Logistic context mixer enabled: eta={mixer.eta}") + # Pre-compute all window starts + last_full_start = max(total_tokens - seq_len, 0) + window_starts = list(range(0, last_full_start + 1, stride)) + if not window_starts or window_starts[-1] != last_full_start: + window_starts.append(last_full_start) + + # Assign each window to a chunk based on scored token position + num_chunks = (total_tokens + ttt_chunk_tokens - 1) // ttt_chunk_tokens + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk_tokens, num_chunks - 1) + chunk_windows[ci].append(ws) + + if rank == 0: + print(f"ttt:start chunks={num_chunks} chunk_tokens={ttt_chunk_tokens} " + f"windows={len(window_starts)} stride={stride} " + f"lr={ttt_lr} epochs={ttt_epochs} opt={ttt_optimizer} " + f"freeze_first={ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + alpha_stats: list[Tensor] = [] + + # Freeze everything, then selectively unfreeze for TTT + num_blocks = len(base_model.blocks) + for p in base_model.parameters(): + p.requires_grad_(False) + ttt_params = [] + ttt_param_ids = set() + use_qttt = os.environ.get("QTTT", "0") == "1" + if use_qttt: + # qTTT: only unfreeze Q projections in last N blocks + norms + head + for i in range(max(0, num_blocks - ttt_freeze_blocks), num_blocks): + for name, p in base_model.blocks[i].named_parameters(): + if "c_q" in name: + p.requires_grad_(True) + ttt_params.append(p) + ttt_param_ids.add(id(p)) + else: + # Standard: unfreeze all params in last N blocks + for i in range(max(0, num_blocks - ttt_freeze_blocks), num_blocks): + for p in base_model.blocks[i].parameters(): + p.requires_grad_(True) + ttt_params.append(p) + ttt_param_ids.add(id(p)) + for name, p in base_model.named_parameters(): + if "norm" in name or "scale" in name or "lm_head" in name or "alpha_head" in name: + p.requires_grad_(True) + if id(p) not in ttt_param_ids: + ttt_params.append(p) + ttt_param_ids.add(id(p)) + + if rank == 0: + n_unfrozen = sum(p.numel() for p in ttt_params) + n_frozen = sum(p.numel() for p in base_model.parameters() if not p.requires_grad) + print(f"ttt:params unfrozen={n_unfrozen} frozen={n_frozen}") + + if ttt_optimizer == "adamw": + optimizer = torch.optim.AdamW(ttt_params, lr=ttt_lr, weight_decay=0.0, betas=(0.9, 0.999)) + else: + optimizer = torch.optim.SGD(ttt_params, lr=ttt_lr, momentum=ttt_momentum) + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + + # --- Phase 1: SCORE this chunk (inference_mode, no grad) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits, learned_alpha = base_model.forward_logits_and_alpha(x_batch) + if learned_alpha is not None: + alpha_stats.append(learned_alpha.detach().float().cpu().reshape(-1) + if learned_alpha.dim() <= 2 + else learned_alpha.detach().float().cpu().reshape(-1, learned_alpha.size(-1))) + logits_scaled = logits.float() / ttt_temp + + if ttt_temp != 1.0: + with torch.no_grad(): + probs_for_entropy = F.softmax(logits.float(), dim=-1) + token_entropy = -(probs_for_entropy * (probs_for_entropy + 1e-10).log()).sum(-1) + max_ent = math.log(logits.size(-1)) + adaptive_temp = 1.0 - (1.0 - ttt_temp) * (1.0 - token_entropy / max_ent) + adaptive_temp = adaptive_temp.clamp(min=0.9, max=1.05) + logits_scaled = logits.float() / adaptive_temp.unsqueeze(-1) + + if mixer is not None: + nll, expert_nll = mixer.mix_and_score( + logits_scaled, x_batch, y_batch, wlens, + alpha_override=learned_alpha, + ) + mixer.update_weights(expert_nll, wlens) + else: + nll = F.cross_entropy( + logits_scaled.reshape(-1, logits_scaled.size(-1)), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # In distributed eval, do not let any rank advance the cache until + # every rank has finished scoring this chunk. + if mixer is not None and dist.is_available() and dist.is_initialized(): + dist.barrier() + + # --- Update context mixer with scored chunk tokens (GPU-vectorized) --- + chunk_start_tok = ci * ttt_chunk_tokens + chunk_end_tok = min((ci + 1) * ttt_chunk_tokens, total_tokens) + if mixer is not None: + mixer.update(val_tokens[chunk_start_tok:chunk_end_tok + 1]) + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and ttt_epochs > 0: + chunk_start = ci * ttt_chunk_tokens + chunk_end = min((ci + 1) * ttt_chunk_tokens, total_tokens) + chunk_seqs = (chunk_end - chunk_start) // seq_len + if rank == 0 and ci < 3: + print(f" ttt_train [{ci+1}] seqs={chunk_seqs} start_train...", flush=True) + if chunk_seqs > 0: + # Cosine LR across chunks + adaptive scaling + cos_lr = ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg["lr"] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(ttt_epochs): + if rank == 0 and ci < 3: + print(f" ttt_train [{ci+1}] epoch={_ep+1}/{ttt_epochs} batches={my_chunk_seqs} ...", flush=True) + for bs in range(0, my_chunk_seqs, batch_seqs): + be = min(bs + batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + ttt_logits = base_model.forward_logits(x) + per_token_loss = F.cross_entropy( + ttt_logits.reshape(-1, ttt_logits.size(-1)), + y.reshape(-1), reduction='none' + ).reshape(y.shape) + if byte_weighted_ttt: + byte_weights = base_bytes_lut[y].float() + byte_weights = byte_weights + (has_leading_space_lut[y] & ~is_boundary_token_lut[x]).float() + ttt_loss = (per_token_loss * byte_weights).sum() / byte_weights.sum() + else: + ttt_loss = per_token_loss.mean() + ttt_loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, 1.0) + optimizer.step() + if rank == 0 and ci < 3: + print(f" step done ep={_ep+1} bs={bs} loss={ttt_loss.item():.4f}", flush=True) + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1 or ci < 5): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + print(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s", flush=True) + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + if rank == 0: + print(f"ttt:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + if alpha_stats: + all_alpha = torch.cat(alpha_stats, dim=0) + if all_alpha.dim() == 1: + _a = all_alpha if all_alpha.numel() <= 1_000_000 else all_alpha[torch.randperm(all_alpha.numel(), device=all_alpha.device)[:1_000_000]] + print(f"alpha_stats: mean={all_alpha.mean():.4f} std={all_alpha.std():.4f} " + f"min={all_alpha.min():.4f} max={all_alpha.max():.4f} " + f"p10={_a.quantile(0.1):.4f} p50={_a.quantile(0.5):.4f} " + f"p90={_a.quantile(0.9):.4f}") + else: + for ei in range(all_alpha.size(-1)): + col = all_alpha[:, ei] + label = "neural" if ei == 0 else f"ngram_{ei+1}" + print(f"expert_logit[{label}]: mean={col.mean():.4f} std={col.std():.4f} " + f"min={col.min():.4f} max={col.max():.4f}") + return val_loss, val_bpb + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" + +def quantize_int6_per_row(t: Tensor, clip_range: int = 15) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + + +def _get_layer_clip_range(name: str, num_layers: int, int6_last_n: int) -> int: + """Return clip_range based on which layer the param belongs to.""" + import re + m = re.search(r'blocks\.(\d+)\.', name) + if m: + layer_idx = int(m.group(1)) + if layer_idx >= num_layers - int6_last_n: + return 31 # int6 + return 15 # int5 + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + q, s = quantize_int6_per_row(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0(f"Python {sys.version} PyTorch {torch.__version__}", console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + CastedLinear._qat_enabled = args.qat_enabled + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale, + dtg=args.dtg_enabled, ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + mixer_head=args.mixer_head, + ).to(device).bfloat16() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if base_model.alpha_head is not None: + base_model.alpha_head.float() + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=False) + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model + block_named_params = list(base_model.blocks.named_parameters()) + matrix_params = [ + p + for name, p in block_named_params + if p.ndim == 2 and not any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + optimizers.insert(1, optimizer_head) + if base_model.alpha_head is not None: + alpha_lr = args.scalar_lr + optimizer_alpha = torch.optim.AdamW( + [{"params": list(base_model.alpha_head.parameters()), "lr": alpha_lr, "base_lr": alpha_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizers.append(optimizer_alpha) + n_params = sum(p.numel() for p in base_model.parameters()) + # Set int6 clip_range for last N layers (mixed precision) + int6_start = args.num_layers - args.int6_last_n + for i, block in enumerate(base_model.blocks): + if i >= int6_start: + for m in block.modules(): + if isinstance(m, CastedLinear): + m._clip_range = 31 # int6 + if master_process: + int5_count = sum(1 for m in base_model.modules() if isinstance(m, CastedLinear) and m._clip_range == 15) + int6_count = sum(1 for m in base_model.modules() if isinstance(m, CastedLinear) and m._clip_range == 31) + log0(f"mixed_precision: {int5_count} int5 layers, {int6_count} int6 layers (last {args.int6_last_n} blocks)") + log0(f"model_params:{n_params}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:{xsa_layers} ws:{world_size} gqa:{args.num_heads}/{args.num_kv_heads}") + log0(f"lr:embed={token_lr} matrix={args.matrix_lr} scalar={args.scalar_lr} batch:{args.train_batch_tokens} wall:{args.max_wallclock_seconds:.0f}s seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + train_mixer = BackoffNgramMixer(vocab_size=args.vocab_size, device=str(device), eta=0.0) if base_model.alpha_head is not None else None + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + train_reserve_ms = 18000 + effective_train_ms = (max_wallclock_ms - train_reserve_ms) if max_wallclock_ms is not None else None + _prefill_offset_ms = 0.0 + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if effective_train_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = max(elapsed_ms - _prefill_offset_ms, 0.0) / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(effective_train_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + # TTT_ONLY mode: skip training, load saved model, run TTT eval + if os.environ.get("TTT_ONLY", "0") == "1": + log0("TTT_ONLY mode: skipping training, loading saved model...") + sd_cpu = {k: v.cpu() for k, v in torch.load("final_model.pt", map_location="cpu").items()} + if args.prune_pct > 0: + for k, v in sd_cpu.items(): + if v.ndim == 2 and v.numel() > 65536: + thresh = torch.quantile(v.abs().float(), args.prune_pct) + v[v.abs() < thresh] = 0.0 + log0(f"pruning:{args.prune_pct*100:.1f}% magnitude pruning applied") + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + mixer_head=args.mixer_head, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + if eval_model.alpha_head is not None: + eval_model.alpha_head.float() + eval_model.load_state_dict(deq_state, strict=True) + sw_seq_len = int(os.environ.get("EVAL_SEQ_LEN", str(effective_eval_seq_len))) + log0(f"TTT_ONLY: model loaded, starting TTT eval...") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_epochs = int(os.environ.get("TTT_EPOCHS", "3")) + ttt_lr = float(os.environ.get("TTT_LR", "0.0005")) + ttt_freeze = int(os.environ.get("TTT_FREEZE_BLOCKS", "2")) + ttt_chunk = int(os.environ.get("TTT_CHUNK_TOKENS", "32768")) + ttt_opt = os.environ.get("TTT_OPTIMIZER", "adamw") + log0(f"TTT: epochs={ttt_epochs} lr={ttt_lr} freeze_first={ttt_freeze} chunk={ttt_chunk} opt={ttt_opt}") + ttt_temp = args.ttt_temperature + log0(f"TTT temperature: {ttt_temp}") + ttt_val_loss, ttt_val_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, ttt_epochs=ttt_epochs, ttt_lr=ttt_lr, + ttt_freeze_blocks=ttt_freeze, eval_seq_len=sw_seq_len, + ttt_chunk_tokens=ttt_chunk, ttt_optimizer=ttt_opt, + ttt_temp=ttt_temp, + byte_weighted_ttt=os.environ.get("BYTE_WEIGHTED_TTT", "1") == "1", + ) + torch.cuda.synchronize() + log0( + f"final_int6_ttt val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"final_int6_ttt_exact val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() + return + + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + if train_mixer is not None: + log0("pre-compiling mixer loss path (dummy data, no training tokens)...") + _pc_seq = args.train_seq_len + _pc_batch = args.train_batch_tokens // (world_size * grad_accum_steps) // _pc_seq + _pc_x = torch.zeros(_pc_batch, _pc_seq, dtype=torch.int64, device=device) + _pc_y = torch.zeros(_pc_batch, _pc_seq, dtype=torch.int64, device=device) + _pc_bp = torch.full((_pc_batch, _pc_seq), 0.5, device=device) + _pc_op = torch.full((_pc_batch, _pc_seq, 6), 0.1, device=device) + _pc_ov = torch.ones(_pc_batch, _pc_seq, 6, dtype=torch.bool, device=device) + zero_grad_all() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + _pc_loss = model(_pc_x, _pc_y, _pc_bp, _pc_op, _pc_ov) + (_pc_loss * grad_scale).backward() + zero_grad_all() + del _pc_x, _pc_y, _pc_bp, _pc_op, _pc_ov, _pc_loss + torch.cuda.empty_cache() + log0("pre-compile done") + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = float(os.environ.get("EMA_DECAY", "0.997")) + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + if train_mixer is not None: + log0("prefilling n-gram tables from training shards (frozen oracle)...") + import glob as _glob + _PREFILL_CHUNK = 10_000_000 + for _shard in sorted(_glob.glob(args.train_files)): + _raw = np.fromfile(_shard, dtype=np.uint16) + for _off in range(0, len(_raw), _PREFILL_CHUNK): + _chunk = torch.from_numpy(_raw[_off:_off + _PREFILL_CHUNK].astype(np.int32)).to(device) + train_mixer.update(_chunk) + del _chunk + del _raw + torch.cuda.empty_cache() + torch.cuda.synchronize() + prefill_ms = 1000.0 * (time.perf_counter() - t0) + training_time_ms += prefill_ms + _prefill_offset_ms = prefill_ms + log0(f"prefilled {train_mixer.total_tokens:,} tokens in {prefill_ms:.0f}ms (counted in wallclock)") + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{max(training_time_ms - _prefill_offset_ms, 0.0) / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + # Anneal soft-round alpha based on QAT progress + if CastedLinear._use_soft_round and CastedLinear._qat_enabled: + qat_progress = max(0.0, 1.0 - scale / max(args.late_qat_threshold, 0.01)) + CastedLinear._soft_round_alpha = 1.0 + 15.0 * qat_progress + if args.late_qat_threshold > 0 and scale < args.late_qat_threshold and not CastedLinear._qat_enabled and step >= 50: + CastedLinear._qat_enabled = True + CastedLinear._use_soft_round = os.environ.get("SOFT_ROUND_QAT", "0") == "1" + if CastedLinear._use_soft_round and master_process: + log0(f"soft_round_qat:enabled initial_alpha=1.0") + log0(f"late_qat:enabled step:{step} scale:{scale:.4f}") + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + ngram_best_p, ngram_order_p, ngram_order_valid = None, None, None + if train_mixer is not None: + with torch.no_grad(): + best_p, order_p, order_valid = train_mixer._ngram_backoff_p(x, y, device) + ngram_best_p = best_p.detach() + ngram_order_p = order_p.detach() + ngram_order_valid = order_valid.detach() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + if ngram_best_p is not None: + loss = model(x, y, ngram_best_p, ngram_order_p, ngram_order_valid) + else: + loss = model(x, y) + # CROWN-Q: penalize quantization-sensitive weights during warmdown + crownq_lambda = float(os.environ.get("CROWN_Q_LAMBDA", "0.01")) + if CastedLinear._qat_enabled and crownq_lambda > 0: + cq_loss = torch.zeros((), device=device) + for m in base_model.modules(): + if isinstance(m, CastedLinear) and m.weight.ndim == 2: + w = m.weight.float() + cr = float(m._clip_range) + row_max = w.detach().abs().amax(dim=1) + delta = row_max / cr # quantization step size + cq_loss = cq_loss + (w.pow(2) * delta.pow(2).unsqueeze(1)).mean() + loss = loss + crownq_lambda * cq_loss / 12.0 + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + if args.swa_enabled and scale < 0.2 and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{max(approx_training_time_ms - _prefill_offset_ms, 0.0) / step:.2f}ms" + ) + reached_cap = effective_train_ms is not None and approx_training_time_ms >= effective_train_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply EMA weights directly (skip diagnostic evals to save ~5s of reserve) + log0("ema:applying EMA weights (skipping diagnostic evals)") + current_state = base_model.state_dict() + ema_sd = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(ema_sd, strict=True) + export_sd = base_model.state_dict() + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + if args.prune_pct > 0: + for k, v in sd_cpu.items(): + if v.ndim == 2 and v.numel() > 65536: + thresh = torch.quantile(v.abs().float(), args.prune_pct) + v[v.abs() < thresh] = 0.0 + if master_process: + log0(f"pruning:{args.prune_pct*100:.1f}% magnitude pruning applied") + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"}) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) if _COMPRESSOR == "zstd" else zlib.compress(quant_raw, 9) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(zstandard.ZstdDecompressor().decompress(quant_blob_disk) if _COMPRESSOR == "zstd" else zlib.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, dtg=args.dtg_enabled, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + mixer_head=args.mixer_head, + ).to(device).bfloat16() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + if eval_model.alpha_head is not None: + eval_model.alpha_head.float() + eval_model.load_state_dict(deq_state, strict=True) + sw_seq_len = int(os.environ.get("EVAL_SEQ_LEN", str(effective_eval_seq_len))) + if sw_seq_len != effective_eval_seq_len and rank == 0: + log0(f"Eval seq_len override: {effective_eval_seq_len} -> {sw_seq_len}") + if args.eval_stride > 0 and args.eval_stride < sw_seq_len and not os.environ.get("SKIP_SLIDING"): + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_epochs = int(os.environ.get("TTT_EPOCHS", "3")) + ttt_lr = float(os.environ.get("TTT_LR", "0.0005")) + ttt_freeze = int(os.environ.get("TTT_FREEZE_BLOCKS", "2")) + ttt_chunk = int(os.environ.get("TTT_CHUNK_TOKENS", "32768")) + ttt_opt = os.environ.get("TTT_OPTIMIZER", "adamw") + log0(f"TTT: epochs={ttt_epochs} lr={ttt_lr} freeze_first={ttt_freeze} chunk={ttt_chunk} opt={ttt_opt}") + ttt_temp = args.ttt_temperature + log0(f"TTT temperature: {ttt_temp}") + ttt_val_loss, ttt_val_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, ttt_epochs=ttt_epochs, ttt_lr=ttt_lr, + ttt_freeze_blocks=ttt_freeze, eval_seq_len=sw_seq_len, + ttt_chunk_tokens=ttt_chunk, ttt_optimizer=ttt_opt, + ttt_temp=ttt_temp, + byte_weighted_ttt=os.environ.get("BYTE_WEIGHTED_TTT", "1") == "1", + ) + torch.cuda.synchronize() + log0( + f"final_int6_ttt val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms" + ) + log0(f"final_int6_ttt_exact val_loss:{ttt_val_loss:.8f} val_bpb:{ttt_val_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main()