Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Record: 11L XSA-all + Full GPTQ (Budget-Legal) + Parallel Muon + Selective Pruning

**val_bpb: 1.1178** (3-seed mean, std 0.0001) | **15.95 MB** max artifact | 8xH100 SXM, ~596s total compute

## Update (2026-03-26)

This PR was updated to fix a GPTQ budget violation identified in [issue #677](https://github.com/openai/parameter-golf/issues/677). The previous version trained for the full 600s, then ran GPTQ calibration for ~46s on top, exceeding the 600s artifact-production budget. The fix reserves 14s from the training budget for GPTQ calibration (`gptq_reserve_ms = 14000.0`), ensuring total compute (training ~586s + GPTQ ~10s = ~596s) stays within the 600s limit. All results below use the fixed code with fresh 3-seed runs.

## Results (3 seeds, 8xH100 SXM)

| Seed | Steps | ms/step | Sliding BPB (s64) | val_loss | Artifact | Train Time | GPTQ Time | Total |
|------|-------|---------|--------------------|----------|----------|------------|-----------|-------|
| 1337 | 6,674 | ~88 | **1.1177** | 1.8871 | 15,929,433 bytes | 586,128ms | 9,786ms | 595,915ms |
| 42 | 6,732 | ~87 | 1.1179 | 1.8875 | 15,949,353 bytes | 586,050ms | 9,792ms | 595,842ms |
| 7 | 6,731 | ~87 | 1.1179 | 1.8875 | 15,946,145 bytes | 586,066ms | 9,823ms | 595,889ms |

**Mean: 1.1178 | Std: 0.0001**

## Key Techniques

### XSA on All 11 Layers
Standard practice applies Exclusive Self-Attention to only the last 4 layers. Applying to all 11 forces cross-position information mixing from layer 0, improving representation quality. Zero new parameters — just a config change. -0.0016 BPB vs XSA-last-4 in ablation.

### Full Hessian GPTQ (Budget-Legal)
- 64-batch GPU Hessian calibration from training data
- Column-wise int6 quantization with Cholesky error compensation, block size 128, percdamp 0.01
- QAT STE aligned to export quantizer using row-maximum (amax) clipping with [-32, 31] range
- **Budget reservation:** `gptq_reserve_ms = 14000.0` — training stops ~14s early so GPTQ calibration fits within 600s
- Log verification: `gptq:budget_check train:586128ms + gptq:9786ms = 595915ms (budget:600000ms)`

### Parallel Muon Optimizer with Parameter Banking
- Weight matrices stored in contiguous parameter banks (qo_bank, kv_bank, mlp_up_bank, mlp_down_bank)
- 3-phase overlapped optimizer step: async reduce-scatter -> batched Newton-Schulz orthogonalization -> async all-gather
- Eliminates DDP double-communication overhead, achieving ~87ms/step (~6,700 steps in 586s)

### Selective Magnitude Pruning
Post-GPTQ, sort quantized values at +/-1 by reconstruction error (scale^2), zero least-impactful first until artifact fits target. Binary search for exact target size.

### LZMA Compression
LZMA preset 6 replacing zstd-22. Better compression ratio on int6 quantized weights.

## Architecture

- 11 transformer layers, dim=512, 8 heads, 4 KV heads (GQA)
- 3x MLP expansion (hidden=1536) with **LeakyReLU(0.5)^2** activation
- **XSA on all 11 layers** (Exclusive Self-Attention)
- Partial RoPE (16/64 dims) + NTK-aware scaling
- LN Scale Factor 1/sqrt(layer_idx+1)
- U-Net skip connections (5 encoder, 6 decoder)
- SmearGate temporal gating
- BigramHash (2048 buckets, 128-dim)
- Shared Value Embedding (dim=128, layers 9-10)
- FlashAttention 3 (Hopper native kernels)
- Orthogonal init, logit softcap 30, tied embeddings

## Training

- Parallel Muon optimizer (matrices): lr=0.025, momentum=0.99, WD=0.04, 5 Newton-Schulz steps
- AdamW (embeddings): lr=0.035, (scalars): lr=0.025, WD=0.04
- Gradient clip: 0.3
- Batch: 786,432 tokens/step, seq_len=2048
- Warmdown: 3,500 iters (wallclock-based)
- EMA (decay=0.997) + Tight SWA (every 50 steps, scale<0.2)
- Late QAT: STE int6 fake-quantization when LR scale<0.15

## Quantization & Compression

- Full GPTQ with 64-batch GPU Hessian calibration, block_size=128, percdamp=0.01
- Int6 per-row with amax clipping, range [-32, 31]
- Selective magnitude pruning (target 15.9MB)
- Small tensors + tok_emb.weight in fp16
- LZMA preset 6 compression

## Compliance

- [x] 3 seeds, all total compute <= 600s on 8xH100 SXM (verified: max 595,915ms)
- [x] GPTQ calibration WITHIN training budget (14s reserved, verified via `gptq:budget_check`)
- [x] All artifacts <= 16,000,000 bytes (max: 15,949,353)
- [x] No TTT on validation data
- [x] No training data accessed during evaluation
- [x] No network calls during evaluation
- [x] Sliding window eval stride=64, consistent across seeds (std=0.0001)

## Run Command

```bash
SEED=1337 TARGET_MB=15.9 torchrun --standalone --nproc_per_node=8 train_gpt.py
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"author": "Raahil Shah",
"github_id": "raahilshah",
"name": "11L XSA-all + Full GPTQ (budget-legal) + Parallel Muon + LZMA + Selective Pruning",
"blurb": "XSA on all 11 layers, Hessian-aware GPTQ with 14s budget reservation (train 586s + GPTQ 10s = 596s total), amax-aligned QAT, Parallel Muon optimizer with parameter banking, LZMA compression, selective magnitude pruning. LeakyReLU(0.5)² activation, EMA(0.997), Tight SWA, VE128, Partial RoPE 16/64, LN Scale, BigramHash(2048), U-Net skips.",
"date": "2026-03-26T00:00:00Z",
"val_loss": 1.88736798,
"val_bpb": 1.11780859,
"pre_quant_val_loss": 1.9210,
"pre_quant_val_bpb": 1.1377,
"bytes_total": 15949353,
"seeds": {
"1337": {"val_bpb": 1.11765772, "val_loss": 1.88711325, "bytes_total": 15929433},
"42": {"val_bpb": 1.11789104, "val_loss": 1.88750721, "bytes_total": 15949353},
"7": {"val_bpb": 1.11787700, "val_loss": 1.88748353, "bytes_total": 15946145}
},
"mean_val_bpb": 1.11780859,
"std_val_bpb": 0.00010683
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
logs/88f192b5-5eef-4f0f-99ee-b94b7e4b0298.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:26993756
mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0
XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9309 val_bpb:4.1049 train_time:0ms step_avg:0.02ms
step:1/20000 train_loss:6.9317 train_time:5842ms step_avg:5842.44ms
step:2/20000 train_loss:8.6935 train_time:5886ms step_avg:2943.25ms
step:3/20000 train_loss:7.5958 train_time:5971ms step_avg:1990.36ms
step:4/20000 train_loss:7.3348 train_time:6057ms step_avg:1514.13ms
step:5/20000 train_loss:7.2640 train_time:6141ms step_avg:1228.14ms
step:6/20000 train_loss:7.1178 train_time:6224ms step_avg:1037.40ms
step:7/20000 train_loss:6.9214 train_time:6308ms step_avg:901.16ms
step:8/20000 train_loss:6.8019 train_time:6392ms step_avg:798.97ms
step:9/20000 train_loss:6.4159 train_time:6476ms step_avg:719.52ms
step:10/20000 train_loss:6.0449 train_time:6560ms step_avg:656.01ms
step:500/20000 train_loss:2.3890 train_time:49041ms step_avg:98.08ms
step:1000/20000 train_loss:2.2584 train_time:92205ms step_avg:92.20ms
step:1500/20000 train_loss:2.2044 train_time:135515ms step_avg:90.34ms
step:2000/20000 train_loss:2.0479 train_time:178935ms step_avg:89.47ms
step:2500/20000 train_loss:2.1542 train_time:222374ms step_avg:88.95ms
step:3000/20000 train_loss:2.1427 train_time:265854ms step_avg:88.62ms
step:3500/20000 train_loss:2.1571 train_time:309328ms step_avg:88.38ms
step:4000/20000 train_loss:1.9474 train_time:352778ms step_avg:88.19ms
step:4000/20000 val_loss:2.0405 val_bpb:1.2085 train_time:352821ms step_avg:88.21ms
step:4500/20000 train_loss:2.0987 train_time:396236ms step_avg:88.05ms
step:5000/20000 train_loss:2.0803 train_time:439652ms step_avg:87.93ms
step:5500/20000 train_loss:1.9943 train_time:483113ms step_avg:87.84ms
swa:start step:6000
step:6000/20000 train_loss:1.9186 train_time:526547ms step_avg:87.76ms
late_qat:enabled step:6150 scale:0.1499
step:6500/20000 train_loss:2.0554 train_time:570717ms step_avg:87.80ms
step:6674/20000 val_loss:1.9226 val_bpb:1.1387 train_time:586128ms step_avg:87.82ms
stopping_early: wallclock_cap train_time:586128ms step:6674/20000
peak memory allocated: 22861 MiB reserved: 23032 MiB
ema:applying EMA weights
DIAGNOSTIC post_ema val_loss:1.9210 val_bpb:1.1377 eval_time:2075ms
Serialized model: 106158518 bytes
Code size: 113761 bytes
gptq:building calibration model...
gptq:calibrating with 64 training batches...
gptq:calibrated 68 layers in 9.8s
gptq:budget_check train:586128ms + gptq:9786ms = 595915ms (budget:600000ms)
gptq_quantize: 66 GPTQ layers, 0 naive layers
gptq_quantize: 66 GPTQ layers, 0 naive layers
gptq_quantize: 66 GPTQ layers, 0 naive layers
gptq_quantize: 66 GPTQ layers, 0 naive layers
gptq_quantize: 66 GPTQ layers, 0 naive layers
gptq_quantize: 66 GPTQ layers, 0 naive layers
gptq_quantize: 66 GPTQ layers, 0 naive layers
gptq_quantize: 66 GPTQ layers, 0 naive layers
selective_prune: 4110461 ±1 candidates, unpruned=15.19MB target=15.9MB
Serialized model int6+lzma: 15815672 bytes
Total submission size int6+lzma: 15929433 bytes
Total submission size int8+zlib: 15929433 bytes
final_int6_roundtrip val_loss:1.9268 val_bpb:1.1411 eval_time:49161ms
final_int6_roundtrip_exact val_loss:1.92677762 val_bpb:1.14114624
final_int6_sliding_window val_loss:1.8871 val_bpb:1.1177 stride:64 eval_time:102894ms
final_int6_sliding_window_exact val_loss:1.88711325 val_bpb:1.11765772
final_int8_zlib_roundtrip_exact val_loss:1.88711325 val_bpb:1.11765772
Loading