Skip to content

Non-Record: 11L Parallel Muon + LeakyReLU² MLP3x + Legal TTT (val_bpb 1.1253)#635

Closed
aryanbhosale wants to merge 15 commits intoopenai:mainfrom
aryanbhosale:submission/sota-11l-mlp35x
Closed

Non-Record: 11L Parallel Muon + LeakyReLU² MLP3x + Legal TTT (val_bpb 1.1253)#635
aryanbhosale wants to merge 15 commits intoopenai:mainfrom
aryanbhosale:submission/sota-11l-mlp35x

Conversation

@aryanbhosale
Copy link

@aryanbhosale aryanbhosale commented Mar 24, 2026

Record: 11L Parallel Muon + LeakyReLU² MLP3x + Legal Score-First TTT

val_bpb = 1.1253 (3-seed mean, std 0.0002) | ~15 MB | 8×H100 SXM

3-Seed Results (8×H100 80GB SXM, PyTorch 2.9.1+cu128)

Seed step_avg steps EMA bpb Quantized bpb TTT bpb
1337 91.5ms 6,556 1.1194 1.1291 1.1255
42 89.2ms 6,726 1.1195 1.1278 1.1253
2024 89.3ms 6,722 1.1193 1.1280 1.1251
Mean 90.0ms 6,668 1.1194 1.1283 1.1253

Key Techniques

  1. Parallel Muon with parameter banking — 4 contiguous 3D banks, batched Newton-Schulz via torch.bmm, 3-phase async reduce-scatter/all-gather. ~90ms/step, ~6700 steps.
  2. LeakyReLU(0.5)² MLP 3× — preserves negative gradient flow.
  3. Legal Score-First TTT (PR Non-record: 11L Depth Recurrence + High-Yield Legal TTT (1.14458 BPB) #461/Record: LeakyReLU² + Legal Score-First TTT + Parallel Muon — val_bpb 1.1194 (3-seed mean) #549 recipe) — score each 32K-token chunk with sliding windows under inference_mode, then SGD(lr=0.002, momentum=0.9) for 3 epochs, all blocks unfrozen, cosine LR decay.
  4. EMA(0.997) + SWA — EMA selected as best pre-quant weights (1.1194 BPB, matches merged SOTA pre-TTT).
  5. GPTQ-lite int6 + zstd-22 — per-row 5-percentile clip search, FP16 embedding passthrough.
  6. Flash Attention 3, torch.compile(fullgraph=True), no DDP.

Architecture (29.8M params)

11L, 512d, 8H/4KV (GQA), MLP 3×, SmearGate, BigramHash(1536), Value Residual, Gated Attention, XSA4, Partial RoPE(16/64), U-Net skips, OrthoInit, tied embeddings, logit softcap 30.0.

Credits

arbyte77 and others added 9 commits March 25, 2026 16:05
…bpb=1.1330, 8xH100)

- 31.4M params, 11L 512d 8H/4KV MLP3.5x(1792)
- LeakyReLU(0.5)^2, SmearGate, BigramHash(10240), TrigramHash(4096)
- Value Residual, Gated Attention, XSA-all-11, Partial RoPE(16/64)
- Muon lr=0.03, EMA(0.997), Late QAT, int6 GPTQ-lite + zstd-22
- 3-seed: 1.1334/1.1322/1.1334, mean=1.1330, std=0.0007
- Developed via 30-experiment autoresearch on 1xH100
- P0: Full Hessian GPTQ with 256 calibration samples, column reorder,
  blocksize=128, Cholesky error compensation (expected -0.005 to -0.008 BPB)
- P1: LZMA preset 6 compression (replaces zstd-22)
- P2: Selective magnitude pruning (zero bottom 10% of ±1 values)
- P3: Drop TrigramHash default (USE_TRIGRAMHASH=0)
- P4: Muon lr=0.025 (tuned for 8xH100 step count)
- P5: TTT rewrite: sliding-window score-first, SGD momentum=0.9
- Fix: QAT clamp range [-31,31] to match symmetric int6 grid

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The float64 CPU matmuls were extremely slow (~hours). Using float32 on GPU
makes this complete in seconds. Hessians moved to CPU after collection.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
attn_gate (nn.Linear) requires autocast since model mixes bf16/fp32.
Hooks already convert inputs to float32 for Hessian accumulation.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Use cholesky() + cholesky_inverse() instead of linalg.inv() for stability
- Handle dead columns (zero Hessian diagonal) by zeroing weights
- Match standard IST-DASLab/gptq reference implementation

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Full Hessian GPTQ degraded quantized BPB by 0.16 (1.25 -> 1.42) because
QAT trains weights for simple round-to-nearest, not GPTQ error compensation.
GPTQ-lite (per-row clip search) matches QAT and gives ~0.01 degradation.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…adation

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
LZMA preset 6 compresses worse than zstd-22 for int6 quantized tensors.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@arbyte77 arbyte77 force-pushed the submission/sota-11l-mlp35x branch from 6c76d99 to 30e109f Compare March 25, 2026 10:35
arbyte77 and others added 6 commits March 25, 2026 16:15
- BIGRAM_VOCAB_SIZE 10240→2048 (saves ~1M params, matches community SOTA)
- Magnitude pruning at 3% (not 10%) matching PR openai#634's validated approach
- Remove unused collect_hessians/gptq_quantize_weight (dead code, ~155 lines)
- Clean up quantize_state_dict_int8 signature

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Even 3% pruning causes 0.18 BPB degradation because QAT optimizes weights
for round-to-nearest. PR openai#634 uses pruning without QAT — different regime.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Major rewrite to use Parallel Muon architecture:
- 4 contiguous 3D parameter banks (qo_bank, kv_bank, mlp_up_bank, mlp_down_bank)
- Batched Newton-Schulz via torch.bmm for all bank params
- 3-phase async optimizer: reduce-scatter → Adam on non-bank → NS5+all-gather
- No DDP — manual gradient sync for non-bank params
- Bank-aware Block/Attention/MLP take weight tensors as forward args
- torch.compile with fullgraph=True (no DDP wrapper)
- _unbank/_rebank state_dict for quantization compatibility

Expected ~7000 steps on 8xH100 (vs ~3800 before), ~84ms/step

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
FA3 operates on [B,T,H,D] natively (no transposes needed).
Falls back to SDPA when flash_attn_interface is not installed.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Profiling showed 178ms/step vs reference 84ms/step. Root causes:
- MLP 3.5x→3.0x: -17% FLOPs per step
- XSA all 11→last 4: saves compute on 7 layers
- EMA on GPU: remove .cpu() transfer that stalled every step
- Bigram 2048→1536: matches reference exactly

These are the exact settings from merged SOTA PR openai#549 (1.1194 BPB, 83ms/step).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…253, 3-seed)

3-seed mean val_bpb: 1.1253 (std 0.0002) on 8xH100 SXM, 600s training

Architecture: 11L 512d 8H/4KV, Parallel Muon with parameter banking,
LeakyReLU(0.5)² MLP 3x, SmearGate, BigramHash(1536), Value Residual,
Gated Attention, XSA4, Partial RoPE 16/64, EMA(0.997)+SWA, Late QAT,
GPTQ-lite int6+zstd-22, legal score-first TTT.

~90ms/step, ~6700 steps per seed. Flash Attention 3.
@aryanbhosale aryanbhosale changed the title Non-record: 11L MLP3.5x LeakyReLU(0.5)^2 + Full SOTA Stack (mean val_bpb=1.1330, 8xH100 SXM) Record: 11L Parallel Muon + LeakyReLU² MLP3x + Legal TTT (val_bpb 1.1253) Mar 25, 2026
@arbyte77 arbyte77 deleted the submission/sota-11l-mlp35x branch March 25, 2026 18:16
@aryanbhosale aryanbhosale changed the title Record: 11L Parallel Muon + LeakyReLU² MLP3x + Legal TTT (val_bpb 1.1253) Non-Record: 11L Parallel Muon + LeakyReLU² MLP3x + Legal TTT (val_bpb 1.1253) Mar 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants