Submit 1x A100 QAT Fix - 1.4078 BPB (Non-Record)#527
Submit 1x A100 QAT Fix - 1.4078 BPB (Non-Record)#527Shuvam-Banerji-Seal wants to merge 2 commits intoopenai:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds a new non-record submission under records/track_non_record_16mb for a single-A100 QAT run, intended to fix LR scheduling for single-device scaling and avoid a Triton compilation slowdown by switching the QAT clipping estimator to amax.
Changes:
- Introduces a new
train_gpt.pytraining script for the submission, including QAT modifications and int6+compression export. - Adds run artifacts (
train.log) and submission metadata (submission.json,README.md).
Reviewed changes
Copilot reviewed 3 out of 4 changed files in this pull request and generated 8 comments.
| File | Description |
|---|---|
| records/track_non_record_16mb/2026-03-23_Single_A100_QAT_FastFix/train_gpt.py | New end-to-end training/eval/export script for the single-A100 QAT submission. |
| records/track_non_record_16mb/2026-03-23_Single_A100_QAT_FastFix/train.log | Captures the run output used to report the submission’s metrics. |
| records/track_non_record_16mb/2026-03-23_Single_A100_QAT_FastFix/submission.json | Declares submission metadata (name, BPB, size, blurb, author, date). |
| records/track_non_record_16mb/2026-03-23_Single_A100_QAT_FastFix/README.md | Documents the submission intent and reported results. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| import numba | ||
|
|
||
| try: | ||
| import lz4.frame | ||
| except ImportError: | ||
| pass | ||
|
|
||
| import os | ||
| _COMPRESSOR = os.environ.get("COMPRESSOR", _COMPRESSOR) | ||
|
|
||
| import numba | ||
| import numba | ||
| from numba import njit |
There was a problem hiding this comment.
numba is imported unconditionally, but this script doesn’t appear to rely on it for any executed code path. Since other record scripts don’t require numba, this adds a new hard dependency that can break runs in environments without it; please remove numba/njit usage or make it optional (with a fallback path).
| @njit | ||
| def fast_byte_parse(data_bytes): | ||
| # This is just a placeholder example to fulfill the numba requirement for fast processing | ||
| pass | ||
|
|
There was a problem hiding this comment.
fast_byte_parse is defined as a @njit function but is unused and contains only pass. If it’s not part of the submission, it should be removed (and the numba dependency dropped); if it is intended, it needs a real implementation and call site.
| @njit | |
| def fast_byte_parse(data_bytes): | |
| # This is just a placeholder example to fulfill the numba requirement for fast processing | |
| pass |
| if _COMPRESSOR == "zstd": | ||
| quant_blob = zstandard.ZstdCompressor(level=22).compress(quant_raw) | ||
| elif _COMPRESSOR == "lz4": | ||
| quant_blob = lz4.frame.compress(quant_raw, compression_level=16) | ||
| else: | ||
| quant_blob = zlib.compress(quant_raw, 9) |
There was a problem hiding this comment.
_COMPRESSOR can be overridden via the COMPRESSOR env var, but if it’s set to zstd or lz4 in an environment without those optional deps, the later zstandard.* / lz4.frame.* calls will crash. Consider validating availability right after parsing COMPRESSOR (or falling back to zlib with a clear warning/error).
| # We must compile for proper scaled_mm, handling simple scaling | ||
| # For simplicity in this challenge, we just cast to BF16 or try crude e4m3 | ||
| pass |
There was a problem hiding this comment.
USE_FP8 is exposed as a flag, but the FP8 branch is currently a no-op (pass). This can lead to confusing behavior where FP8 is “enabled” but silently ignored; either implement the FP8 path, remove the flag, or raise a clear error when USE_FP8=1.
| # We must compile for proper scaled_mm, handling simple scaling | |
| # For simplicity in this challenge, we just cast to BF16 or try crude e4m3 | |
| pass | |
| # FP8 path is not implemented in this training script. | |
| # Fail fast instead of silently ignoring USE_FP8 to avoid confusing behavior. | |
| raise RuntimeError( | |
| "USE_FP8=1 was set, but the FP8 path in CastedLinear is not implemented " | |
| "in this script. Please unset USE_FP8 or implement FP8 handling." | |
| ) |
| "name": "Single A100 QAT Performance Fix", | ||
| "val_bpb": 1.4078, | ||
| "bytes_total": 15772699, | ||
| "blurb": "Enabled QAT directly within CastedLinear using straight-through estimators. Crucially refactored torch.quantile to .amax(dim=1) to alleviate a 30x compiler performance penalty that was blocking completions. Run fits perfectly in a Single A100 constraint for 10 minutes runtime natively using 2600 steps.", |
There was a problem hiding this comment.
The blurb states the run “fits perfectly … for 10 minutes runtime”, but the included train.log shows training stops at ~600s and then the final sliding-window eval takes ~1,357s (total wall time ~32 min). Please clarify that the 10-minute cap refers to training only, or adjust the default final evaluation mode/stride to stay within the stated runtime.
|
|
||
| ## Results | ||
| * **Hardware:** 1x A100 (80GB) | ||
| * **Epoch Length:** 10 Minutes (Wallclock Cap - 2600 iterations) |
There was a problem hiding this comment.
README.md claims a 10-minute wallclock cap, but the included train.log shows the final sliding-window evaluation alone takes ~22.6 minutes. Please either qualify that the cap is for training-only or document the end-to-end runtime and/or provide a faster final-eval configuration.
| * **Epoch Length:** 10 Minutes (Wallclock Cap - 2600 iterations) | |
| * **Training Loop Length:** 10 Minutes (Wallclock Cap - 2600 iterations; excludes final sliding-window evaluation) | |
| * **End-to-End Runtime (Training + Final Sliding-Window Eval):** ~23 Minutes (per `train.log`) |
|
|
||
| * **Author:** Shuvam Banerji Seal (https://github.com/Shuvam-Banerji-Seal) |
There was a problem hiding this comment.
The author line is duplicated; please keep it only once to avoid confusion.
| * **Author:** Shuvam Banerji Seal (https://github.com/Shuvam-Banerji-Seal) |
|
|
||
| import os | ||
| _COMPRESSOR = os.environ.get("COMPRESSOR", _COMPRESSOR) | ||
| import numba | ||
|
|
||
| try: | ||
| import lz4.frame | ||
| except ImportError: | ||
| pass | ||
|
|
||
| import os | ||
| _COMPRESSOR = os.environ.get("COMPRESSOR", _COMPRESSOR) | ||
|
|
||
| import numba | ||
| import numba |
There was a problem hiding this comment.
The import/setup block has duplicated imports and duplicated _COMPRESSOR initialization (multiple import os, repeated try: import lz4.frame, repeated import numba). This is easy to desync and makes the script harder to audit; please consolidate to a single import block and set _COMPRESSOR exactly once.
| import os | |
| _COMPRESSOR = os.environ.get("COMPRESSOR", _COMPRESSOR) | |
| import numba | |
| try: | |
| import lz4.frame | |
| except ImportError: | |
| pass | |
| import os | |
| _COMPRESSOR = os.environ.get("COMPRESSOR", _COMPRESSOR) | |
| import numba | |
| import numba | |
| _COMPRESSOR = os.environ.get("COMPRESSOR", _COMPRESSOR) | |
| import numba |
Single-device (A100) run tuning hyperparams down from multi-device scales to ensure proper LR scheduling. Also swaps torch.quantile for w.amax().clamp_min to evade a 30x compiler performance penalty in Triton.