-
Notifications
You must be signed in to change notification settings - Fork 16
Open
Description
Repro script with more details inline. Issue first seen in falcon HF model with 2D mesh
"""
Minimal repro: index_put with incompatible shardings on 2D mesh.
Run: python examples/repro_index_put_2d.py
== Bug ==
On a 2D mesh, autoparallel's apply_sharding fails to insert a redistribute
before aten.index_put, causing Inductor to see incompatible local tensor shapes
and crash with:
Broadcast failed in ExpandView([16, 32, 1, 64], [4, 128, 1, 64])
The bug is in autoparallel, not DTensor. In normal DTensor execution, inputs
to an op already have concrete placements resolved one-at-a-time. In
autoparallel, the solver optimizes all placements simultaneously and can pick
RS(0) for index_put's `self` and RS(1) for `values` (from different upstream
paths). The solver correctly identifies the redistribution cost (transition_cost
= 1 in the verbose output), but apply_sharding doesn't actually insert the
redistribute into the compiled graph.
== How the model triggers index_put ==
Falcon-7b uses multi-query attention with list-based advanced indexing:
fused_qkv[..., [-2], :] # extract single K head
Dynamo decomposes `[-2]` (a list index, i.e., advanced indexing) into:
full(zeros) -> index_put(zeros, indices, values, accumulate=True)
in the joint forward+backward graph. The index_put scatters gradients back
through the list-index read during the backward pass.
== Why both branches are needed ==
The forward has two uses of `h`: a slice `h[..., :-2, :]` (shape [B,S,71,D])
and a list-index `h[..., [-2], :]` (shape [B,S,1,D]). These produce outputs
with different shapes. On a 2D mesh, the solver assigns different shardings to
these two paths (e.g., S(0) vs S(1) on the second mesh dim), which become
incompatible when the backward's index_put tries to combine them.
A single use of `h` doesn't trigger the bug — the solver has no reason to pick
conflicting shardings when there's only one path.
== Why small tensors don't trigger it ==
The solver only picks the mixed RS(0)/RS(1) strategy when the tensors are large
enough that the communication cost savings outweigh the redistribution cost.
With small tensors, the solver defaults to compatible shardings. The Falcon-scale
dimensions (4672 hidden, 73 heads, 64 head_dim) are needed to trigger the
cost-driven divergence.
"""
import torch
import torch.nn as nn
from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.tensor.placement_types import Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore
from autoparallel.api import AutoParallel
class Model(nn.Module):
def __init__(self):
super().__init__()
self.proj = nn.Linear(4672, 4672, bias=False)
def forward(self, x):
h = self.proj(x)
h = h.view(x.shape[0], x.shape[1], 73, 64)
# Both branches needed: backward of the list-index ([-2]) generates
# index_put to scatter gradients, and having two uses of h with
# different output shapes causes the solver to pick incompatible
# shardings on the 2D mesh.
q = h[..., :-2, :]
k = h[..., [-2], :]
return q + k.expand_as(q)
def main():
fake_store = FakeStore()
torch.distributed.init_process_group("fake", store=fake_store, rank=0, world_size=8)
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", (2, 4), mesh_dim_names=("dim0", "dim1")
)
with torch.device("meta"):
model = Model()
def input_fn():
return torch.randn(16, 128, 4672, device="cuda")
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
)
with AutoParallel(model, input_fn, mesh, mp_policy, compile=True) as autop:
autop.add_input_constraints([(Shard(0), Replicate())])
placement = autop.optimize_placement(verbose=True)
autop.apply_placement(placement)
print("OK")
if __name__ == "__main__":
main()Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels