diff --git a/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/README.md b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/README.md new file mode 100644 index 0000000000..0ff2eb29d8 --- /dev/null +++ b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/README.md @@ -0,0 +1,134 @@ +# Loader FullGPTQ XSA11 + Online Ngram Agreement + +**val_bpb: 1.11085863** (4-seed mean, std 0.00030217) | **15,953,221 bytes worst case** | 8xH100 SXM + +Improves the current README leader at `1.1194` by **0.00592043 nats/byte** and **0.00854137 bpb** on the bundled 4-seed subset (`42`, `1337`, `2025`, `15`). + +All four bundled seed logs and the included code files correspond to the packaged submission in this folder. + +## Results (8xH100 80GB SXM) + +| Seed | step_avg | steps | Standard sliding bpb | Online-pass LLM bpb | **Online best-agree bpb** | Online gain | Eval time | Total bytes | +|------|----------:|------:|---------------------:|--------------------:|--------------------------:|------------:|----------:|------------:| +| 42 | 91.59ms | 6443 | 1.11343872 | 1.11372806 | **1.11058356** | 0.00314451 | 481.04s | 15817813 | +| 1337 | 91.40ms | 6456 | 1.11408566 | 1.11437756 | **1.11126660** | 0.00311096 | 461.62s | 15953221 | +| 2025 | 91.39ms | 6457 | 1.11352210 | 1.11381798 | **1.11068499** | 0.00313300 | 462.35s | 15842301 | +| 15 | 91.47ms | 6451 | 1.11372333 | 1.11402056 | **1.11089935** | 0.00312121 | 466.09s | 15841741 | +| **Mean** | **91.46ms** | **6452** | **1.11369245** | **1.11398604** | **1.11085863 (std 0.00030217)** | **0.00312742** | **467.78s** | **15953221 worst case** | + +Using the bundled four-seed subset and testing against the null hypothesis that the gain over `1.1194` is at most `0.005 nats/byte`, the one-sided t-test gives **t = 8.7892**, **df = 3**, **p = 0.00155**. The mean online best-agree score is **0.76998852 nats/byte** with a 95% CI of **[0.76965525, 0.77032180]**. + +## High-Level Takeaways + +- The eval-time agreement techniques appear to reduce BPB reliably: all four bundled seeds improve by about `0.0031` BPB versus the matched online LLM baseline, with very little variance in the gain. +- The inference path is still not as optimized as it could be. The current implementation is already fast enough for the budget, but the runtime breakdown suggests there is still headroom in the online state, blending, and probability-extraction path. + +## Summary + +This submission keeps the `Loader_FullGPTQ_XSA11_BigramHash2816` training stack from PR #1060 as the base point, retunes the training schedule to use `WARMDOWN_ITERS=4000`, and adds a single-pass online n-gram agreement evaluator at the end of `train_gpt.py`. + +The online evaluator combines three causal prefix-only experts: + +- token n-gram top-token hints +- within-word continuation hints +- word-start first-token hints + +At each scored position it chooses at most one hinted token, optionally adds a small agreement boost when multiple experts support the same token, and applies that boost to a single fully normalized distribution derived from the model's own probabilities. + +## Why The Eval Is Valid + +The justification is the same four conditions used for causal evaluation in this challenge. + +1. **Strict causal dependence** + The expert state at position `t` depends only on the artifact and the strict prefix. The online token and within-word state are updated only from already-scored tokens, and the word-start state is also maintained online from the prefix only. + +2. **Full normalized distribution** + The base model defines a full normalized distribution over the official vocabulary. The online path does not target-condition on the realized token. Instead it picks at most one prefix-derived hinted token and applies a logit-style boost to that token while renormalizing the whole distribution. + +3. **Score-before-update** + The score for position `t` is taken from the pre-update state. Only after the score is fixed does the evaluator update the online expert state with the current token. + +4. **Single left-to-right pass** + Evaluation is one forward pass over the validation stream in the official order. There is no rescoring pass, no retrospective revision, and no selection among multiple executions. + +The implementation also keeps the metric calculation honest: + +- BPB uses the sentencepiece byte-length lookup tables from `train_gpt.py` +- the full validation set is scored +- validation order is preserved +- GPTQ calibration stays in the training phase via `GPTQ_RESERVE_MS` + +## Why Many Earlier N-gram Caches Were Invalid + +A number of earlier n-gram-style submissions got very low BPB by exploiting the evaluation harness rather than by defining a valid causal probability model. The main failure modes were: + +- **Target-conditioned lookup.** Some implementations asked the cache about the realized next token itself, or used the realized token to decide whether a cache hit existed. That makes the reported `P(x_t | x_{ 1.10955484 bpb` in `462.67s` + +The measured bottlenecks in the benchmark were the online overlay itself rather than the neural forward pass: + +- online state maintenance +- chunk blending / agreement logic +- model forward plus targeted probability extraction + +## Eval-Time Improvements Tried + +Before settling on the final path, I tried and discarded several slower or less defensible variants: + +- cache-heavy offline / shared-cache evaluation flows +- exact phrase cache variants that were not the right final legality story for a per-seed online submission +- a Python-only online prototype before moving the hot n-gram state into a native helper +- an earlier multi-GPU design that communicated too much per-token state + +The final version uses a local-only distributed design, a native open-addressing online n-gram table in `online_ngram_state.c`, and targeted `logsumexp` / gather extraction rather than a full-vocab `log_softmax` pass for every scored token. + +## Run Command + +```bash +SEED=1337 \ +BIGRAM_VOCAB_SIZE=2816 BIGRAM_DIM=112 XSA_LAST_N=11 \ +USE_GPTQ=1 TTT_ENABLED=0 ONLINE_BEST_AGREE_EVAL=1 EVAL_COMPILE=0 \ +MAX_WALLCLOCK_SECONDS=600 GPTQ_RESERVE_MS=10000 \ +WARMDOWN_ITERS=4000 TIED_EMBED_LR=0.035 ITERATIONS=6700 \ +torchrun --standalone --nproc_per_node=8 train_gpt.py +``` + +## Included Files + +- `train_gpt.py` +- `requirements.txt` +- `online_best_agree_eval.py` +- `online_ngram_state.c` +- `train_seed15.log` +- `train_seed1337.log` +- `train_seed2025.log` +- `train_seed42.log` + +## Credits + +- **Base training / quantized eval stack**: PR #1060 `Loader_FullGPTQ_XSA11_BigramHash2816` +- **This submission's main addition**: integrated online token / within-word / word-start agreement eval path, packaged so it runs inside the record folder and stays within the official evaluation budget diff --git a/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/online_best_agree_eval.py b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/online_best_agree_eval.py new file mode 100644 index 0000000000..dd94dbb344 --- /dev/null +++ b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/online_best_agree_eval.py @@ -0,0 +1,671 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import ctypes +import math +import os +import subprocess +import time +from collections import deque +from pathlib import Path + +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F + + +SCRIPT_DIR = Path(__file__).resolve().parent +ONLINE_NGRAM_SRC = SCRIPT_DIR / "online_ngram_state.c" +ONLINE_NGRAM_LIB = SCRIPT_DIR / "libonline_ngram_state.so" + +WHITESPACE_BYTE_IDS = {9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 36} +EDGE_PUNCT = ".,:;!?()[]{}<>\"'`" + + +def normalize_word(text: str, mode: str) -> str: + text = text.strip() + if mode == "lower": + return text.lower() + if mode == "identity": + return text + if mode == "strip_punct_lower": + return text.strip(EDGE_PUNCT).lower() + raise ValueError(f"Unknown word normalization mode: {mode}") + + +def apply_boost( + llm_true_probs: np.ndarray, + llm_hint_probs: np.ndarray, + hit_mask: np.ndarray, + gate_mask: np.ndarray, + boost: float | np.ndarray, +) -> np.ndarray: + boosted = llm_true_probs.astype(np.float64, copy=True) + if not gate_mask.any(): + return boosted + + if np.isscalar(boost): + scale = math.exp(float(boost)) + hit_gate = gate_mask & hit_mask + miss_gate = gate_mask & ~hit_mask + boosted[hit_gate] = (scale * llm_true_probs[hit_gate]) / ( + 1.0 - llm_true_probs[hit_gate] + scale * llm_true_probs[hit_gate] + ) + boosted[miss_gate] = llm_true_probs[miss_gate] / ( + 1.0 - llm_hint_probs[miss_gate] + scale * llm_hint_probs[miss_gate] + ) + return boosted + + boost_arr = boost.astype(np.float64, copy=False) + scale = np.ones(llm_true_probs.shape, dtype=np.float64) + scale[gate_mask] = np.exp(boost_arr[gate_mask]) + denom = 1.0 - llm_hint_probs + scale * llm_hint_probs + boosted = llm_true_probs / denom + hit_gate = gate_mask & hit_mask + boosted[hit_gate] *= scale[hit_gate] + return boosted + + +def expected_gain(top_prob: np.ndarray, llm_hint_prob: np.ndarray, boost: float) -> np.ndarray: + q = np.clip(llm_hint_prob.astype(np.float64, copy=False), 1e-12, 1.0 - 1e-12) + p = np.clip(top_prob.astype(np.float64, copy=False), 0.0, 1.0) + log_norm = np.log1p(q * (math.exp(boost) - 1.0)) + return (p * boost - log_norm).astype(np.float32) + + +def compute_best_agreement_chunk( + *, + llm_chunk: np.ndarray, + true_targets: np.ndarray, + token_top_prob: np.ndarray, + token_top_token: np.ndarray, + token_hint_probs: np.ndarray, + within_top_prob: np.ndarray, + within_top_token: np.ndarray, + within_valid: np.ndarray, + within_hint_probs: np.ndarray, + word_top_prob: np.ndarray, + word_top_token: np.ndarray, + word_hint_probs: np.ndarray, + token_threshold: float, + token_boost: float, + within_tau: float, + within_boost: float, + word_tau: float, + word_boost: float, + agree_add_boost: float, +) -> np.ndarray: + token_hit = token_top_token == true_targets + token_gate = token_top_prob >= token_threshold + token_exp_gain = expected_gain(token_top_prob, token_hint_probs, token_boost) + + within_hit = within_top_token == true_targets + within_gate = within_valid & (within_top_prob >= within_tau) + within_exp_gain = expected_gain(within_top_prob, within_hint_probs, within_boost) + + word_hit = word_top_token == true_targets + word_gate = word_top_prob >= word_tau + word_exp_gain = expected_gain(word_top_prob, word_hint_probs, word_boost) + + within_pick = within_gate & (~token_gate | (within_exp_gain > token_exp_gain)) + token_pick_tw = token_gate & ~within_pick + tw_gate = token_pick_tw | within_pick + + word_pick = word_gate & ((~tw_gate) | (token_pick_tw & (word_exp_gain > token_exp_gain)) | (within_pick & (word_exp_gain > within_exp_gain))) + token_pick = token_pick_tw & ~word_pick + within_pick_final = within_pick & ~word_pick + chosen_gate = token_pick | within_pick_final | word_pick + + chosen_hint_probs = np.zeros(llm_chunk.shape, dtype=np.float64) + chosen_hint_probs[token_pick] = token_hint_probs[token_pick] + chosen_hint_probs[within_pick_final] = within_hint_probs[within_pick_final] + chosen_hint_probs[word_pick] = word_hint_probs[word_pick] + + chosen_hit = np.zeros(llm_chunk.shape, dtype=np.bool_) + chosen_hit[token_pick] = token_hit[token_pick] + chosen_hit[within_pick_final] = within_hit[within_pick_final] + chosen_hit[word_pick] = word_hit[word_pick] + + chosen_boost = np.zeros(llm_chunk.shape, dtype=np.float64) + chosen_boost[token_pick] = token_boost + chosen_boost[within_pick_final] = within_boost + chosen_boost[word_pick] = word_boost + + selected_token = np.zeros(llm_chunk.shape, dtype=np.uint16) + selected_token[token_pick] = token_top_token[token_pick] + selected_token[within_pick_final] = within_top_token[within_pick_final] + selected_token[word_pick] = word_top_token[word_pick] + + agree_count = np.zeros(llm_chunk.shape, dtype=np.uint8) + agree_count += (token_gate & (token_top_token == selected_token)).astype(np.uint8) + agree_count += (within_gate & (within_top_token == selected_token)).astype(np.uint8) + agree_count += (word_gate & (word_top_token == selected_token)).astype(np.uint8) + agree_any = chosen_gate & (agree_count >= 2) + + agree_boost = chosen_boost.copy() + agree_boost[agree_any] += agree_add_boost + return apply_boost(llm_chunk, chosen_hint_probs, chosen_hit, chosen_gate, agree_boost) + + +def dist_max_float(value: float, device: torch.device, world_size: int) -> float: + if world_size <= 1: + return float(value) + tensor = torch.tensor([value], dtype=torch.float64, device=device) + dist.all_reduce(tensor, op=dist.ReduceOp.MAX) + return float(tensor.item()) + + +def suggest_table_bits(expected_entries: int, load_factor: float) -> int: + expected_entries = max(int(expected_entries), 1) + size = 1 + while size * load_factor < expected_entries: + size <<= 1 + return max(size.bit_length() - 1, 10) + + +def loss_to_bpb(total_loss: float, total_bytes: float) -> float: + return total_loss / (total_bytes * math.log(2.0)) + + +def loss_to_nats_per_byte(total_loss: float, total_bytes: float) -> float: + return total_loss / total_bytes + + +def build_chunk_windows(total_targets: int, seq_len: int, stride: int, chunk_tokens: int) -> list[list[int]]: + window_starts = [ + ws + for ws in range(0, total_targets, stride) + if min(ws + seq_len, total_targets) - ws >= stride or ws == 0 + ] + full_num_chunks = (total_targets + chunk_tokens - 1) // chunk_tokens + chunk_windows: list[list[int]] = [[] for _ in range(full_num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_targets) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // chunk_tokens, full_num_chunks - 1) + chunk_windows[ci].append(ws) + return chunk_windows + + +def ensure_online_ngram_lib(log0) -> ctypes.CDLL: + needs_build = (not ONLINE_NGRAM_LIB.exists()) or ( + ONLINE_NGRAM_SRC.stat().st_mtime_ns > ONLINE_NGRAM_LIB.stat().st_mtime_ns + ) + if needs_build: + log0(f"building_native_ngram_helper src={ONLINE_NGRAM_SRC.name}") + subprocess.run( + [ + "gcc", + "-O3", + "-march=native", + "-shared", + "-fPIC", + "-o", + str(ONLINE_NGRAM_LIB), + str(ONLINE_NGRAM_SRC), + ], + check=True, + ) + lib = ctypes.CDLL(str(ONLINE_NGRAM_LIB)) + lib.online_ngram_state_create.restype = ctypes.c_void_p + lib.online_ngram_state_create.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int] + lib.online_ngram_state_destroy.restype = None + lib.online_ngram_state_destroy.argtypes = [ctypes.c_void_p] + lib.online_ngram_state_seed_prefix_token.restype = None + lib.online_ngram_state_seed_prefix_token.argtypes = [ctypes.c_void_p, ctypes.c_uint16] + lib.online_ngram_state_process_chunk.restype = ctypes.c_int + lib.online_ngram_state_process_chunk.argtypes = [ + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_uint16), + ctypes.c_int64, + ctypes.POINTER(ctypes.c_uint8), + ctypes.POINTER(ctypes.c_uint8), + ctypes.POINTER(ctypes.c_uint16), + ctypes.POINTER(ctypes.c_float), + ctypes.POINTER(ctypes.c_uint16), + ctypes.POINTER(ctypes.c_float), + ctypes.POINTER(ctypes.c_uint8), + ] + return lib + + +class OnlineNgramState: + def __init__( + self, + *, + lib: ctypes.CDLL, + token_ctx_len: int, + token_table_bits: int, + within_table_bits: int, + starts_new_word_lut: np.ndarray, + boundary_lut: np.ndarray, + seed_prefix_token: int, + ) -> None: + self.lib = lib + self.state = lib.online_ngram_state_create(token_ctx_len, token_table_bits, within_table_bits) + if not self.state: + raise RuntimeError( + "Failed to allocate native online ngram state. " + f"token_table_bits={token_table_bits} within_table_bits={within_table_bits}" + ) + self.starts_new_word_lut = np.ascontiguousarray(starts_new_word_lut.astype(np.uint8, copy=False)) + self.boundary_lut = np.ascontiguousarray(boundary_lut.astype(np.uint8, copy=False)) + self.lib.online_ngram_state_seed_prefix_token(self.state, ctypes.c_uint16(int(seed_prefix_token))) + + def close(self) -> None: + if self.state: + self.lib.online_ngram_state_destroy(self.state) + self.state = None + + def __del__(self) -> None: + self.close() + + def process_chunk( + self, + chunk_tokens: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + chunk_tokens = np.ascontiguousarray(chunk_tokens.astype(np.uint16, copy=False)) + n = int(chunk_tokens.size) + token_top_token = np.zeros(n, dtype=np.uint16) + token_top_prob = np.zeros(n, dtype=np.float32) + within_top_token = np.zeros(n, dtype=np.uint16) + within_top_prob = np.zeros(n, dtype=np.float32) + within_valid = np.zeros(n, dtype=np.uint8) + rc = self.lib.online_ngram_state_process_chunk( + self.state, + chunk_tokens.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + ctypes.c_int64(n), + self.starts_new_word_lut.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), + self.boundary_lut.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), + token_top_token.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + token_top_prob.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + within_top_token.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + within_top_prob.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), + within_valid.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8)), + ) + if rc != 0: + raise RuntimeError(f"Native online ngram chunk processing failed rc={rc}") + return token_top_token, token_top_prob, within_top_token, within_top_prob, within_valid.astype(bool) + + +class WordStartState: + def __init__( + self, + *, + sp: spm.SentencePieceProcessor, + order: int, + normalize_mode: str, + ) -> None: + self.sp = sp + self.ctx_w = max(order - 1, 0) + self.normalize_mode = normalize_mode + self.prev_word_ids: deque[int] = deque(maxlen=self.ctx_w) + self.current_word_tokens: list[int] = [] + self.word_to_id: dict[str, int] = {} + self.next_word_id = 1 + self.ctx_total: dict[tuple[int, ...], int] = {} + self.pair_count: dict[tuple[tuple[int, ...], int], int] = {} + self.ctx_best_token: dict[tuple[int, ...], int] = {} + self.ctx_best_count: dict[tuple[int, ...], int] = {} + + def _flush_current_word(self) -> None: + if not self.current_word_tokens: + return + text = normalize_word( + self.sp.decode(self.current_word_tokens), + self.normalize_mode, + ) + if text: + word_id = self.word_to_id.get(text) + if word_id is None: + word_id = self.next_word_id + self.word_to_id[text] = word_id + self.next_word_id += 1 + if self.ctx_w > 0: + self.prev_word_ids.append(word_id) + self.current_word_tokens = [] + + def process_chunk( + self, + chunk_tokens: np.ndarray, + *, + starts_new_word_lut: np.ndarray, + boundary_lut: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray]: + chunk_tokens = np.ascontiguousarray(chunk_tokens.astype(np.uint16, copy=False)) + top_token = np.zeros(chunk_tokens.size, dtype=np.uint16) + top_prob = np.zeros(chunk_tokens.size, dtype=np.float32) + for i, tok_u16 in enumerate(chunk_tokens): + tok = int(tok_u16) + is_boundary = bool(boundary_lut[tok]) + is_word_start = bool(starts_new_word_lut[tok]) or not self.current_word_tokens + if is_boundary: + self._flush_current_word() + continue + if bool(starts_new_word_lut[tok]): + self._flush_current_word() + + ctx_key: tuple[int, ...] | None = None + if is_word_start and len(self.prev_word_ids) >= self.ctx_w: + ctx_key = tuple(self.prev_word_ids) if self.ctx_w > 0 else () + total = self.ctx_total.get(ctx_key, 0) + if total > 0: + top_token[i] = np.uint16(self.ctx_best_token[ctx_key]) + top_prob[i] = np.float32(self.ctx_best_count[ctx_key] / total) + + if is_word_start: + if ctx_key is not None: + pair_key = (ctx_key, tok) + pair = self.pair_count.get(pair_key, 0) + 1 + self.pair_count[pair_key] = pair + total = self.ctx_total.get(ctx_key, 0) + 1 + self.ctx_total[ctx_key] = total + best_count = self.ctx_best_count.get(ctx_key, 0) + if pair > best_count: + self.ctx_best_count[ctx_key] = pair + self.ctx_best_token[ctx_key] = tok + self.current_word_tokens = [tok] + else: + self.current_word_tokens.append(tok) + return top_token, top_prob + + +def build_piece_luts( + *, + tokenizer_path: str, + vocab_size: int, +) -> tuple[spm.SentencePieceProcessor, np.ndarray, np.ndarray]: + sp = spm.SentencePieceProcessor(model_file=tokenizer_path) + pieces = [sp.id_to_piece(i) for i in range(sp.vocab_size())] + starts_new_word_lut = np.zeros(vocab_size, dtype=np.uint8) + for i, piece in enumerate(pieces): + starts_new_word_lut[i] = 1 if piece.startswith("▁") else 0 + boundary_lut = np.zeros(vocab_size, dtype=np.uint8) + bos_id = sp.bos_id() + if bos_id >= 0 and bos_id < vocab_size: + boundary_lut[bos_id] = 1 + for tok in range(min(sp.vocab_size(), vocab_size)): + if sp.is_byte(tok) and tok in WHITESPACE_BYTE_IDS: + boundary_lut[tok] = 1 + return sp, starts_new_word_lut, boundary_lut + + +def compile_logits_fn(model: torch.nn.Module, *, seq_len: int, device: torch.device, log0): + if os.environ.get("EVAL_COMPILE", "0") != "1": + log0("eval-pass-online: using eager logits path") + return model.forward_logits + log0("eval-pass-online: compiling logits path") + compiled = torch.compile(model.forward_logits, dynamic=False, fullgraph=True) + dummy = torch.zeros(1, seq_len, dtype=torch.int64, device=device) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if hasattr(torch.compiler, "cudagraph_mark_step_begin"): + torch.compiler.cudagraph_mark_step_begin() + _ = compiled(dummy) + del dummy + log0("eval-pass-online: compile warmup done") + return compiled + + +def partition_windows(windows: list[int], rank: int, world_size: int) -> list[int]: + start = (len(windows) * rank) // world_size + end = (len(windows) * (rank + 1)) // world_size + return windows[start:end] + + +def eval_val_sliding_online_best_agree( + *, + args, + base_model: torch.nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: torch.Tensor, + base_bytes_lut: torch.Tensor, + has_leading_space_lut: torch.Tensor, + is_boundary_token_lut: torch.Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, + log0=print, +) -> tuple[float, float, dict[str, float]]: + startup_t0 = time.perf_counter() + seq_len = eval_seq_len or args.train_seq_len + chunk_tokens = int(os.environ.get("CHUNK_TOKENS", "131072")) + token_order = int(os.environ.get("TOKEN_ORDER", "16")) + token_threshold = float(os.environ.get("TOKEN_THRESHOLD", "0.800")) + token_boost = float(os.environ.get("TOKEN_BOOST", "2.625")) + within_tau = float(os.environ.get("WITHIN_TAU", "0.450")) + within_boost = float(os.environ.get("WITHIN_BOOST", "0.750")) + word_order = int(os.environ.get("WORD_ORDER", "4")) + word_normalize = os.environ.get("WORD_NORMALIZE", "strip_punct_lower") + word_tau = float(os.environ.get("WORD_TAU", "0.650")) + word_boost = float(os.environ.get("WORD_BOOST", "0.750")) + agree_add_boost = float(os.environ.get("AGREE_ADD_BOOST", "0.500")) + + total_targets = val_tokens.numel() - 1 + tokens_np = val_tokens.cpu().numpy().astype(np.uint16, copy=False) + sp, starts_new_word_lut, boundary_lut = build_piece_luts( + tokenizer_path=args.tokenizer_path, + vocab_size=args.vocab_size, + ) + token_table_bits = int( + os.environ.get( + "TOKEN_TABLE_BITS", + str(suggest_table_bits(total_targets, load_factor=0.55)), + ) + ) + within_table_bits = int( + os.environ.get( + "WITHIN_TABLE_BITS", + str(suggest_table_bits(max(total_targets // 2, 1), load_factor=0.60)), + ) + ) + online_lib = ensure_online_ngram_lib(log0) + ngram_state = OnlineNgramState( + lib=online_lib, + token_ctx_len=max(token_order - 1, 0), + token_table_bits=token_table_bits, + within_table_bits=within_table_bits, + starts_new_word_lut=starts_new_word_lut, + boundary_lut=boundary_lut, + seed_prefix_token=int(tokens_np[0]), + ) + word_state = WordStartState( + sp=sp, + order=word_order, + normalize_mode=word_normalize, + ) + + compiled_logits = compile_logits_fn(base_model, seq_len=seq_len, device=device, log0=log0 if rank == 0 else (lambda *_: None)) + startup_s = time.perf_counter() - startup_t0 + startup_max_s = dist_max_float(startup_s, device, world_size) + if rank == 0: + log0( + f"online_best_agree:start total_targets={total_targets} seq_len={seq_len} stride={stride} " + f"chunk_tokens={chunk_tokens} batch_seqs={batch_seqs} token_order={token_order} " + f"word_order={word_order} startup_max={startup_max_s:.2f}s" + ) + + chunk_windows = build_chunk_windows(total_targets, seq_len, stride, chunk_tokens) + + llm_loss_sum = 0.0 + best_agree_loss_sum = 0.0 + byte_sum = 0.0 + token_count = 0.0 + state_time_s = 0.0 + input_time_s = 0.0 + forward_time_s = 0.0 + blend_time_s = 0.0 + loop_t0 = time.perf_counter() + + try: + with torch.inference_mode(): + for ci, windows in enumerate(chunk_windows): + if not windows: + continue + chunk_t0 = ci * chunk_tokens + chunk_t1 = min((ci + 1) * chunk_tokens, total_targets) + chunk_target_tokens = np.ascontiguousarray(tokens_np[chunk_t0 + 1 : chunk_t1 + 1], dtype=np.uint16) + + t_state0 = time.perf_counter() + token_top_token, token_top_prob, within_top_token, within_top_prob, within_valid = ngram_state.process_chunk( + chunk_target_tokens + ) + word_top_token, word_top_prob = word_state.process_chunk( + chunk_target_tokens, + starts_new_word_lut=starts_new_word_lut, + boundary_lut=boundary_lut, + ) + state_time_s += time.perf_counter() - t_state0 + + my_windows = partition_windows(windows, rank, world_size) + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi : bi + batch_seqs] + if not batch_ws: + continue + t_input0 = time.perf_counter() + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + token_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + within_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + word_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + score_starts: list[int] = [] + + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_targets) + wlen = end - ws + wlens.append(wlen) + local = val_tokens[ws : end + 1].to(device=device, dtype=torch.int64) + x_batch[i, :wlen] = local[:-1] + y_batch[i, :wlen] = local[1:] + s = 0 if ws == 0 else max(wlen - stride, 0) + score_starts.append(s) + if wlen - s <= 0: + continue + c0 = ws + s - chunk_t0 + c1 = ws + wlen - chunk_t0 + token_batch[i, s:wlen] = torch.from_numpy( + np.asarray(token_top_token[c0:c1], dtype=np.int64) + ).to(device=device, dtype=torch.int64) + within_batch[i, s:wlen] = torch.from_numpy( + np.asarray(within_top_token[c0:c1], dtype=np.int64) + ).to(device=device, dtype=torch.int64) + word_batch[i, s:wlen] = torch.from_numpy( + np.asarray(word_top_token[c0:c1], dtype=np.int64) + ).to(device=device, dtype=torch.int64) + input_time_s += time.perf_counter() - t_input0 + + t_forward0 = time.perf_counter() + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if hasattr(torch.compiler, "cudagraph_mark_step_begin"): + torch.compiler.cudagraph_mark_step_begin() + logits = compiled_logits(x_batch) + logits_f = logits.float() + log_norm = torch.logsumexp(logits_f, dim=-1) + true_logits = logits_f.gather(-1, y_batch.unsqueeze(-1)).squeeze(-1) + token_logits = logits_f.gather(-1, token_batch.unsqueeze(-1)).squeeze(-1) + within_logits = logits_f.gather(-1, within_batch.unsqueeze(-1)).squeeze(-1) + word_logits = logits_f.gather(-1, word_batch.unsqueeze(-1)).squeeze(-1) + true_probs = (true_logits - log_norm).exp() + token_hint = (token_logits - log_norm).exp() + within_hint = (within_logits - log_norm).exp() + word_hint = (word_logits - log_norm).exp() + forward_time_s += time.perf_counter() - t_forward0 + + t_blend0 = time.perf_counter() + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = score_starts[i] + if wlen - s <= 0: + continue + c0 = ws + s - chunk_t0 + c1 = ws + wlen - chunk_t0 + llm_chunk = true_probs[i, s:wlen].detach().cpu().numpy().astype(np.float64, copy=False) + token_hint_chunk = token_hint[i, s:wlen].detach().cpu().numpy().astype(np.float32, copy=False) + within_hint_chunk = within_hint[i, s:wlen].detach().cpu().numpy().astype(np.float32, copy=False) + word_hint_chunk = word_hint[i, s:wlen].detach().cpu().numpy().astype(np.float32, copy=False) + best_agree_chunk = compute_best_agreement_chunk( + llm_chunk=llm_chunk, + true_targets=chunk_target_tokens[c0:c1], + token_top_prob=np.asarray(token_top_prob[c0:c1], dtype=np.float32), + token_top_token=np.asarray(token_top_token[c0:c1], dtype=np.uint16), + token_hint_probs=token_hint_chunk, + within_top_prob=np.asarray(within_top_prob[c0:c1], dtype=np.float32), + within_top_token=np.asarray(within_top_token[c0:c1], dtype=np.uint16), + within_valid=np.asarray(within_valid[c0:c1], dtype=np.bool_), + within_hint_probs=within_hint_chunk, + word_top_prob=np.asarray(word_top_prob[c0:c1], dtype=np.float32), + word_top_token=np.asarray(word_top_token[c0:c1], dtype=np.uint16), + word_hint_probs=word_hint_chunk, + token_threshold=token_threshold, + token_boost=token_boost, + within_tau=within_tau, + within_boost=within_boost, + word_tau=word_tau, + word_boost=word_boost, + agree_add_boost=agree_add_boost, + ) + llm_loss_sum += float((-np.log(np.clip(llm_chunk, 1e-12, 1.0))).sum()) + best_agree_loss_sum += float((-np.log(np.clip(best_agree_chunk, 1e-12, 1.0))).sum()) + token_count += float(c1 - c0) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_sum += float(tb.sum().item()) + blend_time_s += time.perf_counter() - t_blend0 + finally: + ngram_state.close() + + llm_loss_t = torch.tensor([llm_loss_sum], dtype=torch.float64, device=device) + best_loss_t = torch.tensor([best_agree_loss_sum], dtype=torch.float64, device=device) + byte_sum_t = torch.tensor([byte_sum], dtype=torch.float64, device=device) + token_count_t = torch.tensor([token_count], dtype=torch.float64, device=device) + if world_size > 1: + dist.all_reduce(llm_loss_t, op=dist.ReduceOp.SUM) + dist.all_reduce(best_loss_t, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_sum_t, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count_t, op=dist.ReduceOp.SUM) + + state_max_s = dist_max_float(state_time_s, device, world_size) + input_max_s = dist_max_float(input_time_s, device, world_size) + forward_max_s = dist_max_float(forward_time_s, device, world_size) + blend_max_s = dist_max_float(blend_time_s, device, world_size) + loop_total_max_s = dist_max_float(time.perf_counter() - loop_t0, device, world_size) + + llm_total_loss = float(llm_loss_t.item()) + best_total_loss = float(best_loss_t.item()) + total_bytes = float(byte_sum_t.item()) + total_token_count = float(token_count_t.item()) + llm_bpb = loss_to_bpb(llm_total_loss, total_bytes) + best_agree_bpb = loss_to_bpb(best_total_loss, total_bytes) + + timings = { + "llm_bpb": llm_bpb, + "best_agree_bpb": best_agree_bpb, + "gain_bpb": llm_bpb - best_agree_bpb, + "startup_max_s": startup_max_s, + "loop_total_max_s": loop_total_max_s, + "state_max_s": state_max_s, + "input_max_s": input_max_s, + "forward_max_s": forward_max_s, + "blend_max_s": blend_max_s, + "llm_nats_per_byte": loss_to_nats_per_byte(llm_total_loss, total_bytes), + "best_agree_nats_per_byte": loss_to_nats_per_byte(best_total_loss, total_bytes), + "gain_nats_per_byte": loss_to_nats_per_byte(llm_total_loss, total_bytes) + - loss_to_nats_per_byte(best_total_loss, total_bytes), + } + if rank == 0: + log0( + f"online_best_agree:done llm_bpb={llm_bpb:.8f} best_agree_bpb={best_agree_bpb:.8f} " + f"gain_bpb={llm_bpb - best_agree_bpb:.8f} startup_max={startup_max_s:.2f}s " + f"loop_total_max={loop_total_max_s:.2f}s state_max={state_max_s:.2f}s " + f"input_max={input_max_s:.2f}s forward_max={forward_max_s:.2f}s blend_max={blend_max_s:.2f}s" + ) + return best_total_loss / max(total_token_count, 1.0), best_agree_bpb, timings diff --git a/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/online_ngram_state.c b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/online_ngram_state.c new file mode 100644 index 0000000000..f8472a6f05 --- /dev/null +++ b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/online_ngram_state.c @@ -0,0 +1,433 @@ +#include +#include +#include + +#define COEFF_COUNT 32 + +static const uint64_t ROLLING_COEFFS[COEFF_COUNT] = { + 36313ULL, 27191ULL, 51647ULL, 81929ULL, 131071ULL, 196613ULL, + 262147ULL, 393241ULL, 524309ULL, 655373ULL, 786433ULL, 917521ULL, + 1048583ULL, 1179653ULL, 1310729ULL, 1441801ULL, 1572869ULL, 1703941ULL, + 1835017ULL, 1966087ULL, 2097169ULL, 2228243ULL, 2359319ULL, 2490389ULL, + 2621471ULL, 2752549ULL, 2883617ULL, 3014687ULL, 3145757ULL, 3276833ULL, + 3407903ULL, 3538973ULL, +}; + +static const uint64_t PAIR_MIX = 1000003ULL; +static const uint64_t PREFIX_BASE = 1099511628211ULL; +static const uint64_t LEN_MIX = 0x9E3779B185EBCA87ULL; +static const uint64_t TABLE_MIX = 0x9e3779b97f4a7c15ULL; + +typedef struct { + uint64_t key; + uint32_t total; + uint32_t top_count; + uint16_t top_tok; + uint16_t _pad; +} CtxBucket; + +typedef struct { + uint64_t key; + uint32_t count; + uint32_t _pad; +} PairBucket; + +typedef struct { + int token_ctx_len; + int token_prefix_len; + int token_head; + uint16_t *token_ring; + + CtxBucket *token_ctx_tbl; + uint8_t *token_ctx_used; + size_t token_ctx_mask; + + PairBucket *token_pair_tbl; + uint8_t *token_pair_used; + size_t token_pair_mask; + + uint64_t within_hash; + uint32_t within_len; + + CtxBucket *within_ctx_tbl; + uint8_t *within_ctx_used; + size_t within_ctx_mask; + + PairBucket *within_pair_tbl; + uint8_t *within_pair_used; + size_t within_pair_mask; +} OnlineNgramState; + +static inline size_t mix_index(uint64_t key, size_t mask) { + return (size_t)((key * TABLE_MIX) & mask); +} + +static inline size_t find_ctx_slot( + CtxBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key, + int *found +) { + size_t idx = mix_index(key, mask); + for (size_t probe = 0; probe <= mask; ++probe) { + if (!used[idx]) { + *found = 0; + return idx; + } + if (tbl[idx].key == key) { + *found = 1; + return idx; + } + idx = (idx + 1U) & mask; + } + *found = -1; + return 0; +} + +static inline size_t find_pair_slot( + PairBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key, + int *found +) { + size_t idx = mix_index(key, mask); + for (size_t probe = 0; probe <= mask; ++probe) { + if (!used[idx]) { + *found = 0; + return idx; + } + if (tbl[idx].key == key) { + *found = 1; + return idx; + } + idx = (idx + 1U) & mask; + } + *found = -1; + return 0; +} + +static inline uint64_t token_pair_key(uint64_t ctx_key, uint16_t tok, int ctx_len) { + return (ctx_key * PAIR_MIX) ^ (((uint64_t)tok) * ROLLING_COEFFS[(size_t)ctx_len % COEFF_COUNT]); +} + +static inline uint64_t within_pair_key(uint64_t ctx_key, uint16_t tok) { + return (ctx_key * PAIR_MIX) ^ (((uint64_t)tok) * ROLLING_COEFFS[0]); +} + +static inline uint64_t extend_prefix_hash(uint64_t current_hash, uint16_t tok, uint32_t pos) { + return (current_hash * PREFIX_BASE) ^ (((uint64_t)tok + 1ULL) * ROLLING_COEFFS[(size_t)pos % COEFF_COUNT]); +} + +static inline uint32_t pair_increment( + PairBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key +) { + int found = 0; + size_t idx = find_pair_slot(tbl, used, mask, key, &found); + if (found < 0) { + return 0U; + } + if (!found) { + used[idx] = 1U; + tbl[idx].key = key; + tbl[idx].count = 1U; + return 1U; + } + tbl[idx].count += 1U; + return tbl[idx].count; +} + +static inline int ctx_increment( + CtxBucket *tbl, + uint8_t *used, + size_t mask, + uint64_t key, + uint16_t tok, + uint32_t pair_count +) { + int found = 0; + size_t idx = find_ctx_slot(tbl, used, mask, key, &found); + if (found < 0) { + return -1; + } + if (!found) { + used[idx] = 1U; + tbl[idx].key = key; + tbl[idx].total = 1U; + tbl[idx].top_count = pair_count; + tbl[idx].top_tok = tok; + return 0; + } + tbl[idx].total += 1U; + if (pair_count > tbl[idx].top_count) { + tbl[idx].top_count = pair_count; + tbl[idx].top_tok = tok; + } + return 0; +} + +static inline uint64_t token_context_hash(const OnlineNgramState *st) { + uint64_t h = 0ULL; + if (st->token_ctx_len <= 0) { + return h; + } + for (int j = 0; j < st->token_ctx_len; ++j) { + const int ring_idx = (st->token_head + j) % st->token_ctx_len; + h ^= ((uint64_t)st->token_ring[ring_idx]) * ROLLING_COEFFS[(size_t)j]; + } + return h; +} + +static inline void token_push(OnlineNgramState *st, uint16_t tok) { + if (st->token_ctx_len <= 0) { + return; + } + if (st->token_prefix_len < st->token_ctx_len) { + st->token_ring[st->token_prefix_len] = tok; + st->token_prefix_len += 1; + return; + } + st->token_ring[st->token_head] = tok; + st->token_head = (st->token_head + 1) % st->token_ctx_len; +} + +static void *xcalloc(size_t count, size_t size) { + if (count == 0 || size == 0) { + return NULL; + } + return calloc(count, size); +} + +static int alloc_tables( + size_t table_bits, + CtxBucket **ctx_tbl, + uint8_t **ctx_used, + size_t *ctx_mask, + PairBucket **pair_tbl, + uint8_t **pair_used, + size_t *pair_mask +) { + const size_t size = 1ULL << table_bits; + *ctx_tbl = (CtxBucket *)xcalloc(size, sizeof(CtxBucket)); + *ctx_used = (uint8_t *)xcalloc(size, sizeof(uint8_t)); + *pair_tbl = (PairBucket *)xcalloc(size, sizeof(PairBucket)); + *pair_used = (uint8_t *)xcalloc(size, sizeof(uint8_t)); + if (!*ctx_tbl || !*ctx_used || !*pair_tbl || !*pair_used) { + return -1; + } + *ctx_mask = size - 1U; + *pair_mask = size - 1U; + return 0; +} + +void *online_ngram_state_create( + int token_ctx_len, + int token_table_bits, + int within_table_bits +) { + if (token_ctx_len < 0 || token_table_bits <= 0 || within_table_bits <= 0) { + return NULL; + } + OnlineNgramState *st = (OnlineNgramState *)calloc(1, sizeof(OnlineNgramState)); + if (!st) { + return NULL; + } + st->token_ctx_len = token_ctx_len; + if (token_ctx_len > 0) { + st->token_ring = (uint16_t *)xcalloc((size_t)token_ctx_len, sizeof(uint16_t)); + if (!st->token_ring) { + free(st); + return NULL; + } + } + if (alloc_tables( + (size_t)token_table_bits, + &st->token_ctx_tbl, + &st->token_ctx_used, + &st->token_ctx_mask, + &st->token_pair_tbl, + &st->token_pair_used, + &st->token_pair_mask + ) != 0) { + free(st->token_ring); + free(st); + return NULL; + } + if (alloc_tables( + (size_t)within_table_bits, + &st->within_ctx_tbl, + &st->within_ctx_used, + &st->within_ctx_mask, + &st->within_pair_tbl, + &st->within_pair_used, + &st->within_pair_mask + ) != 0) { + free(st->token_pair_used); + free(st->token_pair_tbl); + free(st->token_ctx_used); + free(st->token_ctx_tbl); + free(st->token_ring); + free(st); + return NULL; + } + return (void *)st; +} + +void online_ngram_state_destroy(void *ptr) { + OnlineNgramState *st = (OnlineNgramState *)ptr; + if (!st) { + return; + } + free(st->within_pair_used); + free(st->within_pair_tbl); + free(st->within_ctx_used); + free(st->within_ctx_tbl); + free(st->token_pair_used); + free(st->token_pair_tbl); + free(st->token_ctx_used); + free(st->token_ctx_tbl); + free(st->token_ring); + free(st); +} + +void online_ngram_state_seed_prefix_token(void *ptr, uint16_t tok) { + OnlineNgramState *st = (OnlineNgramState *)ptr; + if (!st) { + return; + } + token_push(st, tok); +} + +int online_ngram_state_process_chunk( + void *ptr, + const uint16_t *tokens, + int64_t n_tokens, + const uint8_t *starts_new_word_lut, + const uint8_t *boundary_lut, + uint16_t *token_top_token, + float *token_top_prob, + uint16_t *within_top_token, + float *within_top_prob, + uint8_t *within_valid +) { + OnlineNgramState *st = (OnlineNgramState *)ptr; + if (!st || !tokens || n_tokens < 0) { + return -1; + } + for (int64_t i = 0; i < n_tokens; ++i) { + const uint16_t tok = tokens[i]; + const uint8_t is_boundary = boundary_lut[tok]; + const uint8_t is_new_word = starts_new_word_lut[tok]; + + uint64_t token_ctx_key = 0ULL; + if (st->token_ctx_len == 0 || st->token_prefix_len >= st->token_ctx_len) { + token_ctx_key = token_context_hash(st); + int found = 0; + size_t idx = find_ctx_slot( + st->token_ctx_tbl, + st->token_ctx_used, + st->token_ctx_mask, + token_ctx_key, + &found + ); + if (found > 0) { + token_top_token[i] = st->token_ctx_tbl[idx].top_tok; + token_top_prob[i] = + (float)st->token_ctx_tbl[idx].top_count / (float)st->token_ctx_tbl[idx].total; + } else { + token_top_token[i] = 0U; + token_top_prob[i] = 0.0f; + } + } else { + token_top_token[i] = 0U; + token_top_prob[i] = 0.0f; + } + + uint64_t within_ctx_key = 0ULL; + if (!is_boundary && !is_new_word && st->within_len > 0U) { + within_ctx_key = st->within_hash ^ ((uint64_t)st->within_len * LEN_MIX); + int found = 0; + size_t idx = find_ctx_slot( + st->within_ctx_tbl, + st->within_ctx_used, + st->within_ctx_mask, + within_ctx_key, + &found + ); + within_valid[i] = 1U; + if (found > 0) { + within_top_token[i] = st->within_ctx_tbl[idx].top_tok; + within_top_prob[i] = + (float)st->within_ctx_tbl[idx].top_count / (float)st->within_ctx_tbl[idx].total; + } else { + within_top_token[i] = 0U; + within_top_prob[i] = 0.0f; + } + } else { + within_valid[i] = 0U; + within_top_token[i] = 0U; + within_top_prob[i] = 0.0f; + } + + if (st->token_ctx_len == 0 || st->token_prefix_len >= st->token_ctx_len) { + const uint64_t pair_key = token_pair_key(token_ctx_key, tok, st->token_ctx_len); + const uint32_t pair_count = pair_increment( + st->token_pair_tbl, + st->token_pair_used, + st->token_pair_mask, + pair_key + ); + if (pair_count == 0U) { + return -2; + } + if (ctx_increment( + st->token_ctx_tbl, + st->token_ctx_used, + st->token_ctx_mask, + token_ctx_key, + tok, + pair_count + ) != 0) { + return -3; + } + } + token_push(st, tok); + + if (is_boundary) { + st->within_hash = 0ULL; + st->within_len = 0U; + continue; + } + if (is_new_word || st->within_len == 0U) { + st->within_hash = extend_prefix_hash(0ULL, tok, 0U); + st->within_len = 1U; + continue; + } + const uint32_t within_pair_count = pair_increment( + st->within_pair_tbl, + st->within_pair_used, + st->within_pair_mask, + within_pair_key(within_ctx_key, tok) + ); + if (within_pair_count == 0U) { + return -4; + } + if (ctx_increment( + st->within_ctx_tbl, + st->within_ctx_used, + st->within_ctx_mask, + within_ctx_key, + tok, + within_pair_count + ) != 0) { + return -5; + } + st->within_hash = extend_prefix_hash(st->within_hash, tok, st->within_len); + st->within_len += 1U; + } + return 0; +} diff --git a/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/requirements.txt b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/requirements.txt new file mode 100644 index 0000000000..2c75ca7833 --- /dev/null +++ b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/requirements.txt @@ -0,0 +1,6 @@ +# Runs here used PyTorch 2.11.0+cu126; torch CUDA build 12.6, NVIDIA driver reported CUDA 12.7. +torch>=2.9.0 +numpy +sentencepiece +zstandard +flash-attn diff --git a/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/submission.json b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/submission.json new file mode 100644 index 0000000000..efacf46cca --- /dev/null +++ b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/submission.json @@ -0,0 +1,9 @@ +{ + "name": "Loader FullGPTQ XSA11 + Online Ngram Agreement", + "val_bpb": 1.11085863, + "bytes_total": 15953221, + "blurb": "Built on PR #1060 Loader_FullGPTQ_XSA11_BigramHash2816, retuned to use WARMDOWN_ITERS=4000, and adds a single-pass online causal token+within-word+word-start+agreement eval path that stays under the 10-minute eval budget. The bundled 4-seed subset (42, 1337, 2025, 15) averages 1.11085863 val_bpb with +0.00592043 nats/byte over the current 1.1194 leaderboard entry; the one-sided t-test against a 0.005 nats/byte bar gives p=0.00155.", + "author": "Anirudh Rahul", + "github_id": "AnirudhRahul", + "date": "2026-03-30" +} diff --git a/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/train_gpt.py b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/train_gpt.py new file mode 100644 index 0000000000..bb46a7efb8 --- /dev/null +++ b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/train_gpt.py @@ -0,0 +1,2214 @@ +from __future__ import annotations +import copy +import glob +import importlib.util +import io +import lzma +import math +import os +import random +import subprocess +import sys +import time +import uuid +from pathlib import Path +import numpy as np +import sentencepiece as spm +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor, nn +from torch.nn.parallel import DistributedDataParallel as DDP +from flash_attn_interface import flash_attn_func as flash_attn_3_func +TORCH_COMPILE_DISABLE = bool(int(os.environ.get("TORCH_COMPILE_DISABLE", "0"))) +TORCH_COMPILE_DYNAMIC = bool(int(os.environ.get("TORCH_COMPILE_DYNAMIC", "0"))) +TORCH_COMPILE_FULLGRAPH = bool(int(os.environ.get("TORCH_COMPILE_FULLGRAPH", "1"))) +TORCH_COMPILE_BACKEND = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") +TORCH_COMPILE_MODE = os.environ.get("TORCH_COMPILE_MODE") or None +TORCH_COMPILE_EVAL_MODE = os.environ.get("TORCH_COMPILE_EVAL_MODE") or None +TORCH_COMPILE_EVAL_DISABLE = bool(int(os.environ.get( + "TORCH_COMPILE_EVAL_DISABLE", + "1" if (TORCH_COMPILE_MODE or "").startswith("max-autotune") else "0", +))) +def cudagraph_step_begin() -> None: + if TORCH_COMPILE_DISABLE or TORCH_COMPILE_BACKEND != "inductor": + return + if hasattr(torch.compiler, "cudagraph_mark_step_begin"): + torch.compiler.cudagraph_mark_step_begin() +def compile_with_env(fn, *, is_eval: bool = False): + if TORCH_COMPILE_DISABLE: + return fn + if is_eval and TORCH_COMPILE_EVAL_DISABLE: + return fn + kwargs = { + "dynamic": TORCH_COMPILE_DYNAMIC, + "fullgraph": TORCH_COMPILE_FULLGRAPH, + "backend": TORCH_COMPILE_BACKEND, + } + compile_mode = TORCH_COMPILE_EVAL_MODE if is_eval and TORCH_COMPILE_EVAL_MODE is not None else TORCH_COMPILE_MODE + if compile_mode is not None: + kwargs["mode"] = compile_mode + return torch.compile(fn, **kwargs) + + +_ONLINE_BEST_AGREE_EVAL_MOD = None + + +def load_local_module(path: Path, module_name: str): + spec = importlib.util.spec_from_file_location(module_name, path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Failed to load module from {path}") + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def load_online_best_agree_eval_mod(): + global _ONLINE_BEST_AGREE_EVAL_MOD + if _ONLINE_BEST_AGREE_EVAL_MOD is None: + helper_path = Path(__file__).resolve().parent / "online_best_agree_eval.py" + _ONLINE_BEST_AGREE_EVAL_MOD = load_local_module(helper_path, "pr1060_online_best_agree_eval_runtime") + return _ONLINE_BEST_AGREE_EVAL_MOD + + +def submission_code_bytes() -> int: + code_paths = [ + Path(__file__).resolve(), + Path(__file__).resolve().parent / "online_best_agree_eval.py", + Path(__file__).resolve().parent / "online_ngram_state.c", + ] + total = 0 + for path in code_paths: + if path.exists(): + total += len(path.read_bytes()) + return total + + +class Hyperparameters: + data_path = os.environ.get("DATA_PATH", "./data/datasets/fineweb10B_sp1024") + train_files = os.path.join(data_path, "fineweb_train_*.bin") + val_files = os.path.join(data_path, "fineweb_val_*.bin") + tokenizer_path = os.environ.get("TOKENIZER_PATH", "./data/tokenizers/fineweb_1024_bpe.model") + run_id = os.environ.get("RUN_ID", str(uuid.uuid4())) + seed = int(os.environ.get("SEED", 1337)) + val_batch_size = int(os.environ.get("VAL_BATCH_SIZE", 524_288)) + val_loss_every = int(os.environ.get("VAL_LOSS_EVERY", 4000)) + train_log_every = int(os.environ.get("TRAIN_LOG_EVERY", 500)) + iterations = int(os.environ.get("ITERATIONS", 20000)) + warmdown_iters = int(os.environ.get("WARMDOWN_ITERS", 3500)) + warmup_steps = int(os.environ.get("WARMUP_STEPS", 20)) + train_batch_tokens = int(os.environ.get("TRAIN_BATCH_TOKENS", 786_432)) + train_seq_len = int(os.environ.get("TRAIN_SEQ_LEN", 2048)) + eval_seq_len = int(os.environ.get("EVAL_SEQ_LEN", 2048)) + max_wallclock_seconds = float(os.environ.get("MAX_WALLCLOCK_SECONDS", 600.0)) + lr_schedule_reference_step_ms = float(os.environ.get("LR_SCHEDULE_REFERENCE_STEP_MS", "0")) + qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5)) + vocab_size = int(os.environ.get("VOCAB_SIZE", 1024)) + num_layers = int(os.environ.get("NUM_LAYERS", 11)) + num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4)) + model_dim = int(os.environ.get("MODEL_DIM", 512)) + num_heads = int(os.environ.get("NUM_HEADS", 8)) + mlp_mult = float(os.environ.get("MLP_MULT", 3.0)) + tie_embeddings = bool(int(os.environ.get("TIE_EMBEDDINGS", "1"))) + rope_base = float(os.environ.get("ROPE_BASE", 10000.0)) + logit_softcap = float(os.environ.get("LOGIT_SOFTCAP", 30.0)) + embed_lr = float(os.environ.get("EMBED_LR", 0.6)) + head_lr = float(os.environ.get("HEAD_LR", 0.008)) + tied_embed_lr = float(os.environ.get("TIED_EMBED_LR", 0.035)) + tied_embed_init_std = float(os.environ.get("TIED_EMBED_INIT_STD", 0.005)) + matrix_lr = float(os.environ.get("MATRIX_LR", 0.025)) + scalar_lr = float(os.environ.get("SCALAR_LR", 0.025)) + muon_momentum = float(os.environ.get("MUON_MOMENTUM", 0.99)) + muon_backend_steps = int(os.environ.get("MUON_BACKEND_STEPS", 5)) + muon_momentum_warmup_start = float(os.environ.get("MUON_MOMENTUM_WARMUP_START", 0.92)) + muon_momentum_warmup_steps = int(os.environ.get("MUON_MOMENTUM_WARMUP_STEPS", 1500)) + beta1 = float(os.environ.get("BETA1", 0.9)) + beta2 = float(os.environ.get("BETA2", 0.95)) + adam_eps = float(os.environ.get("ADAM_EPS", 1e-8)) + grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) + eval_stride = int(os.environ.get("EVAL_STRIDE", 64)) + muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) + swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1"))) + swa_every = int(os.environ.get("SWA_EVERY", 50)) + swa_start_step = int(os.environ.get("SWA_START_STEP", "0")) + muon_wd = float(os.environ.get("MUON_WD", 0.04)) + adam_wd = float(os.environ.get("ADAM_WD", 0.04)) + bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048)) + bigram_dim = int(os.environ.get("BIGRAM_DIM", 128)) + xsa_last_n = int(os.environ.get("XSA_LAST_N", 4)) + rope_dims = int(os.environ.get("ROPE_DIMS", 16)) + ln_scale = bool(int(os.environ.get("LN_SCALE", "1"))) + ve_enabled = bool(int(os.environ.get("VE_ENABLED", "1"))) + ve_dim = int(os.environ.get("VE_DIM", 128)) + ve_layers = os.environ.get("VE_LAYERS", "9,10") + ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "0"))) + ttt_lr = float(os.environ.get("TTT_LR", 0.002)) + ttt_epochs = int(os.environ.get("TTT_EPOCHS", 3)) + ttt_chunk_tokens = int(os.environ.get("TTT_CHUNK_TOKENS", 32768)) + ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 2)) + ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9)) + ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32)) + ttt_grad_clip = float(os.environ.get("TTT_GRAD_CLIP", 1.0)) + negative_slope = float(os.environ.get("NEGATIVE_SLOPE", 0.5)) + use_gptq = bool(int(os.environ.get("USE_GPTQ", "0"))) + gptq_calib_samples = int(os.environ.get("GPTQ_CALIB_SAMPLES", "64")) + gptq_reserve_ms = float(os.environ.get("GPTQ_RESERVE_MS", "14000")) + quant_clip_range = int(os.environ.get("QUANT_CLIP_RANGE", 31)) + +# --- Batched Newton-Schulz orthogonalization --- + +def zeropower_via_newtonschulz5(G: Tensor, steps: int = 5, eps: float = 1e-7) -> Tensor: + """Batched Newton-Schulz orthogonalization. G: (B,M,N) or (M,N).""" + a, b, c = (3.4445, -4.7750, 2.0315) + was_2d = G.ndim == 2 + if was_2d: + G = G.unsqueeze(0) + X = G.bfloat16() + transposed = X.size(-2) > X.size(-1) + if transposed: + X = X.mT + X = X / (X.norm(dim=(-2, -1), keepdim=True) + eps) + for _ in range(steps): + A = X @ X.mT + B = b * A + c * (A @ A) + X = a * X + B @ X + if transposed: + X = X.mT + if was_2d: + X = X.squeeze(0) + return X + +# --- Parallel Muon optimizer --- + +class Muon(torch.optim.Optimizer): + """Parallel Muon: post-backward reduce-scatter -> local NS5 -> all-gather. + + No DDP for bank params. After backward, this optimizer: + 1. Launches async reduce-scatter for all banks (biggest first) + 2. Returns control so Adam can step on small params while RS is in-flight + 3. Waits for each RS, runs local NS5 on the shard, launches async all-gather + 4. Each all-gather overlaps with next bank's NS5 + """ + def __init__(self, params, lr: float, momentum: float, backend_steps: int, + nesterov: bool = True, weight_decay: float = 0.0): + super().__init__( + params, + dict(lr=lr, momentum=momentum, backend_steps=backend_steps, + nesterov=nesterov, weight_decay=weight_decay), + ) + self._built = False + + def _build(self): + self._distributed = dist.is_available() and dist.is_initialized() + self._world_size = dist.get_world_size() if self._distributed else 1 + self._rank = dist.get_rank() if self._distributed else 0 + ws = self._world_size + + self._bank_meta = [] + for group in self.param_groups: + for p in group["params"]: + B = p.shape[0] + padded_B = ((B + ws - 1) // ws) * ws + shard_B = padded_B // ws + tail = p.shape[1:] + dev = p.device + self._bank_meta.append({ + 'p': p, + 'B': B, + 'padded_grad': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'shard_mom': torch.zeros(shard_B, *tail, device=dev, dtype=torch.bfloat16), + 'full_update': torch.zeros(padded_B, *tail, device=dev, dtype=torch.bfloat16), + 'scale': max(1, p.shape[-2] / p.shape[-1]) ** 0.5, + }) + # Sort by size descending -- launch biggest reduce-scatters first + self._bank_meta.sort(key=lambda m: -m['p'].numel()) + self._built = True + + def launch_reduce_scatters(self): + """Phase 1: launch async reduce-scatter for all banks. Call right after backward.""" + if not self._built: + self._build() + if not self._distributed: + return + self._rs_futures = [] + for m in self._bank_meta: + p = m['p'] + if p.grad is None: + self._rs_futures.append(None) + continue + pg = m['padded_grad'] + pg[:m['B']].copy_(p.grad.bfloat16()) + if pg.shape[0] > m['B']: + pg[m['B']:].zero_() + fut = dist.reduce_scatter_tensor(m['shard'], pg, op=dist.ReduceOp.AVG, async_op=True) + self._rs_futures.append(fut) + + @torch.no_grad() + def step(self, closure=None): + """Phase 3: wait for RS, local NS5, all-gather. Call AFTER Adam steps.""" + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if not self._built: + self._build() + + for group in self.param_groups: + lr = group["lr"] + momentum = group["momentum"] + backend_steps = group["backend_steps"] + nesterov = group["nesterov"] + wd = group.get("weight_decay", 0.0) + + prev_ag_handle = None + prev_m = None + + sharded = self._distributed and hasattr(self, '_rs_futures') + + for i, m in enumerate(self._bank_meta): + p = m['p'] + if p.grad is None: + continue + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if sharded and self._rs_futures[i] is not None: + self._rs_futures[i].wait() + g = m['shard'] + buf = m['shard_mom'] + else: + g = p.grad.bfloat16() + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(g) + buf = state["momentum_buffer"] + + buf.mul_(momentum).add_(g) + if nesterov: + update = g.add(buf, alpha=momentum) + else: + update = buf + + update = zeropower_via_newtonschulz5(update, steps=backend_steps) + + if sharded: + prev_ag_handle = dist.all_gather_into_tensor( + m['full_update'], update, async_op=True) + prev_m = m + else: + if wd > 0.0: + p.data.mul_(1.0 - lr * wd) + p.add_(update.to(dtype=p.dtype), alpha=-lr * m['scale']) + + if prev_ag_handle is not None: + prev_ag_handle.wait() + pp = prev_m['p'] + upd = prev_m['full_update'][:prev_m['B']] + if wd > 0.0: + pp.data.mul_(1.0 - lr * wd) + pp.add_(upd.to(dtype=pp.dtype), alpha=-lr * prev_m['scale']) + + if hasattr(self, '_rs_futures'): + del self._rs_futures + + return loss + +# --- Tokenizer evaluation helpers --- + +def build_sentencepiece_luts( + sp: spm.SentencePieceProcessor, vocab_size: int, device: torch.device +) -> tuple[Tensor, Tensor, Tensor]: + sp_vocab_size = int(sp.vocab_size()) + table_size = max(sp_vocab_size, vocab_size) + base_bytes_np = np.zeros((table_size,), dtype=np.int16) + has_leading_space_np = np.zeros((table_size,), dtype=np.bool_) + is_boundary_token_np = np.ones((table_size,), dtype=np.bool_) + for token_id in range(sp_vocab_size): + if sp.is_control(token_id) or sp.is_unknown(token_id) or sp.is_unused(token_id): + continue + is_boundary_token_np[token_id] = False + if sp.is_byte(token_id): + base_bytes_np[token_id] = 1 + continue + piece = sp.id_to_piece(token_id) + if piece.startswith("\u2581"): + has_leading_space_np[token_id] = True + piece = piece[1:] + base_bytes_np[token_id] = len(piece.encode("utf-8")) + return ( + torch.tensor(base_bytes_np, dtype=torch.int16, device=device), + torch.tensor(has_leading_space_np, dtype=torch.bool, device=device), + torch.tensor(is_boundary_token_np, dtype=torch.bool, device=device), + ) +def load_validation_tokens(pattern: str, seq_len: int) -> Tensor: + files = [Path(p) for p in sorted(glob.glob(pattern))] + if not files: + raise FileNotFoundError(f"No files found for pattern: {pattern}") + tokens = torch.cat([load_data_shard(file) for file in files]).contiguous() + usable = ((tokens.numel() - 1) // seq_len) * seq_len + if usable <= 0: + raise ValueError(f"Validation split is too short for TRAIN_SEQ_LEN={seq_len}") + return tokens[: usable + 1] +def eval_val( + args: Hyperparameters, + model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + grad_accum_steps: int, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + seq_len = eval_seq_len or args.train_seq_len + prime_rotary_caches(model, device, seq_len) + local_batch_tokens = args.val_batch_size // (world_size * grad_accum_steps) + if local_batch_tokens < seq_len: + raise ValueError( + "VAL_BATCH_SIZE must provide at least one sequence per rank; " + f"got VAL_BATCH_SIZE={args.val_batch_size}, WORLD_SIZE={world_size}, " + f"GRAD_ACCUM_STEPS={grad_accum_steps}, seq_len={seq_len}" + ) + local_batch_seqs = local_batch_tokens // seq_len + total_seqs = (val_tokens.numel() - 1) // seq_len + seq_start = (total_seqs * rank) // world_size + seq_end = (total_seqs * (rank + 1)) // world_size + val_loss_sum = torch.zeros((), device=device, dtype=torch.float64) + val_token_count = torch.zeros((), device=device, dtype=torch.float64) + val_byte_count = torch.zeros((), device=device, dtype=torch.float64) + model.eval() + with torch.inference_mode(): + for batch_seq_start in range(seq_start, seq_end, local_batch_seqs): + batch_seq_end = min(batch_seq_start + local_batch_seqs, seq_end) + raw_start = batch_seq_start * seq_len + raw_end = batch_seq_end * seq_len + 1 + local = val_tokens[raw_start:raw_end].to(device=device, dtype=torch.int64, non_blocking=True) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + cudagraph_step_begin() + batch_loss = model(x, y).detach() + batch_token_count = float(y.numel()) + val_loss_sum += batch_loss.to(torch.float64) * batch_token_count + val_token_count += batch_token_count + prev_ids = x.reshape(-1) + tgt_ids = y.reshape(-1) + token_bytes = base_bytes_lut[tgt_ids].to(dtype=torch.int16) + token_bytes += (has_leading_space_lut[tgt_ids] & ~is_boundary_token_lut[prev_ids]).to(dtype=torch.int16) + val_byte_count += token_bytes.to(torch.float64).sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(val_loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(val_token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(val_byte_count, op=dist.ReduceOp.SUM) + val_loss = val_loss_sum / val_token_count + bits_per_token = val_loss.item() / math.log(2.0) + tokens_per_byte = val_token_count.item() / val_byte_count.item() + model.train() + return float(val_loss.item()), float(bits_per_token * tokens_per_byte) + +# --- Quantization helpers --- + +CONTROL_TENSOR_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "CONTROL_TENSOR_NAME_PATTERNS", + "attn_scale,attn_scales,mlp_scale,mlp_scales,resid_mix,resid_mixes,q_gain,skip_weight,skip_weights,smear,dtg_gate,ve_layer_scales,ve_shared.scale,attn_gate,vr_lambda", + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_FP32_NAME_PATTERNS = tuple( + pattern + for pattern in os.environ.get( + "INT8_KEEP_FLOAT_FP32_NAME_PATTERNS", + ",".join(CONTROL_TENSOR_NAME_PATTERNS), + ).split(",") + if pattern +) +INT8_KEEP_FLOAT_MAX_NUMEL = 65_536 +INT8_KEEP_FLOAT_STORE_DTYPE = torch.float16 +INT8_PER_ROW_SCALE_DTYPE = torch.float16 +INT8_CLIP_PERCENTILE = 99.99984 +INT8_CLIP_Q = INT8_CLIP_PERCENTILE / 100.0 +def tensor_nbytes(t: Tensor) -> int: + return int(t.numel()) * int(t.element_size()) +def keep_float_tensor(name: str, t: Tensor, passthrough_orig_dtypes: dict[str, str]) -> Tensor: + if any(pattern in name for pattern in INT8_KEEP_FLOAT_FP32_NAME_PATTERNS): + return t.float().contiguous() + if t.dtype in {torch.float32, torch.bfloat16}: + passthrough_orig_dtypes[name] = str(t.dtype).removeprefix("torch.") + return t.to(dtype=INT8_KEEP_FLOAT_STORE_DTYPE).contiguous() + return t +def quantize_float_tensor(t: Tensor) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + clip_abs = ( + torch.quantile(t32.abs(), INT8_CLIP_Q, dim=1) + if t32.numel() + else torch.empty((t32.shape[0],), dtype=torch.float32) + ) + clipped = torch.maximum(torch.minimum(t32, clip_abs[:, None]), -clip_abs[:, None]) + scale = (clip_abs / 127.0).clamp_min(1.0 / 127.0) + q = torch.clamp(torch.round(clipped / scale[:, None]), -127, 127).to(torch.int8).contiguous() + return q, scale.to(dtype=INT8_PER_ROW_SCALE_DTYPE).contiguous() + clip_abs = float(torch.quantile(t32.abs().flatten(), INT8_CLIP_Q).item()) if t32.numel() else 0.0 + scale = torch.tensor(clip_abs / 127.0 if clip_abs > 0 else 1.0, dtype=torch.float32) + q = torch.clamp(torch.round(torch.clamp(t32, -clip_abs, clip_abs) / scale), -127, 127).to(torch.int8).contiguous() + return q, scale +def quantize_state_dict_int8(state_dict: dict[str, Tensor]): + quantized: dict[str, Tensor] = {} + scales: dict[str, Tensor] = {} + dtypes: dict[str, str] = {} + passthrough: dict[str, Tensor] = {} + passthrough_orig_dtypes: dict[str, str] = {} + qmeta: dict[str, dict[str, object]] = {} + stats = dict.fromkeys( + ("param_count", "num_tensors", "num_float_tensors", "num_nonfloat_tensors", "baseline_tensor_bytes", "int8_payload_bytes"), + 0, + ) + for name, tensor in state_dict.items(): + t = tensor.detach().to("cpu").contiguous() + stats["param_count"] += int(t.numel()) + stats["num_tensors"] += 1 + stats["baseline_tensor_bytes"] += tensor_nbytes(t) + if not t.is_floating_point(): + stats["num_nonfloat_tensors"] += 1 + passthrough[name] = t + stats["int8_payload_bytes"] += tensor_nbytes(t) + continue + if t.numel() <= INT8_KEEP_FLOAT_MAX_NUMEL: + kept = keep_float_tensor(name, t, passthrough_orig_dtypes) + passthrough[name] = kept + stats["int8_payload_bytes"] += tensor_nbytes(kept) + continue + stats["num_float_tensors"] += 1 + q, s = quantize_float_tensor(t) + if s.ndim > 0: + qmeta[name] = {"scheme": "per_row", "axis": 0} + quantized[name] = q + scales[name] = s + dtypes[name] = str(t.dtype).removeprefix("torch.") + stats["int8_payload_bytes"] += tensor_nbytes(q) + tensor_nbytes(s) + obj: dict[str, object] = { + "__quant_format__": "int8_clean_per_row_v1", + "quantized": quantized, + "scales": scales, + "dtypes": dtypes, + "passthrough": passthrough, + } + if qmeta: + obj["qmeta"] = qmeta + if passthrough_orig_dtypes: + obj["passthrough_orig_dtypes"] = passthrough_orig_dtypes + return obj, stats +def dequantize_state_dict_int8(obj: dict[str, object]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + qmeta = obj.get("qmeta", {}) + passthrough_orig_dtypes = obj.get("passthrough_orig_dtypes", {}) + for name, q in obj["quantized"].items(): + dtype = getattr(torch, obj["dtypes"][name]) + s = obj["scales"][name] + if qmeta.get(name, {}).get("scheme") == "per_row" or s.ndim > 0: + s = s.to(dtype=torch.float32) + out[name] = (q.float() * s.view(q.shape[0], *([1] * (q.ndim - 1)))).to(dtype=dtype).contiguous() + else: + scale = float(s.item()) + out[name] = (q.float() * scale).to(dtype=dtype).contiguous() + for name, t in obj["passthrough"].items(): + out_t = t.detach().to("cpu").contiguous() + orig_dtype = passthrough_orig_dtypes.get(name) + if isinstance(orig_dtype, str): + out_t = out_t.to(dtype=getattr(torch, orig_dtype)).contiguous() + out[name] = out_t + return out + +# --- Data loading --- + +def load_data_shard(file: Path) -> Tensor: + header_bytes = 256 * np.dtype(" int: + key = str(file) + cached = _SHARD_NTOKENS_CACHE.get(key) + if cached is not None: + return cached + header = np.fromfile(file, dtype=" np.memmap: + key = str(file) + mm = _MMAP_CACHE.get(key) + if mm is not None: + return mm + n = _read_num_tokens(file) + mm = np.memmap(file, mode="r", dtype=" int: + if n <= 1: + return 1 + while True: + s = int(self._rng.integers(1, n)) + if math.gcd(s, n) == 1: + return s + def _reset_cursor(self, si: int, seq_len: int) -> None: + nt = int(self._num_tokens[si]) + max_phase = min(seq_len - 1, max(0, nt - seq_len - 1)) + phase = int(self._rng.integers(max_phase + 1)) if max_phase > 0 else 0 + bc = (nt - 1 - phase) // seq_len + self._cursor_phase[si] = phase + self._cursor_block_count[si] = bc + self._cursor_next[si] = 0 + self._cursor_start[si] = int(self._rng.integers(bc)) if bc > 1 else 0 + self._cursor_stride[si] = self._pick_coprime_stride(bc) + self._cursor_init[si] = True + def _ensure_cursor(self, si: int, seq_len: int) -> None: + if not self._cursor_init[si] or self._cursor_next[si] >= self._cursor_block_count[si]: + self._reset_cursor(si, seq_len) + def _take_from_shard(self, si: int, seq_len: int, count: int, out: list[tuple[int, int]]) -> None: + rem = count + while rem > 0: + self._ensure_cursor(si, seq_len) + bc = int(self._cursor_block_count[si]) + ni = int(self._cursor_next[si]) + take = min(rem, bc - ni) + phase = int(self._cursor_phase[si]) + start = int(self._cursor_start[si]) + stride = int(self._cursor_stride[si]) + for j in range(take): + bi = (start + (ni + j) * stride) % bc + out.append((si, phase + bi * seq_len)) + self._cursor_next[si] = ni + take + rem -= take + def _init_pipeline(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> None: + local_tokens = global_tokens // (self.world_size * grad_accum_steps) + num_seqs = local_tokens // seq_len + global_num_seqs = num_seqs * self.world_size + self._cfg = (local_tokens, seq_len, num_seqs, global_num_seqs) + bbc = (self._num_tokens - 1) // seq_len + eligible = bbc > 0 + self._eligible_shards = np.nonzero(eligible)[0].astype(np.int64) + self._base_block_counts = bbc[self._eligible_shards].astype(np.int64) + def _sample_global_windows(self) -> list[tuple[int, int]]: + assert self._cfg is not None and self._eligible_shards is not None + _, seq_len, _, gns = self._cfg + ec = int(self._eligible_shards.size) + progress = min(self._batches_built / 1800.0, 1.0) + remaining = np.empty(ec, dtype=np.float64) + for i, si in enumerate(self._eligible_shards.tolist()): + if self._cursor_init[si]: + r = int(self._cursor_block_count[si]) - int(self._cursor_next[si]) + remaining[i] = float(max(r, 1)) + else: + remaining[i] = float(self._base_block_counts[i]) + alpha = 0.90 - 0.40 * progress + weights = np.power(remaining, alpha) + ws = float(weights.sum()) + if not np.isfinite(ws) or ws <= 0.0: + weights = np.ones(ec, dtype=np.float64) + ws = float(weights.sum()) + probs = weights / ws + low = min(max(8, self.world_size), ec, gns) + high = min(max(32, self.world_size * 8), ec, gns) + mix = max(1, min(int(round(low + progress * (high - low))), ec, gns)) + cp = self._rng.choice(ec, size=mix, replace=False, p=probs) + cs = self._eligible_shards[cp] + cpr = probs[cp].copy() + cpr /= cpr.sum() + counts = np.ones(mix, dtype=np.int64) + extra = gns - mix + if extra > 0: + counts += self._rng.multinomial(extra, cpr).astype(np.int64) + perm = self._rng.permutation(mix) + cs, counts = cs[perm], counts[perm] + buckets: list[list[tuple[int, int]]] = [] + for si, cnt in zip(cs.tolist(), counts.tolist()): + b: list[tuple[int, int]] = [] + self._take_from_shard(int(si), seq_len, int(cnt), b) + if b: + if len(b) > 1: + bp = self._rng.permutation(len(b)) + b = [b[int(k)] for k in bp.tolist()] + buckets.append(b) + windows: list[tuple[int, int]] = [] + active = [i for i, bk in enumerate(buckets) if bk] + while active: + order = self._rng.permutation(len(active)) + new_active: list[int] = [] + for oi in order.tolist(): + bi = active[oi] + if buckets[bi]: + windows.append(buckets[bi].pop()) + if buckets[bi]: + new_active.append(bi) + active = new_active + return windows + def next_batch(self, global_tokens: int, seq_len: int, grad_accum_steps: int) -> tuple[Tensor, Tensor]: + if self._cfg is None: + self._init_pipeline(global_tokens, seq_len, grad_accum_steps) + _, _, num_seqs, gns = self._cfg + gw = self._sample_global_windows() + local_w = gw[self.rank::self.world_size] + x = torch.empty((num_seqs, seq_len), dtype=torch.int64) + y = torch.empty((num_seqs, seq_len), dtype=torch.int64) + for slot, (si, pos) in enumerate(local_w): + mm = _get_shard_memmap(self.files[si]) + window = torch.as_tensor(np.array(mm[pos:pos + seq_len + 1], dtype=np.int64)) + x[slot] = window[:-1] + y[slot] = window[1:] + self._batches_built += 1 + return x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) + +# --- Transformer modules --- + +class RMSNorm(nn.Module): + def __init__(self, eps: float | None = None): + super().__init__() + self.eps = eps + def forward(self, x: Tensor) -> Tensor: + return F.rms_norm(x, (x.size(-1),), eps=self.eps) +class CastedLinear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + w = self.weight.to(x.dtype) + bias = self.bias.to(x.dtype) if self.bias is not None else None + return F.linear(x, w, bias) +def restore_low_dim_params_to_fp32(module: nn.Module) -> None: + with torch.no_grad(): + for name, param in module.named_parameters(): + if (param.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS)) and param.dtype != torch.float32: + param.data = param.data.float() +class Rotary(nn.Module): + def __init__(self, dim: int, base: float = 10000.0, train_seq_len: int = 1024, rope_dims: int = 0): + super().__init__() + self.dim = dim + self.base = base + self.train_seq_len = train_seq_len + self.rope_dims = rope_dims if rope_dims > 0 else dim + inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached: Tensor | None = None + self._sin_cached: Tensor | None = None + def _build_cache(self, seq_len: int, device: torch.device) -> None: + rd = self.rope_dims + if seq_len > self.train_seq_len: + scale = seq_len / self.train_seq_len + new_base = self.base * (scale ** (rd / (rd - 2))) + inv_freq = 1.0 / (new_base ** (torch.arange(0, rd, 2, dtype=torch.float32, device=device) / rd)) + else: + inv_freq = self.inv_freq.to(device) + t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq) + self._cos_cached = freqs.cos()[None, :, None, :].contiguous() + self._sin_cached = freqs.sin()[None, :, None, :].contiguous() + self._seq_len_cached = seq_len + def prime_cache(self, seq_len: int, device: torch.device) -> None: + with torch.no_grad(): + self._build_cache(seq_len, device) + def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: + if ( + self._cos_cached is None + or self._sin_cached is None + or self._seq_len_cached < seq_len + or self._cos_cached.device != device + ): + self.prime_cache(seq_len, device) + cos = self._cos_cached[:, :seq_len] + sin = self._sin_cached[:, :seq_len] + if cos.dtype != dtype: + cos = cos.to(dtype=dtype) + if sin.dtype != dtype: + sin = sin.to(dtype=dtype) + return cos, sin +def prime_rotary_caches(module: nn.Module, device: torch.device, *seq_lens: int) -> None: + raw_module = getattr(module, "_orig_mod", module) + max_seq_len = max((int(seq_len) for seq_len in seq_lens if int(seq_len) > 0), default=0) + if max_seq_len <= 0: + return + for submodule in raw_module.modules(): + if isinstance(submodule, Rotary): + submodule.prime_cache(max_seq_len, device) +def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor, rope_dims: int = 0) -> Tensor: + if rope_dims > 0 and rope_dims < x.size(-1): + x_rope, x_pass = x[..., :rope_dims], x[..., rope_dims:] + half = rope_dims // 2 + x1, x2 = x_rope[..., :half], x_rope[..., half:] + x_rope = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + return torch.cat((x_rope, x_pass), dim=-1) + half = x.size(-1) // 2 + x1, x2 = x[..., :half], x[..., half:] + return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1) + +class CausalSelfAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + rope_base: float, + qk_gain_init: float, + ): + super().__init__() + if dim % num_heads != 0: + raise ValueError("model_dim must be divisible by num_heads") + if num_heads % num_kv_heads != 0: + raise ValueError("num_heads must be divisible by num_kv_heads") + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = dim // num_heads + if self.head_dim % 2 != 0: + raise ValueError("head_dim must be even for RoPE") + # No CastedLinear -- weights come from banks + self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32)) + self.rope_dims = 0 # set by GPT.__init__ for partial RoPE + self.rotary = Rotary(self.head_dim, base=rope_base, train_seq_len=1024) + self.use_xsa = False # set by GPT.__init__ for deep layers only + def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor: + """Efficient XSA: subtract self-value projection via GQA-aware reshape (no repeat_interleave). + y: [B, T, H, D], v: [B, T, Hkv, D]. H must be divisible by Hkv.""" + B, T, H, D = y.shape + Hkv = v.size(-2) + group = H // Hkv + y_g = y.reshape(B, T, Hkv, group, D) # [B, T, Hkv, group, D] + vn = F.normalize(v, dim=-1).unsqueeze(-2) # [B, T, Hkv, 1, D] -- broadcast ready + proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn + return (y_g - proj).reshape(B, T, H, D) + def forward(self, x: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, v_embed: Tensor | None = None) -> Tensor: + if getattr(self, '_save_gptq', False): + self._gptq_qkv_in = x.detach() + bsz, seqlen, dim = x.shape + q = F.linear(x, q_w.to(x.dtype)).reshape(bsz, seqlen, self.num_heads, self.head_dim) + k = F.linear(x, k_w.to(x.dtype)).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + v = F.linear(x, v_w.to(x.dtype)) + if v_embed is not None: + v = v + v_embed + v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim) + q = F.rms_norm(q, (q.size(-1),)) + k = F.rms_norm(k, (k.size(-1),)) + cos, sin = self.rotary(seqlen, x.device, q.dtype) + q = apply_rotary_emb(q, cos, sin, self.rope_dims) + k = apply_rotary_emb(k, cos, sin, self.rope_dims) + q = q * self.q_gain.to(dtype=q.dtype)[None, None, :, None] + y = flash_attn_3_func(q, k, v, causal=True) + if self.use_xsa: + y = self._xsa_efficient(y, v) + y = y.reshape(bsz, seqlen, dim) + if getattr(self, '_save_gptq', False): + self._gptq_o_in = y.detach() + return F.linear(y, out_w.to(x.dtype)) + +class SmearGate(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.gate = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + def forward(self, x: Tensor) -> Tensor: + g = torch.sigmoid(self.gate.to(dtype=x.dtype))[None, None, :] + x_prev = torch.cat([torch.zeros_like(x[:, :1]), x[:, :-1]], dim=1) + return (1 - g) * x + g * x_prev + +class BigramHashEmbedding(nn.Module): + def __init__(self, bigram_vocab_size: int, bigram_dim: int, model_dim: int): + super().__init__() + self.bigram_vocab_size = bigram_vocab_size + self.embed = nn.Embedding(bigram_vocab_size, bigram_dim) + nn.init.zeros_(self.embed.weight) + self.proj = CastedLinear(bigram_dim, model_dim, bias=False) if bigram_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.05, dtype=torch.float32)) + def bigram_hash(self, tokens: Tensor) -> Tensor: + t = tokens.to(torch.int32) + mod = self.bigram_vocab_size - 1 + out = torch.empty_like(t) + out[..., 0] = mod + out[..., 1:] = torch.bitwise_xor(36313 * t[..., 1:], 27191 * t[..., :-1]) % mod + return out.long() + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(self.bigram_hash(token_ids)) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class ValueEmbedding(nn.Module): + """Reinject token identity into attention values at specific layers. + Each table maps vocab tokens to a low-dim embedding, projected to model_dim.""" + def __init__(self, vocab_size: int, ve_dim: int, model_dim: int): + super().__init__() + self.embed = nn.Embedding(vocab_size, ve_dim) + nn.init.normal_(self.embed.weight, std=0.01) + self.proj = CastedLinear(ve_dim, model_dim, bias=False) if ve_dim != model_dim else None + if self.proj is not None: + nn.init.zeros_(self.proj.weight) + self.scale = nn.Parameter(torch.tensor(0.1, dtype=torch.float32)) + def forward(self, token_ids: Tensor) -> Tensor: + h = self.embed(token_ids) + if self.proj is not None: + h = self.proj(h) + return h * self.scale.to(dtype=h.dtype) + +class MLP(nn.Module): + def __init__(self, dim: int, mlp_mult: int, neg_slope: float = 0.5): + super().__init__() + self.neg_slope = neg_slope + # No CastedLinear -- weights come from banks + def forward(self, x: Tensor, up_w: Tensor, down_w: Tensor) -> Tensor: + if getattr(self, '_save_gptq', False): + self._gptq_up_in = x.detach() + x = F.leaky_relu(F.linear(x, up_w.to(x.dtype)), negative_slope=self.neg_slope) + x2 = x.square() + if getattr(self, '_save_gptq', False): + self._gptq_down_in = x2.detach() + return F.linear(x2, down_w.to(x.dtype)) + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + rope_base: float, + qk_gain_init: float, + layer_idx: int = 0, + ln_scale: bool = False, + neg_slope: float = 0.5, + ): + super().__init__() + self.layer_idx = layer_idx + self.attn_norm = RMSNorm() + self.mlp_norm = RMSNorm() + self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init) + self.mlp = MLP(dim, mlp_mult, neg_slope=neg_slope) + self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float()) + self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0 + def forward(self, x: Tensor, x0: Tensor, q_w: Tensor, k_w: Tensor, v_w: Tensor, out_w: Tensor, up_w: Tensor, down_w: Tensor, v_embed: Tensor | None = None) -> Tensor: + mix = self.resid_mix.to(dtype=x.dtype) + x_in = mix[0][None, None, :] * x + mix[1][None, None, :] * x0 + attn_out = self.attn(self.attn_norm(x_in) * self.ln_scale_factor, q_w, k_w, v_w, out_w, v_embed=v_embed) + x_out = x_in + self.attn_scale.to(dtype=x_in.dtype)[None, None, :] * attn_out + mlp_out = self.mlp_scale.to(dtype=x_out.dtype)[None, None, :] * self.mlp(self.mlp_norm(x_out) * self.ln_scale_factor, up_w, down_w) + return x_out + mlp_out + +class GPT(nn.Module): + def __init__( + self, + vocab_size: int, + num_layers: int, + model_dim: int, + num_heads: int, + num_kv_heads: int, + mlp_mult: int, + tie_embeddings: bool, + tied_embed_init_std: float, + logit_softcap: float, + rope_base: float, + qk_gain_init: float, + bigram_vocab_size: int = 0, + bigram_dim: int = 128, + xsa_last_n: int = 0, + rope_dims: int = 0, + ln_scale: bool = False, + ve_enabled: bool = False, + ve_dim: int = 128, + ve_layers: str = "9,10", + neg_slope: float = 0.5, + ): + super().__init__() + self._ve_target_dim = num_kv_heads * (model_dim // num_heads) # kv_dim for value projection + if logit_softcap <= 0.0: + raise ValueError(f"logit_softcap must be positive, got {logit_softcap}") + self.tie_embeddings = tie_embeddings + self.tied_embed_init_std = tied_embed_init_std + self.logit_softcap = logit_softcap + self.tok_emb = nn.Embedding(vocab_size, model_dim) + self.bigram = BigramHashEmbedding(bigram_vocab_size, bigram_dim, model_dim) if bigram_vocab_size > 0 else None + self.smear = SmearGate(model_dim) + self.num_encoder_layers = num_layers // 2 + self.num_decoder_layers = num_layers - self.num_encoder_layers + self.num_skip_weights = min(self.num_encoder_layers, self.num_decoder_layers) + self.skip_weights = nn.Parameter(torch.ones(self.num_skip_weights, model_dim, dtype=torch.float32)) + # Parameter banks: contiguous 3D tensors for batched optimizer + head_dim = model_dim // num_heads + kv_dim = num_kv_heads * head_dim + mlp_dim = int(mlp_mult * model_dim) + self.num_layers = num_layers + self.qo_bank = nn.Parameter(torch.empty(2 * num_layers, model_dim, model_dim)) + self.kv_bank = nn.Parameter(torch.empty(2 * num_layers, kv_dim, model_dim)) + self.mlp_up_bank = nn.Parameter(torch.empty(num_layers, mlp_dim, model_dim)) + self.mlp_down_bank = nn.Parameter(torch.empty(num_layers, model_dim, mlp_dim)) + self.blocks = nn.ModuleList( + [ + Block( + model_dim, + num_heads, + num_kv_heads, + mlp_mult, + rope_base, + qk_gain_init, + layer_idx=i, + ln_scale=ln_scale, + neg_slope=neg_slope, + ) + for i in range(num_layers) + ] + ) + if rope_dims > 0: + head_dim = model_dim // num_heads + for block in self.blocks: + block.attn.rope_dims = rope_dims + block.attn.rotary = Rotary(head_dim, base=rope_base, train_seq_len=1024, rope_dims=rope_dims) + self.ve_layer_indices = [int(x) for x in ve_layers.split(",") if x.strip()] if ve_enabled else [] + kv_dim_ve = self._ve_target_dim + if self.ve_layer_indices: + self.ve_shared = ValueEmbedding(vocab_size, ve_dim, kv_dim_ve) + self.ve_layer_scales = nn.ParameterList( + [nn.Parameter(torch.ones(1, dtype=torch.float32)) for _ in self.ve_layer_indices] + ) + else: + self.ve_shared = None + self.ve_layer_scales = nn.ParameterList() + self.value_embeds = nn.ModuleList() # keep empty for compat + self.final_norm = RMSNorm() + self.lm_head = None if tie_embeddings else CastedLinear(model_dim, vocab_size, bias=False) + if self.lm_head is not None: + self.lm_head._zero_init = True + if xsa_last_n > 0: + for i in range(max(0, num_layers - xsa_last_n), num_layers): + self.blocks[i].attn.use_xsa = True + self._init_weights() + def _init_weights(self) -> None: + if self.tie_embeddings: + nn.init.normal_(self.tok_emb.weight, mean=0.0, std=self.tied_embed_init_std) + n = self.num_layers + proj_scale = 1.0 / math.sqrt(2 * n) + # Init banks: orthogonal, with proj layers scaled down and out/down zero-init + for i in range(n): + nn.init.orthogonal_(self.qo_bank.data[i], gain=1.0) # Q + nn.init.zeros_(self.qo_bank.data[n + i]) # Out (zero init) + nn.init.orthogonal_(self.kv_bank.data[i], gain=1.0) # K + nn.init.orthogonal_(self.kv_bank.data[n + i], gain=1.0) # V + nn.init.orthogonal_(self.mlp_up_bank.data[i], gain=1.0) # MLP up + nn.init.zeros_(self.mlp_down_bank.data[i]) # MLP down (zero init) + # Scale proj layers (out_proj and mlp_down are "proj" layers) + self.qo_bank.data[n + i].mul_(proj_scale) + self.mlp_down_bank.data[i].mul_(proj_scale) + # Init remaining nn.Linear modules (bigram proj, lm_head) + for name, module in self.named_modules(): + if isinstance(module, nn.Linear): + if getattr(module, "_zero_init", False): + nn.init.zeros_(module.weight) + elif module.weight.ndim == 2 and module.weight.shape[0] >= 64 and module.weight.shape[1] >= 64: + nn.init.orthogonal_(module.weight, gain=1.0) + def _get_ve(self, layer_idx: int, input_ids: Tensor, ve_cache: dict | None = None) -> Tensor | None: + """Get value embedding for a specific layer using shared table + per-layer scale.""" + if self.ve_shared is None or layer_idx not in self.ve_layer_indices: + return None + if ve_cache is not None and 've' not in ve_cache: + ve_cache['ve'] = self.ve_shared(input_ids) + ve_base = ve_cache['ve'] if ve_cache is not None else self.ve_shared(input_ids) + ve_idx = self.ve_layer_indices.index(layer_idx) + return ve_base * self.ve_layer_scales[ve_idx].to(dtype=ve_base.dtype) + def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor: + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve) + x = self.final_norm(x) + x_flat = x.reshape(-1, x.size(-1)) + targets = target_ids.reshape(-1) + if self.tie_embeddings: + logits_proj = F.linear(x_flat, self.tok_emb.weight) + else: + if self.lm_head is None: + raise RuntimeError("lm_head is required when tie_embeddings=False") + logits_proj = self.lm_head(x_flat) + logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + return F.cross_entropy(logits.float(), targets, reduction="mean") + def forward_logits(self, input_ids: Tensor) -> Tensor: + """Return logits (bsz, seq_len, vocab) without computing loss.""" + n = self.num_layers + x = self.tok_emb(input_ids) + if self.bigram is not None: + x = x + self.bigram(input_ids) + x = F.rms_norm(x, (x.size(-1),)) + x = self.smear(x) + x0 = x + skips: list[Tensor] = [] + ve_cache: dict = {} + for i in range(self.num_encoder_layers): + ve = self._get_ve(i, input_ids, ve_cache) + x = self.blocks[i](x, x0, + self.qo_bank[i], self.kv_bank[i], self.kv_bank[n + i], + self.qo_bank[n + i], self.mlp_up_bank[i], self.mlp_down_bank[i], + v_embed=ve) + skips.append(x) + for i in range(self.num_decoder_layers): + bi = self.num_encoder_layers + i + if skips: + x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop() + ve = self._get_ve(bi, input_ids, ve_cache) + x = self.blocks[bi](x, x0, + self.qo_bank[bi], self.kv_bank[bi], self.kv_bank[n + bi], + self.qo_bank[n + bi], self.mlp_up_bank[bi], self.mlp_down_bank[bi], + v_embed=ve) + x = self.final_norm(x) + if self.tie_embeddings: + logits_proj = F.linear(x, self.tok_emb.weight) + else: + logits_proj = self.lm_head(x) + return self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap) + +# --- Sliding window evaluation --- + +def eval_val_sliding( + args: Hyperparameters, + base_model: nn.Module, + rank: int, + world_size: int, + device: torch.device, + val_tokens: Tensor, + base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, + is_boundary_token_lut: Tensor, + stride: int, + batch_seqs: int = 32, + eval_seq_len: int | None = None, +) -> tuple[float, float]: + """Sliding window evaluation: each token scored with maximum context.""" + seq_len = eval_seq_len or args.train_seq_len + prime_rotary_caches(base_model, device, seq_len) + total_tokens = val_tokens.numel() - 1 + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= 1] + total_windows = len(window_starts) + my_s = (total_windows * rank) // world_size + my_e = (total_windows * (rank + 1)) // world_size + my_windows = window_starts[my_s:my_e] + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + base_model.eval() + compiled_logits = compile_with_env(base_model.forward_logits, is_eval=True) + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk[:-1] + y_batch[i, :wlen] = chunk[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + cudagraph_step_begin() + logits = compiled_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), + reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt = y_batch[i, s:wlen] + prev = x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + val_loss = (loss_sum / token_count).item() + bits_per_token = val_loss / math.log(2.0) + tokens_per_byte = token_count.item() / byte_count.item() + base_model.train() + return val_loss, bits_per_token * tokens_per_byte + + +def eval_val_sliding_ttt( + args: Hyperparameters, base_model: nn.Module, rank: int, world_size: int, + device: torch.device, val_tokens: Tensor, base_bytes_lut: Tensor, + has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor, + stride: int, batch_seqs: int = 32, log0=print, +) -> tuple[float, float]: + """Legal score-first TTT (PR #461 recipe): score each chunk with sliding windows, + then train on it. Every token scored BEFORE any update that could use it.""" + seq_len = args.train_seq_len + prime_rotary_caches(base_model, device, seq_len) + total_tokens = val_tokens.numel() - 1 + ttt_chunk = args.ttt_chunk_tokens + + # Pre-compute all window starts + window_starts = [ws for ws in range(0, total_tokens, stride) + if min(ws + seq_len, total_tokens) - ws >= stride or ws == 0] + + # Assign each window to a chunk based on the first token it scores + num_chunks = (total_tokens + ttt_chunk - 1) // ttt_chunk + chunk_windows: list[list[int]] = [[] for _ in range(num_chunks)] + for ws in window_starts: + end = min(ws + seq_len, total_tokens) + wlen = end - ws + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_start = ws + s + ci = min(scored_start // ttt_chunk, num_chunks - 1) + chunk_windows[ci].append(ws) + + log0(f"ttt_sliding:start chunks={num_chunks} chunk_tokens={ttt_chunk} " + f"total_windows={len(window_starts)} stride={stride} " + f"ttt_lr={args.ttt_lr} ttt_epochs={args.ttt_epochs} " + f"freeze_blocks={args.ttt_freeze_blocks}") + + loss_sum = torch.zeros((), device=device, dtype=torch.float64) + token_count = torch.zeros((), device=device, dtype=torch.float64) + byte_count = torch.zeros((), device=device, dtype=torch.float64) + + # Freeze first N blocks + frozen_block_ids = set(range(min(args.ttt_freeze_blocks, len(base_model.blocks)))) + ttt_params = [] + for name, p in base_model.named_parameters(): + freeze = False + for bi in frozen_block_ids: + if f"blocks.{bi}." in name: + freeze = True + break + if freeze: + p.requires_grad_(False) + else: + p.requires_grad_(True) + ttt_params.append(p) + + log0(f"ttt_sliding:params unfrozen={sum(p.numel() for p in ttt_params)} " + f"frozen={sum(p.numel() for p in base_model.parameters() if not p.requires_grad)}") + + optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum) + t0 = time.perf_counter() + + for ci in range(num_chunks): + windows = chunk_windows[ci] + if not windows: + continue + chunk_start = ci * ttt_chunk + chunk_end = min((ci + 1) * ttt_chunk, total_tokens) + + # --- Phase 1: SCORE this chunk's windows (inference_mode) --- + my_s = (len(windows) * rank) // world_size + my_e = (len(windows) * (rank + 1)) // world_size + my_windows = windows[my_s:my_e] + + base_model.eval() + with torch.inference_mode(): + for bi in range(0, len(my_windows), batch_seqs): + batch_ws = my_windows[bi:bi + batch_seqs] + bsz = len(batch_ws) + x_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + y_batch = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device) + wlens: list[int] = [] + for i, ws in enumerate(batch_ws): + end = min(ws + seq_len, total_tokens) + wlen = end - ws + wlens.append(wlen) + chunk_tok = val_tokens[ws:end + 1].to(dtype=torch.int64, device=device) + x_batch[i, :wlen] = chunk_tok[:-1] + y_batch[i, :wlen] = chunk_tok[1:] + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + logits = base_model.forward_logits(x_batch) + nll = F.cross_entropy( + logits.reshape(-1, logits.size(-1)).float(), + y_batch.reshape(-1), reduction="none", + ).reshape(bsz, seq_len) + for i, ws in enumerate(batch_ws): + wlen = wlens[i] + s = 0 if ws == 0 else max(wlen - stride, 0) + scored_nll = nll[i, s:wlen].to(torch.float64) + loss_sum += scored_nll.sum() + token_count += float(wlen - s) + tgt, prev = y_batch[i, s:wlen], x_batch[i, s:wlen] + tb = base_bytes_lut[tgt].to(torch.float64) + tb += (has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]).to(torch.float64) + byte_count += tb.sum() + + # --- Phase 2: TRAIN on this chunk (already scored = legal) --- + is_last_chunk = (ci == num_chunks - 1) + if not is_last_chunk and args.ttt_epochs > 0: + base_model.train() + chunk_seqs = (chunk_end - chunk_start) // seq_len + if chunk_seqs > 0: + cos_lr = args.ttt_lr * 0.5 * (1.0 + math.cos(math.pi * ci / max(num_chunks - 1, 1))) + for pg in optimizer.param_groups: + pg['lr'] = cos_lr + my_seq_s = (chunk_seqs * rank) // world_size + my_seq_e = (chunk_seqs * (rank + 1)) // world_size + my_chunk_seqs = my_seq_e - my_seq_s + for _ep in range(args.ttt_epochs): + for bs in range(0, my_chunk_seqs, args.ttt_batch_seqs): + be = min(bs + args.ttt_batch_seqs, my_chunk_seqs) + actual_bs = my_seq_s + bs + start_tok = chunk_start + actual_bs * seq_len + end_tok = chunk_start + (my_seq_s + be) * seq_len + 1 + if end_tok > val_tokens.numel(): + continue + local = val_tokens[start_tok:end_tok].to(device=device, dtype=torch.int64) + x = local[:-1].reshape(-1, seq_len) + y = local[1:].reshape(-1, seq_len) + optimizer.zero_grad(set_to_none=True) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + loss = base_model(x, y) + loss.backward() + if world_size > 1: + for p in ttt_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + torch.nn.utils.clip_grad_norm_(ttt_params, args.ttt_grad_clip) + optimizer.step() + + if rank == 0 and (ci % 10 == 0 or ci == num_chunks - 1): + elapsed = time.perf_counter() - t0 + rl = loss_sum.item() / max(token_count.item(), 1) + rbpb = rl / math.log(2.0) * (token_count.item() / max(byte_count.item(), 1)) if token_count.item() > 0 else 0.0 + log0(f" ttt_chunk [{ci+1}/{num_chunks}] bpb={rbpb:.6f} time={elapsed:.1f}s") + + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM) + dist.all_reduce(token_count, op=dist.ReduceOp.SUM) + dist.all_reduce(byte_count, op=dist.ReduceOp.SUM) + + val_loss = (loss_sum / token_count).item() + val_bpb = val_loss / math.log(2.0) * (token_count.item() / byte_count.item()) + + for p in base_model.parameters(): + p.requires_grad_(True) + base_model.eval() + + log0(f"ttt_sliding:done val_loss={val_loss:.6f} val_bpb={val_bpb:.6f} " + f"elapsed={time.perf_counter() - t0:.1f}s") + return val_loss, val_bpb + + +# --- GPTQ-lite int6 quantization --- + +def _classify_param(name: str) -> str: + if "tok_emb" in name or "lm_head" in name: + return "embed" + if ".mlp." in name: + return "mlp" + if ".attn." in name or (".proj." in name and ".mlp." not in name): + return "attn" + return "other" +def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tensor]: + t32 = t.float() + if t32.ndim == 2: + best_q, best_s, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(t32.abs(), pct, dim=1) + else: + row_clip = t32.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + q = torch.clamp(torch.round(t32 / s.float()[:, None]), -clip_range, clip_range).to(torch.int8) + recon = q.float() * s.float()[:, None] + err = (t32 - recon).pow(2).mean().item() + if err < best_err: + best_q, best_s, best_err = q, s, err + return best_q, best_s + amax = t32.abs().max().item() + scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16) + q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8) + return q, scale + +def _unbank_state_dict(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: + """Convert 3D bank tensors into individual 2D tensors with standard names.""" + out: dict[str, Tensor] = {} + n = num_layers + for name, tensor in sd.items(): + if name == "qo_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_q.weight"] = tensor[i] + out[f"blocks.{i}.attn.proj.weight"] = tensor[n + i] + elif name == "kv_bank": + for i in range(n): + out[f"blocks.{i}.attn.c_k.weight"] = tensor[i] + out[f"blocks.{i}.attn.c_v.weight"] = tensor[n + i] + elif name == "mlp_up_bank": + for i in range(n): + out[f"blocks.{i}.mlp.fc.weight"] = tensor[i] + elif name == "mlp_down_bank": + for i in range(n): + out[f"blocks.{i}.mlp.proj.weight"] = tensor[i] + else: + out[name] = tensor + return out + +def _rebank_state_dict(sd: dict[str, Tensor], num_layers: int, template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + """Convert individual 2D tensors back into 3D bank tensors.""" + out: dict[str, Tensor] = {} + n = num_layers + # Reconstruct banks from individual weight keys + qo_slices = [None] * (2 * n) + kv_slices = [None] * (2 * n) + up_slices = [None] * n + down_slices = [None] * n + consumed = set() + for i in range(n): + qk = f"blocks.{i}.attn.c_q.weight" + if qk in sd: + qo_slices[i] = sd[qk] + consumed.add(qk) + ok = f"blocks.{i}.attn.proj.weight" + if ok in sd: + qo_slices[n + i] = sd[ok] + consumed.add(ok) + kk = f"blocks.{i}.attn.c_k.weight" + if kk in sd: + kv_slices[i] = sd[kk] + consumed.add(kk) + vk = f"blocks.{i}.attn.c_v.weight" + if vk in sd: + kv_slices[n + i] = sd[vk] + consumed.add(vk) + fk = f"blocks.{i}.mlp.fc.weight" + if fk in sd: + up_slices[i] = sd[fk] + consumed.add(fk) + dk = f"blocks.{i}.mlp.proj.weight" + if dk in sd: + down_slices[i] = sd[dk] + consumed.add(dk) + out["qo_bank"] = torch.stack(qo_slices).to(dtype=template_sd["qo_bank"].dtype) + out["kv_bank"] = torch.stack(kv_slices).to(dtype=template_sd["kv_bank"].dtype) + out["mlp_up_bank"] = torch.stack(up_slices).to(dtype=template_sd["mlp_up_bank"].dtype) + out["mlp_down_bank"] = torch.stack(down_slices).to(dtype=template_sd["mlp_down_bank"].dtype) + for name, tensor in sd.items(): + if name not in consumed: + out[name] = tensor + return out + +def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str], clip_range: int = 31, + hessians: dict[str, Tensor] | None = None): + num_layers_total = max( + (int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")), + default=0, + ) + 1 + late_k_layers = set(range(num_layers_total - 2, num_layers_total)) + result: dict[str, Tensor] = {} + meta: dict[str, object] = {} + gptq_count, naive_count = 0, 0 + for name, tensor in state_dict.items(): + t = tensor.detach().cpu().contiguous() + cat = _classify_param(name) + if not t.is_floating_point() or t.numel() <= 65536: + result[name] = t.to(torch.float16) if t.is_floating_point() else t + meta[name] = "passthrough" + continue + if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): + result[name] = t.float() + meta[name] = "passthrough_ctrl" + continue + if cat in int6_cats and t.ndim >= 1: + H = hessians.get(name) if hessians else None + if H is not None and t.ndim == 2: + q, s = gptq_quantize_weight(t, H.cpu(), clip_range=clip_range) + gptq_count += 1 + else: + q, s = quantize_int6_per_row(t, clip_range=clip_range) + naive_count += 1 + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int6"} + else: + q, s = quantize_float_tensor(t) + result[name + ".q"] = q + result[name + ".scale"] = s + meta[name] = {"type": "int8"} + if hessians: + print(f"gptq_quantize: {gptq_count} GPTQ layers, {naive_count} naive layers", flush=True) + return result, meta +def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object], + template_sd: dict[str, Tensor]) -> dict[str, Tensor]: + out: dict[str, Tensor] = {} + for name, orig in template_sd.items(): + info = meta.get(name) + if info is None: + continue + orig_dtype = orig.dtype + if info in ("passthrough", "passthrough_ctrl", "passthrough_fp16"): + t = result[name] + if t.dtype == torch.float16 and orig_dtype in (torch.float32, torch.bfloat16): + t = t.to(orig_dtype) + out[name] = t + continue + q, s = result[name + ".q"], result[name + ".scale"] + if s.ndim > 0: + out[name] = (q.float() * s.float().view(q.shape[0], *([1] * (q.ndim - 1)))).to(orig_dtype) + else: + out[name] = (q.float() * float(s.item())).to(orig_dtype) + return out + +# --- Full Hessian GPTQ --- + +def gptq_quantize_weight(W: Tensor, H: Tensor, clip_range: int = 31, + block_size: int = 128, percdamp: float = 0.01) -> tuple[Tensor, Tensor]: + """GPTQ with Cholesky error compensation and actorder (Frantar et al., ICLR 2023).""" + W_orig = W.float().clone() + rows, cols = W_orig.shape + H = H.float().clone() + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + damp = percdamp * H.diag().mean() + H.diagonal().add_(damp) + perm = torch.argsort(H.diag(), descending=True) + invperm = torch.argsort(perm) + W_perm = W_orig[:, perm].clone() + W_perm[:, dead[perm]] = 0 + H = H[perm][:, perm] + try: + Hinv = torch.cholesky_inverse(torch.linalg.cholesky(H)) + Hinv = torch.linalg.cholesky(Hinv, upper=True) + except torch.linalg.LinAlgError: + return quantize_int6_per_row(W_orig, clip_range) + best_q, best_scale, best_err = None, None, float('inf') + for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]: + if pct < 1.0: + row_clip = torch.quantile(W_orig.abs(), pct, dim=1) + else: + row_clip = W_orig.abs().amax(dim=1) + s = (row_clip / clip_range).clamp_min(1.0 / clip_range).to(torch.float16) + sf = s.float() + Q = torch.zeros(rows, cols, dtype=torch.int8) + W_work = W_perm.clone() + for i1 in range(0, cols, block_size): + i2 = min(i1 + block_size, cols) + W_block = W_work[:, i1:i2].clone() + Hinv_block = Hinv[i1:i2, i1:i2] + Err = torch.zeros(rows, i2 - i1) + for j in range(i2 - i1): + w_col = W_block[:, j] + d = Hinv_block[j, j] + q_col = torch.clamp(torch.round(w_col / sf), -clip_range, clip_range) + Q[:, i1 + j] = q_col.to(torch.int8) + err = (w_col - q_col.float() * sf) / d + Err[:, j] = err + W_block[:, j:] -= err.unsqueeze(1) * Hinv_block[j, j:].unsqueeze(0) + if i2 < cols: + W_work[:, i2:] -= Err @ Hinv[i1:i2, i2:] + recon = Q.float() * sf[:, None] + mse = (W_perm - recon).pow(2).mean().item() + if mse < best_err: + best_q, best_scale, best_err = Q, s, mse + best_q = best_q[:, invperm] + return best_q, best_scale + +def _init_hessians(nl: int, dim: int, mlp_dim: int, device: torch.device) -> dict[str, Tensor]: + h: dict[str, Tensor] = {} + for i in range(nl): + for k in ['c_q', 'c_k', 'c_v']: + h[f'blocks.{i}.attn.{k}.weight'] = torch.zeros(dim, dim, dtype=torch.float32, device=device) + h[f'blocks.{i}.attn.proj.weight'] = torch.zeros(dim, dim, dtype=torch.float32, device=device) + h[f'blocks.{i}.mlp.fc.weight'] = torch.zeros(dim, dim, dtype=torch.float32, device=device) + h[f'blocks.{i}.mlp.proj.weight'] = torch.zeros(mlp_dim, mlp_dim, dtype=torch.float32, device=device) + return h + +def _accum_hessians(hessians: dict[str, Tensor], blocks: nn.ModuleList, dim: int, mlp_dim: int) -> None: + for i, block in enumerate(blocks): + qkv_in = block.attn._gptq_qkv_in.float().reshape(-1, dim) + h_qkv = qkv_in.t() @ qkv_in + hessians[f'blocks.{i}.attn.c_q.weight'] += h_qkv + hessians[f'blocks.{i}.attn.c_k.weight'] += h_qkv + hessians[f'blocks.{i}.attn.c_v.weight'] += h_qkv + o_in = block.attn._gptq_o_in.float().reshape(-1, dim) + hessians[f'blocks.{i}.attn.proj.weight'] += o_in.t() @ o_in + up_in = block.mlp._gptq_up_in.float().reshape(-1, dim) + hessians[f'blocks.{i}.mlp.fc.weight'] += up_in.t() @ up_in + down_in = block.mlp._gptq_down_in.float().reshape(-1, mlp_dim) + hessians[f'blocks.{i}.mlp.proj.weight'] += down_in.t() @ down_in + +def _finalize_hessians(hessians: dict[str, Tensor], num_batches: int) -> None: + for name in hessians: + hessians[name] = hessians[name].cpu() / num_batches + damp = 0.01 * torch.diag(hessians[name]).mean().clamp_min(1e-6) + hessians[name] += damp * torch.eye(hessians[name].shape[0]) + +def gptq_collect_hessians(base_model: nn.Module, train_loader, device: torch.device, + num_batches: int, batch_tokens: int, seq_len: int, + grad_accum_steps: int) -> dict[str, Tensor]: + """Collect Hessians H = X^T X from training data.""" + nl = base_model.num_layers + dim = base_model.tok_emb.weight.shape[1] + mlp_dim = base_model.mlp_up_bank.shape[1] + hessians = _init_hessians(nl, dim, mlp_dim, device) + for block in base_model.blocks: + block.attn._save_gptq = True + block.mlp._save_gptq = True + base_model.eval() + with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.bfloat16): + for _ in range(num_batches): + x, y = train_loader.next_batch(batch_tokens, seq_len, grad_accum_steps) + base_model(x, y) + _accum_hessians(hessians, base_model.blocks, dim, mlp_dim) + for block in base_model.blocks: + block.attn._save_gptq = False + block.mlp._save_gptq = False + _finalize_hessians(hessians, num_batches) + base_model.train() + return hessians + +# --- Training --- + +def main() -> None: + code = Path(__file__).read_text(encoding="utf-8") + args = Hyperparameters() + # zeropower_via_newtonschulz5 runs eagerly with bmm -- do NOT compile + distributed = "RANK" in os.environ and "WORLD_SIZE" in os.environ + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if world_size <= 0: + raise ValueError(f"WORLD_SIZE must be positive, got {world_size}") + if 8 % world_size != 0: + raise ValueError(f"WORLD_SIZE={world_size} must divide 8 so grad_accum_steps stays integral") + grad_accum_steps = 8 // world_size + grad_scale = 1.0 / grad_accum_steps + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required") + device = torch.device("cuda", local_rank) + torch.cuda.set_device(device) + if distributed: + dist.init_process_group(backend="nccl", device_id=device) + dist.barrier() + master_process = rank == 0 + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + from torch.backends.cuda import enable_cudnn_sdp, enable_flash_sdp, enable_math_sdp, enable_mem_efficient_sdp + enable_cudnn_sdp(False) + enable_flash_sdp(True) + enable_mem_efficient_sdp(False) + enable_math_sdp(False) + logfile = None + if master_process: + os.makedirs("logs", exist_ok=True) + logfile = f"logs/{args.run_id}.txt" + print(logfile) + def log0(msg: str, console: bool = True) -> None: + if not master_process: + return + if console: + print(msg) + if logfile is not None: + with open(logfile, "a", encoding="utf-8") as f: + print(msg, file=f) + log0(code, console=False) + log0("=" * 100, console=False) + log0(f"Running Python {sys.version}", console=False) + log0(f"Running PyTorch {torch.__version__}", console=False) + log0( + subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=False).stdout, + console=False, + ) + log0("=" * 100, console=False) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + if not args.tokenizer_path.endswith(".model"): + raise ValueError(f"Script only setup for SentencePiece .model file: {args.tokenizer_path}") + sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path) + if int(sp.vocab_size()) != args.vocab_size: + raise ValueError( + f"VOCAB_SIZE={args.vocab_size} does not match tokenizer vocab_size={int(sp.vocab_size())}" + ) + dataset_dir = Path(args.data_path).resolve() + actual_train_files = len(list(dataset_dir.glob("fineweb_train_*.bin"))) + effective_eval_seq_len = args.eval_seq_len if args.eval_seq_len > 0 else args.train_seq_len + val_seq_len = max(args.train_seq_len, effective_eval_seq_len) + val_tokens = load_validation_tokens(args.val_files, val_seq_len) + base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts( + sp, args.vocab_size, device + ) + log0(f"val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path={args.tokenizer_path}") + log0(f"train_loader:dataset:{dataset_dir.name} train_shards:{actual_train_files}") + log0(f"val_loader:shards pattern={args.val_files} tokens:{val_tokens.numel() - 1}") + base_model = GPT( + vocab_size=args.vocab_size, + num_layers=args.num_layers, + model_dim=args.model_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads, + mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, + tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, + rope_base=args.rope_base, + qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, + bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, + ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, + ve_dim=args.ve_dim, + ve_layers=args.ve_layers, + neg_slope=args.negative_slope, + ).to(device).bfloat16() + # Banks stay FP32 (like CastedLinear weights), cast to BF16 in forward + base_model.qo_bank.data = base_model.qo_bank.data.float() + base_model.kv_bank.data = base_model.kv_bank.data.float() + base_model.mlp_up_bank.data = base_model.mlp_up_bank.data.float() + base_model.mlp_down_bank.data = base_model.mlp_down_bank.data.float() + for module in base_model.modules(): + if isinstance(module, CastedLinear): + module.float() + restore_low_dim_params_to_fp32(base_model) + prime_rotary_caches(base_model, device, args.train_seq_len, effective_eval_seq_len) + # No DDP -- Parallel Muon handles bank grad communication via reduce-scatter, + # and non-bank grads are manually all-reduced before Adam steps. + compiled_model = compile_with_env(base_model) + model = compiled_model + + # Optimizer split: + # - 4 parameter banks -> Muon (batched Newton-Schulz) + # - token embedding -> Adam + # - scalars/control tensors -> Adam + # - bigram proj, VE proj -> Adam (small matrix params not worth banking) + matrix_params = [ + base_model.qo_bank, base_model.kv_bank, + base_model.mlp_up_bank, base_model.mlp_down_bank, + ] + block_named_params = list(base_model.blocks.named_parameters()) + scalar_params = [ + p + for name, p in block_named_params + if p.ndim < 2 or any(pattern in name for pattern in CONTROL_TENSOR_NAME_PATTERNS) + ] + if base_model.skip_weights.numel() > 0: + scalar_params.append(base_model.skip_weights) + scalar_params.append(base_model.smear.gate) + if base_model.bigram is not None: + scalar_params.append(base_model.bigram.scale) + token_lr = args.tied_embed_lr if args.tie_embeddings else args.embed_lr + tok_params = [{"params": [base_model.tok_emb.weight], "lr": token_lr, "base_lr": token_lr}] + if base_model.bigram is not None: + tok_params.append({"params": [base_model.bigram.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.bigram.proj is not None: + scalar_params.append(base_model.bigram.proj.weight) + if base_model.ve_shared is not None: + tok_params.append({"params": [base_model.ve_shared.embed.weight], "lr": token_lr, "base_lr": token_lr}) + if base_model.ve_shared.proj is not None: + scalar_params.append(base_model.ve_shared.proj.weight) + scalar_params.append(base_model.ve_shared.scale) + for s in base_model.ve_layer_scales: + scalar_params.append(s) + optimizer_tok = torch.optim.AdamW( + tok_params, + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + optimizer_muon = Muon( + matrix_params, + lr=args.matrix_lr, + momentum=args.muon_momentum, + backend_steps=args.muon_backend_steps, + weight_decay=args.muon_wd, + ) + for group in optimizer_muon.param_groups: + group["base_lr"] = args.matrix_lr + optimizer_scalar = torch.optim.AdamW( + [{"params": scalar_params, "lr": args.scalar_lr, "base_lr": args.scalar_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + weight_decay=args.adam_wd, + fused=True, + ) + # Non-bank params that need manual all-reduce (replicated across GPUs) + replicated_params = list(optimizer_tok.param_groups[0]["params"]) + for pg in optimizer_tok.param_groups[1:]: + replicated_params.extend(pg["params"]) + replicated_params.extend(scalar_params) + + optimizer_head = None + if base_model.lm_head is not None: + optimizer_head = torch.optim.Adam( + [{"params": [base_model.lm_head.weight], "lr": args.head_lr, "base_lr": args.head_lr}], + betas=(args.beta1, args.beta2), + eps=args.adam_eps, + fused=True, + ) + replicated_params.append(base_model.lm_head.weight) + optimizers: list[torch.optim.Optimizer] = [optimizer_tok, optimizer_muon, optimizer_scalar] + if optimizer_head is not None: + optimizers.append(optimizer_head) + log0(f"model_params:{sum(p.numel() for p in base_model.parameters())}") + xsa_layers = [i for i, b in enumerate(base_model.blocks) if b.attn.use_xsa] + log0(f"XSA:last_{args.xsa_last_n} active_layers:{xsa_layers}") + log0(f"world_size:{world_size} grad_accum_steps:{grad_accum_steps}") + log0("sdp_backends:cudnn=False flash=True mem_efficient=False math=False") + log0(f"attention_mode:gqa num_heads:{args.num_heads} num_kv_heads:{args.num_kv_heads}") + log0( + f"torch_compile:disable={TORCH_COMPILE_DISABLE} backend:{TORCH_COMPILE_BACKEND} " + f"mode:{TORCH_COMPILE_MODE or 'default'} dynamic:{TORCH_COMPILE_DYNAMIC} " + f"fullgraph:{TORCH_COMPILE_FULLGRAPH}" + ) + log0( + f"torch_compile_eval:disable={TORCH_COMPILE_EVAL_DISABLE} " + f"mode:{TORCH_COMPILE_EVAL_MODE or 'inherit'}" + ) + log0( + f"tie_embeddings:{args.tie_embeddings} embed_lr:{token_lr} " + f"head_lr:{args.head_lr if base_model.lm_head is not None else 0.0} " + f"matrix_lr:{args.matrix_lr} scalar_lr:{args.scalar_lr}" + ) + log0( + f"swa:enabled={args.swa_enabled} every:{args.swa_every} " + f"start_step:{args.swa_start_step or 'auto'}" + ) + log0( + f"lr_schedule:warmdown_iters:{args.warmdown_iters} " + f"step_ms_ref:{args.lr_schedule_reference_step_ms or 'auto'}" + ) + log0( + f"train_batch_tokens:{args.train_batch_tokens} train_seq_len:{args.train_seq_len} " + f"iterations:{args.iterations} warmup_steps:{args.warmup_steps} " + f"max_wallclock_seconds:{args.max_wallclock_seconds:.3f}" + ) + log0(f"seed:{args.seed}") + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + def zero_grad_all() -> None: + for opt in optimizers: + opt.zero_grad(set_to_none=True) + max_wallclock_ms = 1000.0 * args.max_wallclock_seconds if args.max_wallclock_seconds > 0 else None + if args.use_gptq and max_wallclock_ms is not None: + max_wallclock_ms -= args.gptq_reserve_ms + log0(f"gptq:reserving {args.gptq_reserve_ms:.0f}ms from training budget, effective={max_wallclock_ms:.0f}ms") + def lr_mul(step: int, elapsed_ms: float) -> float: + if args.warmdown_iters <= 0: + return 1.0 + if max_wallclock_ms is None: + warmdown_start = max(args.iterations - args.warmdown_iters, 0) + return max((args.iterations - step) / max(args.warmdown_iters, 1), 0.0) if warmdown_start <= step < args.iterations else 1.0 + step_ms = elapsed_ms / max(step, 1) + if args.lr_schedule_reference_step_ms > 0: + step_ms = max(step_ms, args.lr_schedule_reference_step_ms) + warmdown_ms = args.warmdown_iters * step_ms + remaining_ms = max(max_wallclock_ms - elapsed_ms, 0.0) + return remaining_ms / max(warmdown_ms, 1e-9) if remaining_ms <= warmdown_ms else 1.0 + if args.warmup_steps > 0: + initial_model_state = {name: tensor.detach().cpu().clone() for name, tensor in base_model.state_dict().items()} + initial_optimizer_states = [copy.deepcopy(opt.state_dict()) for opt in optimizers] + model.train() + for warmup_step in range(args.warmup_steps): + zero_grad_all() + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + cudagraph_step_begin() + warmup_loss = model(x, y) + (warmup_loss * grad_scale).backward() + # All-reduce all grads for warmup (simple, not optimized) + if distributed: + for p in base_model.parameters(): + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + for opt in optimizers: + opt.step() + zero_grad_all() + if args.warmup_steps <= 20 or (warmup_step + 1) % 10 == 0 or warmup_step + 1 == args.warmup_steps: + log0(f"warmup_step:{warmup_step + 1}/{args.warmup_steps}") + base_model.load_state_dict(initial_model_state, strict=True) + for opt, state in zip(optimizers, initial_optimizer_states, strict=True): + opt.load_state_dict(state) + zero_grad_all() + train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + swa_state: dict[str, Tensor] | None = None + swa_count = 0 + ema_state = {name: t.detach().float().clone() for name, t in base_model.state_dict().items()} + ema_decay = 0.997 + training_time_ms = 0.0 + stop_after_step: int | None = None + torch.cuda.synchronize() + t0 = time.perf_counter() + step = 0 + while True: + last_step = step == args.iterations or (stop_after_step is not None and step >= stop_after_step) + should_validate = last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0) + if should_validate: + torch.cuda.synchronize() + training_time_ms += 1000.0 * (time.perf_counter() - t0) + val_loss, val_bpb = eval_val( + args, + model, + rank, + world_size, + device, + grad_accum_steps, + val_tokens, + base_bytes_lut, + has_leading_space_lut, + is_boundary_token_lut, + ) + log0( + f"step:{step}/{args.iterations} val_loss:{val_loss:.4f} val_bpb:{val_bpb:.4f} " + f"train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms / max(step, 1):.2f}ms" + ) + torch.cuda.synchronize() + t0 = time.perf_counter() + if last_step: + if stop_after_step is not None and step < args.iterations: + log0( + f"stopping_early: wallclock_cap train_time:{training_time_ms:.0f}ms " + f"step:{step}/{args.iterations}" + ) + break + elapsed_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + scale = lr_mul(step, elapsed_ms) + zero_grad_all() + train_loss = torch.zeros((), device=device) + for micro_step in range(grad_accum_steps): + x, y = train_loader.next_batch(args.train_batch_tokens, args.train_seq_len, grad_accum_steps) + with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True): + cudagraph_step_begin() + loss = model(x, y) + train_loss += loss.detach() + (loss * grad_scale).backward() + train_loss /= grad_accum_steps + frac = min(step / args.muon_momentum_warmup_steps, 1.0) if args.muon_momentum_warmup_steps > 0 else 1.0 + muon_momentum = (1 - frac) * args.muon_momentum_warmup_start + frac * args.muon_momentum + for group in optimizer_muon.param_groups: + group["momentum"] = muon_momentum + for opt in optimizers: + for group in opt.param_groups: + group["lr"] = group["base_lr"] * scale + if args.grad_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(base_model.parameters(), args.grad_clip_norm) + # === 3-phase overlapped optimizer step === + # Phase 1: Launch async reduce-scatter for banks (biggest first) + optimizer_muon.launch_reduce_scatters() + # Phase 2: All-reduce non-bank grads + step Adam (while bank RS is in-flight) + if distributed: + for p in replicated_params: + if p.grad is not None: + dist.all_reduce(p.grad, op=dist.ReduceOp.AVG) + optimizer_tok.step() + optimizer_scalar.step() + if optimizer_head is not None: + optimizer_head.step() + # Phase 3: Wait for RS, local NS5, all-gather (banks processed last) + optimizer_muon.step() + zero_grad_all() + # EMA update + with torch.no_grad(): + for name, t in base_model.state_dict().items(): + ema_state[name].mul_(ema_decay).add_(t.detach().float(), alpha=1.0 - ema_decay) + step += 1 + approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0) + swa_ready = step >= args.swa_start_step if args.swa_start_step > 0 else scale < 0.2 + if args.swa_enabled and swa_ready and step % args.swa_every == 0: + if swa_state is None: + swa_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()} + swa_count = 1 + log0(f"swa:start step:{step}") + else: + for name, t in base_model.state_dict().items(): + swa_state[name] += t.detach().cpu() + swa_count += 1 + should_log_train = ( + args.train_log_every > 0 + and (step <= 10 or step % args.train_log_every == 0 or stop_after_step is not None) + ) + if should_log_train: + log0( + f"step:{step}/{args.iterations} train_loss:{train_loss.item():.4f} " + f"train_time:{approx_training_time_ms:.0f}ms step_avg:{approx_training_time_ms / step:.2f}ms" + ) + reached_cap = max_wallclock_ms is not None and approx_training_time_ms >= max_wallclock_ms + if distributed and max_wallclock_ms is not None: + reached_cap_tensor = torch.tensor(int(reached_cap), device=device) + dist.all_reduce(reached_cap_tensor, op=dist.ReduceOp.MAX) + reached_cap = bool(reached_cap_tensor.item()) + if stop_after_step is None and reached_cap: + stop_after_step = step + log0( + f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB " + f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB" + ) + # Apply EMA weights + log0("ema:applying EMA weights") + current_state = base_model.state_dict() + avg_state = {name: t.to(dtype=current_state[name].dtype) for name, t in ema_state.items()} + base_model.load_state_dict(avg_state, strict=True) + torch.cuda.synchronize() + t_diag = time.perf_counter() + diag_val_loss, diag_val_bpb = eval_val( + args, compiled_model, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + ) + torch.cuda.synchronize() + log0( + f"DIAGNOSTIC post_ema val_loss:{diag_val_loss:.4f} val_bpb:{diag_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_diag):.0f}ms" + ) + export_sd = base_model.state_dict() + if master_process: + torch.save(export_sd, "final_model.pt") + model_bytes = os.path.getsize("final_model.pt") + code_bytes = submission_code_bytes() + log0(f"Serialized model: {model_bytes} bytes") + log0(f"Code size: {code_bytes} bytes") + # Unbank 3D tensors into individual 2D tensors for quantization + sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()} + unbanked_sd = _unbank_state_dict(sd_cpu, args.num_layers) + # GPTQ calibration: collect Hessians from training data + gptq_hessians = None + if args.use_gptq: + t_gptq = time.perf_counter() + log0(f"gptq:calibrating with {args.gptq_calib_samples} batches (training data)...") + calib_loader = DistributedTokenLoader(args.train_files, rank, world_size, device) + gptq_hessians = gptq_collect_hessians( + base_model, calib_loader, device, num_batches=args.gptq_calib_samples, + batch_tokens=args.train_batch_tokens, seq_len=args.train_seq_len, + grad_accum_steps=grad_accum_steps) + del calib_loader + gptq_elapsed = time.perf_counter() - t_gptq + log0(f"gptq:calibrated {len(gptq_hessians)} layers in {gptq_elapsed:.1f}s") + torch.cuda.empty_cache() + quant_result, quant_meta = mixed_quantize_int6(unbanked_sd, {"mlp", "attn"}, clip_range=args.quant_clip_range, hessians=gptq_hessians) + quant_buf = io.BytesIO() + torch.save({"w": quant_result, "m": quant_meta}, quant_buf) + quant_raw = quant_buf.getvalue() + quant_blob = lzma.compress(quant_raw, preset=6) + if master_process: + with open("final_model.int6.ptz", "wb") as f: + f.write(quant_blob) + quant_file_bytes = len(quant_blob) + code_bytes = submission_code_bytes() + log0(f"Serialized model int6+lzma: {quant_file_bytes} bytes") + log0(f"Total submission size int6+lzma: {quant_file_bytes + code_bytes} bytes") + if distributed: + dist.barrier() + with open("final_model.int6.ptz", "rb") as f: + quant_blob_disk = f.read() + quant_state = torch.load( + io.BytesIO(lzma.decompress(quant_blob_disk)), + map_location="cpu", + ) + deq_unbanked = dequantize_mixed_int6(quant_state["w"], quant_state["m"], unbanked_sd) + # Re-bank the dequantized tensors + deq_state = _rebank_state_dict(deq_unbanked, args.num_layers, sd_cpu) + eval_model = GPT( + vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim, + num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult, + tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std, + logit_softcap=args.logit_softcap, rope_base=args.rope_base, qk_gain_init=args.qk_gain_init, + bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim, + xsa_last_n=args.xsa_last_n, + rope_dims=args.rope_dims, ln_scale=args.ln_scale, + ve_enabled=args.ve_enabled, ve_dim=args.ve_dim, ve_layers=args.ve_layers, + neg_slope=args.negative_slope, + ).to(device).bfloat16() + eval_model.qo_bank.data = eval_model.qo_bank.data.float() + eval_model.kv_bank.data = eval_model.kv_bank.data.float() + eval_model.mlp_up_bank.data = eval_model.mlp_up_bank.data.float() + eval_model.mlp_down_bank.data = eval_model.mlp_down_bank.data.float() + for m in eval_model.modules(): + if isinstance(m, CastedLinear): + m.float() + restore_low_dim_params_to_fp32(eval_model) + eval_model.load_state_dict(deq_state, strict=True) + prime_rotary_caches(eval_model, device, effective_eval_seq_len) + compiled_eval = compile_with_env(eval_model, is_eval=True) + torch.cuda.synchronize() + t_qeval = time.perf_counter() + q_val_loss, q_val_bpb = eval_val( + args, compiled_eval, rank, world_size, device, grad_accum_steps, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + eval_seq_len=effective_eval_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_roundtrip val_loss:{q_val_loss:.4f} val_bpb:{q_val_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms" + ) + log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}") + sw_seq_len = effective_eval_seq_len + if args.eval_stride > 0 and args.eval_stride < sw_seq_len: + torch.cuda.synchronize() + t_slide = time.perf_counter() + sw_val_loss, sw_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window val_loss:{sw_val_loss:.4f} val_bpb:{sw_val_bpb:.4f} " + f"stride:{args.eval_stride} eval_time:{1000.0 * (time.perf_counter() - t_slide):.0f}ms" + ) + log0(f"final_int6_sliding_window_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw_val_loss:.8f} val_bpb:{sw_val_bpb:.8f}") + if bool(int(os.environ.get("ONLINE_BEST_AGREE_EVAL", "0"))): + torch.cuda.synchronize() + t_online = time.perf_counter() + online_best_agree_mod = load_online_best_agree_eval_mod() + online_val_loss, online_val_bpb, online_timings = online_best_agree_mod.eval_val_sliding_online_best_agree( + args=args, + base_model=eval_model, + rank=rank, + world_size=world_size, + device=device, + val_tokens=val_tokens, + base_bytes_lut=base_bytes_lut, + has_leading_space_lut=has_leading_space_lut, + is_boundary_token_lut=is_boundary_token_lut, + stride=args.eval_stride, + batch_seqs=int(os.environ.get("BATCH_SEQS", "32")), + eval_seq_len=sw_seq_len, + log0=log0, + ) + torch.cuda.synchronize() + online_elapsed_ms = 1000.0 * (time.perf_counter() - t_online) + log0( + f"online_best_agree_sliding_window val_loss:{online_val_loss:.4f} " + f"val_bpb:{online_val_bpb:.4f} stride:{args.eval_stride} eval_time:{online_elapsed_ms:.0f}ms" + ) + log0( + f"online_best_agree_sliding_window_exact val_loss:{online_val_loss:.8f} " + f"val_bpb:{online_val_bpb:.8f}" + ) + log0( + "online_best_agree_compare " + f"llm_bpb:{online_timings['llm_bpb']:.8f} " + f"best_agree_bpb:{online_timings['best_agree_bpb']:.8f} " + f"gain_bpb:{online_timings['gain_bpb']:.8f} " + f"llm_nats_per_byte:{online_timings['llm_nats_per_byte']:.8f} " + f"best_agree_nats_per_byte:{online_timings['best_agree_nats_per_byte']:.8f} " + f"gain_nats_per_byte:{online_timings['gain_nats_per_byte']:.8f}" + ) + log0( + "online_best_agree_timing " + f"startup_max:{online_timings['startup_max_s']:.2f}s " + f"loop_total_max:{online_timings['loop_total_max_s']:.2f}s " + f"state_max:{online_timings['state_max_s']:.2f}s " + f"input_max:{online_timings['input_max_s']:.2f}s " + f"forward_max:{online_timings['forward_max_s']:.2f}s " + f"blend_max:{online_timings['blend_max_s']:.2f}s " + f"wallclock:{online_elapsed_ms / 1000.0:.2f}s" + ) + if args.eval_stride != 64 and 64 < sw_seq_len: + torch.cuda.synchronize() + t_slide64 = time.perf_counter() + sw64_val_loss, sw64_val_bpb = eval_val_sliding( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=64, + eval_seq_len=sw_seq_len, + ) + torch.cuda.synchronize() + log0( + f"final_int6_sliding_window_s64 val_loss:{sw64_val_loss:.4f} val_bpb:{sw64_val_bpb:.4f} " + f"stride:64 eval_time:{1000.0 * (time.perf_counter() - t_slide64):.0f}ms" + ) + log0(f"final_int6_sliding_window_s64_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + log0(f"final_int8_zlib_roundtrip_exact val_loss:{sw64_val_loss:.8f} val_bpb:{sw64_val_bpb:.8f}") + # Legal score-first TTT (PR #461 recipe) + if args.ttt_enabled: + torch.cuda.synchronize() + t_ttt = time.perf_counter() + ttt_loss, ttt_bpb = eval_val_sliding_ttt( + args, eval_model, rank, world_size, device, + val_tokens, base_bytes_lut, has_leading_space_lut, is_boundary_token_lut, + stride=args.eval_stride, log0=log0, + ) + torch.cuda.synchronize() + log0(f"legal_ttt val_loss:{ttt_loss:.4f} val_bpb:{ttt_bpb:.4f} " + f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms") + log0(f"legal_ttt_exact val_loss:{ttt_loss:.8f} val_bpb:{ttt_bpb:.8f}") + if distributed: + dist.destroy_process_group() +if __name__ == "__main__": + main() diff --git a/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/train_seed1337.log b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/train_seed1337.log new file mode 100644 index 0000000000..f95227e770 --- /dev/null +++ b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/train_seed1337.log @@ -0,0 +1,97 @@ +W0330 15:39:45.304000 3782698 site-packages/torch/distributed/run.py:851] +W0330 15:39:45.304000 3782698 site-packages/torch/distributed/run.py:851] ***************************************** +W0330 15:39:45.304000 3782698 site-packages/torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0330 15:39:45.304000 3782698 site-packages/torch/distributed/run.py:851] ***************************************** +logs/6527ed86-0f0a-403c-a2bc-08e19781bbd4.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/root/parameter-golf-pr1089/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/root/parameter-golf-pr1089/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27038812 +XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +torch_compile:disable=False backend:inductor mode:default dynamic:False fullgraph:True +torch_compile_eval:disable=False mode:inherit +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +swa:enabled=True every:50 start_step:auto +lr_schedule:warmdown_iters:4000 step_ms_ref:auto +train_batch_tokens:786432 train_seq_len:2048 iterations:6700 warmup_steps:20 max_wallclock_seconds:600.000 +seed:1337 +gptq:reserving 10000ms from training budget, effective=590000ms +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/6700 val_loss:6.9309 val_bpb:4.1049 train_time:0ms step_avg:0.01ms +step:1/6700 train_loss:6.9300 train_time:136ms step_avg:136.39ms +step:2/6700 train_loss:8.8448 train_time:183ms step_avg:91.56ms +step:3/6700 train_loss:7.5937 train_time:271ms step_avg:90.17ms +step:4/6700 train_loss:7.1277 train_time:358ms step_avg:89.45ms +step:5/6700 train_loss:7.0348 train_time:450ms step_avg:90.08ms +step:6/6700 train_loss:7.0677 train_time:538ms step_avg:89.67ms +step:7/6700 train_loss:6.9591 train_time:628ms step_avg:89.65ms +step:8/6700 train_loss:6.7365 train_time:720ms step_avg:89.95ms +step:9/6700 train_loss:6.4988 train_time:805ms step_avg:89.49ms +step:10/6700 train_loss:6.1238 train_time:898ms step_avg:89.76ms +step:500/6700 train_loss:2.3154 train_time:45416ms step_avg:90.83ms +step:1000/6700 train_loss:2.1729 train_time:90949ms step_avg:90.95ms +step:1500/6700 train_loss:2.1551 train_time:136484ms step_avg:90.99ms +step:2000/6700 train_loss:2.0963 train_time:182104ms step_avg:91.05ms +step:2500/6700 train_loss:2.0442 train_time:227721ms step_avg:91.09ms +step:3000/6700 train_loss:2.0434 train_time:273357ms step_avg:91.12ms +step:3500/6700 train_loss:2.0287 train_time:319008ms step_avg:91.15ms +step:4000/6700 train_loss:1.9903 train_time:364652ms step_avg:91.16ms +step:4000/6700 val_loss:2.0135 val_bpb:1.1925 train_time:364702ms step_avg:91.18ms +step:4500/6700 train_loss:1.9725 train_time:410287ms step_avg:91.17ms +step:5000/6700 train_loss:1.9405 train_time:455904ms step_avg:91.18ms +step:5500/6700 train_loss:1.9362 train_time:501512ms step_avg:91.18ms +swa:start step:5700 +step:6000/6700 train_loss:2.0035 train_time:547629ms step_avg:91.27ms +step:6456/6700 val_loss:1.9154 val_bpb:1.1344 train_time:590095ms step_avg:91.40ms +stopping_early: wallclock_cap train_time:590095ms step:6456/6700 +peak memory allocated: 23042 MiB reserved: 23182 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9139 val_bpb:1.1335 eval_time:2120ms +Serialized model: 106235254 bytes +Code size: 103317 bytes +gptq:calibrating with 64 batches (training data)... +gptq:calibrated 66 layers in 6.8s +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +Serialized model int6+lzma: 15849904 bytes +Total submission size int6+lzma: 15953221 bytes +final_int6_roundtrip val_loss:1.9206 val_bpb:1.1375 eval_time:2079ms +final_int6_roundtrip_exact val_loss:1.92060357 val_bpb:1.13748962 +final_int6_sliding_window val_loss:1.8811 val_bpb:1.1141 stride:64 eval_time:84064ms +final_int6_sliding_window_exact val_loss:1.88108198 val_bpb:1.11408566 +final_int8_zlib_roundtrip_exact val_loss:1.88108198 val_bpb:1.11408566 +eval-pass-online: using eager logits path +online_best_agree:start total_targets=62021632 seq_len=2048 stride=64 chunk_tokens=131072 batch_seqs=32 token_order=16 word_order=4 startup_max=0.00s +online_best_agree:done llm_bpb=1.11437756 best_agree_bpb=1.11126660 gain_bpb=0.00311096 startup_max=0.00s loop_total_max=457.46s state_max=209.73s input_max=13.63s forward_max=45.68s blend_max=191.59s +online_best_agree_sliding_window val_loss:1.8763 val_bpb:1.1113 stride:64 eval_time:461618ms +online_best_agree_sliding_window_exact val_loss:1.87632213 val_bpb:1.11126660 +online_best_agree_compare llm_bpb:1.11437756 best_agree_bpb:1.11126660 gain_bpb:0.00311096 llm_nats_per_byte:0.77242766 best_agree_nats_per_byte:0.77027131 gain_nats_per_byte:0.00215636 +online_best_agree_timing startup_max:0.00s loop_total_max:457.46s state_max:209.73s input_max:13.63s forward_max:45.68s blend_max:191.59s wallclock:461.62s diff --git a/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/train_seed15.log b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/train_seed15.log new file mode 100644 index 0000000000..86fa7469b4 --- /dev/null +++ b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/train_seed15.log @@ -0,0 +1,97 @@ +W0330 18:26:01.198000 3838175 site-packages/torch/distributed/run.py:851] +W0330 18:26:01.198000 3838175 site-packages/torch/distributed/run.py:851] ***************************************** +W0330 18:26:01.198000 3838175 site-packages/torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0330 18:26:01.198000 3838175 site-packages/torch/distributed/run.py:851] ***************************************** +logs/4cd3ddf2-4f4a-42b2-b8ba-dd45bf498994.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/root/parameter-golf-pr1089/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/root/parameter-golf-pr1089/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27038812 +XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +torch_compile:disable=False backend:inductor mode:default dynamic:False fullgraph:True +torch_compile_eval:disable=False mode:inherit +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +swa:enabled=True every:50 start_step:auto +lr_schedule:warmdown_iters:4000 step_ms_ref:auto +train_batch_tokens:786432 train_seq_len:2048 iterations:6700 warmup_steps:20 max_wallclock_seconds:600.000 +seed:15 +gptq:reserving 10000ms from training budget, effective=590000ms +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/6700 val_loss:6.9315 val_bpb:4.1052 train_time:0ms step_avg:0.01ms +step:1/6700 train_loss:6.9317 train_time:136ms step_avg:136.09ms +step:2/6700 train_loss:8.8673 train_time:181ms step_avg:90.67ms +step:3/6700 train_loss:7.5880 train_time:266ms step_avg:88.70ms +step:4/6700 train_loss:7.1464 train_time:357ms step_avg:89.26ms +step:5/6700 train_loss:7.0683 train_time:447ms step_avg:89.47ms +step:6/6700 train_loss:7.0660 train_time:535ms step_avg:89.13ms +step:7/6700 train_loss:6.9594 train_time:627ms step_avg:89.51ms +step:8/6700 train_loss:6.7488 train_time:716ms step_avg:89.48ms +step:9/6700 train_loss:6.4295 train_time:803ms step_avg:89.19ms +step:10/6700 train_loss:6.0690 train_time:895ms step_avg:89.51ms +step:500/6700 train_loss:2.3131 train_time:45437ms step_avg:90.87ms +step:1000/6700 train_loss:2.1733 train_time:90975ms step_avg:90.97ms +step:1500/6700 train_loss:2.1528 train_time:136542ms step_avg:91.03ms +step:2000/6700 train_loss:2.0963 train_time:182204ms step_avg:91.10ms +step:2500/6700 train_loss:2.0455 train_time:227875ms step_avg:91.15ms +step:3000/6700 train_loss:2.0415 train_time:273557ms step_avg:91.19ms +step:3500/6700 train_loss:2.0247 train_time:319246ms step_avg:91.21ms +step:4000/6700 train_loss:1.9926 train_time:364915ms step_avg:91.23ms +step:4000/6700 val_loss:2.0128 val_bpb:1.1921 train_time:364969ms step_avg:91.24ms +step:4500/6700 train_loss:1.9706 train_time:410575ms step_avg:91.24ms +step:5000/6700 train_loss:1.9381 train_time:456203ms step_avg:91.24ms +step:5500/6700 train_loss:1.9350 train_time:501823ms step_avg:91.24ms +swa:start step:5700 +step:6000/6700 train_loss:2.0016 train_time:548003ms step_avg:91.33ms +step:6451/6700 val_loss:1.9148 val_bpb:1.1341 train_time:590059ms step_avg:91.47ms +stopping_early: wallclock_cap train_time:590059ms step:6451/6700 +peak memory allocated: 23042 MiB reserved: 23182 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9134 val_bpb:1.1332 eval_time:2120ms +Serialized model: 106235254 bytes +Code size: 103317 bytes +gptq:calibrating with 64 batches (training data)... +gptq:calibrated 66 layers in 6.8s +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +Serialized model int6+lzma: 15738424 bytes +Total submission size int6+lzma: 15841741 bytes +final_int6_roundtrip val_loss:1.9201 val_bpb:1.1372 eval_time:2088ms +final_int6_roundtrip_exact val_loss:1.92011797 val_bpb:1.13720202 +final_int6_sliding_window val_loss:1.8805 val_bpb:1.1137 stride:64 eval_time:79809ms +final_int6_sliding_window_exact val_loss:1.88047022 val_bpb:1.11372333 +final_int8_zlib_roundtrip_exact val_loss:1.88047022 val_bpb:1.11372333 +eval-pass-online: using eager logits path +online_best_agree:start total_targets=62021632 seq_len=2048 stride=64 chunk_tokens=131072 batch_seqs=32 token_order=16 word_order=4 startup_max=0.00s +online_best_agree:done llm_bpb=1.11402056 best_agree_bpb=1.11089935 gain_bpb=0.00312121 startup_max=0.00s loop_total_max=462.20s state_max=212.69s input_max=13.79s forward_max=45.74s blend_max=192.36s +online_best_agree_sliding_window val_loss:1.8757 val_bpb:1.1109 stride:64 eval_time:466092ms +online_best_agree_sliding_window_exact val_loss:1.87570205 val_bpb:1.11089935 +online_best_agree_compare llm_bpb:1.11402056 best_agree_bpb:1.11089935 gain_bpb:0.00312121 llm_nats_per_byte:0.77218021 best_agree_nats_per_byte:0.77001675 gain_nats_per_byte:0.00216346 +online_best_agree_timing startup_max:0.00s loop_total_max:462.20s state_max:212.69s input_max:13.79s forward_max:45.74s blend_max:192.36s wallclock:466.09s diff --git a/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/train_seed2025.log b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/train_seed2025.log new file mode 100644 index 0000000000..e1ee491439 --- /dev/null +++ b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/train_seed2025.log @@ -0,0 +1,97 @@ +W0330 16:01:26.240000 3789541 site-packages/torch/distributed/run.py:851] +W0330 16:01:26.240000 3789541 site-packages/torch/distributed/run.py:851] ***************************************** +W0330 16:01:26.240000 3789541 site-packages/torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0330 16:01:26.240000 3789541 site-packages/torch/distributed/run.py:851] ***************************************** +logs/0146acf7-45f3-4875-8aa0-204e20fd6481.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/root/parameter-golf-pr1089/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/root/parameter-golf-pr1089/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27038812 +XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +torch_compile:disable=False backend:inductor mode:default dynamic:False fullgraph:True +torch_compile_eval:disable=False mode:inherit +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +swa:enabled=True every:50 start_step:auto +lr_schedule:warmdown_iters:4000 step_ms_ref:auto +train_batch_tokens:786432 train_seq_len:2048 iterations:6700 warmup_steps:20 max_wallclock_seconds:600.000 +seed:2025 +gptq:reserving 10000ms from training budget, effective=590000ms +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/6700 val_loss:6.9266 val_bpb:4.1023 train_time:0ms step_avg:0.01ms +step:1/6700 train_loss:6.9268 train_time:137ms step_avg:136.67ms +step:2/6700 train_loss:8.4714 train_time:181ms step_avg:90.34ms +step:3/6700 train_loss:7.4467 train_time:272ms step_avg:90.64ms +step:4/6700 train_loss:7.4673 train_time:360ms step_avg:89.97ms +step:5/6700 train_loss:7.2236 train_time:448ms step_avg:89.63ms +step:6/6700 train_loss:7.0313 train_time:539ms step_avg:89.91ms +step:7/6700 train_loss:6.8004 train_time:629ms step_avg:89.80ms +step:8/6700 train_loss:6.6174 train_time:718ms step_avg:89.79ms +step:9/6700 train_loss:6.4211 train_time:810ms step_avg:89.96ms +step:10/6700 train_loss:6.1057 train_time:899ms step_avg:89.89ms +step:500/6700 train_loss:2.3152 train_time:45398ms step_avg:90.80ms +step:1000/6700 train_loss:2.1744 train_time:90883ms step_avg:90.88ms +step:1500/6700 train_loss:2.1541 train_time:136418ms step_avg:90.95ms +step:2000/6700 train_loss:2.0964 train_time:182027ms step_avg:91.01ms +step:2500/6700 train_loss:2.0474 train_time:227672ms step_avg:91.07ms +step:3000/6700 train_loss:2.0439 train_time:273308ms step_avg:91.10ms +step:3500/6700 train_loss:2.0273 train_time:318973ms step_avg:91.14ms +step:4000/6700 train_loss:1.9881 train_time:364615ms step_avg:91.15ms +step:4000/6700 val_loss:2.0134 val_bpb:1.1924 train_time:364668ms step_avg:91.17ms +step:4500/6700 train_loss:1.9725 train_time:410237ms step_avg:91.16ms +step:5000/6700 train_loss:1.9433 train_time:455871ms step_avg:91.17ms +step:5500/6700 train_loss:1.9368 train_time:501473ms step_avg:91.18ms +swa:start step:5700 +step:6000/6700 train_loss:2.0058 train_time:547588ms step_avg:91.26ms +step:6457/6700 val_loss:1.9147 val_bpb:1.1340 train_time:590128ms step_avg:91.39ms +stopping_early: wallclock_cap train_time:590128ms step:6457/6700 +peak memory allocated: 23042 MiB reserved: 23182 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9133 val_bpb:1.1332 eval_time:2121ms +Serialized model: 106235254 bytes +Code size: 103317 bytes +gptq:calibrating with 64 batches (training data)... +gptq:calibrated 66 layers in 6.8s +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +Serialized model int6+lzma: 15738984 bytes +Total submission size int6+lzma: 15842301 bytes +final_int6_roundtrip val_loss:1.9199 val_bpb:1.1371 eval_time:2081ms +final_int6_roundtrip_exact val_loss:1.91991864 val_bpb:1.13708396 +final_int6_sliding_window val_loss:1.8801 val_bpb:1.1135 stride:64 eval_time:79537ms +final_int6_sliding_window_exact val_loss:1.88013045 val_bpb:1.11352210 +final_int8_zlib_roundtrip_exact val_loss:1.88013045 val_bpb:1.11352210 +eval-pass-online: using eager logits path +online_best_agree:start total_targets=62021632 seq_len=2048 stride=64 chunk_tokens=131072 batch_seqs=32 token_order=16 word_order=4 startup_max=0.00s +online_best_agree:done llm_bpb=1.11381798 best_agree_bpb=1.11068499 gain_bpb=0.00313300 startup_max=0.00s loop_total_max=458.64s state_max=211.08s input_max=13.54s forward_max=44.44s blend_max=191.75s +online_best_agree_sliding_window val_loss:1.8753 val_bpb:1.1107 stride:64 eval_time:462347ms +online_best_agree_sliding_window_exact val_loss:1.87534011 val_bpb:1.11068499 +online_best_agree_compare llm_bpb:1.11381798 best_agree_bpb:1.11068499 gain_bpb:0.00313300 llm_nats_per_byte:0.77203979 best_agree_nats_per_byte:0.76986817 gain_nats_per_byte:0.00217163 +online_best_agree_timing startup_max:0.00s loop_total_max:458.64s state_max:211.08s input_max:13.54s forward_max:44.44s blend_max:191.75s wallclock:462.35s diff --git a/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/train_seed42.log b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/train_seed42.log new file mode 100644 index 0000000000..b55458d933 --- /dev/null +++ b/records/track_10min_16mb/2026-03-29_Loader_FullGPTQ_XSA11_OnlineNgramAgreement/train_seed42.log @@ -0,0 +1,97 @@ +W0330 16:23:08.108000 3796131 site-packages/torch/distributed/run.py:851] +W0330 16:23:08.108000 3796131 site-packages/torch/distributed/run.py:851] ***************************************** +W0330 16:23:08.108000 3796131 site-packages/torch/distributed/run.py:851] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. +W0330 16:23:08.108000 3796131 site-packages/torch/distributed/run.py:851] ***************************************** +logs/486fe152-81cc-42dd-8ccd-2b362e75da1c.txt +val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/root/parameter-golf-pr1089/data/tokenizers/fineweb_1024_bpe.model +train_loader:dataset:fineweb10B_sp1024 train_shards:80 +val_loader:shards pattern=/root/parameter-golf-pr1089/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632 +model_params:27038812 +XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] +world_size:8 grad_accum_steps:1 +sdp_backends:cudnn=False flash=True mem_efficient=False math=False +attention_mode:gqa num_heads:8 num_kv_heads:4 +torch_compile:disable=False backend:inductor mode:default dynamic:False fullgraph:True +torch_compile_eval:disable=False mode:inherit +tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025 +swa:enabled=True every:50 start_step:auto +lr_schedule:warmdown_iters:4000 step_ms_ref:auto +train_batch_tokens:786432 train_seq_len:2048 iterations:6700 warmup_steps:20 max_wallclock_seconds:600.000 +seed:42 +gptq:reserving 10000ms from training budget, effective=590000ms +warmup_step:1/20 +warmup_step:2/20 +warmup_step:3/20 +warmup_step:4/20 +warmup_step:5/20 +warmup_step:6/20 +warmup_step:7/20 +warmup_step:8/20 +warmup_step:9/20 +warmup_step:10/20 +warmup_step:11/20 +warmup_step:12/20 +warmup_step:13/20 +warmup_step:14/20 +warmup_step:15/20 +warmup_step:16/20 +warmup_step:17/20 +warmup_step:18/20 +warmup_step:19/20 +warmup_step:20/20 +step:0/6700 val_loss:6.9297 val_bpb:4.1041 train_time:0ms step_avg:0.01ms +step:1/6700 train_loss:6.9291 train_time:137ms step_avg:137.42ms +step:2/6700 train_loss:8.6518 train_time:182ms step_avg:90.83ms +step:3/6700 train_loss:7.4750 train_time:269ms step_avg:89.51ms +step:4/6700 train_loss:7.2164 train_time:360ms step_avg:90.10ms +step:5/6700 train_loss:7.0829 train_time:449ms step_avg:89.86ms +step:6/6700 train_loss:6.9904 train_time:537ms step_avg:89.56ms +step:7/6700 train_loss:6.9313 train_time:628ms step_avg:89.73ms +step:8/6700 train_loss:6.6865 train_time:719ms step_avg:89.91ms +step:9/6700 train_loss:6.4638 train_time:806ms step_avg:89.59ms +step:10/6700 train_loss:6.1012 train_time:898ms step_avg:89.78ms +step:500/6700 train_loss:2.3086 train_time:45440ms step_avg:90.88ms +step:1000/6700 train_loss:2.1687 train_time:91047ms step_avg:91.05ms +step:1500/6700 train_loss:2.1492 train_time:136710ms step_avg:91.14ms +step:2000/6700 train_loss:2.0961 train_time:182435ms step_avg:91.22ms +step:2500/6700 train_loss:2.0443 train_time:228163ms step_avg:91.27ms +step:3000/6700 train_loss:2.0406 train_time:273895ms step_avg:91.30ms +step:3500/6700 train_loss:2.0288 train_time:319637ms step_avg:91.32ms +step:4000/6700 train_loss:1.9907 train_time:365395ms step_avg:91.35ms +step:4000/6700 val_loss:2.0123 val_bpb:1.1918 train_time:365443ms step_avg:91.36ms +step:4500/6700 train_loss:1.9712 train_time:411121ms step_avg:91.36ms +step:5000/6700 train_loss:1.9404 train_time:456829ms step_avg:91.37ms +step:5500/6700 train_loss:1.9358 train_time:502547ms step_avg:91.37ms +swa:start step:5700 +step:6000/6700 train_loss:2.0001 train_time:548800ms step_avg:91.47ms +step:6443/6700 val_loss:1.9146 val_bpb:1.1339 train_time:590093ms step_avg:91.59ms +stopping_early: wallclock_cap train_time:590093ms step:6443/6700 +peak memory allocated: 23042 MiB reserved: 23182 MiB +ema:applying EMA weights +DIAGNOSTIC post_ema val_loss:1.9132 val_bpb:1.1331 eval_time:2120ms +Serialized model: 106235254 bytes +Code size: 103317 bytes +gptq:calibrating with 64 batches (training data)... +gptq:calibrated 66 layers in 6.8s +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +gptq_quantize: 66 GPTQ layers, 0 naive layers +Serialized model int6+lzma: 15714496 bytes +Total submission size int6+lzma: 15817813 bytes +final_int6_roundtrip val_loss:1.9198 val_bpb:1.1370 eval_time:2087ms +final_int6_roundtrip_exact val_loss:1.91977902 val_bpb:1.13700127 +final_int6_sliding_window val_loss:1.8800 val_bpb:1.1134 stride:64 eval_time:80093ms +final_int6_sliding_window_exact val_loss:1.87998967 val_bpb:1.11343872 +final_int8_zlib_roundtrip_exact val_loss:1.87998967 val_bpb:1.11343872 +eval-pass-online: using eager logits path +online_best_agree:start total_targets=62021632 seq_len=2048 stride=64 chunk_tokens=131072 batch_seqs=32 token_order=16 word_order=4 startup_max=0.00s +online_best_agree:done llm_bpb=1.11372806 best_agree_bpb=1.11058356 gain_bpb=0.00314451 startup_max=0.00s loop_total_max=477.25s state_max=230.05s input_max=13.77s forward_max=44.62s blend_max=192.70s +online_best_agree_sliding_window val_loss:1.8752 val_bpb:1.1106 stride:64 eval_time:481038ms +online_best_agree_sliding_window_exact val_loss:1.87516885 val_bpb:1.11058356 +online_best_agree_compare llm_bpb:1.11372806 best_agree_bpb:1.11058356 gain_bpb:0.00314451 llm_nats_per_byte:0.77197747 best_agree_nats_per_byte:0.76979786 gain_nats_per_byte:0.00217961 +online_best_agree_timing startup_max:0.00s loop_total_max:477.25s state_max:230.05s input_max:13.77s forward_max:44.62s blend_max:192.70s wallclock:481.04s