Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions records/track_10min_16mb/2026-03-29_FusedMLP_Brotli_Memmap/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Record: Fused MLP (Triton+CUTLASS EVT) + MLP 3.5× + Mixed int5/int6 + Brotli
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    • @

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

572051688-afbecff2-0387-471a-8ea8-e3cb6127f578 (1)


**val_bpb: 1.1125** (3-seed mean) | **1.8784 nats** | **~14.52 MB** | 8×H100 SXM, 600s | No TTT

Continuation of our merged PR 1019 (current SOTA, 1.1138 BPB). Fused MLP kernels recover throughput; mechanistic analysis identified MLP as the capacity bottleneck, leading to MLP 3.5× enabled by Hessian-based mixed int5/int6 quantization.

Our merged PR 1019 (current SOTA): **1.88059 nats** (1.1138 BPB). Delta: **−0.00215 nats** (−0.0013 BPB).
Prior leaderboard SOTA (our PR 549): **1.89002 nats** (1.1194 BPB). Delta: **−0.01158 nats** (−0.0069 BPB). Welch's t = −17.63, df ≈ 3.24, p < 0.01.

## Results (3-seed)

| Seed | Steps | ms/step | Post-EMA BPB | **Sliding BPB** | val_loss (nats) | Artifact |
|------|-------|---------|--------------|-----------------|-----------------|----------|
| 314 | 6,844 | 87.7 | 1.1253 | **1.1123** | 1.87802 | 14,519,698 |
| 999 | 6,846 | 87.7 | 1.1256 | **1.1124** | 1.87821 | 14,517,302 |
| 1337 | 6,828 | 87.7 | 1.1261 | **1.1129** | 1.87910 | 14,525,480 |
| **Mean** | **6,839** | **87.7** | | **1.1125** | **1.8784** | |

## Changes vs PR 1019

### 1. Fused Triton TMA Forward MLP Kernel

Fuses `F.linear(x, up_w) -> LeakyReLU(0.5) -> square` into a single Triton TMA kernel. The raw matmul output never hits HBM — activation computed in-register before first store. Backward uses explicit cuBLAS matmuls to preserve torch.compile's cross-layer fusion.

Builds on our kernel profiling in PR 670 (abaybektursun).

### 2. CUTLASS EVT Backward MLP Fusion

Fuses `(go @ down_w) * act_grad` into a single CUTLASS 3.x kernel via Epilogue Visitor Tree. The multiply happens in the GEMM epilogue while tiles are still in registers. Uses `KernelTmaWarpSpecializedPingpong` on sm90a with 128x128 tiles.

| Variant | dpre time | vs Unfused |
|---|---|---|
| cuBLAS unfused | 1.199 ms | baseline |
| Triton precomp | 1.130 ms | -0.069 ms |
| **CUTLASS Pingpong** | **1.095 ms** | **-0.104 ms** |

CUTLASS EVT is a hard dependency — no silent fallback.

### 3. Pre-Computed Activation Gradient

Store `act_grad = where(pre > 0, 2*pre, 0.5*pre)` in forward instead of `pre`. Zero extra memory cost. Derive `post = 0.5 * act_grad * c0` via algebraic identity. Eliminates `where()` branching from both forward and backward, and enables the CUTLASS EVT to use a trivial 3-node epilogue tree (`multiplies(AccFetch, AuxLoad)`) with no conditional logic.

### 4. Brotli-11 Compression (replaces LZMA-9)

-581 KB (-5.9%) vs LZMA-9. Independently discovered; also used in PR 1089 (mikeapedia).

### 5. Memmap Multi-Shard Data Pipeline + GPU Prefetch

Coprime-stride sampling, daemon thread, CUDA stream prefetch. Credit: DeepReinforce (PR 726).

## Negative Results

- **Turbo-Muon (AOL + Polar Express NS4):** +0.0018 BPB worse on 8xH100 AND artifact over 16MB. Early convergence advantage at step 500 doesn't hold at 7000+ steps. Reverted to standard NS5.
- **2:4 Structured Sparsity:** +0.672 BPB. Dead.

## Architecture

| Component | Setting | Source |
|-----------|---------|--------|
| Layers | 11 (512d, 8 GQA / 4 KV heads) | Baseline |
| MLP | 3x (1536), LeakyReLU(0.5)^2 | PR 493 (parinzee) |
| MLP Forward | **Fused Triton TMA kernel** | **This work** (profiling: our PR 670) |
| MLP Backward | **CUTLASS EVT Pingpong + pre-computed act_grad** | **This work** |
| Attention | XSA on all 11 layers | PR 478 (gowtham0992) |
| BigramHash | 3072 x 112 | Our PR 1019 (concept: PR 162 (raahilshah)) |
| RoPE | Partial (16/64 dims) | PR 315 (jfprincz) |
| LN Scale | 1/sqrt(layer+1) | PR 315 (jfprincz) |
| VE128 | Layers 9-10 | PR 374 (unnir) |
| SmearGate | Position-mixing gate | PR 65 (aquariouseworkman) |
| U-Net skips | Encoder-decoder connections | PR 289 |
| Weight avg | EMA(0.997) + SWA(every 50) | PR 401 (newjordan) |
| Quantization | Full Hessian GPTQ int6 (AR self-gen calibration) | Our PR 1019 (GPTQ: PR 535 (raahilshah)) |
| Compression | **Brotli quality=11** | **This work** (independently: PR 1089 (mikeapedia)) |
| Data Pipeline | **Memmap multi-shard + GPU prefetch** | PR 726 (DeepReinforce) |
| Warmdown | 4000 iterations | PR 364 (shikhar1729) |
| Optimizer | Parallel Muon (NS5) | Our PR 399 |
| Late QAT | STE at LR scale < 0.15 | PR 286 (chris-buckley) |
| Selective pruning | +/-1 by reconstruction error | PR 609 (saml212) |
| Flash Attention 3 | Hopper kernels | PR 122 (mtybadger) |

**Calibration legality:** AR self-generated (64 seqs x 2048 tokens, temp=0.8). No val data, no train data accessed during quantization. Same method as our PR 1019.

## Setup & Reproduction

```bash
# 1. Python dependencies
pip install torch==2.9.1 --index-url https://download.pytorch.org/whl/cu128
pip install flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291
pip install sentencepiece brotli

# 2. CUTLASS headers (header-only, no build needed for CUTLASS itself)
cd /opt && git clone --depth 1 --branch v3.7.0 https://github.com/NVIDIA/cutlass

# 3. Build CUTLASS EVT extension
cd cutlass_evt_fusion
CUTLASS_PATH=/opt/cutlass python3 setup.py build_ext --inplace
cd ..

# 4. Set library paths (auto-detect from Python packages)
export LD_LIBRARY_PATH=$(python3 -c "import torch; print(torch.__path__[0] + '/lib')"):$(python3 -c "import nvidia.cuda_runtime; print(nvidia.cuda_runtime.__path__[0] + '/lib')"):${LD_LIBRARY_PATH:-}

# 5. Download data
python3 data/cached_challenge_fineweb.py --variant sp1024

# 6. Train (3 seeds)
for SEED in 314 42 999; do
BIGRAM_VOCAB_SIZE=3072 BIGRAM_DIM=112 WARMDOWN_ITERS=4000 SEED=$SEED \
torchrun --standalone --nproc_per_node=8 train_gpt.py 2>&1 | tee train_seed${SEED}.log
done
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
// CUTLASS 3.x EVT kernel: fused GEMM * elementwise multiply
// Computes: dpre = (go @ down_w.T) * act_grad
// Where act_grad = f'(pre) is pre-computed in the forward pass.
//
// Layout convention:
// go: (M, K) bf16 row-major
// down_w: (K, N) bf16 row-major — CUTLASS B(N,K) with RowMajor layout
// act_grad: (M, N) bf16 row-major
// dpre: (M, N) bf16 row-major output

#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp"
#include "cute/tensor.hpp"
#include "cutlass/util/packed_stride.hpp"
#include <iostream>

using namespace cute;

// --- Type aliases ---

using ElementAcc = float;
using ElementCompute = float;
using ElementOutput = cutlass::bfloat16_t;
using ElementAux = cutlass::bfloat16_t;

using namespace cutlass::epilogue::fusion;

// --- Tile / schedule configuration ---

using TileShape = Shape<_128, _256, _64>;
using ClusterShape = Shape<_1, _1, _1>;
using EpilogueTile = cutlass::epilogue::collective::EpilogueTileAuto;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;

// --- Resolve AuxLoad types via EpilogueDescriptor ---

using EpiDesc = cutlass::epilogue::collective::detail::EpilogueDescriptor<
TileShape, EpilogueTile, ElementOutput, ElementOutput, EpilogueSchedule>;

using AuxDesc = cutlass::epilogue::collective::detail::AuxLoadDescriptor<
EpiDesc, cutlass::layout::RowMajor, ElementAux>;

// --- EVT tree: acc * aux_load (builtin multiply) ---

using AuxLoad = Sm90AuxLoad<
AuxDesc::Stages,
typename EpiDesc::EpilogueTile,
typename AuxDesc::Element,
typename AuxDesc::Stride,
typename AuxDesc::SmemLayoutAtom,
typename AuxDesc::CopyOpS2R>;

// Compute node: builtin multiply(acc, act_grad)
using Compute = Sm90Compute<
cutlass::multiplies,
ElementOutput,
ElementCompute,
cutlass::FloatRoundStyle::round_to_nearest>;

// Tree: root = Multiply(child0 = AccFetch, child1 = AuxLoad)
using EVT = Sm90EVT<Compute, Sm90AccFetch, AuxLoad>;

// --- CollectiveBuilder + Kernel type ---

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
TileShape,
ClusterShape,
EpilogueTile,
ElementAcc, ElementCompute,
ElementOutput, cutlass::layout::RowMajor, /* AlignC */ 8,
ElementOutput, cutlass::layout::RowMajor, /* AlignD */ 8,
EpilogueSchedule,
EVT
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm90,
cutlass::arch::OpClassTensorOp,
ElementOutput, cutlass::layout::RowMajor, /* AlignA */ 8,
ElementOutput, cutlass::layout::RowMajor, /* AlignB */ 8,
ElementAcc,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::KernelTmaWarpSpecializedCooperative
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>,
CollectiveMainloop,
CollectiveEpilogue>;

using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

// --- Host launcher ---

void launch_gemm_mul(
void const* ptr_go, // (M, K) bf16 row-major
void const* ptr_down_w, // (K, N) bf16 row-major = RowMajor B(N,K) for CUTLASS
void const* ptr_act_grad, // (M, N) bf16 row-major
void* ptr_dpre, // (M, N) bf16 row-major output
int M, int N, int K,
cudaStream_t stream)
{
using StrideA = cutlass::gemm::TagToStrideA_t<cutlass::layout::RowMajor>;
using StrideB = cutlass::gemm::TagToStrideB_t<cutlass::layout::RowMajor>;
using StrideC = cutlass::gemm::TagToStrideC_t<cutlass::layout::RowMajor>;

int L = 1;
auto prob_shape = make_shape(M, N, K, L);

auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
auto stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
auto stride_Aux = cutlass::make_cute_packed_stride(
typename AuxDesc::Stride{}, cute::make_shape(M, N, L));

typename EVT::Arguments evt_args {
{}, // Sm90AccFetch: no args
{ // Sm90AuxLoad: pointer + null_default + stride
static_cast<ElementAux const*>(ptr_act_grad),
ElementAux(0),
stride_Aux
},
{} // Sm90Compute (multiplies): no args
};

typename GemmOp::Arguments args {
cutlass::gemm::GemmUniversalMode::kGemm,
prob_shape,
{ // Mainloop
static_cast<ElementOutput const*>(ptr_go),
stride_A,
static_cast<ElementOutput const*>(ptr_down_w),
stride_B,
},
{ // Epilogue: {thread_args, ptr_C, stride_C, ptr_D, stride_D}
evt_args,
static_cast<ElementOutput const*>(ptr_dpre), // ptr_C (unused but TMA needs valid ptr)
stride_C,
static_cast<ElementOutput*>(ptr_dpre), // ptr_D (output)
stride_C,
}
};

GemmOp gemm_op;
size_t workspace_size = GemmOp::get_workspace_size(args);
void* workspace = nullptr;
if (workspace_size > 0) {
cudaMalloc(&workspace, workspace_size);
}

auto status = gemm_op.initialize(args, workspace, stream);
if (status != cutlass::Status::kSuccess) {
std::cerr << "CUTLASS initialize failed: " << cutlassGetStatusString(status) << std::endl;
if (workspace) cudaFree(workspace);
exit(EXIT_FAILURE);
}

status = gemm_op.run(stream);
if (status != cutlass::Status::kSuccess) {
cudaError_t cuda_err = cudaStreamSynchronize(stream);
std::cerr << "CUTLASS run failed: " << cutlassGetStatusString(status)
<< " CUDA: " << cudaGetErrorString(cuda_err) << std::endl;
if (workspace) cudaFree(workspace);
exit(EXIT_FAILURE);
}

if (workspace) cudaFree(workspace);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// PyTorch C++ extension: CUTLASS EVT fused GEMM * elementwise multiply
// dpre = (go @ down_w.T) * act_grad
// Pass down_w directly (K, N) — NOT down_w.T.contiguous()

#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>

void launch_gemm_mul(
void const*, void const*, void const*, void*, int, int, int, cudaStream_t);

at::Tensor gemm_mul(at::Tensor go, at::Tensor down_w, at::Tensor act_grad) {
TORCH_CHECK(go.is_cuda() && go.is_contiguous());
TORCH_CHECK(down_w.is_cuda() && down_w.is_contiguous());
TORCH_CHECK(act_grad.is_cuda() && act_grad.is_contiguous());
TORCH_CHECK(go.scalar_type() == at::kBFloat16);
TORCH_CHECK(down_w.scalar_type() == at::kBFloat16);
TORCH_CHECK(act_grad.scalar_type() == at::kBFloat16);

int M = go.size(0);
int K = go.size(1);
int N = down_w.size(1); // down_w is (K, N) row-major

TORCH_CHECK(down_w.size(0) == K,
"K mismatch: go has K=", K, " but down_w has size(0)=", down_w.size(0));
TORCH_CHECK(act_grad.size(0) == M && act_grad.size(1) == N,
"act_grad shape must be (M, N), got (", act_grad.size(0), ", ", act_grad.size(1), ")");

at::Tensor dpre = at::empty({M, N}, go.options());

launch_gemm_mul(
go.data_ptr(), down_w.data_ptr(), act_grad.data_ptr(), dpre.data_ptr(),
M, N, K,
at::cuda::getCurrentCUDAStream());

return dpre;
}

TORCH_LIBRARY(cutlass_evt, m) {
m.def("gemm_mul(Tensor go, Tensor down_w, Tensor act_grad) -> Tensor");
}

TORCH_LIBRARY_IMPL(cutlass_evt, CUDA, m) {
m.impl("gemm_mul", &gemm_mul);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import os

CUTLASS_PATH = os.environ.get("CUTLASS_PATH", "/opt/cutlass")

setup(
name="cutlass_evt_fusion",
ext_modules=[
CUDAExtension(
name="cutlass_evt_fusion",
sources=[
"csrc/gemm_act_grad.cu",
"csrc/torch_binding.cpp",
],
include_dirs=[
f"{CUTLASS_PATH}/include",
f"{CUTLASS_PATH}/tools/util/include",
],
extra_compile_args={
"nvcc": [
"-std=c++17",
"-arch=sm_90a",
"-O3",
"--use_fast_math",
"--expt-relaxed-constexpr",
"-DNDEBUG",
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
],
},
),
],
cmdclass={"build_ext": BuildExtension},
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# FlashAttention 3 must be installed separately; see README.md
sentencepiece
brotli>=1.1
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"name": "Fused MLP (Triton+CUTLASS EVT) + MLP 3.5× + Mixed int5/int6 + Brotli",
"author": "Abay Bektursun",
"github_id": "abaybektursun",
"date": "2026-03-30",
"val_loss": 1.87844412,
"val_bpb": 1.11252336,
"bytes_total": 14525480,
"blurb": "Fused Triton TMA forward + CUTLASS EVT backward MLP kernels, pre-computed activation gradient, MLP 3.5x (1792 hidden dim, motivated by SVD analysis showing 94.4% MLP utilization), Hessian-based mixed int5/int6 quantization (motivated by per-matrix quant sensitivity showing MLP = 80% of damage), Brotli-11 compression, LR floor 0.05, memmap multi-shard pipeline. AR self-gen GPTQ. 3-seed mean (314/999/1337): 1.1125 BPB / 1.8784 nats. Delta vs prior leaderboard SOTA: -0.0116 nats. Welch's t=-17.63, p<0.01."
}
Loading