You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: torchao/prototype/moe_training/README.md
+40-35Lines changed: 40 additions & 35 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -2,32 +2,38 @@
2
2
3
3
This prototype provides:
4
4
5
-
1. Quantized building block for low precision MoE training: `_scaled_grouped_mm`. It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [example](#torchao_scaled_grouped_mm-example-forward--backward-pass) of a forward and backward pass below.
5
+
1. Quantized building block for low precision MoE training: [_quantize_then_scaled_grouped_mm](https://github.com/pytorch/ao/blob/53b5efdac921a38fd15e8d3ac8191c3927140287/torchao/prototype/moe_training/scaled_grouped_mm.py#L42). It is a differentiable drop-in replacement for `torch._grouped_mm` that dynamically quantizes inputs using the given recipe, performs a scaled grouped GEMM, then returns the results in original precision. See runnable [example](#torchao_scaled_grouped_mm-example-forward--backward-pass) of a forward and backward pass below.
6
6
- Using MXFP8 on a B200 GPU, this provides:
7
-
-~1.4x - 1.8x speedups over bfloat16 `torch._grouped_mm` for Llama4 17b 16e shapes (depending on the `M` dimension, i.e. batch_size * seq_len)
8
-
-~1.15 - 1.3x speedups over bfloat16 `torch._grouped_mm` for DeepSeekV3 671b shapes (depending on the `M` dimension, i.e. batch_size * seq_len)
7
+
-**~1.4x - 1.8x speedups** over bfloat16 `torch._grouped_mm` for Llama4 Scout shapes
8
+
-**~1.15 - 1.3x speedups** over bfloat16 `torch._grouped_mm` for DeepSeekV3 671b shapes
9
+
- We also provide the following convenience functions for specific recipes:
2.[TorchTitan](https://github.com/pytorch/torchtitan/tree/main) integration of torchao's dynamically quantized `_scaled_grouped_mm`: pretrain DeepSeekV3/Llama4 with MXFP8 grouped GEMMs by adding the flag to your training command: `--model.converters="quantize.grouped_mm.mx" [--quantize.grouped_mm.mx.fqns="experts"]`
12
14
13
-
3.`quantize_(...)` API support for model conversion: this swaps all `torch._grouped_mm` ops in your model definition to use torchao `_scaled_grouped_mm` under the hood (see [example](#model-conversion-api-example-end-to-end-training) below).
15
+
2.[TorchTitan](https://github.com/pytorch/torchtitan/tree/main) integration: pretrain DeepSeekV3/Llama4 with MXFP8 grouped GEMMs by adding the flag to your training command: `--model.converters="quantize.grouped_mm.mx" --quantize.grouped_mm.mx.fqns="experts"`
16
+
17
+
3. Model conversion API to swap all `torch._grouped_mm` ops in your model definition to use torchao `_quantize_then_scaled_grouped_mm` under the hood (see [example](#model-conversion-api-example-end-to-end-training) below).
0 commit comments