From 3b619c756cd8849f476c91635632f19f8c1b8e2f Mon Sep 17 00:00:00 2001 From: Christopher Lee McClendon Date: Mon, 23 Mar 2026 11:01:50 -0400 Subject: [PATCH] =?UTF-8?q?feat:=20LeakyReLU(0.5)=C2=B2=20+=20per-layer=20?= =?UTF-8?q?LR=20legal=20TTT=20(BPB=201.13872)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace ReLU² with LeakyReLU(0.5)² activation (-0.004 BPB pre-TTT) - Add per-layer LR groups: mlp.proj 3x, mlp.fc 0.5x for TTT - Add intra-chunk cosine LR schedule for TTT epochs - 3-seed validation: 1.13912, 1.14024, 1.13872 (mean 1.13936) - Score-first legal TTT with SGD momentum, 30 epochs, freeze-2 - Best seed (7): BPB 1.13872, artifact 15.36 MB --- .../README.md | 167 ++ .../submission.json | 28 + .../train.log | 1836 +++++++++++++++++ .../train_gpt.py | 1471 +++++++++++++ 4 files changed, 3502 insertions(+) create mode 100644 records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/README.md create mode 100644 records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/submission.json create mode 100644 records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/train.log create mode 100644 records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/train_gpt.py diff --git a/records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/README.md b/records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/README.md new file mode 100644 index 000000000..f609bfc26 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/README.md @@ -0,0 +1,167 @@ +# Depth Recurrence + LeakyReLU(0.5)² + Per-Layer LR Legal TTT (30 Epochs) + +**val_bpb = 1.13872** (best seed) | **3-seed mean: 1.13936 ± 0.0008** | Pre-TTT mean: 1.1574 | TTT gain: **−0.0182** | Artifact: 15.36 MB + +> Non-record unlimited-compute submission (trained on 4×A100-40GB, eval ~3690s on 1×A100). + +--- + +## Headline Result + +This submission integrates three techniques from recent PRs (#518, #481) with our legal score-first TTT recipe, achieving **BPB 1.13872** — improving on our prior 1.14252 (PR #526) by **−0.0038 BPB**. Validated across 3 seeds for reproducibility. + +| Seed | BPB | Δ vs PR #526 | +|------|-----|-------------| +| 1337 | 1.13912 | −0.00340 | +| 42 | 1.14024 | −0.00228 | +| **7** | **1.13872** | **−0.00380** | +| **Mean ± std** | **1.13936 ± 0.0008** | **−0.00316** | + +--- + +## What Changed vs PR #526 + +### Architecture: LeakyReLU(0.5)² Activation (from PR #518) + +Replace `relu(x)²` with `leaky_relu(x, 0.5)²` in the MLP. This preserves negative gradient flow, allowing the model to encode information in both positive and negative activations. The squaring still provides the non-linearity. + +- Pre-TTT BPB improvement: 1.1609 → 1.1574 mean (**−0.0035** across 3 seeds) +- Zero compute overhead, same parameter count +- This accounts for essentially all of the final BPB improvement. + +### TTT Recipe: Per-Layer LR + Intra-Chunk Cosine (from PR #481/#518) + +Two modifications to the TTT recipe, adopted from other PRs: + +1. **Per-layer LR groups**: `mlp.proj` gets 3× LR (higher quantization error), `mlp.fc` gets 0.5× LR. +2. **Intra-chunk cosine decay**: within each chunk's 30 TTT epochs, LR follows `0.5 × (1 + cos(π × step / total_steps))`. + +However, the TTT gain actually went from −0.0184 (PR #526) to −0.0182 (this PR), a **+0.0002 regression**. Without ablations isolating per-layer LR and intra-cosine from the LeakyReLU architecture change, we cannot confirm these TTT modifications help. They may be neutral or slightly negative on this architecture. + +--- + +## Architecture Summary + +| Component | Configuration | +|---|---| +| Layers | 11 logical (10 unique shared BlockCores) | +| Embedding dim | 512 | +| Heads | 8 (64 dim/head), 4 KV heads (GQA) | +| MLP | 3× expansion (1536), **LeakyReLU(0.5)²** activation | +| SmearGate | Learned token-mixing gate on input embeddings | +| Vocab | 1024 (SentencePiece BPE) | +| BigramHash | 2048 features, 128d | +| RoPE | Partial: 16/64 dims, NTK-aware scaling | +| Value Embeddings | 128d on layers 9–10, per-layer scale (init 0.1) | +| LN Scale | `1/√(layer+1)` depth scaling | +| XSA | Cross-sequence attention on last 4 layers | +| U-Net skips | Residual connections across layer pairs | +| Parameters | 24,634,452 total | + +## Training Details + +| Setting | Value | +|---|---| +| Hardware | 4×A100-40GB (NVIDIA) | +| Steps | 5,200 | +| Training wallclock | ~2,509s (~42 min) | +| Optimizer | Muon (hidden/attn) + Adam (embeddings/scalars) | +| SWA | Checkpoints from step 4,650 | +| Late QAT | Enabled at step 4,901 | +| Quantization | Int6 + zstd-22 | + +## TTT Protocol (Legal Score-First + Per-Layer LR) + +``` +for each 32K-token chunk: + 1. model.eval() + torch.inference_mode() + → Forward pass on chunk, accumulate NLL ← SCORE (graded) + 2. model.train() + → SGD(momentum=0.9), 30 epochs ← TRAIN (adaptation) + per-layer LR: mlp.proj 3x, mlp.fc 0.5x + intra-chunk cosine LR decay + inter-chunk cosine LR decay + 3. Advance to next chunk with updated weights +``` + +Every target token is scored exactly once, strictly before any gradient update that could benefit from it. The `torch.inference_mode()` context manager makes gradient leakage during scoring physically impossible. + +| TTT Setting | Value | +|---|---| +| Optimizer | SGD, momentum=0.9 | +| Base learning rate | 0.002 | +| mlp.proj LR mult | 3.0 | +| mlp.fc LR mult | 0.5 | +| Intra-chunk cosine | Enabled | +| Epochs per chunk | 30 | +| Chunk size | 32,768 tokens | +| Stride | 64 | +| Frozen blocks | First 2 (of 11) | +| Trainable params | 19,911,748 / 24,634,452 | +| Eval time | ~3,690s (1×A100) | + +## Quantization & Size + +| Component | Bytes | +|---|---| +| Model (int6 + zstd) | 15,283,215 | +| Code (train_gpt.py) | 74,030 | +| **Total** | **15,357,245** | +| Limit | 16,000,000 | +| Headroom | 642,755 (4.0%) | + +## Comparison to Prior Submissions + +| Metric | PR #456 (1ep) | PR #461 (3ep) | PR #526 (30ep) | This (30ep+) | Δ vs #526 | +|---|---|---|---|---|---| +| **val_bpb** | 1.15321 | 1.14458 | 1.14252 | **1.13872** | **−0.00380** | +| Pre-TTT BPB | 1.1600 | 1.1611 | 1.1609 | 1.1574 (mean) | −0.0035 | +| TTT gain | −0.0068 | −0.0165 | −0.0184 | **−0.0182** | +0.0002 | +| Activation | ReLU² | ReLU² | ReLU² | **LeakyReLU(0.5)²** | new | +| Per-layer LR | No | No | No | **Yes** | new | +| Intra-cosine | No | No | No | **Yes** | new | + +**Key insight**: The entire improvement comes from the better pre-TTT model (−0.0035 mean from LeakyReLU). Per-layer LR and intra-chunk cosine showed no measurable TTT improvement in this data — the TTT gain is −0.0182 vs −0.0184 in PR #526, a slight regression. These TTT modifications require further ablation to determine whether they help independently. + +## Credits + +This submission integrates work from many contributors to the parameter-golf competition: + +- **LeakyReLU(0.5)²** — PR #518 (sofiabod): −0.004 BPB pre-TTT architecture improvement +- **Per-layer LR for TTT** — PR #481 (mrdavtan): differential learning rates for quantization-error recovery +- **Intra-chunk cosine LR** — PR #518 (sofiabod): cosine decay within each chunk's TTT epochs +- **30-epoch legal TTT** — Our prior work (PR #526): SGD + momentum + freeze-2 +- **Score-first protocol** — Our prior work (PR #461): `torch.inference_mode()` during scoring +- **11L depth recurrence** — PR #455 / PR #442: shared BlockCores for weight-efficient depth +- **Partial RoPE, VE128, LN Scale** — PR #374 / PR #455: foundational architecture components +- **SmearGate, BigramHash, XSA** — Community contributions across multiple PRs +- **Muon optimizer** — PR #374 and descendants: Newton-Schulz orthogonal update for matrix params +- **SWA + Late QAT + int6/zstd** — Evolved across many PRs for quantization-aware training pipeline + +## Reproducibility + +```bash +# Environment: Python 3.10+, PyTorch 2.x with CUDA +# From the repo root: +RUN_ID=i39_leaky_perlr \ +NUM_LAYERS=11 \ +UNIQUE_LAYERS=10 \ +DATA_PATH=./data/datasets/fineweb10B_sp1024/ \ +TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \ +VOCAB_SIZE=1024 \ +MAX_WALLCLOCK_SECONDS=0 \ +ITERATIONS=5200 \ +VAL_LOSS_EVERY=500 \ +VE_ENABLED=1 VE_DIM=128 VE_LAYERS=9,10 \ +ROPE_DIMS=16 LN_SCALE=1 \ +BIGRAM_VOCAB_SIZE=2048 \ +XSA_LAST_N=4 EVAL_STRIDE=64 \ +MLP_ACTIVATION=leaky_relu_sq \ +TTT_ENABLED=1 TTT_LR=0.002 TTT_EPOCHS=30 \ +TTT_FREEZE_BLOCKS=2 TTT_BATCH_SEQS=32 TTT_MOMENTUM=0.9 \ +TTT_PERLAYER_LR=1 TTT_PROJ_LR_MULT=3.0 TTT_FC_LR_MULT=0.5 \ +TTT_INTRA_COSINE=1 \ +SEED=7 \ +torchrun --standalone --nproc_per_node=4 \ + records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/train_gpt.py +``` diff --git a/records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/submission.json b/records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/submission.json new file mode 100644 index 000000000..fd71d9ca4 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/submission.json @@ -0,0 +1,28 @@ +{ + "author": "Chris McClendon", + "github_id": "Christopher-Lee-McClendon", + "name": "11L LeakyReLU PerLayerLR LegalTTT 30ep", + "blurb": "11-layer depth-recurrence GPT with LeakyReLU(0.5)² (PR #518), per-layer LR for TTT (mlp.proj 3x, mlp.fc 0.5x, PR #481), intra-chunk cosine LR, VE128 (layers 9-10), Partial RoPE (16/64), legal score-first TTT (30ep SGD, freeze=2). 3-seed mean 1.13936 ± 0.0008. Trained on 4xA100.", + "date": "2026-03-23", + "track": "non-record-unlimited-compute-16mb", + "val_loss": 1.92267394, + "val_bpb": 1.13871882, + "pre_ttt_val_loss": 1.9533, + "pre_ttt_val_bpb": 1.1569, + "step_stop": 5200, + "wallclock_seconds": 2509, + "eval_time_seconds": 3690, + "bytes_total": 15357245, + "bytes_model_int6_zstd": 15283215, + "bytes_code": 74030, + "gpu": "4xA100-40GB", + "seed": 7, + "seeds_tested": 3, + "seed_results": { + "1337": 1.13912231, + "42": 1.14024348, + "7": 1.13871882 + }, + "seed_mean": 1.13936154, + "seed_std": 0.00078 +} diff --git a/records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/train.log b/records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/train.log new file mode 100644 index 000000000..30f29ec69 --- /dev/null +++ b/records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/train.log @@ -0,0 +1,1836 @@ +logs/i39_s3_55210276.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24634452 unique_cores:10 +unique_layers:10 mlp_mult:3.0 +matrix_params:23691264 scalar_params:25684 +world_size:4 grad_accum_steps:2 +tie_embeddings:True embed_lr:0.035 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:5200 warmup_steps:20 max_wallclock_seconds:0.000 +seed:7 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/5200 val_loss:6.9310 val_bpb:4.1049 train_time:0ms step_avg:0.01ms +step:1/5200 train_loss:6.9324 train_time:508ms step_avg:508.15ms +step:2/5200 train_loss:8.7688 train_time:981ms step_avg:490.63ms +step:3/5200 train_loss:7.6271 train_time:1471ms step_avg:490.31ms +step:4/5200 train_loss:7.2939 train_time:1975ms step_avg:493.68ms +step:5/5200 train_loss:7.1451 train_time:2456ms step_avg:491.12ms +step:6/5200 train_loss:6.8672 train_time:2941ms step_avg:490.11ms +step:7/5200 train_loss:6.8486 train_time:3422ms step_avg:488.91ms +step:8/5200 train_loss:6.6979 train_time:3906ms step_avg:488.26ms +step:9/5200 train_loss:6.4365 train_time:4391ms step_avg:487.94ms +step:10/5200 train_loss:6.0993 train_time:4875ms step_avg:487.47ms +step:100/5200 train_loss:3.2377 train_time:47701ms step_avg:477.01ms +step:200/5200 train_loss:2.5230 train_time:95726ms step_avg:478.63ms +step:300/5200 train_loss:2.5249 train_time:143720ms step_avg:479.07ms +step:400/5200 train_loss:2.4072 train_time:191504ms step_avg:478.76ms +step:500/5200 train_loss:2.3540 train_time:239197ms step_avg:478.39ms +step:500/5200 val_loss:2.3450 val_bpb:1.3888 train_time:239208ms step_avg:478.42ms +step:600/5200 train_loss:2.3381 train_time:286861ms step_avg:478.10ms +step:700/5200 train_loss:2.3799 train_time:334521ms step_avg:477.89ms +step:800/5200 train_loss:2.2325 train_time:382130ms step_avg:477.66ms +step:900/5200 train_loss:2.1118 train_time:429963ms step_avg:477.74ms +step:1000/5200 train_loss:2.2651 train_time:477565ms step_avg:477.56ms +step:1000/5200 val_loss:2.2144 val_bpb:1.3115 train_time:477575ms step_avg:477.58ms +step:1100/5200 train_loss:2.2484 train_time:525338ms step_avg:477.58ms +step:1200/5200 train_loss:2.2613 train_time:573169ms step_avg:477.64ms +step:1300/5200 train_loss:2.2086 train_time:620829ms step_avg:477.56ms +step:1400/5200 train_loss:2.2295 train_time:668533ms step_avg:477.52ms +step:1500/5200 train_loss:2.1890 train_time:716279ms step_avg:477.52ms +step:1500/5200 val_loss:2.1736 val_bpb:1.2873 train_time:716290ms step_avg:477.53ms +step:1600/5200 train_loss:2.1213 train_time:763990ms step_avg:477.49ms +step:1700/5200 train_loss:2.1580 train_time:811810ms step_avg:477.54ms +step:1800/5200 train_loss:2.1265 train_time:859715ms step_avg:477.62ms +step:1900/5200 train_loss:2.1151 train_time:907551ms step_avg:477.66ms +step:2000/5200 train_loss:2.0170 train_time:955448ms step_avg:477.72ms +step:2000/5200 val_loss:2.1199 val_bpb:1.2556 train_time:955459ms step_avg:477.73ms +step:2100/5200 train_loss:2.0097 train_time:1003368ms step_avg:477.79ms +step:2200/5200 train_loss:2.1313 train_time:1051264ms step_avg:477.85ms +step:2300/5200 train_loss:2.0424 train_time:1099100ms step_avg:477.87ms +step:2400/5200 train_loss:2.0650 train_time:1146897ms step_avg:477.87ms +step:2500/5200 train_loss:2.1276 train_time:1194775ms step_avg:477.91ms +step:2500/5200 val_loss:2.0875 val_bpb:1.2363 train_time:1194785ms step_avg:477.91ms +step:2600/5200 train_loss:2.1213 train_time:1242753ms step_avg:477.98ms +step:2700/5200 train_loss:2.0191 train_time:1290742ms step_avg:478.05ms +step:2800/5200 train_loss:2.1568 train_time:1338726ms step_avg:478.12ms +step:2900/5200 train_loss:2.0460 train_time:1386511ms step_avg:478.11ms +step:3000/5200 train_loss:2.0776 train_time:1434331ms step_avg:478.11ms +step:3000/5200 val_loss:2.0614 val_bpb:1.2209 train_time:1434342ms step_avg:478.11ms +step:3100/5200 train_loss:2.0778 train_time:1482179ms step_avg:478.12ms +step:3200/5200 train_loss:2.1065 train_time:1529969ms step_avg:478.12ms +step:3300/5200 train_loss:2.0627 train_time:1577776ms step_avg:478.11ms +step:3400/5200 train_loss:2.0492 train_time:1625616ms step_avg:478.12ms +step:3500/5200 train_loss:2.1301 train_time:1673566ms step_avg:478.16ms +step:3500/5200 val_loss:2.0380 val_bpb:1.2070 train_time:1673577ms step_avg:478.16ms +step:3600/5200 train_loss:2.0416 train_time:1721464ms step_avg:478.18ms +step:3700/5200 train_loss:2.0425 train_time:1769201ms step_avg:478.16ms +step:3800/5200 train_loss:2.0298 train_time:1816932ms step_avg:478.14ms +step:3900/5200 train_loss:2.0388 train_time:1864678ms step_avg:478.12ms +step:4000/5200 train_loss:2.0838 train_time:1912429ms step_avg:478.11ms +step:4000/5200 val_loss:2.0154 val_bpb:1.1936 train_time:1912440ms step_avg:478.11ms +step:4100/5200 train_loss:2.0041 train_time:1960200ms step_avg:478.10ms +step:4200/5200 train_loss:2.0223 train_time:2008058ms step_avg:478.11ms +step:4300/5200 train_loss:1.9992 train_time:2055831ms step_avg:478.10ms +step:4400/5200 train_loss:1.9372 train_time:2103787ms step_avg:478.13ms +step:4500/5200 train_loss:2.0377 train_time:2151676ms step_avg:478.15ms +step:4500/5200 val_loss:1.9882 val_bpb:1.1775 train_time:2151687ms step_avg:478.15ms +step:4600/5200 train_loss:1.8825 train_time:2199385ms step_avg:478.13ms +swa:start step:4650 +step:4700/5200 train_loss:2.0725 train_time:2249289ms step_avg:478.57ms +step:4800/5200 train_loss:2.1830 train_time:2301203ms step_avg:479.42ms +step:4900/5200 train_loss:1.9485 train_time:2353078ms step_avg:480.22ms +late_qat:enabled step:4901 scale:0.0997 clip_range:31 +step:5000/5200 train_loss:1.9777 train_time:2405158ms step_avg:481.03ms +step:5000/5200 val_loss:1.9602 val_bpb:1.1610 train_time:2407189ms step_avg:481.44ms +step:5100/5200 train_loss:1.9852 train_time:2457292ms step_avg:481.82ms +step:5200/5200 train_loss:1.9884 train_time:2509281ms step_avg:482.55ms +step:5200/5200 val_loss:1.9533 val_bpb:1.1569 train_time:2511303ms step_avg:482.94ms +peak memory allocated: 20223 MiB reserved: 20350 MiB +swa:applying averaged 12 checkpoints +Serialized model: 96746619 bytes +Code size: 74030 bytes +Total submission size: 96820649 bytes +magnitude_pruning: frac=0.03 +=== Weight distribution diagnostics === + OUTLIER cores.0.attn.c_k.weight: max=2.8573 mean=0.1386 ratio=20.6 kurtosis=10.3 + OUTLIER cores.9.mlp.proj.weight: max=2.6011 mean=0.1012 ratio=25.7 kurtosis=0.6 +Serialized model int6+zstd: 15283215 bytes +Total submission size int6+zstd: 15357245 bytes +final_eval_mode:sliding_window_ttt stride:64 chunk_tokens:32768 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=30 freeze_blocks=2 +ttt_sliding:params unfrozen=19911748 frozen=4722704 +ttt_sliding:perlayer_lr proj=8(3.0x) fc=8(0.5x) other=78(1x) + ttt_chunk [1/1893] bpb=1.187163 time=2.1s + ttt_chunk [11/1893] bpb=1.143039 time=21.5s + ttt_chunk [21/1893] bpb=1.146252 time=40.9s + ttt_chunk [31/1893] bpb=1.149530 time=60.4s + ttt_chunk [41/1893] bpb=1.138514 time=80.0s + ttt_chunk [51/1893] bpb=1.135680 time=99.5s + ttt_chunk [61/1893] bpb=1.141447 time=119.0s + ttt_chunk [71/1893] bpb=1.138207 time=138.6s + ttt_chunk [81/1893] bpb=1.138302 time=158.1s + ttt_chunk [91/1893] bpb=1.137471 time=177.6s + ttt_chunk [101/1893] bpb=1.140623 time=197.1s + ttt_chunk [111/1893] bpb=1.141896 time=216.6s + ttt_chunk [121/1893] bpb=1.138736 time=236.1s + ttt_chunk [131/1893] bpb=1.138982 time=255.6s + ttt_chunk [141/1893] bpb=1.138685 time=275.1s + ttt_chunk [151/1893] bpb=1.142057 time=294.5s + ttt_chunk [161/1893] bpb=1.144021 time=314.0s + ttt_chunk [171/1893] bpb=1.144846 time=333.5s + ttt_chunk [181/1893] bpb=1.144914 time=353.0s + ttt_chunk [191/1893] bpb=1.148367 time=372.5s + ttt_chunk [201/1893] bpb=1.148661 time=392.0s + ttt_chunk [211/1893] bpb=1.146452 time=411.5s + ttt_chunk [221/1893] bpb=1.148524 time=431.0s + ttt_chunk [231/1893] bpb=1.147960 time=450.5s + ttt_chunk [241/1893] bpb=1.147804 time=470.0s + ttt_chunk [251/1893] bpb=1.146194 time=489.5s + ttt_chunk [261/1893] bpb=1.144612 time=509.0s + ttt_chunk [271/1893] bpb=1.143303 time=528.5s + ttt_chunk [281/1893] bpb=1.145809 time=548.0s + ttt_chunk [291/1893] bpb=1.146605 time=567.5s + ttt_chunk [301/1893] bpb=1.147232 time=586.9s + ttt_chunk [311/1893] bpb=1.148884 time=606.4s + ttt_chunk [321/1893] bpb=1.150381 time=625.9s + ttt_chunk [331/1893] bpb=1.150310 time=645.4s + ttt_chunk [341/1893] bpb=1.150673 time=664.9s + ttt_chunk [351/1893] bpb=1.151971 time=684.4s + ttt_chunk [361/1893] bpb=1.153347 time=703.9s + ttt_chunk [371/1893] bpb=1.152868 time=723.4s + ttt_chunk [381/1893] bpb=1.152787 time=742.9s + ttt_chunk [391/1893] bpb=1.152332 time=762.4s + ttt_chunk [401/1893] bpb=1.151009 time=781.9s + ttt_chunk [411/1893] bpb=1.149860 time=801.4s + ttt_chunk [421/1893] bpb=1.149308 time=820.9s + ttt_chunk [431/1893] bpb=1.150036 time=840.4s + ttt_chunk [441/1893] bpb=1.149868 time=859.9s + ttt_chunk [451/1893] bpb=1.149604 time=879.4s + ttt_chunk [461/1893] bpb=1.148892 time=898.9s + ttt_chunk [471/1893] bpb=1.148520 time=918.4s + ttt_chunk [481/1893] bpb=1.148299 time=937.9s + ttt_chunk [491/1893] bpb=1.147958 time=957.4s + ttt_chunk [501/1893] bpb=1.147417 time=976.9s + ttt_chunk [511/1893] bpb=1.146833 time=996.4s + ttt_chunk [521/1893] bpb=1.145968 time=1015.9s + ttt_chunk [531/1893] bpb=1.146059 time=1035.4s + ttt_chunk [541/1893] bpb=1.145938 time=1054.9s + ttt_chunk [551/1893] bpb=1.144728 time=1074.4s + ttt_chunk [561/1893] bpb=1.145250 time=1093.9s + ttt_chunk [571/1893] bpb=1.144542 time=1113.4s + ttt_chunk [581/1893] bpb=1.144000 time=1132.9s + ttt_chunk [591/1893] bpb=1.143356 time=1152.4s + ttt_chunk [601/1893] bpb=1.144041 time=1171.9s + ttt_chunk [611/1893] bpb=1.143643 time=1191.4s + ttt_chunk [621/1893] bpb=1.143586 time=1210.9s + ttt_chunk [631/1893] bpb=1.144017 time=1230.4s + ttt_chunk [641/1893] bpb=1.143800 time=1249.9s + ttt_chunk [651/1893] bpb=1.143793 time=1269.5s + ttt_chunk [661/1893] bpb=1.143682 time=1289.2s + ttt_chunk [671/1893] bpb=1.143335 time=1308.9s + ttt_chunk [681/1893] bpb=1.143586 time=1328.4s + ttt_chunk [691/1893] bpb=1.144292 time=1347.9s + ttt_chunk [701/1893] bpb=1.143508 time=1367.3s + ttt_chunk [711/1893] bpb=1.144083 time=1386.8s + ttt_chunk [721/1893] bpb=1.143702 time=1406.3s + ttt_chunk [731/1893] bpb=1.144159 time=1425.8s + ttt_chunk [741/1893] bpb=1.144113 time=1445.3s + ttt_chunk [751/1893] bpb=1.143717 time=1464.8s + ttt_chunk [761/1893] bpb=1.143564 time=1484.3s + ttt_chunk [771/1893] bpb=1.143343 time=1503.7s + ttt_chunk [781/1893] bpb=1.143891 time=1523.2s + ttt_chunk [791/1893] bpb=1.143571 time=1542.8s + ttt_chunk [801/1893] bpb=1.143625 time=1562.3s + ttt_chunk [811/1893] bpb=1.143156 time=1581.8s + ttt_chunk [821/1893] bpb=1.142978 time=1601.3s + ttt_chunk [831/1893] bpb=1.142559 time=1620.8s + ttt_chunk [841/1893] bpb=1.141989 time=1640.3s + ttt_chunk [851/1893] bpb=1.141929 time=1659.8s + ttt_chunk [861/1893] bpb=1.142074 time=1679.3s + ttt_chunk [871/1893] bpb=1.142141 time=1698.8s + ttt_chunk [881/1893] bpb=1.142206 time=1718.3s + ttt_chunk [891/1893] bpb=1.142029 time=1737.8s + ttt_chunk [901/1893] bpb=1.142009 time=1757.3s + ttt_chunk [911/1893] bpb=1.141995 time=1776.8s + ttt_chunk [921/1893] bpb=1.142371 time=1796.3s + ttt_chunk [931/1893] bpb=1.142192 time=1815.7s + ttt_chunk [941/1893] bpb=1.142079 time=1835.2s + ttt_chunk [951/1893] bpb=1.142101 time=1854.7s + ttt_chunk [961/1893] bpb=1.141856 time=1874.2s + ttt_chunk [971/1893] bpb=1.142610 time=1893.7s + ttt_chunk [981/1893] bpb=1.142748 time=1913.1s + ttt_chunk [991/1893] bpb=1.142665 time=1932.6s + ttt_chunk [1001/1893] bpb=1.142827 time=1952.1s + ttt_chunk [1011/1893] bpb=1.143142 time=1971.6s + ttt_chunk [1021/1893] bpb=1.143299 time=1991.0s + ttt_chunk [1031/1893] bpb=1.143846 time=2010.5s + ttt_chunk [1041/1893] bpb=1.143496 time=2030.0s + ttt_chunk [1051/1893] bpb=1.143206 time=2049.5s + ttt_chunk [1061/1893] bpb=1.143479 time=2069.0s + ttt_chunk [1071/1893] bpb=1.143976 time=2088.4s + ttt_chunk [1081/1893] bpb=1.143984 time=2107.9s + ttt_chunk [1091/1893] bpb=1.144386 time=2127.4s + ttt_chunk [1101/1893] bpb=1.144535 time=2146.9s + ttt_chunk [1111/1893] bpb=1.144298 time=2166.4s + ttt_chunk [1121/1893] bpb=1.144253 time=2185.8s + ttt_chunk [1131/1893] bpb=1.144132 time=2205.3s + ttt_chunk [1141/1893] bpb=1.143989 time=2225.0s + ttt_chunk [1151/1893] bpb=1.144045 time=2244.6s + ttt_chunk [1161/1893] bpb=1.143460 time=2264.2s + ttt_chunk [1171/1893] bpb=1.144007 time=2283.7s + ttt_chunk [1181/1893] bpb=1.143498 time=2303.2s + ttt_chunk [1191/1893] bpb=1.143213 time=2322.7s + ttt_chunk [1201/1893] bpb=1.143788 time=2342.2s + ttt_chunk [1211/1893] bpb=1.143207 time=2361.7s + ttt_chunk [1221/1893] bpb=1.142895 time=2381.2s + ttt_chunk [1231/1893] bpb=1.142789 time=2400.7s + ttt_chunk [1241/1893] bpb=1.142592 time=2420.2s + ttt_chunk [1251/1893] bpb=1.142346 time=2439.7s + ttt_chunk [1261/1893] bpb=1.142293 time=2459.2s + ttt_chunk [1271/1893] bpb=1.142108 time=2478.7s + ttt_chunk [1281/1893] bpb=1.141915 time=2498.2s + ttt_chunk [1291/1893] bpb=1.141746 time=2517.7s + ttt_chunk [1301/1893] bpb=1.141376 time=2537.2s + ttt_chunk [1311/1893] bpb=1.141054 time=2556.7s + ttt_chunk [1321/1893] bpb=1.140870 time=2576.2s + ttt_chunk [1331/1893] bpb=1.140776 time=2595.7s + ttt_chunk [1341/1893] bpb=1.140660 time=2615.2s + ttt_chunk [1351/1893] bpb=1.140614 time=2634.7s + ttt_chunk [1361/1893] bpb=1.140815 time=2654.2s + ttt_chunk [1371/1893] bpb=1.140663 time=2673.7s + ttt_chunk [1381/1893] bpb=1.140596 time=2693.2s + ttt_chunk [1391/1893] bpb=1.140057 time=2712.7s + ttt_chunk [1401/1893] bpb=1.140086 time=2732.2s + ttt_chunk [1411/1893] bpb=1.140112 time=2751.7s + ttt_chunk [1421/1893] bpb=1.140366 time=2771.2s + ttt_chunk [1431/1893] bpb=1.140222 time=2790.7s + ttt_chunk [1441/1893] bpb=1.140904 time=2810.2s + ttt_chunk [1451/1893] bpb=1.141023 time=2829.7s + ttt_chunk [1461/1893] bpb=1.140755 time=2849.1s + ttt_chunk [1471/1893] bpb=1.141662 time=2868.6s + ttt_chunk [1481/1893] bpb=1.141442 time=2888.1s + ttt_chunk [1491/1893] bpb=1.141442 time=2907.6s + ttt_chunk [1501/1893] bpb=1.141617 time=2927.0s + ttt_chunk [1511/1893] bpb=1.141711 time=2946.5s + ttt_chunk [1521/1893] bpb=1.141739 time=2966.0s + ttt_chunk [1531/1893] bpb=1.141564 time=2985.5s + ttt_chunk [1541/1893] bpb=1.141528 time=3004.9s + ttt_chunk [1551/1893] bpb=1.141869 time=3024.4s + ttt_chunk [1561/1893] bpb=1.142013 time=3043.9s + ttt_chunk [1571/1893] bpb=1.142142 time=3063.4s + ttt_chunk [1581/1893] bpb=1.142261 time=3082.8s + ttt_chunk [1591/1893] bpb=1.142195 time=3102.3s + ttt_chunk [1601/1893] bpb=1.142366 time=3121.8s + ttt_chunk [1611/1893] bpb=1.142436 time=3141.3s + ttt_chunk [1621/1893] bpb=1.142290 time=3160.7s + ttt_chunk [1631/1893] bpb=1.142448 time=3180.2s + ttt_chunk [1641/1893] bpb=1.142320 time=3199.7s + ttt_chunk [1651/1893] bpb=1.142242 time=3219.2s + ttt_chunk [1661/1893] bpb=1.142133 time=3238.6s + ttt_chunk [1671/1893] bpb=1.142496 time=3258.1s + ttt_chunk [1681/1893] bpb=1.142751 time=3277.6s + ttt_chunk [1691/1893] bpb=1.142706 time=3297.1s + ttt_chunk [1701/1893] bpb=1.142696 time=3316.5s + ttt_chunk [1711/1893] bpb=1.142536 time=3336.0s + ttt_chunk [1721/1893] bpb=1.142390 time=3355.5s + ttt_chunk [1731/1893] bpb=1.142342 time=3375.0s + ttt_chunk [1741/1893] bpb=1.142136 time=3394.5s + ttt_chunk [1751/1893] bpb=1.141977 time=3414.0s + ttt_chunk [1761/1893] bpb=1.142060 time=3433.5s + ttt_chunk [1771/1893] bpb=1.141994 time=3453.0s + ttt_chunk [1781/1893] bpb=1.141960 time=3472.5s + ttt_chunk [1791/1893] bpb=1.141546 time=3492.0s + ttt_chunk [1801/1893] bpb=1.141545 time=3511.5s + ttt_chunk [1811/1893] bpb=1.141372 time=3531.0s + ttt_chunk [1821/1893] bpb=1.141395 time=3550.5s + ttt_chunk [1831/1893] bpb=1.141006 time=3570.0s + ttt_chunk [1841/1893] bpb=1.141051 time=3589.5s + ttt_chunk [1851/1893] bpb=1.140835 time=3609.0s + ttt_chunk [1861/1893] bpb=1.140353 time=3628.5s + ttt_chunk [1871/1893] bpb=1.140201 time=3648.0s + ttt_chunk [1881/1893] bpb=1.139810 time=3667.5s + ttt_chunk [1891/1893] bpb=1.139650 time=3687.0s + ttt_chunk [1893/1893] bpb=1.139672 time=3689.2s +ttt_sliding:done val_loss=1.922674 val_bpb=1.138719 elapsed=3689.3s +final_int6_roundtrip val_loss:1.9227 val_bpb:1.1387 eval_time:3689782ms +final_int6_roundtrip_exact val_loss:1.92267394 val_bpb:1.13871882 +or.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(p in name for p in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if force_int5 else (15 if cat == "mlp" else 31) + q, s = quantize_intN_per_row(t, clip_range=clip, gptq_lite=gptq_lite) + bits = {15: 5, 31: 6, 63: 7}.get(clip, 6) + result[name + ".q"] = q; result[name + ".scale"] = s + meta[name] = {"type": f"int{bits}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q; result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + _qat_clip_range: int = 31 + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + cr = CastedLinear._qat_clip_range + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale_q = (row_max / float(cr)).clamp_min(1.0 / float(cr)) + w_q = (torch.clamp(torch.round(w32 / scale_q[:, None]), -(cr + 1), cr) * scale_q[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS)) \ + and param.dtype != torch.float32: + param.data = param.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, + rope_dims: int = 0): + super().__init__() + self.dim, self.base = dim, base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + rd = self.rope_dims + inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if (self._cos_cached is None or self._sin_cached is None + or self._seq_len_cached != seq_len or self._cos_cached.device != device): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, + rope_base: float, qk_gain_init: float, xsa_enabled: bool = False, + rope_dims: int = 0): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + self.xsa_enabled = xsa_enabled + self.rope_dims = rope_dims + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, + rope_dims=rope_dims) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + # Transpose to [B, H, T, D] for SDPA + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + if _IS_AMPERE_PLUS and self.num_kv_heads != self.num_heads: + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True, enable_gqa=True) + else: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(repeats, dim=1) + v_for_sdpa = v.repeat_interleave(repeats, dim=1) + else: + v_for_sdpa = v + y = F.scaled_dot_product_attention(q, k, v_for_sdpa, attn_mask=None, is_causal=True) + if self.xsa_enabled: + group_size = self.num_heads // self.num_kv_heads + y_t = y.transpose(1, 2) + y_grouped = y_t.reshape(bsz, seqlen, self.num_kv_heads, group_size, self.head_dim) + vn = F.normalize(v.transpose(1, 2).unsqueeze(3), dim=-1) + dot_prod = (y_grouped * vn).sum(dim=-1, keepdim=True) + y = (y_grouped - dot_prod * vn).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float, activation: str = "relu_sq"): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.activation = activation + + def forward(self, x: Tensor) -> Tensor: + if self.activation == "leaky_relu_sq": + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + else: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class BlockCore(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, + rope_base: float, qk_gain_init: float, + xsa_enabled: bool = False, mlp_activation: str = "relu_sq", + rope_dims: int = 0): + super().__init__() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + xsa_enabled=xsa_enabled, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult, activation=mlp_activation) + +class Block(nn.Module): + def __init__(self, dim: int, layer_idx: int = 0, ln_scale: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, core: BlockCore, + v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * core.attn( + self.attn_norm(x) * self.ln_scale_factor, v_embed=v_embed) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * core.mlp( + self.mlp_norm(x) * self.ln_scale_factor) + return x + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, model_dim: int, num_heads: int, + num_kv_heads: int, mlp_mult: float, tie_embeddings: bool, + tied_embed_init_std: float, logit_softcap: float, rope_base: float, + qk_gain_init: float, bigram_vocab_size: int = 0, bigram_dim: int = 128, + unique_layers: int = 0, xsa_last_n: int = 0, mlp_activation: str = "relu_sq", + rope_dims: int = 0, ln_scale: bool = False, + ve_enabled: bool = False, ve_dim: int = 128, ve_layers: str = "9,10"): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) \ + if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter( + torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + n_cores = unique_layers if (0 < unique_layers < num_layers) else num_layers + xsa_start = max(0, n_cores - xsa_last_n) if xsa_last_n > 0 else n_cores + self.cores = nn.ModuleList([ + BlockCore(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, xsa_enabled=(i >= xsa_start), + mlp_activation=mlp_activation, rope_dims=rope_dims) + for i in range(n_cores) + ]) + self.blocks = nn.ModuleList([ + Block(model_dim, layer_idx=i, ln_scale=ln_scale) + for i in range(num_layers) + ]) + self._core_indices = [i % n_cores for i in range(num_layers)] + if n_cores < num_layers: + from collections import Counter + uses = Counter(self._core_indices) + for core_idx, core in enumerate(self.cores): + n_uses = uses[core_idx] + if n_uses > 1: + scale = 1.0 / n_uses + for p in core.parameters(): + p.register_hook(lambda grad, s=scale: grad * s) + # Value Embedding (VE128) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = num_kv_heads * (model_dim // num_heads) + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices]) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, CastedLinear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def _get_ve(self, layer_idx: int, input_ids: Tensor, + ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def _forward_body(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, self.cores[self._core_indices[i]], v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + idx = self.num_encoder_layers + i + ve = self._get_ve(idx, input_ids, ve_cache) + x = self.blocks[idx](x, x0, self.cores[self._core_indices[idx]], v_embed=ve) + return self.final_norm(x) + + def _logits(self, x: Tensor) -> Tensor: + if self.tie_embeddings: + raw = F.linear(x, self.tok_emb.weight) + else: + raw = self.lm_head(x) + return self.logit_softcap * torch.tanh(raw / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self._forward_body(input_ids) + x = x.reshape(-1, x.size(-1)) + logits = self._logits(x) + return F.cross_entropy(logits.float(), target_ids.reshape(-1), reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + return self._logits(self._forward_body(input_ids)) + +def eval_val_sliding( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=_HALF_DTYPE): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + rl = (loss_sum / token_count).item() if token_count.item() > 0 else 0.0 + rbpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) if token_count.item() > 0 else 0.0 + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} " + f"windows running_bpb={rbpb:.6f}", flush=True) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + base_model.train() + return val_loss, val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT: score each chunk with sliding windows, then train on it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts (same as eval_val_sliding) + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + # BPB accumulators + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Setup TTT optimizer (SGD + momentum for the legal score-first TTT pass) + n_blocks = len(base_model.blocks) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, n_blocks))) + frozen_core_ids = set(base_model._core_indices[i] for i in frozen_block_ids) if frozen_block_ids else set() + + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True; break + if not freeze: + for ci_core in frozen_core_ids: + if f"cores.{ci_core}." in name: + freeze = True; break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + # Per-layer LR groups: mlp.proj gets higher LR (high quant error), mlp.fc gets lower LR + if args.ttt_perlayer_lr: + proj_params, fc_params, other_params = [], [], [] + for name, p in base_model.named_parameters(): + if not p.requires_grad: + continue + if "mlp.proj" in name: + proj_params.append(p) + elif "mlp.fc" in name: + fc_params.append(p) + else: + other_params.append(p) + param_groups = [ + {"params": proj_params, "lr": args.ttt_lr * args.ttt_proj_lr_mult}, + {"params": fc_params, "lr": args.ttt_lr * args.ttt_fc_lr_mult}, + {"params": other_params, "lr": args.ttt_lr}, + ] + log0(f"ttt_sliding:perlayer_lr proj={len(proj_params)}({args.ttt_proj_lr_mult}x) " + f"fc={len(fc_params)}({args.ttt_fc_lr_mult}x) other={len(other_params)}(1x)") + else: + param_groups = [{"params": ttt_params, "lr": args.ttt_lr}] + + optimizer = torch.optim.SGD(param_groups, momentum=args.ttt_momentum) + # Store initial per-group LRs for cosine scheduling + for pg in optimizer.param_groups: + pg['_base_lr'] = pg['lr'] + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (sliding window eval) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=_HALF_DTYPE): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk's tokens (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + # Cosine decay across chunks (inter-chunk schedule) + inter_cos = 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + + # Store base LR for each param group (for intra-chunk cosine) + base_lrs = [pg['_base_lr'] for pg in optimizer.param_groups] + + # Partition training seqs across ranks + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + # Count steps per epoch for intra-chunk cosine + steps_per_epoch = max(1, (my_chunk_seqs + args.ttt_batch_seqs - 1) // args.ttt_batch_seqs) + total_chunk_steps = args.ttt_epochs * steps_per_epoch + chunk_step = 0 + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + # Intra-chunk cosine LR schedule (within each chunk's TTT epochs) + if args.ttt_intra_cosine and total_chunk_steps > 1: + intra_cos = 0.5 * (1.0 + math.cos(math.pi * chunk_step / total_chunk_steps)) + else: + intra_cos = 1.0 + for i, pg in enumerate(optimizer.param_groups): + pg['lr'] = base_lrs[i] * inter_cos * intra_cos + chunk_step += 1 + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + actual_be = my_seq_s + be + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + actual_be * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + if args.ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + # Progress log + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + # Final all-reduce + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Restore state + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if _IS_AMPERE_PLUS: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + if _IS_AMPERE_PLUS: + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(False); enable_math_sdp(False) + else: + enable_cudnn_sdp(False); enable_flash_sdp(False) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: return + if console: print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, + text=True, check=False).stdout, console=False) + log0("=" * 100, console=False) + random.seed(args.seed); np.random.seed(args.seed) + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} != tokenizer vocab_size={int(sp.vocab_size())}") + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, unique_layers=args.unique_layers, + xsa_last_n=args.xsa_last_n, mlp_activation=args.mlp_activation, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + ).to(device).to(_HALF_DTYPE) + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if _IS_AMPERE_PLUS: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + else: + log0("skipping torch.compile on non-Ampere GPU") + compiled_model = base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], + broadcast_buffers=False) if distributed else compiled_model + + matrix_params, scalar_params = [], [] + for name, p in base_model.cores.named_parameters(): + if p.ndim == 2 and not any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS): + matrix_params.append(p) + else: + scalar_params.append(p) + for name, p in base_model.blocks.named_parameters(): + if p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS): + scalar_params.append(p) + elif p.ndim == 2 and not any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS): + matrix_params.append(p) + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], + "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], + "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + + optimizer_tok = torch.optim.AdamW( + tok_params, betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.adam_wd, fused=True) + optimizer_muon = Muon( + matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, weight_decay=0.04) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.adam_wd, fused=True) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} unique_cores:{len(base_model.cores)}") + log0(f"unique_layers:{args.unique_layers} mlp_mult:{args.mlp_mult}") + log0(f"matrix_params:{sum(p.numel() for p in matrix_params)} " + f"scalar_params:{sum(p.numel() for p in scalar_params)}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}") + log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}") + log0(f"seed:{args.seed}") + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + if warmdown_start <= step < args.iterations: + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) + return 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = { + name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=_HALF_DTYPE, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + while True: + last_step = step == args.iterations or ( + stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + if (args.quant_eval_every > 0 and should_validate + and lr_mul(step, training_time_ms) < args.swa_start_frac + and step % args.quant_eval_every == 0 and master_process): + with torch.no_grad(): + sd_snap = {k: v.detach().cpu().clone() for k, v in base_model.state_dict().items()} + qr, qm = mixed_quantize_int6(sd_snap, {"mlp", "attn", "bigram"}) + deq = dequantize_mixed_int6(qr, qm, sd_snap) + orig_sd = base_model.state_dict() + base_model.load_state_dict( + {k: v.to(dtype=orig_sd[k].dtype, device=orig_sd[k].device) for k, v in deq.items()}, + strict=True) + _, q_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"quant_gap step:{step} float_bpb:{val_bpb:.4f} int6_bpb:{q_bpb:.4f} gap:{q_bpb - val_bpb:.4f}") + base_model.load_state_dict(orig_sd, strict=True) + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=_HALF_DTYPE, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if args.late_qat and scale < args.qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + CastedLinear._qat_clip_range = 15 if args.all_int5 else 31 + log0(f"late_qat:enabled step:{step} scale:{scale:.4f} clip_range:{CastedLinear._qat_clip_range}") + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = (args.train_log_every > 0 and + (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None)) + if should_log_train: + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms") + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB") + + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = {name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + _model_pt = f"final_model_{args.run_id}.pt" + _model_ptz = f"final_model_{args.run_id}.int8.ptz" + if master_process: + torch.save(base_model.state_dict(), _model_pt) + model_bytes = os.path.getsize(_model_pt) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), args.prune_frac) + param.masked_fill_(param.abs() < threshold, 0.0) + log0(f"magnitude_pruning: frac={args.prune_frac}") + + if master_process: + log0("=== Weight distribution diagnostics ===") + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 8192: + t = param.detach().float() + absmax = t.abs().max().item() + absmean = t.abs().mean().item() + kurtosis = ((t - t.mean()) / t.std()).pow(4).mean().item() - 3.0 + if kurtosis > 5.0 or absmax / absmean > 20.0: + log0(f" OUTLIER {name}: max={absmax:.4f} mean={absmean:.4f} " + f"ratio={absmax/absmean:.1f} kurtosis={kurtosis:.1f}") + + CastedLinear._qat_enabled = False + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}, + gptq_lite=args.gptq_lite, + force_int5=args.all_int5) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open(_model_ptz, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(_model_ptz) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(_model_ptz, "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + base_model.load_state_dict(deq_state, strict=True) + + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.ttt_enabled and args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window_ttt stride:{args.eval_stride} " + f"chunk_tokens:{args.ttt_chunk_tokens}") + q_val_loss, q_val_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, log0=log0) + elif args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms") + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + +if __name__ == "__main__": + main() + +==================================================================================================== +Running Python 3.12.7 | packaged by conda-forge | (main, Oct 4 2024, 16:05:46) [GCC 13.3.0] +Running PyTorch 2.10.0+cu128 +Mon Mar 23 09:00:05 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 570.172.08 Driver Version: 570.172.08 CUDA Version: 12.8 | +|-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA A100-PCIE-40GB On | 00000000:17:00.0 Off | 0 | +| N/A 34C P0 47W / 250W | 667MiB / 40960MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 1 NVIDIA A100-PCIE-40GB On | 00000000:65:00.0 Off | 0 | +| N/A 35C P0 46W / 250W | 667MiB / 40960MiB | 10% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 2 NVIDIA A100-PCIE-40GB On | 00000000:CA:00.0 Off | 0 | +| N/A 34C P0 49W / 250W | 667MiB / 40960MiB | 0% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ +| 3 NVIDIA A100-PCIE-40GB On | 00000000:E3:00.0 Off | 0 | +| N/A 34C P0 46W / 250W | 667MiB / 40960MiB | 10% Default | +| | | Disabled | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| 0 N/A N/A 661034 C ...ameter_golf/.venv/bin/python3 658MiB | +| 1 N/A N/A 661035 C ...ameter_golf/.venv/bin/python3 658MiB | +| 2 N/A N/A 661036 C ...ameter_golf/.venv/bin/python3 658MiB | +| 3 N/A N/A 661037 C ...ameter_golf/.venv/bin/python3 658MiB | ++-----------------------------------------------------------------------------------------+ + +==================================================================================================== +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/hpfs/scratch/gpfs/mcclec07/code/parameter_golf/repo/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:24634452 unique_cores:10 +unique_layers:10 mlp_mult:3.0 +matrix_params:23691264 scalar_params:25684 +world_size:4 grad_accum_steps:2 +tie_embeddings:True embed_lr:0.035 matrix_lr:0.025 scalar_lr:0.025 +train_batch_tokens:786432 train_seq_len:2048 iterations:5200 warmup_steps:20 max_wallclock_seconds:0.000 +seed:7 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/5200 val_loss:6.9310 val_bpb:4.1049 train_time:0ms step_avg:0.01ms +step:1/5200 train_loss:6.9324 train_time:508ms step_avg:508.15ms +step:2/5200 train_loss:8.7688 train_time:981ms step_avg:490.63ms +step:3/5200 train_loss:7.6271 train_time:1471ms step_avg:490.31ms +step:4/5200 train_loss:7.2939 train_time:1975ms step_avg:493.68ms +step:5/5200 train_loss:7.1451 train_time:2456ms step_avg:491.12ms +step:6/5200 train_loss:6.8672 train_time:2941ms step_avg:490.11ms +step:7/5200 train_loss:6.8486 train_time:3422ms step_avg:488.91ms +step:8/5200 train_loss:6.6979 train_time:3906ms step_avg:488.26ms +step:9/5200 train_loss:6.4365 train_time:4391ms step_avg:487.94ms +step:10/5200 train_loss:6.0993 train_time:4875ms step_avg:487.47ms +step:100/5200 train_loss:3.2377 train_time:47701ms step_avg:477.01ms +step:200/5200 train_loss:2.5230 train_time:95726ms step_avg:478.63ms +step:300/5200 train_loss:2.5249 train_time:143720ms step_avg:479.07ms +step:400/5200 train_loss:2.4072 train_time:191504ms step_avg:478.76ms +step:500/5200 train_loss:2.3540 train_time:239197ms step_avg:478.39ms +step:500/5200 val_loss:2.3450 val_bpb:1.3888 train_time:239208ms step_avg:478.42ms +step:600/5200 train_loss:2.3381 train_time:286861ms step_avg:478.10ms +step:700/5200 train_loss:2.3799 train_time:334521ms step_avg:477.89ms +step:800/5200 train_loss:2.2325 train_time:382130ms step_avg:477.66ms +step:900/5200 train_loss:2.1118 train_time:429963ms step_avg:477.74ms +step:1000/5200 train_loss:2.2651 train_time:477565ms step_avg:477.56ms +step:1000/5200 val_loss:2.2144 val_bpb:1.3115 train_time:477575ms step_avg:477.58ms +step:1100/5200 train_loss:2.2484 train_time:525338ms step_avg:477.58ms +step:1200/5200 train_loss:2.2613 train_time:573169ms step_avg:477.64ms +step:1300/5200 train_loss:2.2086 train_time:620829ms step_avg:477.56ms +step:1400/5200 train_loss:2.2295 train_time:668533ms step_avg:477.52ms +step:1500/5200 train_loss:2.1890 train_time:716279ms step_avg:477.52ms +step:1500/5200 val_loss:2.1736 val_bpb:1.2873 train_time:716290ms step_avg:477.53ms +step:1600/5200 train_loss:2.1213 train_time:763990ms step_avg:477.49ms +step:1700/5200 train_loss:2.1580 train_time:811810ms step_avg:477.54ms +step:1800/5200 train_loss:2.1265 train_time:859715ms step_avg:477.62ms +step:1900/5200 train_loss:2.1151 train_time:907551ms step_avg:477.66ms +step:2000/5200 train_loss:2.0170 train_time:955448ms step_avg:477.72ms +step:2000/5200 val_loss:2.1199 val_bpb:1.2556 train_time:955459ms step_avg:477.73ms +step:2100/5200 train_loss:2.0097 train_time:1003368ms step_avg:477.79ms +step:2200/5200 train_loss:2.1313 train_time:1051264ms step_avg:477.85ms +step:2300/5200 train_loss:2.0424 train_time:1099100ms step_avg:477.87ms +step:2400/5200 train_loss:2.0650 train_time:1146897ms step_avg:477.87ms +step:2500/5200 train_loss:2.1276 train_time:1194775ms step_avg:477.91ms +step:2500/5200 val_loss:2.0875 val_bpb:1.2363 train_time:1194785ms step_avg:477.91ms +step:2600/5200 train_loss:2.1213 train_time:1242753ms step_avg:477.98ms +step:2700/5200 train_loss:2.0191 train_time:1290742ms step_avg:478.05ms +step:2800/5200 train_loss:2.1568 train_time:1338726ms step_avg:478.12ms +step:2900/5200 train_loss:2.0460 train_time:1386511ms step_avg:478.11ms +step:3000/5200 train_loss:2.0776 train_time:1434331ms step_avg:478.11ms +step:3000/5200 val_loss:2.0614 val_bpb:1.2209 train_time:1434342ms step_avg:478.11ms +step:3100/5200 train_loss:2.0778 train_time:1482179ms step_avg:478.12ms +step:3200/5200 train_loss:2.1065 train_time:1529969ms step_avg:478.12ms +step:3300/5200 train_loss:2.0627 train_time:1577776ms step_avg:478.11ms +step:3400/5200 train_loss:2.0492 train_time:1625616ms step_avg:478.12ms +step:3500/5200 train_loss:2.1301 train_time:1673566ms step_avg:478.16ms +step:3500/5200 val_loss:2.0380 val_bpb:1.2070 train_time:1673577ms step_avg:478.16ms +step:3600/5200 train_loss:2.0416 train_time:1721464ms step_avg:478.18ms +step:3700/5200 train_loss:2.0425 train_time:1769201ms step_avg:478.16ms +step:3800/5200 train_loss:2.0298 train_time:1816932ms step_avg:478.14ms +step:3900/5200 train_loss:2.0388 train_time:1864678ms step_avg:478.12ms +step:4000/5200 train_loss:2.0838 train_time:1912429ms step_avg:478.11ms +step:4000/5200 val_loss:2.0154 val_bpb:1.1936 train_time:1912440ms step_avg:478.11ms +step:4100/5200 train_loss:2.0041 train_time:1960200ms step_avg:478.10ms +step:4200/5200 train_loss:2.0223 train_time:2008058ms step_avg:478.11ms +step:4300/5200 train_loss:1.9992 train_time:2055831ms step_avg:478.10ms +step:4400/5200 train_loss:1.9372 train_time:2103787ms step_avg:478.13ms +step:4500/5200 train_loss:2.0377 train_time:2151676ms step_avg:478.15ms +step:4500/5200 val_loss:1.9882 val_bpb:1.1775 train_time:2151687ms step_avg:478.15ms +step:4600/5200 train_loss:1.8825 train_time:2199385ms step_avg:478.13ms +swa:start step:4650 +step:4700/5200 train_loss:2.0725 train_time:2249289ms step_avg:478.57ms +step:4800/5200 train_loss:2.1830 train_time:2301203ms step_avg:479.42ms +step:4900/5200 train_loss:1.9485 train_time:2353078ms step_avg:480.22ms +late_qat:enabled step:4901 scale:0.0997 clip_range:31 +step:5000/5200 train_loss:1.9777 train_time:2405158ms step_avg:481.03ms +step:5000/5200 val_loss:1.9602 val_bpb:1.1610 train_time:2407189ms step_avg:481.44ms +step:5100/5200 train_loss:1.9852 train_time:2457292ms step_avg:481.82ms +step:5200/5200 train_loss:1.9884 train_time:2509281ms step_avg:482.55ms +step:5200/5200 val_loss:1.9533 val_bpb:1.1569 train_time:2511303ms step_avg:482.94ms +peak memory allocated: 20223 MiB reserved: 20350 MiB +swa:applying averaged 12 checkpoints +Serialized model: 96746619 bytes +Code size: 74030 bytes +Total submission size: 96820649 bytes +magnitude_pruning: frac=0.03 +=== Weight distribution diagnostics === + OUTLIER cores.0.attn.c_k.weight: max=2.8573 mean=0.1386 ratio=20.6 kurtosis=10.3 + OUTLIER cores.9.mlp.proj.weight: max=2.6011 mean=0.1012 ratio=25.7 kurtosis=0.6 +Serialized model int6+zstd: 15283215 bytes +Total submission size int6+zstd: 15357245 bytes +final_eval_mode:sliding_window_ttt stride:64 chunk_tokens:32768 +ttt_sliding:start chunks=1893 chunk_tokens=32768 total_windows=969088 stride=64 ttt_lr=0.002 ttt_epochs=30 freeze_blocks=2 +ttt_sliding:params unfrozen=19911748 frozen=4722704 +ttt_sliding:perlayer_lr proj=8(3.0x) fc=8(0.5x) other=78(1x) + ttt_chunk [1/1893] bpb=1.187163 time=2.1s + ttt_chunk [11/1893] bpb=1.143039 time=21.5s + ttt_chunk [21/1893] bpb=1.146252 time=40.9s + ttt_chunk [31/1893] bpb=1.149530 time=60.4s + ttt_chunk [41/1893] bpb=1.138514 time=80.0s + ttt_chunk [51/1893] bpb=1.135680 time=99.5s + ttt_chunk [61/1893] bpb=1.141447 time=119.0s + ttt_chunk [71/1893] bpb=1.138207 time=138.6s + ttt_chunk [81/1893] bpb=1.138302 time=158.1s + ttt_chunk [91/1893] bpb=1.137471 time=177.6s + ttt_chunk [101/1893] bpb=1.140623 time=197.1s + ttt_chunk [111/1893] bpb=1.141896 time=216.6s + ttt_chunk [121/1893] bpb=1.138736 time=236.1s + ttt_chunk [131/1893] bpb=1.138982 time=255.6s + ttt_chunk [141/1893] bpb=1.138685 time=275.1s + ttt_chunk [151/1893] bpb=1.142057 time=294.5s + ttt_chunk [161/1893] bpb=1.144021 time=314.0s + ttt_chunk [171/1893] bpb=1.144846 time=333.5s + ttt_chunk [181/1893] bpb=1.144914 time=353.0s + ttt_chunk [191/1893] bpb=1.148367 time=372.5s + ttt_chunk [201/1893] bpb=1.148661 time=392.0s + ttt_chunk [211/1893] bpb=1.146452 time=411.5s + ttt_chunk [221/1893] bpb=1.148524 time=431.0s + ttt_chunk [231/1893] bpb=1.147960 time=450.5s + ttt_chunk [241/1893] bpb=1.147804 time=470.0s + ttt_chunk [251/1893] bpb=1.146194 time=489.5s + ttt_chunk [261/1893] bpb=1.144612 time=509.0s + ttt_chunk [271/1893] bpb=1.143303 time=528.5s + ttt_chunk [281/1893] bpb=1.145809 time=548.0s + ttt_chunk [291/1893] bpb=1.146605 time=567.5s + ttt_chunk [301/1893] bpb=1.147232 time=586.9s + ttt_chunk [311/1893] bpb=1.148884 time=606.4s + ttt_chunk [321/1893] bpb=1.150381 time=625.9s + ttt_chunk [331/1893] bpb=1.150310 time=645.4s + ttt_chunk [341/1893] bpb=1.150673 time=664.9s + ttt_chunk [351/1893] bpb=1.151971 time=684.4s + ttt_chunk [361/1893] bpb=1.153347 time=703.9s + ttt_chunk [371/1893] bpb=1.152868 time=723.4s + ttt_chunk [381/1893] bpb=1.152787 time=742.9s + ttt_chunk [391/1893] bpb=1.152332 time=762.4s + ttt_chunk [401/1893] bpb=1.151009 time=781.9s + ttt_chunk [411/1893] bpb=1.149860 time=801.4s + ttt_chunk [421/1893] bpb=1.149308 time=820.9s + ttt_chunk [431/1893] bpb=1.150036 time=840.4s + ttt_chunk [441/1893] bpb=1.149868 time=859.9s + ttt_chunk [451/1893] bpb=1.149604 time=879.4s + ttt_chunk [461/1893] bpb=1.148892 time=898.9s + ttt_chunk [471/1893] bpb=1.148520 time=918.4s + ttt_chunk [481/1893] bpb=1.148299 time=937.9s + ttt_chunk [491/1893] bpb=1.147958 time=957.4s + ttt_chunk [501/1893] bpb=1.147417 time=976.9s + ttt_chunk [511/1893] bpb=1.146833 time=996.4s + ttt_chunk [521/1893] bpb=1.145968 time=1015.9s + ttt_chunk [531/1893] bpb=1.146059 time=1035.4s + ttt_chunk [541/1893] bpb=1.145938 time=1054.9s + ttt_chunk [551/1893] bpb=1.144728 time=1074.4s + ttt_chunk [561/1893] bpb=1.145250 time=1093.9s + ttt_chunk [571/1893] bpb=1.144542 time=1113.4s + ttt_chunk [581/1893] bpb=1.144000 time=1132.9s + ttt_chunk [591/1893] bpb=1.143356 time=1152.4s + ttt_chunk [601/1893] bpb=1.144041 time=1171.9s + ttt_chunk [611/1893] bpb=1.143643 time=1191.4s + ttt_chunk [621/1893] bpb=1.143586 time=1210.9s + ttt_chunk [631/1893] bpb=1.144017 time=1230.4s + ttt_chunk [641/1893] bpb=1.143800 time=1249.9s + ttt_chunk [651/1893] bpb=1.143793 time=1269.5s + ttt_chunk [661/1893] bpb=1.143682 time=1289.2s + ttt_chunk [671/1893] bpb=1.143335 time=1308.9s + ttt_chunk [681/1893] bpb=1.143586 time=1328.4s + ttt_chunk [691/1893] bpb=1.144292 time=1347.9s + ttt_chunk [701/1893] bpb=1.143508 time=1367.3s + ttt_chunk [711/1893] bpb=1.144083 time=1386.8s + ttt_chunk [721/1893] bpb=1.143702 time=1406.3s + ttt_chunk [731/1893] bpb=1.144159 time=1425.8s + ttt_chunk [741/1893] bpb=1.144113 time=1445.3s + ttt_chunk [751/1893] bpb=1.143717 time=1464.8s + ttt_chunk [761/1893] bpb=1.143564 time=1484.3s + ttt_chunk [771/1893] bpb=1.143343 time=1503.7s + ttt_chunk [781/1893] bpb=1.143891 time=1523.2s + ttt_chunk [791/1893] bpb=1.143571 time=1542.8s + ttt_chunk [801/1893] bpb=1.143625 time=1562.3s + ttt_chunk [811/1893] bpb=1.143156 time=1581.8s + ttt_chunk [821/1893] bpb=1.142978 time=1601.3s + ttt_chunk [831/1893] bpb=1.142559 time=1620.8s + ttt_chunk [841/1893] bpb=1.141989 time=1640.3s + ttt_chunk [851/1893] bpb=1.141929 time=1659.8s + ttt_chunk [861/1893] bpb=1.142074 time=1679.3s + ttt_chunk [871/1893] bpb=1.142141 time=1698.8s + ttt_chunk [881/1893] bpb=1.142206 time=1718.3s + ttt_chunk [891/1893] bpb=1.142029 time=1737.8s + ttt_chunk [901/1893] bpb=1.142009 time=1757.3s + ttt_chunk [911/1893] bpb=1.141995 time=1776.8s + ttt_chunk [921/1893] bpb=1.142371 time=1796.3s + ttt_chunk [931/1893] bpb=1.142192 time=1815.7s + ttt_chunk [941/1893] bpb=1.142079 time=1835.2s + ttt_chunk [951/1893] bpb=1.142101 time=1854.7s + ttt_chunk [961/1893] bpb=1.141856 time=1874.2s + ttt_chunk [971/1893] bpb=1.142610 time=1893.7s + ttt_chunk [981/1893] bpb=1.142748 time=1913.1s + ttt_chunk [991/1893] bpb=1.142665 time=1932.6s + ttt_chunk [1001/1893] bpb=1.142827 time=1952.1s + ttt_chunk [1011/1893] bpb=1.143142 time=1971.6s + ttt_chunk [1021/1893] bpb=1.143299 time=1991.0s + ttt_chunk [1031/1893] bpb=1.143846 time=2010.5s + ttt_chunk [1041/1893] bpb=1.143496 time=2030.0s + ttt_chunk [1051/1893] bpb=1.143206 time=2049.5s + ttt_chunk [1061/1893] bpb=1.143479 time=2069.0s + ttt_chunk [1071/1893] bpb=1.143976 time=2088.4s + ttt_chunk [1081/1893] bpb=1.143984 time=2107.9s + ttt_chunk [1091/1893] bpb=1.144386 time=2127.4s + ttt_chunk [1101/1893] bpb=1.144535 time=2146.9s + ttt_chunk [1111/1893] bpb=1.144298 time=2166.4s + ttt_chunk [1121/1893] bpb=1.144253 time=2185.8s + ttt_chunk [1131/1893] bpb=1.144132 time=2205.3s + ttt_chunk [1141/1893] bpb=1.143989 time=2225.0s + ttt_chunk [1151/1893] bpb=1.144045 time=2244.6s + ttt_chunk [1161/1893] bpb=1.143460 time=2264.2s + ttt_chunk [1171/1893] bpb=1.144007 time=2283.7s + ttt_chunk [1181/1893] bpb=1.143498 time=2303.2s + ttt_chunk [1191/1893] bpb=1.143213 time=2322.7s + ttt_chunk [1201/1893] bpb=1.143788 time=2342.2s + ttt_chunk [1211/1893] bpb=1.143207 time=2361.7s + ttt_chunk [1221/1893] bpb=1.142895 time=2381.2s + ttt_chunk [1231/1893] bpb=1.142789 time=2400.7s + ttt_chunk [1241/1893] bpb=1.142592 time=2420.2s + ttt_chunk [1251/1893] bpb=1.142346 time=2439.7s + ttt_chunk [1261/1893] bpb=1.142293 time=2459.2s + ttt_chunk [1271/1893] bpb=1.142108 time=2478.7s + ttt_chunk [1281/1893] bpb=1.141915 time=2498.2s + ttt_chunk [1291/1893] bpb=1.141746 time=2517.7s + ttt_chunk [1301/1893] bpb=1.141376 time=2537.2s + ttt_chunk [1311/1893] bpb=1.141054 time=2556.7s + ttt_chunk [1321/1893] bpb=1.140870 time=2576.2s + ttt_chunk [1331/1893] bpb=1.140776 time=2595.7s + ttt_chunk [1341/1893] bpb=1.140660 time=2615.2s + ttt_chunk [1351/1893] bpb=1.140614 time=2634.7s + ttt_chunk [1361/1893] bpb=1.140815 time=2654.2s + ttt_chunk [1371/1893] bpb=1.140663 time=2673.7s + ttt_chunk [1381/1893] bpb=1.140596 time=2693.2s + ttt_chunk [1391/1893] bpb=1.140057 time=2712.7s + ttt_chunk [1401/1893] bpb=1.140086 time=2732.2s + ttt_chunk [1411/1893] bpb=1.140112 time=2751.7s + ttt_chunk [1421/1893] bpb=1.140366 time=2771.2s + ttt_chunk [1431/1893] bpb=1.140222 time=2790.7s + ttt_chunk [1441/1893] bpb=1.140904 time=2810.2s + ttt_chunk [1451/1893] bpb=1.141023 time=2829.7s + ttt_chunk [1461/1893] bpb=1.140755 time=2849.1s + ttt_chunk [1471/1893] bpb=1.141662 time=2868.6s + ttt_chunk [1481/1893] bpb=1.141442 time=2888.1s + ttt_chunk [1491/1893] bpb=1.141442 time=2907.6s + ttt_chunk [1501/1893] bpb=1.141617 time=2927.0s + ttt_chunk [1511/1893] bpb=1.141711 time=2946.5s + ttt_chunk [1521/1893] bpb=1.141739 time=2966.0s + ttt_chunk [1531/1893] bpb=1.141564 time=2985.5s + ttt_chunk [1541/1893] bpb=1.141528 time=3004.9s + ttt_chunk [1551/1893] bpb=1.141869 time=3024.4s + ttt_chunk [1561/1893] bpb=1.142013 time=3043.9s + ttt_chunk [1571/1893] bpb=1.142142 time=3063.4s + ttt_chunk [1581/1893] bpb=1.142261 time=3082.8s + ttt_chunk [1591/1893] bpb=1.142195 time=3102.3s + ttt_chunk [1601/1893] bpb=1.142366 time=3121.8s + ttt_chunk [1611/1893] bpb=1.142436 time=3141.3s + ttt_chunk [1621/1893] bpb=1.142290 time=3160.7s + ttt_chunk [1631/1893] bpb=1.142448 time=3180.2s + ttt_chunk [1641/1893] bpb=1.142320 time=3199.7s + ttt_chunk [1651/1893] bpb=1.142242 time=3219.2s + ttt_chunk [1661/1893] bpb=1.142133 time=3238.6s + ttt_chunk [1671/1893] bpb=1.142496 time=3258.1s + ttt_chunk [1681/1893] bpb=1.142751 time=3277.6s + ttt_chunk [1691/1893] bpb=1.142706 time=3297.1s + ttt_chunk [1701/1893] bpb=1.142696 time=3316.5s + ttt_chunk [1711/1893] bpb=1.142536 time=3336.0s + ttt_chunk [1721/1893] bpb=1.142390 time=3355.5s + ttt_chunk [1731/1893] bpb=1.142342 time=3375.0s + ttt_chunk [1741/1893] bpb=1.142136 time=3394.5s + ttt_chunk [1751/1893] bpb=1.141977 time=3414.0s + ttt_chunk [1761/1893] bpb=1.142060 time=3433.5s + ttt_chunk [1771/1893] bpb=1.141994 time=3453.0s + ttt_chunk [1781/1893] bpb=1.141960 time=3472.5s + ttt_chunk [1791/1893] bpb=1.141546 time=3492.0s + ttt_chunk [1801/1893] bpb=1.141545 time=3511.5s + ttt_chunk [1811/1893] bpb=1.141372 time=3531.0s + ttt_chunk [1821/1893] bpb=1.141395 time=3550.5s + ttt_chunk [1831/1893] bpb=1.141006 time=3570.0s + ttt_chunk [1841/1893] bpb=1.141051 time=3589.5s + ttt_chunk [1851/1893] bpb=1.140835 time=3609.0s + ttt_chunk [1861/1893] bpb=1.140353 time=3628.5s + ttt_chunk [1871/1893] bpb=1.140201 time=3648.0s + ttt_chunk [1881/1893] bpb=1.139810 time=3667.5s + ttt_chunk [1891/1893] bpb=1.139650 time=3687.0s + ttt_chunk [1893/1893] bpb=1.139672 time=3689.2s +ttt_sliding:done val_loss=1.922674 val_bpb=1.138719 elapsed=3689.3s +final_int6_roundtrip val_loss:1.9227 val_bpb:1.1387 eval_time:3689782ms +final_int6_roundtrip_exact val_loss:1.92267394 val_bpb:1.13871882 + +================================================================================ +Additional seed results (validation only, training logs omitted for brevity): +================================================================================ + +--- Seed 1337 (i39_s1_55210274) --- +seed:1337 +step:5200/5200 val_loss:1.9539 val_bpb:1.1572 train_time:2481497ms step_avg:477.21ms +ttt_sliding:done val_loss=1.923355 val_bpb=1.139122 elapsed=3688.5s +final_int6_roundtrip_exact val_loss:1.92335521 val_bpb:1.13912231 + +--- Seed 42 (i39_s2_55210275) --- +seed:42 +step:5200/5200 val_loss:1.9552 val_bpb:1.1580 train_time:2499683ms step_avg:480.71ms +ttt_sliding:done val_loss=1.925248 val_bpb=1.140243 elapsed=3689.0s +final_int6_roundtrip_exact val_loss:1.92524826 val_bpb:1.14024348 + +--- 3-Seed Summary --- +Seed 7: val_bpb=1.13871882 pre_ttt=1.1569 +Seed 1337: val_bpb=1.13912231 pre_ttt=1.1572 +Seed 42: val_bpb=1.14024348 pre_ttt=1.1580 +Mean: val_bpb=1.13936154 pre_ttt=1.1574 std=0.00079 diff --git a/records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/train_gpt.py b/records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/train_gpt.py new file mode 100644 index 000000000..e3fa83a9d --- /dev/null +++ b/records/track_non_record_16mb/2026-03-23_11L_LeakyReLU_PerLayerLR_LegalTTT/train_gpt.py @@ -0,0 +1,1471 @@ +""" +Parameter Golf: 11L Depth Recurrence + LeakyReLU(0.5)² + Per-Layer LR Legal TTT +11-layer GPT with BigramHash, SmearGate, XSA, U-Net skips, SWA, VE128, +partial RoPE (16/64), LN scale, mixed int5/int6 quantization, and legal TTT. +Depth recurrence (shared BlockCores) enabled via UNIQUE_LAYERS env var. + +Key improvements from PRs #518, #481, #455, #442, #374: +- LeakyReLU(0.5)² activation (PR #518, sofiabod) +- Per-layer LR for TTT: mlp.proj 3×, mlp.fc 0.5× (PR #481, mrdavtan) +- Intra-chunk cosine LR decay for TTT (PR #518) +- 11 layers (vs 10) for more capacity +- Partial RoPE: only 16/64 head dims get rotary embedding +- LN Scale: 1/sqrt(layer_idx+1) scaling on normalized inputs +- ValueEmbedding (VE128): shared embedding added to value projections on deep layers +- XSA on last 4 layers, BigramHash(2048) +- Legal TTT: SGD 30 epochs, freeze first 2 blocks +""" +from __future__ import annotations +import copy, glob, io, math, os, random, subprocess, sys, time, uuid, zlib +from pathlib import Path +try: + import zstandard; _COMPRESSOR = "zstd" +except ImportError: + _COMPRESSOR = "zlib" +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP + +_IS_AMPERE_PLUS = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 +_HALF_DTYPE = torch.bfloat16 if _IS_AMPERE_PLUS else torch.float16 + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 42)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 500)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 100)) + iterations = int(os.environ.get("ITERATIONS", 5200)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3000)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + mlp_activation = os.environ.get("MLP_ACTIVATION", "leaky_relu_sq").lower() + unique_layers = int(os.environ.get("UNIQUE_LAYERS", 10)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + weight_decay = float(os.environ.get("WEIGHT_DECAY", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.01)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 64)) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.2)) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + late_qat = bool(int(os.environ.get("LATE_QAT", "1"))) + qat_threshold = float(os.environ.get("QAT_THRESHOLD", 0.1)) + all_int5 = bool(int(os.environ.get("ALL_INT5", "0"))) + prune_frac = float(os.environ.get("PRUNE_FRAC", "0.03")) + gptq_lite = bool(int(os.environ.get("GPTQ_LITE", "0"))) + quant_eval_every = int(os.environ.get("QUANT_EVAL_EVERY", "0")) + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 30)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_perlayer_lr = bool(int(os.environ.get("TTT_PERLAYER_LR", "1"))) + ttt_proj_lr_mult = float(os.environ.get("TTT_PROJ_LR_MULT", 3.0)) + ttt_fc_lr_mult = float(os.environ.get("TTT_FC_LR_MULT", 0.5)) + ttt_intra_cosine = bool(int(os.environ.get("TTT_INTRA_COSINE", "1"))) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 10, eps: float = 1e-7) -> Tensor: + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.to(torch.bfloat16 if _IS_AMPERE_PLUS else torch.float32) + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X + +class Muon(torch.optim.Optimizer): + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__(params, dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay)) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for group in self.param_groups: + params = group["params"] + if not params: + continue + lr, momentum = group["lr"], group["momentum"] + backend_steps, nesterov = group["backend_steps"], group["nesterov"] + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=_HALF_DTYPE) + curr = 0 + for i, p in enumerate(params): + if i % world_size == rank and p.grad is not None: + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + g = zeropower_via_newtonschulz5(g, steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + updates_flat[curr : curr + p.numel()] = g.reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + wd = group.get("weight_decay", 0.0) + curr = 0 + for p in params: + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + if wd > 0: + p.data.mul_(1.0 - lr * wd) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) + +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] + +def eval_val( + args: Hyperparameters, model: nn.Module, rank: int, world_size: int, + device: torch.device, grad_accum_steps: int, val_tokens: Tensor, + base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, +) -> tuple[float, float]: + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < args.train_seq_len: + raise ValueError("VAL_BATCH_SIZE too small") + local_batch_seqs = local_batch_tokens // args.train_seq_len + total_seqs = (val_tokens.numel() - 1) // args.train_seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * args.train_seq_len + raw_end = batch_seq_end * args.train_seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, args.train_seq_len) + y = local[1:].reshape(-1, args.train_seq_len) + with torch.autocast(device_type="cuda", dtype=_HALF_DTYPE, enabled=True): + batch_loss = model(x, y).detach() + val_loss_sum += batch_loss.to(torch.float64) * float(y.numel()) + val_token_count += float(y.numel()) + prev_ids, tgt_ids = x.reshape(-1), y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + p for p in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes," + "q_gain,skip_weight,skip_weights,smear,bigram.scale,ve_layer_scales,ve_shared.scale", + ).split(",") if p +) +FP16_KEEP_NAME_PATTERNS = tuple( + p for p in os.environ.get( + "FP16_KEEP_NAME_PATTERNS", "tok_emb,cores.2.attn.c_k" + ).split(",") if p +) +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_Q = float(os.environ.get("INT8_CLIP_PERCENTILE", "99.99984")) / 100.0 + +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = (torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() else torch.empty((t32.shape[0],), dtype=torch.float32)) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: return "embed" + if ".mlp." in name: return "mlp" + if "bigram" in name: return "bigram" + if ".attn." in name or (".proj." in name and ".mlp." not in name): return "attn" + return "other" + +def quantize_intN_per_row(t: Tensor, clip_range: int = 31, + gptq_lite: bool = False) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + if gptq_lite: + n_cols = t32.shape[1] + sorted_abs, _ = t32.abs().sort(dim=1) + best_q = best_scale = None + best_mse = torch.full((t32.shape[0],), float('inf'), device=t32.device) + for p in (0.95, 0.975, 0.99, 0.995, 1.0): + idx = min(int(p * (n_cols - 1)), n_cols - 1) + row_clip = sorted_abs[:, idx] + sc = (row_clip / clip_range).clamp_min(1e-12).to(torch.float16) + sc = sc.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / sc.float()[:, None]), + -(clip_range + 1), clip_range).to(torch.int8) + deq = q.float() * sc.float()[:, None] + mse = (t32 - deq).pow(2).mean(dim=1) + if best_q is None: + best_q, best_scale, best_mse = q, sc, mse + else: + better = mse < best_mse + best_q[better] = q[better] + best_scale[better] = sc[better] + best_mse[better] = mse[better] + return best_q, best_scale + row_max = t32.abs().amax(dim=1) + scale = (row_max / clip_range).clamp_min(1e-12).to(torch.float16) + scale = scale.clamp_min(torch.finfo(torch.float16).tiny) + q = torch.clamp(torch.round(t32 / scale.float()[:, None]), -(clip_range + 1), clip_range).to(torch.int8) + return q, scale + amax = t32.abs().max().item() + scale = torch.tensor(max(amax / clip_range, 1e-12), dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -(clip_range + 1), clip_range).to(torch.int8) + return q, scale + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], + gptq_lite: bool = False, force_int5: bool = False): + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 8192: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if any(p in name for p in FP16_KEEP_NAME_PATTERNS): + result[name] = t.to(dtype=torch.float16).contiguous() + meta[name] = "passthrough_fp16" + continue + if cat in int6_cats and t.ndim >= 1: + clip = 15 if force_int5 else (15 if cat == "mlp" else 31) + q, s = quantize_intN_per_row(t, clip_range=clip, gptq_lite=gptq_lite) + bits = {15: 5, 31: 6, 63: 7}.get(clip, 6) + result[name + ".q"] = q; result[name + ".scale"] = s + meta[name] = {"type": f"int{bits}"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q; result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta + +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta[name] + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" None: + self.file_idx = (self.file_idx + 1) % len(self.files) + self.tokens = load_data_shard(self.files[self.file_idx]) + self.pos = 0 + + def take(self, n: int) -> Tensor: + chunks: list[Tensor] = [] + remaining = n + while remaining > 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._advance_file() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) + +class DistributedTokenLoader: + def __init__(self, pattern: str, rank: int, world_size: int, device: torch.device): + self.rank, self.world_size, self.device = rank, world_size, device + self.stream = TokenStream(pattern) + + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) + +class CastedLinear(nn.Linear): + _qat_enabled: bool = False + _qat_clip_range: int = 31 + + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + if CastedLinear._qat_enabled and self.training and w.ndim == 2: + cr = CastedLinear._qat_clip_range + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale_q = (row_max / float(cr)).clamp_min(1.0 / float(cr)) + w_q = (torch.clamp(torch.round(w32 / scale_q[:, None]), -(cr + 1), cr) * scale_q[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) + +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS)) \ + and param.dtype != torch.float32: + param.data = param.data.float() + +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, + rope_dims: int = 0): + super().__init__() + self.dim, self.base = dim, base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + rd = self.rope_dims + inv_freq = 1.0 / (base ** (torch.arange(0, rd, 2, dtype=torch.float32) / rd)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if (self._cos_cached is None or self._sin_cached is None + or self._seq_len_cached != seq_len or self._cos_cached.device != device): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) + +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, + rope_base: float, qk_gain_init: float, xsa_enabled: bool = False, + rope_dims: int = 0): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads, self.num_kv_heads = num_heads, num_kv_heads + self.head_dim = dim // num_heads + self.xsa_enabled = xsa_enabled + self.rope_dims = rope_dims + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CastedLinear(dim, dim, bias=False) + self.c_k = CastedLinear(dim, kv_dim, bias=False) + self.c_v = CastedLinear(dim, kv_dim, bias=False) + self.proj = CastedLinear(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024, + rope_dims=rope_dims) + + def forward(self, x: Tensor, v_embed: Tensor | None = None) -> Tensor: + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + # Transpose to [B, H, T, D] for SDPA + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + if _IS_AMPERE_PLUS and self.num_kv_heads != self.num_heads: + y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True, enable_gqa=True) + else: + if self.num_kv_heads != self.num_heads: + repeats = self.num_heads // self.num_kv_heads + k = k.repeat_interleave(repeats, dim=1) + v_for_sdpa = v.repeat_interleave(repeats, dim=1) + else: + v_for_sdpa = v + y = F.scaled_dot_product_attention(q, k, v_for_sdpa, attn_mask=None, is_causal=True) + if self.xsa_enabled: + group_size = self.num_heads // self.num_kv_heads + y_t = y.transpose(1, 2) + y_grouped = y_t.reshape(bsz, seqlen, self.num_kv_heads, group_size, self.head_dim) + vn = F.normalize(v.transpose(1, 2).unsqueeze(3), dim=-1) + dot_prod = (y_grouped * vn).sum(dim=-1, keepdim=True) + y = (y_grouped - dot_prod * vn).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2) + y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim) + return self.proj(y) + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: float, activation: str = "relu_sq"): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CastedLinear(dim, hidden, bias=False) + self.proj = CastedLinear(hidden, dim, bias=False) + self.proj._zero_init = True + self.activation = activation + + def forward(self, x: Tensor) -> Tensor: + if self.activation == "leaky_relu_sq": + x = F.leaky_relu(self.fc(x), negative_slope=0.5) + else: + x = torch.relu(self.fc(x)) + return self.proj(x.square()) + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class BlockCore(nn.Module): + def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, + rope_base: float, qk_gain_init: float, + xsa_enabled: bool = False, mlp_activation: str = "relu_sq", + rope_dims: int = 0): + super().__init__() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + xsa_enabled=xsa_enabled, rope_dims=rope_dims) + self.mlp = MLP(dim, mlp_mult, activation=mlp_activation) + +class Block(nn.Module): + def __init__(self, dim: int, layer_idx: int = 0, ln_scale: bool = False): + super().__init__() + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + + def forward(self, x: Tensor, x0: Tensor, core: BlockCore, + v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * core.attn( + self.attn_norm(x) * self.ln_scale_factor, v_embed=v_embed) + x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * core.mlp( + self.mlp_norm(x) * self.ln_scale_factor) + return x + +class GPT(nn.Module): + def __init__(self, vocab_size: int, num_layers: int, model_dim: int, num_heads: int, + num_kv_heads: int, mlp_mult: float, tie_embeddings: bool, + tied_embed_init_std: float, logit_softcap: float, rope_base: float, + qk_gain_init: float, bigram_vocab_size: int = 0, bigram_dim: int = 128, + unique_layers: int = 0, xsa_last_n: int = 0, mlp_activation: str = "relu_sq", + rope_dims: int = 0, ln_scale: bool = False, + ve_enabled: bool = False, ve_dim: int = 128, ve_layers: str = "9,10"): + super().__init__() + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) \ + if bigram_vocab_size > 0 else None + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter( + torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.smear = SmearGate(model_dim) + n_cores = unique_layers if (0 < unique_layers < num_layers) else num_layers + xsa_start = max(0, n_cores - xsa_last_n) if xsa_last_n > 0 else n_cores + self.cores = nn.ModuleList([ + BlockCore(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, + qk_gain_init, xsa_enabled=(i >= xsa_start), + mlp_activation=mlp_activation, rope_dims=rope_dims) + for i in range(n_cores) + ]) + self.blocks = nn.ModuleList([ + Block(model_dim, layer_idx=i, ln_scale=ln_scale) + for i in range(num_layers) + ]) + self._core_indices = [i % n_cores for i in range(num_layers)] + if n_cores < num_layers: + from collections import Counter + uses = Counter(self._core_indices) + for core_idx, core in enumerate(self.cores): + n_uses = uses[core_idx] + if n_uses > 1: + scale = 1.0 / n_uses + for p in core.parameters(): + p.register_hook(lambda grad, s=scale: grad * s) + # Value Embedding (VE128) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = num_kv_heads * (model_dim // num_heads) + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices]) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self._init_weights() + + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, CastedLinear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + + def _get_ve(self, layer_idx: int, input_ids: Tensor, + ve_cache: dict | None = None) -> Tensor | None: + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + + def _forward_body(self, input_ids: Tensor) -> Tensor: + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, self.cores[self._core_indices[i]], v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + idx = self.num_encoder_layers + i + ve = self._get_ve(idx, input_ids, ve_cache) + x = self.blocks[idx](x, x0, self.cores[self._core_indices[idx]], v_embed=ve) + return self.final_norm(x) + + def _logits(self, x: Tensor) -> Tensor: + if self.tie_embeddings: + raw = F.linear(x, self.tok_emb.weight) + else: + raw = self.lm_head(x) + return self.logit_softcap * torch.tanh(raw / self.logit_softcap) + + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + x = self._forward_body(input_ids) + x = x.reshape(-1, x.size(-1)) + logits = self._logits(x) + return F.cross_entropy(logits.float(), target_ids.reshape(-1), reduction="mean") + + def forward_logits(self, input_ids: Tensor) -> Tensor: + return self._logits(self._forward_body(input_ids)) + +def eval_val_sliding( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, +) -> tuple[float, float]: + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=_HALF_DTYPE): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if rank == 0 and (bi // batch_seqs) % 50 == 0: + done = min(bi + batch_seqs, len(my_windows)) + pct = done / len(my_windows) * 100 + rl = (loss_sum / token_count).item() if token_count.item() > 0 else 0.0 + rbpb = rl / math.log(2.0) * (token_count.item() / byte_count.item()) if token_count.item() > 0 else 0.0 + print(f" sliding_eval [{pct:5.1f}%] {done}/{len(my_windows)} " + f"windows running_bpb={rbpb:.6f}", flush=True) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + base_model.train() + return val_loss, val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT: score each chunk with sliding windows, then train on it.""" + seq_len = args.train_seq_len + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts (same as eval_val_sliding) + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + # BPB accumulators + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Setup TTT optimizer (SGD + momentum for the legal score-first TTT pass) + n_blocks = len(base_model.blocks) + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, n_blocks))) + frozen_core_ids = set(base_model._core_indices[i] for i in frozen_block_ids) if frozen_block_ids else set() + + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True; break + if not freeze: + for ci_core in frozen_core_ids: + if f"cores.{ci_core}." in name: + freeze = True; break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + # Per-layer LR groups: mlp.proj gets higher LR (high quant error), mlp.fc gets lower LR + if args.ttt_perlayer_lr: + proj_params, fc_params, other_params = [], [], [] + for name, p in base_model.named_parameters(): + if not p.requires_grad: + continue + if "mlp.proj" in name: + proj_params.append(p) + elif "mlp.fc" in name: + fc_params.append(p) + else: + other_params.append(p) + param_groups = [ + {"params": proj_params, "lr": args.ttt_lr * args.ttt_proj_lr_mult}, + {"params": fc_params, "lr": args.ttt_lr * args.ttt_fc_lr_mult}, + {"params": other_params, "lr": args.ttt_lr}, + ] + log0(f"ttt_sliding:perlayer_lr proj={len(proj_params)}({args.ttt_proj_lr_mult}x) " + f"fc={len(fc_params)}({args.ttt_fc_lr_mult}x) other={len(other_params)}(1x)") + else: + param_groups = [{"params": ttt_params, "lr": args.ttt_lr}] + + optimizer = torch.optim.SGD(param_groups, momentum=args.ttt_momentum) + # Store initial per-group LRs for cosine scheduling + for pg in optimizer.param_groups: + pg['_base_lr'] = pg['lr'] + + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (sliding window eval) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=_HALF_DTYPE): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk's tokens (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + # Cosine decay across chunks (inter-chunk schedule) + inter_cos = 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + + # Store base LR for each param group (for intra-chunk cosine) + base_lrs = [pg['_base_lr'] for pg in optimizer.param_groups] + + # Partition training seqs across ranks + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + + # Count steps per epoch for intra-chunk cosine + steps_per_epoch = max(1, (my_chunk_seqs + args.ttt_batch_seqs - 1) // args.ttt_batch_seqs) + total_chunk_steps = args.ttt_epochs * steps_per_epoch + chunk_step = 0 + + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + # Intra-chunk cosine LR schedule (within each chunk's TTT epochs) + if args.ttt_intra_cosine and total_chunk_steps > 1: + intra_cos = 0.5 * (1.0 + math.cos(math.pi * chunk_step / total_chunk_steps)) + else: + intra_cos = 1.0 + for i, pg in enumerate(optimizer.param_groups): + pg['lr'] = base_lrs[i] * inter_cos * intra_cos + chunk_step += 1 + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + actual_be = my_seq_s + be + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + actual_be * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + if args.ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + # Progress log + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + # Final all-reduce + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + # Restore state + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + +def main() -> None: + global zeropower_via_newtonschulz5 + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + if _IS_AMPERE_PLUS: + zeropower_via_newtonschulz5 = torch.compile(zeropower_via_newtonschulz5) + + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + if _IS_AMPERE_PLUS: + enable_cudnn_sdp(False); enable_flash_sdp(True) + enable_mem_efficient_sdp(False); enable_math_sdp(False) + else: + enable_cudnn_sdp(False); enable_flash_sdp(False) + enable_mem_efficient_sdp(True); enable_math_sdp(True) + + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + + def log0(msg: str, console: bool = True) -> None: + if not master_process: return + if console: print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, + text=True, check=False).stdout, console=False) + log0("=" * 100, console=False) + random.seed(args.seed); np.random.seed(args.seed) + torch.manual_seed(args.seed); torch.cuda.manual_seed_all(args.seed) + + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} != tokenizer vocab_size={int(sp.vocab_size())}") + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + val_tokens = load_validation_tokens(args.val_files, args.train_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + + base_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, unique_layers=args.unique_layers, + xsa_last_n=args.xsa_last_n, mlp_activation=args.mlp_activation, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + ).to(device).to(_HALF_DTYPE) + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + if _IS_AMPERE_PLUS: + compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True) + else: + log0("skipping torch.compile on non-Ampere GPU") + compiled_model = base_model + model: nn.Module = DDP(compiled_model, device_ids=[local_rank], + broadcast_buffers=False) if distributed else compiled_model + + matrix_params, scalar_params = [], [] + for name, p in base_model.cores.named_parameters(): + if p.ndim == 2 and not any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS): + matrix_params.append(p) + else: + scalar_params.append(p) + for name, p in base_model.blocks.named_parameters(): + if p.ndim < 2 or any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS): + scalar_params.append(p) + elif p.ndim == 2 and not any(pat in name for pat in CONTROL_TENSOR_NAME_PATTERNS): + matrix_params.append(p) + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], + "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + matrix_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], + "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + matrix_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + + optimizer_tok = torch.optim.AdamW( + tok_params, betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.adam_wd, fused=True) + optimizer_muon = Muon( + matrix_params, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, weight_decay=0.04) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.adam_wd, fused=True) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True) + optimizers.insert(1, optimizer_head) + + n_params = sum(p.numel() for p in base_model.parameters()) + log0(f"model_params:{n_params} unique_cores:{len(base_model.cores)}") + log0(f"unique_layers:{args.unique_layers} mlp_mult:{args.mlp_mult}") + log0(f"matrix_params:{sum(p.numel() for p in matrix_params)} " + f"scalar_params:{sum(p.numel() for p in scalar_params)}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0(f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}") + log0(f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}") + log0(f"seed:{args.seed}") + + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + if warmdown_start <= step < args.iterations: + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) + return 1.0 + step_ms = elapsed_ms / max(step, 1) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + + if args.warmup_steps > 0: + initial_model_state = { + name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=_HALF_DTYPE, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + + training_time_ms = 0.0 + stop_after_step: int | None = None + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + + while True: + last_step = step == args.iterations or ( + stop_after_step is not None and step >= stop_after_step) + + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms") + torch.cuda.synchronize() + t0 = time.perf_counter() + if (args.quant_eval_every > 0 and should_validate + and lr_mul(step, training_time_ms) < args.swa_start_frac + and step % args.quant_eval_every == 0 and master_process): + with torch.no_grad(): + sd_snap = {k: v.detach().cpu().clone() for k, v in base_model.state_dict().items()} + qr, qm = mixed_quantize_int6(sd_snap, {"mlp", "attn", "bigram"}) + deq = dequantize_mixed_int6(qr, qm, sd_snap) + orig_sd = base_model.state_dict() + base_model.load_state_dict( + {k: v.to(dtype=orig_sd[k].dtype, device=orig_sd[k].device) for k, v in deq.items()}, + strict=True) + _, q_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + log0(f"quant_gap step:{step} float_bpb:{val_bpb:.4f} int6_bpb:{q_bpb:.4f} gap:{q_bpb - val_bpb:.4f}") + base_model.load_state_dict(orig_sd, strict=True) + + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms step:{step}/{args.iterations}") + break + + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + if distributed: + model.require_backward_grad_sync = micro_step == grad_accum_steps - 1 + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=_HALF_DTYPE, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + + if args.late_qat and scale < args.qat_threshold and not CastedLinear._qat_enabled: + CastedLinear._qat_enabled = True + CastedLinear._qat_clip_range = 15 if args.all_int5 else 31 + log0(f"late_qat:enabled step:{step} scale:{scale:.4f} clip_range:{CastedLinear._qat_clip_range}") + + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + + if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + + should_log_train = (args.train_log_every > 0 and + (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None)) + if should_log_train: + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms") + + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + + log0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB") + + if args.swa_enabled and swa_state is not None and swa_count > 1: + log0(f"swa:applying averaged {swa_count} checkpoints") + current_state = base_model.state_dict() + avg_state = {name: (tensor / swa_count).to(dtype=current_state[name].dtype) + for name, tensor in swa_state.items()} + base_model.load_state_dict(avg_state, strict=True) + + _model_pt = f"final_model_{args.run_id}.pt" + _model_ptz = f"final_model_{args.run_id}.int8.ptz" + if master_process: + torch.save(base_model.state_dict(), _model_pt) + model_bytes = os.path.getsize(_model_pt) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + log0(f"Total submission size: {model_bytes + code_bytes} bytes") + + with torch.no_grad(): + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 65536: + threshold = torch.quantile(param.abs().float().flatten(), args.prune_frac) + param.masked_fill_(param.abs() < threshold, 0.0) + log0(f"magnitude_pruning: frac={args.prune_frac}") + + if master_process: + log0("=== Weight distribution diagnostics ===") + for name, param in base_model.named_parameters(): + if param.ndim == 2 and param.numel() > 8192: + t = param.detach().float() + absmax = t.abs().max().item() + absmean = t.abs().mean().item() + kurtosis = ((t - t.mean()) / t.std()).pow(4).mean().item() - 3.0 + if kurtosis > 5.0 or absmax / absmean > 20.0: + log0(f" OUTLIER {name}: max={absmax:.4f} mean={absmean:.4f} " + f"ratio={absmax/absmean:.1f} kurtosis={kurtosis:.1f}") + + CastedLinear._qat_enabled = False + sd_cpu = {k: v.detach().cpu() for k, v in base_model.state_dict().items()} + quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn", "bigram"}, + gptq_lite=args.gptq_lite, + force_int5=args.all_int5) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + if _COMPRESSOR == "zstd": + quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) + else: + quant_blob = zlib.compress(quant_raw, 9) + if master_process: + with open(_model_ptz, "wb") as f: + f.write(quant_blob) + quant_file_bytes = os.path.getsize(_model_ptz) + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model int6+{_COMPRESSOR}: {quant_file_bytes} bytes") + log0(f"Total submission size int6+{_COMPRESSOR}: {quant_file_bytes + code_bytes} bytes") + + if distributed: + dist.barrier() + with open(_model_ptz, "rb") as f: + quant_blob_disk = f.read() + if _COMPRESSOR == "zstd": + decompressed = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + else: + decompressed = zlib.decompress(quant_blob_disk) + quant_state = torch.load(io.BytesIO(decompressed), map_location="cpu") + deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu) + base_model.load_state_dict(deq_state, strict=True) + + torch.cuda.synchronize() + t_qeval = time.perf_counter() + if args.ttt_enabled and args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window_ttt stride:{args.eval_stride} " + f"chunk_tokens:{args.ttt_chunk_tokens}") + q_val_loss, q_val_bpb = eval_val_sliding_ttt( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs, log0=log0) + elif args.eval_stride > 0 and args.eval_stride < args.train_seq_len: + log0(f"final_eval_mode:sliding_window stride:{args.eval_stride} batch_seqs:{args.eval_batch_seqs}") + q_val_loss, q_val_bpb = eval_val_sliding( + args, base_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, batch_seqs=args.eval_batch_seqs) + else: + log0("final_eval_mode:standard") + q_val_loss, q_val_bpb = eval_val( + args, model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut) + torch.cuda.synchronize() + log0(f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms") + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + if distributed: + dist.destroy_process_group() + +if __name__ == "__main__": + main()