Skip to content

Record: GPTQ + Legal TTT (3-seed mean val_bpb=1.1195)#529

Open
EthanYangTW wants to merge 2 commits intoopenai:mainfrom
EthanYangTW:submission/gptq-qat-ema-3seed
Open

Record: GPTQ + Legal TTT (3-seed mean val_bpb=1.1195)#529
EthanYangTW wants to merge 2 commits intoopenai:mainfrom
EthanYangTW:submission/gptq-qat-ema-3seed

Conversation

@EthanYangTW
Copy link

11L XSA11 + GPTQ + Early QAT + EMA 0.997 + Legal Score-First AdamW TTT

val_bpb (3-seed mean): 1.1195 (std: 0.0005)

Seed val_bpb Artifact
1337 1.1189 15.96 MB
42 1.1197 15.75 MB
7 1.1198 15.54 MB

Improvements over #503 (1.1218)

  • GPTQ quantization: Hessian-aware error compensation with column reordering, 256-sample calibration (-0.0024 BPB quant tax reduction)
  • Early QAT (threshold 0.5): ~1750 QAT steps (3x more), model adapts to quant noise longer
  • EMA 0.997: Tuned from 0.9985

Architecture (unchanged from #374 base)

  • 11 layers, model_dim=512, 8H/4KV (GQA), MLP 3x relu²
  • XSA on all 11 layers
  • Partial RoPE 16/64, LN Scale, SmearGate + OrthoInit
  • BigramHash 2048, Shared VE128 (layers 9,10)
  • FA3 Hopper, ~89ms/step → ~6737 steps in 600s
  • ~27M params, int6 + zstd-22, 2% magnitude pruning

Legal Score-First TTT

  • Chunks of 131072 tokens, stride=32 sliding window
  • For each chunk: score first (inference_mode), then adapt
  • AdamW (lr=0.0001, wd=0.0), 3 epochs per chunk, cosine LR
  • Last 2 blocks + norms + lm_head unfrozen (4.7M / 27M params)
  • Every token scored BEFORE any gradient update using it
  • Manual grad all_reduce (no DDP wrapper)

Quantization Pipeline

  1. Early QAT (threshold 0.5): fake int6 STE with 0.9995 percentile clipping
  2. GPTQ (post-training): 256-sample Hessian calibration, per-row optimal scales, column reordering, block-128 Cholesky error compensation
  3. int6 quantization (range [-31, 31]) stored as int8
  4. 2% magnitude pruning
  5. zstd-22 compression

Timing

  • Training: 600s (89ms/step, ~6737 steps)
  • Sliding window eval: ~151s
  • TTT (3 epochs): ~465s
  • Total eval: ~616s

Compute

8xH100 SXM, ~20 min/seed. Three seeds for verification.

…_bpb=1.1195)

Improvements over previous submission (1.1218):
- GPTQ quantization: Hessian-aware error compensation, -0.0024 BPB
- Early QAT (threshold 0.5): 3x more QAT steps
- EMA 0.997 (tuned from 0.9985)

3-seed results:
  Seed 1337: 1.1189 (15.96 MB)
  Seed 42:   1.1197 (15.75 MB)
  Seed 7:    1.1198 (15.54 MB)
  Mean:      1.1195 (std 0.0005)
Copilot AI review requested due to automatic review settings March 23, 2026 12:59
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Updates train_gpt.py to a record-grade training/eval pipeline featuring GPTQ-based int6 quantization and “legal” score-first test-time training (TTT), aligning the script with the PR’s reported val_bpb improvements.

Changes:

  • Added GPTQ Hessian calibration + mixed int6/int8 quantization export with optional zstd compression.
  • Expanded the model/training stack (e.g., XSA, smear gate, bigram hash embedding, VE, SWA/EMA, early/late QAT) and updated hyperparameters.
  • Added sliding-window evaluation and score-first TTT evaluation paths.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

else:
self.ve_shared = None
self.ve_layer_scales = nn.ParameterList()
self.value_embeds = nn.ModuleList()
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.value_embeds is initialized but never used (no subsequent reads/writes). If it’s leftover from an earlier design, remove it; if it’s intended for future per-layer VE modules, it should be populated/used to avoid dead code.

Suggested change
self.value_embeds = nn.ModuleList()

Copilot uses AI. Check for mistakes.
Hard stop: To keep readable for newcomers, let's make sure `train_gpt.py` and `train_gpt_mlx.py` never are longer than 1500 lines.
"""

"""V7: GPTQ quantization + Early QAT 0.5 + EMA 0.997. Fork of #503."""
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR removes the introductory docstring that documents the intent/constraints of train_gpt.py, and the file now exceeds the previously stated 1500-line hard stop (file ends at line 1579). If this script is meant to remain a newcomer-friendly baseline, consider restoring the guidance and/or moving record-grade logic into /records (or splitting into modules) so train_gpt.py stays under the line limit.

Suggested change
"""V7: GPTQ quantization + Early QAT 0.5 + EMA 0.997. Fork of #503."""
"""Baseline GPT training script.
This file is intended to remain a **newcomer-friendly baseline**: a single,
readable script that can be skimmed top-to-bottom to understand how the model,
data loading, training loop, and evaluation fit together.
Design / usage notes:
- Configuration is driven primarily by environment variables (see
`Hyperparameters` below) so that this file can stay code-only and easily
copy-pasted.
- The implementation favors clarity over chasing absolute record scores.
If you are adding highly experimental, benchmark- or record-focused logic,
prefer putting it into `/records` or separate modules and calling into it
from here instead of inlining large new subsystems.
- As a soft guideline, keep this script to ~1500 lines by factoring out
heavy utilities into dedicated modules. This helps keep the baseline
approachable for new contributors.
Current variant: V7, with GPTQ quantization, early QAT (0.5), and EMA 0.997.
This is a fork/evolution of the earlier #503 baseline.
"""

Copilot uses AI. Check for mistakes.
w32 = self.weight.float()
row_clip = torch.quantile(w32.abs(), 0.9995, dim=1)
scale = (row_clip / 31.0).clamp_min(1.0 / 31.0)
w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype)
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The QAT fake-quant path clamps to [-32, 31], while the rest of the int6 pipeline (e.g., clip_range=31) implies a symmetric [-31, 31] range. This mismatch changes the quantization behavior during QAT vs export and can hurt calibration/roundtrip fidelity. Align the clamp range with the export quantizer (or document why the asymmetric range is intentional).

Suggested change
w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -32, 31) * scale[:, None]).to(x.dtype)
w_q = (torch.clamp(torch.round(w32 / scale[:, None]), -31, 31) * scale[:, None]).to(x.dtype)

Copilot uses AI. Check for mistakes.
Comment on lines +1061 to +1068
if not t.is_floating_point() or t.numel() <= 65536:
result[name] = t.to(torch.float16) if t.is_floating_point() else t
meta[name] = "passthrough"
continue
if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS):
result[name] = t.float()
meta[name] = "passthrough_ctrl"
continue
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In mixed_quantize_int6_gptq, the early t.numel() <= 65536 passthrough runs before the control-tensor check. This means small control/scaling tensors matching CONTROL_TENSOR_NAME_PATTERNS will be downcast to fp16, losing precision, despite the later intent to keep them fp32. Reorder the checks so control tensors are handled before the size-based passthrough (or explicitly exempt them from downcasting).

Suggested change
if not t.is_floating_point() or t.numel() <= 65536:
result[name] = t.to(torch.float16) if t.is_floating_point() else t
meta[name] = "passthrough"
continue
if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS):
result[name] = t.float()
meta[name] = "passthrough_ctrl"
continue
# Non-floating tensors are passed through unchanged.
if not t.is_floating_point():
result[name] = t
meta[name] = "passthrough"
continue
# Control/scaling tensors are kept in float32, regardless of size.
if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS):
result[name] = t.float()
meta[name] = "passthrough_ctrl"
continue
# Small floating-point tensors are downcast to float16 and passed through.
if t.numel() <= 65536:
result[name] = t.to(torch.float16)
meta[name] = "passthrough"
continue

Copilot uses AI. Check for mistakes.
Comment on lines +1102 to +1109
if not t.is_floating_point() or t.numel() <= 65536:
result[name] = t.to(torch.float16) if t.is_floating_point() else t
meta[name] = "passthrough"
continue
if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS):
result[name] = t.float()
meta[name] = "passthrough_ctrl"
continue
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as above in mixed_quantize_int6: the size-based passthrough happens before the control-tensor check, so small control/scaling tensors will be downcast to fp16. Reorder/exempt control tensors to preserve fp32 precision for these parameters.

Suggested change
if not t.is_floating_point() or t.numel() <= 65536:
result[name] = t.to(torch.float16) if t.is_floating_point() else t
meta[name] = "passthrough"
continue
if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS):
result[name] = t.float()
meta[name] = "passthrough_ctrl"
continue
# Always preserve floating-point control/scaling tensors in FP32,
# regardless of their size, to avoid precision loss.
if t.is_floating_point() and any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS):
result[name] = t.float()
meta[name] = "passthrough_ctrl"
continue
if not t.is_floating_point() or t.numel() <= 65536:
result[name] = t.to(torch.float16) if t.is_floating_point() else t
meta[name] = "passthrough"
continue

Copilot uses AI. Check for mistakes.
self._cos_cached = freqs.cos()[None, None, :, :]
self._sin_cached = freqs.sin()[None, None, :, :]
rd = self.rope_dims
if seq_len > self.train_seq_len:
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rotary.forward computes rd / (rd - 2) when seq_len > train_seq_len. If rope_dims is set to 2 via env/config, this will divide by zero and crash on longer sequences. Add a guard/validation (e.g., require rope_dims > 2 when using the NTK-style base rescaling) or handle the rd==2 case explicitly.

Suggested change
if seq_len > self.train_seq_len:
if seq_len > self.train_seq_len:
if rd <= 2:
raise ValueError(
f"rope_dims must be greater than 2 when using NTK-style RoPE rescaling "
f"for seq_len ({seq_len}) > train_seq_len ({self.train_seq_len}); got rope_dims={rd}"
)

Copilot uses AI. Check for mistakes.
# Background on Muon: https://kellerjordan.github.io/posts/muon/
grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.3))
eval_stride = int(os.environ.get("EVAL_STRIDE", 32))
muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95))
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hyperparameters.muon_beta2 is defined but never used anywhere in the script (no references beyond the env read). If Muon is meant to support a beta2 term, wire it through; otherwise consider removing this hyperparameter to avoid confusion/misconfiguration.

Suggested change
muon_beta2 = float(os.environ.get("MUON_BETA2", 0.95))

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants