Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ data/manifest.json
data/docs_selected.jsonl
.mypy_cache/
.venv
logs/
logs/
final_model.*
sweep.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
## Depth Recurrence + Cross-Repeat Skip + Value Embeddings

Beats naive baseline (1.2244) by 0.005 bpb using 3.1x fewer training steps through stateful depth recurrence.

val_bpb = 1.2196 (sliding window eval on int8+zlib roundtrip model, stride=256)
val_bpb = 1.2533 (standard int8+zlib roundtrip)

### Architecture

Replaced the baseline's 9 unique transformer blocks with 3 shared blocks repeated 4 times (12 effective layers). Trades unique parameters for effective depth.

Changes from baseline:
- Depth recurrence: 3 blocks x 4 repeats = 12 effective layers (vs 9 in baseline)
- Cross-Repeat Skip (original): each block gets a weighted residual of its own output from the previous repeat, turning stateless recurrence into stateful. Per-repeat learned scales, ~7.5K params total.
- Value Embeddings: 2 extra embedding tables mixed into the residual stream at each effective layer with learned scales. From snimu's modded-nanogpt record.
- Loop Embedding: learned per-layer vector added before each block as depth-wise positional encoding.
- Model dim 832 (vs 512), 8 heads, 4 KV heads, MLP 2x
- Removed U-Net skip connections (Cross-Repeat Skip covers this role)
- 17.14M params, 12.83MB artifact

### Training

LR x0.3 from baseline — recurrence amplifies gradients through 4 passes, so optimal LR is much lower. Found via sweep of 10 configs on RTX 3060.

MATRIX_LR=0.012, SCALAR_LR=0.012, TIED_EMBED_LR=0.015, GRAD_CLIP_NORM=0.3, WARMDOWN_ITERS=3000, TRAIN_SEQ_LEN=1024.

Tested train@2048 but 1024 gives more steps (133ms vs 253ms/step) which matters more for this architecture. Standard Muon + Adam.

### Evaluation

Sliding window eval: window=1024, stride=256 on the int8+zlib roundtrip model. Eval time 209s on 8xH100.

### Results (8xH100, 600s wallclock)

4494 steps, 133ms/step avg. Pre-quant 1.2487, roundtrip 1.2533, sliding window 1.2196. Artifact 12.83MB, quant degradation 0.005 bpb, peak memory ~29GB/GPU.

### Ablations (RTX 3060, 2000 steps each)

- Cross-Repeat Skip: -0.041 bpb
- Value Embeddings (2 tables): -0.079 bpb
- LR x0.3: -0.052 bpb
- Sliding window eval: -0.034 bpb
- WARMDOWN_ITERS=3000: -0.027 bpb

### Development

All experiments, ablations, and hyperparameter sweeps done on a single RTX 3060 12GB. Cloud GPUs (1xH200, 6xH100) used only for validation. Final run on 8xH100.

### Command

```
RUN_ID=submission_8xh100 \
QUANT_LEVELS=127 \
TTT_STEPS=0 \
EVAL_STRIDE=256 \
EVAL_SEQ_LEN=1024 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"author": "Ivan Verbovoy",
"github_id": "iverbovoy",
"name": "Depth Recurrence + Cross-Repeat Skip + Value Embeddings + Sliding Window",
"blurb": "3 unique blocks x 4 repeats (12 effective layers), dim=832, with Cross-Repeat Skip (stateful recurrence), 2 Value Embedding tables, LR x0.3, sliding window eval (stride=256). 4494 steps in 600s on 8xH100.",
"date": "2026-03-20T02:00:00Z",
"val_loss": 2.05921204,
"val_bpb": 1.21958209,
"roundtrip_val_loss": 2.11612232,
"roundtrip_val_bpb": 1.25328684,
"step_stop": 4494,
"wallclock_seconds": 600.133,
"bytes_total": 12829176,
"bytes_model_int8_zlib": 12771121,
"bytes_code": 58055
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
W0320 00:54:42.000000 1050 torch/distributed/run.py:852]
W0320 00:54:42.000000 1050 torch/distributed/run.py:852] *****************************************
W0320 00:54:42.000000 1050 torch/distributed/run.py:852] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0320 00:54:42.000000 1050 torch/distributed/run.py:852] *****************************************
logs/submission_8xh100.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:17140056
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.015 head_lr:0.0 matrix_lr:0.012 scalar_lr:0.012
train_batch_tokens:524288 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9766 val_bpb:4.1319 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9765 train_time:162ms step_avg:161.95ms
step:2/20000 train_loss:9.0581 train_time:218ms step_avg:109.04ms
step:3/20000 train_loss:7.8439 train_time:342ms step_avg:114.12ms
step:4/20000 train_loss:6.5913 train_time:466ms step_avg:116.40ms
step:5/20000 train_loss:6.1067 train_time:589ms step_avg:117.72ms
step:6/20000 train_loss:6.3514 train_time:712ms step_avg:118.70ms
step:7/20000 train_loss:5.9725 train_time:836ms step_avg:119.39ms
step:8/20000 train_loss:5.8139 train_time:958ms step_avg:119.78ms
step:9/20000 train_loss:5.5629 train_time:1081ms step_avg:120.13ms
step:10/20000 train_loss:5.3728 train_time:1206ms step_avg:120.64ms
step:200/20000 train_loss:2.7739 train_time:26609ms step_avg:133.05ms
step:400/20000 train_loss:2.3107 train_time:53543ms step_avg:133.86ms
step:600/20000 train_loss:2.5249 train_time:80122ms step_avg:133.54ms
step:800/20000 train_loss:2.2710 train_time:106824ms step_avg:133.53ms
step:1000/20000 train_loss:2.3610 train_time:133649ms step_avg:133.65ms
step:1000/20000 val_loss:2.3206 val_bpb:1.3744 train_time:133722ms step_avg:133.72ms
step:1200/20000 train_loss:2.3700 train_time:160457ms step_avg:133.71ms
step:1400/20000 train_loss:2.4196 train_time:187085ms step_avg:133.63ms
step:1600/20000 train_loss:2.0826 train_time:213643ms step_avg:133.53ms
step:1800/20000 train_loss:2.1817 train_time:240257ms step_avg:133.48ms
step:2000/20000 train_loss:2.2342 train_time:266823ms step_avg:133.41ms
step:2000/20000 val_loss:2.2137 val_bpb:1.3111 train_time:266903ms step_avg:133.45ms
step:2200/20000 train_loss:2.0469 train_time:293423ms step_avg:133.37ms
step:2400/20000 train_loss:2.1757 train_time:320078ms step_avg:133.37ms
step:2600/20000 train_loss:2.3756 train_time:346626ms step_avg:133.32ms
step:2800/20000 train_loss:2.2012 train_time:373394ms step_avg:133.35ms
step:3000/20000 train_loss:2.1910 train_time:400062ms step_avg:133.35ms
step:3000/20000 val_loss:2.1585 val_bpb:1.2784 train_time:400147ms step_avg:133.38ms
step:3200/20000 train_loss:2.1485 train_time:426762ms step_avg:133.36ms
step:3400/20000 train_loss:2.1171 train_time:453425ms step_avg:133.36ms
step:3600/20000 train_loss:2.0703 train_time:480073ms step_avg:133.35ms
step:3800/20000 train_loss:2.1774 train_time:506627ms step_avg:133.32ms
step:4000/20000 train_loss:2.1156 train_time:532930ms step_avg:133.23ms
step:4000/20000 val_loss:2.1201 val_bpb:1.2556 train_time:533004ms step_avg:133.25ms
step:4200/20000 train_loss:2.1277 train_time:561906ms step_avg:133.79ms
step:4400/20000 train_loss:2.0541 train_time:588700ms step_avg:133.80ms
step:4494/20000 val_loss:2.1084 val_bpb:1.2487 train_time:600133ms step_avg:133.54ms
stopping_early: wallclock_cap train_time:600133ms step:4494/20000
peak memory allocated: 21771 MiB reserved: 21818 MiB
Serialized model: 63387167 bytes
Code size: 58055 bytes
Total submission size: 63445222 bytes
Serialized model int8+zlib: 12771121 bytes (payload:17243616 raw_torch:17261176 payload_ratio:3.68x)
Total submission size int8+zlib: 12829176 bytes
final_int8_zlib_roundtrip val_loss:2.1161 val_bpb:1.2533 eval_time:3709ms
final_int8_zlib_roundtrip_exact val_loss:2.11612232 val_bpb:1.25328684
final_sliding_window val_loss:2.0592 val_bpb:1.2196 window:1024 stride:256 eval_time:209349ms
final_sliding_window_exact val_loss:2.05921204 val_bpb:1.21958209
Loading