Skip to content

Commit e9c7bea

Browse files
[moe training] update readme with links, cleanup (#3239)
1 parent 53b5efd commit e9c7bea

File tree

1 file changed

+40
-35
lines changed

1 file changed

+40
-35
lines changed

torchao/prototype/moe_training/README.md

Lines changed: 40 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,38 @@
22

33
This prototype provides:
44

5-
1. Quantized building block for low precision MoE training: `_scaled_grouped_mm`. It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [example](#torchao_scaled_grouped_mm-example-forward--backward-pass) of a forward and backward pass below.
5+
1. Quantized building block for low precision MoE training: [_quantize_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/53b5efdac921a38fd15e8d3ac8191c3927140287/torchao/prototype/moe_training/scaled_grouped_mm.py#L42). It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [example](#torchao_scaled_grouped_mm-example-forward--backward-pass) of a forward and backward pass below.
66
- Using MXFP8 on a B200 GPU, this provides:
7-
- ~1.4x - 1.8x speedups over bfloat16 `torch._grouped_mm` for Llama4 17b 16e shapes (depending on the `M` dimension, i.e. batch_size * seq_len)
8-
- ~1.15 - 1.3x speedups over bfloat16 `torch._grouped_mm` for DeepSeekV3 671b shapes (depending on the `M` dimension, i.e. batch_size * seq_len)
7+
- **~1.4x - 1.8x speedups** over bfloat16 `torch._grouped_mm` for Llama4 Scout shapes
8+
- **~1.15 - 1.3x speedups** over bfloat16 `torch._grouped_mm` for DeepSeekV3 671b shapes
9+
- We also provide the following convenience functions for specific recipes:
10+
- [_to_mxfp8_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/53b5efdac921a38fd15e8d3ac8191c3927140287/torchao/prototype/moe_training/scaled_grouped_mm.py#L677)
11+
- [_to_fp8_rowwise_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/53b5efdac921a38fd15e8d3ac8191c3927140287/torchao/prototype/moe_training/scaled_grouped_mm.py#L678)
912

1013

11-
2. [TorchTitan](https://github.com/pytorch/torchtitan/tree/main) integration of torchao's dynamically quantized `_scaled_grouped_mm`: pretrain DeepSeekV3/Llama4 with MXFP8 grouped GEMMs by adding the flag to your training command: `--model.converters="quantize.grouped_mm.mx" [--quantize.grouped_mm.mx.fqns="experts"]`
1214

13-
3. `quantize_(...)` API support for model conversion: this swaps all `torch._grouped_mm` ops in your model definition to use torchao `_scaled_grouped_mm` under the hood (see [example](#model-conversion-api-example-end-to-end-training) below).
15+
2. [TorchTitan](https://github.com/pytorch/torchtitan/tree/main) integration: pretrain DeepSeekV3/Llama4 with MXFP8 grouped GEMMs by adding the flag to your training command: `--model.converters="quantize.grouped_mm.mx" --quantize.grouped_mm.mx.fqns="experts"`
16+
17+
3. Model conversion API to swap all `torch._grouped_mm` ops in your model definition to use torchao `_quantize_then_scaled_grouped_mm` under the hood (see [example](#model-conversion-api-example-end-to-end-training) below).
1418

1519

1620
## Table of Contents
1721

1822
- [Examples](#examples)
19-
- [Performance Benchmarks](#performance-benchmarks-mxfp8)
2023
- [System Requirements](#system-requirements)
24+
- [Microbenchmarks](#microbenchmarks)
25+
- [Single MoE layer benchmarks](#benchmark-single-moe-layer-forward--backward-pass)
26+
- [E2E training benchmarks](#end-to-end-training-benchmark-with-torchtitan-llama4-scout-vs-bfloat16-baseline)
2127
- [Implementation Details for Developers](#implementation-details-for-developers)
2228
- [Limitations](#limitations)
2329

2430
## Examples
25-
#### torchao_scaled_grouped_mm example: forward + backward pass
31+
#### _quantize_then_scaled_grouped_mm usage
2632
```python
2733
import torch
2834
from torch.nn import functional as F
2935
from torchao.prototype.moe_training import (
30-
_quantize_then_scaled_grouped_mm as torchao_scaled_grouped_mm
36+
_quantize_then_scaled_grouped_mm
3137
)
3238
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
3339
from torchao.prototype.moe_training.utils import generate_jagged_offs
@@ -42,7 +48,7 @@ B = torch.randn(num_groups, N, K, dtype=torch.bfloat16, device="cuda", requires_
4248
offs = generate_jagged_offs(num_groups, total_M, device="cuda")
4349

4450
# Forward and backward example
45-
out = torchao_scaled_grouped_mm(
51+
out = _quantize_then_scaled_grouped_mm(
4652
A,
4753
B.transpose(-2, -1),
4854
offs=offs,
@@ -127,39 +133,18 @@ for step in range(10):
127133
# backward pass
128134
out_loss.backward()
129135
optimizer.step()
130-
136+
optimizer.zero_grad()
131137
```
132138

133139
## System requirements
134140
- torchao 0.14+
135141
- For MXFP8 MoE training, CUDA 12.8+ and SM100+ GPU arch are required.
136142
- For FP8 rowwise MoE training, CUDA 12.4+ and SM89+ GPU arch are required.
137143

138-
## Performance benchmarks: MXFP8
139-
140-
141-
### Single MoE layer forward + backward pass vs bfloat16 baseline
142-
143-
| Model | total_M | N | K | bf16 time (ms) | mxfp8 time (ms) | speedup |
144-
|--------------|---------|------|------|---------------|-----------------|---------|
145-
| Llama4 16e | 131072 | 8192 | 5120 | 275.270 | 192.420 | 1.431x |
146-
| DeepSeekV3 | 131072 | 2048 | 7168 | 92.032 | 80.182 | 1.148x |
147-
148-
To reproduce these benchmarks, on a B200 GPU machine, run the following commands:
149-
150-
Llama4 17b 16e shapes:
151-
```bash
152-
CUDA_VISIBLE_DEVICES=6 python benchmarks/prototype/moe_training/bench_moe_layer.py --recipe mxfp8 --local_batch_size=16 --dim=5120 --hidden_dim=8192 --local_num_experts=8
153-
```
154-
155-
DeepSeekV3 671b shapes:
156-
```bash
157-
CUDA_VISIBLE_DEVICES=6 python benchmarks/prototype/moe_training/bench_moe_layer.py --recipe mxfp8 --local_batch_size=16 --dim=7168 --hidden_dim=2048 --local_num_experts=8
158-
```
159144

160-
### Individual bfloat16 torch._grouped_mm op vs torchao_scaled_grouped_mm
145+
## Microbenchmarks
161146

162-
**MXFP8 with Llama4 17b 16e shapes** (with G=1-8 to simulate different degrees of expert parallelism)
147+
**MXFP8 with Llama4 17b 16e shapes** (with G from 1 to 8 to simulate different degrees of expert parallelism)
163148

164149
| M,N,K,G | bf16_fwd_bwd_us | scaled_fwd_bwd_us | scaled_fwd_bwd_speedup |
165150
| ----------------------- | --------------: | ----------------: | ---------------------: |
@@ -168,7 +153,7 @@ CUDA_VISIBLE_DEVICES=6 python benchmarks/prototype/moe_training/bench_moe_layer.
168153
| (128000, 8192, 5120, 4) | 39189.20 | 23945.50 | 1.637x |
169154
| (128000, 8192, 5120, 8) | 37700.70 | 22170.60 | 1.700x |
170155

171-
**MXFP8 with DeepSeekV3** (with G=-8 to simulate different degrees of expert parallelism)
156+
**MXFP8 with DeepSeekV3** (with G from 1 to 8 to simulate different degrees of expert parallelism)
172157

173158
| M,N,K,G | bf16_fwd_bwd_us | scaled_fwd_bwd_us | scaled_fwd_bwd_speedup |
174159
| ----------------------- | --------------: | ----------------: | ---------------------: |
@@ -183,8 +168,28 @@ To reproduce this benchmark, on a B200 GPU machine, run the following command:
183168
- torchao: `0.14.0+gitc7b8e13da`
184169
- torch: `2.10.0a0+gitf6de195`
185170

171+
## Benchmark: single MoE layer forward + backward pass
172+
173+
| Model | total_M | N | K | bf16 time (ms) | mxfp8 time (ms) | speedup |
174+
|--------------|---------|------|------|---------------|-----------------|---------|
175+
| Llama4 16e | 131072 | 8192 | 5120 | 275.270 | 192.420 | 1.431x |
176+
| DeepSeekV3 | 131072 | 2048 | 7168 | 92.032 | 80.182 | 1.148x |
177+
178+
To reproduce these benchmarks, on a B200 GPU machine, run the following commands:
179+
180+
Llama4 17b 16e shapes:
181+
```bash
182+
CUDA_VISIBLE_DEVICES=6 python benchmarks/prototype/moe_training/bench_moe_layer.py --recipe mxfp8 --local_batch_size=16 --dim=5120 --hidden_dim=8192 --local_num_experts=8
183+
```
184+
185+
DeepSeekV3 671b shapes:
186+
```bash
187+
CUDA_VISIBLE_DEVICES=6 python benchmarks/prototype/moe_training/bench_moe_layer.py --recipe mxfp8 --local_batch_size=16 --dim=7168 --hidden_dim=2048 --local_num_experts=8
188+
```
189+
190+
186191

187-
#### End-to-end training: Llama4 16e MoE layer vs bfloat16 baseline with TorchTitan
192+
## End-to-end training benchmark with TorchTitan: Llama4 Scout vs bfloat16 baseline
188193
- Single node benchmarks with 4xB200
189194
- Llama4 16e default configs; FSDP=4, EP=4; AC=none; compile=True; seq_len=8192; local_bs=8
190195
- Reduced num layers from 48 -> 2 to avoid OOM in single node setting

0 commit comments

Comments
 (0)