Non-Record: 11L Parallel Muon + LeakyReLU² MLP3x + Legal TTT (val_bpb 1.1253)#635
Closed
aryanbhosale wants to merge 15 commits intoopenai:mainfrom
Closed
Non-Record: 11L Parallel Muon + LeakyReLU² MLP3x + Legal TTT (val_bpb 1.1253)#635aryanbhosale wants to merge 15 commits intoopenai:mainfrom
aryanbhosale wants to merge 15 commits intoopenai:mainfrom
Conversation
…bpb=1.1330, 8xH100) - 31.4M params, 11L 512d 8H/4KV MLP3.5x(1792) - LeakyReLU(0.5)^2, SmearGate, BigramHash(10240), TrigramHash(4096) - Value Residual, Gated Attention, XSA-all-11, Partial RoPE(16/64) - Muon lr=0.03, EMA(0.997), Late QAT, int6 GPTQ-lite + zstd-22 - 3-seed: 1.1334/1.1322/1.1334, mean=1.1330, std=0.0007 - Developed via 30-experiment autoresearch on 1xH100
- P0: Full Hessian GPTQ with 256 calibration samples, column reorder, blocksize=128, Cholesky error compensation (expected -0.005 to -0.008 BPB) - P1: LZMA preset 6 compression (replaces zstd-22) - P2: Selective magnitude pruning (zero bottom 10% of ±1 values) - P3: Drop TrigramHash default (USE_TRIGRAMHASH=0) - P4: Muon lr=0.025 (tuned for 8xH100 step count) - P5: TTT rewrite: sliding-window score-first, SGD momentum=0.9 - Fix: QAT clamp range [-31,31] to match symmetric int6 grid Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The float64 CPU matmuls were extremely slow (~hours). Using float32 on GPU makes this complete in seconds. Hessians moved to CPU after collection. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
attn_gate (nn.Linear) requires autocast since model mixes bf16/fp32. Hooks already convert inputs to float32 for Hessian accumulation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Use cholesky() + cholesky_inverse() instead of linalg.inv() for stability - Handle dead columns (zero Hessian diagonal) by zeroing weights - Match standard IST-DASLab/gptq reference implementation Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Full Hessian GPTQ degraded quantized BPB by 0.16 (1.25 -> 1.42) because QAT trains weights for simple round-to-nearest, not GPTQ error compensation. GPTQ-lite (per-row clip search) matches QAT and gives ~0.01 degradation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…adation Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
LZMA preset 6 compresses worse than zstd-22 for int6 quantized tensors. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
6c76d99 to
30e109f
Compare
- BIGRAM_VOCAB_SIZE 10240→2048 (saves ~1M params, matches community SOTA) - Magnitude pruning at 3% (not 10%) matching PR openai#634's validated approach - Remove unused collect_hessians/gptq_quantize_weight (dead code, ~155 lines) - Clean up quantize_state_dict_int8 signature Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Even 3% pruning causes 0.18 BPB degradation because QAT optimizes weights for round-to-nearest. PR openai#634 uses pruning without QAT — different regime. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Major rewrite to use Parallel Muon architecture: - 4 contiguous 3D parameter banks (qo_bank, kv_bank, mlp_up_bank, mlp_down_bank) - Batched Newton-Schulz via torch.bmm for all bank params - 3-phase async optimizer: reduce-scatter → Adam on non-bank → NS5+all-gather - No DDP — manual gradient sync for non-bank params - Bank-aware Block/Attention/MLP take weight tensors as forward args - torch.compile with fullgraph=True (no DDP wrapper) - _unbank/_rebank state_dict for quantization compatibility Expected ~7000 steps on 8xH100 (vs ~3800 before), ~84ms/step Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
FA3 operates on [B,T,H,D] natively (no transposes needed). Falls back to SDPA when flash_attn_interface is not installed. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Profiling showed 178ms/step vs reference 84ms/step. Root causes: - MLP 3.5x→3.0x: -17% FLOPs per step - XSA all 11→last 4: saves compute on 7 layers - EMA on GPU: remove .cpu() transfer that stalled every step - Bigram 2048→1536: matches reference exactly These are the exact settings from merged SOTA PR openai#549 (1.1194 BPB, 83ms/step). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…253, 3-seed) 3-seed mean val_bpb: 1.1253 (std 0.0002) on 8xH100 SXM, 600s training Architecture: 11L 512d 8H/4KV, Parallel Muon with parameter banking, LeakyReLU(0.5)² MLP 3x, SmearGate, BigramHash(1536), Value Residual, Gated Attention, XSA4, Partial RoPE 16/64, EMA(0.997)+SWA, Late QAT, GPTQ-lite int6+zstd-22, legal score-first TTT. ~90ms/step, ~6700 steps per seed. Flash Attention 3.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Record: 11L Parallel Muon + LeakyReLU² MLP3x + Legal Score-First TTT
val_bpb = 1.1253 (3-seed mean, std 0.0002) | ~15 MB | 8×H100 SXM
3-Seed Results (8×H100 80GB SXM, PyTorch 2.9.1+cu128)
Key Techniques
torch.bmm, 3-phase async reduce-scatter/all-gather. ~90ms/step, ~6700 steps.inference_mode, then SGD(lr=0.002, momentum=0.9) for 3 epochs, all blocks unfrozen, cosine LR decay.torch.compile(fullgraph=True), no DDP.Architecture (29.8M params)
11L, 512d, 8H/4KV (GQA), MLP 3×, SmearGate, BigramHash(1536), Value Residual, Gated Attention, XSA4, Partial RoPE(16/64), U-Net skips, OrthoInit, tied embeddings, logit softcap 30.0.
Credits