Skip to content

Commit e8acfbb

Browse files
[mxfp8 moe training] integrate triton quant/dequant kernels into mxfp8 all to all
1 parent 82ded0b commit e8acfbb

File tree

2 files changed

+63
-52
lines changed

2 files changed

+63
-52
lines changed

benchmarks/prototype/moe_training/mxfp8/bench_all_to_all_v.py

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@ class ExperimentConfig:
4242

4343
@dataclass(frozen=True)
4444
class ExperimentResult:
45-
bf16_ms: float
46-
mxfp8_ms: float
45+
bf16_fwd_ms: float
46+
mxfp8_fwd_ms: float
47+
bf16_bwd_ms: float
48+
mxfp8_bwd_ms: float
4749

4850

4951
@dataclass(frozen=True)
@@ -67,7 +69,7 @@ def get_configs() -> List[ExperimentConfig]:
6769
return configs
6870

6971

70-
def default_a2a_fwd_bwd(
72+
def default_a2a_fwd(
7173
routed_input: torch.Tensor,
7274
labels: torch.Tensor,
7375
output_splits_list: list[int],
@@ -81,15 +83,11 @@ def default_a2a_fwd_bwd(
8183
device_mesh.get_group(),
8284
)
8385
routed_input = torch.ops._c10d_functional.wait_tensor(routed_input)
84-
85-
loss = F.mse_loss(routed_input, labels)
86-
loss.backward()
87-
8886
torch.cuda.synchronize()
8987
return routed_input
9088

9189

92-
def mxfp8_a2a_fwd_bwd(
90+
def mxfp8_a2a_fwd(
9391
routed_input: torch.Tensor,
9492
labels: torch.Tensor,
9593
output_splits_list: list[int],
@@ -102,16 +100,17 @@ def mxfp8_a2a_fwd_bwd(
102100
input_splits_list,
103101
device_mesh.get_group(),
104102
)
105-
106-
loss = F.mse_loss(routed_input, labels)
107-
loss.backward()
108103
torch.cuda.synchronize()
109104
return routed_input
110105

111106

112-
# Compile target funcs
113-
default_a2a_sync_compiled = torch.compile(default_a2a_fwd_bwd)
114-
mxfp8_a2a_sync_compiled = torch.compile(mxfp8_a2a_fwd_bwd)
107+
def mse_loss_and_bwd(
108+
routed_input: torch.Tensor,
109+
labels: torch.Tensor,
110+
):
111+
loss = F.mse_loss(routed_input, labels)
112+
loss.backward()
113+
torch.cuda.synchronize()
115114

116115

117116
def run_experiment(
@@ -149,62 +148,78 @@ def warmup(func_no_args):
149148

150149
# Bench default a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
151150
warmup(
152-
lambda: default_a2a_sync_compiled(
151+
lambda: default_a2a_fwd(
153152
ref_x, labels, output_splits_list, input_splits_list, mesh
154153
)
155154
)
156155
start_sec = time.perf_counter()
157-
default_a2a_sync_compiled(
156+
bf16_output = default_a2a_fwd(
158157
ref_x, labels, output_splits_list, input_splits_list, mesh
159158
)
160159
end_sec = time.perf_counter()
161-
bf16_ms = (end_sec - start_sec) * 1e3
160+
bf16_fwd_ms = (end_sec - start_sec) * 1e3
162161
if args.profile:
163162
profile_fn(
164-
default_a2a_sync_compiled,
163+
default_a2a_fwd,
165164
ref_x,
166165
labels,
167166
output_splits_list,
168167
input_splits_list,
169168
mesh,
170169
distributed=True,
171-
profile_name="all_to_all_single_autograd",
170+
profile_name="default_a2a_fwd",
172171
)
173172

173+
# Bench default a2a bwd
174+
warmup(lambda: mse_loss_and_bwd(bf16_output, labels))
175+
start_sec = time.perf_counter()
176+
mse_loss_and_bwd(bf16_output, labels)
177+
end_sec = time.perf_counter()
178+
bf16_bwd_ms = (end_sec - start_sec) * 1e3
179+
174180
# Bench mxfp8 sync a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
175181
warmup(
176-
lambda: mxfp8_a2a_sync_compiled(
177-
x, labels, output_splits_list, input_splits_list, mesh
178-
)
182+
lambda: mxfp8_a2a_fwd(x, labels, output_splits_list, input_splits_list, mesh)
179183
)
180184
start_sec = time.perf_counter()
181-
mxfp8_a2a_sync_compiled(x, labels, output_splits_list, input_splits_list, mesh)
185+
mxfp8_output = mxfp8_a2a_fwd(x, labels, output_splits_list, input_splits_list, mesh)
182186
end_sec = time.perf_counter()
183187
mxfp8_ms = (end_sec - start_sec) * 1e3
184188
if args.profile:
185189
profile_fn(
186-
mxfp8_a2a_sync_compiled,
190+
mxfp8_a2a_fwd,
187191
x,
188192
labels,
189193
output_splits_list,
190194
input_splits_list,
191195
mesh,
192196
distributed=True,
193-
profile_name="to_mxfp8_a2a_dequant",
197+
profile_name="mxfp8_a2a_fwd",
194198
)
195199

200+
# Bench mxfp8 a2a bwd
201+
warmup(lambda: mse_loss_and_bwd(mxfp8_output, labels))
202+
start_sec = time.perf_counter()
203+
mse_loss_and_bwd(mxfp8_output, labels)
204+
end_sec = time.perf_counter()
205+
mxfp8_bwd_ms = (end_sec - start_sec) * 1e3
206+
196207
return ExperimentResult(
197-
bf16_ms=bf16_ms,
198-
mxfp8_ms=mxfp8_ms,
208+
bf16_fwd_ms=bf16_fwd_ms,
209+
mxfp8_fwd_ms=mxfp8_ms,
210+
bf16_bwd_ms=bf16_bwd_ms,
211+
mxfp8_bwd_ms=mxfp8_bwd_ms,
199212
)
200213

201214

202215
def print_results(experiments: List[Experiment]):
203216
headers = [
204217
"input_shape",
205218
"num_splits",
206-
"bf16_ms",
207-
"mxfp8_ms",
219+
"fwd_bf16_ms",
220+
"fwd_mxfp8_ms",
221+
"bwd_bf16_ms",
222+
"bwd_mxfp8_ms",
208223
]
209224
rows = []
210225
num_splits = dist.get_world_size()
@@ -213,8 +228,10 @@ def print_results(experiments: List[Experiment]):
213228
[
214229
str(experiment.config.input_shape),
215230
num_splits,
216-
experiment.result.bf16_ms,
217-
experiment.result.mxfp8_ms,
231+
experiment.result.bf16_fwd_ms,
232+
experiment.result.mxfp8_fwd_ms,
233+
experiment.result.bf16_bwd_ms,
234+
experiment.result.mxfp8_bwd_ms,
218235
]
219236
)
220237
print(tabulate(rows, headers=headers))

torchao/prototype/moe_training/kernels/mxfp8/comms.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
blockwise_barrier,
1212
sync_threads,
1313
)
14-
from torchao.prototype.mx_formats.config import ScaleCalculationMode
14+
from torchao.prototype.mx_formats.kernels import (
15+
triton_mxfp8_dequant_dim0,
16+
triton_to_mxfp8_dim0,
17+
)
1518
from torchao.prototype.mx_formats.mx_tensor import to_dtype, to_mx
1619

1720

@@ -473,11 +476,9 @@ def forward(
473476
"""
474477
# Quantize input
475478
block_size = 32
476-
input_scales, input_data = to_mx(
479+
input_data, input_scales = triton_to_mxfp8_dim0(
477480
input,
478-
elem_dtype=torch.float8_e4m3fn,
479-
block_size=block_size,
480-
scaling_mode=ScaleCalculationMode.RCEIL,
481+
inner_block_size=block_size,
481482
)
482483

483484
# Dispatch data (async)
@@ -501,20 +502,17 @@ def forward(
501502
output_data = torch.ops._c10d_functional.wait_tensor(output_data)
502503

503504
# Dequantize output
504-
lowp_dtype = output_data.dtype
505505
hp_dtype = input.dtype
506-
hp_output = to_dtype(
506+
triton_hp_output = triton_mxfp8_dequant_dim0(
507507
output_data,
508-
output_scales.view(torch.float8_e8m0fnu),
509-
lowp_dtype,
510-
block_size,
508+
output_scales,
511509
hp_dtype,
510+
block_size,
512511
)
513-
514512
ctx.input_splits = input_splits
515513
ctx.output_splits = output_splits
516514
ctx.group = group
517-
return hp_output
515+
return triton_hp_output
518516

519517
@staticmethod
520518
def backward(ctx, grad_output_hp):
@@ -529,11 +527,9 @@ def backward(ctx, grad_output_hp):
529527

530528
# Quantize grad_output
531529
block_size = 32
532-
grad_out_scales, grad_out_data = to_mx(
530+
grad_out_data, grad_out_scales = triton_to_mxfp8_dim0(
533531
grad_output_hp,
534-
elem_dtype=torch.float8_e4m3fn,
535-
block_size=block_size,
536-
scaling_mode=ScaleCalculationMode.RCEIL,
532+
inner_block_size=block_size,
537533
)
538534

539535
# Dispatch data (async)
@@ -557,13 +553,11 @@ def backward(ctx, grad_output_hp):
557553
grad_input_scales = torch.ops._c10d_functional.wait_tensor(grad_input_scales)
558554

559555
hp_dtype = grad_output_hp.dtype
560-
lowp_dtype = grad_input_data.dtype
561-
grad_input_hp = to_dtype(
556+
grad_input_hp = triton_mxfp8_dequant_dim0(
562557
grad_input_data,
563-
grad_input_scales.view(torch.float8_e8m0fnu),
564-
lowp_dtype,
565-
block_size,
558+
grad_input_scales,
566559
hp_dtype,
560+
block_size,
567561
)
568562
return grad_input_hp, None, None, None
569563

0 commit comments

Comments
 (0)