Skip to content

Commit 9a76360

Browse files
authored
Fix RemoteMixtureOfExperts and RemoteSwitchMixtureOfExperts backward() on GPU (#626)
1 parent a19b61d commit 9a76360

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

hivemind/moe/client/moe.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -313,16 +313,18 @@ def backward(cls, ctx, *raw_grads: torch.Tensor) -> Tuple[Optional[torch.Tensor]
313313

314314
num_samples, max_experts = dummy_grad_mask.shape
315315

316-
inputs_per_expert = zip(*(tensor[alive_ii].split(1, dim=0) for tensor in flat_inputs_cpu))
316+
alive_ii_cpu = alive_ii.cpu()
317+
alive_jj_cpu = alive_jj.cpu()
318+
inputs_per_expert = zip(*(tensor[alive_ii_cpu].split(1, dim=0) for tensor in flat_inputs_cpu))
317319
grad_outputs_per_expert = zip(
318-
*(tensor[alive_ii, alive_jj].split(1, dim=0) for tensor in flat_grad_outputs_cpu)
320+
*(tensor[alive_ii_cpu, alive_jj_cpu].split(1, dim=0) for tensor in flat_grad_outputs_cpu)
319321
)
320322
backward_schema = tuple(nested_flatten((info["forward_schema"], info["outputs_schema"])))
321323

322324
# dispatch tasks to all remote experts, collect responses
323325
pending_tasks = {}
324326
for i, j, inputs_ij, grad_outputs_ij in zip(
325-
alive_ii.cpu().numpy(), alive_jj.cpu().numpy(), inputs_per_expert, grad_outputs_per_expert
327+
alive_ii_cpu.numpy(), alive_jj_cpu.numpy(), inputs_per_expert, grad_outputs_per_expert
326328
):
327329
expert: RemoteExpert = expert_per_sample[i.item()][j.item()]
328330
stub = get_server_stub(expert.p2p, expert.peer_id)

0 commit comments

Comments
 (0)