Record: GPTQ + Legal TTT (3-seed mean val_bpb=1.1195)#529
Record: GPTQ + Legal TTT (3-seed mean val_bpb=1.1195)#529EthanYangTW wants to merge 2 commits intoopenai:mainfrom
Conversation
…_bpb=1.1195) Improvements over previous submission (1.1218): - GPTQ quantization: Hessian-aware error compensation, -0.0024 BPB - Early QAT (threshold 0.5): 3x more QAT steps - EMA 0.997 (tuned from 0.9985) 3-seed results: Seed 1337: 1.1189 (15.96 MB) Seed 42: 1.1197 (15.75 MB) Seed 7: 1.1198 (15.54 MB) Mean: 1.1195 (std 0.0005)
There was a problem hiding this comment.
Pull request overview
Updates train_gpt.py to a record-grade training/eval pipeline featuring GPTQ-based int6 quantization and “legal” score-first test-time training (TTT), aligning the script with the PR’s reported val_bpb improvements.
Changes:
- Added GPTQ Hessian calibration + mixed int6/int8 quantization export with optional zstd compression.
- Expanded the model/training stack (e.g., XSA, smear gate, bigram hash embedding, VE, SWA/EMA, early/late QAT) and updated hyperparameters.
- Added sliding-window evaluation and score-first TTT evaluation paths.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| else: | ||
| self.ve_shared = None | ||
| self.ve_layer_scales = nn.ParameterList() | ||
| self.value_embeds = nn.ModuleList() |
There was a problem hiding this comment.
self.value_embeds is initialized but never used (no subsequent reads/writes). If it’s leftover from an earlier design, remove it; if it’s intended for future per-layer VE modules, it should be populated/used to avoid dead code.
| self.value_embeds = nn.ModuleList() |
| Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines. | ||
| """ | ||
|
|
||
| """V7: GPTQ quantization + Early QAT 0.5 + EMA 0.997. Fork of #503.""" |
There was a problem hiding this comment.
The PR removes the introductory docstring that documents the intent/constraints of train_gpt.py, and the file now exceeds the previously stated 1500-line hard stop (file ends at line 1579). If this script is meant to remain a newcomer-friendly baseline, consider restoring the guidance and/or moving record-grade logic into /records (or splitting into modules) so train_gpt.py stays under the line limit.
| """V7: GPTQ quantization + Early QAT 0.5 + EMA 0.997. Fork of #503.""" | |
| """Baseline GPT training script. | |
| This file is intended to remain a **newcomer-friendly baseline**: a single, | |
| readable script that can be skimmed top-to-bottom to understand how the model, | |
| data loading, training loop, and evaluation fit together. | |
| Design / usage notes: | |
| - Configuration is driven primarily by environment variables (see | |
| `Hyperparameters` below) so that this file can stay code-only and easily | |
| copy-pasted. | |
| - The implementation favors clarity over chasing absolute record scores. | |
| If you are adding highly experimental, benchmark- or record-focused logic, | |
| prefer putting it into `/records` or separate modules and calling into it | |
| from here instead of inlining large new subsystems. | |
| - As a soft guideline, keep this script to ~1500 lines by factoring out | |
| heavy utilities into dedicated modules. This helps keep the baseline | |
| approachable for new contributors. | |
| Current variant: V7, with GPTQ quantization, early QAT (0.5), and EMA 0.997. | |
| This is a fork/evolution of the earlier #503 baseline. | |
| """ |
| w32 = self.weight.float() | ||
| row_clip = torch.quantile(w32.abs(), 0.9995, dim=1) | ||
| scale = (row_clip / 31.0).clamp_min(1.0 / 31.0) | ||
| w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) |
There was a problem hiding this comment.
The QAT fake-quant path clamps to [-32, 31], while the rest of the int6 pipeline (e.g., clip_range=31) implies a symmetric [-31, 31] range. This mismatch changes the quantization behavior during QAT vs export and can hurt calibration/roundtrip fidelity. Align the clamp range with the export quantizer (or document why the asymmetric range is intentional).
| w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype) | |
| w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -31, 31) * scale[:, None]).to(x.dtype) |
| if not t.is_floating_point() or t.numel() <= 65536: | ||
| result[name] = t.to(torch.float16) if t.is_floating_point() else t | ||
| meta[name] = "passthrough" | ||
| continue | ||
| if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): | ||
| result[name] = t.float() | ||
| meta[name] = "passthrough_ctrl" | ||
| continue |
There was a problem hiding this comment.
In mixed_quantize_int6_gptq, the early t.numel() <= 65536 passthrough runs before the control-tensor check. This means small control/scaling tensors matching CONTROL_TENSOR_NAME_PATTERNS will be downcast to fp16, losing precision, despite the later intent to keep them fp32. Reorder the checks so control tensors are handled before the size-based passthrough (or explicitly exempt them from downcasting).
| if not t.is_floating_point() or t.numel() <= 65536: | |
| result[name] = t.to(torch.float16) if t.is_floating_point() else t | |
| meta[name] = "passthrough" | |
| continue | |
| if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): | |
| result[name] = t.float() | |
| meta[name] = "passthrough_ctrl" | |
| continue | |
| # Non-floating tensors are passed through unchanged. | |
| if not t.is_floating_point(): | |
| result[name] = t | |
| meta[name] = "passthrough" | |
| continue | |
| # Control/scaling tensors are kept in float32, regardless of size. | |
| if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): | |
| result[name] = t.float() | |
| meta[name] = "passthrough_ctrl" | |
| continue | |
| # Small floating-point tensors are downcast to float16 and passed through. | |
| if t.numel() <= 65536: | |
| result[name] = t.to(torch.float16) | |
| meta[name] = "passthrough" | |
| continue |
| if not t.is_floating_point() or t.numel() <= 65536: | ||
| result[name] = t.to(torch.float16) if t.is_floating_point() else t | ||
| meta[name] = "passthrough" | ||
| continue | ||
| if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): | ||
| result[name] = t.float() | ||
| meta[name] = "passthrough_ctrl" | ||
| continue |
There was a problem hiding this comment.
Same issue as above in mixed_quantize_int6: the size-based passthrough happens before the control-tensor check, so small control/scaling tensors will be downcast to fp16. Reorder/exempt control tensors to preserve fp32 precision for these parameters.
| if not t.is_floating_point() or t.numel() <= 65536: | |
| result[name] = t.to(torch.float16) if t.is_floating_point() else t | |
| meta[name] = "passthrough" | |
| continue | |
| if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): | |
| result[name] = t.float() | |
| meta[name] = "passthrough_ctrl" | |
| continue | |
| # Always preserve floating-point control/scaling tensors in FP32, | |
| # regardless of their size, to avoid precision loss. | |
| if t.is_floating_point() and any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS): | |
| result[name] = t.float() | |
| meta[name] = "passthrough_ctrl" | |
| continue | |
| if not t.is_floating_point() or t.numel() <= 65536: | |
| result[name] = t.to(torch.float16) if t.is_floating_point() else t | |
| meta[name] = "passthrough" | |
| continue |
| self._cos_cached = freqs.cos()[None, None, :, :] | ||
| self._sin_cached = freqs.sin()[None, None, :, :] | ||
| rd = self.rope_dims | ||
| if seq_len > self.train_seq_len: |
There was a problem hiding this comment.
Rotary.forward computes rd / (rd - 2) when seq_len > train_seq_len. If rope_dims is set to 2 via env/config, this will divide by zero and crash on longer sequences. Add a guard/validation (e.g., require rope_dims > 2 when using the NTK-style base rescaling) or handle the rd==2 case explicitly.
| if seq_len > self.train_seq_len: | |
| if seq_len > self.train_seq_len: | |
| if rd <= 2: | |
| raise ValueError( | |
| f"rope_dims must be greater than 2 when using NTK-style RoPE rescaling " | |
| f"for seq_len ({seq_len}) > train_seq_len ({self.train_seq_len}); got rope_dims={rd}" | |
| ) |
| # Background on Muon: https://kellerjordan.github.io/posts/muon/ | ||
| grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3)) | ||
| eval_stride = int(os.environ.get("EVAL_STRIDE", 32)) | ||
| muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) |
There was a problem hiding this comment.
Hyperparameters.muon_beta2 is defined but never used anywhere in the script (no references beyond the env read). If Muon is meant to support a beta2 term, wire it through; otherwise consider removing this hyperparameter to avoid confusion/misconfiguration.
| muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95)) |
11L XSA11 + GPTQ + Early QAT + EMA 0.997 + Legal Score-First AdamW TTT
val_bpb (3-seed mean): 1.1195 (std: 0.0005)
Improvements over #503 (1.1218)
Architecture (unchanged from #374 base)
Legal Score-First TTT
Quantization Pipeline
Timing
Compute
8xH100 SXM, ~20 min/seed. Three seeds for verification.