Record: Fused MLP (Triton+CUTLASS EVT) + MLP 3.5× + Mixed int5/int6 + SLOT + Brotli — 1.1088 BPB (3-seed mean)#1105
Open
abaybektursun wants to merge 1 commit intoopenai:mainfrom
Conversation
ba665dd to
64ce201
Compare
9b27cf4 to
c27131c
Compare
… Brotli — val_bpb 1.1125 (3-seed mean) Seed 314: 1.1123 BPB / 1.87802 nats, 14.52 MB, 6844 steps, 87.7ms/step Seed 999: 1.1124 BPB / 1.87821 nats, 14.52 MB, 6846 steps, 87.7ms/step Seed 1337: 1.1129 BPB / 1.87910 nats, 14.53 MB, 6828 steps, 87.7ms/step Delta vs merged SOTA (our PR 1019): -0.00215 nats (-0.0013 BPB). Delta vs prior leaderboard (our PR 549): -0.01158 nats. Welch's t = -17.63, p < 0.01. Changes from PR 1019: 1. Fused Triton TMA forward + CUTLASS EVT backward MLP kernels 2. Pre-computed activation gradient (branch-free backward) 3. MLP 3.5x (1792 hidden dim, motivated by SVD analysis) 4. Hessian-based mixed int5/int6 quantization (motivated by quant sensitivity) 5. Brotli-11 compression (-581KB vs LZMA-9) 6. LR floor 0.05 7. Memmap multi-shard data pipeline (PR 726) Negative: Turbo-Muon +0.0018 BPB worse at scale, reverted to NS5. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
c27131c to
0df40cc
Compare
|
@abaybektursun - this is a fantastic write-up! Congrats on the SLOT improvement. If you need to free up even more room, you should check out the shrink.py script I used in PR 1089. I was able to shrink the train_gpt.py file by ~100KB. That might let you reduce pruning and/or promote one more group to int6. |
Contributor
Author
|
Ohhh I think with newer Pytroch performance and speed will be even better! I will try it when I can get my hands around 8xH100s |
6 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Results: val_bpb 1.1088 (3-seed mean) | 1.8722 nats | 8×H100 SXM | 600s | ~14.52 MB
Mixed quantization: 10 layers int6, 56 layers int5, no pruning needed. SLOT eval adds ~54s (well within 10-min budget).
Our merged PR 1019 (current SOTA): 1.88059 nats (1.1138 BPB). Delta: −0.00836 nats (−0.0050 BPB). Welch's t = −9.98, df ≈ 3.20, p < 0.01. Clears 0.005-nat threshold.
What does −0.005 BPB look like? Side-by-side generation (temp=0.8)
Prompt (50 tokens): "Insurance Company Declares Living Man Dead George Johannesen is very much alive. Which is why it was so surpr"
The old model drifts into incoherence ("Rachel Drobles... techniques of the car industry... Lyon Man is dead"). The new model stays on topic — insurance, health measurement, living man — and maintains grammatical coherence throughout. Both are wrong (the real text is about a cancelled driver's license), but the new model's errors are at least topically plausible. Remarkable that a 0.005 BPB difference — just 0.4% relative — produces a visible jump in coherence. Attention head analysis shows why: the model shifted from pattern-matching (induction heads: 2→1) to stronger sequence continuation (previous-token heads: 22→28), consistent with staying on topic instead of drifting.
Prior results: full stack without SLOT (val_bpb 1.1125, 3-seed)
Delta vs PR 1019: −0.00215 nats (below 0.005 threshold). Delta vs PR 549: −0.01158 nats, t = −17.63, p < 0.01.
Prior results: fused kernels + Brotli only (val_bpb 1.1138, 3-seed)
Delta vs PR 549: −0.00943 nats. Welch's t = −10.26, df ≈ 3.78, p < 0.01.
Throughput recovery
Our PR 1019 (now merged as SOTA) traded throughput for quality — full Hessian GPTQ and BigramHash 3072×112 added 3.3ms/step. Fused MLP kernels recover that regression. Mechanistic analysis of that model identified MLP as the capacity bottleneck, leading to MLP 3.5× (enabled by mixed quantization + Brotli headroom).
Changes vs our PR 1019
1. Fused MLP Kernels: Triton TMA Forward + CUTLASS EVT Backward
Forward (Triton TMA): Fuses
F.linear(x, up_w) → LeakyReLU(0.5) → squareinto a single kernel. The 302MB intermediate never touches HBM.Backward (CUTLASS EVT): Fuses
(go @ down_w.T) * act_gradinto a single CUTLASS 3.x kernel via Epilogue Visitor Tree. The elementwise multiply runs in the GEMM epilogue while tiles are still in registers — eliminating one 302MB write + read per layer.Key design insight — pre-computed activation gradient: We store the activation gradient in the forward pass instead of the pre-activation:
The identity
post = 0.5 · act_grad · preholds for both signs because:This eliminates all branching from the backward, reducing the CUTLASS EVT epilogue to a trivial 3-node tree:
Sm90EVT<multiplies, AccFetch, AuxLoad>. No conditionals in the kernel.CUTLASS EVT is a hard dependency — no silent fallback.
Kernel benchmarks + incremental deltas (2×H100)
Per-layer kernel timing:
CUTLASS vs Triton: +0.032 ms/layer, +0.347 ms/step kernel-level.
End-to-end training (35 steps, seed=42):
Kernel-level 0.347ms translates to 0.43ms end-to-end (cache/scheduling interactions).
8×H100: 86.7ms (our PR 1019, unfused) → 83.5ms (this PR) = −3.2ms/step (−3.7%).
Step-time profile — where all 313ms goes (2×H100, Nsight)
Why surgical fusion, not full-MLP autograd.Function: The 21.6% from torch.compile's cross-layer fusions (RMSNorm backward, residual adds, RoPE backward) only exists because these ops are visible to the compiler. Wrapping the full MLP backward in
autograd.Functionmakes it opaque to Inductor — all backward GEMMs plus cross-layer fusion run in eager mode, 2.7× slower net (identified in our PR 670). We fuse only forward and one backward GEMM+pointwise, preserving the compiler's scope.Top individual kernels:
Wall-clock breakdown: forward+backward compute ~94%, NCCL ~1.6%, CPU overhead ~4.1%.
2. Brotli-11 Compression (replaces LZMA-9)
−581 KB (−5.9%) vs LZMA-9. Independently discovered; PR 1089 (mikeapedia) also uses Brotli.
3. Memmap Multi-Shard Data Pipeline + GPU Prefetch
Coprime-stride sampling, daemon thread, CUDA stream prefetch. Credit: DeepReinforce (PR 726).
4. MLP 3.5× (1536 → 1792 hidden dim)
Motivated by mechanistic analysis: SVD analysis of our PR 1019 model showed MLP at 94.4% rank utilization (fully packed) while attention Q sat at 72.6% (spare capacity). The model was parameter-starved in MLP, not attention — so we made MLP wider.
Increases hidden dim from 3.0 × 512 = 1536 to 3.5 × 512 = 1792. Model goes from 27.07M to 29.95M params (+2.88M). At uniform int6, the 29.95M model compresses to 17.36 MB — 1.36 MB over the 16 MB limit. This is what makes mixed quantization (change 5) necessary.
Impact: −0.003 BPB from capacity, +13ms/step on 2×H100 (bigger GEMMs). Credit: PR 185 (dttdrv), PR 344 (aryanbhosale).
5. Mixed int5/int6 Quantization (Hessian-based)
Motivated by mechanistic analysis: Per-matrix quantization sensitivity showed MLP accounts for 80% of int6 quantization damage (MLP_down: +0.0039 BPB total, all Q matrices: +0.0003 BPB total — a 13× gap). Giving more bits to MLP is the optimal allocation.
Instead of uniform int6 for all layers, use int5 as default and promote the top 10 most sensitive layers to int6 based on Hessian trace ranking. Sensitivity = trace(H) where H = X^TX collected during GPTQ calibration. MLP projection layers in early blocks are most sensitive — they get int6; the remaining 56 layers get int5.
Uniform int5 loses ~0.019 BPB (catastrophic). Targeted Hessian-based allocation keeps quality loss under ~0.003 BPB while saving ~1.5 MB — exactly the headroom MLP 3.5× needs to fit under 16 MB. The wider MLP also made the model 3.6× less sensitive to quantization overall — information distributed across more dimensions means no single weight is load-bearing.
Credit: mixed quant concept PR 76 (Will DePue), gradient-guided PR 332 (saml212), Hessian-based PR 1089 (mikeapedia).
6. LR Floor (0.05)
During warmdown, learning rate normally decays to 0. With
lr_floor=0.05, it stops at 5% of peak instead. Prevents the optimizer from stalling, which helps with quantization-sensitive weight distributions still being refined at end of training.Impact: ~0.001 BPB. Credit: PR 130 (mohosy).
7. SLOT Eval (Selective Logit Offset Tuning)
Eval-time adaptation: tunes a 512-dim delta vector at the last hidden layer using AdamW (lr=0.003, 5 steps) on validation tokens already scored. −0.0037 BPB on top of the full stack (1.1123 → 1.1086). Adds 54s eval time (148s total vs 94s without), well within the 10-minute eval budget.
This is what pushes us past the 0.005-nat threshold vs our merged PR 1019.
Credit: PR 609 (saml212).
Negative Results
Architecture
Calibration legality: AR self-generated (64 seqs × 2048 tokens, temp=0.8). No val data, no train data accessed during quantization. Same method as our PR 1019.
Setup & Reproduction
🤖 Generated with Claude Code