Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Non-record: 11L XSA + SwiGLU + LoRA TTT (1xH100 PCIe)

**val_bpb: 1.1573** (LoRA TTT) | **15.02 MB** artifact | 1xH100 PCIe, ~80 min

## Key Techniques

1. **XSA (Cross-token Self-Attention)** on last 4 layers — removes self-value projection, forcing attention to contribute cross-position context. -0.005 BPB.
2. **SwiGLU 3x MLP** — gated activation `swish(gate(x)) * up(x)`. More parameter-efficient than ReLU². +0.004.
3. **SmearGate** — blends each token embedding with the previous token's embedding, giving bigram context at the embedding layer. Critical: +1.80 loss without it.
4. **U-Net skip connections** — encoder (L0-L4) saves skip outputs, decoder (L6-L10) adds them back. Ensures gradient flow through all 11 layers.
5. **Orthogonal initialization** — all weight matrices initialized orthogonally. Required for SmearGate to work.
6. **Muon optimizer with WD=0.04** — decoupled weight decay shrinks weights for better quantization + generalization.
7. **Stochastic Weight Averaging** — averages 15 checkpoints during warmdown for smoother quantized weights.
8. **Mixed quantization** — int5 (MLP) + int6 (attention) + int8 (embeddings) + zstd-22 compression. Fits in 15.02 MB.
9. **LoRA TTT** — per-document test-time training with rank-8 LoRA on Q and V projections. Score-then-train per 256-token chunk (legal: every token scored before being trained on). -0.034 BPB.

## Results

| Eval Method | val_loss | val_bpb | Delta |
|-------------|----------|---------|-------|
| Pre-quant (SWA) | 1.9800 | 1.1727 | — |
| Int8+zlib roundtrip | 1.9969 | 1.1826 | +0.010 |
| Mixed quant (int5/int6/int8+zstd) | 1.9913 | 1.1930 | +0.020 |
| **LoRA TTT (mixed quant)** | **1.9724** | **1.1573** | **-0.015** |

## Architecture

```
11L, 512d, 8H/4KV (GQA), SwiGLU 3x MLP
XSA on L7-L10, SmearGate, U-Net skips
OrthoInit, Muon WD=0.04, SWA (15 checkpoints)
Mixed quant: int5-MLP + int6-attn + int8-embed + zstd-22
LoRA TTT: rank-8, Q+V, LR=0.05, score-then-train, 256-token chunks
```

## Training Configuration

- **GPU**: 1xH100 PCIe (RunPod) — grad accumulation 8 steps to match 524K batch
- **Wallclock**: ~4850s (~80 min) — NOT a 10-min record submission
- **Batch**: 524,288 tokens/step (grad_accum=8 × seq_len=2048 × micro_batch=32)
- **Sequence length**: 2048
- **Warmdown**: 3000 iterations
- **Steps completed**: 7,926 / 20,000 (wallclock cap)

## Why Non-Record

This ran on 1xH100 PCIe for ~80 minutes (not 8xH100 in 10 min). The architecture and training are identical to what would run on 8xH100 — only the batch parallelism differs.

## Development Journey

18 experiments over 5 days, from val_bpb=3.10 (wrong batch size) to 1.1573:

| Experiment | val_bpb | What changed |
|-----------|---------|-------------|
| 1 (baseline) | 3.10 | Wrong batch size |
| 2 | 1.46 | Fixed batch to 65K |
| 6 | 1.312 | 1200s training, warmdown=600 |
| 10 | 1.283 | + SmearGate, OrthoInit, MLP 3x, WD |
| 13 | — | SwiGLU > ReLU² (+0.004) |
| 14-8x | 1.202 | 11 layers + SWA on 8xH100 |
| 15 | 1.187 | + seq_len=2048 |
| 17 | 1.183 | + XSA (last 4 layers) |
| + Quant | 1.191 | int5+int6+int8+zstd (15 MB) |
| **+ LoRA TTT** | **1.157** | Per-document adaptation at eval |

Total compute cost: ~$50 across all experiments.

## What Didn't Work

| Technique | Result | Why |
|-----------|--------|-----|
| Register token | +0.002 worse | Step overhead > marginal benefit |
| Layer looping + wider | +0.034 worse | Step time from wider dim |
| Data sampling (juncture) | +0.002 worse | Shard-level too coarse |
| Hard example mining | +0.040 worse | Destroys Muon weight geometry |
| Partial RoPE (16/64) | +0.015 worse | Head dim too small |
| EMA (replacing SWA) | +0.015 worse | Over-smoothed warmdown weights |
| BigramHash | 0.000 | SmearGate makes it redundant |
| SGD TTT | +0.018 worse | Modifying dequantized weights directly breaks them |

## Command

```bash
RUN_ID=exp17_xsa \
MAX_WALLCLOCK_SECONDS=4850 \
TRAIN_BATCH_TOKENS=524288 \
WARMDOWN_ITERS=3000 \
MUON_WD=0.04 \
NUM_LAYERS=11 \
TRAIN_SEQ_LEN=2048 \
MLP_MULT=3 \
MATRIX_LR=0.04 \
SCALAR_LR=0.04 \
python train_gpt.py
```

## Included Files

- `train_gpt.py` — full training + quantization + LoRA TTT eval script
- `train.log` — training log from 1xH100 run
- `submission.json` — metadata
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"name": "11L XSA + SwiGLU + SWA + Mixed Quant + LoRA TTT (1xH100 PCIe)",
"author": "swapp1990",
"github_id": "swapp1990",
"date": "2026-03-24",
"val_bpb": 1.1573,
"val_loss": 1.9724,
"pre_quant_val_bpb": 1.1727,
"pre_quant_val_loss": 1.9800,
"step_stop": 7926,
"wallclock_seconds": 4850,
"gpu": "1xH100 PCIe (RunPod)",
"track": "non-record-16mb",
"bytes_total": 15793319,
"bytes_model": 15727804,
"bytes_code": 65515,
"blurb": "11-layer transformer with XSA (cross-token self-attention on last 4 layers), SwiGLU 3x MLP, SmearGate, U-Net skip connections, orthogonal init, Muon optimizer with WD=0.04, and stochastic weight averaging. Trained on 1xH100 PCIe for ~80 min with batch=524K seq=2048 (grad accumulation). Mixed quantization (int5 MLP + int6 attn + int8 embed + zstd) fits in 15.02 MB. LoRA TTT (rank-8, score-then-train per chunk) brings val_bpb from 1.191 to 1.157. 18 experiments over 5 days."
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
logs/06226eeb-9a28-46a0-ba37-47715b3f2521.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:27092057
world_size:1 grad_accum_steps:8
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04
train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:4850.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9277 val_bpb:4.1030 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9295 train_time:578ms step_avg:577.95ms
step:2/20000 train_loss:9.7430 train_time:1181ms step_avg:590.66ms
step:3/20000 train_loss:8.6021 train_time:1787ms step_avg:595.71ms
step:4/20000 train_loss:8.7606 train_time:2796ms step_avg:699.05ms
step:5/20000 train_loss:8.3936 train_time:3421ms step_avg:684.26ms
step:6/20000 train_loss:8.2054 train_time:4025ms step_avg:670.80ms
step:7/20000 train_loss:7.7758 train_time:4625ms step_avg:660.69ms
step:8/20000 train_loss:7.2954 train_time:5294ms step_avg:661.75ms
step:9/20000 train_loss:6.8198 train_time:5924ms step_avg:658.17ms
step:10/20000 train_loss:6.4736 train_time:6506ms step_avg:650.65ms
step:200/20000 train_loss:2.7500 train_time:125492ms step_avg:627.46ms
step:400/20000 train_loss:2.3249 train_time:246819ms step_avg:617.05ms
step:600/20000 train_loss:2.4240 train_time:367212ms step_avg:612.02ms
step:800/20000 train_loss:2.2708 train_time:488724ms step_avg:610.91ms
step:1000/20000 train_loss:2.2825 train_time:613127ms step_avg:613.13ms
step:1000/20000 val_loss:2.2538 val_bpb:1.3348 train_time:613139ms step_avg:613.14ms
step:1200/20000 train_loss:2.2151 train_time:734714ms step_avg:612.26ms
step:1400/20000 train_loss:2.2454 train_time:856615ms step_avg:611.87ms
step:1600/20000 train_loss:2.1343 train_time:978149ms step_avg:611.34ms
step:1800/20000 train_loss:2.1791 train_time:1099034ms step_avg:610.57ms
step:2000/20000 train_loss:2.1387 train_time:1219142ms step_avg:609.57ms
step:2000/20000 val_loss:2.1475 val_bpb:1.2719 train_time:1219151ms step_avg:609.58ms
step:2200/20000 train_loss:2.0670 train_time:1342246ms step_avg:610.11ms
step:2400/20000 train_loss:2.1141 train_time:1462622ms step_avg:609.43ms
step:2600/20000 train_loss:2.1813 train_time:1584180ms step_avg:609.30ms
step:2800/20000 train_loss:2.1298 train_time:1704582ms step_avg:608.78ms
step:3000/20000 train_loss:2.0552 train_time:1825266ms step_avg:608.42ms
step:3000/20000 val_loss:2.1026 val_bpb:1.2453 train_time:1825424ms step_avg:608.47ms
step:3200/20000 train_loss:2.1302 train_time:1950560ms step_avg:609.55ms
step:3400/20000 train_loss:2.1032 train_time:2072471ms step_avg:609.55ms
step:3600/20000 train_loss:2.0625 train_time:2196370ms step_avg:610.10ms
step:3800/20000 train_loss:2.1419 train_time:2319658ms step_avg:610.44ms
step:4000/20000 train_loss:2.0426 train_time:2442476ms step_avg:610.62ms
step:4000/20000 val_loss:2.0758 val_bpb:1.2294 train_time:2442641ms step_avg:610.66ms
step:4200/20000 train_loss:2.1144 train_time:2570120ms step_avg:611.93ms
step:4400/20000 train_loss:2.1164 train_time:2695921ms step_avg:612.71ms
step:4600/20000 train_loss:1.9531 train_time:2819014ms step_avg:612.83ms
step:4800/20000 train_loss:2.0661 train_time:2939820ms step_avg:612.46ms
step:5000/20000 train_loss:2.0568 train_time:3060813ms step_avg:612.16ms
SWA: started at step 5000 (warmdown_start~4922)
step:5000/20000 val_loss:2.0563 val_bpb:1.2179 train_time:3061106ms step_avg:612.22ms
step:5200/20000 train_loss:2.0657 train_time:3181771ms step_avg:611.88ms
step:5400/20000 train_loss:2.0772 train_time:3303276ms step_avg:611.72ms
step:5600/20000 train_loss:2.0720 train_time:3427130ms step_avg:611.99ms
step:5800/20000 train_loss:2.0417 train_time:3548669ms step_avg:611.84ms
step:6000/20000 train_loss:2.1672 train_time:3670464ms step_avg:611.74ms
step:6000/20000 val_loss:2.0271 val_bpb:1.2006 train_time:3670951ms step_avg:611.83ms
step:6200/20000 train_loss:2.0269 train_time:3794142ms step_avg:611.96ms
step:6400/20000 train_loss:2.0329 train_time:3914942ms step_avg:611.71ms
step:6600/20000 train_loss:1.9839 train_time:4035937ms step_avg:611.51ms
step:6800/20000 train_loss:2.1208 train_time:4158062ms step_avg:611.48ms
step:7000/20000 train_loss:2.0063 train_time:4279307ms step_avg:611.33ms
step:7000/20000 val_loss:1.9992 val_bpb:1.1840 train_time:4279531ms step_avg:611.36ms
step:7200/20000 train_loss:2.0013 train_time:4402050ms step_avg:611.40ms
step:7400/20000 train_loss:1.9981 train_time:4523744ms step_avg:611.32ms
step:7600/20000 train_loss:1.9361 train_time:4645429ms step_avg:611.24ms
step:7800/20000 train_loss:1.9874 train_time:4773436ms step_avg:611.98ms
step:7926/20000 val_loss:1.9800 val_bpb:1.1727 train_time:4850090ms step_avg:611.92ms
stopping_early: wallclock_cap train_time:4850090ms step:7926/20000
peak memory allocated: 18476 MiB reserved: 18726 MiB
SWA: averaging 15 checkpoints
SWA: applied
Serialized model: 106317248 bytes
Code size: 57899 bytes
Total submission size: 106375147 bytes
Serialized model int8+zlib: 24655764 bytes (payload:27333986 raw_torch:27397078 payload_ratio:3.89x)
Total submission size int8+zlib: 24713663 bytes
/runpod-volume/parameter-golf/train_gpt.py:1319: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(io.BytesIO(zlib.decompress(quant_blob_disk)), map_location="cpu")
final_int8_zlib_roundtrip val_loss:1.9969 val_bpb:1.1826 eval_time:18994ms
final_int8_zlib_roundtrip_exact val_loss:1.99685268 val_bpb:1.18264864
Loading