From ccd3575f6edbe2c50bfef6485ee6055572f860fa Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 13 Mar 2026 13:31:06 +0000 Subject: [PATCH] Fix group name lookup crash for flattened mesh collectives in debug_helpers make_custom_runtime_estimation assumed every collective's process group corresponds to a single mesh dimension, using groups_name.index(group_name) to resolve it. This crashes with ValueError when a collective uses the flattened mesh group created by mesh._flatten() in _apply_placement_common, since that group isn't in mesh.get_all_groups(). The fix resolves the group name against per-dimension groups first, then falls back to mesh._flatten() (which is cached/idempotent) with mesh_dim=0. All downstream shape and topology references use cost_mesh so the cost estimation is correct for both per-dimension and flattened collectives. Authored with Claude. --- autoparallel/graph_passes/debug_helpers.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/autoparallel/graph_passes/debug_helpers.py b/autoparallel/graph_passes/debug_helpers.py index 9436191d..d75436f4 100644 --- a/autoparallel/graph_passes/debug_helpers.py +++ b/autoparallel/graph_passes/debug_helpers.py @@ -125,11 +125,20 @@ def custom_runtime_estimation(node: torch.fx.Node, override_size=None): target = node.target if target == torch.ops._c10d_functional.wait_tensor.default: return 0 - # TODO: figure out mesh without reading from global scope - mesh_topo = MeshTopoInfo.build_from_mesh(mesh) - groups_name = tuple(g.group_name for g in mesh.get_all_groups()) + # Resolve group_name to a mesh and dimension. Collectives may + # use per-dimension groups or the flattened mesh group created + # by mesh._flatten() in _apply_placement_common. group_name = get_group_name(node) - mesh_dim = groups_name.index(group_name) + groups_name = tuple(g.group_name for g in mesh.get_all_groups()) + if group_name in groups_name: + cost_mesh = mesh + mesh_dim = groups_name.index(group_name) + elif mesh.ndim > 1: + cost_mesh = mesh._flatten() + mesh_dim = 0 + else: + return 0 + mesh_topo = MeshTopoInfo.build_from_mesh(cost_mesh) t = node.args[0].meta["val"] # type: ignore[union-attr] comm_bytes_gb = t.numel() * t.itemsize / 2**30 if override_size is not None: @@ -138,7 +147,7 @@ def custom_runtime_estimation(node: torch.fx.Node, override_size=None): torch.ops._c10d_functional.all_gather_into_tensor.default, torch.ops._c10d_functional.all_gather_into_tensor_out.default, }: - comm_bytes_gb *= mesh.shape[mesh_dim] + comm_bytes_gb *= cost_mesh.shape[mesh_dim] return allgather_cost(comm_bytes_gb, mesh_topo, mesh_dim) elif target == torch.ops._c10d_functional.reduce_scatter_tensor.default: return reduce_scatter_cost(comm_bytes_gb, mesh_topo, mesh_dim)