@@ -42,8 +42,10 @@ class ExperimentConfig:
4242
4343@dataclass (frozen = True )
4444class ExperimentResult :
45- bf16_ms : float
46- mxfp8_ms : float
45+ bf16_fwd_ms : float
46+ mxfp8_fwd_ms : float
47+ bf16_bwd_ms : float
48+ mxfp8_bwd_ms : float
4749
4850
4951@dataclass (frozen = True )
@@ -67,7 +69,7 @@ def get_configs() -> List[ExperimentConfig]:
6769 return configs
6870
6971
70- def default_a2a_fwd_bwd (
72+ def default_a2a_fwd (
7173 routed_input : torch .Tensor ,
7274 labels : torch .Tensor ,
7375 output_splits_list : list [int ],
@@ -81,15 +83,11 @@ def default_a2a_fwd_bwd(
8183 device_mesh .get_group (),
8284 )
8385 routed_input = torch .ops ._c10d_functional .wait_tensor (routed_input )
84-
85- loss = F .mse_loss (routed_input , labels )
86- loss .backward ()
87-
8886 torch .cuda .synchronize ()
8987 return routed_input
9088
9189
92- def mxfp8_a2a_fwd_bwd (
90+ def mxfp8_a2a_fwd (
9391 routed_input : torch .Tensor ,
9492 labels : torch .Tensor ,
9593 output_splits_list : list [int ],
@@ -102,16 +100,17 @@ def mxfp8_a2a_fwd_bwd(
102100 input_splits_list ,
103101 device_mesh .get_group (),
104102 )
105-
106- loss = F .mse_loss (routed_input , labels )
107- loss .backward ()
108103 torch .cuda .synchronize ()
109104 return routed_input
110105
111106
112- # Compile target funcs
113- default_a2a_sync_compiled = torch .compile (default_a2a_fwd_bwd )
114- mxfp8_a2a_sync_compiled = torch .compile (mxfp8_a2a_fwd_bwd )
107+ def mse_loss_and_bwd (
108+ routed_input : torch .Tensor ,
109+ labels : torch .Tensor ,
110+ ):
111+ loss = F .mse_loss (routed_input , labels )
112+ loss .backward ()
113+ torch .cuda .synchronize ()
115114
116115
117116def run_experiment (
@@ -149,62 +148,78 @@ def warmup(func_no_args):
149148
150149 # Bench default a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
151150 warmup (
152- lambda : default_a2a_sync_compiled (
151+ lambda : default_a2a_fwd (
153152 ref_x , labels , output_splits_list , input_splits_list , mesh
154153 )
155154 )
156155 start_sec = time .perf_counter ()
157- default_a2a_sync_compiled (
156+ bf16_output = default_a2a_fwd (
158157 ref_x , labels , output_splits_list , input_splits_list , mesh
159158 )
160159 end_sec = time .perf_counter ()
161- bf16_ms = (end_sec - start_sec ) * 1e3
160+ bf16_fwd_ms = (end_sec - start_sec ) * 1e3
162161 if args .profile :
163162 profile_fn (
164- default_a2a_sync_compiled ,
163+ default_a2a_fwd ,
165164 ref_x ,
166165 labels ,
167166 output_splits_list ,
168167 input_splits_list ,
169168 mesh ,
170169 distributed = True ,
171- profile_name = "all_to_all_single_autograd " ,
170+ profile_name = "default_a2a_fwd " ,
172171 )
173172
173+ # Bench default a2a bwd
174+ warmup (lambda : mse_loss_and_bwd (bf16_output , labels ))
175+ start_sec = time .perf_counter ()
176+ mse_loss_and_bwd (bf16_output , labels )
177+ end_sec = time .perf_counter ()
178+ bf16_bwd_ms = (end_sec - start_sec ) * 1e3
179+
174180 # Bench mxfp8 sync a2a (exclude d2h sync from preparing input splits_list and output_splits_list)
175181 warmup (
176- lambda : mxfp8_a2a_sync_compiled (
177- x , labels , output_splits_list , input_splits_list , mesh
178- )
182+ lambda : mxfp8_a2a_fwd (x , labels , output_splits_list , input_splits_list , mesh )
179183 )
180184 start_sec = time .perf_counter ()
181- mxfp8_a2a_sync_compiled (x , labels , output_splits_list , input_splits_list , mesh )
185+ mxfp8_output = mxfp8_a2a_fwd (x , labels , output_splits_list , input_splits_list , mesh )
182186 end_sec = time .perf_counter ()
183187 mxfp8_ms = (end_sec - start_sec ) * 1e3
184188 if args .profile :
185189 profile_fn (
186- mxfp8_a2a_sync_compiled ,
190+ mxfp8_a2a_fwd ,
187191 x ,
188192 labels ,
189193 output_splits_list ,
190194 input_splits_list ,
191195 mesh ,
192196 distributed = True ,
193- profile_name = "to_mxfp8_a2a_dequant " ,
197+ profile_name = "mxfp8_a2a_fwd " ,
194198 )
195199
200+ # Bench mxfp8 a2a bwd
201+ warmup (lambda : mse_loss_and_bwd (mxfp8_output , labels ))
202+ start_sec = time .perf_counter ()
203+ mse_loss_and_bwd (mxfp8_output , labels )
204+ end_sec = time .perf_counter ()
205+ mxfp8_bwd_ms = (end_sec - start_sec ) * 1e3
206+
196207 return ExperimentResult (
197- bf16_ms = bf16_ms ,
198- mxfp8_ms = mxfp8_ms ,
208+ bf16_fwd_ms = bf16_fwd_ms ,
209+ mxfp8_fwd_ms = mxfp8_ms ,
210+ bf16_bwd_ms = bf16_bwd_ms ,
211+ mxfp8_bwd_ms = mxfp8_bwd_ms ,
199212 )
200213
201214
202215def print_results (experiments : List [Experiment ]):
203216 headers = [
204217 "input_shape" ,
205218 "num_splits" ,
206- "bf16_ms" ,
207- "mxfp8_ms" ,
219+ "fwd_bf16_ms" ,
220+ "fwd_mxfp8_ms" ,
221+ "bwd_bf16_ms" ,
222+ "bwd_mxfp8_ms" ,
208223 ]
209224 rows = []
210225 num_splits = dist .get_world_size ()
@@ -213,8 +228,10 @@ def print_results(experiments: List[Experiment]):
213228 [
214229 str (experiment .config .input_shape ),
215230 num_splits ,
216- experiment .result .bf16_ms ,
217- experiment .result .mxfp8_ms ,
231+ experiment .result .bf16_fwd_ms ,
232+ experiment .result .mxfp8_fwd_ms ,
233+ experiment .result .bf16_bwd_ms ,
234+ experiment .result .mxfp8_bwd_ms ,
218235 ]
219236 )
220237 print (tabulate (rows , headers = headers ))
0 commit comments