Skip to content

Record: Parallel Muon + Parameter Banking — 81.87ms/step, val_bpb 1.1247 (3-seed mean)#399

Open
abaybektursun wants to merge 4 commits intoopenai:mainfrom
abaybektursun:submission/parallel-muon-82ms
Open

Record: Parallel Muon + Parameter Banking — 81.87ms/step, val_bpb 1.1247 (3-seed mean)#399
abaybektursun wants to merge 4 commits intoopenai:mainfrom
abaybektursun:submission/parallel-muon-82ms

Conversation

@abaybektursun
Copy link
Contributor

@abaybektursun abaybektursun commented Mar 22, 2026

Novel Contribution: Parameter Banking + Parallel Muon

This submission introduces Parameter Banking, a weight layout restructuring that enables batched optimizer operations, combined with an adapted Parallel Muon communication strategy. Together, these provide a 3.4% training throughput improvement that is architecture-agnostic and composes with any Muon-based training stack. The approach has since been adopted by subsequent competition submissions (e.g., PR #549).

Pure systems optimization — model architecture and hyperparameters are unchanged.

3-Seed Results (8×H100 80GB SXM, PyTorch 2.9.1+cu128, 600s)

Seed step_avg steps int6 sliding val_bpb artifact
1337 81.86 ms 7,331 1.1241 15,830,960 bytes
42 81.88 ms 7,328 1.1253 15,819,728 bytes
2025 81.86 ms 7,330 1.1247 15,796,052 bytes
Mean 81.87 ms 7,330 1.1247 (std 0.0006) ~15.8 MB

Technical Approach

1. Parameter Banking (novel)

We restructure 66 separate nn.Linear weight matrices into 4 contiguous 3D nn.Parameter tensors, grouped by shape:

  • qo_bank: (22, 512, 512) — Q + Out projections
  • kv_bank: (22, 256, 512) — K + V projections
  • mlp_up_bank: (11, 1536, 512) — MLP up
  • mlp_down_bank: (11, 512, 1536) — MLP down

Forward pass uses F.linear(x, bank[layer_idx]) — compiles identically to nn.Linear under torch.compile. Verified: banked forward+backward = 72.33ms vs baseline 72.59ms.

The key benefit: Newton-Schulz orthogonalization (used by Muon) becomes a single torch.bmm over the batch dimension, replacing 66 sequential small GEMMs. This reduces optimizer time from 19.7ms to 1.3ms (15× faster).

2. Parallel Muon (adapted from arXiv:2511.07464)

Standard DDP is incompatible with parameter banking: bank gradients aggregate across all 11 layers and are only available at end of backward, destroying compute-communication overlap (+4ms regression).

Our solution removes DDP for banked parameters and schedules communication explicitly:

  1. Launch async reduce_scatter for all banks (biggest first)
  2. all_reduce + Adam step on small replicated params (while bank RS is in-flight)
  3. Wait for RS, local batched NS on each GPU's shard, async all_gather

This follows the DDP-free communication pattern from modded-nanogpt, adapted to work with our banking structure.

Engineering notes

Approach Result Lesson
Non-surgery batching (keep 66 params, batch in optimizer) 85.73ms Gather/scatter kernel overhead offsets speedup
DDP with banks 88.8ms (+4ms) Bank grads only available at end of backward
Polar Express (arXiv:2505.16932) 82ms, 16.2MB PE weights compress ~190KB worse than NS
Parameter Banking + Parallel Muon 81.87ms, 15.8MB Architecture-agnostic, composable

Compatibility analysis

Base PR Speed Score Finding
#315 (EMA only) -3.4% -0.0006 BPB Extra steps improve EMA monotonically
#374 (Tight SWA) -3.5% +0.001 SWA averages warmdown weights; extra steps don't enter the window
#401 (EMA+SWA) -2.8% +0.0005 Same SWA dilution
#398 (TTT) -2.3% +0.004 More-converged model has less room for TTT adaptation

Key finding: The throughput advantage translates to quality gains exclusively for EMA-based models, where every additional step monotonically refines the exponential moving average.

Credits

🤖 Generated with Claude Code

Systems optimization built on PR openai#315 by @jfprincz (11L XSA4+EMA, 1.1248 bpb).
Same architecture, same hyperparameters, only optimizer changed.

82.14ms/step vs 84.76ms baseline = 7,306 steps vs 7,079 in 600s.
Pre-quant val_bpb 1.1421 (identical to baseline).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
abaybektursun and others added 2 commits March 22, 2026 00:13
…1.1248)

Unbank state dict before quantization so int6 per-row scales match baseline.
Rebank after dequantization for roundtrip eval.

Results: 82.13ms/step, 7,306 steps, int6 sliding window val_bpb 1.1238.
Artifact: 16.06MB (int6+zstd).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Seeds 42, 1337, 2025: mean 82.08ms/step, val_bpb 1.1239 (std 0.0001).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@abaybektursun abaybektursun changed the title Record: Parallel Muon + Parameter Banking — 82.14ms/step (3.1% faster than PR #315) Record: Parallel Muon + Parameter Banking — 82.08ms/step (3.2% faster than PR #315) Mar 22, 2026
@abaybektursun abaybektursun force-pushed the submission/parallel-muon-82ms branch from 5f4d141 to 4db0057 Compare March 22, 2026 15:24
@abaybektursun abaybektursun changed the title Record: Parallel Muon + Parameter Banking — 82.08ms/step (3.2% faster than PR #315) Record: Parallel Muon + Parameter Banking + Polar Express — 82.14ms/step (3.1% faster than PR #315) Mar 22, 2026
Replaced Polar Express with standard Newton-Schulz + switched to lzma compression.
3-seed results: 81.87ms/step mean, 1.1247 sliding bpb mean, all artifacts ~15.8MB.

Seed 1337: 7331 steps, 1.1241 bpb, 15,830,960 bytes
Seed 42:   7328 steps, 1.1253 bpb, 15,819,728 bytes
Seed 2025: 7330 steps, 1.1247 bpb, 15,796,052 bytes

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@abaybektursun abaybektursun changed the title Record: Parallel Muon + Parameter Banking + Polar Express — 82.14ms/step (3.1% faster than PR #315) Record: Parallel Muon + Parameter Banking — 81.87ms/step, val_bpb 1.1247 (3-seed mean) Mar 22, 2026
abaybektursun added a commit to abaybektursun/parameter-golf that referenced this pull request Mar 22, 2026
Legal score-first TTT (PR openai#461 recipe) applied to openai#414 stack with
Parameter Banking + Parallel Muon (first introduced in PR openai#399).

Pre-TTT: 1.1234, post-TTT: 1.1213 (-0.0021). TTT eval: 400s.
Artifact: 15.84 MB. Seed 1337, 8×H100 SXM, PyTorch 2.9.1+cu128.

Every token scored BEFORE model adapts (inference_mode enforced).
SGD+momentum(0.9), 3 epochs/32K chunk, freeze first 2 blocks.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
abaybektursun added a commit to abaybektursun/parameter-golf that referenced this pull request Mar 23, 2026
Legal score-first TTT (PR openai#461 recipe) applied to openai#414 stack with
Parameter Banking + Parallel Muon (first introduced in PR openai#399).

Pre-TTT: 1.1234, post-TTT: 1.1213 (-0.0021). TTT eval: 400s.
Artifact: 15.84 MB. Seed 1337, 8×H100 SXM, PyTorch 2.9.1+cu128.

Every token scored BEFORE model adapts (inference_mode enforced).
SGD+momentum(0.9), 3 epochs/32K chunk, freeze first 2 blocks.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
abaybektursun added a commit to abaybektursun/parameter-golf that referenced this pull request Mar 23, 2026
Legal score-first TTT (PR openai#461 recipe) applied to openai#414 stack with
Parameter Banking + Parallel Muon (first introduced in PR openai#399).

Pre-TTT: 1.1234, post-TTT: 1.1213 (-0.0021). TTT eval: 400s.
Artifact: 15.84 MB. Seed 1337, 8×H100 SXM, PyTorch 2.9.1+cu128.

Every token scored BEFORE model adapts (inference_mode enforced).
SGD+momentum(0.9), 3 epochs/32K chunk, freeze first 2 blocks.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
abaybektursun added a commit to abaybektursun/parameter-golf that referenced this pull request Mar 23, 2026
Legal score-first TTT (PR openai#461 recipe) applied to openai#414 stack with
Parameter Banking + Parallel Muon (first introduced in PR openai#399).

Pre-TTT: 1.1234, post-TTT: 1.1213 (-0.0021). TTT eval: 400s.
Artifact: 15.84 MB. Seed 1337, 8×H100 SXM, PyTorch 2.9.1+cu128.

Every token scored BEFORE model adapts (inference_mode enforced).
SGD+momentum(0.9), 3 epochs/32K chunk, freeze first 2 blocks.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
abaybektursun added a commit to abaybektursun/parameter-golf that referenced this pull request Mar 23, 2026
…d mean)

Legal score-first TTT (PR openai#461 recipe) + BigramHash(3072) + freeze=0
on openai#414 stack with Parameter Banking + Parallel Muon (PR openai#399).

3-seed results (BIGRAM=3072, 3ep, freeze=0, SGD+mom=0.9):
  Seed 1337: 1.1204 bpb, 413s TTT, 15.98 MB
  Seed 42:   1.1216 bpb, 406s TTT, 15.99 MB
  Seed 2025: 1.1221 bpb, 405s TTT, 15.99 MB
  Mean:      1.1214 (std 0.0009)

All artifacts under 16MB. All eval times under 600s.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
abaybektursun added a commit to abaybektursun/parameter-golf that referenced this pull request Mar 23, 2026
…ed mean)

LeakyReLU(0.5)² activation (-0.003 vs relu²) + legal score-first TTT
(PR openai#461 recipe, 3ep SGD, all blocks unfrozen) + BigramHash(1536) on
openai#414 stack with Parameter Banking + Parallel Muon (PR openai#399).

3-seed results:
  Seed 42:   1.1200 bpb, 408s TTT, 15.88 MB
  Seed 2025: 1.1189 bpb, 408s TTT, 15.99 MB
  Seed 1337: pending (log will be added)
  Mean:      1.1195 (std 0.0008)

All artifacts under 16MB. All eval under 10 min.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
abaybektursun added a commit to abaybektursun/parameter-golf that referenced this pull request Mar 23, 2026
…ed mean)

LeakyReLU(0.5)² activation (-0.003 vs relu²) + legal score-first TTT
(PR openai#461 recipe, 3ep SGD, all blocks unfrozen) + BigramHash(1536) on
openai#414 stack with Parameter Banking + Parallel Muon (PR openai#399).

3-seed results:
  Seed 1337: 1.1192 bpb, 410s TTT, 15.98 MB
  Seed 42:   1.1200 bpb, 408s TTT, 15.88 MB
  Seed 2025: 1.1189 bpb, 408s TTT, 15.99 MB
  Mean:      1.1194 (std 0.0006)

All artifacts under 16MB. All eval under 10 min.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Mistobaan pushed a commit to Mistobaan/parameter-golf that referenced this pull request Mar 25, 2026
…ed mean)

LeakyReLU(0.5)² activation (-0.003 vs relu²) + legal score-first TTT
(PR openai#461 recipe, 3ep SGD, all blocks unfrozen) + BigramHash(1536) on
openai#414 stack with Parameter Banking + Parallel Muon (PR openai#399).

3-seed results:
  Seed 1337: 1.1192 bpb, 410s TTT, 15.98 MB
  Seed 42:   1.1200 bpb, 408s TTT, 15.88 MB
  Seed 2025: 1.1189 bpb, 408s TTT, 15.99 MB
  Mean:      1.1194 (std 0.0006)

All artifacts under 16MB. All eval under 10 min.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
TimS-ml referenced this pull request in TimS-ml/parameter-golf-autoresearch Mar 26, 2026
…ed mean)

LeakyReLU(0.5)² activation (-0.003 vs relu²) + legal score-first TTT
(PR openai#461 recipe, 3ep SGD, all blocks unfrozen) + BigramHash(1536) on
openai#414 stack with Parameter Banking + Parallel Muon (PR openai#399).

3-seed results:
  Seed 1337: 1.1192 bpb, 410s TTT, 15.98 MB
  Seed 42:   1.1200 bpb, 408s TTT, 15.88 MB
  Seed 2025: 1.1189 bpb, 408s TTT, 15.99 MB
  Mean:      1.1194 (std 0.0006)

All artifacts under 16MB. All eval under 10 min.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
nedcut pushed a commit to nedcut/parameter-golf that referenced this pull request Mar 26, 2026
…ed mean)

LeakyReLU(0.5)² activation (-0.003 vs relu²) + legal score-first TTT
(PR openai#461 recipe, 3ep SGD, all blocks unfrozen) + BigramHash(1536) on
openai#414 stack with Parameter Banking + Parallel Muon (PR openai#399).

3-seed results:
  Seed 1337: 1.1192 bpb, 410s TTT, 15.98 MB
  Seed 42:   1.1200 bpb, 408s TTT, 15.88 MB
  Seed 2025: 1.1189 bpb, 408s TTT, 15.99 MB
  Mean:      1.1194 (std 0.0006)

All artifacts under 16MB. All eval under 10 min.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
nvemuri4649 pushed a commit to thanushpatlolla/parameter-golf that referenced this pull request Mar 27, 2026
…ed mean)

LeakyReLU(0.5)² activation (-0.003 vs relu²) + legal score-first TTT
(PR openai#461 recipe, 3ep SGD, all blocks unfrozen) + BigramHash(1536) on
openai#414 stack with Parameter Banking + Parallel Muon (PR openai#399).

3-seed results:
  Seed 1337: 1.1192 bpb, 410s TTT, 15.98 MB
  Seed 42:   1.1200 bpb, 408s TTT, 15.88 MB
  Seed 2025: 1.1189 bpb, 408s TTT, 15.99 MB
  Mean:      1.1194 (std 0.0006)

All artifacts under 16MB. All eval under 10 min.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant