@@ -313,16 +313,18 @@ def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor]
313313
314314 num_samples , max_experts = dummy_grad_mask .shape
315315
316- inputs_per_expert = zip (* (tensor [alive_ii ].split (1 , dim = 0 ) for tensor in flat_inputs_cpu ))
316+ alive_ii_cpu = alive_ii .cpu ()
317+ alive_jj_cpu = alive_jj .cpu ()
318+ inputs_per_expert = zip (* (tensor [alive_ii_cpu ].split (1 , dim = 0 ) for tensor in flat_inputs_cpu ))
317319 grad_outputs_per_expert = zip (
318- * (tensor [alive_ii , alive_jj ].split (1 , dim = 0 ) for tensor in flat_grad_outputs_cpu )
320+ * (tensor [alive_ii_cpu , alive_jj_cpu ].split (1 , dim = 0 ) for tensor in flat_grad_outputs_cpu )
319321 )
320322 backward_schema = tuple (nested_flatten ((info ["forward_schema" ], info ["outputs_schema" ])))
321323
322324 # dispatch tasks to all remote experts, collect responses
323325 pending_tasks = {}
324326 for i , j , inputs_ij , grad_outputs_ij in zip (
325- alive_ii . cpu (). numpy (), alive_jj . cpu () .numpy (), inputs_per_expert , grad_outputs_per_expert
327+ alive_ii_cpu . numpy (), alive_jj_cpu .numpy (), inputs_per_expert , grad_outputs_per_expert
326328 ):
327329 expert : RemoteExpert = expert_per_sample [i .item ()][j .item ()]
328330 stub = get_server_stub (expert .p2p , expert .peer_id )
0 commit comments