From 0444ef13e132a596356801dc7fa1da5bca7f1269 Mon Sep 17 00:00:00 2001 From: callithyia Date: Thu, 26 Mar 2026 21:33:32 +0800 Subject: [PATCH] =?UTF-8?q?Record:=200.3212=20BPB=20=E2=80=94=20Complement?= =?UTF-8?q?ary=20N-gram=2065K=20+=20Int5=20GPTQ=20+=20LoRA=20TTT?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 3-seed mean 0.3212 (std 0.0003). Complementary training + order-9 n-gram eval cache with 65K-token chunks + Full Hessian GPTQ int5 + LoRA TTT with Polyak averaging. --- .../README.md | 87 + .../requirements.txt | 3 + .../submission.json | 33 + .../train_gpt.py | 2181 +++++++++++++++++ .../train_seed1337.log | 524 ++++ .../train_seed2024.log | 523 ++++ .../train_seed42.log | 523 ++++ 7 files changed, 3874 insertions(+) create mode 100644 records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/README.md create mode 100644 records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/requirements.txt create mode 100644 records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/submission.json create mode 100644 records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/train_gpt.py create mode 100644 records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/train_seed1337.log create mode 100644 records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/train_seed2024.log create mode 100644 records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/train_seed42.log diff --git a/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/README.md b/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/README.md new file mode 100644 index 000000000..7802afbca --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/README.md @@ -0,0 +1,87 @@ +# Record: 0.3212 BPB — Complementary N-gram 65K + Int5 GPTQ + LoRA TTT + +**Complementary training + Order-9 n-gram eval cache (65K chunks) + Full Hessian GPTQ Int5 + LoRA TTT + Polyak averaging** + +**val_bpb: 0.3212** (3-seed mean, std 0.0003) | **~14.9 MB** artifact | 8xH100 SXM, 600s train + ~570s eval + +## Results (3 seeds, 8xH100 SXM) + +| Seed | Steps | ms/step | val_bpb | Post-quant BPB | Artifact | +|------|-------|---------|---------|----------------|----------| +| 1337 | 5,457 | 101 | **0.3211** | 1.1817 | 14,965,401 bytes | +| 42 | 5,437 | 101 | **0.3210** | 1.1794 | 14,926,117 bytes | +| 2024 | 5,498 | 101 | **0.3216** | 1.1831 | 14,874,853 bytes | +| **Mean** | **5,464** | **101** | **0.3212** | **1.1814** | **14,922,124 bytes** | +| **Std** | **31** | **0** | **0.0003** | **0.0019** | **45,330 bytes** | + +## Architecture + +- 11 transformer layers, dim=512, GQA 8Q/4KV, head_dim=64 +- MLP 3.0x expansion (hidden=1536) with LeakyReLU(0.9) squared +- XSA on last 4 layers (layers 7-10) +- Value Residual Learning on layers 1-10 +- Gated Attention with bias=4.0 on all layers +- BigramHash 4096-bucket embedding +- Logit softcap 30.0 +- EMA decay 0.997 +- ~27.3M parameters + +## Key Techniques + +### Training +- **Complementary training** (COMPLEMENT_ALPHA=0.50): Downweights bigram-predictable tokens in the loss, making the model deliberately weaker where n-grams are strong. The n-gram cache handles those tokens at eval. +- **Parallel Muon** optimizer with Newton-Schulz5, per-group banking, encoder/decoder LR split (0.025/0.05) +- **WSD learning rate schedule** (75% stable, cosine decay) +- **Late QAT**: Soft-Round quantization-aware training triggered at 85% wallclock + +### Quantization +- **Full Hessian GPTQ Int5**: Activation-order column permutation, Cholesky error compensation, 256-batch calibration +- **LZMA compression** (preset 9 extreme): ~14.8MB artifact + +### Evaluation (single pass, ~570s) +- **Order-9 n-gram backoff cache**: 4M hash buckets, orders 2-9, entropy-adaptive alpha blending +- **65K-token chunks** (65,536): Cache updates 15x more frequently than standard 1M chunks. Reduces cold-cache penalty on early tokens. +- **Per-order entropy centers + multipliers**: Orders 5-9 boosted 2x, orders 2-3 suppressed 0.3x. Per-order sigmoid centers shift trust toward higher orders. +- **LoRA TTT** (rank 8, Q+V on blocks 9-10): AdamW lr=0.003, Polyak averaging decay=0.998. Adapts model weights causally per chunk. +- **Score-first protocol**: Each chunk scored before cache update (backward-looking compliant). + +### What's Novel +- First combination of complementary training + order-9 n-gram cache + 65K chunks + LoRA TTT with Polyak averaging +- Per-order entropy centers combined with per-order multipliers for alpha computation +- Full Hessian GPTQ with Soft-Round QAT (not naive quantization) + +## Setup and Run + +```bash +cd /workspace +git clone https://github.com/openai/parameter-golf.git pgolf +cd pgolf +pip install --break-system-packages -r requirements.txt zstandard +python data/cached_challenge_fineweb.py --variant sp1024 + +# Run (single seed) +SEED=1337 MAX_WALLCLOCK_SECONDS=600 PROG_SEQ_ENABLED=0 \ + torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Compliance + +- [x] 3 seeds run on 8xH100 SXM +- [x] All seeds train in <=600s +- [x] All seeds eval in <=600s (~570s) +- [x] Artifact <=16,000,000 bytes (~14.9MB) +- [x] No validation data accessed during training +- [x] TTT is backward-looking (score-first per chunk) +- [x] No network calls during evaluation +- [x] No multi-pass rescoring +- [x] Reproducible from single script with seed + +## Credits + +Built on techniques from: +- **PR #809** (@quietsmile): Per-order multipliers, entropy-adaptive alpha, order-9 n-gram backoff cache +- **PR #803** (@travispchen): Complementary training (bigram-weighted loss) +- **PR #798** (@travispchen): Per-order entropy centers, drift-free TTT, Polyak averaging +- **PR #840** (@quietsmile): 65K-token chunk size for n-gram eval +- **PR #779** (@lukacf): Integrated TTT + n-gram eval loop concept +- **PR #414** (@signalrush): GPTQ + EMA + warmdown baseline diff --git a/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/requirements.txt b/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/requirements.txt new file mode 100644 index 000000000..201fcc881 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/requirements.txt @@ -0,0 +1,3 @@ +# Non-standard dependencies beyond the competition template +# (torch, numpy, sentencepiece, flash-attn are pre-installed) +zstandard diff --git a/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/submission.json b/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/submission.json new file mode 100644 index 000000000..00d72e059 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/submission.json @@ -0,0 +1,33 @@ +{ + "name": "Complementary N-gram 65K + Int5 GPTQ + LoRA TTT", + "val_bpb": 0.3212, + "val_bpb_std": 0.0003, + "n_seeds": 3, + "seeds": [1337, 42, 2024], + "bytes_total": 14965401, + "blurb": "Complementary training (alpha=0.50) + order-9 n-gram eval cache with 65K-token chunks (15x cache refresh) + entropy-adaptive alpha with per-order centers and multipliers + Full Hessian GPTQ int5 + LZMA + LoRA TTT (rank 8, Polyak decay 0.998) + LeakyReLU(0.9)^2 + XSA-4 + VRL + Gated Attention + Parallel Muon. 3-seed mean: 0.3212 (std 0.0003).", + "author": "callithyia", + "github_id": "callithyia", + "date": "2026-03-26", + "model_params": 27301064, + "num_layers": 11, + "model_dim": 512, + "num_heads": 8, + "num_kv_heads": 4, + "mlp_mult": 3.0, + "vocab_size": 1024, + "train_seq_len": 2048, + "eval_stride": 64, + "ngram_order": 9, + "ngram_chunk_tokens": 65536, + "complement_alpha": 0.5, + "ttt_epochs": 1, + "ttt_lora_rank": 8, + "ttt_polyak_decay": 0.998, + "compression": "lzma", + "quantization": "gptq_int5", + "step_avg_ms": 101, + "train_time_seconds": 600, + "eval_time_seconds": 572, + "hardware": "8xH100 SXM (RunPod)" +} diff --git a/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/train_gpt.py b/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/train_gpt.py new file mode 100644 index 000000000..7d7d13d44 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/train_gpt.py @@ -0,0 +1,2181 @@ +import copy +import glob +import io +import json +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +import zlib +from pathlib import Path +try: + import zstandard + _HAS_ZSTD = True +except ImportError: + _HAS_ZSTD = False +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from flash_attn_interface import flash_attn_func as flash_attn_3_func +_E=os.environ.get +_sync=torch.cuda.synchronize +_now=time.perf_counter +class HP: + data_path = _E("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = _E("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = _E("RUN_ID", str(uuid.uuid4())) + seed = int(_E("SEED", 1337)) + val_batch_size = int(_E("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(_E("VAL_LOSS_EVERY", 4000)) + train_log_every = int(_E("TRAIN_LOG_EVERY", 500)) + iterations = int(_E("ITERATIONS", 20000)) + warmdown_iters = int(_E("WARMDOWN_ITERS", 3500)) + warmup_steps = int(_E("WARMUP_STEPS", 20)) + train_batch_tokens = int(_E("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(_E("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(_E("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(_E("MAX_WALLCLOCK_SECONDS", 600.0)) + qk_gain_init = float(_E("QK_GAIN_INIT", 1.5)) + vocab_size = int(_E("VOCAB_SIZE", 1024)) + num_layers = int(_E("NUM_LAYERS", 11)) + num_kv_heads = int(_E("NUM_KV_HEADS", 4)) + model_dim = int(_E("MODEL_DIM", 512)) + num_heads = int(_E("NUM_HEADS", 8)) + mlp_mult = float(_E("MLP_MULT", 3.0)) + tie_embeddings = bool(int(_E("TIE_EMBEDDINGS", "1"))) + rope_base = float(_E("ROPE_BASE", 10000.0)) + logit_softcap = float(_E("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(_E("EMBED_LR", 0.6)) + head_lr = float(_E("HEAD_LR", 0.008)) + tied_embed_lr = float(_E("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(_E("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(_E("MATRIX_LR", 0.025)) + scalar_lr = float(_E("SCALAR_LR", 0.025)) + muon_momentum = float(_E("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(_E("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(_E("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(_E("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(_E("BETA1", 0.9)) + beta2 = float(_E("BETA2", 0.95)) + adam_eps = float(_E("ADAM_EPS", 1e-8)) + grad_clip_norm = float(_E("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(_E("EVAL_STRIDE", 64)) + mtp_num_heads = int(_E("MTP_NUM_HEADS", 0)) + mtp_loss_weight = float(_E("MTP_LOSS_WEIGHT", 0.2)) + muon_beta2 = float(_E("MUON_BETA2", 0.95)) + muon_wd = float(_E("MUON_WD", 0.04)) + adam_wd = float(_E("ADAM_WD", 0.04)) + qat_enabled = bool(int(_E("QAT_ENABLED", "0"))) + bigram_vocab_size = int(_E("BIGRAM_VOCAB_SIZE", 4096)) # CHANGE 6: 2048 -> 4096 + bigram_dim = int(_E("BIGRAM_DIM", 128)) + xsa_last_n = int(_E("XSA_LAST_N", 4)) + rope_dims = int(_E("ROPE_DIMS", 16)) + ln_scale = bool(int(_E("LN_SCALE", "1"))) + dtg_enabled = bool(int(_E("DTG_ENABLED", "0"))) + late_qat_threshold = float(_E("LATE_QAT_THRESHOLD", 0.15)) + ve_enabled = bool(int(_E("VE_ENABLED", "1"))) + ve_dim = int(_E("VE_DIM", 128)) + ve_layers = _E("VE_LAYERS", "9,10") + vrl_enabled = bool(int(_E("VRL_ENABLED", "1"))) + gated_attn_enabled = bool(int(_E("GATED_ATTN_ENABLED", "1"))) + gated_attn_bias_init = float(_E("GATED_ATTN_BIAS_INIT", 4.0)) + wsd_stable_frac = float(_E("WSD_STABLE_FRAC", 0.75)) + qat_trigger_frac = float(_E("QAT_TRIGGER_FRAC", 0.85)) + prog_seq_enabled = bool(int(_E("PROG_SEQ_ENABLED", "0"))) + prog_seq_phase1_len = int(_E("PROG_SEQ_PHASE1_LEN", 512)) + prog_seq_phase2_len = int(_E("PROG_SEQ_PHASE2_LEN", 1024)) + prog_seq_phase3_len = int(_E("PROG_SEQ_PHASE3_LEN", 2048)) + prog_seq_frac1 = float(_E("PROG_SEQ_FRAC1", 0.33)) + prog_seq_frac2 = float(_E("PROG_SEQ_FRAC2", 0.55)) + gptq_n_samples = int(_E("GPTQ_N_SAMPLES", 256)) + gptq_block_size = int(_E("GPTQ_BLOCK_SIZE", 128)) + gptq_percdamp = float(_E("GPTQ_PERCDAMP", 0.01)) + ttt_epochs = int(_E("TTT_EPOCHS", 1)) + ttt_lr = float(_E("TTT_LR", 0.003)) # CHANGE 4: 0.003 for shared LoRA (0.01 was for per-doc BatchedTTTLoRA) + ttt_chunk_tokens = int(_E("TTT_CHUNK_TOKENS", 65536)) + ttt_freeze_blocks = int(_E("TTT_FREEZE_BLOCKS", 9)) + ttt_grad_clip = float(_E("TTT_GRAD_CLIP", 1.0)) + ttt_temperature = float(_E("TTT_TEMPERATURE", 0.98)) + ttt_momentum = float(_E("TTT_MOMENTUM", 0.9)) + ttt_lora_rank = int(_E("TTT_LORA_RANK", 8)) # CHANGE 4: LoRA rank + ttt_polyak_decay = float(_E("TTT_POLYAK_DECAY", 0.998)) # CHANGE 4: Polyak + decoder_lr_mult = float(_E("DECODER_LR_MULT", 2.0)) + ngram_enabled = bool(int(_E("NGRAM_ENABLED", "1"))) + ngram_order = int(_E("NGRAM_EVAL_ORDER", 9)) + ngram_min_order = int(_E("NGRAM_EVAL_MIN_ORDER", 2)) + ngram_alpha = float(_E("NGRAM_EVAL_ALPHA", 0.30)) + ngram_adaptive = bool(int(_E("NGRAM_EVAL_ADAPTIVE", "1"))) + ngram_alpha_min = float(_E("NGRAM_EVAL_ALPHA_MIN", 0.12)) # CHANGE 2: 0.12 balances complement training + multipliers + ngram_alpha_max = float(_E("NGRAM_EVAL_ALPHA_MAX", 0.60)) + ngram_ent_center = float(_E("NGRAM_EVAL_ENTROPY_CENTER", 3.0)) + ngram_ent_scale = float(_E("NGRAM_EVAL_ENTROPY_SCALE", 2.0)) + ngram_min_count = int(_E("NGRAM_EVAL_MIN_COUNT", 2)) + ngram_buckets = int(_E("NGRAM_EVAL_BUCKETS", 4194304)) + ngram_batch_seqs = int(_E("NGRAM_EVAL_BATCH_SEQS", 32)) + ngram_chunk_tokens = int(_E("NGRAM_EVAL_CHUNK_TOKENS", 1_000_000)) + complement_alpha = float(_E("COMPLEMENT_ALPHA", 0.50)) # CHANGE 3: 0.20 -> 0.50 + prune_pct = float(_E("PRUNE_PCT", 0.03)) # CHANGE 8: 0.02 -> 0.03 +def _ns5(G , steps = 10, eps = 1e-7): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= X.norm() + eps + transposed = G.size(0) > G.size(1) + if transposed: + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + return X.T if transposed else X +def _ns5b(G , steps = 10, eps = 1e-7): + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + norms = X.flatten(1).norm(dim=1).unsqueeze(-1).unsqueeze(-1).clamp_min(eps) + X = X / norms + transposed = X.size(1) > X.size(2) + if transposed: + X = X.transpose(1, 2) + for _ in range(steps): + A = X.bmm(X.transpose(1, 2)) + B = b * A + c * A.bmm(A) + X = a * X + B.bmm(X) + if transposed: + X = X.transpose(1, 2) + return X +class PM(torch.optim.Optimizer): + def __init__(self, params, lr , momentum , backend_steps , + nesterov = True, weight_decay = 0.0): + defaults = dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay) + param_list = list(params) + super().__init__(param_list, defaults) + self._bb() + def _bb(self): + self._group_banks: list[dict[tuple[int, int], list[int]]] = [] + for group in self.param_groups: + banks: dict[tuple[int, int], list[int]] = {} + for idx, p in enumerate(group["params"]): + if p.ndim == 2: + shape_key = (p.shape[0], p.shape[1]) + if shape_key not in banks: + banks[shape_key] = [] + banks[shape_key].append(idx) + self._group_banks.append(banks) + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + distributed = dist.is_available() and dist.is_initialized() + world_size = dist.get_world_size() if distributed else 1 + rank = dist.get_rank() if distributed else 0 + for gi, group in enumerate(self.param_groups): + params = group["params"] + if not params: + continue + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + total_params = sum(int(p.numel()) for p in params) + updates_flat = torch.zeros(total_params, device=params[0].device, dtype=torch.bfloat16) + param_updates = {} + momentum_grads = {} + for i, p in enumerate(params): + if i % world_size != rank or p.grad is None: + continue + g = p.grad + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + buf.mul_(momentum).add_(g) + if nesterov: + g = g.add(buf, alpha=momentum) + momentum_grads[i] = g + group_banks = self._group_banks[gi] if gi < len(self._group_banks) else {} + for shape_key, indices in group_banks.items(): + bank_indices = [idx for idx in indices if idx in momentum_grads] + if not bank_indices: + continue + if len(bank_indices) >= 2: + stacked = torch.stack([momentum_grads[idx] for idx in bank_indices]) + ns_result = _ns5b(stacked, steps=backend_steps) + for j, idx in enumerate(bank_indices): + p = params[idx] + g = ns_result[j] + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + param_updates[idx] = g + else: + idx = bank_indices[0] + g = _ns5(momentum_grads[idx], steps=backend_steps) + g *= max(1, g.size(0) / g.size(1)) ** 0.5 + param_updates[idx] = g + for idx, g in momentum_grads.items(): + if idx not in param_updates: + param_updates[idx] = g + curr = 0 + for i, p in enumerate(params): + if i in param_updates: + updates_flat[curr : curr + p.numel()] = param_updates[i].reshape(-1) + curr += p.numel() + if distributed: + dist.all_reduce(updates_flat, op=dist.ReduceOp.SUM) + curr = 0 + for p in params: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + g = updates_flat[curr : curr + p.numel()].view_as(p).to(dtype=p.dtype) + p.add_(g, alpha=-lr) + curr += p.numel() + return loss +def _sp_luts( + sp: spm.SentencePieceProcessor, vocab_size , device ): + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def _ld_shard(file ): + header_bytes = 256 * np.dtype(" 0: + avail = self.tokens.numel() - self.pos + if avail <= 0: + self._af() + continue + k = min(remaining, avail) + chunks.append(self.tokens[self.pos : self.pos + k]) + self.pos += k + remaining -= k + return chunks[0] if len(chunks) == 1 else torch.cat(chunks) +class DTL: + def __init__(self, pattern , rank , world_size , device ): + self.rank = rank + self.world_size = world_size + self.device = device + self.stream = TS(pattern) + def next_batch(self, global_tokens , seq_len , gas ): + local_tokens = global_tokens // (self.world_size * gas) + per_rank_span = local_tokens + 1 + chunk = self.stream.take(per_rank_span * self.world_size) + start = self.rank * per_rank_span + local = chunk[start : start + per_rank_span].to(dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) +def _ev( + args , + model , + rank , + world_size , + device , + gas , + vtok , + bblut , + hlslut , + ibtlut , + eval_seq_len = None, +): + seq_len = eval_seq_len or args.train_seq_len + local_batch_tokens = args.val_batch_size // (world_size * gas) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={gas}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (vtok.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = vtok[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = bblut[tgt_ids].to(dtype=torch.int16) + token_bytes += (hlslut[tgt_ids] & ~ibtlut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# ============================================================ +# CHANGE 4: LoRA Layer for TTT +# ============================================================ +class LoRALayer(nn.Module): + """Low-rank adapter for TTT. Adds A @ B to the base linear's output.""" + def __init__(self, in_features, out_features, rank=8, device=None, dtype=None): + super().__init__() + self.rank = rank + self.lora_A = nn.Parameter(torch.randn(rank, in_features, device=device, dtype=dtype) * (1.0 / math.sqrt(in_features))) + self.lora_B = nn.Parameter(torch.zeros(out_features, rank, device=device, dtype=dtype)) + def forward(self, x): + # x: (..., in_features) -> (..., out_features) + return F.linear(F.linear(x, self.lora_A), self.lora_B) + +# ============================================================ +# N-gram hash/score/update helpers +# ============================================================ +_NHP = np.array( + [36313, 27191, 51647, 81929, 131071, 174763, 233021, 283721, 347239], + dtype=np.uint64, +) + +# CHANGE 5: Per-order entropy centers from PR #798 +_NGRAM_ENT_CENTERS = {9: 2.5, 8: 2.75, 7: 3.0, 6: 3.2, 5: 3.5, 4: 3.8, 3: 4.2, 2: 4.5} + +def _ngh( + val_np , + global_j , + ctx_width , + primes , + mask , +): + valid = global_j >= ctx_width + if not valid.any(): + empty = np.array([], dtype=np.int64) + return empty, empty, empty + v_idx = np.nonzero(valid)[0] + jv = global_j[v_idx] + ctx_hash = np.zeros(len(jv), dtype=np.uint64) + for k in range(ctx_width): + tok = val_np[jv - (ctx_width - k)].astype(np.uint64) + ctx_hash ^= tok * primes[k % len(primes)] + ctx_key = (ctx_hash & mask).astype(np.int64) + tgt = val_np[jv].astype(np.uint64) + full_key = ((ctx_hash ^ (tgt * primes[ctx_width % len(primes)])) & mask).astype(np.int64) + return v_idx, ctx_key, full_key +def _ngs( + seg_model_p , + seg_logits , + val_np , + global_j , + ctx_tables: dict[int, np.ndarray], + full_tables: dict[int, np.ndarray], + *, + max_order = 9, + min_order = 2, + adaptive = True, + alpha_fixed = 0.30, + alpha_min = 0.05, + alpha_max = 0.60, + ent_center = 3.0, + ent_scale = 2.0, + min_count = 2, + primes = _NHP, + mask = np.uint64(4_194_304 - 1), +): + seg_len = len(seg_model_p) + if adaptive: + with torch.no_grad(): + log_probs = F.log_softmax(seg_logits.float(), dim=-1) + probs = log_probs.exp() + entropy = -(probs * log_probs).sum(dim=-1).cpu().numpy() + _OM = {2: 0.3, 3: 0.3, 4: 1.0} + p_ng = np.zeros(seg_len, dtype=np.float64) + ng_matched = np.zeros(seg_len, dtype=np.bool_) + ng_order = np.zeros(seg_len, dtype=np.int32) + for n in range(max_order, min_order - 1, -1): + ctx_width = n - 1 + v_idx, ctx_key, full_key = _ngh( + val_np, global_j, ctx_width, primes, mask, + ) + if len(v_idx) == 0: + continue + still_need = ~ng_matched[v_idx] + if not still_need.any(): + continue + v_idx = v_idx[still_need] + ctx_key = ctx_key[still_need] + full_key = full_key[still_need] + ctx_counts = ctx_tables[n][ctx_key].astype(np.float64) + full_counts = full_tables[n][full_key].astype(np.float64) + has_data = ctx_counts >= float(min_count) + if has_data.any(): + p = np.minimum(full_counts, ctx_counts) / np.maximum(ctx_counts, 1.0) + p = np.clip(p, 0.0, 1.0) + hit_idx = v_idx[has_data] + p_ng[hit_idx] = p[has_data] + ng_matched[hit_idx] = True + ng_order[hit_idx] = n + mixed_p = seg_model_p.copy() + if ng_matched.any(): + m_idx = np.nonzero(ng_matched)[0] + matched_orders = ng_order[m_idx] + if adaptive: + # CHANGE 5: Use per-order entropy centers combined with per-order multipliers + centers = np.array([_NGRAM_ENT_CENTERS.get(o, ent_center) for o in matched_orders], dtype=np.float64) + sig = 1.0 / (1.0 + np.exp(-ent_scale * (entropy[m_idx] - centers))) + a = alpha_min + (alpha_max - alpha_min) * sig + # Apply per-order multipliers on top of per-order centers + for order_val, mult in _OM.items(): + order_mask = matched_orders == order_val + a[order_mask] *= mult + high_order_mask = matched_orders >= 5 + a[high_order_mask] *= 2.0 + a = np.clip(a, 0.0, 0.95) + else: + a = np.full(len(m_idx), alpha_fixed) + mixed_p[m_idx] = (1.0 - a) * mixed_p[m_idx] + a * p_ng[m_idx] + return -np.log(np.clip(mixed_p, 1e-12, 1.0)) +def _ngu( + val_np , + global_j , + ctx_tables: dict[int, np.ndarray], + full_tables: dict[int, np.ndarray], + *, + max_order = 9, + min_order = 2, + primes = _NHP, + mask = np.uint64(4_194_304 - 1), +): + buckets = int(mask) + 1 + for n in range(min_order, max_order + 1): + ctx_width = n - 1 + v_idx, ctx_key, full_key = _ngh( + val_np, global_j, ctx_width, primes, mask, + ) + if len(v_idx) == 0: + continue + ctx_tables[n] += np.bincount(ctx_key, minlength=buckets).astype(ctx_tables[n].dtype) + full_tables[n] += np.bincount(full_key, minlength=buckets).astype(full_tables[n].dtype) + +# ============================================================ +# CHANGE 1: INTEGRATED SINGLE-PASS EVAL (TTT + N-gram combined) +# ============================================================ +def _ev_integrated( + args , + bm , + rank , + world_size , + device , + vtok , + bblut , + hlslut , + ibtlut , + stride , + batch_seqs = 32, + eval_seq_len = None, + log_fn=None, + time_budget_s = 600.0, +): + """Single-pass evaluation: for each chunk, score with sliding window, + compute n-gram probabilities, blend, accumulate BPB, update n-gram cache, + then train TTT (LoRA) on scored tokens.""" + seq_len = eval_seq_len or args.train_seq_len + total_tokens = vtok.numel() - 1 + chunk_tokens = args.ttt_chunk_tokens + ttt_epochs = args.ttt_epochs + ttt_lr = args.ttt_lr + ttt_grad_clip = args.ttt_grad_clip + ttt_temperature = args.ttt_temperature + lora_rank = args.ttt_lora_rank + polyak_decay = args.ttt_polyak_decay + + # N-gram config + max_order = args.ngram_order + min_order = max(args.ngram_min_order, 2) + adaptive = args.ngram_adaptive + alpha_fixed = args.ngram_alpha + alpha_min = args.ngram_alpha_min + alpha_max = args.ngram_alpha_max + ent_center = args.ngram_ent_center + ent_scale = args.ngram_ent_scale + min_count = args.ngram_min_count + buckets = args.ngram_buckets + mask = np.uint64(buckets - 1) + + # Initialize n-gram tables + ctx_tables = { + n: np.zeros(buckets, dtype=np.uint32) for n in range(min_order, max_order + 1) + } + full_tables = { + n: np.zeros(buckets, dtype=np.uint32) for n in range(min_order, max_order + 1) + } + val_np = vtok.numpy() + + # Build window lists per chunk + all_window_starts = [ + ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1 + ] + num_chunks = max((total_tokens + chunk_tokens - 1) // chunk_tokens, 1) + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in all_window_starts: + wlen = min(ws + seq_len, total_tokens) - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // chunk_tokens, num_chunks - 1) + chunk_windows[ci].append(ws) + + distributed = dist.is_available() and dist.is_initialized() + num_layers = len(bm.blocks) + + # ---- CHANGE 4: Setup LoRA adapters on Q and V of last 2 blocks ---- + lora_layers = {} # {(block_idx, 'q'|'v'): LoRALayer} + unfreeze_start = max(num_layers - 2, 0) # last 2 blocks + for bi in range(unfreeze_start, num_layers): + block = bm.blocks[bi] + q_in = block.attn.c_q.in_features + q_out = block.attn.c_q.out_features + v_in = block.attn.c_v.in_features + v_out = block.attn.c_v.out_features + lora_q = LoRALayer(q_in, q_out, rank=lora_rank, device=device, dtype=torch.float32) + lora_v = LoRALayer(v_in, v_out, rank=lora_rank, device=device, dtype=torch.float32) + lora_layers[(bi, 'q')] = lora_q + lora_layers[(bi, 'v')] = lora_v + + # Collect all LoRA parameters + lora_params = [] + for lora in lora_layers.values(): + lora_params.extend(list(lora.parameters())) + + # Polyak EMA state for LoRA params + polyak_state = {id(p): p.data.clone() for p in lora_params} + + if log_fn: + total_lora = sum(p.numel() for p in lora_params) + log_fn(f"integrated_eval: LoRA rank={lora_rank} on Q,V of blocks {unfreeze_start}-{num_layers-1}, " + f"{total_lora} LoRA params, {num_chunks} chunks, polyak_decay={polyak_decay}, " + f"ngram order={min_order}-{max_order} adaptive={adaptive} alpha=[{alpha_min},{alpha_max}]") + + # Monkey-patch forward to inject LoRA + # Save original forward methods + _orig_attn_fwd = {} + for bi in range(unfreeze_start, num_layers): + block = bm.blocks[bi] + _orig_attn_fwd[bi] = block.attn.forward + + def _make_lora_attn_forward(block_idx, orig_fwd, loras): + lora_q = loras[(block_idx, 'q')] + lora_v = loras[(block_idx, 'v')] + def lora_forward(x, v_embed=None, v0=None, vr_lambda=None): + attn = bm.blocks[block_idx].attn + bsz, seqlen, dim = x.shape + q = attn.c_q(x) + lora_q(x) # Add LoRA delta to Q + q = q.reshape(bsz, seqlen, attn.num_heads, attn.head_dim) + k = attn.c_k(x).reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + v = attn.c_v(x) + lora_v(x) # Add LoRA delta to V + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, attn.num_kv_heads, attn.head_dim) + raw_v = v + if v0 is not None and vr_lambda is not None: + lam = torch.softmax(vr_lambda, dim=0) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = attn.rotary(seqlen, x.device, q.dtype) + q = _arope(q, cos, sin, attn.rope_dims) + k = _arope(k, cos, sin, attn.rope_dims) + q = q * attn.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if attn.use_xsa: + y = attn._xsa(y, v) + if attn.use_gated_attn: + gate = torch.sigmoid(attn.attn_gate(x)) + y = y * gate.unsqueeze(-1) + y = y.reshape(bsz, seqlen, dim) + return attn.proj(y), raw_v + return lora_forward + + def _install_lora(loras, use_polyak=False): + """Install LoRA-augmented forward on attention layers. If use_polyak, + swap in polyak-averaged LoRA weights for scoring.""" + if use_polyak: + # Swap in polyak weights + saved = {} + for p in lora_params: + saved[id(p)] = p.data.clone() + p.data.copy_(polyak_state[id(p)]) + for bi in range(unfreeze_start, num_layers): + bm.blocks[bi].attn.forward = _make_lora_attn_forward(bi, _orig_attn_fwd[bi], loras) + if use_polyak: + return saved + return None + + def _restore_lora_weights(saved): + """Restore training weights after polyak scoring.""" + if saved is not None: + for p in lora_params: + p.data.copy_(saved[id(p)]) + + def _uninstall_lora(): + """Restore original forward methods.""" + for bi in range(unfreeze_start, num_layers): + bm.blocks[bi].attn.forward = _orig_attn_fwd[bi] + + # Freeze all base model params for TTT + for p in bm.parameters(): + p.requires_grad_(False) + + # ---- Accumulators ---- + model_loss_sum = 0.0 + ngram_loss_sum = 0.0 + token_count = 0.0 + byte_count = 0.0 + + t0 = _now() + + # ---- AdamW optimizer for LoRA (CHANGE 4) ---- + ttt_optimizer = torch.optim.AdamW( + [{"params": lora_params, "lr": ttt_lr}], + lr=ttt_lr, + betas=(0.9, 0.999), + weight_decay=0.0, + ) + + for ci in range(num_chunks): + chunk_ws_all = chunk_windows[ci] + if not chunk_ws_all: + # Still need to update ngram cache for this chunk + chunk_start = ci * chunk_tokens + chunk_end = min((ci + 1) * chunk_tokens, total_tokens) + chunk_global_j = np.arange(chunk_start + 1, chunk_end + 1, dtype=np.int64) + _ngu(val_np, chunk_global_j, ctx_tables, full_tables, + max_order=max_order, min_order=min_order, primes=_NHP, mask=mask) + continue + + # Shard windows across ranks for scoring + my_s = (len(chunk_ws_all) * rank) // world_size + my_e = (len(chunk_ws_all) * (rank + 1)) // world_size + my_windows = chunk_ws_all[my_s:my_e] + + chunk_model_loss = 0.0 + chunk_ngram_loss = 0.0 + chunk_token_count = 0.0 + chunk_byte_count = 0.0 + + # ---- SCORE phase: install LoRA with Polyak weights ---- + bm.eval() + saved_weights = _install_lora(lora_layers, use_polyak=(ci > 0)) + + with torch.inference_mode(): + for bi_w in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi_w:bi_w + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + seg = vtok[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = seg[:-1] + y_batch[i, :wlen] = seg[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = bm.fwd_l(x_batch) + logits_f = logits.float() + if ttt_temperature != 1.0: + logits_f = logits_f / ttt_temperature + nll = F.cross_entropy( + logits_f.reshape(-1, logits_f.size(-1)), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + seg_len_val = wlen - s + if seg_len_val <= 0: + continue + seg_nll = nll[i, s:wlen].to(torch.float64).cpu().numpy() + seg_model_p = np.exp(-seg_nll) + global_j = np.arange(ws + s + 1, ws + wlen + 1, dtype=np.int64) + chunk_model_loss += float(seg_nll.sum()) + # N-gram scoring integrated here + mixed_nll = _ngs( + seg_model_p, + logits_f[i, s:wlen], + val_np, + global_j, + ctx_tables, + full_tables, + max_order=max_order, + min_order=min_order, + adaptive=adaptive, + alpha_fixed=alpha_fixed, + alpha_min=alpha_min, + alpha_max=alpha_max, + ent_center=ent_center, + ent_scale=ent_scale, + min_count=min_count, + primes=_NHP, + mask=mask, + ) + chunk_ngram_loss += float(mixed_nll.sum()) + chunk_token_count += float(seg_len_val) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = bblut[tgt].to(torch.float64) + tb += (hlslut[tgt] & ~ibtlut[prev]).to(torch.float64) + chunk_byte_count += float(tb.sum().item()) + + # Restore training LoRA weights after polyak scoring + _restore_lora_weights(saved_weights) + + # Aggregate across ranks + _cm = torch.tensor(chunk_model_loss, device=device, dtype=torch.float64) + _cn = torch.tensor(chunk_ngram_loss, device=device, dtype=torch.float64) + _ct = torch.tensor(chunk_token_count, device=device, dtype=torch.float64) + _cb = torch.tensor(chunk_byte_count, device=device, dtype=torch.float64) + if distributed: + dist.all_reduce(_cm, op=dist.ReduceOp.SUM) + dist.all_reduce(_cn, op=dist.ReduceOp.SUM) + dist.all_reduce(_ct, op=dist.ReduceOp.SUM) + dist.all_reduce(_cb, op=dist.ReduceOp.SUM) + model_loss_sum += _cm.item() + ngram_loss_sum += _cn.item() + token_count += _ct.item() + byte_count += _cb.item() + + # ---- UPDATE n-gram cache AFTER scoring ---- + chunk_start = ci * chunk_tokens + chunk_end = min((ci + 1) * chunk_tokens, total_tokens) + chunk_global_j = np.arange(chunk_start + 1, chunk_end + 1, dtype=np.int64) + _ngu( + val_np, chunk_global_j, ctx_tables, full_tables, + max_order=max_order, min_order=min_order, + primes=_NHP, mask=mask, + ) + + # ---- TRAIN LoRA on this chunk (score-first: we already scored) ---- + if ci < num_chunks - 1: + _install_lora(lora_layers, use_polyak=False) # install with training weights + chunk_len = chunk_end - chunk_start + if chunk_len >= seq_len: + num_seqs = chunk_len // seq_len + train_tokens = vtok[chunk_start : chunk_start + num_seqs * seq_len + 1] + bm.train() + for epoch in range(ttt_epochs): + for si in range(0, num_seqs, batch_seqs): + batch_end_s = min(si + batch_seqs, num_seqs) + actual_bsz = batch_end_s - si + t_start = si * seq_len + t_end = batch_end_s * seq_len + 1 + local_tok = train_tokens[t_start:t_end].to(dtype=torch.int64, device=device) + x_train = local_tok[:-1].reshape(actual_bsz, seq_len) + y_train = local_tok[1:].reshape(actual_bsz, seq_len) + ttt_optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits_train = bm.fwd_l(x_train) + per_tok_loss = F.cross_entropy( + logits_train.float().reshape(-1, logits_train.size(-1)), + y_train.reshape(-1), + reduction="none", + ).reshape(actual_bsz, seq_len) + # CHANGE 7: Byte-weighted TTT loss + byte_weights = bblut[y_train].float() + byte_weights += (hlslut[y_train] & ~ibtlut[x_train]).float() + byte_weights = byte_weights.clamp(min=1.0) + loss = (per_tok_loss * byte_weights).sum() / byte_weights.sum() + loss.backward() + if distributed: + for p in lora_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + if ttt_grad_clip > 0: + torch.nn.utils.clip_grad_norm_(lora_params, ttt_grad_clip) + ttt_optimizer.step() + # Polyak update after training on this chunk + with torch.no_grad(): + for p in lora_params: + polyak_state[id(p)].mul_(polyak_decay).add_(p.data, alpha=1.0 - polyak_decay) + + if log_fn and (ci + 1) % 5 == 0: + elapsed = _now() - t0 + cur_model_bpb = (model_loss_sum / max(token_count, 1.0)) / math.log(2.0) * (token_count / max(byte_count, 1.0)) + cur_ngram_bpb = (ngram_loss_sum / max(token_count, 1.0)) / math.log(2.0) * (token_count / max(byte_count, 1.0)) + log_fn( + f"integrated_eval: chunk {ci + 1}/{num_chunks} " + f"model_bpb={cur_model_bpb:.6f} ngram_bpb={cur_ngram_bpb:.6f} " + f"delta={cur_ngram_bpb - cur_model_bpb:.6f} t={elapsed:.0f}s" + ) + + # Per-chunk time guard: abort if eval budget exceeded + elapsed_eval = _now() - t0 + if time_budget_s > 0 and elapsed_eval > time_budget_s: + if log_fn: + log_fn(f"integrated_eval: ABORTING at chunk {ci + 1}/{num_chunks} -- " + f"elapsed {elapsed_eval:.0f}s exceeds {time_budget_s:.0f}s budget") + break + + # Restore original attention forwards + _uninstall_lora() + + model_val_loss = model_loss_sum / max(token_count, 1.0) + ngram_val_loss = ngram_loss_sum / max(token_count, 1.0) + tpb = token_count / max(byte_count, 1.0) + model_val_bpb = model_val_loss / math.log(2.0) * tpb + ngram_val_bpb = ngram_val_loss / math.log(2.0) * tpb + if log_fn: + elapsed = _now() - t0 + log_fn( + f"integrated_eval: DONE model_bpb={model_val_bpb:.4f} ngram_bpb={ngram_val_bpb:.4f} " + f"delta={ngram_val_bpb - model_val_bpb:.4f} elapsed={elapsed:.0f}s" + ) + return model_val_loss, model_val_bpb, ngram_val_loss, ngram_val_bpb + + +_CTRL = tuple( + pattern + for pattern in _E( + "_CTRL", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", + ).split(",") + if pattern +) +_PRSD = torch.float16 +_CLQ = 0.9999984 +def _qf(t ): + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), _CLQ, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=_PRSD).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), _CLQ).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def _qi5(t , clip_range = 15): + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale +def _brs(W , clip_range = 15): + W32 = W.float() + nrows = W32.shape[0] + best_scales = torch.zeros(nrows, dtype=torch.float32, device=W.device) + best_err = torch.full((nrows,), float('inf'), dtype=torch.float32, device=W.device) + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(W32.abs(), pct, dim=1) + else: + row_clip = W32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range) + q = torch.clamp(torch.round(W32 / s[:, None]), -clip_range, clip_range) + recon = q * s[:, None] + row_mse = (W32 - recon).pow(2).mean(dim=1) + improved = row_mse < best_err + best_scales[improved] = s[improved] + best_err[improved] = row_mse[improved] + return best_scales +def _gptq_qw( + W , + H , + clip_range = 15, + block_size = 128, + percdamp = 0.01, +): + W = W.float().clone() + nrows, ncols = W.shape + H = H.float().clone() + row_scale = _brs(W, clip_range) + diag = torch.diag(H) + damp = percdamp * diag.mean() + diag += damp + H[range(ncols), range(ncols)] = diag + perm = torch.argsort(torch.diag(H)) + invperm = torch.argsort(perm) + W = W[:, perm] + H = H[perm][:, perm] + try: + L = torch.linalg.cholesky(H) + Hinv = torch.cholesky_inverse(L) + except torch.linalg.LinAlgError: + Hinv = torch.diag(1.0 / torch.diag(H).clamp_min(1e-10)) + Q = torch.zeros_like(W) + for i1 in range(0, ncols, block_size): + i2 = min(i1 + block_size, ncols) + W_block = W[:, i1:i2].clone() + Q_block = torch.zeros_like(W_block) + Err_block = torch.zeros_like(W_block) + Hinv_block = Hinv[i1:i2, i1:i2] + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + q_col = torch.clamp(torch.round(w_col / row_scale), -clip_range, clip_range) + Q_block[:, j] = q_col + deq_col = q_col * row_scale + err = (w_col - deq_col) / d.clamp_min(1e-10) + Err_block[:, j] = err + if j + 1 < i2 - i1: + W_block[:, j + 1:] -= err[:, None] * Hinv_block[j, j + 1:][None, :] + Q[:, i1:i2] = Q_block + if i2 < ncols: + W[:, i2:] -= Err_block @ Hinv[i1:i2, i2:] + Q = Q[:, invperm] + return Q.to(torch.int8), row_scale.to(torch.float16) +def _gptq_cal( + model , + train_pattern , + device , + n_samples = 256, + seq_len = 2048, +): + hessians = {} + n_seen = {} + hooks = [] + for name, module in model.named_modules(): + if isinstance(module, (nn.Linear, CL)): + if module.weight.ndim == 2 and module.weight.numel() > 65536: + def make_hook(layer_name , in_features ): + def hook_fn(mod, inp, out): + x = inp[0] + if x.ndim == 3: + x = x.reshape(-1, x.size(-1)) + x = x.float() + xtx = x.T @ x + if layer_name not in hessians: + hessians[layer_name] = torch.zeros(in_features, in_features, device=device) + n_seen[layer_name] = 0 + hessians[layer_name] += xtx + n_seen[layer_name] += x.shape[0] + return hook_fn + h = module.register_forward_hook(make_hook(name, module.in_features)) + hooks.append(h) + stream = TS(train_pattern) + model.eval() + samples_run = 0 + with torch.inference_mode(): + while samples_run < n_samples: + batch_size = min(4, n_samples - samples_run) + tokens = stream.take(batch_size * (seq_len + 1)).to(dtype=torch.int64, device=device) + tokens = tokens[:batch_size * (seq_len + 1)].reshape(batch_size, seq_len + 1) + x = tokens[:, :seq_len] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + _ = model.fwd_l(x) + samples_run += batch_size + for h in hooks: + h.remove() + for name in hessians: + if n_seen[name] > 0: + hessians[name] = hessians[name].clone() / n_seen[name] + model.train() + return hessians +def _clsp(name ): + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def _mq5g( + state_dict: dict[str, Tensor], + int5_cats: set[str], + hessians: dict[str, Tensor], + block_size = 128, + percdamp = 0.01, + prune_pct = 0.0, # CHANGE 8: magnitude pruning +): + result = {} + meta = {} + gptq_count = 0 + naive_count = 0 + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _clsp(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in _CTRL): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + # CHANGE 8: Apply magnitude pruning before quantization + if prune_pct > 0 and t.ndim == 2: + with torch.no_grad(): + abs_t = t.float().abs() + threshold = torch.quantile(abs_t.flatten(), prune_pct) + t = t.clone() + t[abs_t < threshold] = 0.0 + if cat in int5_cats and t.ndim >= 1: + h_key = name.rsplit(".", 1)[0] if name.endswith((".weight", ".bias")) else name + if t.ndim == 2 and h_key in hessians: + H = hessians[h_key].cpu() + if H.shape[0] == t.shape[1] and H.shape[1] == t.shape[1]: + q, s = _gptq_qw(t, H, clip_range=15, block_size=block_size, percdamp=percdamp) + gptq_count += 1 + else: + q, s = _qi5(t) + naive_count += 1 + else: + q, s = _qi5(t) + naive_count += 1 + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int5"} + else: + q, s = _qf(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + return result, meta, gptq_count, naive_count +def _dq5(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]): + out = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out +class RN(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x ): + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CL(nn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._qat_enabled = False + self.register_buffer("_soft_round_alpha", torch.tensor(1.0), persistent=False) + def forward(self, x ): + w = self.weight.to(x.dtype) + if self._qat_enabled and self.training and w.ndim == 2: + alpha = self._soft_round_alpha.item() + with torch.no_grad(): + w32 = self.weight.float() + row_max = w32.abs().amax(dim=1) + scale = (row_max / 15.0).clamp_min(1.0 / 15.0) + w_scaled = w32 / scale[:, None] + w_rounded = torch.round(w_scaled) + diff = w_rounded - w_scaled + w_soft = w_scaled + (1.0 / (2.0 * alpha)) * torch.tanh(alpha * diff) + w_soft = torch.clamp(w_soft, -15, 15) + w_q = (w_soft * scale[:, None]).to(x.dtype) + w = w + (w_q - w).detach() + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def _sqat(module , enabled , alpha = 1.0): + for m in module.modules(): + if isinstance(m, CL): + m._qat_enabled = enabled + m._soft_round_alpha.fill_(alpha) +def _fp32(module ): + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in _CTRL)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim , base = 10000.0, train_seq_len = 1024, rope_dims = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def _ic(self): + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + def forward(self, seq_len , device , dtype: torch.dtype): + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached != seq_len + or self._cos_cached.device != device + ): + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :] + self._sin_cached = freqs.sin()[None, :, None, :] + self._seq_len_cached = seq_len + return self._cos_cached.to(dtype=dtype), self._sin_cached.to(dtype=dtype) +def _arope(x , cos , sin , rope_dims = 0): + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) +class CSA(nn.Module): + def __init__(self, dim , num_heads , num_kv_heads , rope_base , qk_gain_init , + gated_attn = False, gated_attn_bias_init = 4.0): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + kv_dim = self.num_kv_heads * self.head_dim + self.c_q = CL(dim, dim, bias=False) + self.c_k = CL(dim, kv_dim, bias=False) + self.c_v = CL(dim, kv_dim, bias=False) + self.proj = CL(dim, dim, bias=False) + self.proj._zero_init = True + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False + self.use_gated_attn = gated_attn + if gated_attn: + self.attn_gate = nn.Linear(dim, num_heads, bias=True) + nn.init.zeros_(self.attn_gate.weight) + nn.init.constant_(self.attn_gate.bias, gated_attn_bias_init) + def _xsa(self, y , v ): + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) + vn = F.normalize(v, dim=-1).unsqueeze(-2) + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x , v_embed: Tensor | None = None, + v0 = None, vr_lambda: nn.Parameter | None = None, + ): + bsz, seqlen, dim = x.shape + q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = self.c_v(x) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + raw_v = v + if v0 is not None and vr_lambda is not None: + lam = torch.softmax(vr_lambda, dim=0) + v = lam[0] * v0 + lam[1] * v + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = _arope(q, cos, sin, self.rope_dims) + k = _arope(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa(y, v) + if self.use_gated_attn: + gate = torch.sigmoid(self.attn_gate(x)) + y = y * gate.unsqueeze(-1) + y = y.reshape(bsz, seqlen, dim) + return self.proj(y), raw_v +class SG(nn.Module): + def __init__(self, dim ): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x ): + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev +class BHE(nn.Module): + def __init__(self, bigram_vocab_size , bigram_dim , model_dim ): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CL(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def _bh(self, tokens ): + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids ): + h = self.embed(self._bh(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class VEm(nn.Module): + def __init__(self, vocab_size , ve_dim , model_dim ): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CL(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids ): + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) +class MLP(nn.Module): + def __init__(self, dim , mlp_mult ): + super().__init__() + hidden = int(mlp_mult * dim) + self.fc = CL(dim, hidden, bias=False) + self.proj = CL(hidden, dim, bias=False) + self.proj._zero_init = True + def forward(self, x ): + x = F.leaky_relu(self.fc(x), negative_slope=0.9) + return self.proj(x.square()) +class Block(nn.Module): + def __init__(self, dim , num_heads , num_kv_heads , mlp_mult , + rope_base , qk_gain_init , layer_idx = 0, + ln_scale = False, dtg = False, + gated_attn = False, gated_attn_bias_init = 4.0, + vrl_enabled = False): + super().__init__() + self.layer_idx = layer_idx + self.attn_norm = RN() + self.mlp_norm = RN() + self.attn = CSA(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, + gated_attn=gated_attn, gated_attn_bias_init=gated_attn_bias_init) + self.mlp = MLP(dim, mlp_mult) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + if dtg: + self.dtg_gate = nn.Linear(dim, 1, bias=True) + nn.init.zeros_(self.dtg_gate.weight) + nn.init.constant_(self.dtg_gate.bias, 2.0) + else: + self.dtg_gate = None + if vrl_enabled and layer_idx > 0: + self.vr_lambda = nn.Parameter(torch.tensor([0.5, 0.5], dtype=torch.float32)) + else: + self.vr_lambda = None + def forward(self, x , x0 , v_embed: Tensor | None = None, + v0 = None): + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out, raw_v = self.attn( + self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, + v0=v0, vr_lambda=self.vr_lambda, + ) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + x_out = x_out + self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor) + if self.dtg_gate is not None: + gate = torch.sigmoid(self.dtg_gate(x_in.detach())) + x_out = x_in + gate * (x_out - x_in) + return x_out, raw_v +class TNT: + def __init__(self, vocab_size , device , complement_alpha = 0.5): + self.V = vocab_size + self.alpha = complement_alpha + self.bi_counts = torch.zeros(vocab_size, vocab_size, device=device, dtype=torch.int64) + self.bi_totals = torch.zeros(vocab_size, device=device, dtype=torch.int64) + @torch.no_grad() + def update(self, x , y ): + xf = x.reshape(-1) + yf = y.reshape(-1) + ones = torch.ones(xf.numel(), device=xf.device, dtype=torch.int64) + self.bi_counts.reshape(-1).scatter_add_(0, (xf * self.V + yf).long(), ones) + self.bi_totals.scatter_add_(0, xf.long(), ones) + def get_weights(self, x , y ): + xf = x.reshape(-1) + yf = y.reshape(-1) + total = self.bi_totals[xf.long()].float() + count = self.bi_counts.reshape(-1)[(xf * self.V + yf).long()].float() + ngram_prob = count / (total + 1.0) + return (1.0 - self.alpha * ngram_prob).clamp(min=0.1) +class GPT(nn.Module): + def __init__(self, vocab_size , num_layers , model_dim , num_heads , + num_kv_heads , mlp_mult , tie_embeddings , tied_embed_init_std , + logit_softcap , rope_base , qk_gain_init , + mtp_num_heads = 0, mtp_loss_weight = 0.1, + bigram_vocab_size = 0, bigram_dim = 128, xsa_last_n = 0, + rope_dims = 0, ln_scale = False, dtg = False, + ve_enabled = False, ve_dim = 128, ve_layers = "9,10", + vrl_enabled = False, gated_attn = False, gated_attn_bias_init = 4.0): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) + self.vrl_enabled = vrl_enabled + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.mtp_num_heads = mtp_num_heads + self.mtp_loss_weight = mtp_loss_weight + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BHE(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SG(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + self.blocks = nn.ModuleList([ + Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init, + layer_idx=i, ln_scale=ln_scale, dtg=dtg, + gated_attn=gated_attn, gated_attn_bias_init=gated_attn_bias_init, + vrl_enabled=vrl_enabled) + for i in range(num_layers) + ]) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = VEm(vocab_size, ve_dim, kv_dim) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() + self.final_norm = RN() + self.lm_head = None if tie_embeddings else CL(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + self.mtp_heads = nn.ModuleList( + [CL(model_dim, vocab_size, bias=False) for _ in range(mtp_num_heads)] + ) + for head in self.mtp_heads: + head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._iw() + def _iw(self): + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + num_layers = len(self.blocks) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + if ".proj." in name or name.endswith(".proj"): + with torch.no_grad(): + module.weight.mul_(1.0 / math.sqrt(2 * num_layers)) + def _get_ve(self, layer_idx , input_ids , ve_cache: dict | None = None): + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def _irc(self): + for block in self.blocks: + block.attn.rotary._ic() + def forward(self, input_ids , target_ids ): + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips = [] + ve_cache: dict = {} + v0 = None + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, v_embed=ve, v0=v0) + if i == 0 and self.vrl_enabled: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, v_embed=ve, v0=v0) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + if hasattr(self, '_ngram_tracker') and self._ngram_tracker is not None and self.training: + per_tok_loss = F.cross_entropy(logits.float(), targets, reduction="none") + weights = self._ngram_tracker.get_weights(input_ids, target_ids) + main_loss = (per_tok_loss * weights).mean() + else: + main_loss = F.cross_entropy(logits.float(), targets, reduction="mean") + if self.training and self.mtp_num_heads > 0 and self.mtp_loss_weight > 0.0: + _, seqlen, dim = x.shape + mtp_loss_sum = x.new_zeros(()) + mtp_loss_count = 0 + for k, mtp_head in enumerate(self.mtp_heads): + valid_t = seqlen - (k + 1) + if valid_t <= 0: + continue + mtp_hidden = x[:, :valid_t, :].reshape(-1, dim) + mtp_targets = target_ids[:, k + 1 :].reshape(-1) + mtp_logits_proj = mtp_head(mtp_hidden) + mtp_logits = self.logit_softcap * torch.tanh(mtp_logits_proj / self.logit_softcap) + mtp_loss_sum = mtp_loss_sum + F.cross_entropy(mtp_logits.float(), mtp_targets, reduction="mean") + mtp_loss_count += 1 + if mtp_loss_count > 0: + main_loss = main_loss + self.mtp_loss_weight * (mtp_loss_sum / mtp_loss_count) + return main_loss + def fwd_l(self, input_ids ): + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips = [] + ve_cache: dict = {} + v0 = None + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x, raw_v = self.blocks[i](x, x0, v_embed=ve, v0=v0) + if i == 0 and self.vrl_enabled: + v0 = raw_v + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x, _ = self.blocks[bi](x, x0, v_embed=ve, v0=v0) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) +def _gsl(args , elapsed_ms , max_wallclock_ms ): + if not args.prog_seq_enabled: + return args.train_seq_len + frac = elapsed_ms / max(max_wallclock_ms, 1e-9) + if frac < args.prog_seq_frac1: + return args.prog_seq_phase1_len + elif frac < args.prog_seq_frac2: + return args.prog_seq_phase2_len + else: + return args.prog_seq_phase3_len +def main(): + global _ns5 + code = Path(__file__).read_text(encoding="utf-8") + args = HP() + _ns5 = torch.compile(_ns5) + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(_E("RANK", "0")) + world_size = int(_E("WORLD_SIZE", "1")) + local_rank = int(_E("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so gas stays integral") + gas = 8 // world_size + gsc = 1.0 / gas + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + mp = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if mp: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg , console = True): + if not mp: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0(subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, console=False) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError(f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}") + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + eesl = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, eesl) + vtok = _ld_val(args.val_files, val_seq_len) + bblut, hlslut, ibtlut = _sp_luts(sp, args.vocab_size, device) + log0(f"bpb:sp={args.tokenizer_path}") + log0(f"tl:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val:{args.val_files} tokens:{vtok.numel() - 1}") + log0(f"v:opti-ms2 act:lr09sq xsa:last_{args.xsa_last_n} " + f"qat:sr wd:WSD gptq:fh ttt:lora_polyak_adamw compression:lzma " + f"optimizer:PM prog_seq:enabled vrl:{args.vrl_enabled} gated_attn:{args.gated_attn_enabled} " + f"decoder_lr_mult:{args.decoder_lr_mult} ngram:{args.ngram_enabled} ngram_order:{args.ngram_order} " + f"complement_alpha:{args.complement_alpha} prune_pct:{args.prune_pct}") + bm = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=args.mtp_num_heads, mtp_loss_weight=args.mtp_loss_weight, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale, + dtg=args.dtg_enabled, ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + vrl_enabled=args.vrl_enabled, gated_attn=args.gated_attn_enabled, + gated_attn_bias_init=args.gated_attn_bias_init, + ).to(device).bfloat16() + for module in bm.modules(): + if isinstance(module, CL): + module.float() + _fp32(bm) + complement_alpha = args.complement_alpha + if complement_alpha > 0: + ngram_tracker = TNT(args.vocab_size, device, complement_alpha=complement_alpha) + bm._ngram_tracker = ngram_tracker + log0(f"comp:on alpha={complement_alpha}") + else: + bm._ngram_tracker = None + log0("comp:off") + cm = torch.compile(bm, dynamic=False, fullgraph=False) + model = DDP(cm, device_ids=[local_rank], broadcast_buffers=False) if distributed else cm + ebi = set(range(bm.num_encoder_layers)) + dbi = set(range(bm.num_encoder_layers, len(bm.blocks))) + emp = [] + dmp = [] + scp = [] + for block_idx, block in enumerate(bm.blocks): + for name, p in block.named_parameters(): + if p.ndim == 2 and not any(pattern in name for pattern in _CTRL): + if block_idx in ebi: + emp.append(p) + else: + dmp.append(p) + elif p.ndim < 2 or any(pattern in name for pattern in _CTRL): + scp.append(p) + if bm.mtp_num_heads > 0: + emp.extend([p for p in bm.mtp_heads.parameters() if p.ndim == 2]) + if bm.skip_weights.numel() > 0: + scp.append(bm.skip_weights) + scp.append(bm.smear.gate) + if bm.bigram is not None: + scp.append(bm.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [bm.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if bm.bigram is not None: + tok_params.append({"params": [bm.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if bm.bigram.proj is not None: + emp.append(bm.bigram.proj.weight) + if bm.ve_shared is not None: + tok_params.append({"params": [bm.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if bm.ve_shared.proj is not None: + emp.append(bm.ve_shared.proj.weight) + scp.append(bm.ve_shared.scale) + for s in bm.ve_layer_scales: + scp.append(s) + opt_t = torch.optim.AdamW( + tok_params, betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.adam_wd, fused=True, + ) + encoder_lr = args.matrix_lr + decoder_lr = args.matrix_lr * args.decoder_lr_mult + amp2 = emp + dmp + epi = {id(p) for p in emp} + opt_m = PM( + amp2, lr=args.matrix_lr, momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, weight_decay=args.muon_wd, + ) + enc_params = [p for p in amp2 if id(p) in epi] + dec_params = [p for p in amp2 if id(p) not in epi] + opt_m.param_groups = [] + if enc_params: + opt_m.add_param_group({ + "params": enc_params, "lr": encoder_lr, "base_lr": encoder_lr, + "momentum": args.muon_momentum, "backend_steps": args.muon_backend_steps, + "nesterov": True, "weight_decay": args.muon_wd, + }) + if dec_params: + opt_m.add_param_group({ + "params": dec_params, "lr": decoder_lr, "base_lr": decoder_lr, + "momentum": args.muon_momentum, "backend_steps": args.muon_backend_steps, + "nesterov": True, "weight_decay": args.muon_wd, + }) + opt_m._bb() + opt_s = torch.optim.AdamW( + [{"params": scp, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, + weight_decay=args.adam_wd, fused=True, + ) + optimizers = [opt_t, opt_m, opt_s] + if bm.lm_head is not None: + opt_h = torch.optim.Adam( + [{"params": [bm.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), eps=args.adam_eps, fused=True, + ) + optimizers.insert(1, opt_h) + n_params = sum(p.numel() for p in bm.parameters()) + mtp_params = sum(p.numel() for p in bm.mtp_heads.parameters()) + log0(f"model_params:{n_params}") + log0(f"mtp_num_heads:{args.mtp_num_heads} mtp_loss_weight:{args.mtp_loss_weight} mtp_params:{mtp_params}") + xsa_layers = [i for i, b in enumerate(bm.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} gas:{gas}") + log0(f"seed:{args.seed}") + log0(f"lr_schedule:WSD wsd_stable_frac:{args.wsd_stable_frac} decay_shape:cosine qat_trigger_frac:{args.qat_trigger_frac}") + log0(f"optimizer:PM (NS)") + if args.prog_seq_enabled: + log0(f"pseq:on phases={args.prog_seq_phase1_len}/{args.prog_seq_phase2_len}/{args.prog_seq_phase3_len} " + f"fracs={args.prog_seq_frac1:.2f}/{args.prog_seq_frac2:.2f}") + log0(f"ttt_config: optimizer=AdamW(LoRA) epochs={args.ttt_epochs} lr={args.ttt_lr} " + f"lora_rank={args.ttt_lora_rank} polyak_decay={args.ttt_polyak_decay} " + f"chunk={args.ttt_chunk_tokens} " + f"grad_clip={args.ttt_grad_clip} temperature={args.ttt_temperature}") + log0(f"decoder_lr_mult:{args.decoder_lr_mult} encoder_matrix_lr:{encoder_lr} decoder_matrix_lr:{decoder_lr}") + ga_layers = [i for i, b in enumerate(bm.blocks) if b.attn.use_gated_attn] + vrl_layers = [i for i, b in enumerate(bm.blocks) if b.vr_lambda is not None] + log0(f"gated_attn:{'enabled' if args.gated_attn_enabled else 'disabled'} layers:{ga_layers}") + log0(f"value_residual:{'enabled' if args.vrl_enabled else 'disabled'} layers:{vrl_layers}") + log0(f"gptq_config: samples={args.gptq_n_samples} block_size={args.gptq_block_size} damp={args.gptq_percdamp}") + log0(f"prune_pct:{args.prune_pct}") + tl = DTL(args.train_files, rank, world_size, device) + def zero_grad_all(): + for opt in optimizers: + opt.zero_grad(set_to_none=True) + if args.max_wallclock_seconds <= 0: + raise RuntimeError( + f"FATAL: MAX_WALLCLOCK_SECONDS={args.max_wallclock_seconds} disables the training time cap. " + f"Competition rules require MAX_WALLCLOCK_SECONDS=600. NEVER set to 0 on paid runs." + ) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds + def lr_mul(step , elapsed_ms ): + if max_wallclock_ms is None: + total = max(args.iterations, 1) + stable_end = int(args.wsd_stable_frac * total) + if step <= stable_end: + return 1.0 + decay_steps = max(total - stable_end, 1) + progress = min((step - stable_end) / decay_steps, 1.0) + return 0.5 * (1.0 + math.cos(math.pi * progress)) + stable_ms = args.wsd_stable_frac * max_wallclock_ms + if elapsed_ms <= stable_ms: + return 1.0 + decay_ms = max(max_wallclock_ms - stable_ms, 1e-9) + progress = min((elapsed_ms - stable_ms) / decay_ms, 1.0) + return 0.5 * (1.0 + math.cos(math.pi * progress)) + if args.warmup_steps > 0: + ims = {name: tensor.detach().cpu().clone() for name, tensor in bm.state_dict().items()} + ios = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + warmup_seq = args.prog_seq_phase1_len if args.prog_seq_enabled else args.train_seq_len + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(gas): + if distributed: + model.require_backward_grad_sync = micro_step == gas - 1 + x, y = tl.next_batch(args.train_batch_tokens, warmup_seq, gas) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + warmup_loss = model(x, y) + (warmup_loss * gsc).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + if args.prog_seq_enabled: + for extra_seq in [args.prog_seq_phase2_len, args.prog_seq_phase3_len]: + if extra_seq != warmup_seq: + zero_grad_all() + x, y = tl.next_batch(args.train_batch_tokens, extra_seq, gas) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss_w = model(x, y) + (loss_w * gsc).backward() + for opt in optimizers: + opt.step() + zero_grad_all() + log0(f"warmup:primed seq_len={extra_seq}") + bm.load_state_dict(ims, strict=True) + for opt, state in zip(optimizers, ios, strict=True): + opt.load_state_dict(state) + zero_grad_all() + if distributed: + model.require_backward_grad_sync = True + tl = DTL(args.train_files, rank, world_size, device) + ema_state = {name: t.detach().float().clone() for name, t in bm.state_dict().items()} + ema_decay = 0.997 + swa_state = None + swa_count = 0 + swa_every = 50 + swa_scale_threshold = 0.2 + ttms = 0.0 + stop_after_step = None + csl = args.prog_seq_phase1_len if args.prog_seq_enabled else args.train_seq_len + qat_start_ms = None + _sync() + t0 = _now() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + _sync() + ttms += 1000.0 * (_now() - t0) + val_loss, val_bpb = _ev( + args, model, rank, world_size, device, gas, + vtok, bblut, hlslut, ibtlut, + ) + log0(f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{ttms:.0f}ms step_avg:{ttms / max(step, 1):.2f}ms " + f"seq_len:{csl}") + _sync() + t0 = _now() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0(f"early_stop train_time:{ttms:.0f}ms step:{step}/{args.iterations}") + break + elapsed_ms = ttms + 1000.0 * (_now() - t0) + if args.prog_seq_enabled and max_wallclock_ms is not None: + new_seq_len = _gsl(args, elapsed_ms, max_wallclock_ms) + if new_seq_len != csl: + log0(f"prog_seq:transition seq_len {csl}->{new_seq_len} step:{step} elapsed_ms:{elapsed_ms:.0f}") + csl = new_seq_len + bm._irc() + scale = lr_mul(step, elapsed_ms) + is_qat_on = any(isinstance(m, CL) and m._qat_enabled for m in bm.modules()) + if not is_qat_on and args.late_qat_threshold > 0: + trigger = False + if max_wallclock_ms is not None: + if elapsed_ms >= args.qat_trigger_frac * max_wallclock_ms: + trigger = True + else: + if step >= int(args.qat_trigger_frac * args.iterations): + trigger = True + if trigger: + _sqat(bm, True, alpha=1.0) + qat_start_ms = elapsed_ms + log0(f"late_qat:enabled step:{step} scale:{scale:.4f} elapsed_ms:{elapsed_ms:.0f} trigger:wallclock@{args.qat_trigger_frac}") + elif is_qat_on and qat_start_ms is not None: + remaining_ms = max((max_wallclock_ms or float('inf')) - elapsed_ms, 0) + qat_total_ms = max(elapsed_ms - qat_start_ms, 1e-9) + if max_wallclock_ms is not None: + total_qat_ms = max(max_wallclock_ms - qat_start_ms, 1e-9) + alpha_frac = min(qat_total_ms / total_qat_ms, 1.0) + else: + alpha_frac = min(qat_total_ms / max(1000.0 * args.max_wallclock_seconds * (1.0 - args.qat_trigger_frac), 1e-9), 1.0) + alpha = 1.0 + 15.0 * alpha_frac + _sqat(bm, True, alpha=alpha) + zero_grad_all() + train_loss = torch.zeros((), device=device) + micro_batches: list[tuple[Tensor, Tensor]] = [] + for micro_step in range(gas): + if distributed: + model.require_backward_grad_sync = micro_step == gas - 1 + x, y = tl.next_batch(args.train_batch_tokens, csl, gas) + if bm._ngram_tracker is not None: + micro_batches.append((x.detach(), y.detach())) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + loss = model(x, y) + train_loss += loss.detach() + (loss * gsc).backward() + train_loss /= gas + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in opt_m.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(bm.parameters(), args.grad_clip_norm) + for opt in optimizers: + opt.step() + zero_grad_all() + if bm._ngram_tracker is not None: + for mb_x, mb_y in micro_batches: + bm._ngram_tracker.update(mb_x, mb_y) + with torch.no_grad(): + for name, t in bm.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + if scale < swa_scale_threshold and step % swa_every == 0: + with torch.no_grad(): + if swa_state is None: + swa_state = {name: t.clone() for name, t in ema_state.items()} + swa_count = 1 + else: + for name, t in ema_state.items(): + swa_state[name] += t + swa_count += 1 + step += 1 + approx_ttms = ttms + 1000.0 * (_now() - t0) + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0(f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_ttms:.0f}ms step_avg:{approx_ttms / step:.2f}ms " + f"seq_len:{csl}") + reached_cap = max_wallclock_ms is not None and approx_ttms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0(f"peak_mem: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB") + log0(f"steps:{step}") + log0(f"swa_n:{swa_count}") + if swa_state is not None and swa_count > 0: + log0(f"swa:blending EMA+SWA (0.7*EMA + 0.3*SWA, {swa_count} SWA checkpoints)") + swa_avg = {name: (swa_state[name] / swa_count) for name in swa_state} + avg_state = {} + for name, ema_t in ema_state.items(): + dtype = bm.state_dict()[name].dtype + avg_state[name] = (0.7 * ema_t + 0.3 * swa_avg[name]).to(dtype=dtype) + else: + log0("ema:applying EMA weights (no SWA checkpoints collected)") + avg_state = {name: t.to(dtype=bm.state_dict()[name].dtype) + for name, t in ema_state.items()} + bm.load_state_dict(avg_state, strict=True) + _sync() + t_diag = _now() + diag_val_loss, diag_val_bpb = _ev( + args, cm, rank, world_size, device, gas, + vtok, bblut, hlslut, ibtlut, + ) + _sync() + log0(f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (_now() - t_diag):.0f}ms") + log0(f"gptq:calibrating with {args.gptq_n_samples} samples...") + _sync() + t_cal = _now() + hessians = _gptq_cal(bm, args.train_files, device, n_samples=args.gptq_n_samples, seq_len=args.train_seq_len) + _sync() + log0(f"gptq:calibration done in {1000.0 * (_now() - t_cal):.0f}ms, {len(hessians)} layers") + full_state_dict = bm.state_dict() + export_sd = {k: v for k, v in full_state_dict.items() if "mtp_heads" not in k} + excluded_mtp = sum(int(t.numel()) for k, t in full_state_dict.items() if "mtp_heads" in k) + if excluded_mtp > 0: + log0(f"mtp_excl:{excluded_mtp}") + if mp: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = len(code.encode("utf-8")) + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + log0(f"gptq:quantizing with block_size={args.gptq_block_size} percdamp={args.gptq_percdamp} prune_pct={args.prune_pct}...") + _sync() + t_quant = _now() + quant_result, quant_meta, gptq_count, naive_count = _mq5g( + sd_cpu, {"mlp", "attn"}, hessians, + block_size=args.gptq_block_size, percdamp=args.gptq_percdamp, + prune_pct=args.prune_pct, # CHANGE 8: pass prune_pct + ) + _sync() + log0(f"gptq:quantization done in {1000.0 * (_now() - t_quant):.0f}ms (gptq:{gptq_count} naive:{naive_count})") + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=9 | lzma.PRESET_EXTREME) + best_compressor = "lzma" + log0(f"compression:{best_compressor} raw_size:{len(quant_raw)} compressed_size:{len(quant_blob)} " + f"ratio:{len(quant_raw)/len(quant_blob):.2f}x") + if mp: + with open("final_model.int5.ptz", "wb") as f: + f.write(b"LZMA") + f.write(quant_blob) + quant_file_bytes = 4 + len(quant_blob) + code_bytes = len(code.encode("utf-8")) + total_bytes = quant_file_bytes + code_bytes + log0(f"Serialized model int5+{best_compressor}: {quant_file_bytes} bytes") + log0(f"Total submission size: {total_bytes} bytes") + if total_bytes > 16_000_000: + raise RuntimeError( + f"FATAL: Total submission size {total_bytes} exceeds 16MB limit! " + f"Delta: +{total_bytes - 16_000_000} bytes. " + f"Artifact: {quant_file_bytes} bytes, Code: {code_bytes} bytes." + ) + else: + log0(f"Size budget OK: {16_000_000 - total_bytes} bytes remaining") + if distributed: + dist.barrier() + with open("final_model.int5.ptz", "rb") as f: + comp_id = f.read(4) + quant_blob_disk = f.read() + if comp_id == b"LZMA": + raw_bytes = lzma.decompress(quant_blob_disk) + elif comp_id == b"ZSTD": + raw_bytes = zstandard.ZstdDecompressor().decompress(quant_blob_disk) + elif comp_id == b"ZLIB": + raw_bytes = zlib.decompress(quant_blob_disk) + else: + raise ValueError(f"Unknown compressor ID: {comp_id!r}") + quant_state = torch.load(io.BytesIO(raw_bytes), map_location="cpu", weights_only=False) + deq_state = _dq5(quant_state["w"], quant_state["m"], sd_cpu) + em = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + mtp_num_heads=0, mtp_loss_weight=0.0, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale, + dtg=args.dtg_enabled, ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + vrl_enabled=args.vrl_enabled, gated_attn=args.gated_attn_enabled, + gated_attn_bias_init=args.gated_attn_bias_init, + ).to(device).bfloat16() + for m in em.modules(): + if isinstance(m, CL): + m.float() + _fp32(em) + em.load_state_dict(deq_state, strict=True) + ce = torch.compile(em, dynamic=False, fullgraph=False) + _sync() + t_eval_start = _now() # Start of 600s eval budget + EVAL_BUDGET_S = 600.0 + t_qeval = _now() + q_val_loss, q_val_bpb = _ev( + args, ce, rank, world_size, device, gas, + vtok, bblut, hlslut, ibtlut, + eval_seq_len=eesl, + ) + _sync() + log0(f"final_int5_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (_now() - t_qeval):.0f}ms") + log0(f"final_int5_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + + # ============================================================ + # CHANGE 1: Single integrated eval pass (TTT + N-gram combined) + # ============================================================ + eval_elapsed = _now() - t_eval_start + eval_remaining = EVAL_BUDGET_S - eval_elapsed + log0(f"eval_budget: {eval_elapsed:.0f}s elapsed, {eval_remaining:.0f}s remaining") + + ngram_time_ms = 0.0 + ngram_val_bpb = q_val_bpb + ngram_val_loss = q_val_loss + ngram_model_bpb = q_val_bpb + ngram_model_loss = q_val_loss + ttt_val_bpb = q_val_bpb + ttt_val_loss = q_val_loss + + if eval_remaining < 60: + log0("eval_budget: SKIPPING integrated eval -- insufficient time") + elif args.ngram_enabled: + log0("integrated_eval: starting single-pass TTT+N-gram evaluation") + _sync() + t_integrated = _now() + ngram_model_loss, ngram_model_bpb, ngram_val_loss, ngram_val_bpb = _ev_integrated( + args, em, rank, world_size, device, + vtok, bblut, hlslut, ibtlut, + stride=args.eval_stride, + batch_seqs=args.ngram_batch_seqs, + eval_seq_len=eesl, + log_fn=log0, + time_budget_s=eval_remaining, + ) + _sync() + ngram_time_ms = 1000.0 * (_now() - t_integrated) + ttt_val_bpb = ngram_model_bpb # model BPB (with TTT LoRA but without n-gram) + ttt_val_loss = ngram_model_loss + log0(f"final_integrated model_val_loss:{ngram_model_loss:.4f} model_val_bpb:{ngram_model_bpb:.4f}") + log0(f"final_integrated ngram_val_loss:{ngram_val_loss:.4f} ngram_val_bpb:{ngram_val_bpb:.4f}") + log0(f"final_integrated delta_bpb:{ngram_val_bpb - ngram_model_bpb:.4f}") + log0(f"final_integrated_exact model_bpb:{ngram_model_bpb:.8f} ngram_bpb:{ngram_val_bpb:.8f}") + log0(f"final_integrated_time:{ngram_time_ms:.0f}ms") + else: + log0("ngram:DISABLED (set NGRAM_ENABLED=1 to enable)") + + if mp: + best_bpb = ngram_val_bpb if args.ngram_enabled else ttt_val_bpb + submission = { + "name": "opti-ms2", + "github_id": _E("GITHUB_ID", "callithyia"), + "variant": "opti-ms2", + "description": "Optimised MS2: integrated eval, LoRA TTT, per-order entropy centers, pruning", + "base": "ms2v2", + "val_bpb": round(best_bpb, 8), + "ttt_val_bpb": round(ttt_val_bpb, 8), + "ngram_model_bpb": round(ngram_model_bpb, 8), + "ngram_val_bpb": round(ngram_val_bpb, 8), + "ngram_delta_bpb": round(ngram_val_bpb - ngram_model_bpb, 8) if args.ngram_enabled else 0.0, + "seed": args.seed, + "wsd_stable_frac": args.wsd_stable_frac, + "qat_trigger_frac": args.qat_trigger_frac, + "prog_seq_enabled": args.prog_seq_enabled, + "ttt_optimizer": "AdamW(LoRA)", + "ttt_epochs": args.ttt_epochs, + "ttt_lr": args.ttt_lr, + "ttt_lora_rank": args.ttt_lora_rank, + "ttt_polyak_decay": args.ttt_polyak_decay, + "ttt_temperature": args.ttt_temperature, + "vrl_enabled": args.vrl_enabled, + "gated_attn_enabled": args.gated_attn_enabled, + "decoder_lr_mult": args.decoder_lr_mult, + "compression": "lzma", + "ngram_enabled": args.ngram_enabled, + "ngram_order": args.ngram_order, + "ngram_min_order": args.ngram_min_order, + "ngram_adaptive": args.ngram_adaptive, + "ngram_alpha_min": args.ngram_alpha_min, + "ngram_alpha_max": args.ngram_alpha_max, + "ngram_ent_center": args.ngram_ent_center, + "ngram_ent_scale": args.ngram_ent_scale, + "ngram_min_count": args.ngram_min_count, + "ngram_buckets": args.ngram_buckets, + "complement_alpha": args.complement_alpha, + "prune_pct": args.prune_pct, + "activation": "leaky_relu_0.9_sq", + "total_steps": step, + "ttms": round(ttms, 1), + "integrated_eval_time_ms": round(ngram_time_ms, 1), + } + with open("submission.json", "w") as _f: + json.dump(submission, _f, indent=2) + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/train_seed1337.log b/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/train_seed1337.log new file mode 100644 index 000000000..6155bf6ca --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/train_seed1337.log @@ -0,0 +1,524 @@ +W0326 12:06:11.759000 97967 torch/distributed/run.py:803] +W0326 12:06:11.759000 97967 torch/distributed/run.py:803] ***************************************** +W0326 12:06:11.759000 97967 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 12:06:11.759000 97967 torch/distributed/run.py:803] ***************************************** +logs/83238d57-5c52-47e7-bbc0-85438627a03b.txt +bpb:sp=./data/tokenizers/fineweb_1024_bpe.model +tl:dataset:fineweb10B_sp1024 train_shards:80 +val:./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +v:opti-ms2 act:lr09sq xsa:last_4 qat:sr wd:WSD gptq:fh ttt:lora_polyak_adamw compression:lzma optimizer:PM prog_seq:enabled vrl:True gated_attn:True decoder_lr_mult:2.0 ngram:True ngram_order:9 complement_alpha:0.5 prune_pct:0.03 +comp:on alpha=0.5 +model_params:27301064 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 gas:1 +seed:1337 +lr_schedule:WSD wsd_stable_frac:0.75 decay_shape:cosine qat_trigger_frac:0.85 +optimizer:PM (NS) +ttt_config: optimizer=AdamW(LoRA) epochs=1 lr=0.003 lora_rank=8 polyak_decay=0.998 chunk=65536 grad_clip=1.0 temperature=0.98 +decoder_lr_mult:2.0 encoder_matrix_lr:0.025 decoder_matrix_lr:0.05 +gated_attn:enabled layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +value_residual:enabled layers:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +gptq_config: samples=256 block_size=128 damp=0.01 +prune_pct:0.03 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9302 val_bpb:4.1045 train_time:0ms step_avg:0.04ms seq_len:2048 +step:1/20000 train_loss:6.9329 train_time:183ms step_avg:183.33ms seq_len:2048 +step:2/20000 train_loss:8.5659 train_time:288ms step_avg:144.24ms seq_len:2048 +step:3/20000 train_loss:7.7740 train_time:379ms step_avg:126.35ms seq_len:2048 +step:4/20000 train_loss:7.1435 train_time:470ms step_avg:117.50ms seq_len:2048 +step:5/20000 train_loss:6.8554 train_time:560ms step_avg:112.09ms seq_len:2048 +step:6/20000 train_loss:6.7643 train_time:652ms step_avg:108.61ms seq_len:2048 +step:7/20000 train_loss:6.6589 train_time:742ms step_avg:105.96ms seq_len:2048 +step:8/20000 train_loss:6.6349 train_time:833ms step_avg:104.11ms seq_len:2048 +step:9/20000 train_loss:6.3300 train_time:923ms step_avg:102.61ms seq_len:2048 +step:10/20000 train_loss:5.9178 train_time:1014ms step_avg:101.39ms seq_len:2048 +step:500/20000 train_loss:2.3234 train_time:47489ms step_avg:94.98ms seq_len:2048 +step:1000/20000 train_loss:2.2409 train_time:96875ms step_avg:96.87ms seq_len:2048 +step:1500/20000 train_loss:2.2015 train_time:147187ms step_avg:98.12ms seq_len:2048 +step:2000/20000 train_loss:2.0499 train_time:197327ms step_avg:98.66ms seq_len:2048 +step:2500/20000 train_loss:2.1613 train_time:248060ms step_avg:99.22ms seq_len:2048 +step:3000/20000 train_loss:2.1562 train_time:298871ms step_avg:99.62ms seq_len:2048 +step:3500/20000 train_loss:2.1770 train_time:349421ms step_avg:99.83ms seq_len:2048 +step:4000/20000 train_loss:1.9890 train_time:400276ms step_avg:100.07ms seq_len:2048 +step:4000/20000 val_loss:2.1016 val_bpb:1.2447 train_time:400282ms step_avg:100.07ms seq_len:2048 +step:4500/20000 train_loss:2.1494 train_time:450987ms step_avg:100.22ms seq_len:2048 +step:5000/20000 train_loss:2.1221 train_time:501669ms step_avg:100.33ms seq_len:2048 +late_qat:enabled step:5083 scale:0.6540 elapsed_ms:510055 trigger:wallclock@0.85 +[rank0]:W0326 12:16:21.051000 98035 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank0]:W0326 12:16:21.051000 98035 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank0]:W0326 12:16:21.051000 98035 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank0]:W0326 12:16:21.051000 98035 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank0]:W0326 12:16:21.051000 98035 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank0]:W0326 12:16:21.051000 98035 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank0]:W0326 12:16:21.051000 98035 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank0]:W0326 12:16:21.051000 98035 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank0]:W0326 12:16:21.051000 98035 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank0]:W0326 12:16:21.051000 98035 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank0]:W0326 12:16:21.051000 98035 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank0]:W0326 12:16:21.051000 98035 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank0]:W0326 12:16:21.051000 98035 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank0]:W0326 12:16:21.051000 98035 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank0]:W0326 12:16:21.051000 98035 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank5]:W0326 12:16:21.063000 98040 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank5]:W0326 12:16:21.063000 98040 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank5]:W0326 12:16:21.063000 98040 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank5]:W0326 12:16:21.063000 98040 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank5]:W0326 12:16:21.063000 98040 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank5]:W0326 12:16:21.063000 98040 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank5]:W0326 12:16:21.063000 98040 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank5]:W0326 12:16:21.063000 98040 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank5]:W0326 12:16:21.063000 98040 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank5]:W0326 12:16:21.063000 98040 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank5]:W0326 12:16:21.063000 98040 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank5]:W0326 12:16:21.063000 98040 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank5]:W0326 12:16:21.063000 98040 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank5]:W0326 12:16:21.063000 98040 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank5]:W0326 12:16:21.063000 98040 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank3]:W0326 12:16:21.064000 98038 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank3]:W0326 12:16:21.064000 98038 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank3]:W0326 12:16:21.064000 98038 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank3]:W0326 12:16:21.064000 98038 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank3]:W0326 12:16:21.064000 98038 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank3]:W0326 12:16:21.064000 98038 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank3]:W0326 12:16:21.064000 98038 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank3]:W0326 12:16:21.064000 98038 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank3]:W0326 12:16:21.064000 98038 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank3]:W0326 12:16:21.064000 98038 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank3]:W0326 12:16:21.064000 98038 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank3]:W0326 12:16:21.064000 98038 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank3]:W0326 12:16:21.064000 98038 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank3]:W0326 12:16:21.064000 98038 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank3]:W0326 12:16:21.064000 98038 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank4]:W0326 12:16:21.066000 98039 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank4]:W0326 12:16:21.066000 98039 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank4]:W0326 12:16:21.066000 98039 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank4]:W0326 12:16:21.066000 98039 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank4]:W0326 12:16:21.066000 98039 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank4]:W0326 12:16:21.066000 98039 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank4]:W0326 12:16:21.066000 98039 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank4]:W0326 12:16:21.066000 98039 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank4]:W0326 12:16:21.066000 98039 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank4]:W0326 12:16:21.066000 98039 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank4]:W0326 12:16:21.066000 98039 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank4]:W0326 12:16:21.066000 98039 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank4]:W0326 12:16:21.066000 98039 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank4]:W0326 12:16:21.066000 98039 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank4]:W0326 12:16:21.066000 98039 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank2]:W0326 12:16:21.098000 98037 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank2]:W0326 12:16:21.098000 98037 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank2]:W0326 12:16:21.098000 98037 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank2]:W0326 12:16:21.098000 98037 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank2]:W0326 12:16:21.098000 98037 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank2]:W0326 12:16:21.098000 98037 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank2]:W0326 12:16:21.098000 98037 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank2]:W0326 12:16:21.098000 98037 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank2]:W0326 12:16:21.098000 98037 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank2]:W0326 12:16:21.098000 98037 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank2]:W0326 12:16:21.098000 98037 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank2]:W0326 12:16:21.098000 98037 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank2]:W0326 12:16:21.098000 98037 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank2]:W0326 12:16:21.098000 98037 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank2]:W0326 12:16:21.098000 98037 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank1]:W0326 12:16:21.099000 98036 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank1]:W0326 12:16:21.099000 98036 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank1]:W0326 12:16:21.099000 98036 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank1]:W0326 12:16:21.099000 98036 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank1]:W0326 12:16:21.099000 98036 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank1]:W0326 12:16:21.099000 98036 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank1]:W0326 12:16:21.099000 98036 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank1]:W0326 12:16:21.099000 98036 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank1]:W0326 12:16:21.099000 98036 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank1]:W0326 12:16:21.099000 98036 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank1]:W0326 12:16:21.099000 98036 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank1]:W0326 12:16:21.099000 98036 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank1]:W0326 12:16:21.099000 98036 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank1]:W0326 12:16:21.099000 98036 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank1]:W0326 12:16:21.099000 98036 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank6]:W0326 12:16:21.783000 98041 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank6]:W0326 12:16:21.783000 98041 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank6]:W0326 12:16:21.783000 98041 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank6]:W0326 12:16:21.783000 98041 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank6]:W0326 12:16:21.783000 98041 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank6]:W0326 12:16:21.783000 98041 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank6]:W0326 12:16:21.783000 98041 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank6]:W0326 12:16:21.783000 98041 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank6]:W0326 12:16:21.783000 98041 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank6]:W0326 12:16:21.783000 98041 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank6]:W0326 12:16:21.783000 98041 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank6]:W0326 12:16:21.783000 98041 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank6]:W0326 12:16:21.783000 98041 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank6]:W0326 12:16:21.783000 98041 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank6]:W0326 12:16:21.783000 98041 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank7]:W0326 12:16:21.810000 98042 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank7]:W0326 12:16:21.810000 98042 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank7]:W0326 12:16:21.810000 98042 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank7]:W0326 12:16:21.810000 98042 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank7]:W0326 12:16:21.810000 98042 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank7]:W0326 12:16:21.810000 98042 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank7]:W0326 12:16:21.810000 98042 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank7]:W0326 12:16:21.810000 98042 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank7]:W0326 12:16:21.810000 98042 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank7]:W0326 12:16:21.810000 98042 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank7]:W0326 12:16:21.810000 98042 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank7]:W0326 12:16:21.810000 98042 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank7]:W0326 12:16:21.810000 98042 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank7]:W0326 12:16:21.810000 98042 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank7]:W0326 12:16:21.810000 98042 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank0]:W0326 12:16:28.551000 98035 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank0]:W0326 12:16:28.551000 98035 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank0]:W0326 12:16:28.551000 98035 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank0]:W0326 12:16:28.551000 98035 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank0]:W0326 12:16:28.551000 98035 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank5]:W0326 12:16:28.711000 98040 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank5]:W0326 12:16:28.711000 98040 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank5]:W0326 12:16:28.711000 98040 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank5]:W0326 12:16:28.711000 98040 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank5]:W0326 12:16:28.711000 98040 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank4]:W0326 12:16:28.723000 98039 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank4]:W0326 12:16:28.723000 98039 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank4]:W0326 12:16:28.723000 98039 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank4]:W0326 12:16:28.723000 98039 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank4]:W0326 12:16:28.723000 98039 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank6]:W0326 12:16:28.727000 98041 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank6]:W0326 12:16:28.727000 98041 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank6]:W0326 12:16:28.727000 98041 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank6]:W0326 12:16:28.727000 98041 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank6]:W0326 12:16:28.727000 98041 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank3]:W0326 12:16:28.762000 98038 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank3]:W0326 12:16:28.762000 98038 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank3]:W0326 12:16:28.762000 98038 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank3]:W0326 12:16:28.762000 98038 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank3]:W0326 12:16:28.762000 98038 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank7]:W0326 12:16:29.282000 98042 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank7]:W0326 12:16:29.282000 98042 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank7]:W0326 12:16:29.282000 98042 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank7]:W0326 12:16:29.282000 98042 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank7]:W0326 12:16:29.282000 98042 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank2]:W0326 12:16:29.384000 98037 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank2]:W0326 12:16:29.384000 98037 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank2]:W0326 12:16:29.384000 98037 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank2]:W0326 12:16:29.384000 98037 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank2]:W0326 12:16:29.384000 98037 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank1]:W0326 12:16:29.838000 98036 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank1]:W0326 12:16:29.838000 98036 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank1]:W0326 12:16:29.838000 98036 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank1]:W0326 12:16:29.838000 98036 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank1]:W0326 12:16:29.838000 98036 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank3]:W0326 12:16:33.058000 98038 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank3]:W0326 12:16:33.058000 98038 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank3]:W0326 12:16:33.058000 98038 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank3]:W0326 12:16:33.058000 98038 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank3]:W0326 12:16:33.058000 98038 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank6]:W0326 12:16:33.074000 98041 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank6]:W0326 12:16:33.074000 98041 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank6]:W0326 12:16:33.074000 98041 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank6]:W0326 12:16:33.074000 98041 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank6]:W0326 12:16:33.074000 98041 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank1]:W0326 12:16:33.249000 98036 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank1]:W0326 12:16:33.249000 98036 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank1]:W0326 12:16:33.249000 98036 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank1]:W0326 12:16:33.249000 98036 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank1]:W0326 12:16:33.249000 98036 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank7]:W0326 12:16:33.257000 98042 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank7]:W0326 12:16:33.257000 98042 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank7]:W0326 12:16:33.257000 98042 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank7]:W0326 12:16:33.257000 98042 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank7]:W0326 12:16:33.257000 98042 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank4]:W0326 12:16:33.274000 98039 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank4]:W0326 12:16:33.274000 98039 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank4]:W0326 12:16:33.274000 98039 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank4]:W0326 12:16:33.274000 98039 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank4]:W0326 12:16:33.274000 98039 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank0]:W0326 12:16:33.285000 98035 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank0]:W0326 12:16:33.285000 98035 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank0]:W0326 12:16:33.285000 98035 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank0]:W0326 12:16:33.285000 98035 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank0]:W0326 12:16:33.285000 98035 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank5]:W0326 12:16:33.293000 98040 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank5]:W0326 12:16:33.293000 98040 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank5]:W0326 12:16:33.293000 98040 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank5]:W0326 12:16:33.293000 98040 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank5]:W0326 12:16:33.293000 98040 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank2]:W0326 12:16:33.332000 98037 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank2]:W0326 12:16:33.332000 98037 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank2]:W0326 12:16:33.332000 98037 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank2]:W0326 12:16:33.332000 98037 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank2]:W0326 12:16:33.332000 98037 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +step:5500/20000 train_loss:1.9783 train_time:596598ms step_avg:108.47ms seq_len:2048 +step:5519/20000 val_loss:1.9822 val_bpb:1.1739 train_time:600118ms step_avg:108.74ms seq_len:2048 +early_stop train_time:600118ms step:5519/20000 +peak_mem: 24934 MiB reserved: 25978 MiB +steps:5519 +swa_n:5 +swa:blending EMA+SWA (0.7*EMA + 0.3*SWA, 5 SWA checkpoints) +DIAGNOSTIC post_ema val_loss:1.9842 val_bpb:1.1752 eval_time:2086ms +gptq:calibrating with 256 samples... +gptq:calibration done in 1143ms, 66 layers +Serialized model: 106893257 bytes +Code size: 99811 bytes +gptq:quantizing with block_size=128 percdamp=0.01 prune_pct=0.03... +gptq:quantization done in 11260ms (gptq:66 naive:0) +compression:lzma raw_size:27632195 compressed_size:14908344 ratio:1.85x +Serialized model int5+lzma: 14908348 bytes +Total submission size: 15008159 bytes +Size budget OK: 991841 bytes remaining +[rank7]:W0326 12:19:02.968000 98042 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank7]:W0326 12:19:02.968000 98042 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank7]:W0326 12:19:02.968000 98042 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank7]:W0326 12:19:02.968000 98042 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank7]:W0326 12:19:02.968000 98042 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank2]:W0326 12:19:03.019000 98037 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank2]:W0326 12:19:03.019000 98037 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank2]:W0326 12:19:03.019000 98037 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank2]:W0326 12:19:03.019000 98037 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank2]:W0326 12:19:03.019000 98037 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank4]:W0326 12:19:03.064000 98039 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank4]:W0326 12:19:03.064000 98039 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank4]:W0326 12:19:03.064000 98039 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank4]:W0326 12:19:03.064000 98039 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank4]:W0326 12:19:03.064000 98039 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank1]:W0326 12:19:03.072000 98036 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank1]:W0326 12:19:03.072000 98036 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank1]:W0326 12:19:03.072000 98036 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank1]:W0326 12:19:03.072000 98036 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank1]:W0326 12:19:03.072000 98036 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank6]:W0326 12:19:03.080000 98041 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank6]:W0326 12:19:03.080000 98041 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank6]:W0326 12:19:03.080000 98041 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank6]:W0326 12:19:03.080000 98041 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank6]:W0326 12:19:03.080000 98041 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank3]:W0326 12:19:03.212000 98038 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank3]:W0326 12:19:03.212000 98038 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank3]:W0326 12:19:03.212000 98038 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank3]:W0326 12:19:03.212000 98038 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank3]:W0326 12:19:03.212000 98038 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank5]:W0326 12:19:03.796000 98040 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank5]:W0326 12:19:03.796000 98040 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank5]:W0326 12:19:03.796000 98040 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank5]:W0326 12:19:03.796000 98040 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank5]:W0326 12:19:03.796000 98040 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank0]:W0326 12:19:03.970000 98035 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank0]:W0326 12:19:03.970000 98035 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank0]:W0326 12:19:03.970000 98035 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank0]:W0326 12:19:03.970000 98035 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank0]:W0326 12:19:03.970000 98035 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +final_int5_roundtrip val_loss:1.9959 val_bpb:1.1821 eval_time:9100ms +final_int5_roundtrip_exact val_loss:1.99589206 val_bpb:1.18207970 +eval_budget: 9s elapsed, 591s remaining +integrated_eval: starting single-pass TTT+N-gram evaluation +integrated_eval: LoRA rank=8 on Q,V of blocks 9-10, 28672 LoRA params, 947 chunks, polyak_decay=0.998, ngram order=2-9 adaptive=True alpha=[0.12,0.6] +integrated_eval: chunk 5/947 model_bpb=1.152144 ngram_bpb=1.191478 delta=0.039334 t=3s +integrated_eval: chunk 10/947 model_bpb=1.161709 ngram_bpb=1.218182 delta=0.056473 t=6s +integrated_eval: chunk 15/947 model_bpb=1.170380 ngram_bpb=1.236355 delta=0.065975 t=9s +integrated_eval: chunk 20/947 model_bpb=1.166396 ngram_bpb=1.237253 delta=0.070858 t=13s +integrated_eval: chunk 25/947 model_bpb=1.165490 ngram_bpb=1.236945 delta=0.071455 t=16s +integrated_eval: chunk 30/947 model_bpb=1.170108 ngram_bpb=1.239205 delta=0.069097 t=19s +integrated_eval: chunk 35/947 model_bpb=1.168355 ngram_bpb=1.232541 delta=0.064186 t=22s +integrated_eval: chunk 40/947 model_bpb=1.164408 ngram_bpb=1.221845 delta=0.057437 t=25s +integrated_eval: chunk 45/947 model_bpb=1.163693 ngram_bpb=1.212002 delta=0.048309 t=28s +integrated_eval: chunk 50/947 model_bpb=1.164917 ngram_bpb=1.202344 delta=0.037427 t=31s +integrated_eval: chunk 55/947 model_bpb=1.165682 ngram_bpb=1.190389 delta=0.024707 t=34s +integrated_eval: chunk 60/947 model_bpb=1.161224 ngram_bpb=1.172652 delta=0.011428 t=37s +integrated_eval: chunk 65/947 model_bpb=1.161240 ngram_bpb=1.158054 delta=-0.003187 t=40s +integrated_eval: chunk 70/947 model_bpb=1.159361 ngram_bpb=1.140656 delta=-0.018705 t=43s +integrated_eval: chunk 75/947 model_bpb=1.159543 ngram_bpb=1.124795 delta=-0.034748 t=46s +integrated_eval: chunk 80/947 model_bpb=1.160532 ngram_bpb=1.108803 delta=-0.051729 t=49s +integrated_eval: chunk 85/947 model_bpb=1.162654 ngram_bpb=1.092887 delta=-0.069767 t=52s +integrated_eval: chunk 90/947 model_bpb=1.162829 ngram_bpb=1.075853 delta=-0.086976 t=55s +integrated_eval: chunk 95/947 model_bpb=1.165948 ngram_bpb=1.060590 delta=-0.105357 t=58s +integrated_eval: chunk 100/947 model_bpb=1.165192 ngram_bpb=1.043001 delta=-0.122191 t=61s +integrated_eval: chunk 105/947 model_bpb=1.163784 ngram_bpb=1.025071 delta=-0.138714 t=64s +integrated_eval: chunk 110/947 model_bpb=1.164973 ngram_bpb=1.008732 delta=-0.156240 t=67s +integrated_eval: chunk 115/947 model_bpb=1.164618 ngram_bpb=0.991985 delta=-0.172633 t=71s +integrated_eval: chunk 120/947 model_bpb=1.165230 ngram_bpb=0.976095 delta=-0.189134 t=74s +integrated_eval: chunk 125/947 model_bpb=1.164417 ngram_bpb=0.959846 delta=-0.204571 t=77s +integrated_eval: chunk 130/947 model_bpb=1.164379 ngram_bpb=0.944106 delta=-0.220272 t=79s +integrated_eval: chunk 135/947 model_bpb=1.163202 ngram_bpb=0.928309 delta=-0.234893 t=82s +integrated_eval: chunk 140/947 model_bpb=1.164196 ngram_bpb=0.913471 delta=-0.250725 t=85s +integrated_eval: chunk 145/947 model_bpb=1.164182 ngram_bpb=0.898430 delta=-0.265752 t=88s +integrated_eval: chunk 150/947 model_bpb=1.164487 ngram_bpb=0.883994 delta=-0.280492 t=91s +integrated_eval: chunk 155/947 model_bpb=1.164823 ngram_bpb=0.870004 delta=-0.294819 t=94s +integrated_eval: chunk 160/947 model_bpb=1.165462 ngram_bpb=0.856451 delta=-0.309011 t=97s +integrated_eval: chunk 165/947 model_bpb=1.165143 ngram_bpb=0.843056 delta=-0.322087 t=100s +integrated_eval: chunk 170/947 model_bpb=1.164689 ngram_bpb=0.829788 delta=-0.334901 t=103s +integrated_eval: chunk 175/947 model_bpb=1.164985 ngram_bpb=0.817231 delta=-0.347755 t=106s +integrated_eval: chunk 180/947 model_bpb=1.166220 ngram_bpb=0.806245 delta=-0.359975 t=109s +integrated_eval: chunk 185/947 model_bpb=1.165969 ngram_bpb=0.794336 delta=-0.371633 t=112s +integrated_eval: chunk 190/947 model_bpb=1.166026 ngram_bpb=0.782831 delta=-0.383195 t=115s +integrated_eval: chunk 195/947 model_bpb=1.166201 ngram_bpb=0.771738 delta=-0.394463 t=119s +integrated_eval: chunk 200/947 model_bpb=1.166143 ngram_bpb=0.760999 delta=-0.405144 t=121s +integrated_eval: chunk 205/947 model_bpb=1.165103 ngram_bpb=0.750365 delta=-0.414738 t=125s +integrated_eval: chunk 210/947 model_bpb=1.165176 ngram_bpb=0.740235 delta=-0.424941 t=128s +integrated_eval: chunk 215/947 model_bpb=1.165757 ngram_bpb=0.730414 delta=-0.435343 t=131s +integrated_eval: chunk 220/947 model_bpb=1.164994 ngram_bpb=0.720627 delta=-0.444367 t=134s +integrated_eval: chunk 225/947 model_bpb=1.165108 ngram_bpb=0.711188 delta=-0.453920 t=137s +integrated_eval: chunk 230/947 model_bpb=1.164882 ngram_bpb=0.702020 delta=-0.462862 t=139s +integrated_eval: chunk 235/947 model_bpb=1.164665 ngram_bpb=0.693109 delta=-0.471556 t=142s +integrated_eval: chunk 240/947 model_bpb=1.164192 ngram_bpb=0.684683 delta=-0.479509 t=145s +integrated_eval: chunk 245/947 model_bpb=1.164322 ngram_bpb=0.676543 delta=-0.487779 t=148s +integrated_eval: chunk 250/947 model_bpb=1.164189 ngram_bpb=0.668436 delta=-0.495753 t=151s +integrated_eval: chunk 255/947 model_bpb=1.163555 ngram_bpb=0.660631 delta=-0.502924 t=155s +integrated_eval: chunk 260/947 model_bpb=1.163334 ngram_bpb=0.653172 delta=-0.510162 t=158s +integrated_eval: chunk 265/947 model_bpb=1.163873 ngram_bpb=0.645896 delta=-0.517977 t=161s +integrated_eval: chunk 270/947 model_bpb=1.164015 ngram_bpb=0.638634 delta=-0.525381 t=164s +integrated_eval: chunk 275/947 model_bpb=1.163479 ngram_bpb=0.631620 delta=-0.531859 t=167s +integrated_eval: chunk 280/947 model_bpb=1.163398 ngram_bpb=0.624791 delta=-0.538607 t=170s +integrated_eval: chunk 285/947 model_bpb=1.162897 ngram_bpb=0.618275 delta=-0.544621 t=173s +integrated_eval: chunk 290/947 model_bpb=1.162673 ngram_bpb=0.611931 delta=-0.550742 t=176s +integrated_eval: chunk 295/947 model_bpb=1.162122 ngram_bpb=0.605794 delta=-0.556328 t=179s +integrated_eval: chunk 300/947 model_bpb=1.162242 ngram_bpb=0.599773 delta=-0.562468 t=182s +integrated_eval: chunk 305/947 model_bpb=1.161855 ngram_bpb=0.593937 delta=-0.567918 t=185s +integrated_eval: chunk 310/947 model_bpb=1.161831 ngram_bpb=0.588197 delta=-0.573634 t=188s +integrated_eval: chunk 315/947 model_bpb=1.161604 ngram_bpb=0.582525 delta=-0.579079 t=191s +integrated_eval: chunk 320/947 model_bpb=1.160955 ngram_bpb=0.576953 delta=-0.584003 t=194s +integrated_eval: chunk 325/947 model_bpb=1.160538 ngram_bpb=0.571591 delta=-0.588946 t=197s +integrated_eval: chunk 330/947 model_bpb=1.160267 ngram_bpb=0.566427 delta=-0.593839 t=200s +integrated_eval: chunk 335/947 model_bpb=1.160073 ngram_bpb=0.561337 delta=-0.598736 t=203s +integrated_eval: chunk 340/947 model_bpb=1.159334 ngram_bpb=0.556290 delta=-0.603044 t=206s +integrated_eval: chunk 345/947 model_bpb=1.159444 ngram_bpb=0.551428 delta=-0.608016 t=209s +integrated_eval: chunk 350/947 model_bpb=1.158870 ngram_bpb=0.546626 delta=-0.612244 t=212s +integrated_eval: chunk 355/947 model_bpb=1.158680 ngram_bpb=0.542088 delta=-0.616592 t=215s +integrated_eval: chunk 360/947 model_bpb=1.158440 ngram_bpb=0.537648 delta=-0.620792 t=218s +integrated_eval: chunk 365/947 model_bpb=1.158694 ngram_bpb=0.533328 delta=-0.625366 t=221s +integrated_eval: chunk 370/947 model_bpb=1.158998 ngram_bpb=0.529091 delta=-0.629907 t=224s +integrated_eval: chunk 375/947 model_bpb=1.158438 ngram_bpb=0.524887 delta=-0.633552 t=227s +integrated_eval: chunk 380/947 model_bpb=1.158640 ngram_bpb=0.520908 delta=-0.637732 t=230s +integrated_eval: chunk 385/947 model_bpb=1.158512 ngram_bpb=0.517013 delta=-0.641498 t=233s +integrated_eval: chunk 390/947 model_bpb=1.158693 ngram_bpb=0.513206 delta=-0.645487 t=236s +integrated_eval: chunk 395/947 model_bpb=1.158697 ngram_bpb=0.509518 delta=-0.649180 t=239s +integrated_eval: chunk 400/947 model_bpb=1.158582 ngram_bpb=0.505735 delta=-0.652847 t=242s +integrated_eval: chunk 405/947 model_bpb=1.158496 ngram_bpb=0.502155 delta=-0.656341 t=245s +integrated_eval: chunk 410/947 model_bpb=1.158454 ngram_bpb=0.498588 delta=-0.659866 t=248s +integrated_eval: chunk 415/947 model_bpb=1.158251 ngram_bpb=0.495164 delta=-0.663087 t=251s +integrated_eval: chunk 420/947 model_bpb=1.158117 ngram_bpb=0.491832 delta=-0.666285 t=254s +integrated_eval: chunk 425/947 model_bpb=1.158070 ngram_bpb=0.488571 delta=-0.669498 t=257s +integrated_eval: chunk 430/947 model_bpb=1.158217 ngram_bpb=0.485396 delta=-0.672821 t=260s +integrated_eval: chunk 435/947 model_bpb=1.158365 ngram_bpb=0.482208 delta=-0.676158 t=263s +integrated_eval: chunk 440/947 model_bpb=1.158537 ngram_bpb=0.479215 delta=-0.679322 t=266s +integrated_eval: chunk 445/947 model_bpb=1.158083 ngram_bpb=0.476286 delta=-0.681797 t=269s +integrated_eval: chunk 450/947 model_bpb=1.158193 ngram_bpb=0.473431 delta=-0.684761 t=272s +integrated_eval: chunk 455/947 model_bpb=1.158001 ngram_bpb=0.470525 delta=-0.687475 t=275s +integrated_eval: chunk 460/947 model_bpb=1.158148 ngram_bpb=0.467690 delta=-0.690458 t=278s +integrated_eval: chunk 465/947 model_bpb=1.158030 ngram_bpb=0.465000 delta=-0.693031 t=281s +integrated_eval: chunk 470/947 model_bpb=1.158412 ngram_bpb=0.462222 delta=-0.696190 t=284s +integrated_eval: chunk 475/947 model_bpb=1.158830 ngram_bpb=0.459552 delta=-0.699277 t=287s +integrated_eval: chunk 480/947 model_bpb=1.158897 ngram_bpb=0.456784 delta=-0.702113 t=290s +integrated_eval: chunk 485/947 model_bpb=1.159488 ngram_bpb=0.454149 delta=-0.705340 t=292s +integrated_eval: chunk 490/947 model_bpb=1.159640 ngram_bpb=0.451483 delta=-0.708157 t=296s +integrated_eval: chunk 495/947 model_bpb=1.159695 ngram_bpb=0.448929 delta=-0.710766 t=299s +integrated_eval: chunk 500/947 model_bpb=1.159941 ngram_bpb=0.446398 delta=-0.713543 t=301s +integrated_eval: chunk 505/947 model_bpb=1.160200 ngram_bpb=0.443965 delta=-0.716235 t=305s +integrated_eval: chunk 510/947 model_bpb=1.160458 ngram_bpb=0.441534 delta=-0.718923 t=308s +integrated_eval: chunk 515/947 model_bpb=1.160907 ngram_bpb=0.439103 delta=-0.721804 t=311s +integrated_eval: chunk 520/947 model_bpb=1.161374 ngram_bpb=0.436714 delta=-0.724660 t=314s +integrated_eval: chunk 525/947 model_bpb=1.161298 ngram_bpb=0.434368 delta=-0.726930 t=317s +integrated_eval: chunk 530/947 model_bpb=1.161485 ngram_bpb=0.432054 delta=-0.729431 t=320s +integrated_eval: chunk 535/947 model_bpb=1.161624 ngram_bpb=0.429890 delta=-0.731734 t=322s +integrated_eval: chunk 540/947 model_bpb=1.161698 ngram_bpb=0.427645 delta=-0.734054 t=326s +integrated_eval: chunk 545/947 model_bpb=1.161960 ngram_bpb=0.425447 delta=-0.736513 t=329s +integrated_eval: chunk 550/947 model_bpb=1.162170 ngram_bpb=0.423318 delta=-0.738852 t=332s +integrated_eval: chunk 555/947 model_bpb=1.161965 ngram_bpb=0.421217 delta=-0.740748 t=335s +integrated_eval: chunk 560/947 model_bpb=1.161793 ngram_bpb=0.419139 delta=-0.742654 t=338s +integrated_eval: chunk 565/947 model_bpb=1.161636 ngram_bpb=0.417111 delta=-0.744526 t=341s +integrated_eval: chunk 570/947 model_bpb=1.161368 ngram_bpb=0.415074 delta=-0.746294 t=344s +integrated_eval: chunk 575/947 model_bpb=1.161466 ngram_bpb=0.413122 delta=-0.748343 t=347s +integrated_eval: chunk 580/947 model_bpb=1.161393 ngram_bpb=0.411123 delta=-0.750270 t=350s +integrated_eval: chunk 585/947 model_bpb=1.161116 ngram_bpb=0.409156 delta=-0.751959 t=353s +integrated_eval: chunk 590/947 model_bpb=1.160909 ngram_bpb=0.407227 delta=-0.753682 t=355s +integrated_eval: chunk 595/947 model_bpb=1.161085 ngram_bpb=0.405342 delta=-0.755743 t=358s +integrated_eval: chunk 600/947 model_bpb=1.161322 ngram_bpb=0.403509 delta=-0.757813 t=361s +integrated_eval: chunk 605/947 model_bpb=1.160929 ngram_bpb=0.401711 delta=-0.759219 t=364s +integrated_eval: chunk 610/947 model_bpb=1.161349 ngram_bpb=0.400012 delta=-0.761337 t=367s +integrated_eval: chunk 615/947 model_bpb=1.161230 ngram_bpb=0.398238 delta=-0.762991 t=370s +integrated_eval: chunk 620/947 model_bpb=1.160980 ngram_bpb=0.396456 delta=-0.764524 t=373s +integrated_eval: chunk 625/947 model_bpb=1.160560 ngram_bpb=0.394793 delta=-0.765767 t=376s +integrated_eval: chunk 630/947 model_bpb=1.160265 ngram_bpb=0.393122 delta=-0.767143 t=379s +integrated_eval: chunk 635/947 model_bpb=1.160118 ngram_bpb=0.391454 delta=-0.768664 t=382s +integrated_eval: chunk 640/947 model_bpb=1.159781 ngram_bpb=0.389762 delta=-0.770018 t=385s +integrated_eval: chunk 645/947 model_bpb=1.159558 ngram_bpb=0.388134 delta=-0.771425 t=388s +integrated_eval: chunk 650/947 model_bpb=1.159507 ngram_bpb=0.386539 delta=-0.772968 t=391s +integrated_eval: chunk 655/947 model_bpb=1.159254 ngram_bpb=0.384965 delta=-0.774289 t=394s +integrated_eval: chunk 660/947 model_bpb=1.158971 ngram_bpb=0.383373 delta=-0.775598 t=397s +integrated_eval: chunk 665/947 model_bpb=1.158742 ngram_bpb=0.381825 delta=-0.776916 t=400s +integrated_eval: chunk 670/947 model_bpb=1.158657 ngram_bpb=0.380317 delta=-0.778340 t=403s +integrated_eval: chunk 675/947 model_bpb=1.158532 ngram_bpb=0.378854 delta=-0.779679 t=406s +integrated_eval: chunk 680/947 model_bpb=1.158702 ngram_bpb=0.377467 delta=-0.781235 t=409s +integrated_eval: chunk 685/947 model_bpb=1.158888 ngram_bpb=0.376087 delta=-0.782801 t=412s +integrated_eval: chunk 690/947 model_bpb=1.159292 ngram_bpb=0.374758 delta=-0.784533 t=415s +integrated_eval: chunk 695/947 model_bpb=1.159054 ngram_bpb=0.373356 delta=-0.785698 t=418s +integrated_eval: chunk 700/947 model_bpb=1.159124 ngram_bpb=0.372006 delta=-0.787117 t=421s +integrated_eval: chunk 705/947 model_bpb=1.159293 ngram_bpb=0.370713 delta=-0.788580 t=424s +integrated_eval: chunk 710/947 model_bpb=1.159399 ngram_bpb=0.369382 delta=-0.790017 t=427s +integrated_eval: chunk 715/947 model_bpb=1.159364 ngram_bpb=0.368181 delta=-0.791182 t=430s +integrated_eval: chunk 720/947 model_bpb=1.159831 ngram_bpb=0.366913 delta=-0.792918 t=433s +integrated_eval: chunk 725/947 model_bpb=1.159830 ngram_bpb=0.365668 delta=-0.794162 t=436s +integrated_eval: chunk 730/947 model_bpb=1.159744 ngram_bpb=0.364436 delta=-0.795308 t=438s +integrated_eval: chunk 735/947 model_bpb=1.160404 ngram_bpb=0.363211 delta=-0.797193 t=441s +integrated_eval: chunk 740/947 model_bpb=1.160356 ngram_bpb=0.361994 delta=-0.798362 t=444s +integrated_eval: chunk 745/947 model_bpb=1.160697 ngram_bpb=0.360846 delta=-0.799851 t=447s +integrated_eval: chunk 750/947 model_bpb=1.160758 ngram_bpb=0.359693 delta=-0.801065 t=450s +integrated_eval: chunk 755/947 model_bpb=1.160866 ngram_bpb=0.358518 delta=-0.802348 t=454s +integrated_eval: chunk 760/947 model_bpb=1.160944 ngram_bpb=0.357388 delta=-0.803555 t=457s +integrated_eval: chunk 765/947 model_bpb=1.161214 ngram_bpb=0.356263 delta=-0.804951 t=460s +integrated_eval: chunk 770/947 model_bpb=1.161301 ngram_bpb=0.355103 delta=-0.806199 t=462s +integrated_eval: chunk 775/947 model_bpb=1.161626 ngram_bpb=0.353978 delta=-0.807648 t=465s +integrated_eval: chunk 780/947 model_bpb=1.161722 ngram_bpb=0.352852 delta=-0.808870 t=469s +integrated_eval: chunk 785/947 model_bpb=1.161935 ngram_bpb=0.351752 delta=-0.810182 t=472s +integrated_eval: chunk 790/947 model_bpb=1.162171 ngram_bpb=0.350690 delta=-0.811481 t=475s +integrated_eval: chunk 795/947 model_bpb=1.162194 ngram_bpb=0.349554 delta=-0.812640 t=478s +integrated_eval: chunk 800/947 model_bpb=1.162461 ngram_bpb=0.348455 delta=-0.814005 t=481s +integrated_eval: chunk 805/947 model_bpb=1.162674 ngram_bpb=0.347354 delta=-0.815320 t=484s +integrated_eval: chunk 810/947 model_bpb=1.162593 ngram_bpb=0.346326 delta=-0.816268 t=486s +integrated_eval: chunk 815/947 model_bpb=1.162685 ngram_bpb=0.345266 delta=-0.817420 t=489s +integrated_eval: chunk 820/947 model_bpb=1.162739 ngram_bpb=0.344212 delta=-0.818527 t=492s +integrated_eval: chunk 825/947 model_bpb=1.162797 ngram_bpb=0.343183 delta=-0.819614 t=495s +integrated_eval: chunk 830/947 model_bpb=1.162978 ngram_bpb=0.342161 delta=-0.820817 t=498s +integrated_eval: chunk 835/947 model_bpb=1.163238 ngram_bpb=0.341162 delta=-0.822076 t=501s +integrated_eval: chunk 840/947 model_bpb=1.163330 ngram_bpb=0.340151 delta=-0.823179 t=504s +integrated_eval: chunk 845/947 model_bpb=1.163438 ngram_bpb=0.339155 delta=-0.824283 t=507s +integrated_eval: chunk 850/947 model_bpb=1.163711 ngram_bpb=0.338169 delta=-0.825542 t=510s +integrated_eval: chunk 855/947 model_bpb=1.163742 ngram_bpb=0.337182 delta=-0.826560 t=513s +integrated_eval: chunk 860/947 model_bpb=1.163668 ngram_bpb=0.336224 delta=-0.827444 t=516s +integrated_eval: chunk 865/947 model_bpb=1.163710 ngram_bpb=0.335276 delta=-0.828434 t=519s +integrated_eval: chunk 870/947 model_bpb=1.163535 ngram_bpb=0.334318 delta=-0.829217 t=522s +integrated_eval: chunk 875/947 model_bpb=1.163411 ngram_bpb=0.333368 delta=-0.830043 t=525s +integrated_eval: chunk 880/947 model_bpb=1.163485 ngram_bpb=0.332447 delta=-0.831038 t=528s +integrated_eval: chunk 885/947 model_bpb=1.163464 ngram_bpb=0.331522 delta=-0.831942 t=531s +integrated_eval: chunk 890/947 model_bpb=1.163422 ngram_bpb=0.330632 delta=-0.832790 t=534s +integrated_eval: chunk 895/947 model_bpb=1.163117 ngram_bpb=0.329732 delta=-0.833385 t=537s +integrated_eval: chunk 900/947 model_bpb=1.163056 ngram_bpb=0.328841 delta=-0.834215 t=540s +integrated_eval: chunk 905/947 model_bpb=1.163014 ngram_bpb=0.327963 delta=-0.835051 t=543s +integrated_eval: chunk 910/947 model_bpb=1.163107 ngram_bpb=0.327090 delta=-0.836017 t=546s +integrated_eval: chunk 915/947 model_bpb=1.162979 ngram_bpb=0.326210 delta=-0.836769 t=549s +integrated_eval: chunk 920/947 model_bpb=1.163067 ngram_bpb=0.325398 delta=-0.837669 t=552s +integrated_eval: chunk 925/947 model_bpb=1.162853 ngram_bpb=0.324563 delta=-0.838290 t=555s +integrated_eval: chunk 930/947 model_bpb=1.162820 ngram_bpb=0.323719 delta=-0.839101 t=558s +integrated_eval: chunk 935/947 model_bpb=1.162792 ngram_bpb=0.322906 delta=-0.839886 t=561s +integrated_eval: chunk 940/947 model_bpb=1.162585 ngram_bpb=0.322094 delta=-0.840491 t=563s +integrated_eval: chunk 945/947 model_bpb=1.162608 ngram_bpb=0.321306 delta=-0.841302 t=567s +integrated_eval: DONE model_bpb=1.1627 ngram_bpb=0.3211 delta=-0.8416 elapsed=568s +final_integrated model_val_loss:1.9631 model_val_bpb:1.1627 +final_integrated ngram_val_loss:0.5422 ngram_val_bpb:0.3211 +final_integrated delta_bpb:-0.8416 +final_integrated_exact model_bpb:1.16265931 ngram_bpb:0.32109743 +final_integrated_time:568076ms diff --git a/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/train_seed2024.log b/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/train_seed2024.log new file mode 100644 index 000000000..ff3e27a56 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/train_seed2024.log @@ -0,0 +1,523 @@ +W0326 13:02:28.973000 102565 torch/distributed/run.py:803] +W0326 13:02:28.973000 102565 torch/distributed/run.py:803] ***************************************** +W0326 13:02:28.973000 102565 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 13:02:28.973000 102565 torch/distributed/run.py:803] ***************************************** +logs/150bb9c0-6a2d-43b7-9691-78b301d31374.txt +bpb:sp=./data/tokenizers/fineweb_1024_bpe.model +tl:dataset:fineweb10B_sp1024 train_shards:80 +val:./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +v:opti-ms2 act:lr09sq xsa:last_4 qat:sr wd:WSD gptq:fh ttt:lora_polyak_adamw compression:lzma optimizer:PM prog_seq:enabled vrl:True gated_attn:True decoder_lr_mult:2.0 ngram:True ngram_order:9 complement_alpha:0.5 prune_pct:0.03 +comp:on alpha=0.5 +model_params:27301064 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 gas:1 +seed:2024 +lr_schedule:WSD wsd_stable_frac:0.75 decay_shape:cosine qat_trigger_frac:0.85 +optimizer:PM (NS) +ttt_config: optimizer=AdamW(LoRA) epochs=1 lr=0.003 lora_rank=8 polyak_decay=0.998 chunk=65536 grad_clip=1.0 temperature=0.98 +decoder_lr_mult:2.0 encoder_matrix_lr:0.025 decoder_matrix_lr:0.05 +gated_attn:enabled layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +value_residual:enabled layers:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +gptq_config: samples=256 block_size=128 damp=0.01 +prune_pct:0.03 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9306 val_bpb:4.1047 train_time:0ms step_avg:0.04ms seq_len:2048 +step:1/20000 train_loss:6.9322 train_time:187ms step_avg:187.27ms seq_len:2048 +step:2/20000 train_loss:8.6505 train_time:274ms step_avg:136.81ms seq_len:2048 +step:3/20000 train_loss:7.8014 train_time:364ms step_avg:121.49ms seq_len:2048 +step:4/20000 train_loss:7.0552 train_time:456ms step_avg:113.97ms seq_len:2048 +step:5/20000 train_loss:6.8392 train_time:546ms step_avg:109.21ms seq_len:2048 +step:6/20000 train_loss:6.7524 train_time:639ms step_avg:106.54ms seq_len:2048 +step:7/20000 train_loss:6.6348 train_time:729ms step_avg:104.19ms seq_len:2048 +step:8/20000 train_loss:6.5551 train_time:820ms step_avg:102.46ms seq_len:2048 +step:9/20000 train_loss:6.3128 train_time:910ms step_avg:101.10ms seq_len:2048 +step:10/20000 train_loss:5.8934 train_time:1000ms step_avg:100.03ms seq_len:2048 +step:500/20000 train_loss:2.3282 train_time:48827ms step_avg:97.65ms seq_len:2048 +step:1000/20000 train_loss:2.2478 train_time:99371ms step_avg:99.37ms seq_len:2048 +step:1500/20000 train_loss:2.2075 train_time:149901ms step_avg:99.93ms seq_len:2048 +step:2000/20000 train_loss:2.0546 train_time:200739ms step_avg:100.37ms seq_len:2048 +step:2500/20000 train_loss:2.1624 train_time:251699ms step_avg:100.68ms seq_len:2048 +step:3000/20000 train_loss:2.1573 train_time:302339ms step_avg:100.78ms seq_len:2048 +step:3500/20000 train_loss:2.1834 train_time:352922ms step_avg:100.83ms seq_len:2048 +step:4000/20000 train_loss:1.9876 train_time:403842ms step_avg:100.96ms seq_len:2048 +step:4000/20000 val_loss:2.1013 val_bpb:1.2445 train_time:403848ms step_avg:100.96ms seq_len:2048 +step:4500/20000 train_loss:2.1531 train_time:453314ms step_avg:100.74ms seq_len:2048 +step:5000/20000 train_loss:2.1227 train_time:504064ms step_avg:100.81ms seq_len:2048 +late_qat:enabled step:5060 scale:0.6537 elapsed_ms:510082 trigger:wallclock@0.85 +[rank2]:W0326 13:12:36.308000 102635 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank2]:W0326 13:12:36.308000 102635 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank2]:W0326 13:12:36.308000 102635 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank2]:W0326 13:12:36.308000 102635 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank2]:W0326 13:12:36.308000 102635 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank2]:W0326 13:12:36.308000 102635 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank2]:W0326 13:12:36.308000 102635 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank2]:W0326 13:12:36.308000 102635 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank2]:W0326 13:12:36.308000 102635 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank2]:W0326 13:12:36.308000 102635 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank2]:W0326 13:12:36.308000 102635 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank2]:W0326 13:12:36.308000 102635 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank2]:W0326 13:12:36.308000 102635 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank2]:W0326 13:12:36.308000 102635 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank2]:W0326 13:12:36.308000 102635 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank0]:W0326 13:12:36.312000 102633 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank0]:W0326 13:12:36.312000 102633 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank0]:W0326 13:12:36.312000 102633 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank0]:W0326 13:12:36.312000 102633 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank0]:W0326 13:12:36.312000 102633 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank0]:W0326 13:12:36.312000 102633 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank0]:W0326 13:12:36.312000 102633 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank0]:W0326 13:12:36.312000 102633 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank0]:W0326 13:12:36.312000 102633 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank0]:W0326 13:12:36.312000 102633 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank0]:W0326 13:12:36.312000 102633 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank0]:W0326 13:12:36.312000 102633 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank0]:W0326 13:12:36.312000 102633 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank0]:W0326 13:12:36.312000 102633 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank0]:W0326 13:12:36.312000 102633 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank5]:W0326 13:12:36.313000 102638 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank5]:W0326 13:12:36.313000 102638 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank5]:W0326 13:12:36.313000 102638 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank5]:W0326 13:12:36.313000 102638 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank5]:W0326 13:12:36.313000 102638 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank5]:W0326 13:12:36.313000 102638 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank5]:W0326 13:12:36.313000 102638 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank5]:W0326 13:12:36.313000 102638 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank5]:W0326 13:12:36.313000 102638 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank5]:W0326 13:12:36.313000 102638 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank5]:W0326 13:12:36.313000 102638 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank5]:W0326 13:12:36.313000 102638 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank5]:W0326 13:12:36.313000 102638 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank5]:W0326 13:12:36.313000 102638 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank5]:W0326 13:12:36.313000 102638 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank3]:W0326 13:12:36.321000 102636 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank3]:W0326 13:12:36.321000 102636 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank3]:W0326 13:12:36.321000 102636 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank3]:W0326 13:12:36.321000 102636 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank3]:W0326 13:12:36.321000 102636 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank3]:W0326 13:12:36.321000 102636 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank3]:W0326 13:12:36.321000 102636 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank3]:W0326 13:12:36.321000 102636 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank3]:W0326 13:12:36.321000 102636 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank3]:W0326 13:12:36.321000 102636 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank3]:W0326 13:12:36.321000 102636 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank3]:W0326 13:12:36.321000 102636 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank3]:W0326 13:12:36.321000 102636 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank3]:W0326 13:12:36.321000 102636 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank3]:W0326 13:12:36.321000 102636 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank1]:W0326 13:12:36.345000 102634 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank1]:W0326 13:12:36.345000 102634 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank1]:W0326 13:12:36.345000 102634 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank1]:W0326 13:12:36.345000 102634 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank1]:W0326 13:12:36.345000 102634 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank1]:W0326 13:12:36.345000 102634 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank1]:W0326 13:12:36.345000 102634 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank1]:W0326 13:12:36.345000 102634 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank1]:W0326 13:12:36.345000 102634 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank1]:W0326 13:12:36.345000 102634 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank1]:W0326 13:12:36.345000 102634 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank1]:W0326 13:12:36.345000 102634 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank1]:W0326 13:12:36.345000 102634 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank1]:W0326 13:12:36.345000 102634 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank1]:W0326 13:12:36.345000 102634 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank4]:W0326 13:12:36.348000 102637 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank4]:W0326 13:12:36.348000 102637 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank4]:W0326 13:12:36.348000 102637 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank4]:W0326 13:12:36.348000 102637 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank4]:W0326 13:12:36.348000 102637 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank4]:W0326 13:12:36.348000 102637 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank4]:W0326 13:12:36.348000 102637 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank4]:W0326 13:12:36.348000 102637 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank4]:W0326 13:12:36.348000 102637 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank4]:W0326 13:12:36.348000 102637 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank4]:W0326 13:12:36.348000 102637 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank4]:W0326 13:12:36.348000 102637 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank4]:W0326 13:12:36.348000 102637 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank4]:W0326 13:12:36.348000 102637 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank4]:W0326 13:12:36.348000 102637 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank7]:W0326 13:12:36.953000 102640 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank7]:W0326 13:12:36.953000 102640 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank7]:W0326 13:12:36.953000 102640 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank7]:W0326 13:12:36.953000 102640 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank7]:W0326 13:12:36.953000 102640 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank7]:W0326 13:12:36.953000 102640 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank7]:W0326 13:12:36.953000 102640 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank7]:W0326 13:12:36.953000 102640 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank7]:W0326 13:12:36.953000 102640 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank7]:W0326 13:12:36.953000 102640 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank7]:W0326 13:12:36.953000 102640 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank7]:W0326 13:12:36.953000 102640 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank7]:W0326 13:12:36.953000 102640 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank7]:W0326 13:12:36.953000 102640 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank7]:W0326 13:12:36.953000 102640 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank6]:W0326 13:12:37.111000 102639 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank6]:W0326 13:12:37.111000 102639 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank6]:W0326 13:12:37.111000 102639 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank6]:W0326 13:12:37.111000 102639 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank6]:W0326 13:12:37.111000 102639 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank6]:W0326 13:12:37.111000 102639 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank6]:W0326 13:12:37.111000 102639 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank6]:W0326 13:12:37.111000 102639 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank6]:W0326 13:12:37.111000 102639 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank6]:W0326 13:12:37.111000 102639 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank6]:W0326 13:12:37.111000 102639 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank6]:W0326 13:12:37.111000 102639 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank6]:W0326 13:12:37.111000 102639 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank6]:W0326 13:12:37.111000 102639 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank6]:W0326 13:12:37.111000 102639 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank7]:W0326 13:12:44.295000 102640 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank7]:W0326 13:12:44.295000 102640 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank7]:W0326 13:12:44.295000 102640 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank7]:W0326 13:12:44.295000 102640 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank7]:W0326 13:12:44.295000 102640 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank2]:W0326 13:12:44.307000 102635 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank2]:W0326 13:12:44.307000 102635 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank2]:W0326 13:12:44.307000 102635 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank2]:W0326 13:12:44.307000 102635 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank2]:W0326 13:12:44.307000 102635 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank0]:W0326 13:12:44.324000 102633 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank0]:W0326 13:12:44.324000 102633 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank0]:W0326 13:12:44.324000 102633 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank0]:W0326 13:12:44.324000 102633 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank0]:W0326 13:12:44.324000 102633 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank5]:W0326 13:12:44.394000 102638 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank5]:W0326 13:12:44.394000 102638 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank5]:W0326 13:12:44.394000 102638 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank5]:W0326 13:12:44.394000 102638 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank5]:W0326 13:12:44.394000 102638 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank6]:W0326 13:12:44.481000 102639 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank6]:W0326 13:12:44.481000 102639 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank6]:W0326 13:12:44.481000 102639 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank6]:W0326 13:12:44.481000 102639 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank6]:W0326 13:12:44.481000 102639 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank3]:W0326 13:12:44.695000 102636 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank3]:W0326 13:12:44.695000 102636 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank3]:W0326 13:12:44.695000 102636 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank3]:W0326 13:12:44.695000 102636 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank3]:W0326 13:12:44.695000 102636 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank1]:W0326 13:12:45.117000 102634 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank1]:W0326 13:12:45.117000 102634 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank1]:W0326 13:12:45.117000 102634 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank1]:W0326 13:12:45.117000 102634 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank1]:W0326 13:12:45.117000 102634 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank4]:W0326 13:12:45.542000 102637 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank4]:W0326 13:12:45.542000 102637 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank4]:W0326 13:12:45.542000 102637 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank4]:W0326 13:12:45.542000 102637 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank4]:W0326 13:12:45.542000 102637 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank5]:W0326 13:12:49.198000 102638 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank5]:W0326 13:12:49.198000 102638 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank5]:W0326 13:12:49.198000 102638 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank5]:W0326 13:12:49.198000 102638 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank5]:W0326 13:12:49.198000 102638 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank3]:W0326 13:12:49.280000 102636 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank3]:W0326 13:12:49.280000 102636 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank3]:W0326 13:12:49.280000 102636 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank3]:W0326 13:12:49.280000 102636 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank3]:W0326 13:12:49.280000 102636 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank6]:W0326 13:12:49.332000 102639 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank6]:W0326 13:12:49.332000 102639 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank6]:W0326 13:12:49.332000 102639 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank6]:W0326 13:12:49.332000 102639 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank6]:W0326 13:12:49.332000 102639 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank2]:W0326 13:12:49.356000 102635 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank2]:W0326 13:12:49.356000 102635 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank2]:W0326 13:12:49.356000 102635 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank2]:W0326 13:12:49.356000 102635 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank2]:W0326 13:12:49.356000 102635 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank7]:W0326 13:12:49.395000 102640 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank7]:W0326 13:12:49.395000 102640 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank7]:W0326 13:12:49.395000 102640 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank7]:W0326 13:12:49.395000 102640 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank7]:W0326 13:12:49.395000 102640 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank1]:W0326 13:12:49.399000 102634 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank1]:W0326 13:12:49.399000 102634 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank1]:W0326 13:12:49.399000 102634 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank1]:W0326 13:12:49.399000 102634 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank1]:W0326 13:12:49.399000 102634 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank0]:W0326 13:12:49.416000 102633 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank0]:W0326 13:12:49.416000 102633 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank0]:W0326 13:12:49.416000 102633 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank0]:W0326 13:12:49.416000 102633 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank0]:W0326 13:12:49.416000 102633 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank4]:W0326 13:12:49.507000 102637 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank4]:W0326 13:12:49.507000 102637 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank4]:W0326 13:12:49.507000 102637 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank4]:W0326 13:12:49.507000 102637 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank4]:W0326 13:12:49.507000 102637 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +step:5498/20000 val_loss:1.9824 val_bpb:1.1741 train_time:600133ms step_avg:109.15ms seq_len:2048 +early_stop train_time:600133ms step:5498/20000 +peak_mem: 24934 MiB reserved: 25978 MiB +steps:5498 +swa_n:4 +swa:blending EMA+SWA (0.7*EMA + 0.3*SWA, 4 SWA checkpoints) +DIAGNOSTIC post_ema val_loss:1.9853 val_bpb:1.1758 eval_time:2090ms +gptq:calibrating with 256 samples... +gptq:calibration done in 1134ms, 66 layers +Serialized model: 106893257 bytes +Code size: 99809 bytes +gptq:quantizing with block_size=128 percdamp=0.01 prune_pct=0.03... +gptq:quantization done in 11261ms (gptq:66 naive:0) +compression:lzma raw_size:27632195 compressed_size:14775040 ratio:1.87x +Serialized model int5+lzma: 14775044 bytes +Total submission size: 14874853 bytes +Size budget OK: 1125147 bytes remaining +[rank3]:W0326 13:15:20.592000 102636 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank3]:W0326 13:15:20.592000 102636 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank3]:W0326 13:15:20.592000 102636 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank3]:W0326 13:15:20.592000 102636 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank3]:W0326 13:15:20.592000 102636 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank7]:W0326 13:15:20.601000 102640 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank7]:W0326 13:15:20.601000 102640 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank7]:W0326 13:15:20.601000 102640 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank7]:W0326 13:15:20.601000 102640 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank7]:W0326 13:15:20.601000 102640 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank4]:W0326 13:15:20.645000 102637 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank4]:W0326 13:15:20.645000 102637 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank4]:W0326 13:15:20.645000 102637 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank4]:W0326 13:15:20.645000 102637 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank4]:W0326 13:15:20.645000 102637 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank5]:W0326 13:15:20.654000 102638 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank5]:W0326 13:15:20.654000 102638 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank5]:W0326 13:15:20.654000 102638 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank5]:W0326 13:15:20.654000 102638 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank5]:W0326 13:15:20.654000 102638 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank2]:W0326 13:15:20.731000 102635 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank2]:W0326 13:15:20.731000 102635 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank2]:W0326 13:15:20.731000 102635 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank2]:W0326 13:15:20.731000 102635 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank2]:W0326 13:15:20.731000 102635 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank1]:W0326 13:15:20.733000 102634 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank1]:W0326 13:15:20.733000 102634 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank1]:W0326 13:15:20.733000 102634 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank1]:W0326 13:15:20.733000 102634 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank1]:W0326 13:15:20.733000 102634 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank6]:W0326 13:15:20.856000 102639 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank6]:W0326 13:15:20.856000 102639 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank6]:W0326 13:15:20.856000 102639 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank6]:W0326 13:15:20.856000 102639 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank6]:W0326 13:15:20.856000 102639 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank0]:W0326 13:15:21.592000 102633 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank0]:W0326 13:15:21.592000 102633 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank0]:W0326 13:15:21.592000 102633 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank0]:W0326 13:15:21.592000 102633 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank0]:W0326 13:15:21.592000 102633 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +final_int5_roundtrip val_loss:1.9975 val_bpb:1.1831 eval_time:9034ms +final_int5_roundtrip_exact val_loss:1.99753416 val_bpb:1.18305224 +eval_budget: 9s elapsed, 591s remaining +integrated_eval: starting single-pass TTT+N-gram evaluation +integrated_eval: LoRA rank=8 on Q,V of blocks 9-10, 28672 LoRA params, 947 chunks, polyak_decay=0.998, ngram order=2-9 adaptive=True alpha=[0.12,0.6] +integrated_eval: chunk 5/947 model_bpb=1.152087 ngram_bpb=1.191128 delta=0.039041 t=3s +integrated_eval: chunk 10/947 model_bpb=1.162450 ngram_bpb=1.218559 delta=0.056109 t=6s +integrated_eval: chunk 15/947 model_bpb=1.171348 ngram_bpb=1.236910 delta=0.065562 t=9s +integrated_eval: chunk 20/947 model_bpb=1.167225 ngram_bpb=1.237705 delta=0.070481 t=13s +integrated_eval: chunk 25/947 model_bpb=1.166313 ngram_bpb=1.237437 delta=0.071123 t=16s +integrated_eval: chunk 30/947 model_bpb=1.171212 ngram_bpb=1.239919 delta=0.068707 t=19s +integrated_eval: chunk 35/947 model_bpb=1.169424 ngram_bpb=1.233245 delta=0.063820 t=22s +integrated_eval: chunk 40/947 model_bpb=1.165457 ngram_bpb=1.222604 delta=0.057148 t=25s +integrated_eval: chunk 45/947 model_bpb=1.164740 ngram_bpb=1.212757 delta=0.048018 t=28s +integrated_eval: chunk 50/947 model_bpb=1.165922 ngram_bpb=1.203051 delta=0.037129 t=32s +integrated_eval: chunk 55/947 model_bpb=1.166655 ngram_bpb=1.191041 delta=0.024385 t=35s +integrated_eval: chunk 60/947 model_bpb=1.162205 ngram_bpb=1.173324 delta=0.011119 t=38s +integrated_eval: chunk 65/947 model_bpb=1.162242 ngram_bpb=1.158728 delta=-0.003514 t=41s +integrated_eval: chunk 70/947 model_bpb=1.160361 ngram_bpb=1.141286 delta=-0.019075 t=44s +integrated_eval: chunk 75/947 model_bpb=1.160593 ngram_bpb=1.125433 delta=-0.035160 t=47s +integrated_eval: chunk 80/947 model_bpb=1.161552 ngram_bpb=1.109434 delta=-0.052119 t=50s +integrated_eval: chunk 85/947 model_bpb=1.163614 ngram_bpb=1.093458 delta=-0.070156 t=53s +integrated_eval: chunk 90/947 model_bpb=1.163775 ngram_bpb=1.076402 delta=-0.087373 t=56s +integrated_eval: chunk 95/947 model_bpb=1.166877 ngram_bpb=1.061095 delta=-0.105782 t=59s +integrated_eval: chunk 100/947 model_bpb=1.166090 ngram_bpb=1.043513 delta=-0.122577 t=62s +integrated_eval: chunk 105/947 model_bpb=1.164737 ngram_bpb=1.025587 delta=-0.139150 t=65s +integrated_eval: chunk 110/947 model_bpb=1.165893 ngram_bpb=1.009259 delta=-0.156635 t=68s +integrated_eval: chunk 115/947 model_bpb=1.165543 ngram_bpb=0.992541 delta=-0.173002 t=71s +integrated_eval: chunk 120/947 model_bpb=1.166137 ngram_bpb=0.976679 delta=-0.189457 t=74s +integrated_eval: chunk 125/947 model_bpb=1.165376 ngram_bpb=0.960453 delta=-0.204923 t=77s +integrated_eval: chunk 130/947 model_bpb=1.165341 ngram_bpb=0.944717 delta=-0.220624 t=80s +integrated_eval: chunk 135/947 model_bpb=1.164171 ngram_bpb=0.928948 delta=-0.235223 t=83s +integrated_eval: chunk 140/947 model_bpb=1.165206 ngram_bpb=0.914111 delta=-0.251095 t=86s +integrated_eval: chunk 145/947 model_bpb=1.165162 ngram_bpb=0.899063 delta=-0.266100 t=89s +integrated_eval: chunk 150/947 model_bpb=1.165477 ngram_bpb=0.884622 delta=-0.280856 t=92s +integrated_eval: chunk 155/947 model_bpb=1.165818 ngram_bpb=0.870625 delta=-0.295194 t=95s +integrated_eval: chunk 160/947 model_bpb=1.166448 ngram_bpb=0.857083 delta=-0.309365 t=98s +integrated_eval: chunk 165/947 model_bpb=1.166138 ngram_bpb=0.843695 delta=-0.322443 t=102s +integrated_eval: chunk 170/947 model_bpb=1.165717 ngram_bpb=0.830435 delta=-0.335282 t=105s +integrated_eval: chunk 175/947 model_bpb=1.166023 ngram_bpb=0.817888 delta=-0.348135 t=108s +integrated_eval: chunk 180/947 model_bpb=1.167282 ngram_bpb=0.806902 delta=-0.360380 t=111s +integrated_eval: chunk 185/947 model_bpb=1.167017 ngram_bpb=0.794998 delta=-0.372020 t=114s +integrated_eval: chunk 190/947 model_bpb=1.167104 ngram_bpb=0.783492 delta=-0.383612 t=117s +integrated_eval: chunk 195/947 model_bpb=1.167286 ngram_bpb=0.772410 delta=-0.394876 t=120s +integrated_eval: chunk 200/947 model_bpb=1.167189 ngram_bpb=0.761662 delta=-0.405528 t=123s +integrated_eval: chunk 205/947 model_bpb=1.166140 ngram_bpb=0.751022 delta=-0.415118 t=126s +integrated_eval: chunk 210/947 model_bpb=1.166185 ngram_bpb=0.740893 delta=-0.425292 t=130s +integrated_eval: chunk 215/947 model_bpb=1.166791 ngram_bpb=0.731077 delta=-0.435714 t=133s +integrated_eval: chunk 220/947 model_bpb=1.166018 ngram_bpb=0.721293 delta=-0.444725 t=136s +integrated_eval: chunk 225/947 model_bpb=1.166146 ngram_bpb=0.711856 delta=-0.454290 t=139s +integrated_eval: chunk 230/947 model_bpb=1.165928 ngram_bpb=0.702693 delta=-0.463235 t=142s +integrated_eval: chunk 235/947 model_bpb=1.165673 ngram_bpb=0.693775 delta=-0.471898 t=145s +integrated_eval: chunk 240/947 model_bpb=1.165222 ngram_bpb=0.685348 delta=-0.479874 t=148s +integrated_eval: chunk 245/947 model_bpb=1.165327 ngram_bpb=0.677206 delta=-0.488121 t=151s +integrated_eval: chunk 250/947 model_bpb=1.165189 ngram_bpb=0.669097 delta=-0.496092 t=154s +integrated_eval: chunk 255/947 model_bpb=1.164564 ngram_bpb=0.661300 delta=-0.503264 t=157s +integrated_eval: chunk 260/947 model_bpb=1.164336 ngram_bpb=0.653852 delta=-0.510484 t=160s +integrated_eval: chunk 265/947 model_bpb=1.164854 ngram_bpb=0.646573 delta=-0.518281 t=163s +integrated_eval: chunk 270/947 model_bpb=1.164992 ngram_bpb=0.639312 delta=-0.525681 t=167s +integrated_eval: chunk 275/947 model_bpb=1.164469 ngram_bpb=0.632302 delta=-0.532167 t=170s +integrated_eval: chunk 280/947 model_bpb=1.164368 ngram_bpb=0.625467 delta=-0.538901 t=173s +integrated_eval: chunk 285/947 model_bpb=1.163865 ngram_bpb=0.618956 delta=-0.544909 t=176s +integrated_eval: chunk 290/947 model_bpb=1.163638 ngram_bpb=0.612612 delta=-0.551026 t=179s +integrated_eval: chunk 295/947 model_bpb=1.163102 ngram_bpb=0.606480 delta=-0.556622 t=182s +integrated_eval: chunk 300/947 model_bpb=1.163195 ngram_bpb=0.600456 delta=-0.562739 t=185s +integrated_eval: chunk 305/947 model_bpb=1.162815 ngram_bpb=0.594623 delta=-0.568191 t=188s +integrated_eval: chunk 310/947 model_bpb=1.162780 ngram_bpb=0.588883 delta=-0.573897 t=191s +integrated_eval: chunk 315/947 model_bpb=1.162556 ngram_bpb=0.583203 delta=-0.579353 t=194s +integrated_eval: chunk 320/947 model_bpb=1.161903 ngram_bpb=0.577630 delta=-0.584272 t=197s +integrated_eval: chunk 325/947 model_bpb=1.161496 ngram_bpb=0.572276 delta=-0.589219 t=200s +integrated_eval: chunk 330/947 model_bpb=1.161210 ngram_bpb=0.567113 delta=-0.594097 t=203s +integrated_eval: chunk 335/947 model_bpb=1.161008 ngram_bpb=0.562019 delta=-0.598989 t=206s +integrated_eval: chunk 340/947 model_bpb=1.160253 ngram_bpb=0.556970 delta=-0.603283 t=209s +integrated_eval: chunk 345/947 model_bpb=1.160365 ngram_bpb=0.552108 delta=-0.608256 t=212s +integrated_eval: chunk 350/947 model_bpb=1.159802 ngram_bpb=0.547305 delta=-0.612497 t=215s +integrated_eval: chunk 355/947 model_bpb=1.159594 ngram_bpb=0.542767 delta=-0.616827 t=218s +integrated_eval: chunk 360/947 model_bpb=1.159345 ngram_bpb=0.538327 delta=-0.621018 t=221s +integrated_eval: chunk 365/947 model_bpb=1.159605 ngram_bpb=0.534011 delta=-0.625594 t=224s +integrated_eval: chunk 370/947 model_bpb=1.159875 ngram_bpb=0.529771 delta=-0.630104 t=227s +integrated_eval: chunk 375/947 model_bpb=1.159294 ngram_bpb=0.525560 delta=-0.633734 t=230s +integrated_eval: chunk 380/947 model_bpb=1.159485 ngram_bpb=0.521579 delta=-0.637905 t=233s +integrated_eval: chunk 385/947 model_bpb=1.159352 ngram_bpb=0.517681 delta=-0.641670 t=236s +integrated_eval: chunk 390/947 model_bpb=1.159544 ngram_bpb=0.513875 delta=-0.645668 t=239s +integrated_eval: chunk 395/947 model_bpb=1.159542 ngram_bpb=0.510185 delta=-0.649357 t=242s +integrated_eval: chunk 400/947 model_bpb=1.159417 ngram_bpb=0.506398 delta=-0.653019 t=245s +integrated_eval: chunk 405/947 model_bpb=1.159324 ngram_bpb=0.502816 delta=-0.656508 t=248s +integrated_eval: chunk 410/947 model_bpb=1.159275 ngram_bpb=0.499244 delta=-0.660031 t=251s +integrated_eval: chunk 415/947 model_bpb=1.159059 ngram_bpb=0.495816 delta=-0.663243 t=255s +integrated_eval: chunk 420/947 model_bpb=1.158922 ngram_bpb=0.492482 delta=-0.666441 t=258s +integrated_eval: chunk 425/947 model_bpb=1.158884 ngram_bpb=0.489220 delta=-0.669664 t=261s +integrated_eval: chunk 430/947 model_bpb=1.159025 ngram_bpb=0.486046 delta=-0.672980 t=264s +integrated_eval: chunk 435/947 model_bpb=1.159177 ngram_bpb=0.482857 delta=-0.676320 t=267s +integrated_eval: chunk 440/947 model_bpb=1.159351 ngram_bpb=0.479864 delta=-0.679488 t=270s +integrated_eval: chunk 445/947 model_bpb=1.158895 ngram_bpb=0.476935 delta=-0.681960 t=273s +integrated_eval: chunk 450/947 model_bpb=1.159006 ngram_bpb=0.474079 delta=-0.684926 t=276s +integrated_eval: chunk 455/947 model_bpb=1.158806 ngram_bpb=0.471173 delta=-0.687633 t=279s +integrated_eval: chunk 460/947 model_bpb=1.158942 ngram_bpb=0.468334 delta=-0.690608 t=282s +integrated_eval: chunk 465/947 model_bpb=1.158820 ngram_bpb=0.465646 delta=-0.693174 t=285s +integrated_eval: chunk 470/947 model_bpb=1.159206 ngram_bpb=0.462866 delta=-0.696339 t=288s +integrated_eval: chunk 475/947 model_bpb=1.159638 ngram_bpb=0.460196 delta=-0.699442 t=291s +integrated_eval: chunk 480/947 model_bpb=1.159706 ngram_bpb=0.457425 delta=-0.702281 t=294s +integrated_eval: chunk 485/947 model_bpb=1.160288 ngram_bpb=0.454785 delta=-0.705503 t=297s +integrated_eval: chunk 490/947 model_bpb=1.160428 ngram_bpb=0.452117 delta=-0.708312 t=300s +integrated_eval: chunk 495/947 model_bpb=1.160474 ngram_bpb=0.449562 delta=-0.710912 t=303s +integrated_eval: chunk 500/947 model_bpb=1.160712 ngram_bpb=0.447028 delta=-0.713684 t=307s +integrated_eval: chunk 505/947 model_bpb=1.160970 ngram_bpb=0.444594 delta=-0.716376 t=310s +integrated_eval: chunk 510/947 model_bpb=1.161219 ngram_bpb=0.442161 delta=-0.719059 t=312s +integrated_eval: chunk 515/947 model_bpb=1.161676 ngram_bpb=0.439731 delta=-0.721946 t=316s +integrated_eval: chunk 520/947 model_bpb=1.162146 ngram_bpb=0.437344 delta=-0.724802 t=319s +integrated_eval: chunk 525/947 model_bpb=1.162063 ngram_bpb=0.434996 delta=-0.727067 t=322s +integrated_eval: chunk 530/947 model_bpb=1.162259 ngram_bpb=0.432680 delta=-0.729579 t=325s +integrated_eval: chunk 535/947 model_bpb=1.162405 ngram_bpb=0.430515 delta=-0.731891 t=328s +integrated_eval: chunk 540/947 model_bpb=1.162479 ngram_bpb=0.428272 delta=-0.734207 t=331s +integrated_eval: chunk 545/947 model_bpb=1.162723 ngram_bpb=0.426075 delta=-0.736648 t=334s +integrated_eval: chunk 550/947 model_bpb=1.162940 ngram_bpb=0.423945 delta=-0.738995 t=337s +integrated_eval: chunk 555/947 model_bpb=1.162732 ngram_bpb=0.421840 delta=-0.740892 t=340s +integrated_eval: chunk 560/947 model_bpb=1.162559 ngram_bpb=0.419760 delta=-0.742799 t=342s +integrated_eval: chunk 565/947 model_bpb=1.162404 ngram_bpb=0.417730 delta=-0.744674 t=345s +integrated_eval: chunk 570/947 model_bpb=1.162127 ngram_bpb=0.415690 delta=-0.746437 t=349s +integrated_eval: chunk 575/947 model_bpb=1.162226 ngram_bpb=0.413737 delta=-0.748489 t=352s +integrated_eval: chunk 580/947 model_bpb=1.162148 ngram_bpb=0.411737 delta=-0.750411 t=355s +integrated_eval: chunk 585/947 model_bpb=1.161870 ngram_bpb=0.409769 delta=-0.752101 t=358s +integrated_eval: chunk 590/947 model_bpb=1.161664 ngram_bpb=0.407838 delta=-0.753827 t=361s +integrated_eval: chunk 595/947 model_bpb=1.161839 ngram_bpb=0.405951 delta=-0.755888 t=364s +integrated_eval: chunk 600/947 model_bpb=1.162067 ngram_bpb=0.404115 delta=-0.757952 t=367s +integrated_eval: chunk 605/947 model_bpb=1.161681 ngram_bpb=0.402316 delta=-0.759366 t=370s +integrated_eval: chunk 610/947 model_bpb=1.162113 ngram_bpb=0.400615 delta=-0.761497 t=373s +integrated_eval: chunk 615/947 model_bpb=1.161989 ngram_bpb=0.398840 delta=-0.763149 t=376s +integrated_eval: chunk 620/947 model_bpb=1.161742 ngram_bpb=0.397056 delta=-0.764687 t=379s +integrated_eval: chunk 625/947 model_bpb=1.161314 ngram_bpb=0.395391 delta=-0.765924 t=382s +integrated_eval: chunk 630/947 model_bpb=1.161013 ngram_bpb=0.393719 delta=-0.767294 t=385s +integrated_eval: chunk 635/947 model_bpb=1.160866 ngram_bpb=0.392049 delta=-0.768818 t=388s +integrated_eval: chunk 640/947 model_bpb=1.160534 ngram_bpb=0.390355 delta=-0.770179 t=391s +integrated_eval: chunk 645/947 model_bpb=1.160305 ngram_bpb=0.388724 delta=-0.771581 t=394s +integrated_eval: chunk 650/947 model_bpb=1.160255 ngram_bpb=0.387129 delta=-0.773126 t=397s +integrated_eval: chunk 655/947 model_bpb=1.160001 ngram_bpb=0.385554 delta=-0.774447 t=400s +integrated_eval: chunk 660/947 model_bpb=1.159715 ngram_bpb=0.383961 delta=-0.775754 t=403s +integrated_eval: chunk 665/947 model_bpb=1.159468 ngram_bpb=0.382411 delta=-0.777057 t=406s +integrated_eval: chunk 670/947 model_bpb=1.159383 ngram_bpb=0.380900 delta=-0.778483 t=409s +integrated_eval: chunk 675/947 model_bpb=1.159257 ngram_bpb=0.379434 delta=-0.779822 t=412s +integrated_eval: chunk 680/947 model_bpb=1.159420 ngram_bpb=0.378044 delta=-0.781376 t=415s +integrated_eval: chunk 685/947 model_bpb=1.159601 ngram_bpb=0.376661 delta=-0.782940 t=418s +integrated_eval: chunk 690/947 model_bpb=1.160004 ngram_bpb=0.375335 delta=-0.784669 t=421s +integrated_eval: chunk 695/947 model_bpb=1.159767 ngram_bpb=0.373928 delta=-0.785839 t=424s +integrated_eval: chunk 700/947 model_bpb=1.159835 ngram_bpb=0.372576 delta=-0.787259 t=427s +integrated_eval: chunk 705/947 model_bpb=1.160004 ngram_bpb=0.371280 delta=-0.788724 t=430s +integrated_eval: chunk 710/947 model_bpb=1.160101 ngram_bpb=0.369946 delta=-0.790155 t=433s +integrated_eval: chunk 715/947 model_bpb=1.160059 ngram_bpb=0.368744 delta=-0.791316 t=436s +integrated_eval: chunk 720/947 model_bpb=1.160527 ngram_bpb=0.367470 delta=-0.793057 t=439s +integrated_eval: chunk 725/947 model_bpb=1.160526 ngram_bpb=0.366222 delta=-0.794303 t=442s +integrated_eval: chunk 730/947 model_bpb=1.160437 ngram_bpb=0.364990 delta=-0.795447 t=445s +integrated_eval: chunk 735/947 model_bpb=1.161095 ngram_bpb=0.363762 delta=-0.797333 t=448s +integrated_eval: chunk 740/947 model_bpb=1.161045 ngram_bpb=0.362541 delta=-0.798504 t=451s +integrated_eval: chunk 745/947 model_bpb=1.161382 ngram_bpb=0.361390 delta=-0.799993 t=454s +integrated_eval: chunk 750/947 model_bpb=1.161448 ngram_bpb=0.360237 delta=-0.801211 t=457s +integrated_eval: chunk 755/947 model_bpb=1.161548 ngram_bpb=0.359058 delta=-0.802490 t=460s +integrated_eval: chunk 760/947 model_bpb=1.161624 ngram_bpb=0.357926 delta=-0.803697 t=463s +integrated_eval: chunk 765/947 model_bpb=1.161884 ngram_bpb=0.356799 delta=-0.805085 t=466s +integrated_eval: chunk 770/947 model_bpb=1.161967 ngram_bpb=0.355636 delta=-0.806331 t=469s +integrated_eval: chunk 775/947 model_bpb=1.162292 ngram_bpb=0.354510 delta=-0.807782 t=472s +integrated_eval: chunk 780/947 model_bpb=1.162387 ngram_bpb=0.353381 delta=-0.809006 t=475s +integrated_eval: chunk 785/947 model_bpb=1.162597 ngram_bpb=0.352279 delta=-0.810318 t=478s +integrated_eval: chunk 790/947 model_bpb=1.162825 ngram_bpb=0.351213 delta=-0.811611 t=481s +integrated_eval: chunk 795/947 model_bpb=1.162843 ngram_bpb=0.350075 delta=-0.812768 t=484s +integrated_eval: chunk 800/947 model_bpb=1.163111 ngram_bpb=0.348975 delta=-0.814137 t=487s +integrated_eval: chunk 805/947 model_bpb=1.163324 ngram_bpb=0.347870 delta=-0.815454 t=490s +integrated_eval: chunk 810/947 model_bpb=1.163244 ngram_bpb=0.346836 delta=-0.816407 t=493s +integrated_eval: chunk 815/947 model_bpb=1.163328 ngram_bpb=0.345774 delta=-0.817554 t=496s +integrated_eval: chunk 820/947 model_bpb=1.163376 ngram_bpb=0.344718 delta=-0.818657 t=499s +integrated_eval: chunk 825/947 model_bpb=1.163425 ngram_bpb=0.343686 delta=-0.819739 t=502s +integrated_eval: chunk 830/947 model_bpb=1.163602 ngram_bpb=0.342662 delta=-0.820939 t=505s +integrated_eval: chunk 835/947 model_bpb=1.163863 ngram_bpb=0.341662 delta=-0.822201 t=508s +integrated_eval: chunk 840/947 model_bpb=1.163954 ngram_bpb=0.340649 delta=-0.823304 t=511s +integrated_eval: chunk 845/947 model_bpb=1.164059 ngram_bpb=0.339653 delta=-0.824406 t=514s +integrated_eval: chunk 850/947 model_bpb=1.164330 ngram_bpb=0.338666 delta=-0.825664 t=517s +integrated_eval: chunk 855/947 model_bpb=1.164351 ngram_bpb=0.337676 delta=-0.826676 t=520s +integrated_eval: chunk 860/947 model_bpb=1.164278 ngram_bpb=0.336717 delta=-0.827561 t=523s +integrated_eval: chunk 865/947 model_bpb=1.164316 ngram_bpb=0.335768 delta=-0.828548 t=526s +integrated_eval: chunk 870/947 model_bpb=1.164134 ngram_bpb=0.334810 delta=-0.829324 t=529s +integrated_eval: chunk 875/947 model_bpb=1.164005 ngram_bpb=0.333859 delta=-0.830146 t=532s +integrated_eval: chunk 880/947 model_bpb=1.164081 ngram_bpb=0.332936 delta=-0.831145 t=536s +integrated_eval: chunk 885/947 model_bpb=1.164063 ngram_bpb=0.332009 delta=-0.832054 t=539s +integrated_eval: chunk 890/947 model_bpb=1.164012 ngram_bpb=0.331115 delta=-0.832897 t=542s +integrated_eval: chunk 895/947 model_bpb=1.163705 ngram_bpb=0.330211 delta=-0.833494 t=545s +integrated_eval: chunk 900/947 model_bpb=1.163641 ngram_bpb=0.329317 delta=-0.834324 t=548s +integrated_eval: chunk 905/947 model_bpb=1.163600 ngram_bpb=0.328437 delta=-0.835163 t=551s +integrated_eval: chunk 910/947 model_bpb=1.163693 ngram_bpb=0.327562 delta=-0.836131 t=554s +integrated_eval: chunk 915/947 model_bpb=1.163565 ngram_bpb=0.326680 delta=-0.836885 t=557s +integrated_eval: chunk 920/947 model_bpb=1.163641 ngram_bpb=0.325865 delta=-0.837776 t=560s +integrated_eval: chunk 925/947 model_bpb=1.163423 ngram_bpb=0.325028 delta=-0.838396 t=563s +integrated_eval: chunk 930/947 model_bpb=1.163385 ngram_bpb=0.324182 delta=-0.839203 t=566s +integrated_eval: chunk 935/947 model_bpb=1.163358 ngram_bpb=0.323368 delta=-0.839990 t=569s +integrated_eval: chunk 940/947 model_bpb=1.163148 ngram_bpb=0.322554 delta=-0.840594 t=572s +integrated_eval: chunk 945/947 model_bpb=1.163170 ngram_bpb=0.321763 delta=-0.841407 t=575s +integrated_eval: DONE model_bpb=1.1632 ngram_bpb=0.3216 delta=-0.8417 elapsed=576s +final_integrated model_val_loss:1.9641 model_val_bpb:1.1632 +final_integrated ngram_val_loss:0.5429 ngram_val_bpb:0.3216 +final_integrated delta_bpb:-0.8417 +final_integrated_exact model_bpb:1.16322535 ngram_bpb:0.32155373 +final_integrated_time:576383ms diff --git a/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/train_seed42.log b/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/train_seed42.log new file mode 100644 index 000000000..66e60c1d7 --- /dev/null +++ b/records/track_10min_16mb/2026-03-26_ComplementaryNgram65K_Int5GPTQ_LoRATTT/train_seed42.log @@ -0,0 +1,523 @@ +W0326 12:39:51.344000 100312 torch/distributed/run.py:803] +W0326 12:39:51.344000 100312 torch/distributed/run.py:803] ***************************************** +W0326 12:39:51.344000 100312 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0326 12:39:51.344000 100312 torch/distributed/run.py:803] ***************************************** +logs/d6b5c142-1eaf-490e-8361-1c5a98426a5e.txt +bpb:sp=./data/tokenizers/fineweb_1024_bpe.model +tl:dataset:fineweb10B_sp1024 train_shards:80 +val:./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +v:opti-ms2 act:lr09sq xsa:last_4 qat:sr wd:WSD gptq:fh ttt:lora_polyak_adamw compression:lzma optimizer:PM prog_seq:enabled vrl:True gated_attn:True decoder_lr_mult:2.0 ngram:True ngram_order:9 complement_alpha:0.5 prune_pct:0.03 +comp:on alpha=0.5 +model_params:27301064 +mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0 +XSA:last_4 active_layers:[7, 8, 9, 10] +world_size:8 gas:1 +seed:42 +lr_schedule:WSD wsd_stable_frac:0.75 decay_shape:cosine qat_trigger_frac:0.85 +optimizer:PM (NS) +ttt_config: optimizer=AdamW(LoRA) epochs=1 lr=0.003 lora_rank=8 polyak_decay=0.998 chunk=65536 grad_clip=1.0 temperature=0.98 +decoder_lr_mult:2.0 encoder_matrix_lr:0.025 decoder_matrix_lr:0.05 +gated_attn:enabled layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +value_residual:enabled layers:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +gptq_config: samples=256 block_size=128 damp=0.01 +prune_pct:0.03 +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/20000 val_loss:6.9303 val_bpb:4.1045 train_time:0ms step_avg:0.04ms seq_len:2048 +step:1/20000 train_loss:6.9317 train_time:196ms step_avg:196.07ms seq_len:2048 +step:2/20000 train_loss:8.4864 train_time:315ms step_avg:157.48ms seq_len:2048 +step:3/20000 train_loss:7.7149 train_time:415ms step_avg:138.49ms seq_len:2048 +step:4/20000 train_loss:7.0917 train_time:514ms step_avg:128.52ms seq_len:2048 +step:5/20000 train_loss:6.8974 train_time:614ms step_avg:122.87ms seq_len:2048 +step:6/20000 train_loss:6.7654 train_time:713ms step_avg:118.84ms seq_len:2048 +step:7/20000 train_loss:6.6837 train_time:812ms step_avg:115.99ms seq_len:2048 +step:8/20000 train_loss:6.6431 train_time:909ms step_avg:113.59ms seq_len:2048 +step:9/20000 train_loss:6.3277 train_time:1008ms step_avg:112.02ms seq_len:2048 +step:10/20000 train_loss:5.9285 train_time:1106ms step_avg:110.63ms seq_len:2048 +step:500/20000 train_loss:2.3229 train_time:50503ms step_avg:101.01ms seq_len:2048 +step:1000/20000 train_loss:2.2379 train_time:102003ms step_avg:102.00ms seq_len:2048 +step:1500/20000 train_loss:2.1980 train_time:152902ms step_avg:101.93ms seq_len:2048 +step:2000/20000 train_loss:2.0515 train_time:203717ms step_avg:101.86ms seq_len:2048 +step:2500/20000 train_loss:2.1581 train_time:255203ms step_avg:102.08ms seq_len:2048 +step:3000/20000 train_loss:2.1532 train_time:306475ms step_avg:102.16ms seq_len:2048 +step:3500/20000 train_loss:2.1749 train_time:357876ms step_avg:102.25ms seq_len:2048 +step:4000/20000 train_loss:1.9868 train_time:409896ms step_avg:102.47ms seq_len:2048 +step:4000/20000 val_loss:2.0960 val_bpb:1.2414 train_time:409903ms step_avg:102.48ms seq_len:2048 +step:4500/20000 train_loss:2.1441 train_time:458408ms step_avg:101.87ms seq_len:2048 +step:5000/20000 train_loss:2.1105 train_time:508970ms step_avg:101.79ms seq_len:2048 +late_qat:enabled step:5011 scale:0.6538 elapsed_ms:510073 trigger:wallclock@0.85 +[rank3]:W0326 12:49:57.054000 100383 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank3]:W0326 12:49:57.054000 100383 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank3]:W0326 12:49:57.054000 100383 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank3]:W0326 12:49:57.054000 100383 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank3]:W0326 12:49:57.054000 100383 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank3]:W0326 12:49:57.054000 100383 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank3]:W0326 12:49:57.054000 100383 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank3]:W0326 12:49:57.054000 100383 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank3]:W0326 12:49:57.054000 100383 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank3]:W0326 12:49:57.054000 100383 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank3]:W0326 12:49:57.054000 100383 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank3]:W0326 12:49:57.054000 100383 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank3]:W0326 12:49:57.054000 100383 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank3]:W0326 12:49:57.054000 100383 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank3]:W0326 12:49:57.054000 100383 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank2]:W0326 12:49:57.056000 100382 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank2]:W0326 12:49:57.056000 100382 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank2]:W0326 12:49:57.056000 100382 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank2]:W0326 12:49:57.056000 100382 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank2]:W0326 12:49:57.056000 100382 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank2]:W0326 12:49:57.056000 100382 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank2]:W0326 12:49:57.056000 100382 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank2]:W0326 12:49:57.056000 100382 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank2]:W0326 12:49:57.056000 100382 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank2]:W0326 12:49:57.056000 100382 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank2]:W0326 12:49:57.056000 100382 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank2]:W0326 12:49:57.056000 100382 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank2]:W0326 12:49:57.056000 100382 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank2]:W0326 12:49:57.056000 100382 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank2]:W0326 12:49:57.056000 100382 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank5]:W0326 12:49:57.067000 100385 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank5]:W0326 12:49:57.067000 100385 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank5]:W0326 12:49:57.067000 100385 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank5]:W0326 12:49:57.067000 100385 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank5]:W0326 12:49:57.067000 100385 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank5]:W0326 12:49:57.067000 100385 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank5]:W0326 12:49:57.067000 100385 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank5]:W0326 12:49:57.067000 100385 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank5]:W0326 12:49:57.067000 100385 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank5]:W0326 12:49:57.067000 100385 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank5]:W0326 12:49:57.067000 100385 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank5]:W0326 12:49:57.067000 100385 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank5]:W0326 12:49:57.067000 100385 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank5]:W0326 12:49:57.067000 100385 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank5]:W0326 12:49:57.067000 100385 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank0]:W0326 12:49:57.091000 100380 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank0]:W0326 12:49:57.091000 100380 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank0]:W0326 12:49:57.091000 100380 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank0]:W0326 12:49:57.091000 100380 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank0]:W0326 12:49:57.091000 100380 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank0]:W0326 12:49:57.091000 100380 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank0]:W0326 12:49:57.091000 100380 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank0]:W0326 12:49:57.091000 100380 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank0]:W0326 12:49:57.091000 100380 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank0]:W0326 12:49:57.091000 100380 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank0]:W0326 12:49:57.091000 100380 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank0]:W0326 12:49:57.091000 100380 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank0]:W0326 12:49:57.091000 100380 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank0]:W0326 12:49:57.091000 100380 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank0]:W0326 12:49:57.091000 100380 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank4]:W0326 12:49:57.093000 100384 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank4]:W0326 12:49:57.093000 100384 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank4]:W0326 12:49:57.093000 100384 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank4]:W0326 12:49:57.093000 100384 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank4]:W0326 12:49:57.093000 100384 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank4]:W0326 12:49:57.093000 100384 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank4]:W0326 12:49:57.093000 100384 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank4]:W0326 12:49:57.093000 100384 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank4]:W0326 12:49:57.093000 100384 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank4]:W0326 12:49:57.093000 100384 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank4]:W0326 12:49:57.093000 100384 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank4]:W0326 12:49:57.093000 100384 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank4]:W0326 12:49:57.093000 100384 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank4]:W0326 12:49:57.093000 100384 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank4]:W0326 12:49:57.093000 100384 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank1]:W0326 12:49:57.097000 100381 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank1]:W0326 12:49:57.097000 100381 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank1]:W0326 12:49:57.097000 100381 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank1]:W0326 12:49:57.097000 100381 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank1]:W0326 12:49:57.097000 100381 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank1]:W0326 12:49:57.097000 100381 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank1]:W0326 12:49:57.097000 100381 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank1]:W0326 12:49:57.097000 100381 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank1]:W0326 12:49:57.097000 100381 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank1]:W0326 12:49:57.097000 100381 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank1]:W0326 12:49:57.097000 100381 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank1]:W0326 12:49:57.097000 100381 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank1]:W0326 12:49:57.097000 100381 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank1]:W0326 12:49:57.097000 100381 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank1]:W0326 12:49:57.097000 100381 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank6]:W0326 12:49:57.737000 100386 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank6]:W0326 12:49:57.737000 100386 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank6]:W0326 12:49:57.737000 100386 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank6]:W0326 12:49:57.737000 100386 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank6]:W0326 12:49:57.737000 100386 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank6]:W0326 12:49:57.737000 100386 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank6]:W0326 12:49:57.737000 100386 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank6]:W0326 12:49:57.737000 100386 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank6]:W0326 12:49:57.737000 100386 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank6]:W0326 12:49:57.737000 100386 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank6]:W0326 12:49:57.737000 100386 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank6]:W0326 12:49:57.737000 100386 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank6]:W0326 12:49:57.737000 100386 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank6]:W0326 12:49:57.737000 100386 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank6]:W0326 12:49:57.737000 100386 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank7]:W0326 12:49:57.741000 100387 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break from `Tensor.item()`, consider setting: +[rank7]:W0326 12:49:57.741000 100387 torch/_dynamo/variables/tensor.py:1048] [0/4] torch._dynamo.config.capture_scalar_outputs = True +[rank7]:W0326 12:49:57.741000 100387 torch/_dynamo/variables/tensor.py:1048] [0/4] or: +[rank7]:W0326 12:49:57.741000 100387 torch/_dynamo/variables/tensor.py:1048] [0/4] env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 +[rank7]:W0326 12:49:57.741000 100387 torch/_dynamo/variables/tensor.py:1048] [0/4] to include these operations in the captured graph. +[rank7]:W0326 12:49:57.741000 100387 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank7]:W0326 12:49:57.741000 100387 torch/_dynamo/variables/tensor.py:1048] [0/4] Graph break: from user code at: +[rank7]:W0326 12:49:57.741000 100387 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1496, in forward +[rank7]:W0326 12:49:57.741000 100387 torch/_dynamo/variables/tensor.py:1048] [0/4] x = x + self.bigram(input_ids) +[rank7]:W0326 12:49:57.741000 100387 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1320, in forward +[rank7]:W0326 12:49:57.741000 100387 torch/_dynamo/variables/tensor.py:1048] [0/4] h = self.proj(h) +[rank7]:W0326 12:49:57.741000 100387 torch/_dynamo/variables/tensor.py:1048] [0/4] File "/workspace/pgolf/train_gpt.py", line 1158, in forward +[rank7]:W0326 12:49:57.741000 100387 torch/_dynamo/variables/tensor.py:1048] [0/4] alpha = self._soft_round_alpha.item() +[rank7]:W0326 12:49:57.741000 100387 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank7]:W0326 12:49:57.741000 100387 torch/_dynamo/variables/tensor.py:1048] [0/4] +[rank6]:W0326 12:50:04.651000 100386 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank6]:W0326 12:50:04.651000 100386 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank6]:W0326 12:50:04.651000 100386 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank6]:W0326 12:50:04.651000 100386 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank6]:W0326 12:50:04.651000 100386 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank7]:W0326 12:50:04.664000 100387 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank7]:W0326 12:50:04.664000 100387 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank7]:W0326 12:50:04.664000 100387 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank7]:W0326 12:50:04.664000 100387 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank7]:W0326 12:50:04.664000 100387 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank2]:W0326 12:50:04.761000 100382 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank2]:W0326 12:50:04.761000 100382 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank2]:W0326 12:50:04.761000 100382 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank2]:W0326 12:50:04.761000 100382 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank2]:W0326 12:50:04.761000 100382 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank3]:W0326 12:50:04.957000 100383 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank3]:W0326 12:50:04.957000 100383 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank3]:W0326 12:50:04.957000 100383 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank3]:W0326 12:50:04.957000 100383 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank3]:W0326 12:50:04.957000 100383 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank5]:W0326 12:50:05.086000 100385 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank5]:W0326 12:50:05.086000 100385 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank5]:W0326 12:50:05.086000 100385 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank5]:W0326 12:50:05.086000 100385 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank5]:W0326 12:50:05.086000 100385 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank4]:W0326 12:50:05.517000 100384 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank4]:W0326 12:50:05.517000 100384 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank4]:W0326 12:50:05.517000 100384 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank4]:W0326 12:50:05.517000 100384 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank4]:W0326 12:50:05.517000 100384 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank1]:W0326 12:50:05.537000 100381 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank1]:W0326 12:50:05.537000 100381 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank1]:W0326 12:50:05.537000 100381 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank1]:W0326 12:50:05.537000 100381 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank1]:W0326 12:50:05.537000 100381 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank0]:W0326 12:50:05.698000 100380 torch/_dynamo/convert_frame.py:1358] [9/8] torch._dynamo hit config.recompile_limit (8) +[rank0]:W0326 12:50:05.698000 100380 torch/_dynamo/convert_frame.py:1358] [9/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1373) +[rank0]:W0326 12:50:05.698000 100380 torch/_dynamo/convert_frame.py:1358] [9/8] last reason: 9/7: self.ln_scale_factor == 0.35355339059327373 # self.attn_norm(x_in) * self.ln_scale_factor, v_embed=v_embed, # orkspace/pgolf/train_gpt.py:1378 in forward +[rank0]:W0326 12:50:05.698000 100380 torch/_dynamo/convert_frame.py:1358] [9/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank0]:W0326 12:50:05.698000 100380 torch/_dynamo/convert_frame.py:1358] [9/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank5]:W0326 12:50:09.140000 100385 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank5]:W0326 12:50:09.140000 100385 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank5]:W0326 12:50:09.140000 100385 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank5]:W0326 12:50:09.140000 100385 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank5]:W0326 12:50:09.140000 100385 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank6]:W0326 12:50:09.179000 100386 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank6]:W0326 12:50:09.179000 100386 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank6]:W0326 12:50:09.179000 100386 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank6]:W0326 12:50:09.179000 100386 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank6]:W0326 12:50:09.179000 100386 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank0]:W0326 12:50:09.193000 100380 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank0]:W0326 12:50:09.193000 100380 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank0]:W0326 12:50:09.193000 100380 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank0]:W0326 12:50:09.193000 100380 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank0]:W0326 12:50:09.193000 100380 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank4]:W0326 12:50:09.202000 100384 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank4]:W0326 12:50:09.202000 100384 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank4]:W0326 12:50:09.202000 100384 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank4]:W0326 12:50:09.202000 100384 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank4]:W0326 12:50:09.202000 100384 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank3]:W0326 12:50:09.205000 100383 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank3]:W0326 12:50:09.205000 100383 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank3]:W0326 12:50:09.205000 100383 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank3]:W0326 12:50:09.205000 100383 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank3]:W0326 12:50:09.205000 100383 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank1]:W0326 12:50:09.210000 100381 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank1]:W0326 12:50:09.210000 100381 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank1]:W0326 12:50:09.210000 100381 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank1]:W0326 12:50:09.210000 100381 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank1]:W0326 12:50:09.210000 100381 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank7]:W0326 12:50:09.238000 100387 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank7]:W0326 12:50:09.238000 100387 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank7]:W0326 12:50:09.238000 100387 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank7]:W0326 12:50:09.238000 100387 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank7]:W0326 12:50:09.238000 100387 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank2]:W0326 12:50:09.244000 100382 torch/_dynamo/convert_frame.py:1358] [4/8] torch._dynamo hit config.recompile_limit (8) +[rank2]:W0326 12:50:09.244000 100382 torch/_dynamo/convert_frame.py:1358] [4/8] function: 'torch_dynamo_resume_in_forward_at_1158' (/workspace/pgolf/train_gpt.py:1158) +[rank2]:W0326 12:50:09.244000 100382 torch/_dynamo/convert_frame.py:1358] [4/8] last reason: 4/7: tensor 'w' size mismatch at index 0. expected 512, actual 256 +[rank2]:W0326 12:50:09.244000 100382 torch/_dynamo/convert_frame.py:1358] [4/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank2]:W0326 12:50:09.244000 100382 torch/_dynamo/convert_frame.py:1358] [4/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +step:5437/20000 val_loss:1.9782 val_bpb:1.1716 train_time:600114ms step_avg:110.38ms seq_len:2048 +early_stop train_time:600114ms step:5437/20000 +peak_mem: 24934 MiB reserved: 25978 MiB +steps:5437 +swa_n:4 +swa:blending EMA+SWA (0.7*EMA + 0.3*SWA, 4 SWA checkpoints) +DIAGNOSTIC post_ema val_loss:1.9795 val_bpb:1.1724 eval_time:2094ms +gptq:calibrating with 256 samples... +gptq:calibration done in 1181ms, 66 layers +Serialized model: 106893257 bytes +Code size: 99809 bytes +gptq:quantizing with block_size=128 percdamp=0.01 prune_pct=0.03... +gptq:quantization done in 11203ms (gptq:66 naive:0) +compression:lzma raw_size:27632195 compressed_size:14826304 ratio:1.86x +Serialized model int5+lzma: 14826308 bytes +Total submission size: 14926117 bytes +Size budget OK: 1073883 bytes remaining +[rank4]:W0326 12:52:39.523000 100384 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank4]:W0326 12:52:39.523000 100384 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank4]:W0326 12:52:39.523000 100384 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank4]:W0326 12:52:39.523000 100384 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank4]:W0326 12:52:39.523000 100384 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank6]:W0326 12:52:39.656000 100386 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank6]:W0326 12:52:39.656000 100386 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank6]:W0326 12:52:39.656000 100386 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank6]:W0326 12:52:39.656000 100386 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank6]:W0326 12:52:39.656000 100386 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank1]:W0326 12:52:39.670000 100381 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank1]:W0326 12:52:39.670000 100381 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank1]:W0326 12:52:39.670000 100381 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank1]:W0326 12:52:39.670000 100381 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank1]:W0326 12:52:39.670000 100381 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank7]:W0326 12:52:39.744000 100387 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank7]:W0326 12:52:39.744000 100387 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank7]:W0326 12:52:39.744000 100387 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank7]:W0326 12:52:39.744000 100387 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank7]:W0326 12:52:39.744000 100387 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank3]:W0326 12:52:39.852000 100383 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank3]:W0326 12:52:39.852000 100383 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank3]:W0326 12:52:39.852000 100383 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank3]:W0326 12:52:39.852000 100383 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank3]:W0326 12:52:39.852000 100383 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank5]:W0326 12:52:39.876000 100385 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank5]:W0326 12:52:39.876000 100385 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank5]:W0326 12:52:39.876000 100385 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank5]:W0326 12:52:39.876000 100385 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank5]:W0326 12:52:39.876000 100385 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank0]:W0326 12:52:40.560000 100380 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank0]:W0326 12:52:40.560000 100380 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank0]:W0326 12:52:40.560000 100380 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank0]:W0326 12:52:40.560000 100380 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank0]:W0326 12:52:40.560000 100380 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +[rank2]:W0326 12:52:40.976000 100382 torch/_dynamo/convert_frame.py:1358] [0/8] torch._dynamo hit config.recompile_limit (8) +[rank2]:W0326 12:52:40.976000 100382 torch/_dynamo/convert_frame.py:1358] [0/8] function: 'forward' (/workspace/pgolf/train_gpt.py:1493) +[rank2]:W0326 12:52:40.976000 100382 torch/_dynamo/convert_frame.py:1358] [0/8] last reason: 0/7: self._modules['blocks']._modules['0']._modules['attn']._modules['rotary']._cos_cached is None # self._cos_cached is None # orkspace/pgolf/train_gpt.py:1200 in forward +[rank2]:W0326 12:52:40.976000 100382 torch/_dynamo/convert_frame.py:1358] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles". +[rank2]:W0326 12:52:40.976000 100382 torch/_dynamo/convert_frame.py:1358] [0/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html +final_int5_roundtrip val_loss:1.9913 val_bpb:1.1794 eval_time:9559ms +final_int5_roundtrip_exact val_loss:1.99134378 val_bpb:1.17938595 +eval_budget: 10s elapsed, 590s remaining +integrated_eval: starting single-pass TTT+N-gram evaluation +integrated_eval: LoRA rank=8 on Q,V of blocks 9-10, 28672 LoRA params, 947 chunks, polyak_decay=0.998, ngram order=2-9 adaptive=True alpha=[0.12,0.6] +integrated_eval: chunk 5/947 model_bpb=1.148766 ngram_bpb=1.187873 delta=0.039107 t=4s +integrated_eval: chunk 10/947 model_bpb=1.159124 ngram_bpb=1.215160 delta=0.056036 t=7s +integrated_eval: chunk 15/947 model_bpb=1.167699 ngram_bpb=1.233252 delta=0.065553 t=10s +integrated_eval: chunk 20/947 model_bpb=1.163665 ngram_bpb=1.234059 delta=0.070394 t=13s +integrated_eval: chunk 25/947 model_bpb=1.162883 ngram_bpb=1.233776 delta=0.070892 t=16s +integrated_eval: chunk 30/947 model_bpb=1.167605 ngram_bpb=1.236002 delta=0.068397 t=19s +integrated_eval: chunk 35/947 model_bpb=1.165784 ngram_bpb=1.229355 delta=0.063571 t=22s +integrated_eval: chunk 40/947 model_bpb=1.161885 ngram_bpb=1.218678 delta=0.056793 t=25s +integrated_eval: chunk 45/947 model_bpb=1.161174 ngram_bpb=1.208883 delta=0.047708 t=29s +integrated_eval: chunk 50/947 model_bpb=1.162299 ngram_bpb=1.199111 delta=0.036812 t=32s +integrated_eval: chunk 55/947 model_bpb=1.163051 ngram_bpb=1.187190 delta=0.024139 t=35s +integrated_eval: chunk 60/947 model_bpb=1.158536 ngram_bpb=1.169455 delta=0.010919 t=38s +integrated_eval: chunk 65/947 model_bpb=1.158546 ngram_bpb=1.154936 delta=-0.003611 t=41s +integrated_eval: chunk 70/947 model_bpb=1.156659 ngram_bpb=1.137579 delta=-0.019080 t=44s +integrated_eval: chunk 75/947 model_bpb=1.156826 ngram_bpb=1.121821 delta=-0.035005 t=47s +integrated_eval: chunk 80/947 model_bpb=1.157750 ngram_bpb=1.105855 delta=-0.051895 t=49s +integrated_eval: chunk 85/947 model_bpb=1.159869 ngram_bpb=1.089992 delta=-0.069876 t=53s +integrated_eval: chunk 90/947 model_bpb=1.160050 ngram_bpb=1.073030 delta=-0.087020 t=56s +integrated_eval: chunk 95/947 model_bpb=1.163187 ngram_bpb=1.057823 delta=-0.105364 t=59s +integrated_eval: chunk 100/947 model_bpb=1.162379 ngram_bpb=1.040291 delta=-0.122088 t=62s +integrated_eval: chunk 105/947 model_bpb=1.161011 ngram_bpb=1.022431 delta=-0.138580 t=65s +integrated_eval: chunk 110/947 model_bpb=1.162145 ngram_bpb=1.006158 delta=-0.155987 t=69s +integrated_eval: chunk 115/947 model_bpb=1.161797 ngram_bpb=0.989497 delta=-0.172300 t=72s +integrated_eval: chunk 120/947 model_bpb=1.162377 ngram_bpb=0.973677 delta=-0.188700 t=75s +integrated_eval: chunk 125/947 model_bpb=1.161596 ngram_bpb=0.957491 delta=-0.204105 t=78s +integrated_eval: chunk 130/947 model_bpb=1.161546 ngram_bpb=0.941809 delta=-0.219737 t=81s +integrated_eval: chunk 135/947 model_bpb=1.160375 ngram_bpb=0.926084 delta=-0.234291 t=84s +integrated_eval: chunk 140/947 model_bpb=1.161380 ngram_bpb=0.911305 delta=-0.250075 t=87s +integrated_eval: chunk 145/947 model_bpb=1.161348 ngram_bpb=0.896304 delta=-0.265044 t=90s +integrated_eval: chunk 150/947 model_bpb=1.161667 ngram_bpb=0.881907 delta=-0.279760 t=93s +integrated_eval: chunk 155/947 model_bpb=1.162025 ngram_bpb=0.867980 delta=-0.294045 t=96s +integrated_eval: chunk 160/947 model_bpb=1.162624 ngram_bpb=0.854486 delta=-0.308138 t=99s +integrated_eval: chunk 165/947 model_bpb=1.162315 ngram_bpb=0.841139 delta=-0.321175 t=102s +integrated_eval: chunk 170/947 model_bpb=1.161908 ngram_bpb=0.827937 delta=-0.333970 t=105s +integrated_eval: chunk 175/947 model_bpb=1.162191 ngram_bpb=0.815424 delta=-0.346767 t=108s +integrated_eval: chunk 180/947 model_bpb=1.163454 ngram_bpb=0.804483 delta=-0.358971 t=111s +integrated_eval: chunk 185/947 model_bpb=1.163224 ngram_bpb=0.792613 delta=-0.370611 t=114s +integrated_eval: chunk 190/947 model_bpb=1.163289 ngram_bpb=0.781142 delta=-0.382147 t=117s +integrated_eval: chunk 195/947 model_bpb=1.163456 ngram_bpb=0.770094 delta=-0.393362 t=120s +integrated_eval: chunk 200/947 model_bpb=1.163374 ngram_bpb=0.759395 delta=-0.403980 t=123s +integrated_eval: chunk 205/947 model_bpb=1.162357 ngram_bpb=0.748793 delta=-0.413564 t=126s +integrated_eval: chunk 210/947 model_bpb=1.162421 ngram_bpb=0.738701 delta=-0.423720 t=129s +integrated_eval: chunk 215/947 model_bpb=1.163029 ngram_bpb=0.728925 delta=-0.434104 t=132s +integrated_eval: chunk 220/947 model_bpb=1.162257 ngram_bpb=0.719175 delta=-0.443082 t=135s +integrated_eval: chunk 225/947 model_bpb=1.162397 ngram_bpb=0.709775 delta=-0.452621 t=138s +integrated_eval: chunk 230/947 model_bpb=1.162183 ngram_bpb=0.700648 delta=-0.461535 t=141s +integrated_eval: chunk 235/947 model_bpb=1.161932 ngram_bpb=0.691762 delta=-0.470171 t=144s +integrated_eval: chunk 240/947 model_bpb=1.161468 ngram_bpb=0.683366 delta=-0.478102 t=147s +integrated_eval: chunk 245/947 model_bpb=1.161614 ngram_bpb=0.675255 delta=-0.486359 t=150s +integrated_eval: chunk 250/947 model_bpb=1.161482 ngram_bpb=0.667176 delta=-0.494305 t=153s +integrated_eval: chunk 255/947 model_bpb=1.160852 ngram_bpb=0.659400 delta=-0.501452 t=156s +integrated_eval: chunk 260/947 model_bpb=1.160649 ngram_bpb=0.651977 delta=-0.508673 t=159s +integrated_eval: chunk 265/947 model_bpb=1.161201 ngram_bpb=0.644729 delta=-0.516472 t=162s +integrated_eval: chunk 270/947 model_bpb=1.161332 ngram_bpb=0.637493 delta=-0.523839 t=165s +integrated_eval: chunk 275/947 model_bpb=1.160798 ngram_bpb=0.630505 delta=-0.530293 t=168s +integrated_eval: chunk 280/947 model_bpb=1.160685 ngram_bpb=0.623690 delta=-0.536995 t=172s +integrated_eval: chunk 285/947 model_bpb=1.160190 ngram_bpb=0.617205 delta=-0.542985 t=175s +integrated_eval: chunk 290/947 model_bpb=1.159961 ngram_bpb=0.610883 delta=-0.549078 t=178s +integrated_eval: chunk 295/947 model_bpb=1.159419 ngram_bpb=0.604771 delta=-0.554648 t=181s +integrated_eval: chunk 300/947 model_bpb=1.159526 ngram_bpb=0.598771 delta=-0.560755 t=184s +integrated_eval: chunk 305/947 model_bpb=1.159135 ngram_bpb=0.592956 delta=-0.566179 t=187s +integrated_eval: chunk 310/947 model_bpb=1.159108 ngram_bpb=0.587235 delta=-0.571873 t=189s +integrated_eval: chunk 315/947 model_bpb=1.158886 ngram_bpb=0.581577 delta=-0.577309 t=192s +integrated_eval: chunk 320/947 model_bpb=1.158227 ngram_bpb=0.576024 delta=-0.582203 t=195s +integrated_eval: chunk 325/947 model_bpb=1.157828 ngram_bpb=0.570685 delta=-0.587143 t=198s +integrated_eval: chunk 330/947 model_bpb=1.157553 ngram_bpb=0.565546 delta=-0.592007 t=201s +integrated_eval: chunk 335/947 model_bpb=1.157364 ngram_bpb=0.560473 delta=-0.596892 t=204s +integrated_eval: chunk 340/947 model_bpb=1.156602 ngram_bpb=0.555438 delta=-0.601164 t=207s +integrated_eval: chunk 345/947 model_bpb=1.156717 ngram_bpb=0.550596 delta=-0.606121 t=211s +integrated_eval: chunk 350/947 model_bpb=1.156154 ngram_bpb=0.545810 delta=-0.610344 t=214s +integrated_eval: chunk 355/947 model_bpb=1.155958 ngram_bpb=0.541289 delta=-0.614669 t=217s +integrated_eval: chunk 360/947 model_bpb=1.155701 ngram_bpb=0.536864 delta=-0.618837 t=220s +integrated_eval: chunk 365/947 model_bpb=1.155970 ngram_bpb=0.532565 delta=-0.623405 t=223s +integrated_eval: chunk 370/947 model_bpb=1.156249 ngram_bpb=0.528337 delta=-0.627912 t=225s +integrated_eval: chunk 375/947 model_bpb=1.155677 ngram_bpb=0.524146 delta=-0.631530 t=228s +integrated_eval: chunk 380/947 model_bpb=1.155877 ngram_bpb=0.520180 delta=-0.635697 t=232s +integrated_eval: chunk 385/947 model_bpb=1.155743 ngram_bpb=0.516294 delta=-0.639449 t=235s +integrated_eval: chunk 390/947 model_bpb=1.155921 ngram_bpb=0.512499 delta=-0.643423 t=238s +integrated_eval: chunk 395/947 model_bpb=1.155915 ngram_bpb=0.508824 delta=-0.647092 t=241s +integrated_eval: chunk 400/947 model_bpb=1.155804 ngram_bpb=0.505051 delta=-0.650753 t=244s +integrated_eval: chunk 405/947 model_bpb=1.155714 ngram_bpb=0.501485 delta=-0.654229 t=247s +integrated_eval: chunk 410/947 model_bpb=1.155676 ngram_bpb=0.497929 delta=-0.657747 t=250s +integrated_eval: chunk 415/947 model_bpb=1.155456 ngram_bpb=0.494513 delta=-0.660943 t=253s +integrated_eval: chunk 420/947 model_bpb=1.155334 ngram_bpb=0.491193 delta=-0.664141 t=256s +integrated_eval: chunk 425/947 model_bpb=1.155297 ngram_bpb=0.487943 delta=-0.667354 t=259s +integrated_eval: chunk 430/947 model_bpb=1.155449 ngram_bpb=0.484783 delta=-0.670665 t=262s +integrated_eval: chunk 435/947 model_bpb=1.155599 ngram_bpb=0.481605 delta=-0.673994 t=265s +integrated_eval: chunk 440/947 model_bpb=1.155778 ngram_bpb=0.478624 delta=-0.677154 t=268s +integrated_eval: chunk 445/947 model_bpb=1.155320 ngram_bpb=0.475707 delta=-0.679613 t=271s +integrated_eval: chunk 450/947 model_bpb=1.155436 ngram_bpb=0.472861 delta=-0.682576 t=274s +integrated_eval: chunk 455/947 model_bpb=1.155256 ngram_bpb=0.469968 delta=-0.685288 t=277s +integrated_eval: chunk 460/947 model_bpb=1.155399 ngram_bpb=0.467142 delta=-0.688257 t=280s +integrated_eval: chunk 465/947 model_bpb=1.155275 ngram_bpb=0.464459 delta=-0.690816 t=283s +integrated_eval: chunk 470/947 model_bpb=1.155656 ngram_bpb=0.461694 delta=-0.693963 t=286s +integrated_eval: chunk 475/947 model_bpb=1.156089 ngram_bpb=0.459034 delta=-0.697055 t=289s +integrated_eval: chunk 480/947 model_bpb=1.156148 ngram_bpb=0.456273 delta=-0.699875 t=292s +integrated_eval: chunk 485/947 model_bpb=1.156738 ngram_bpb=0.453648 delta=-0.703090 t=295s +integrated_eval: chunk 490/947 model_bpb=1.156881 ngram_bpb=0.450989 delta=-0.705892 t=298s +integrated_eval: chunk 495/947 model_bpb=1.156927 ngram_bpb=0.448443 delta=-0.708484 t=301s +integrated_eval: chunk 500/947 model_bpb=1.157170 ngram_bpb=0.445916 delta=-0.711254 t=305s +integrated_eval: chunk 505/947 model_bpb=1.157429 ngram_bpb=0.443494 delta=-0.713936 t=308s +integrated_eval: chunk 510/947 model_bpb=1.157681 ngram_bpb=0.441070 delta=-0.716612 t=311s +integrated_eval: chunk 515/947 model_bpb=1.158145 ngram_bpb=0.438647 delta=-0.719498 t=314s +integrated_eval: chunk 520/947 model_bpb=1.158627 ngram_bpb=0.436269 delta=-0.722358 t=317s +integrated_eval: chunk 525/947 model_bpb=1.158556 ngram_bpb=0.433931 delta=-0.724625 t=320s +integrated_eval: chunk 530/947 model_bpb=1.158752 ngram_bpb=0.431625 delta=-0.727127 t=323s +integrated_eval: chunk 535/947 model_bpb=1.158892 ngram_bpb=0.429469 delta=-0.729423 t=326s +integrated_eval: chunk 540/947 model_bpb=1.158979 ngram_bpb=0.427233 delta=-0.731746 t=329s +integrated_eval: chunk 545/947 model_bpb=1.159241 ngram_bpb=0.425050 delta=-0.734191 t=332s +integrated_eval: chunk 550/947 model_bpb=1.159460 ngram_bpb=0.422931 delta=-0.736530 t=335s +integrated_eval: chunk 555/947 model_bpb=1.159258 ngram_bpb=0.420838 delta=-0.738420 t=338s +integrated_eval: chunk 560/947 model_bpb=1.159085 ngram_bpb=0.418764 delta=-0.740320 t=341s +integrated_eval: chunk 565/947 model_bpb=1.158932 ngram_bpb=0.416743 delta=-0.742189 t=343s +integrated_eval: chunk 570/947 model_bpb=1.158664 ngram_bpb=0.414715 delta=-0.743949 t=346s +integrated_eval: chunk 575/947 model_bpb=1.158775 ngram_bpb=0.412768 delta=-0.746007 t=349s +integrated_eval: chunk 580/947 model_bpb=1.158706 ngram_bpb=0.410775 delta=-0.747931 t=352s +integrated_eval: chunk 585/947 model_bpb=1.158437 ngram_bpb=0.408817 delta=-0.749620 t=356s +integrated_eval: chunk 590/947 model_bpb=1.158235 ngram_bpb=0.406896 delta=-0.751339 t=359s +integrated_eval: chunk 595/947 model_bpb=1.158418 ngram_bpb=0.405017 delta=-0.753401 t=362s +integrated_eval: chunk 600/947 model_bpb=1.158655 ngram_bpb=0.403190 delta=-0.755465 t=365s +integrated_eval: chunk 605/947 model_bpb=1.158270 ngram_bpb=0.401396 delta=-0.756874 t=368s +integrated_eval: chunk 610/947 model_bpb=1.158703 ngram_bpb=0.399703 delta=-0.759001 t=371s +integrated_eval: chunk 615/947 model_bpb=1.158576 ngram_bpb=0.397934 delta=-0.760642 t=374s +integrated_eval: chunk 620/947 model_bpb=1.158332 ngram_bpb=0.396159 delta=-0.762173 t=377s +integrated_eval: chunk 625/947 model_bpb=1.157913 ngram_bpb=0.394503 delta=-0.763410 t=380s +integrated_eval: chunk 630/947 model_bpb=1.157609 ngram_bpb=0.392839 delta=-0.764770 t=383s +integrated_eval: chunk 635/947 model_bpb=1.157468 ngram_bpb=0.391177 delta=-0.766291 t=386s +integrated_eval: chunk 640/947 model_bpb=1.157139 ngram_bpb=0.389490 delta=-0.767649 t=389s +integrated_eval: chunk 645/947 model_bpb=1.156916 ngram_bpb=0.387868 delta=-0.769048 t=392s +integrated_eval: chunk 650/947 model_bpb=1.156872 ngram_bpb=0.386278 delta=-0.770593 t=395s +integrated_eval: chunk 655/947 model_bpb=1.156617 ngram_bpb=0.384710 delta=-0.771907 t=398s +integrated_eval: chunk 660/947 model_bpb=1.156335 ngram_bpb=0.383125 delta=-0.773210 t=401s +integrated_eval: chunk 665/947 model_bpb=1.156103 ngram_bpb=0.381584 delta=-0.774520 t=404s +integrated_eval: chunk 670/947 model_bpb=1.156018 ngram_bpb=0.380079 delta=-0.775938 t=407s +integrated_eval: chunk 675/947 model_bpb=1.155895 ngram_bpb=0.378622 delta=-0.777273 t=410s +integrated_eval: chunk 680/947 model_bpb=1.156061 ngram_bpb=0.377238 delta=-0.778823 t=413s +integrated_eval: chunk 685/947 model_bpb=1.156242 ngram_bpb=0.375861 delta=-0.780381 t=416s +integrated_eval: chunk 690/947 model_bpb=1.156642 ngram_bpb=0.374539 delta=-0.782103 t=419s +integrated_eval: chunk 695/947 model_bpb=1.156401 ngram_bpb=0.373141 delta=-0.783260 t=422s +integrated_eval: chunk 700/947 model_bpb=1.156476 ngram_bpb=0.371796 delta=-0.784680 t=425s +integrated_eval: chunk 705/947 model_bpb=1.156645 ngram_bpb=0.370507 delta=-0.786137 t=428s +integrated_eval: chunk 710/947 model_bpb=1.156753 ngram_bpb=0.369179 delta=-0.787574 t=431s +integrated_eval: chunk 715/947 model_bpb=1.156715 ngram_bpb=0.367983 delta=-0.788733 t=434s +integrated_eval: chunk 720/947 model_bpb=1.157195 ngram_bpb=0.366718 delta=-0.790477 t=437s +integrated_eval: chunk 725/947 model_bpb=1.157204 ngram_bpb=0.365478 delta=-0.791727 t=440s +integrated_eval: chunk 730/947 model_bpb=1.157113 ngram_bpb=0.364250 delta=-0.792862 t=443s +integrated_eval: chunk 735/947 model_bpb=1.157770 ngram_bpb=0.363029 delta=-0.794741 t=446s +integrated_eval: chunk 740/947 model_bpb=1.157724 ngram_bpb=0.361816 delta=-0.795909 t=449s +integrated_eval: chunk 745/947 model_bpb=1.158065 ngram_bpb=0.360671 delta=-0.797394 t=452s +integrated_eval: chunk 750/947 model_bpb=1.158126 ngram_bpb=0.359521 delta=-0.798605 t=455s +integrated_eval: chunk 755/947 model_bpb=1.158227 ngram_bpb=0.358348 delta=-0.799879 t=458s +integrated_eval: chunk 760/947 model_bpb=1.158304 ngram_bpb=0.357221 delta=-0.801083 t=461s +integrated_eval: chunk 765/947 model_bpb=1.158570 ngram_bpb=0.356098 delta=-0.802472 t=464s +integrated_eval: chunk 770/947 model_bpb=1.158660 ngram_bpb=0.354941 delta=-0.803719 t=467s +integrated_eval: chunk 775/947 model_bpb=1.158975 ngram_bpb=0.353820 delta=-0.805155 t=470s +integrated_eval: chunk 780/947 model_bpb=1.159073 ngram_bpb=0.352696 delta=-0.806376 t=473s +integrated_eval: chunk 785/947 model_bpb=1.159282 ngram_bpb=0.351600 delta=-0.807683 t=476s +integrated_eval: chunk 790/947 model_bpb=1.159513 ngram_bpb=0.350542 delta=-0.808972 t=479s +integrated_eval: chunk 795/947 model_bpb=1.159535 ngram_bpb=0.349406 delta=-0.810128 t=482s +integrated_eval: chunk 800/947 model_bpb=1.159808 ngram_bpb=0.348311 delta=-0.811497 t=485s +integrated_eval: chunk 805/947 model_bpb=1.160024 ngram_bpb=0.347212 delta=-0.812812 t=488s +integrated_eval: chunk 810/947 model_bpb=1.159947 ngram_bpb=0.346187 delta=-0.813760 t=491s +integrated_eval: chunk 815/947 model_bpb=1.160042 ngram_bpb=0.345129 delta=-0.814913 t=494s +integrated_eval: chunk 820/947 model_bpb=1.160095 ngram_bpb=0.344078 delta=-0.816017 t=497s +integrated_eval: chunk 825/947 model_bpb=1.160156 ngram_bpb=0.343052 delta=-0.817104 t=500s +integrated_eval: chunk 830/947 model_bpb=1.160340 ngram_bpb=0.342034 delta=-0.818306 t=503s +integrated_eval: chunk 835/947 model_bpb=1.160602 ngram_bpb=0.341037 delta=-0.819564 t=506s +integrated_eval: chunk 840/947 model_bpb=1.160691 ngram_bpb=0.340029 delta=-0.820661 t=509s +integrated_eval: chunk 845/947 model_bpb=1.160801 ngram_bpb=0.339037 delta=-0.821763 t=512s +integrated_eval: chunk 850/947 model_bpb=1.161080 ngram_bpb=0.338054 delta=-0.823025 t=515s +integrated_eval: chunk 855/947 model_bpb=1.161108 ngram_bpb=0.337070 delta=-0.824037 t=518s +integrated_eval: chunk 860/947 model_bpb=1.161035 ngram_bpb=0.336115 delta=-0.824920 t=521s +integrated_eval: chunk 865/947 model_bpb=1.161073 ngram_bpb=0.335170 delta=-0.825903 t=524s +integrated_eval: chunk 870/947 model_bpb=1.160894 ngram_bpb=0.334217 delta=-0.826677 t=527s +integrated_eval: chunk 875/947 model_bpb=1.160767 ngram_bpb=0.333269 delta=-0.827499 t=530s +integrated_eval: chunk 880/947 model_bpb=1.160842 ngram_bpb=0.332350 delta=-0.828492 t=533s +integrated_eval: chunk 885/947 model_bpb=1.160824 ngram_bpb=0.331428 delta=-0.829396 t=536s +integrated_eval: chunk 890/947 model_bpb=1.160778 ngram_bpb=0.330538 delta=-0.830240 t=539s +integrated_eval: chunk 895/947 model_bpb=1.160473 ngram_bpb=0.329638 delta=-0.830835 t=542s +integrated_eval: chunk 900/947 model_bpb=1.160414 ngram_bpb=0.328748 delta=-0.831666 t=545s +integrated_eval: chunk 905/947 model_bpb=1.160372 ngram_bpb=0.327872 delta=-0.832500 t=548s +integrated_eval: chunk 910/947 model_bpb=1.160465 ngram_bpb=0.327000 delta=-0.833465 t=551s +integrated_eval: chunk 915/947 model_bpb=1.160342 ngram_bpb=0.326123 delta=-0.834219 t=554s +integrated_eval: chunk 920/947 model_bpb=1.160424 ngram_bpb=0.325311 delta=-0.835112 t=557s +integrated_eval: chunk 925/947 model_bpb=1.160211 ngram_bpb=0.324479 delta=-0.835732 t=560s +integrated_eval: chunk 930/947 model_bpb=1.160181 ngram_bpb=0.323635 delta=-0.836546 t=563s +integrated_eval: chunk 935/947 model_bpb=1.160151 ngram_bpb=0.322824 delta=-0.837327 t=566s +integrated_eval: chunk 940/947 model_bpb=1.159945 ngram_bpb=0.322014 delta=-0.837931 t=569s +integrated_eval: chunk 945/947 model_bpb=1.159968 ngram_bpb=0.321227 delta=-0.838741 t=572s +integrated_eval: DONE model_bpb=1.1600 ngram_bpb=0.3210 delta=-0.8390 elapsed=573s +final_integrated model_val_loss:1.9586 model_val_bpb:1.1600 +final_integrated ngram_val_loss:0.5420 ngram_val_bpb:0.3210 +final_integrated delta_bpb:-0.8390 +final_integrated_exact model_bpb:1.16001811 ngram_bpb:0.32101837 +final_integrated_time:573942ms