Skip to content

index_put with incompatible shardings on 2D mesh #348

@aditvenk

Description

@aditvenk

Repro script with more details inline. Issue first seen in falcon HF model with 2D mesh

"""
Minimal repro: index_put with incompatible shardings on 2D mesh.

Run: python examples/repro_index_put_2d.py

== Bug ==

On a 2D mesh, autoparallel's apply_sharding fails to insert a redistribute
before aten.index_put, causing Inductor to see incompatible local tensor shapes
and crash with:

    Broadcast failed in ExpandView([16, 32, 1, 64], [4, 128, 1, 64])

The bug is in autoparallel, not DTensor. In normal DTensor execution, inputs
to an op already have concrete placements resolved one-at-a-time. In
autoparallel, the solver optimizes all placements simultaneously and can pick
RS(0) for index_put's `self` and RS(1) for `values` (from different upstream
paths). The solver correctly identifies the redistribution cost (transition_cost
= 1 in the verbose output), but apply_sharding doesn't actually insert the
redistribute into the compiled graph.

== How the model triggers index_put ==

Falcon-7b uses multi-query attention with list-based advanced indexing:

    fused_qkv[..., [-2], :]   # extract single K head

Dynamo decomposes `[-2]` (a list index, i.e., advanced indexing) into:

    full(zeros) -> index_put(zeros, indices, values, accumulate=True)

in the joint forward+backward graph. The index_put scatters gradients back
through the list-index read during the backward pass.

== Why both branches are needed ==

The forward has two uses of `h`: a slice `h[..., :-2, :]` (shape [B,S,71,D])
and a list-index `h[..., [-2], :]` (shape [B,S,1,D]). These produce outputs
with different shapes. On a 2D mesh, the solver assigns different shardings to
these two paths (e.g., S(0) vs S(1) on the second mesh dim), which become
incompatible when the backward's index_put tries to combine them.

A single use of `h` doesn't trigger the bug — the solver has no reason to pick
conflicting shardings when there's only one path.

== Why small tensors don't trigger it ==

The solver only picks the mixed RS(0)/RS(1) strategy when the tensors are large
enough that the communication cost savings outweigh the redistribution cost.
With small tensors, the solver defaults to compatible shardings. The Falcon-scale
dimensions (4672 hidden, 73 heads, 64 head_dim) are needed to trigger the
cost-driven divergence.
"""

import torch
import torch.nn as nn
from torch.distributed.fsdp import MixedPrecisionPolicy
from torch.distributed.tensor.placement_types import Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore

from autoparallel.api import AutoParallel


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.proj = nn.Linear(4672, 4672, bias=False)

    def forward(self, x):
        h = self.proj(x)
        h = h.view(x.shape[0], x.shape[1], 73, 64)
        # Both branches needed: backward of the list-index ([-2]) generates
        # index_put to scatter gradients, and having two uses of h with
        # different output shapes causes the solver to pick incompatible
        # shardings on the 2D mesh.
        q = h[..., :-2, :]
        k = h[..., [-2], :]
        return q + k.expand_as(q)


def main():
    fake_store = FakeStore()
    torch.distributed.init_process_group("fake", store=fake_store, rank=0, world_size=8)
    mesh = torch.distributed.device_mesh.init_device_mesh(
        "cuda", (2, 4), mesh_dim_names=("dim0", "dim1")
    )

    with torch.device("meta"):
        model = Model()

    def input_fn():
        return torch.randn(16, 128, 4672, device="cuda")

    mp_policy = MixedPrecisionPolicy(
        param_dtype=torch.bfloat16, reduce_dtype=torch.float32
    )

    with AutoParallel(model, input_fn, mesh, mp_policy, compile=True) as autop:
        autop.add_input_constraints([(Shard(0), Replicate())])
        placement = autop.optimize_placement(verbose=True)
        autop.apply_placement(placement)

    print("OK")


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions