From a862ecb817050acf7b3efd6eba73cd13484ffdfc Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 16 Feb 2026 20:30:29 +0000 Subject: [PATCH 01/22] Add new implementation for OptimizeSharding --- autoparallel/optimize_sharding_new.py | 865 ++++++++++++++++++++++++ examples/example_autoparallel_factor.py | 211 ++++++ 2 files changed, 1076 insertions(+) create mode 100644 autoparallel/optimize_sharding_new.py create mode 100644 examples/example_autoparallel_factor.py diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py new file mode 100644 index 00000000..310550cb --- /dev/null +++ b/autoparallel/optimize_sharding_new.py @@ -0,0 +1,865 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +""" +Factor-based sharding optimization using Integer Linear Programming (ILP). + +This module reformulates the sharding optimization problem using "factors" — +logical dimensions of computation inspired by Shardy's factor-based propagation +(see shardy/dialect/sdy/ir/attrs.td OpShardingRuleAttr). + +Key idea +-------- +Instead of enumerating all placement combinations per op — which is O((d+1)^k) +per tensor where d = tensor dims and k = mesh dims — each op is decomposed into +factors, and the ILP decides which mesh dimension (if any) shards each factor. + + Original ILP variables per op: O(A × (d+1)^(2k)) + Factor ILP variables per op: O(F × k) + + where A = args, F = factors, d = tensor dims, k = mesh dims + +For a matmul on a 4D mesh: ~13,000 → ~12 variables per op. + +Factor extraction +----------------- +Factors are extracted *generically* from existing DTensor OpStrategy objects by +inspecting placement patterns on a single mesh dimension. Because most +OpStrategies are Cartesian products of per-mesh-dim "atoms" (via +``expand_to_full_mesh_op_strategy``), each unique non-trivial atom corresponds +to exactly one factor. This means we reuse all existing DTensor op rules +without writing per-op factor definitions. + +Example: ``mm(A[M,K], B[K,N]) -> C[M,N]`` + + 1D atoms (from mesh dim 0): + (C=R, A=R, B=R ) → all-replicate, skip + (C=S(0), A=S(0), B=R ) → Factor "M": {A.dim0, C.dim0} + (C=S(1), A=R, B=S(1)) → Factor "N": {B.dim1, C.dim1} + (C=P, A=S(1), B=S(0)) → Factor "K": {A.dim1, B.dim0}, reduction + + ≡ Shardy's ([i,k],[k,j])->([i,j]) {i=M, j=N, k=K} reduction={k} +""" + +from __future__ import annotations + +import math +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Optional + +import pulp +import torch +import torch.fx +from torch._functorch._aot_autograd.descriptors import PlainAOTInput, PlainAOTOutput +from torch._functorch._aot_autograd.fx_utils import ( + get_plain_input_and_grad_nodes, + get_plain_output_and_tangent_nodes, +) +from torch.distributed._tensor.placement_types import TensorMeta +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import OpSpec, OpStrategy +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard +from torch.utils._pytree import tree_map_only + +from .cost_models.compute_estimation import estimate_strategy_runtime_cost +from .shardings.placement_options import get_placement_options +from .shardings.propagation_rules import _create_all_options + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + + +def _is_partial(p: Placement) -> bool: + """Check if a placement is Partial (reduction output).""" + return not isinstance(p, (Shard, Replicate)) + + +@dataclass +class Factor: + """A computation factor — a logical dimension of the computation. + + Analogous to Shardy's factor concept (from ``OpShardingRuleAttr``). + For ``C[M,N] = A[M,K] @ B[K,N]``, the factors are M, N, K. + + Attributes + ---------- + id : int + Unique id within the parent ``FactorRule``. + size : int + Size along this factor (e.g. M=1024). + is_reduction : bool + True if this factor is a contraction/reduction dimension (the output + is ``Partial`` when it is sharded). + operand_dims : list[int | None] + For each operand, the tensor dim mapped to this factor (None = not mapped). + result_dims : list[int | None] + For each result, the tensor dim mapped to this factor (None = not mapped). + """ + + id: int + size: int + is_reduction: bool = False + operand_dims: list = field(default_factory=list) + result_dims: list = field(default_factory=list) + + +@dataclass +class FactorRule: + """Factor decomposition for one operation. + + Analogous to Shardy's ``OpShardingRuleAttr``. + + Example — ``mm(A[M,K], B[K,N]) -> C[M,N]``:: + + factors = [ + Factor(0, M, operand_dims=[0, None], result_dims=[0]), # M + Factor(1, N, operand_dims=[None, 1], result_dims=[1]), # N + Factor(2, K, operand_dims=[1, 0], result_dims=[], is_reduction=True), # K + ] + """ + + factors: list[Factor] + num_operands: int + num_results: int + + +# --------------------------------------------------------------------------- +# Union-Find +# --------------------------------------------------------------------------- + + +class UnionFind: + """Disjoint-set (union-find) for merging factors across dataflow edges.""" + + def __init__(self) -> None: + self.parent: dict[int, int] = {} + self.rank: dict[int, int] = {} + + def make_set(self, x: int) -> None: + if x not in self.parent: + self.parent[x] = x + self.rank[x] = 0 + + def find(self, x: int) -> int: + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] # path halving + x = self.parent[x] + return x + + def union(self, x: int, y: int) -> int: + rx, ry = self.find(x), self.find(y) + if rx == ry: + return rx + if self.rank[rx] < self.rank[ry]: + rx, ry = ry, rx + self.parent[ry] = rx + if self.rank[rx] == self.rank[ry]: + self.rank[rx] += 1 + return rx + + +# --------------------------------------------------------------------------- +# Factor extraction from OpStrategy +# --------------------------------------------------------------------------- + + +def _infer_factor_size( + node: torch.fx.Node, + operand_dims: list[int | None], + result_dims: list[int | None], +) -> int: + """Infer the size of a factor from the node's tensor metadata.""" + # Try the node's own output first. + val = node.meta.get("val") + if val is not None and isinstance(val, torch.Tensor): + for d in result_dims: + if d is not None and d < len(val.shape): + return val.shape[d] + # Fall back to operand shapes. + for arg_idx, d in enumerate(operand_dims): + if d is None or arg_idx >= len(node.args): + continue + arg = node.args[arg_idx] + if isinstance(arg, torch.fx.Node): + arg_val = arg.meta.get("val") + if ( + arg_val is not None + and isinstance(arg_val, torch.Tensor) + and d < len(arg_val.shape) + ): + return arg_val.shape[d] + return 1 # fallback + + +def extract_factors_from_strategy( + op_strategy: OpStrategy, + node: torch.fx.Node, +) -> FactorRule: + """Convert an ``OpStrategy`` into a ``FactorRule``. + + Each *unique* per-mesh-dimension placement pattern (excluding all-replicate) + in the strategy set corresponds to one factor. We inspect mesh dim 0, + which is valid because ``expand_to_full_mesh_op_strategy`` replicates the + same ``single_mesh_dim_strategies`` across all mesh dims. + + Parameters + ---------- + op_strategy : OpStrategy + Multi-dim strategy (may have been expanded via Cartesian product). + node : torch.fx.Node + The FX node, used for shape metadata. + + Returns + ------- + FactorRule + """ + if not op_strategy.strategies: + return FactorRule(factors=[], num_operands=0, num_results=1) + + first_spec = op_strategy.strategies[0] + num_operands = len(first_spec.input_specs) if first_spec.input_specs else 0 + + # Collect unique 1-D "atoms" by looking at mesh dim 0. + seen: dict[str, tuple] = {} + for spec in op_strategy.strategies: + out_specs = spec.output_specs + out_p = ( + out_specs.placements[0] + if isinstance(out_specs, DTensorSpec) + else out_specs[0].placements[0] + ) + in_ps = tuple( + s.placements[0] + for s in (spec.input_specs or []) + if isinstance(s, DTensorSpec) + ) + all_ps = (out_p,) + in_ps + if all(isinstance(p, Replicate) for p in all_ps): + continue # skip the all-replicate atom + key = str(all_ps) + if key not in seen: + seen[key] = (out_p, in_ps) + + # Each atom → one Factor. + factors: list[Factor] = [] + for factor_id, (out_p, in_ps) in enumerate(seen.values()): + is_reduction = _is_partial(out_p) + result_dims = [out_p.dim if isinstance(out_p, Shard) else None] + operand_dims = [p.dim if isinstance(p, Shard) else None for p in in_ps] + size = _infer_factor_size(node, operand_dims, result_dims) + factors.append( + Factor( + id=factor_id, + size=size, + is_reduction=is_reduction, + operand_dims=operand_dims, + result_dims=result_dims, + ) + ) + + return FactorRule(factors=factors, num_operands=num_operands, num_results=1) + + +def _placeholder_factor_rule(node: torch.fx.Node) -> FactorRule: + """Create a ``FactorRule`` for a placeholder / get_attr node. + + Each tensor dimension becomes an independent spatial factor. + """ + val = node.meta.get("val") + if val is None or not isinstance(val, torch.Tensor): + return FactorRule(factors=[], num_operands=0, num_results=1) + shape = val.shape + factors = [ + Factor(id=d, size=shape[d], operand_dims=[], result_dims=[d]) + for d in range(len(shape)) + ] + return FactorRule(factors=factors, num_operands=0, num_results=1) + + +# --------------------------------------------------------------------------- +# Factor-based sharding optimizer +# --------------------------------------------------------------------------- + + +class FactorShardingOptimizer: + """Sharding optimizer using factor-based ILP variables. + + Public API mirrors :class:`ShardingOptimizer` where possible so + that it can be used as a drop-in replacement (modulo output format). + + Parameters + ---------- + gm : torch.fx.GraphModule + Traced FX graph (joint forward + backward). + mesh : DeviceMesh + Target device mesh. + rescale_grad_comm_cost_for_mp : float + Scaling factor for gradient communication costs (mixed precision). + """ + + def __init__( + self, + gm: torch.fx.GraphModule, + mesh: Any, + rescale_grad_comm_cost_for_mp: float = 1.0, + ) -> None: + self.gm = gm + self.graph = gm.graph + self.nodes = list(self.graph.nodes) + self.mesh = mesh + self.node_map: dict[torch.fx.Node, int] = { + n: i for i, n in enumerate(self.nodes) + } + + # -- Step 1: build multi-dim strategies (reuses existing DTensor rules) + # NOTE: in a production implementation you would build strategies for a + # *1-D* mesh only (O(d) per node instead of O((d+1)^k)), e.g. via + # flat_mesh = mesh._flatten("flat") + # For this POC we reuse the real mesh so that all existing op rules work + # unchanged. The key savings come from the ILP reformulation below. + self.strats = self._build_sharding_metadata() + + # -- Step 2: extract factor rules from strategies. + self.factor_rules: dict[torch.fx.Node, FactorRule] = {} + self._extract_all_factor_rules() + + # -- Step 3: merge factors across dataflow edges (union-find). + self.uf = UnionFind() + self.factor_keys: dict[tuple[int, int], int] = {} # (node_idx, local) → gid + self._next_gid = 0 + self._build_factor_graph() + + # -- Step 4: collect per-root metadata for cost model. + # root → [(node, factor, local_idx)] + self.factor_ops: dict[ + int, list[tuple[torch.fx.Node, Factor, int]] + ] = defaultdict(list) + self._collect_factor_metadata() + + # -- Step 5: build ILP. + self.prob = pulp.LpProblem("AutoParallel_Factor", pulp.LpMinimize) + self.y_vars: dict[tuple[int, int], pulp.LpVariable] = {} + self._build_ilp() + + # ----------------------------------------------------------------- + # Step 1 — strategy building (mirrors ShardingOptimizer) + # ----------------------------------------------------------------- + + def _build_sharding_metadata(self) -> dict[torch.fx.Node, OpStrategy]: + strats: dict[torch.fx.Node, OpStrategy] = {} + for node in self.graph.nodes: + if node.op == "placeholder": + strats[node] = _create_all_options( + self.mesh, node.meta["val"].shape, tensor=node.meta["val"] + ) + elif node.op == "call_function": + user_strats = tree_map_only( + torch.fx.Node, lambda x: strats[x], node.args + ) + user_args = tree_map_only( + torch.fx.Node, lambda x: x.meta["val"], node.args + ) + user_kwargs = tree_map_only( + torch.fx.Node, lambda x: x.meta["val"], node.kwargs + ) + strats[node] = get_placement_options( + self.mesh, + node.target, + user_strats, + user_args, + user_kwargs, + ) + elif node.op == "get_attr": + strats[node] = _create_all_options( + self.mesh, node.meta["val"].shape, tensor=node.meta["val"] + ) + return strats + + # ----------------------------------------------------------------- + # Step 2 — factor extraction + # ----------------------------------------------------------------- + + def _extract_all_factor_rules(self) -> None: + for node in self.graph.nodes: + if node.op in ("placeholder", "get_attr"): + self.factor_rules[node] = _placeholder_factor_rule(node) + elif node.op == "call_function" and node in self.strats: + self.factor_rules[node] = extract_factors_from_strategy( + self.strats[node], node + ) + elif node.op == "output": + self.factor_rules[node] = FactorRule( + factors=[], num_operands=0, num_results=0 + ) + + # ----------------------------------------------------------------- + # Step 3 — factor graph (union-find merging across edges) + # ----------------------------------------------------------------- + + def _alloc_gid(self) -> int: + gid = self._next_gid + self._next_gid += 1 + return gid + + def _build_factor_graph(self) -> None: + # Register every factor. + for node in self.graph.nodes: + rule = self.factor_rules.get(node) + if rule is None: + continue + nidx = self.node_map[node] + for li, _ in enumerate(rule.factors): + gid = self._alloc_gid() + self.factor_keys[(nidx, li)] = gid + self.uf.make_set(gid) + + # Merge factors across producer → consumer edges. + for node in self.graph.nodes: + if node.op != "call_function": + continue + consumer_rule = self.factor_rules.get(node) + if consumer_rule is None: + continue + cidx = self.node_map[node] + + for arg_pos, arg in enumerate(node.args): + if not isinstance(arg, torch.fx.Node): + continue + producer_rule = self.factor_rules.get(arg) + if producer_rule is None: + continue + pidx = self.node_map[arg] + + # Match: consumer operand dim == producer result dim on the + # same positional dimension → same logical factor. + for c_li, c_fac in enumerate(consumer_rule.factors): + if arg_pos >= len(c_fac.operand_dims): + continue + c_dim = c_fac.operand_dims[arg_pos] + if c_dim is None: + continue + for p_li, p_fac in enumerate(producer_rule.factors): + if not p_fac.result_dims: + continue + p_dim = p_fac.result_dims[0] + if p_dim is not None and p_dim == c_dim: + pk = self.factor_keys.get((pidx, p_li)) + ck = self.factor_keys.get((cidx, c_li)) + if pk is not None and ck is not None: + self.uf.union(pk, ck) + + # ----------------------------------------------------------------- + # Step 4 — metadata collection + # ----------------------------------------------------------------- + + def _collect_factor_metadata(self) -> None: + for node in self.graph.nodes: + rule = self.factor_rules.get(node) + if rule is None: + continue + nidx = self.node_map[node] + for li, fac in enumerate(rule.factors): + gid = self.factor_keys.get((nidx, li)) + if gid is None: + continue + root = self.uf.find(gid) + self.factor_ops[root].append((node, fac, li)) + + def _unique_roots(self) -> set[int]: + return {self.uf.find(gid) for gid in self.factor_keys.values()} + + # ----------------------------------------------------------------- + # Step 5 — ILP construction + # ----------------------------------------------------------------- + + def _build_ilp(self) -> None: + roots = self._unique_roots() + + # --- Variables: y[root, mesh_dim] ∈ {0, 1} --- + for r in roots: + for m in range(self.mesh.ndim): + self.y_vars[(r, m)] = pulp.LpVariable(f"y_{r}_m{m}", cat="Binary") + + # --- Constraints --- + self._add_factor_uniqueness(roots) + self._add_tensor_exclusion() + + # --- Objective --- + self._add_objective(roots) + + # ---- constraints ------------------------------------------------ + + def _add_factor_uniqueness(self, roots: set[int]) -> None: + """Each factor is assigned to *at most one* mesh dimension.""" + for r in roots: + self.prob += ( + pulp.lpSum(self.y_vars[(r, m)] for m in range(self.mesh.ndim)) <= 1, + f"fac_uniq_{r}", + ) + + def _add_tensor_exclusion(self) -> None: + """Per tensor per mesh dim, at most one factor can be sharded. + + This encodes the DTensor invariant: a tensor dimension can only appear + as ``Shard(d)`` for a single ``d`` on each mesh dimension. + + Important: multiple factors at the same node may share a root (after + union-find merging, e.g. nheads and head_dim from unflatten both map to + the hidden input dimension). We must deduplicate by root to avoid + counting the same ILP variable twice, which would turn ``sum <= 1`` + into ``2*y <= 1`` and incorrectly force that variable to 0. + """ + cid = 0 + for node in self.graph.nodes: + rule = self.factor_rules.get(node) + if rule is None or not rule.factors: + continue + nidx = self.node_map[node] + + # — result tensor — + for m in range(self.mesh.ndim): + vs = [] + seen_roots: set[int] = set() + for li, fac in enumerate(rule.factors): + if fac.result_dims and fac.result_dims[0] is not None: + gid = self.factor_keys.get((nidx, li)) + if gid is not None: + root = self.uf.find(gid) + if root not in seen_roots: + seen_roots.add(root) + vs.append(self.y_vars[(root, m)]) + if len(vs) > 1: + self.prob += pulp.lpSum(vs) <= 1, f"tex_r_{cid}" + cid += 1 + + # — operand tensors — + for oi in range(rule.num_operands): + for m in range(self.mesh.ndim): + vs = [] + seen_roots: set[int] = set() + for li, fac in enumerate(rule.factors): + if ( + oi < len(fac.operand_dims) + and fac.operand_dims[oi] is not None + ): + gid = self.factor_keys.get((nidx, li)) + if gid is not None: + root = self.uf.find(gid) + if root not in seen_roots: + seen_roots.add(root) + vs.append(self.y_vars[(root, m)]) + if len(vs) > 1: + self.prob += pulp.lpSum(vs) <= 1, f"tex_o_{cid}" + cid += 1 + + # ---- objective -------------------------------------------------- + + def _add_objective(self, roots: set[int]) -> None: + """Build the cost function. + + For each factor *f* assigned to mesh dim *m*: + + * **Reduction factor** → ``+allreduce_cost(output_bytes, mesh.shape[m])`` + * **Spatial factor** → ``-compute_savings(op, mesh.shape[m])`` + + This is a *first-order* (linear) approximation. The true compute cost + depends on the product of all shard sizes, which would make the + objective quadratic. The linear model captures the dominant effects: + allreduce penalties and per-factor parallelism benefits. + """ + terms: list[Any] = [] + + for r in roots: + refs = self.factor_ops.get(r, []) + for m in range(self.mesh.ndim): + mesh_size = self.mesh.shape[m] + var = self.y_vars[(r, m)] + cost = 0.0 + + for node, fac, _ in refs: + if node.op != "call_function": + continue + + if fac.is_reduction: + # Allreduce: ring algorithm ≈ 2·B·(n-1)/n + out_bytes = self._output_bytes(node) + ar_bytes = 2.0 * out_bytes * (mesh_size - 1) / mesh_size + # Rough bandwidth model: 50 GB/s per link + cost += ar_bytes / 50e9 * 1e6 # microseconds + else: + # Compute benefit: work is divided by mesh_size. + compute = self._compute_cost(node) + benefit = compute * (1.0 - 1.0 / mesh_size) + cost -= benefit + + if cost != 0.0: + terms.append(cost * var) + + if terms: + self.prob += pulp.lpSum(terms) + + # ---- cost helpers ----------------------------------------------- + + @staticmethod + def _output_bytes(node: torch.fx.Node) -> float: + val = node.meta.get("val") + if val is not None and isinstance(val, torch.Tensor): + return float(val.numel() * val.element_size()) + return 0.0 + + @staticmethod + def _compute_cost(node: torch.fx.Node) -> float: + try: + return float(estimate_strategy_runtime_cost(node, None)) + except Exception: + return 0.0 + + # ----------------------------------------------------------------- + # User constraints + # ----------------------------------------------------------------- + + def add_node_constraint( + self, + node: torch.fx.Node, + placement: tuple[Placement, ...], + ) -> None: + """Pin a node's output to a specific placement.""" + rule = self.factor_rules.get(node) + if rule is None: + return + nidx = self.node_map[node] + + for m, p in enumerate(placement): + if isinstance(p, Shard): + for li, fac in enumerate(rule.factors): + if fac.result_dims and fac.result_dims[0] == p.dim: + gid = self.factor_keys.get((nidx, li)) + if gid is not None: + root = self.uf.find(gid) + self.prob += ( + self.y_vars[(root, m)] == 1, + f"pin_{nidx}_f{li}_m{m}", + ) + break + elif isinstance(p, Replicate): + seen_roots: set[int] = set() + for li, fac in enumerate(rule.factors): + if fac.result_dims and fac.result_dims[0] is not None: + gid = self.factor_keys.get((nidx, li)) + if gid is not None: + root = self.uf.find(gid) + if root not in seen_roots: + seen_roots.add(root) + self.prob += ( + self.y_vars[(root, m)] == 0, + f"rep_{nidx}_r{root}_m{m}", + ) + + def add_input_constraints( + self, input_placements: list[tuple[Placement, ...] | None] | None = None + ) -> None: + """Constrain input placements (and their corresponding gradients). + + Uses ``get_plain_input_and_grad_nodes`` to correctly map inputs to + their gradient nodes in the joint fwd+bwd graph, matching the + original :class:`ShardingOptimizer` behaviour. + """ + mut_ips = None + if input_placements is not None: + mut_ips = {i: p for i, p in enumerate(input_placements)} + + for desc, (node, grad_node) in get_plain_input_and_grad_nodes( + self.graph + ).items(): + if input_placements is None: + placement = None + else: + assert isinstance(desc, PlainAOTInput) + assert mut_ips is not None + placement = mut_ips.pop(desc.idx, None) + + if placement is not None: + self.add_node_constraint(node, tuple(placement)) + if grad_node is not None: + self.add_node_constraint(grad_node, tuple(placement)) + + def add_output_constraints( + self, output_placements: list[tuple[Placement, ...] | None] | None = None + ) -> None: + """Constrain output placements (and their corresponding tangents). + + Uses ``get_plain_output_and_tangent_nodes`` to correctly map outputs to + their tangent nodes in the joint fwd+bwd graph, matching the + original :class:`ShardingOptimizer` behaviour. + """ + mut_ops = None + if output_placements is not None: + mut_ops = {i: p for i, p in enumerate(output_placements)} + + for desc, (node, tangent_node) in get_plain_output_and_tangent_nodes( + self.graph + ).items(): + if output_placements is None: + placement = None + else: + assert isinstance(desc, PlainAOTOutput) + assert mut_ops is not None + placement = mut_ops.pop(desc.idx, None) + + if placement is not None: + self.add_node_constraint(node, tuple(placement)) + if tangent_node is not None: + self.add_node_constraint(tangent_node, tuple(placement)) + + # ----------------------------------------------------------------- + # Solve + # ----------------------------------------------------------------- + + def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, DTensorSpec]: + """Solve the factor ILP and reconstruct per-node DTensorSpecs.""" + solver = pulp.PULP_CBC_CMD(msg=verbose) + self.prob.solve(solver) + + if self.prob.status == -1: + diag = self._infeasibility_diagnostics() + raise RuntimeError( + "Factor-based ILP is infeasible. " + "Check that input / output constraints are satisfiable.\n" + diag + ) + + # Extract factor → mesh-dim assignments. + assignment: dict[int, int] = {} # root → mesh_dim + for (root, m), var in self.y_vars.items(): + if var.varValue is not None and var.varValue > 0.5: + assignment[root] = m + + # Reconstruct per-node placements. + result: dict[torch.fx.Node, DTensorSpec] = {} + for node in self.graph.nodes: + if node.op == "output": + continue + rule = self.factor_rules.get(node) + if rule is None: + continue + nidx = self.node_map[node] + placements: list[Placement] = [Replicate()] * self.mesh.ndim + + for li, fac in enumerate(rule.factors): + if not fac.result_dims or fac.result_dims[0] is None: + continue + gid = self.factor_keys.get((nidx, li)) + if gid is None: + continue + root = self.uf.find(gid) + m = assignment.get(root) + if m is not None: + td = fac.result_dims[0] + # A factor may be spatial here but reduction in another op. + # For the *output* placement we use Shard(dim). + placements[m] = Shard(td) + + val = node.meta.get("val") + if val is not None and isinstance(val, torch.Tensor): + tensor_meta = TensorMeta(val.shape, val.stride(), val.dtype) + result[node] = DTensorSpec( + self.mesh, tuple(placements), tensor_meta=tensor_meta + ) + + return result + + # ----------------------------------------------------------------- + # Diagnostics + # ----------------------------------------------------------------- + + def _infeasibility_diagnostics(self) -> str: + """Build a diagnostic string to help debug infeasible ILPs. + + Scans all equality constraints to detect variables pinned to both 0 and + 1 (the most common cause of infeasibility in the factor ILP). + """ + # Collect per-variable equality constraints. + pinned_to_1: dict[str, list[str]] = defaultdict(list) + pinned_to_0: dict[str, list[str]] = defaultdict(list) + for name, c in self.prob.constraints.items(): + # An equality constraint y == v has sense EQ (0) and constant = -v + if c.sense == 0 and len(c) == 1: + # single-variable equality + for var, coeff in c.items(): + val = -c.constant / coeff + if abs(val - 1.0) < 1e-9: + pinned_to_1[var.name].append(name) + elif abs(val) < 1e-9: + pinned_to_0[var.name].append(name) + + conflicts = [] + for var_name in set(pinned_to_1) & set(pinned_to_0): + conflicts.append( + f" Variable {var_name}:\n" + f" pinned to 1 by: {pinned_to_1[var_name]}\n" + f" pinned to 0 by: {pinned_to_0[var_name]}" + ) + + if conflicts: + return "Conflicting constraints found:\n" + "\n".join(conflicts) + return ( + "No direct 0-vs-1 conflicts found; infeasibility may be caused " + "by interacting inequality constraints (tensor exclusion, factor " + "uniqueness)." + ) + + def get_stats(self) -> dict[str, Any]: + """Return ILP size statistics (useful for comparing with original).""" + roots = self._unique_roots() + + # Estimate original variable count. + orig_vars = 0 + for node, strat in self.strats.items(): + if not strat.strategies: + continue + n_out = len(strat.strategies) + first = strat.strategies[0] + n_args = len(first.input_specs) if first.input_specs else 0 + orig_vars += max(n_args, 1) * n_out * n_out + + n_factor_vars = len(self.y_vars) + return { + "num_graph_nodes": len(self.nodes), + "num_unique_factors": len(roots), + "num_factor_ilp_variables": n_factor_vars, + "num_factor_ilp_constraints": len(self.prob.constraints), + "mesh_shape": tuple(self.mesh.shape), + "estimated_original_ilp_variables": orig_vars, + "variable_reduction_ratio": orig_vars / max(n_factor_vars, 1), + } + + def get_log(self, verbose: bool = False) -> str: + """Human-readable summary.""" + lines: list[str] = [] + lines.append(f"Factor ILP status: {pulp.LpStatus[self.prob.status]}") + s = self.get_stats() + lines.append(f"Unique factors: {s['num_unique_factors']}") + lines.append(f"Factor ILP variables: {s['num_factor_ilp_variables']}") + lines.append(f"Factor ILP constraints: {s['num_factor_ilp_constraints']}") + lines.append( + f"Est. original ILP vars: {s['estimated_original_ilp_variables']}" + ) + lines.append(f"Variable reduction: {s['variable_reduction_ratio']:.1f}x") + + if verbose and self.prob.status == 1: + lines.append("") + lines.append("Factor assignments:") + for (root, m), var in sorted(self.y_vars.items()): + if var.varValue is not None and var.varValue > 0.5: + refs = self.factor_ops.get(root, []) + desc = "" + if refs: + _, fac, _ = refs[0] + kind = "reduction" if fac.is_reduction else "spatial" + desc = f" ({kind}, size={fac.size})" + lines.append(f" Factor {root} → mesh dim {m}{desc}") + + return "\n".join(lines) diff --git a/examples/example_autoparallel_factor.py b/examples/example_autoparallel_factor.py new file mode 100644 index 00000000..dd16e967 --- /dev/null +++ b/examples/example_autoparallel_factor.py @@ -0,0 +1,211 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +""" +Comparison of the original enumeration-based ILP vs the factor-based ILP. + +Uses the same transformer Block model as example_autoparallel.py, but instead +of running the full AutoParallel pipeline, it: + 1. Traces the model to obtain the FX graph. + 2. Runs the *original* ShardingOptimizer (enumeration-based). + 3. Runs the *factor-based* FactorShardingOptimizer on the same graph. + 4. Prints a side-by-side comparison of ILP sizes and solutions. + +Usage: + python examples/example_autoparallel_factor.py +""" + +import time + +import torch +from torch import nn +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import Replicate, Shard +from torch.testing._internal.distributed.fake_pg import FakeStore + +from autoparallel.api import AutoParallel +from autoparallel.optimize_sharding import ShardingOptimizer +from autoparallel.optimize_sharding_new import FactorShardingOptimizer + +# --------------------------------------------------------------------------- +# Model (same as example_autoparallel.py, minus activation checkpointing for +# simplicity) +# --------------------------------------------------------------------------- + + +class Block(nn.Module): + def __init__(self, nheads, dim1, dim2): + super().__init__() + self.nheads = nheads + bias = False + self.wq = nn.Linear(dim1, dim1, bias=bias) + self.wk = nn.Linear(dim1, dim1, bias=bias) + self.wv = nn.Linear(dim1, dim1, bias=bias) + self.wo = nn.Linear(dim1, dim1, bias=bias) + self.w1 = nn.Linear(dim1, dim2, bias=bias) + self.w2 = nn.Linear(dim2, dim1, bias=bias) + + def forward(self, x): + q = self.wq(x) + k = self.wk(x) + v = self.wv(x) + + q = q.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) + k = k.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) + v = v.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) + + o = nn.functional.scaled_dot_product_attention(q, k, v) + o = o.permute(0, 2, 1, 3).flatten(-2) + + o = self.wo(o) + o0 = o + x + + o = self.w1(o0) + o = torch.nn.functional.relu(o) + o = self.w2(o) + + o = o0 + o + return o + + +# --------------------------------------------------------------------------- +# Setup +# --------------------------------------------------------------------------- + +world_size = 64 + +fake_store = FakeStore() +torch.distributed.init_process_group( + "fake", store=fake_store, rank=0, world_size=world_size +) + +mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (world_size // 8, 8), + mesh_dim_names=("dp", "tp"), +) + +bs = 8 * mesh.shape[0] +seq_len = 256 +nheads = 48 +dim1 = 6144 +dim2 = dim1 * 4 + + +def input_fn(): + return torch.rand(bs, seq_len, dim1, device="cuda") + + +x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1) + +mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) + +# --------------------------------------------------------------------------- +# Trace the model (reuse AutoParallel for graph capture only) +# --------------------------------------------------------------------------- + +print("=" * 70) +print("Tracing model...") +print("=" * 70) + +with torch.device("meta"): + model = Block(nheads, dim1, dim2) + +with AutoParallel(model, input_fn, mesh, mp_policy) as autop: + gm = autop.gm # the traced FX graph (joint fwd + bwd) + + # ------------------------------------------------------------------ + # 1. Original (enumeration-based) optimizer + # ------------------------------------------------------------------ + print("\n" + "=" * 70) + print("Running ORIGINAL (enumeration-based) ShardingOptimizer") + print("=" * 70) + + t0 = time.perf_counter() + orig_opt = ShardingOptimizer(gm, mesh) + orig_opt.add_grad_param_constraints() + orig_opt.add_sharded_input_constraint([x_sharding]) + orig_opt.add_sharded_output_constraint([x_sharding]) + orig_solution = orig_opt.get_solution(verbose=False) + t_orig = time.perf_counter() - t0 + + print(f" Solve time: {t_orig:.2f}s") + print(f" ILP variables: {len(orig_opt.ds):,}") + print(f" ILP constraints: {len(orig_opt.prob.constraints):,}") + print(f" ILP status: {orig_opt.prob.status}") + + # ------------------------------------------------------------------ + # 2. Factor-based optimizer + # ------------------------------------------------------------------ + print("\n" + "=" * 70) + print("Running FACTOR-BASED FactorShardingOptimizer") + print("=" * 70) + + t0 = time.perf_counter() + factor_opt = FactorShardingOptimizer(gm, mesh) + factor_opt.add_input_constraints([x_sharding]) + factor_opt.add_output_constraints([x_sharding]) + factor_solution = factor_opt.get_solution(verbose=False) + t_factor = time.perf_counter() - t0 + + print(f" Solve time: {t_factor:.2f}s") + print(factor_opt.get_log(verbose=True)) + + # ------------------------------------------------------------------ + # 3. Comparison + # ------------------------------------------------------------------ + stats = factor_opt.get_stats() + + print("\n" + "=" * 70) + print("COMPARISON") + print("=" * 70) + print(f" Mesh shape: {tuple(mesh.shape)}") + print(f" Graph nodes: {stats['num_graph_nodes']}") + print() + print(f" Original ILP variables: {len(orig_opt.ds):,}") + print(f" Factor ILP variables: {stats['num_factor_ilp_variables']:,}") + print(f" Variable reduction: {stats['variable_reduction_ratio']:.1f}x") + print() + print(f" Original ILP constraints:{len(orig_opt.prob.constraints):,}") + print(f" Factor ILP constraints: {stats['num_factor_ilp_constraints']:,}") + print() + print(f" Unique factors: {stats['num_unique_factors']}") + + # ------------------------------------------------------------------ + # 4. Show per-node placement comparison + # ------------------------------------------------------------------ + print("\n" + "=" * 70) + print("PER-NODE PLACEMENT COMPARISON (first 30 call_function nodes)") + print("=" * 70) + + call_fn_nodes = [n for n in gm.graph.nodes if n.op == "call_function"] + for node in call_fn_nodes[:30]: + orig_spec = orig_solution.get(node) + factor_spec = factor_solution.get(node) + + if orig_spec is not None and hasattr(orig_spec, "output_specs"): + os = orig_spec.output_specs + if isinstance(os, DTensorSpec): + orig_plc = tuple(os.placements) + elif isinstance(os, (list, tuple)) and os: + orig_plc = tuple(os[0].placements) + else: + orig_plc = "?" + else: + orig_plc = "?" + factor_plc = tuple(factor_spec.placements) if factor_spec is not None else "?" + match = "OK" if str(orig_plc) == str(factor_plc) else "DIFF" + op_name = ( + str(node.target).split(".")[-1] + if hasattr(node.target, "__name__") + else str(node.target) + ) + # Truncate long op names + if len(op_name) > 40: + op_name = op_name[:37] + "..." + print(f" [{match:4s}] {op_name:42s} orig={orig_plc} factor={factor_plc}") + +print("\nDone.") From 9b0e8d64e4afba51b133df78c45419078ab19b80 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 17 Feb 2026 06:13:32 +0000 Subject: [PATCH 02/22] Bugfixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The representation now allows it — nothing in the constraint system prevents y[root, 0] = 1 AND y[root, 1] = 1. The issue is in the cost model. Let me trace why. The DIFF nodes are weight transposes (t.default). Their dim 0 factor merges with the K (reduction/contraction) factor of the downstream mm. In the current cost model: if fac.is_reduction: cost += allreduce_cost # penalty only else: cost -= compute_benefit # benefit only Reduction factors get only an allreduce penalty, with no compute benefit subtracted. But sharding the K dimension does reduce compute — each device computes a smaller matmul. The original enumeration-based optimizer captures this because its per-strategy costs include both comm and compute for each choice. The factor cost model misses the compute savings for reduction factors, so the optimizer avoids assigning K to any mesh dim. --- autoparallel/optimize_sharding_new.py | 21 ++++++++++++++------- examples/example_autoparallel_factor.py | 7 ++----- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index 310550cb..9bcf45a1 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -485,7 +485,12 @@ def _build_ilp(self) -> None: self.y_vars[(r, m)] = pulp.LpVariable(f"y_{r}_m{m}", cat="Binary") # --- Constraints --- - self._add_factor_uniqueness(roots) + # NOTE: we intentionally omit a factor-uniqueness constraint here. + # A factor MAY be assigned to multiple mesh dims simultaneously, + # which corresponds to placements like (Shard(0), Shard(0)) where a + # single tensor dim is sharded across several mesh dims. The tensor + # exclusion constraint already prevents invalid combos (two different + # factors claiming the same tensor dim on the same mesh dim). self._add_tensor_exclusion() # --- Objective --- @@ -584,17 +589,19 @@ def _add_objective(self, roots: set[int]) -> None: if node.op != "call_function": continue + # Compute benefit: work is divided by mesh_size + # regardless of whether the factor is a reduction + # or spatial dimension. + compute = self._compute_cost(node) + benefit = compute * (1.0 - 1.0 / mesh_size) + cost -= benefit + if fac.is_reduction: - # Allreduce: ring algorithm ≈ 2·B·(n-1)/n + # Allreduce penalty: ring algorithm ≈ 2·B·(n-1)/n out_bytes = self._output_bytes(node) ar_bytes = 2.0 * out_bytes * (mesh_size - 1) / mesh_size # Rough bandwidth model: 50 GB/s per link cost += ar_bytes / 50e9 * 1e6 # microseconds - else: - # Compute benefit: work is divided by mesh_size. - compute = self._compute_cost(node) - benefit = compute * (1.0 - 1.0 / mesh_size) - cost -= benefit if cost != 0.0: terms.append(cost * var) diff --git a/examples/example_autoparallel_factor.py b/examples/example_autoparallel_factor.py index dd16e967..e64e8de7 100644 --- a/examples/example_autoparallel_factor.py +++ b/examples/example_autoparallel_factor.py @@ -102,6 +102,7 @@ def input_fn(): x_sharding = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1) mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) +mp_policy = None # --------------------------------------------------------------------------- # Trace the model (reuse AutoParallel for graph capture only) @@ -198,11 +199,7 @@ def input_fn(): orig_plc = "?" factor_plc = tuple(factor_spec.placements) if factor_spec is not None else "?" match = "OK" if str(orig_plc) == str(factor_plc) else "DIFF" - op_name = ( - str(node.target).split(".")[-1] - if hasattr(node.target, "__name__") - else str(node.target) - ) + op_name = str(node) # Truncate long op names if len(op_name) > 40: op_name = op_name[:37] + "..." From b5a071657eaa8c7f6ca0e820e568026b8b72fe5b Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 17 Feb 2026 06:22:54 +0000 Subject: [PATCH 03/22] Fix support for ops with list/tuple outputs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three places fixed in optimize_sharding_new.py: 1. _infer_factor_size — now uses _get_primary_tensor() to extract the first tensor from tuple outputs for shape inference 2. _placeholder_factor_rule — same, handles tuple val metadata 3. get_solution — for tuple outputs, produces a tuple of DTensorSpecs, filtering out Shard(d) placements where d exceeds a given output tensor's ndim (since auxiliary outputs like logsumexp may have fewer dimensions than the primary output) --- autoparallel/optimize_sharding_new.py | 56 ++++++++++++++++++------- examples/example_autoparallel_factor.py | 15 +++++-- 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index 9bcf45a1..23acc1e1 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -167,6 +167,21 @@ def union(self, x: int, y: int) -> int: # --------------------------------------------------------------------------- +def _get_primary_tensor(val: Any) -> torch.Tensor | None: + """Extract the primary tensor from a node's meta 'val'. + + For single-tensor outputs, returns the tensor directly. + For tuple outputs (multi-output ops like SDPA), returns the first tensor. + """ + if isinstance(val, torch.Tensor): + return val + if isinstance(val, (tuple, list)): + for v in val: + if isinstance(v, torch.Tensor): + return v + return None + + def _infer_factor_size( node: torch.fx.Node, operand_dims: list[int | None], @@ -174,24 +189,20 @@ def _infer_factor_size( ) -> int: """Infer the size of a factor from the node's tensor metadata.""" # Try the node's own output first. - val = node.meta.get("val") - if val is not None and isinstance(val, torch.Tensor): + out = _get_primary_tensor(node.meta.get("val")) + if out is not None: for d in result_dims: - if d is not None and d < len(val.shape): - return val.shape[d] + if d is not None and d < len(out.shape): + return out.shape[d] # Fall back to operand shapes. for arg_idx, d in enumerate(operand_dims): if d is None or arg_idx >= len(node.args): continue arg = node.args[arg_idx] if isinstance(arg, torch.fx.Node): - arg_val = arg.meta.get("val") - if ( - arg_val is not None - and isinstance(arg_val, torch.Tensor) - and d < len(arg_val.shape) - ): - return arg_val.shape[d] + arg_out = _get_primary_tensor(arg.meta.get("val")) + if arg_out is not None and d < len(arg_out.shape): + return arg_out.shape[d] return 1 # fallback @@ -269,10 +280,10 @@ def _placeholder_factor_rule(node: torch.fx.Node) -> FactorRule: Each tensor dimension becomes an independent spatial factor. """ - val = node.meta.get("val") - if val is None or not isinstance(val, torch.Tensor): + out = _get_primary_tensor(node.meta.get("val")) + if out is None: return FactorRule(factors=[], num_operands=0, num_results=1) - shape = val.shape + shape = out.shape factors = [ Factor(id=d, size=shape[d], operand_dims=[], result_dims=[d]) for d in range(len(shape)) @@ -775,6 +786,23 @@ def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, DTensorSpec result[node] = DTensorSpec( self.mesh, tuple(placements), tensor_meta=tensor_meta ) + elif val is not None and isinstance(val, (tuple, list)): + # Multi-output op (e.g. SDPA). The factors describe the + # primary (first) output. For each output tensor, keep only + # Shard placements whose dim is in range for that tensor. + specs = [] + for v in val: + if isinstance(v, torch.Tensor): + plc = tuple( + p if not isinstance(p, Shard) or p.dim < len(v.shape) + else Replicate() + for p in placements + ) + tm = TensorMeta(v.shape, v.stride(), v.dtype) + specs.append(DTensorSpec(self.mesh, plc, tensor_meta=tm)) + else: + specs.append(None) + result[node] = tuple(specs) return result diff --git a/examples/example_autoparallel_factor.py b/examples/example_autoparallel_factor.py index e64e8de7..310dcd77 100644 --- a/examples/example_autoparallel_factor.py +++ b/examples/example_autoparallel_factor.py @@ -178,12 +178,13 @@ def input_fn(): # ------------------------------------------------------------------ # 4. Show per-node placement comparison # ------------------------------------------------------------------ + n_show = 100 print("\n" + "=" * 70) - print("PER-NODE PLACEMENT COMPARISON (first 30 call_function nodes)") + print(f"PER-NODE PLACEMENT COMPARISON (first {n_show} call_function nodes)") print("=" * 70) call_fn_nodes = [n for n in gm.graph.nodes if n.op == "call_function"] - for node in call_fn_nodes[:30]: + for node in call_fn_nodes[:n_show]: orig_spec = orig_solution.get(node) factor_spec = factor_solution.get(node) @@ -197,7 +198,15 @@ def input_fn(): orig_plc = "?" else: orig_plc = "?" - factor_plc = tuple(factor_spec.placements) if factor_spec is not None else "?" + if factor_spec is not None: + if isinstance(factor_spec, DTensorSpec): + factor_plc = tuple(factor_spec.placements) + elif isinstance(factor_spec, (list, tuple)) and factor_spec: + factor_plc = tuple(factor_spec[0].placements) + else: + factor_plc = "?" + else: + factor_plc = "?" match = "OK" if str(orig_plc) == str(factor_plc) else "DIFF" op_name = str(node) # Truncate long op names From 2b01ef19110b66c0b1018c53b77473e45797d6ef Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 17 Feb 2026 06:36:52 +0000 Subject: [PATCH 04/22] Add support for Partial and all-gather cost MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Here's a summary of the changes: 1. Partial placement output — Reduction factors (like K in matmul) now produce Partial() instead of being silently skipped. This matches what the original optimizer outputs for nodes like mm_3. 2. FLOP-based compute cost estimator — Replaced the broken estimate_strategy_runtime_cost(node, None) (which was silently returning 0 for everything) with a direct FLOP counter for mm, addmm, bmm, and SDPA. This is why the optimizer wasn't incentivized to shard the batch dimension on both mesh dims — it saw zero compute benefit. 3. All-gather cost at exit edges — For each spatial factor root, we pre-compute "exit edges" where a producer has the factor but its consumer doesn't (via union-find). When the factor is assigned to a mesh dim, an all-gather is needed at each such edge. This cost is added as a linear term in the objective, penalizing factors that would require redistribution. --- autoparallel/optimize_sharding_new.py | 136 +++++++++++++++++++++----- 1 file changed, 110 insertions(+), 26 deletions(-) diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index 23acc1e1..596967f4 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -45,7 +45,6 @@ from __future__ import annotations -import math from collections import defaultdict from dataclasses import dataclass, field from typing import Any, Optional @@ -60,11 +59,10 @@ ) from torch.distributed._tensor.placement_types import TensorMeta from torch.distributed.tensor._dtensor_spec import DTensorSpec -from torch.distributed.tensor._op_schema import OpSpec, OpStrategy -from torch.distributed.tensor.placement_types import Placement, Replicate, Shard +from torch.distributed.tensor._op_schema import OpStrategy +from torch.distributed.tensor.placement_types import Partial, Placement, Replicate, Shard from torch.utils._pytree import tree_map_only -from .cost_models.compute_estimation import estimate_strategy_runtime_cost from .shardings.placement_options import get_placement_options from .shardings.propagation_rules import _create_all_options @@ -75,7 +73,7 @@ def _is_partial(p: Placement) -> bool: """Check if a placement is Partial (reduction output).""" - return not isinstance(p, (Shard, Replicate)) + return isinstance(p, Partial) @dataclass @@ -577,16 +575,20 @@ def _add_tensor_exclusion(self) -> None: def _add_objective(self, roots: set[int]) -> None: """Build the cost function. - For each factor *f* assigned to mesh dim *m*: + For each factor *f* assigned to mesh dim *m* the cost coefficient + includes three components: - * **Reduction factor** → ``+allreduce_cost(output_bytes, mesh.shape[m])`` - * **Spatial factor** → ``-compute_savings(op, mesh.shape[m])`` + 1. **Compute benefit** (all factors): sharding any dimension divides + work by ``mesh.shape[m]``. + 2. **Allreduce penalty** (reduction factors): the ``Partial`` output + must be all-reduced before consumers can use it. + 3. **All-gather penalty** (spatial factors at "exit" edges): when a + producer is ``Shard(d)`` on mesh dim *m* but a consumer doesn't + share that factor (via union-find), an all-gather is needed. - This is a *first-order* (linear) approximation. The true compute cost - depends on the product of all shard sizes, which would make the - objective quadratic. The linear model captures the dominant effects: - allreduce penalties and per-factor parallelism benefits. + All three are linear in the ``y`` variables, keeping the ILP linear. """ + ag_bytes = self._compute_redistribution_bytes() terms: list[Any] = [] for r in roots: @@ -611,8 +613,13 @@ def _add_objective(self, roots: set[int]) -> None: # Allreduce penalty: ring algorithm ≈ 2·B·(n-1)/n out_bytes = self._output_bytes(node) ar_bytes = 2.0 * out_bytes * (mesh_size - 1) / mesh_size - # Rough bandwidth model: 50 GB/s per link - cost += ar_bytes / 50e9 * 1e6 # microseconds + cost += ar_bytes / self._BW * 1e6 # microseconds + + # All-gather penalty for exit edges (spatial factors whose + # consumers don't share the root). + if r in ag_bytes: + ag_comm = ag_bytes[r] * (mesh_size - 1) / mesh_size + cost += ag_comm / self._BW * 1e6 if cost != 0.0: terms.append(cost * var) @@ -622,19 +629,96 @@ def _add_objective(self, roots: set[int]) -> None: # ---- cost helpers ----------------------------------------------- + # Rough inter-node bandwidth (bytes/s). 50 GB/s is a reasonable + # default for NVLink / high-end InfiniBand. + _BW: float = 50e9 + @staticmethod def _output_bytes(node: torch.fx.Node) -> float: - val = node.meta.get("val") - if val is not None and isinstance(val, torch.Tensor): + val = _get_primary_tensor(node.meta.get("val")) + if val is not None: return float(val.numel() * val.element_size()) return 0.0 @staticmethod def _compute_cost(node: torch.fx.Node) -> float: - try: - return float(estimate_strategy_runtime_cost(node, None)) - except Exception: + """Estimate compute cost in FLOPs for compute-intensive ops.""" + if node.op != "call_function": return 0.0 + target = node.target + + def _shape(n: Any) -> tuple[int, ...] | None: + if isinstance(n, torch.fx.Node): + t = _get_primary_tensor(n.meta.get("val")) + if t is not None: + return tuple(t.shape) + return None + + # mm(A[M,K], B[K,N]) → 2·M·K·N + if target == torch.ops.aten.mm.default: + a, b = _shape(node.args[0]), _shape(node.args[1]) + if a and b and len(a) == 2 and len(b) == 2: + M, K = a + _, N = b + return 2.0 * M * K * N + + # addmm(bias, A[M,K], B[K,N]) → 2·M·K·N + if target == torch.ops.aten.addmm.default: + a, b = _shape(node.args[1]), _shape(node.args[2]) + if a and b and len(a) == 2 and len(b) == 2: + M, K = a + _, N = b + return 2.0 * M * K * N + + # bmm(A[B,M,K], B[B,K,N]) → 2·B·M·K·N + if target == torch.ops.aten.bmm.default: + a, b = _shape(node.args[0]), _shape(node.args[1]) + if a and b and len(a) == 3 and len(b) == 3: + B, M, K = a + _, _, N = b + return 2.0 * B * M * K * N + + # SDPA — dominated by two bmm-like ops internally: + # scores = Q·K^T → 2·B·H·S·S·D + # output = scores·V → 2·B·H·S·S·D + if "scaled_dot_product" in str(target): + q = _shape(node.args[0]) + if q and len(q) == 4: + B, H, S, D = q + return 2.0 * 2 * B * H * S * S * D # two bmm-equivalent + + return 0.0 + + def _compute_redistribution_bytes(self) -> dict[int, float]: + """For each factor root, total output bytes at "exit" edges. + + An exit edge is a producer→consumer edge where the producer has a + spatial factor with root *R* but the consumer doesn't share *R* + (via union-find). When *R* is assigned to a mesh dim the consumer + needs an all-gather on that edge. + """ + # node_idx → set of factor roots at that node + node_roots: dict[int, set[int]] = defaultdict(set) + for (nidx, _li), gid in self.factor_keys.items(): + node_roots[nidx].add(self.uf.find(gid)) + + ag_bytes: dict[int, float] = defaultdict(float) + for root, refs in self.factor_ops.items(): + for node, fac, _li in refs: + # Only spatial factors produce Shard output. + if fac.is_reduction: + continue + if not fac.result_dims or fac.result_dims[0] is None: + continue + for user in node.users: + if user.op != "call_function": + continue + uidx = self.node_map.get(user) + if uidx is None: + continue + if root not in node_roots.get(uidx, set()): + ag_bytes[root] += self._output_bytes(node) + return dict(ag_bytes) # ----------------------------------------------------------------- # User constraints @@ -767,18 +851,18 @@ def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, DTensorSpec placements: list[Placement] = [Replicate()] * self.mesh.ndim for li, fac in enumerate(rule.factors): - if not fac.result_dims or fac.result_dims[0] is None: - continue gid = self.factor_keys.get((nidx, li)) if gid is None: continue root = self.uf.find(gid) m = assignment.get(root) - if m is not None: - td = fac.result_dims[0] - # A factor may be spatial here but reduction in another op. - # For the *output* placement we use Shard(dim). - placements[m] = Shard(td) + if m is None: + continue + if fac.is_reduction: + # Reduction factor → output is Partial on this mesh dim. + placements[m] = Partial() + elif fac.result_dims and fac.result_dims[0] is not None: + placements[m] = Shard(fac.result_dims[0]) val = node.meta.get("val") if val is not None and isinstance(val, torch.Tensor): From 42231eb16a9c58306cb4e2b9d3a44ff33015fd92 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 17 Feb 2026 06:48:39 +0000 Subject: [PATCH 05/22] Propagate reduction factors through data-preserving ops MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Here's a summary of the three changes: 1. Reduction factor propagation in _build_factor_graph: When a consumer has a reduction factor with operand_dims[arg_pos] = None (a Partial pass-through atom from view/permute/alias strategies), it now merges with the producer's reduction factor. This makes Partial propagate through data-preserving ops — so view_11 will correctly show Partial(sum) instead of Replicate(). 2. Reduce-scatter cost replaces allreduce: The old code added allreduce cost (2B·(n-1)/n) unconditionally at every node where a factor was a reduction. Now the cost is only added at "exit edges" where the Partial doesn't propagate to the consumer, using reduce-scatter cost (B·(n-1)/n — half of allreduce). Where the reduction factor propagates (e.g., mm → view), there's zero cost. 3. Unified _compute_redistribution_bytes: Returns both ag_bytes (spatial exit edges → all-gather) and rs_bytes (reduction exit edges → reduce-scatter) in a single pass over the graph. --- autoparallel/optimize_sharding_new.py | 94 +++++++++++++++++---------- 1 file changed, 61 insertions(+), 33 deletions(-) diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index 596967f4..c5b8bdc0 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -449,17 +449,32 @@ def _build_factor_graph(self) -> None: if arg_pos >= len(c_fac.operand_dims): continue c_dim = c_fac.operand_dims[arg_pos] - if c_dim is None: - continue - for p_li, p_fac in enumerate(producer_rule.factors): - if not p_fac.result_dims: - continue - p_dim = p_fac.result_dims[0] - if p_dim is not None and p_dim == c_dim: - pk = self.factor_keys.get((pidx, p_li)) - ck = self.factor_keys.get((cidx, c_li)) - if pk is not None and ck is not None: - self.uf.union(pk, ck) + + if c_dim is not None: + # Spatial factor: match by dimension index. + for p_li, p_fac in enumerate(producer_rule.factors): + if not p_fac.result_dims: + continue + p_dim = p_fac.result_dims[0] + if p_dim is not None and p_dim == c_dim: + pk = self.factor_keys.get((pidx, p_li)) + ck = self.factor_keys.get((cidx, c_li)) + if pk is not None and ck is not None: + self.uf.union(pk, ck) + elif c_fac.is_reduction: + # Reduction factor pass-through (Partial → Partial). + # Data-preserving ops (view, permute, alias, …) have + # a strategy atom (out=Partial, in=Partial) which + # produces a factor with operand_dims=[None] and + # result_dims=[None]. Merge it with the producer's + # reduction factor so that Partial propagates. + for p_li, p_fac in enumerate(producer_rule.factors): + if p_fac.is_reduction: + pk = self.factor_keys.get((pidx, p_li)) + ck = self.factor_keys.get((cidx, c_li)) + if pk is not None and ck is not None: + self.uf.union(pk, ck) + break # one-to-one merge # ----------------------------------------------------------------- # Step 4 — metadata collection @@ -580,15 +595,19 @@ def _add_objective(self, roots: set[int]) -> None: 1. **Compute benefit** (all factors): sharding any dimension divides work by ``mesh.shape[m]``. - 2. **Allreduce penalty** (reduction factors): the ``Partial`` output - must be all-reduced before consumers can use it. + 2. **Reduce-scatter penalty** (reduction factors at "exit" edges): + when a ``Partial`` output reaches a consumer that doesn't share + the reduction root, a reduce-scatter is needed. Cost ≈ B·(n-1)/n. + (If the reduction factor propagates through data-preserving ops + via union-find, there is zero cost at those edges.) 3. **All-gather penalty** (spatial factors at "exit" edges): when a producer is ``Shard(d)`` on mesh dim *m* but a consumer doesn't share that factor (via union-find), an all-gather is needed. + Cost ≈ B·(n-1)/n. All three are linear in the ``y`` variables, keeping the ILP linear. """ - ag_bytes = self._compute_redistribution_bytes() + ag_bytes, rs_bytes = self._compute_redistribution_bytes() terms: list[Any] = [] for r in roots: @@ -609,14 +628,14 @@ def _add_objective(self, roots: set[int]) -> None: benefit = compute * (1.0 - 1.0 / mesh_size) cost -= benefit - if fac.is_reduction: - # Allreduce penalty: ring algorithm ≈ 2·B·(n-1)/n - out_bytes = self._output_bytes(node) - ar_bytes = 2.0 * out_bytes * (mesh_size - 1) / mesh_size - cost += ar_bytes / self._BW * 1e6 # microseconds + # Reduce-scatter penalty at reduction exit edges. + # Only incurred where Partial doesn't propagate to the + # consumer (i.e. at the point where Partial is resolved). + if r in rs_bytes: + rs_comm = rs_bytes[r] * (mesh_size - 1) / mesh_size + cost += rs_comm / self._BW * 1e6 - # All-gather penalty for exit edges (spatial factors whose - # consumers don't share the root). + # All-gather penalty at spatial exit edges. if r in ag_bytes: ag_comm = ag_bytes[r] * (mesh_size - 1) / mesh_size cost += ag_comm / self._BW * 1e6 @@ -689,13 +708,18 @@ def _shape(n: Any) -> tuple[int, ...] | None: return 0.0 - def _compute_redistribution_bytes(self) -> dict[int, float]: + def _compute_redistribution_bytes( + self, + ) -> tuple[dict[int, float], dict[int, float]]: """For each factor root, total output bytes at "exit" edges. - An exit edge is a producer→consumer edge where the producer has a - spatial factor with root *R* but the consumer doesn't share *R* - (via union-find). When *R* is assigned to a mesh dim the consumer - needs an all-gather on that edge. + Returns ``(ag_bytes, rs_bytes)``: + + * **ag_bytes** — for *spatial* factors: bytes needing an all-gather + at edges where the consumer doesn't share the root. + * **rs_bytes** — for *reduction* factors: bytes needing a + reduce-scatter at edges where the ``Partial`` doesn't propagate + to the consumer. """ # node_idx → set of factor roots at that node node_roots: dict[int, set[int]] = defaultdict(set) @@ -703,22 +727,26 @@ def _compute_redistribution_bytes(self) -> dict[int, float]: node_roots[nidx].add(self.uf.find(gid)) ag_bytes: dict[int, float] = defaultdict(float) + rs_bytes: dict[int, float] = defaultdict(float) for root, refs in self.factor_ops.items(): for node, fac, _li in refs: - # Only spatial factors produce Shard output. - if fac.is_reduction: - continue - if not fac.result_dims or fac.result_dims[0] is None: - continue for user in node.users: if user.op != "call_function": continue uidx = self.node_map.get(user) if uidx is None: continue - if root not in node_roots.get(uidx, set()): + if root in node_roots.get(uidx, set()): + continue # factor propagates — no redistribution + + if fac.is_reduction: + # Partial exits here → reduce-scatter + rs_bytes[root] += self._output_bytes(node) + elif fac.result_dims and fac.result_dims[0] is not None: + # Shard exits here → all-gather ag_bytes[root] += self._output_bytes(node) - return dict(ag_bytes) + + return dict(ag_bytes), dict(rs_bytes) # ----------------------------------------------------------------- # User constraints From be3a6738ffe1c0cfe9683298922d755a360ad98d Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 17 Feb 2026 07:18:39 +0000 Subject: [PATCH 06/22] Add linearization for all_reduce / reduce_scatter cost MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Here's a summary of the three changes made: 1. Tensor exclusion extended to include reduction factors (line 556-560) Previously, only spatial factors were included in the per-node per-mesh-dim exclusion constraint. A reduction factor (Partial) and a spatial factor (Shard) could end up assigned to the same mesh dim for the same output tensor, which is invalid in DTensor. Now both are included in the sum <= 1 constraint. 2. Linearized Partial → Replicate cost (lines 649-694) For each reduction exit edge (root r, consumer u) on mesh dim m: - Base cost = B·(n-1)/n · y[r,m] — reduce-scatter, always incurred - Extra cost = B·(n-1)/n · z — upgrades to all-reduce when consumer is Replicate The auxiliary continuous variable z satisfies: - z ≥ y[r,m] − Σ y[s,m] for consumer's spatial roots s - z ≥ 0 Since z has a positive coefficient and we minimize, the solver naturally sets z = max(0, y[r,m] − Σ y[s,m]): - Consumer has spatial factor on m → z=0, total = B (reduce-scatter) - Consumer fully replicated on m → z=y[r,m], total = 2B (all-reduce) 3. New helper methods (lines 798-838) - _compute_reduction_exit_info() — returns (reduction_root, consumer_nidx) → bytes for each exit edge - _get_spatial_roots_at_node(nidx) — returns the set of spatial factor roots at a given node The auxiliary z variables are continuous (not binary), so they add negligible solve cost. The stats now report the breakdown (e.g., "104 y + 42 z"). --- autoparallel/optimize_sharding_new.py | 124 ++++++++++++++++++++++---- 1 file changed, 108 insertions(+), 16 deletions(-) diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index c5b8bdc0..bdee765a 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -554,7 +554,10 @@ def _add_tensor_exclusion(self) -> None: vs = [] seen_roots: set[int] = set() for li, fac in enumerate(rule.factors): - if fac.result_dims and fac.result_dims[0] is not None: + # Include both spatial and reduction factors: a + # tensor can only be Shard(d) OR Partial on each + # mesh dim, never both simultaneously. + if (fac.result_dims and fac.result_dims[0] is not None) or fac.is_reduction: gid = self.factor_keys.get((nidx, li)) if gid is not None: root = self.uf.find(gid) @@ -595,19 +598,26 @@ def _add_objective(self, roots: set[int]) -> None: 1. **Compute benefit** (all factors): sharding any dimension divides work by ``mesh.shape[m]``. - 2. **Reduce-scatter penalty** (reduction factors at "exit" edges): + 2. **Redistribution penalty** (reduction factors at "exit" edges): when a ``Partial`` output reaches a consumer that doesn't share - the reduction root, a reduce-scatter is needed. Cost ≈ B·(n-1)/n. - (If the reduction factor propagates through data-preserving ops - via union-find, there is zero cost at those edges.) + the reduction root, redistribution is needed. The exact cost + depends on the consumer's placement on that mesh dim: + + - **Partial → Shard** (reduce-scatter): B·(n-1)/n + - **Partial → Replicate** (all-reduce): 2B·(n-1)/n + + This is captured exactly via auxiliary continuous variables that + linearize the product ``y[r,m] · (1 - any_consumer_spatial_on_m)``. 3. **All-gather penalty** (spatial factors at "exit" edges): when a producer is ``Shard(d)`` on mesh dim *m* but a consumer doesn't share that factor (via union-find), an all-gather is needed. Cost ≈ B·(n-1)/n. - All three are linear in the ``y`` variables, keeping the ILP linear. + All three are linear in the ``y`` and ``z`` variables, keeping the + ILP linear. """ - ag_bytes, rs_bytes = self._compute_redistribution_bytes() + ag_bytes, _rs_bytes = self._compute_redistribution_bytes() + exit_info = self._compute_reduction_exit_info() terms: list[Any] = [] for r in roots: @@ -628,13 +638,6 @@ def _add_objective(self, roots: set[int]) -> None: benefit = compute * (1.0 - 1.0 / mesh_size) cost -= benefit - # Reduce-scatter penalty at reduction exit edges. - # Only incurred where Partial doesn't propagate to the - # consumer (i.e. at the point where Partial is resolved). - if r in rs_bytes: - rs_comm = rs_bytes[r] * (mesh_size - 1) / mesh_size - cost += rs_comm / self._BW * 1e6 - # All-gather penalty at spatial exit edges. if r in ag_bytes: ag_comm = ag_bytes[r] * (mesh_size - 1) / mesh_size @@ -643,6 +646,50 @@ def _add_objective(self, roots: set[int]) -> None: if cost != 0.0: terms.append(cost * var) + # Reduction exit edges: linearized Partial → {Shard, Replicate} cost. + # + # For each (reduction_root r, consumer u) exit edge, on mesh dim m: + # + # base cost = B·(n-1)/n · y[r,m] (reduce-scatter) + # extra cost = B·(n-1)/n · z (upgrade to all-reduce) + # + # where z is a continuous auxiliary variable satisfying: + # z ≥ y[r,m] − Σ_s y[s,m] for consumer's spatial roots s + # z ≥ 0 (implicit from lowBound=0) + # + # Since z has a positive coefficient and we minimize, the solver + # sets z = max(0, y[r,m] − Σ y[s,m]). + # + # • Consumer has spatial factor on m (Σ≥1) → z=0, total = B (reduce-scatter) + # • Consumer fully replicated on m (Σ=0) → z=y, total = 2B (all-reduce) + z_id = 0 + for (r, uidx), bytes_val in exit_info.items(): + consumer_spatial = self._get_spatial_roots_at_node(uidx) + for m in range(self.mesh.ndim): + mesh_size = self.mesh.shape[m] + y_r_m = self.y_vars[(r, m)] + + # Base reduce-scatter cost (always incurred when y[r,m]=1). + comm_unit = bytes_val * (mesh_size - 1) / mesh_size / self._BW * 1e6 + terms.append(comm_unit * y_r_m) + + # Extra cost for Partial → Replicate (linearized). + valid_roots = [s for s in consumer_spatial if (s, m) in self.y_vars] + if valid_roots: + z = pulp.LpVariable(f"z_pr_{z_id}", lowBound=0) + spatial_sum = pulp.lpSum( + self.y_vars[(s, m)] for s in valid_roots + ) + self.prob += z >= y_r_m - spatial_sum, f"z_pr_lb_{z_id}" + terms.append(comm_unit * z) + z_id += 1 + else: + # No spatial factors at consumer → always all-reduce. + # Extra cost = B·(n-1)/n · y[r,m] (doubling the base). + terms.append(comm_unit * y_r_m) + + self._num_z_vars = z_id + if terms: self.prob += pulp.lpSum(terms) @@ -748,6 +795,48 @@ def _compute_redistribution_bytes( return dict(ag_bytes), dict(rs_bytes) + def _compute_reduction_exit_info(self) -> dict[tuple[int, int], float]: + """For each (reduction_root, consumer_nidx) pair, total bytes at exits. + + Used by the linearized Partial → Replicate cost model to distinguish + reduce-scatter (consumer is Shard) from all-reduce (consumer is + Replicate) on each mesh dimension. + """ + node_roots: dict[int, set[int]] = defaultdict(set) + for (nidx, _li), gid in self.factor_keys.items(): + node_roots[nidx].add(self.uf.find(gid)) + + exit_info: dict[tuple[int, int], float] = defaultdict(float) + for root, refs in self.factor_ops.items(): + for node, fac, _li in refs: + if not fac.is_reduction: + continue + for user in node.users: + if user.op != "call_function": + continue + uidx = self.node_map.get(user) + if uidx is None: + continue + if root in node_roots.get(uidx, set()): + continue # factor propagates — no redistribution + exit_info[(root, uidx)] += self._output_bytes(node) + + return dict(exit_info) + + def _get_spatial_roots_at_node(self, nidx: int) -> set[int]: + """Get unique roots for spatial (non-reduction) result factors at a node.""" + node = self.nodes[nidx] + rule = self.factor_rules.get(node) + if rule is None: + return set() + roots: set[int] = set() + for li, fac in enumerate(rule.factors): + if not fac.is_reduction and fac.result_dims and fac.result_dims[0] is not None: + gid = self.factor_keys.get((nidx, li)) + if gid is not None: + roots.add(self.uf.find(gid)) + return roots + # ----------------------------------------------------------------- # User constraints # ----------------------------------------------------------------- @@ -973,10 +1062,13 @@ def get_stats(self) -> dict[str, Any]: orig_vars += max(n_args, 1) * n_out * n_out n_factor_vars = len(self.y_vars) + n_aux_vars = getattr(self, "_num_z_vars", 0) return { "num_graph_nodes": len(self.nodes), "num_unique_factors": len(roots), - "num_factor_ilp_variables": n_factor_vars, + "num_factor_ilp_variables": n_factor_vars + n_aux_vars, + "num_factor_y_variables": n_factor_vars, + "num_factor_z_variables": n_aux_vars, "num_factor_ilp_constraints": len(self.prob.constraints), "mesh_shape": tuple(self.mesh.shape), "estimated_original_ilp_variables": orig_vars, @@ -989,7 +1081,7 @@ def get_log(self, verbose: bool = False) -> str: lines.append(f"Factor ILP status: {pulp.LpStatus[self.prob.status]}") s = self.get_stats() lines.append(f"Unique factors: {s['num_unique_factors']}") - lines.append(f"Factor ILP variables: {s['num_factor_ilp_variables']}") + lines.append(f"Factor ILP variables: {s['num_factor_ilp_variables']} ({s['num_factor_y_variables']} y + {s['num_factor_z_variables']} z)") lines.append(f"Factor ILP constraints: {s['num_factor_ilp_constraints']}") lines.append( f"Est. original ILP vars: {s['estimated_original_ilp_variables']}" From 5e8f680da56b39c04052813ea8657aeaceef556c Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 17 Feb 2026 08:55:02 +0000 Subject: [PATCH 07/22] Fix Partial propagation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Here's what was changed: Fix 1: Partial propagation validation (_build_factor_graph + _merge_reduction_factors) Problem: For add(o, x) in the transformer, o is Partial (from mm's K factor) but x is Shard(0). The old code merged reduction factors one operand at a time — it saw o has Partial and immediately merged, without checking that x ALSO has Partial. The strategy add(Partial, Shard) is invalid in DTensor. Fix: Spatial and reduction factor merging are now separate passes. The new _merge_reduction_factors method (line 470) collects ALL operands that need Partial (operand_dims[arg_pos] = None) and only merges when every one of them can provide a reduction factor from its producer. For unary ops (view, permute, alias) with one operand, nothing changes. For binary ops (add, mul), Partial only propagates when both operands produce Partial. Fix 2: Compute cost via estimate_strategy_runtime_cost (line 769) Problem: The old _compute_cost was a manual FLOP estimator covering only mm/addmm/bmm/SDPA. It returned 0 for pointwise ops (add, relu, etc.), giving the optimizer no incentive to shard them. It also returned FLOPs while communication costs were in microseconds — the units were inconsistent, making the compute-vs-communication tradeoff arbitrary. Fix: Now calls estimate_strategy_runtime_cost(node, None) from compute_estimation.py, which: - Counts FLOPs via FlopCounterMode (handles all ops, not just matmuls) - Computes memory read/write bytes (captures memory-bound ops like pointwise add/relu) - Returns max(compute_time, memory_time) in microseconds - Returns 0 for view ops (via _has_zero_cost) Results are cached per node. The communication costs in _add_objective are already in microseconds (bytes / BW * 1e6), so now both terms use consistent units. --- autoparallel/optimize_sharding_new.py | 154 +++++++++++++++----------- 1 file changed, 91 insertions(+), 63 deletions(-) diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index bdee765a..147ba64b 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -63,6 +63,7 @@ from torch.distributed.tensor.placement_types import Partial, Placement, Replicate, Shard from torch.utils._pytree import tree_map_only +from .cost_models.compute_estimation import estimate_strategy_runtime_cost from .shardings.placement_options import get_placement_options from .shardings.propagation_rules import _create_all_options @@ -350,6 +351,7 @@ def __init__( self._collect_factor_metadata() # -- Step 5: build ILP. + self._cost_cache: dict[torch.fx.Node, float] = {} self.prob = pulp.LpProblem("AutoParallel_Factor", pulp.LpMinimize) self.y_vars: dict[tuple[int, int], pulp.LpVariable] = {} self._build_ilp() @@ -426,7 +428,7 @@ def _build_factor_graph(self) -> None: self.factor_keys[(nidx, li)] = gid self.uf.make_set(gid) - # Merge factors across producer → consumer edges. + # Merge spatial factors across producer → consumer edges. for node in self.graph.nodes: if node.op != "call_function": continue @@ -449,7 +451,6 @@ def _build_factor_graph(self) -> None: if arg_pos >= len(c_fac.operand_dims): continue c_dim = c_fac.operand_dims[arg_pos] - if c_dim is not None: # Spatial factor: match by dimension index. for p_li, p_fac in enumerate(producer_rule.factors): @@ -461,20 +462,79 @@ def _build_factor_graph(self) -> None: ck = self.factor_keys.get((cidx, c_li)) if pk is not None and ck is not None: self.uf.union(pk, ck) - elif c_fac.is_reduction: - # Reduction factor pass-through (Partial → Partial). - # Data-preserving ops (view, permute, alias, …) have - # a strategy atom (out=Partial, in=Partial) which - # produces a factor with operand_dims=[None] and - # result_dims=[None]. Merge it with the producer's - # reduction factor so that Partial propagates. - for p_li, p_fac in enumerate(producer_rule.factors): - if p_fac.is_reduction: - pk = self.factor_keys.get((pidx, p_li)) - ck = self.factor_keys.get((cidx, c_li)) - if pk is not None and ck is not None: - self.uf.union(pk, ck) - break # one-to-one merge + + # Merge reduction factors (Partial → Partial pass-through) in a + # separate pass, after spatial merging is complete. + self._merge_reduction_factors() + + def _merge_reduction_factors(self) -> None: + """Merge reduction factors across edges only when ALL operands agree. + + For data-preserving ops (view, permute, alias, …), a Partial can + propagate through the op because the strategy has an atom + ``(out=Partial, in=Partial)`` which produces a reduction factor with + ``operand_dims=[None]`` and ``result_dims=[None]``. + + For multi-operand ops like add/mul, the reduction factor has + ``operand_dims=[None, None]`` — ALL operands must be Partial. + We must only merge when every operand that maps to ``None`` in the + factor can provide a Partial from its producer. Otherwise, the + resulting placement (e.g. ``add(Partial, Shard)``) is invalid. + """ + for node in self.graph.nodes: + if node.op != "call_function": + continue + consumer_rule = self.factor_rules.get(node) + if consumer_rule is None: + continue + cidx = self.node_map[node] + + for c_li, c_fac in enumerate(consumer_rule.factors): + if not c_fac.is_reduction: + continue + ck = self.factor_keys.get((cidx, c_li)) + if ck is None: + continue + + # Collect producer reduction keys for each operand that + # needs Partial, and validate that ALL can provide it. + merge_pairs: list[tuple[int, int]] = [] + all_valid = True + + for arg_pos, c_od in enumerate(c_fac.operand_dims): + if c_od is not None: + continue # Spatial dim on this operand, not Partial + + # This operand must be Partial for the factor to propagate. + if arg_pos >= len(node.args): + all_valid = False + break + arg = node.args[arg_pos] + if not isinstance(arg, torch.fx.Node): + all_valid = False + break + producer_rule = self.factor_rules.get(arg) + if producer_rule is None: + all_valid = False + break + pidx = self.node_map[arg] + + # Find a reduction factor at the producer. + found = False + for p_li, p_fac in enumerate(producer_rule.factors): + if p_fac.is_reduction: + pk = self.factor_keys.get((pidx, p_li)) + if pk is not None: + merge_pairs.append((pk, ck)) + found = True + break + if not found: + all_valid = False + break + + if all_valid and merge_pairs: + for pk, ck_val in merge_pairs: + self.uf.union(pk, ck_val) # ----------------------------------------------------------------- # Step 4 — metadata collection @@ -706,54 +766,22 @@ def _output_bytes(node: torch.fx.Node) -> float: return float(val.numel() * val.element_size()) return 0.0 - @staticmethod - def _compute_cost(node: torch.fx.Node) -> float: - """Estimate compute cost in FLOPs for compute-intensive ops.""" - if node.op != "call_function": - return 0.0 - target = node.target - - def _shape(n: Any) -> tuple[int, ...] | None: - if isinstance(n, torch.fx.Node): - t = _get_primary_tensor(n.meta.get("val")) - if t is not None: - return tuple(t.shape) - return None - - # mm(A[M,K], B[K,N]) → 2·M·K·N - if target == torch.ops.aten.mm.default: - a, b = _shape(node.args[0]), _shape(node.args[1]) - if a and b and len(a) == 2 and len(b) == 2: - M, K = a - _, N = b - return 2.0 * M * K * N - - # addmm(bias, A[M,K], B[K,N]) → 2·M·K·N - if target == torch.ops.aten.addmm.default: - a, b = _shape(node.args[1]), _shape(node.args[2]) - if a and b and len(a) == 2 and len(b) == 2: - M, K = a - _, N = b - return 2.0 * M * K * N - - # bmm(A[B,M,K], B[B,K,N]) → 2·B·M·K·N - if target == torch.ops.aten.bmm.default: - a, b = _shape(node.args[0]), _shape(node.args[1]) - if a and b and len(a) == 3 and len(b) == 3: - B, M, K = a - _, _, N = b - return 2.0 * B * M * K * N - - # SDPA — dominated by two bmm-like ops internally: - # scores = Q·K^T → 2·B·H·S·S·D - # output = scores·V → 2·B·H·S·S·D - if "scaled_dot_product" in str(target): - q = _shape(node.args[0]) - if q and len(q) == 4: - B, H, S, D = q - return 2.0 * 2 * B * H * S * S * D # two bmm-equivalent + def _compute_cost(self, node: torch.fx.Node) -> float: + """Estimate unsharded compute cost for a node in microseconds. - return 0.0 + Uses ``estimate_strategy_runtime_cost`` from ``compute_estimation.py`` + which accounts for both FLOP-bound ops (matmul, bmm, SDPA) and + memory-bound ops (pointwise add, relu, etc.). View ops return 0. + Results are cached per node. + """ + if node in self._cost_cache: + return self._cost_cache[node] + try: + cost = estimate_strategy_runtime_cost(node, None) + except Exception: + cost = 0.0 + self._cost_cache[node] = cost + return cost def _compute_redistribution_bytes( self, From 76020c95ba4a0b8bde47e3c6b4ade5e4dfc86ecc Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 17 Feb 2026 12:40:47 +0000 Subject: [PATCH 08/22] Fix for getitem and tuple outputs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Let me summarize all the changes made: Changes summary 1. import operator added (line 48) — needed to check node.target is operator.getitem in the factor graph builder. 2. _infer_factor_size (lines 191-204) — now handles multi-output nodes correctly. For tuple/list outputs, it indexes result_dims[ri] against val[ri] to find the correct output tensor's shape, rather than always using the first tensor. 3. extract_factors_from_strategy (lines 239-299) — the core fix. Previously it extracted only the first output's placement (out_specs[0].placements[0]) and built a single-element result_dims. Now it: - Determines num_results from the output specs structure - Extracts placements for all output tensors into out_ps - Builds result_dims with one entry per output tensor (e.g., [1, 1, 1] for SDPA backward with 3 outputs that all shard on dim 1) - Sets num_results correctly on the returned FactorRule 4. _build_factor_graph (lines 473-502) — for getitem nodes, determines the getitem index from node.args[1] and uses p_fac.result_dims[result_idx] instead of always result_dims[0]. This correctly connects e.g. getitem_4 = sdpa_bwd[0] to the first output's factors of sdpa_bwd, and getitem_5 = sdpa_bwd[1] to the second output's factors. 5. _add_tensor_exclusion (lines 650-672) — now iterates over each result tensor separately (for ri in range(rule.num_results)), generating per-result exclusion constraints. Different output tensors of the same op can independently have different shardings. 6. _compute_redistribution_bytes (line 863) — changed from fac.result_dims[0] is not None to any(d is not None for d in fac.result_dims). 7. _get_spatial_roots_at_node (line 905) — same pattern, checks any result dim. 8. add_node_constraint (lines 929, 941) — Shard matching checks any(d == p.dim for d in fac.result_dims), Replicate checking uses any(d is not None for d in fac.result_dims). 9. get_solution (lines 1063-1091) — multi-output reconstruction now builds per-output-tensor placements using fac.result_dims[ri] for each output index ri, rather than copying a single set of placements and filtering by shape. 10. Removed the commented-out IPython embed debug line. How it works with your example For _scaled_dot_product_efficient_attention_backward with 3 gradient outputs (grad_q, grad_k, grad_v), the factor rule now produces: - Factor "nheads": result_dims=[1, 1, 1] (dim 1 of all 3 outputs) - Factor "batch": result_dims=[0, 0, 0] (dim 0 of all 3 outputs) - num_results=3 When getitem_4 = sdpa_bwd[0] is processed in the factor graph, it uses result_idx=0, so it matches against result_dims[0] of each producer factor — correctly connecting to the first output's dimensions. Similarly, getitem_5 = sdpa_bwd[1] uses result_idx=1, connecting to the second output. --- autoparallel/optimize_sharding_new.py | 161 +++++++++++++++++--------- 1 file changed, 109 insertions(+), 52 deletions(-) diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index 147ba64b..170e6849 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -45,6 +45,7 @@ from __future__ import annotations +import operator from collections import defaultdict from dataclasses import dataclass, field from typing import Any, Optional @@ -188,11 +189,19 @@ def _infer_factor_size( ) -> int: """Infer the size of a factor from the node's tensor metadata.""" # Try the node's own output first. - out = _get_primary_tensor(node.meta.get("val")) - if out is not None: - for d in result_dims: - if d is not None and d < len(out.shape): - return out.shape[d] + val = node.meta.get("val") + if val is not None: + if isinstance(val, torch.Tensor): + for d in result_dims: + if d is not None and d < len(val.shape): + return val.shape[d] + elif isinstance(val, (tuple, list)): + # Multi-output: result_dims[ri] corresponds to output tensor ri. + for ri, d in enumerate(result_dims): + if d is not None and ri < len(val): + v = val[ri] + if isinstance(v, torch.Tensor) and d < len(v.shape): + return v.shape[d] # Fall back to operand shapes. for arg_idx, d in enumerate(operand_dims): if d is None or arg_idx >= len(node.args): @@ -233,32 +242,48 @@ def extract_factors_from_strategy( first_spec = op_strategy.strategies[0] num_operands = len(first_spec.input_specs) if first_spec.input_specs else 0 + # Determine number of result tensors. + first_out = first_spec.output_specs + if isinstance(first_out, DTensorSpec): + num_results = 1 + elif isinstance(first_out, (list, tuple)): + num_results = len(first_out) + else: + num_results = 1 + # Collect unique 1-D "atoms" by looking at mesh dim 0. seen: dict[str, tuple] = {} for spec in op_strategy.strategies: out_specs = spec.output_specs - out_p = ( - out_specs.placements[0] - if isinstance(out_specs, DTensorSpec) - else out_specs[0].placements[0] - ) + if isinstance(out_specs, DTensorSpec): + out_ps = (out_specs.placements[0],) + elif isinstance(out_specs, (list, tuple)): + out_ps = tuple( + s.placements[0] if isinstance(s, DTensorSpec) else None + for s in out_specs + ) + else: + out_ps = (Replicate(),) in_ps = tuple( s.placements[0] for s in (spec.input_specs or []) if isinstance(s, DTensorSpec) ) - all_ps = (out_p,) + in_ps - if all(isinstance(p, Replicate) for p in all_ps): + all_ps = out_ps + in_ps + if all(p is None or isinstance(p, Replicate) for p in all_ps): continue # skip the all-replicate atom key = str(all_ps) if key not in seen: - seen[key] = (out_p, in_ps) + seen[key] = (out_ps, in_ps) # Each atom → one Factor. factors: list[Factor] = [] - for factor_id, (out_p, in_ps) in enumerate(seen.values()): - is_reduction = _is_partial(out_p) - result_dims = [out_p.dim if isinstance(out_p, Shard) else None] + for factor_id, (out_ps, in_ps) in enumerate(seen.values()): + is_reduction = any(_is_partial(p) for p in out_ps if p is not None) + result_dims = [ + p.dim if p is not None and isinstance(p, Shard) else None + for p in out_ps + ] operand_dims = [p.dim if isinstance(p, Shard) else None for p in in_ps] size = _infer_factor_size(node, operand_dims, result_dims) factors.append( @@ -271,7 +296,7 @@ def extract_factors_from_strategy( ) ) - return FactorRule(factors=factors, num_operands=num_operands, num_results=1) + return FactorRule(factors=factors, num_operands=num_operands, num_results=num_results) def _placeholder_factor_rule(node: torch.fx.Node) -> FactorRule: @@ -445,6 +470,17 @@ def _build_factor_graph(self) -> None: continue pidx = self.node_map[arg] + # For getitem nodes, the consumer's operand 0 corresponds + # to a specific result index of the multi-output producer. + result_idx = 0 + if ( + node.target is operator.getitem + and arg_pos == 0 + and len(node.args) > 1 + and isinstance(node.args[1], int) + ): + result_idx = node.args[1] + # Match: consumer operand dim == producer result dim on the # same positional dimension → same logical factor. for c_li, c_fac in enumerate(consumer_rule.factors): @@ -456,7 +492,9 @@ def _build_factor_graph(self) -> None: for p_li, p_fac in enumerate(producer_rule.factors): if not p_fac.result_dims: continue - p_dim = p_fac.result_dims[0] + if result_idx >= len(p_fac.result_dims): + continue + p_dim = p_fac.result_dims[result_idx] if p_dim is not None and p_dim == c_dim: pk = self.factor_keys.get((pidx, p_li)) ck = self.factor_keys.get((cidx, c_li)) @@ -609,24 +647,29 @@ def _add_tensor_exclusion(self) -> None: continue nidx = self.node_map[node] - # — result tensor — - for m in range(self.mesh.ndim): - vs = [] - seen_roots: set[int] = set() - for li, fac in enumerate(rule.factors): - # Include both spatial and reduction factors: a - # tensor can only be Shard(d) OR Partial on each - # mesh dim, never both simultaneously. - if (fac.result_dims and fac.result_dims[0] is not None) or fac.is_reduction: - gid = self.factor_keys.get((nidx, li)) - if gid is not None: - root = self.uf.find(gid) - if root not in seen_roots: - seen_roots.add(root) - vs.append(self.y_vars[(root, m)]) - if len(vs) > 1: - self.prob += pulp.lpSum(vs) <= 1, f"tex_r_{cid}" - cid += 1 + # — result tensors (one exclusion set per result) — + for ri in range(rule.num_results): + for m in range(self.mesh.ndim): + vs = [] + seen_roots: set[int] = set() + for li, fac in enumerate(rule.factors): + # Include both spatial and reduction factors: a + # tensor can only be Shard(d) OR Partial on each + # mesh dim, never both simultaneously. + has_spatial = ( + ri < len(fac.result_dims) + and fac.result_dims[ri] is not None + ) + if has_spatial or fac.is_reduction: + gid = self.factor_keys.get((nidx, li)) + if gid is not None: + root = self.uf.find(gid) + if root not in seen_roots: + seen_roots.add(root) + vs.append(self.y_vars[(root, m)]) + if len(vs) > 1: + self.prob += pulp.lpSum(vs) <= 1, f"tex_r_{cid}" + cid += 1 # — operand tensors — for oi in range(rule.num_operands): @@ -817,7 +860,7 @@ def _compute_redistribution_bytes( if fac.is_reduction: # Partial exits here → reduce-scatter rs_bytes[root] += self._output_bytes(node) - elif fac.result_dims and fac.result_dims[0] is not None: + elif any(d is not None for d in fac.result_dims): # Shard exits here → all-gather ag_bytes[root] += self._output_bytes(node) @@ -859,7 +902,7 @@ def _get_spatial_roots_at_node(self, nidx: int) -> set[int]: return set() roots: set[int] = set() for li, fac in enumerate(rule.factors): - if not fac.is_reduction and fac.result_dims and fac.result_dims[0] is not None: + if not fac.is_reduction and any(d is not None for d in fac.result_dims): gid = self.factor_keys.get((nidx, li)) if gid is not None: roots.add(self.uf.find(gid)) @@ -883,7 +926,7 @@ def add_node_constraint( for m, p in enumerate(placement): if isinstance(p, Shard): for li, fac in enumerate(rule.factors): - if fac.result_dims and fac.result_dims[0] == p.dim: + if any(d == p.dim for d in fac.result_dims): gid = self.factor_keys.get((nidx, li)) if gid is not None: root = self.uf.find(gid) @@ -895,7 +938,7 @@ def add_node_constraint( elif isinstance(p, Replicate): seen_roots: set[int] = set() for li, fac in enumerate(rule.factors): - if fac.result_dims and fac.result_dims[0] is not None: + if any(d is not None for d in fac.result_dims): gid = self.factor_keys.get((nidx, li)) if gid is not None: root = self.uf.find(gid) @@ -1006,8 +1049,10 @@ def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, DTensorSpec if fac.is_reduction: # Reduction factor → output is Partial on this mesh dim. placements[m] = Partial() - elif fac.result_dims and fac.result_dims[0] is not None: - placements[m] = Shard(fac.result_dims[0]) + else: + # Use result_dims[0] for single-output nodes. + if fac.result_dims and fac.result_dims[0] is not None: + placements[m] = Shard(fac.result_dims[0]) val = node.meta.get("val") if val is not None and isinstance(val, torch.Tensor): @@ -1016,19 +1061,31 @@ def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, DTensorSpec self.mesh, tuple(placements), tensor_meta=tensor_meta ) elif val is not None and isinstance(val, (tuple, list)): - # Multi-output op (e.g. SDPA). The factors describe the - # primary (first) output. For each output tensor, keep only - # Shard placements whose dim is in range for that tensor. + # Multi-output op (e.g. SDPA). Build per-output placements + # using the corresponding result_dims index for each output. specs = [] - for v in val: + for ri, v in enumerate(val): if isinstance(v, torch.Tensor): - plc = tuple( - p if not isinstance(p, Shard) or p.dim < len(v.shape) - else Replicate() - for p in placements - ) + plc_list = [Replicate()] * self.mesh.ndim + for li, fac in enumerate(rule.factors): + gid = self.factor_keys.get((nidx, li)) + if gid is None: + continue + root = self.uf.find(gid) + m_assigned = assignment.get(root) + if m_assigned is None: + continue + if fac.is_reduction: + plc_list[m_assigned] = Partial() + elif ( + ri < len(fac.result_dims) + and fac.result_dims[ri] is not None + ): + plc_list[m_assigned] = Shard(fac.result_dims[ri]) tm = TensorMeta(v.shape, v.stride(), v.dtype) - specs.append(DTensorSpec(self.mesh, plc, tensor_meta=tm)) + specs.append( + DTensorSpec(self.mesh, tuple(plc_list), tensor_meta=tm) + ) else: specs.append(None) result[node] = tuple(specs) From 93bba5ca5718a02759e25096027bd5e24d17417a Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 17 Feb 2026 13:47:38 +0000 Subject: [PATCH 09/22] Changes for OpSpec return type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Imports (lines 63, 65): - Added OpSpec to the _op_schema import - Added tree_flatten to the _pytree import get_solution (lines 1012-1103): - Return type changed from dict[torch.fx.Node, DTensorSpec] to dict[torch.fx.Node, OpSpec] - Output spec construction is unchanged (same logic, now stored in output_specs local var) - After building output_specs, calls the new _build_input_specs helper - Returns OpSpec(output_specs=..., input_specs=...) per node - For nodes with no operands (placeholders, get_attr), input_specs is None _build_input_specs (lines 1105-1176) — new method: - For each operand oi in range(rule.num_operands): - Starts with [Replicate()] * mesh.ndim - For each factor assigned to mesh dim m: - operand_dims[oi] not None → Shard(dim) (spatial sharding) - operand_dims[oi] is None + is_reduction → Partial (reduction pass-through) - Otherwise → stays Replicate - Gets TensorMeta from the corresponding input node's meta["val"] - For getitem nodes consuming a multi-output producer, uses the getitem index to pick the correct tensor from the tuple for TensorMeta This matches the contract that apply_sharding.py expects: sharding_placement[node].output_specs for the node's output, and sharding_placement[node].input_specs[c] for the c-th tensor input (in tree_flatten(node.args) order, filtering for Nodes). Note that example_autoparallel_factor.py will need its comparison code updated — it currently accesses the factor solution directly as DTensorSpec/tuples, but now it should access .output_specs on the returned OpSpec, consistent with how it already handles the original optimizer's solution. --- autoparallel/optimize_sharding_new.py | 131 +++++++++++++++++++++----- 1 file changed, 107 insertions(+), 24 deletions(-) diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index 170e6849..50fcbca0 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -60,9 +60,9 @@ ) from torch.distributed._tensor.placement_types import TensorMeta from torch.distributed.tensor._dtensor_spec import DTensorSpec -from torch.distributed.tensor._op_schema import OpStrategy +from torch.distributed.tensor._op_schema import OpSpec, OpStrategy from torch.distributed.tensor.placement_types import Partial, Placement, Replicate, Shard -from torch.utils._pytree import tree_map_only +from torch.utils._pytree import tree_flatten, tree_map_only from .cost_models.compute_estimation import estimate_strategy_runtime_cost from .shardings.placement_options import get_placement_options @@ -1009,8 +1009,8 @@ def add_output_constraints( # Solve # ----------------------------------------------------------------- - def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, DTensorSpec]: - """Solve the factor ILP and reconstruct per-node DTensorSpecs.""" + def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, OpSpec]: + """Solve the factor ILP and reconstruct per-node OpSpecs.""" solver = pulp.PULP_CBC_CMD(msg=verbose) self.prob.solve(solver) @@ -1028,7 +1028,7 @@ def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, DTensorSpec assignment[root] = m # Reconstruct per-node placements. - result: dict[torch.fx.Node, DTensorSpec] = {} + result: dict[torch.fx.Node, OpSpec] = {} for node in self.graph.nodes: if node.op == "output": continue @@ -1036,28 +1036,28 @@ def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, DTensorSpec if rule is None: continue nidx = self.node_map[node] - placements: list[Placement] = [Replicate()] * self.mesh.ndim - - for li, fac in enumerate(rule.factors): - gid = self.factor_keys.get((nidx, li)) - if gid is None: - continue - root = self.uf.find(gid) - m = assignment.get(root) - if m is None: - continue - if fac.is_reduction: - # Reduction factor → output is Partial on this mesh dim. - placements[m] = Partial() - else: - # Use result_dims[0] for single-output nodes. - if fac.result_dims and fac.result_dims[0] is not None: - placements[m] = Shard(fac.result_dims[0]) + # --- Build output_specs --- + output_specs = None val = node.meta.get("val") + if val is not None and isinstance(val, torch.Tensor): + placements: list[Placement] = [Replicate()] * self.mesh.ndim + for li, fac in enumerate(rule.factors): + gid = self.factor_keys.get((nidx, li)) + if gid is None: + continue + root = self.uf.find(gid) + m = assignment.get(root) + if m is None: + continue + if fac.is_reduction: + placements[m] = Partial() + else: + if fac.result_dims and fac.result_dims[0] is not None: + placements[m] = Shard(fac.result_dims[0]) tensor_meta = TensorMeta(val.shape, val.stride(), val.dtype) - result[node] = DTensorSpec( + output_specs = DTensorSpec( self.mesh, tuple(placements), tensor_meta=tensor_meta ) elif val is not None and isinstance(val, (tuple, list)): @@ -1088,10 +1088,93 @@ def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, DTensorSpec ) else: specs.append(None) - result[node] = tuple(specs) + output_specs = tuple(specs) + + if output_specs is None: + continue + + # --- Build input_specs --- + input_specs = self._build_input_specs(node, rule, nidx, assignment) + + result[node] = OpSpec( + output_specs=output_specs, input_specs=input_specs or None + ) return result + def _build_input_specs( + self, + node: torch.fx.Node, + rule: FactorRule, + nidx: int, + assignment: dict[int, int], + ) -> list[DTensorSpec]: + """Reconstruct input DTensorSpecs from factor assignments. + + For each operand, the placement on each mesh dim is derived from the + factor assigned to that mesh dim: + + - ``operand_dims[oi]`` is not None → ``Shard(dim)`` + - ``operand_dims[oi]`` is None and ``is_reduction`` → ``Partial`` + - otherwise → ``Replicate`` + """ + if rule.num_operands == 0: + return [] + + # Tensor input nodes in tree_flatten order (matches operand indexing). + flat_args, _ = tree_flatten(node.args) + tensor_args = [a for a in flat_args if isinstance(a, torch.fx.Node)] + + input_specs: list[DTensorSpec] = [] + for oi in range(rule.num_operands): + inp_placements: list[Placement] = [Replicate()] * self.mesh.ndim + for li, fac in enumerate(rule.factors): + gid = self.factor_keys.get((nidx, li)) + if gid is None: + continue + root = self.uf.find(gid) + m_assigned = assignment.get(root) + if m_assigned is None: + continue + if oi < len(fac.operand_dims): + od = fac.operand_dims[oi] + if od is not None: + inp_placements[m_assigned] = Shard(od) + elif fac.is_reduction: + inp_placements[m_assigned] = Partial() + + # Get TensorMeta from the corresponding input node. + inp_tm = None + if oi < len(tensor_args): + arg_val = tensor_args[oi].meta.get("val") + if isinstance(arg_val, torch.Tensor): + inp_tm = TensorMeta( + arg_val.shape, arg_val.stride(), arg_val.dtype + ) + elif isinstance(arg_val, (tuple, list)): + # Multi-output producer (e.g. getitem consuming SDPA). + # Use the getitem index to find the correct tensor. + if ( + node.target is operator.getitem + and oi == 0 + and len(node.args) > 1 + and isinstance(node.args[1], int) + ): + idx = node.args[1] + if idx < len(arg_val) and isinstance( + arg_val[idx], torch.Tensor + ): + v = arg_val[idx] + inp_tm = TensorMeta(v.shape, v.stride(), v.dtype) + + input_specs.append( + DTensorSpec( + self.mesh, tuple(inp_placements), tensor_meta=inp_tm + ) + ) + + return input_specs + # ----------------------------------------------------------------- # Diagnostics # ----------------------------------------------------------------- From 45eed33fdc3a7173416759ba7d0fc5db6cfc1fc8 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 17 Feb 2026 13:52:03 +0000 Subject: [PATCH 10/22] Placeholder and get_attr nodes now get input_specs=[output_specs] (mirroring their output spec) instead of None. --- autoparallel/optimize_sharding_new.py | 13 +++++++++++-- examples/example_autoparallel_factor.py | 13 ++++++++----- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index 50fcbca0..3e8c0ccb 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -1094,10 +1094,19 @@ def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, OpSpec]: continue # --- Build input_specs --- - input_specs = self._build_input_specs(node, rule, nidx, assignment) + if node.op in ("placeholder", "get_attr"): + # Convention: placeholders use output_specs as input_specs. + if isinstance(output_specs, DTensorSpec): + input_specs = [output_specs] + else: + input_specs = None + else: + input_specs = self._build_input_specs( + node, rule, nidx, assignment + ) or None result[node] = OpSpec( - output_specs=output_specs, input_specs=input_specs or None + output_specs=output_specs, input_specs=input_specs ) return result diff --git a/examples/example_autoparallel_factor.py b/examples/example_autoparallel_factor.py index 310dcd77..7a47c4b9 100644 --- a/examples/example_autoparallel_factor.py +++ b/examples/example_autoparallel_factor.py @@ -155,6 +155,8 @@ def input_fn(): print(f" Solve time: {t_factor:.2f}s") print(factor_opt.get_log(verbose=True)) + # parallel_mod = autop.apply_placement(factor_solution) + # ------------------------------------------------------------------ # 3. Comparison # ------------------------------------------------------------------ @@ -198,11 +200,12 @@ def input_fn(): orig_plc = "?" else: orig_plc = "?" - if factor_spec is not None: - if isinstance(factor_spec, DTensorSpec): - factor_plc = tuple(factor_spec.placements) - elif isinstance(factor_spec, (list, tuple)) and factor_spec: - factor_plc = tuple(factor_spec[0].placements) + if factor_spec is not None and hasattr(factor_spec, "output_specs"): + os = factor_spec.output_specs + if isinstance(os, DTensorSpec): + factor_plc = tuple(os.placements) + elif isinstance(os, (list, tuple)) and os: + factor_plc = tuple(os[0].placements) else: factor_plc = "?" else: From c2331fc5d60a76272861d48ea20c475f1e068fda Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 17 Feb 2026 14:11:53 +0000 Subject: [PATCH 11/22] Add add_grad_param_constraints MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. Iterates over all (param, grad) pairs from get_param_and_grad_nodes 2. For each spatial factor of the param (matched by result_dims[0]), finds the corresponding factor in the grad with the same dimension index 3. If they have different roots (not already unified by the factor graph), adds equality constraints y[param_root, m] == y[grad_root, m] for all mesh dims — ensuring the same sharding decision applies to both --- autoparallel/optimize_sharding_new.py | 47 +++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index 3e8c0ccb..f63fcb32 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -55,6 +55,7 @@ import torch.fx from torch._functorch._aot_autograd.descriptors import PlainAOTInput, PlainAOTOutput from torch._functorch._aot_autograd.fx_utils import ( + get_param_and_grad_nodes, get_plain_input_and_grad_nodes, get_plain_output_and_tangent_nodes, ) @@ -949,6 +950,52 @@ def add_node_constraint( f"rep_{nidx}_r{root}_m{m}", ) + def add_grad_param_constraints(self) -> None: + """Ensure parameters and their gradients have matching placements. + + For each (param, grad) pair, constrains every spatial factor to have + the same mesh-dim assignment: ``y[param_root, m] == y[grad_root, m]`` + for all mesh dims ``m``. + """ + for param, grad in get_param_and_grad_nodes(self.graph).values(): + if grad is None: + continue + param_rule = self.factor_rules.get(param) + grad_rule = self.factor_rules.get(grad) + if param_rule is None or grad_rule is None: + continue + pidx = self.node_map[param] + gidx = self.node_map[grad] + + for p_li, p_fac in enumerate(param_rule.factors): + if not p_fac.result_dims or p_fac.result_dims[0] is None: + continue + p_dim = p_fac.result_dims[0] + pk = self.factor_keys.get((pidx, p_li)) + if pk is None: + continue + p_root = self.uf.find(pk) + + # Find the matching factor in the grad (same result dim). + for g_li, g_fac in enumerate(grad_rule.factors): + if not g_fac.result_dims or g_fac.result_dims[0] != p_dim: + continue + gk = self.factor_keys.get((gidx, g_li)) + if gk is None: + continue + g_root = self.uf.find(gk) + + if p_root != g_root: + for m in range(self.mesh.ndim): + pv = self.y_vars.get((p_root, m)) + gv = self.y_vars.get((g_root, m)) + if pv is not None and gv is not None: + self.prob += ( + pv == gv, + f"grad_param_{pidx}_{gidx}_d{p_dim}_m{m}", + ) + break + def add_input_constraints( self, input_placements: list[tuple[Placement, ...] | None] | None = None ) -> None: From 46121177d02225931e3a8f9c54e6d50a3c8910c4 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 17 Feb 2026 14:51:45 +0000 Subject: [PATCH 12/22] Add parameter memory constraint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Problem: A parameter's memory ratio is 1 / product(mesh_shape[m] for sharded mesh dims m) — a product of binary decisions, which is nonlinear in the y variables. Linearization: For each parameter, introduce a binary indicator variable b_S for each possible subset S of mesh dims (2^k variables where k = mesh.ndim). The constraints enforce: 1. Exactly one subset active (sum_S b_S = 1): the parameter is sharded on exactly one combination of mesh dims 2. Consistency with factor assignments (sum_{S: m∈S} b_S = s_m): subset S is active iff the parameter is sharded on exactly those mesh dims, where s_m = sum_fi y[root_fi, m] (already 0 or 1 due to tensor exclusion) 3. Memory contribution: each b_S contributes a precomputed ratio 1 / prod(mesh_shape[m] for m in S) — fully linear The final constraint matches the original: low * N <= sum_p ratio_p <= high * N, where N is the number of eligible parameters (those large enough to be fully sharded). For a 2D mesh, this adds 4 binary variables and 3 constraints per parameter. For a 3D mesh, 8 variables and 4 constraints. With ~10-20 parameters in a typical model, the overhead is negligible. --- autoparallel/optimize_sharding_new.py | 111 ++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index f63fcb32..38390b07 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -56,6 +56,7 @@ from torch._functorch._aot_autograd.descriptors import PlainAOTInput, PlainAOTOutput from torch._functorch._aot_autograd.fx_utils import ( get_param_and_grad_nodes, + get_param_nodes, get_plain_input_and_grad_nodes, get_plain_output_and_tangent_nodes, ) @@ -996,6 +997,116 @@ def add_grad_param_constraints(self) -> None: ) break + def add_parameter_memory_constraint( + self, memory_factor_low: float, memory_factor_high: float + ) -> None: + """Constrain total parameter memory to stay within specified bounds. + + Matches the semantics of + :meth:`ShardingOptimizer.add_parameter_memory_constraint`: the + constraint is on the sum of per-parameter memory ratios (1.0 = + replicated, 1/world_size = fully sharded), bounded by + ``[low * N, high * N]`` where *N* is the number of eligible + parameters. + + Because a parameter's memory ratio is a *product* over mesh dims + (nonlinear in the ``y`` variables), we linearize by introducing + auxiliary indicator variables ``b_S`` for each possible subset *S* + of mesh dims on which the parameter could be sharded. For typical + mesh dimensions (k = 2–4) this adds O(2^k) variables per parameter + — negligible in practice. + """ + import math + from itertools import combinations + + param_nodes: list[torch.fx.Node] = get_param_nodes(self.graph) + world_size: int = math.prod(self.mesh.shape) + k = self.mesh.ndim + + # Precompute all 2^k subsets of mesh dims. + all_subsets: list[frozenset[int]] = [] + for r in range(k + 1): + for subset in combinations(range(k), r): + all_subsets.append(frozenset(subset)) + + memory_terms: list[Any] = [] + num_params_to_consider: int = 0 + cid = 0 + + for node in param_nodes: + can_be_fully_sharded: bool = node.meta["val"].numel() >= world_size + num_params_to_consider += int(can_be_fully_sharded) + if not can_be_fully_sharded: + continue + + nidx = self.node_map[node] + rule = self.factor_rules.get(node) + if rule is None: + continue + + # s_m = sum_fi y[root_fi, m]: is param sharded on mesh dim m? + # (0 or 1 due to tensor-exclusion constraints.) + s_exprs: dict[int, list] = {} + for m in range(k): + terms: list = [] + seen_roots: set[int] = set() + for li, fac in enumerate(rule.factors): + if not fac.result_dims or fac.result_dims[0] is None: + continue + gid = self.factor_keys.get((nidx, li)) + if gid is None: + continue + root = self.uf.find(gid) + if root in seen_roots: + continue + seen_roots.add(root) + var = self.y_vars.get((root, m)) + if var is not None: + terms.append(var) + s_exprs[m] = terms + + # Indicator variable b_S for each subset S ⊆ {0, …, k-1}. + b_vars: dict[frozenset[int], pulp.LpVariable] = {} + for S in all_subsets: + tag = "".join(str(m) for m in sorted(S)) if S else "R" + b_vars[S] = pulp.LpVariable( + f"mem_{nidx}_{tag}", cat="Binary" + ) + + # Exactly one subset is active per parameter. + self.prob += ( + pulp.lpSum(b_vars.values()) == 1, + f"mem_one_{cid}", + ) + cid += 1 + + # Link b to the y variables: for each mesh dim m, + # sum_{S : m ∈ S} b_S = s_m + for m in range(k): + lhs = pulp.lpSum(b_vars[S] for S in all_subsets if m in S) + rhs = pulp.lpSum(s_exprs[m]) + self.prob += (lhs == rhs, f"mem_link_{cid}") + cid += 1 + + # Memory ratio contribution: b_S * (1 / prod_{m ∈ S} M_m). + for S in all_subsets: + shard_div = math.prod(self.mesh.shape[m] for m in S) if S else 1 + memory_terms.append(b_vars[S] * (1.0 / shard_div)) + + if not memory_terms or num_params_to_consider == 0: + return + + target_low = memory_factor_low * num_params_to_consider + target_high = memory_factor_high * num_params_to_consider + self.prob += ( + pulp.lpSum(memory_terms) >= target_low, + "memory_constraint_low", + ) + self.prob += ( + pulp.lpSum(memory_terms) <= target_high, + "memory_constraint_high", + ) + def add_input_constraints( self, input_placements: list[tuple[Placement, ...] | None] | None = None ) -> None: From 23bdcdde163d6c68a54b92cceec22c5b3ccdda62 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 17 Feb 2026 18:35:01 +0000 Subject: [PATCH 13/22] Remove merging of producer-consumer edges and allow for redistribution MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary of Changes File: autoparallel/optimize_sharding_new.py Added - FactorEdge dataclass — records producer→consumer factor relationships with producer_gid, consumer_gid, node indices, producer node reference, and kind ("spatial"/"reduction") - _add_reduction_propagation_constraints() — for each reduction edge (pk → ck): y[ck, m] <= y[pk, m], preventing Partial from being created from non-Partial input - _add_disabled_reduction_constraints() — forces disabled reduction gids to 0 across all mesh dims Modified - __init__ — replaced self.uf, self.factor_ops, _collect_factor_metadata() with self._factor_edges and self._disabled_reduction_gids - _build_factor_graph() — removed UF make_set/union calls; spatial matches now append FactorEdge objects; calls _record_reduction_edges() instead of _merge_reduction_factors() - _merge_reduction_factors() → _record_reduction_edges() — appends FactorEdge(kind="reduction") instead of UF union; tracks _disabled_reduction_gids when all_valid fails with null operand dims - _build_ilp() — uses all_gids (all factor keys) instead of UF roots; adds reduction propagation and disabled reduction constraints - _add_tensor_exclusion() — uses gid directly instead of self.uf.find(gid) - _add_objective() — rewritten with three components: (A) per-node compute benefit, (B) per-edge disagreement costs with linearized z-variables, (C) uncovered reduction exit costs - _get_spatial_roots_at_node() → _get_spatial_gids_at_node() — returns gids directly - add_node_constraint() — uses gid directly instead of UF roots - add_grad_param_constraints() — uses pk/gk directly; pk != gk always true without UF - add_parameter_memory_constraint() — uses gid directly - get_solution() — assignment is now dict[int, list[int]] (gid → list of mesh dims), supporting multi-dim assignments like (Shard(0), Shard(0)) - _build_input_specs() — iterates assignment.get(gid, []) instead of single mesh dim - get_stats() — reports num_edges; uses len(set(self.factor_keys.values())) for unique factors - get_log() — reports edge count; builds reverse lookup for verbose mode instead of using self.factor_ops Removed - _collect_factor_metadata(), _unique_roots(), _compute_redistribution_bytes(), _compute_reduction_exit_info() Kept (unused but harmless) - UnionFind class definition --- autoparallel/optimize_sharding_new.py | 593 +++++++++++++------------- 1 file changed, 307 insertions(+), 286 deletions(-) diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index 38390b07..87a294f4 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -129,6 +129,37 @@ class FactorRule: num_results: int +@dataclass +class FactorEdge: + """An edge between two factor instances across a producer→consumer dataflow edge. + + Instead of merging factors via union-find, each edge records the relationship + so the ILP can penalize disagreement (redistribution cost) independently. + + Attributes + ---------- + producer_gid : int + Global factor ID at the producer node. + consumer_gid : int + Global factor ID at the consumer node. + producer_nidx : int + Node index of the producer. + consumer_nidx : int + Node index of the consumer. + producer_node : torch.fx.Node + The producer FX node (used for output bytes estimation). + kind : str + ``"spatial"`` or ``"reduction"``. + """ + + producer_gid: int + consumer_gid: int + producer_nidx: int + consumer_nidx: int + producer_node: torch.fx.Node + kind: str # "spatial" or "reduction" + + # --------------------------------------------------------------------------- # Union-Find # --------------------------------------------------------------------------- @@ -364,20 +395,14 @@ def __init__( self.factor_rules: dict[torch.fx.Node, FactorRule] = {} self._extract_all_factor_rules() - # -- Step 3: merge factors across dataflow edges (union-find). - self.uf = UnionFind() + # -- Step 3: build factor edge graph (no merging; edges track relationships). self.factor_keys: dict[tuple[int, int], int] = {} # (node_idx, local) → gid self._next_gid = 0 + self._factor_edges: list[FactorEdge] = [] + self._disabled_reduction_gids: set[int] = set() self._build_factor_graph() - # -- Step 4: collect per-root metadata for cost model. - # root → [(node, factor, local_idx)] - self.factor_ops: dict[ - int, list[tuple[torch.fx.Node, Factor, int]] - ] = defaultdict(list) - self._collect_factor_metadata() - - # -- Step 5: build ILP. + # -- Step 4: build ILP. self._cost_cache: dict[torch.fx.Node, float] = {} self.prob = pulp.LpProblem("AutoParallel_Factor", pulp.LpMinimize) self.y_vars: dict[tuple[int, int], pulp.LpVariable] = {} @@ -435,7 +460,7 @@ def _extract_all_factor_rules(self) -> None: ) # ----------------------------------------------------------------- - # Step 3 — factor graph (union-find merging across edges) + # Step 3 — factor graph (edge recording across dataflow edges) # ----------------------------------------------------------------- def _alloc_gid(self) -> int: @@ -453,9 +478,8 @@ def _build_factor_graph(self) -> None: for li, _ in enumerate(rule.factors): gid = self._alloc_gid() self.factor_keys[(nidx, li)] = gid - self.uf.make_set(gid) - # Merge spatial factors across producer → consumer edges. + # Record spatial factor edges across producer → consumer dataflow. for node in self.graph.nodes: if node.op != "call_function": continue @@ -501,14 +525,15 @@ def _build_factor_graph(self) -> None: pk = self.factor_keys.get((pidx, p_li)) ck = self.factor_keys.get((cidx, c_li)) if pk is not None and ck is not None: - self.uf.union(pk, ck) + self._factor_edges.append( + FactorEdge(pk, ck, pidx, cidx, arg, "spatial") + ) - # Merge reduction factors (Partial → Partial pass-through) in a - # separate pass, after spatial merging is complete. - self._merge_reduction_factors() + # Record reduction factor edges in a separate pass. + self._record_reduction_edges() - def _merge_reduction_factors(self) -> None: - """Merge reduction factors across edges only when ALL operands agree. + def _record_reduction_edges(self) -> None: + """Record reduction factor edges across dataflow edges. For data-preserving ops (view, permute, alias, …), a Partial can propagate through the op because the strategy has an atom @@ -517,9 +542,10 @@ def _merge_reduction_factors(self) -> None: For multi-operand ops like add/mul, the reduction factor has ``operand_dims=[None, None]`` — ALL operands must be Partial. - We must only merge when every operand that maps to ``None`` in the - factor can provide a Partial from its producer. Otherwise, the - resulting placement (e.g. ``add(Partial, Shard)``) is invalid. + We record edges only when every operand that maps to ``None`` in the + factor can provide a Partial from its producer. Otherwise, if the + factor has null operand dims (propagation type), we disable the + consumer's reduction factor. """ for node in self.graph.nodes: if node.op != "call_function": @@ -538,13 +564,17 @@ def _merge_reduction_factors(self) -> None: # Collect producer reduction keys for each operand that # needs Partial, and validate that ALL can provide it. - merge_pairs: list[tuple[int, int]] = [] + # Track (pk, ck, p_nidx, arg) for edge recording. + merge_pairs: list[tuple[int, int, int, torch.fx.Node]] = [] all_valid = True + has_null_operand_dims = False for arg_pos, c_od in enumerate(c_fac.operand_dims): if c_od is not None: continue # Spatial dim on this operand, not Partial + has_null_operand_dims = True + # This operand must be Partial for the factor to propagate. if arg_pos >= len(node.args): all_valid = False @@ -565,7 +595,7 @@ def _merge_reduction_factors(self) -> None: if p_fac.is_reduction: pk = self.factor_keys.get((pidx, p_li)) if pk is not None: - merge_pairs.append((pk, ck)) + merge_pairs.append((pk, ck, pidx, arg)) found = True break if not found: @@ -573,40 +603,25 @@ def _merge_reduction_factors(self) -> None: break if all_valid and merge_pairs: - for pk, ck_val in merge_pairs: - self.uf.union(pk, ck_val) - - # ----------------------------------------------------------------- - # Step 4 — metadata collection - # ----------------------------------------------------------------- - - def _collect_factor_metadata(self) -> None: - for node in self.graph.nodes: - rule = self.factor_rules.get(node) - if rule is None: - continue - nidx = self.node_map[node] - for li, fac in enumerate(rule.factors): - gid = self.factor_keys.get((nidx, li)) - if gid is None: - continue - root = self.uf.find(gid) - self.factor_ops[root].append((node, fac, li)) - - def _unique_roots(self) -> set[int]: - return {self.uf.find(gid) for gid in self.factor_keys.values()} + for pk, ck_val, p_nidx, arg in merge_pairs: + self._factor_edges.append( + FactorEdge(pk, ck_val, p_nidx, cidx, arg, "reduction") + ) + elif not all_valid and has_null_operand_dims: + # Can't propagate Partial through this consumer factor. + self._disabled_reduction_gids.add(ck) # ----------------------------------------------------------------- - # Step 5 — ILP construction + # Step 4 — ILP construction # ----------------------------------------------------------------- def _build_ilp(self) -> None: - roots = self._unique_roots() + all_gids = set(self.factor_keys.values()) - # --- Variables: y[root, mesh_dim] ∈ {0, 1} --- - for r in roots: + # --- Variables: y[gid, mesh_dim] ∈ {0, 1} --- + for gid in all_gids: for m in range(self.mesh.ndim): - self.y_vars[(r, m)] = pulp.LpVariable(f"y_{r}_m{m}", cat="Binary") + self.y_vars[(gid, m)] = pulp.LpVariable(f"y_{gid}_m{m}", cat="Binary") # --- Constraints --- # NOTE: we intentionally omit a factor-uniqueness constraint here. @@ -616,18 +631,20 @@ def _build_ilp(self) -> None: # exclusion constraint already prevents invalid combos (two different # factors claiming the same tensor dim on the same mesh dim). self._add_tensor_exclusion() + self._add_reduction_propagation_constraints() + self._add_disabled_reduction_constraints() # --- Objective --- - self._add_objective(roots) + self._add_objective() # ---- constraints ------------------------------------------------ - def _add_factor_uniqueness(self, roots: set[int]) -> None: + def _add_factor_uniqueness(self, gids: set[int]) -> None: """Each factor is assigned to *at most one* mesh dimension.""" - for r in roots: + for gid in gids: self.prob += ( - pulp.lpSum(self.y_vars[(r, m)] for m in range(self.mesh.ndim)) <= 1, - f"fac_uniq_{r}", + pulp.lpSum(self.y_vars[(gid, m)] for m in range(self.mesh.ndim)) <= 1, + f"fac_uniq_{gid}", ) def _add_tensor_exclusion(self) -> None: @@ -636,11 +653,8 @@ def _add_tensor_exclusion(self) -> None: This encodes the DTensor invariant: a tensor dimension can only appear as ``Shard(d)`` for a single ``d`` on each mesh dimension. - Important: multiple factors at the same node may share a root (after - union-find merging, e.g. nheads and head_dim from unflatten both map to - the hidden input dimension). We must deduplicate by root to avoid - counting the same ILP variable twice, which would turn ``sum <= 1`` - into ``2*y <= 1`` and incorrectly force that variable to 0. + Without union-find, gids are unique per (node, factor), so dedup is + technically unnecessary but harmless to keep. """ cid = 0 for node in self.graph.nodes: @@ -653,7 +667,7 @@ def _add_tensor_exclusion(self) -> None: for ri in range(rule.num_results): for m in range(self.mesh.ndim): vs = [] - seen_roots: set[int] = set() + seen_gids: set[int] = set() for li, fac in enumerate(rule.factors): # Include both spatial and reduction factors: a # tensor can only be Shard(d) OR Partial on each @@ -665,10 +679,9 @@ def _add_tensor_exclusion(self) -> None: if has_spatial or fac.is_reduction: gid = self.factor_keys.get((nidx, li)) if gid is not None: - root = self.uf.find(gid) - if root not in seen_roots: - seen_roots.add(root) - vs.append(self.y_vars[(root, m)]) + if gid not in seen_gids: + seen_gids.add(gid) + vs.append(self.y_vars[(gid, m)]) if len(vs) > 1: self.prob += pulp.lpSum(vs) <= 1, f"tex_r_{cid}" cid += 1 @@ -677,7 +690,7 @@ def _add_tensor_exclusion(self) -> None: for oi in range(rule.num_operands): for m in range(self.mesh.ndim): vs = [] - seen_roots: set[int] = set() + seen_gids: set[int] = set() for li, fac in enumerate(rule.factors): if ( oi < len(fac.operand_dims) @@ -685,113 +698,188 @@ def _add_tensor_exclusion(self) -> None: ): gid = self.factor_keys.get((nidx, li)) if gid is not None: - root = self.uf.find(gid) - if root not in seen_roots: - seen_roots.add(root) - vs.append(self.y_vars[(root, m)]) + if gid not in seen_gids: + seen_gids.add(gid) + vs.append(self.y_vars[(gid, m)]) if len(vs) > 1: self.prob += pulp.lpSum(vs) <= 1, f"tex_o_{cid}" cid += 1 + def _add_reduction_propagation_constraints(self) -> None: + """Consumer reduction can only be active if producer reduction is. + + For each reduction edge (pk → ck): y[ck, m] <= y[pk, m]. + This prevents creating Partial from non-Partial input. + """ + cid = 0 + for edge in self._factor_edges: + if edge.kind != "reduction": + continue + for m in range(self.mesh.ndim): + self.prob += ( + self.y_vars[(edge.consumer_gid, m)] + <= self.y_vars[(edge.producer_gid, m)], + f"red_prop_{cid}", + ) + cid += 1 + + def _add_disabled_reduction_constraints(self) -> None: + """Force disabled reduction factors to 0 (can't propagate Partial).""" + cid = 0 + for gid in self._disabled_reduction_gids: + for m in range(self.mesh.ndim): + if (gid, m) in self.y_vars: + self.prob += self.y_vars[(gid, m)] == 0, f"dis_red_{cid}" + cid += 1 + # ---- objective -------------------------------------------------- - def _add_objective(self, roots: set[int]) -> None: + def _add_objective(self) -> None: """Build the cost function. - For each factor *f* assigned to mesh dim *m* the cost coefficient - includes three components: + Three components: - 1. **Compute benefit** (all factors): sharding any dimension divides - work by ``mesh.shape[m]``. - 2. **Redistribution penalty** (reduction factors at "exit" edges): - when a ``Partial`` output reaches a consumer that doesn't share - the reduction root, redistribution is needed. The exact cost - depends on the consumer's placement on that mesh dim: + 1. **Compute benefit** (per node/factor/mesh_dim): sharding any + dimension divides work by ``mesh.shape[m]``. + 2. **Edge disagreement costs** (per FactorEdge per mesh_dim): - - **Partial → Shard** (reduce-scatter): B·(n-1)/n - - **Partial → Replicate** (all-reduce): 2B·(n-1)/n + - *Spatial edges*: all-gather when producer shards but consumer + doesn't. Cost = B·(n-1)/n. + - *Reduction edges*: reduce-scatter/all-reduce when producer is + Partial but consumer isn't. - This is captured exactly via auxiliary continuous variables that - linearize the product ``y[r,m] · (1 - any_consumer_spatial_on_m)``. - 3. **All-gather penalty** (spatial factors at "exit" edges): when a - producer is ``Shard(d)`` on mesh dim *m* but a consumer doesn't - share that factor (via union-find), an all-gather is needed. - Cost ≈ B·(n-1)/n. - - All three are linear in the ``y`` and ``z`` variables, keeping the - ILP linear. + 3. **Uncovered reduction exit costs**: for reduction factors with NO + outgoing reduction edge to a consumer, add Partial→{Shard,Replicate} + linearized cost directly. """ - ag_bytes, _rs_bytes = self._compute_redistribution_bytes() - exit_info = self._compute_reduction_exit_info() terms: list[Any] = [] - for r in roots: - refs = self.factor_ops.get(r, []) - for m in range(self.mesh.ndim): - mesh_size = self.mesh.shape[m] - var = self.y_vars[(r, m)] - cost = 0.0 - - for node, fac, _ in refs: - if node.op != "call_function": - continue + # --- A) Compute benefit --- + for node in self.graph.nodes: + if node.op != "call_function": + continue + rule = self.factor_rules.get(node) + if rule is None: + continue + nidx = self.node_map[node] + compute = self._compute_cost(node) + for li, fac in enumerate(rule.factors): + gid = self.factor_keys.get((nidx, li)) + if gid is None: + continue + for m in range(self.mesh.ndim): + benefit = compute * (1.0 - 1.0 / self.mesh.shape[m]) + if benefit > 0: + terms.append(-benefit * self.y_vars[(gid, m)]) - # Compute benefit: work is divided by mesh_size - # regardless of whether the factor is a reduction - # or spatial dimension. - compute = self._compute_cost(node) - benefit = compute * (1.0 - 1.0 / mesh_size) - cost -= benefit - - # All-gather penalty at spatial exit edges. - if r in ag_bytes: - ag_comm = ag_bytes[r] * (mesh_size - 1) / mesh_size - cost += ag_comm / self._BW * 1e6 - - if cost != 0.0: - terms.append(cost * var) - - # Reduction exit edges: linearized Partial → {Shard, Replicate} cost. - # - # For each (reduction_root r, consumer u) exit edge, on mesh dim m: - # - # base cost = B·(n-1)/n · y[r,m] (reduce-scatter) - # extra cost = B·(n-1)/n · z (upgrade to all-reduce) - # - # where z is a continuous auxiliary variable satisfying: - # z ≥ y[r,m] − Σ_s y[s,m] for consumer's spatial roots s - # z ≥ 0 (implicit from lowBound=0) - # - # Since z has a positive coefficient and we minimize, the solver - # sets z = max(0, y[r,m] − Σ y[s,m]). - # - # • Consumer has spatial factor on m (Σ≥1) → z=0, total = B (reduce-scatter) - # • Consumer fully replicated on m (Σ=0) → z=y, total = 2B (all-reduce) + # --- B) Edge disagreement costs --- z_id = 0 - for (r, uidx), bytes_val in exit_info.items(): - consumer_spatial = self._get_spatial_roots_at_node(uidx) + for edge in self._factor_edges: + bytes_val = self._output_bytes(edge.producer_node) + if bytes_val <= 0: + continue + for m in range(self.mesh.ndim): mesh_size = self.mesh.shape[m] - y_r_m = self.y_vars[(r, m)] - - # Base reduce-scatter cost (always incurred when y[r,m]=1). comm_unit = bytes_val * (mesh_size - 1) / mesh_size / self._BW * 1e6 - terms.append(comm_unit * y_r_m) - - # Extra cost for Partial → Replicate (linearized). - valid_roots = [s for s in consumer_spatial if (s, m) in self.y_vars] - if valid_roots: - z = pulp.LpVariable(f"z_pr_{z_id}", lowBound=0) - spatial_sum = pulp.lpSum( - self.y_vars[(s, m)] for s in valid_roots - ) - self.prob += z >= y_r_m - spatial_sum, f"z_pr_lb_{z_id}" + + y_prod = self.y_vars[(edge.producer_gid, m)] + y_cons = self.y_vars[(edge.consumer_gid, m)] + + if edge.kind == "spatial": + # All-gather when producer shards but consumer doesn't: + # z >= y[producer, m] - y[consumer, m]; z >= 0 + z = pulp.LpVariable(f"z_ag_{z_id}", lowBound=0) + self.prob += z >= y_prod - y_cons, f"z_ag_lb_{z_id}" terms.append(comm_unit * z) z_id += 1 else: - # No spatial factors at consumer → always all-reduce. - # Extra cost = B·(n-1)/n · y[r,m] (doubling the base). - terms.append(comm_unit * y_r_m) + # Reduction edge: Partial exit when producer is Partial + # but consumer isn't. + # z_exit >= y[producer, m] - y[consumer, m]; z_exit >= 0 + z_exit = pulp.LpVariable(f"z_re_{z_id}", lowBound=0) + self.prob += z_exit >= y_prod - y_cons, f"z_re_lb_{z_id}" + # Base reduce-scatter cost. + terms.append(comm_unit * z_exit) + + # Upgrade to all-reduce when consumer has no spatial + # factor on this mesh dim: + # z_ar >= z_exit - Σ_s y[s, m]; z_ar >= 0 + consumer_spatial = self._get_spatial_gids_at_node(edge.consumer_nidx) + valid_gids = [s for s in consumer_spatial if (s, m) in self.y_vars] + if valid_gids: + z_ar = pulp.LpVariable(f"z_ar_{z_id}", lowBound=0) + spatial_sum = pulp.lpSum( + self.y_vars[(s, m)] for s in valid_gids + ) + self.prob += z_ar >= z_exit - spatial_sum, f"z_ar_lb_{z_id}" + terms.append(comm_unit * z_ar) + else: + # No spatial factors → always all-reduce (double cost). + terms.append(comm_unit * z_exit) + z_id += 1 + + # --- C) Uncovered reduction exit costs --- + # For reduction factors at nodes whose Partial output has NO outgoing + # reduction edge to a consumer, model the exit cost directly. + covered_reduction_pairs: set[tuple[int, int]] = set() + for edge in self._factor_edges: + if edge.kind == "reduction": + covered_reduction_pairs.add( + (edge.producer_nidx, edge.consumer_nidx) + ) + + for node in self.graph.nodes: + rule = self.factor_rules.get(node) + if rule is None: + continue + nidx = self.node_map[node] + for li, fac in enumerate(rule.factors): + if not fac.is_reduction: + continue + gid = self.factor_keys.get((nidx, li)) + if gid is None: + continue + + for user in node.users: + if user.op != "call_function": + continue + uidx = self.node_map.get(user) + if uidx is None: + continue + if (nidx, uidx) in covered_reduction_pairs: + continue # handled by edge disagreement above + + bytes_val = self._output_bytes(node) + if bytes_val <= 0: + continue + + consumer_spatial = self._get_spatial_gids_at_node(uidx) + for m in range(self.mesh.ndim): + mesh_size = self.mesh.shape[m] + y_r_m = self.y_vars[(gid, m)] + comm_unit = ( + bytes_val * (mesh_size - 1) / mesh_size / self._BW * 1e6 + ) + + # Base reduce-scatter cost. + terms.append(comm_unit * y_r_m) + + # Extra cost for Partial → Replicate (linearized). + valid_gids = [ + s for s in consumer_spatial if (s, m) in self.y_vars + ] + if valid_gids: + z = pulp.LpVariable(f"z_ure_{z_id}", lowBound=0) + spatial_sum = pulp.lpSum( + self.y_vars[(s, m)] for s in valid_gids + ) + self.prob += z >= y_r_m - spatial_sum, f"z_ure_lb_{z_id}" + terms.append(comm_unit * z) + z_id += 1 + else: + # No spatial factors → always all-reduce. + terms.append(comm_unit * y_r_m) self._num_z_vars = z_id @@ -828,87 +916,19 @@ def _compute_cost(self, node: torch.fx.Node) -> float: self._cost_cache[node] = cost return cost - def _compute_redistribution_bytes( - self, - ) -> tuple[dict[int, float], dict[int, float]]: - """For each factor root, total output bytes at "exit" edges. - - Returns ``(ag_bytes, rs_bytes)``: - - * **ag_bytes** — for *spatial* factors: bytes needing an all-gather - at edges where the consumer doesn't share the root. - * **rs_bytes** — for *reduction* factors: bytes needing a - reduce-scatter at edges where the ``Partial`` doesn't propagate - to the consumer. - """ - # node_idx → set of factor roots at that node - node_roots: dict[int, set[int]] = defaultdict(set) - for (nidx, _li), gid in self.factor_keys.items(): - node_roots[nidx].add(self.uf.find(gid)) - - ag_bytes: dict[int, float] = defaultdict(float) - rs_bytes: dict[int, float] = defaultdict(float) - for root, refs in self.factor_ops.items(): - for node, fac, _li in refs: - for user in node.users: - if user.op != "call_function": - continue - uidx = self.node_map.get(user) - if uidx is None: - continue - if root in node_roots.get(uidx, set()): - continue # factor propagates — no redistribution - - if fac.is_reduction: - # Partial exits here → reduce-scatter - rs_bytes[root] += self._output_bytes(node) - elif any(d is not None for d in fac.result_dims): - # Shard exits here → all-gather - ag_bytes[root] += self._output_bytes(node) - - return dict(ag_bytes), dict(rs_bytes) - - def _compute_reduction_exit_info(self) -> dict[tuple[int, int], float]: - """For each (reduction_root, consumer_nidx) pair, total bytes at exits. - - Used by the linearized Partial → Replicate cost model to distinguish - reduce-scatter (consumer is Shard) from all-reduce (consumer is - Replicate) on each mesh dimension. - """ - node_roots: dict[int, set[int]] = defaultdict(set) - for (nidx, _li), gid in self.factor_keys.items(): - node_roots[nidx].add(self.uf.find(gid)) - - exit_info: dict[tuple[int, int], float] = defaultdict(float) - for root, refs in self.factor_ops.items(): - for node, fac, _li in refs: - if not fac.is_reduction: - continue - for user in node.users: - if user.op != "call_function": - continue - uidx = self.node_map.get(user) - if uidx is None: - continue - if root in node_roots.get(uidx, set()): - continue # factor propagates — no redistribution - exit_info[(root, uidx)] += self._output_bytes(node) - - return dict(exit_info) - - def _get_spatial_roots_at_node(self, nidx: int) -> set[int]: - """Get unique roots for spatial (non-reduction) result factors at a node.""" + def _get_spatial_gids_at_node(self, nidx: int) -> set[int]: + """Get gids for spatial (non-reduction) result factors at a node.""" node = self.nodes[nidx] rule = self.factor_rules.get(node) if rule is None: return set() - roots: set[int] = set() + gids: set[int] = set() for li, fac in enumerate(rule.factors): if not fac.is_reduction and any(d is not None for d in fac.result_dims): gid = self.factor_keys.get((nidx, li)) if gid is not None: - roots.add(self.uf.find(gid)) - return roots + gids.add(gid) + return gids # ----------------------------------------------------------------- # User constraints @@ -931,31 +951,29 @@ def add_node_constraint( if any(d == p.dim for d in fac.result_dims): gid = self.factor_keys.get((nidx, li)) if gid is not None: - root = self.uf.find(gid) self.prob += ( - self.y_vars[(root, m)] == 1, + self.y_vars[(gid, m)] == 1, f"pin_{nidx}_f{li}_m{m}", ) break elif isinstance(p, Replicate): - seen_roots: set[int] = set() + seen_gids: set[int] = set() for li, fac in enumerate(rule.factors): if any(d is not None for d in fac.result_dims): gid = self.factor_keys.get((nidx, li)) if gid is not None: - root = self.uf.find(gid) - if root not in seen_roots: - seen_roots.add(root) + if gid not in seen_gids: + seen_gids.add(gid) self.prob += ( - self.y_vars[(root, m)] == 0, - f"rep_{nidx}_r{root}_m{m}", + self.y_vars[(gid, m)] == 0, + f"rep_{nidx}_g{gid}_m{m}", ) def add_grad_param_constraints(self) -> None: """Ensure parameters and their gradients have matching placements. For each (param, grad) pair, constrains every spatial factor to have - the same mesh-dim assignment: ``y[param_root, m] == y[grad_root, m]`` + the same mesh-dim assignment: ``y[pk, m] == y[gk, m]`` for all mesh dims ``m``. """ for param, grad in get_param_and_grad_nodes(self.graph).values(): @@ -975,7 +993,6 @@ def add_grad_param_constraints(self) -> None: pk = self.factor_keys.get((pidx, p_li)) if pk is None: continue - p_root = self.uf.find(pk) # Find the matching factor in the grad (same result dim). for g_li, g_fac in enumerate(grad_rule.factors): @@ -984,12 +1001,11 @@ def add_grad_param_constraints(self) -> None: gk = self.factor_keys.get((gidx, g_li)) if gk is None: continue - g_root = self.uf.find(gk) - if p_root != g_root: + if pk != gk: for m in range(self.mesh.ndim): - pv = self.y_vars.get((p_root, m)) - gv = self.y_vars.get((g_root, m)) + pv = self.y_vars.get((pk, m)) + gv = self.y_vars.get((gk, m)) if pv is not None and gv is not None: self.prob += ( pv == gv, @@ -1044,23 +1060,22 @@ def add_parameter_memory_constraint( if rule is None: continue - # s_m = sum_fi y[root_fi, m]: is param sharded on mesh dim m? + # s_m = sum_fi y[gid_fi, m]: is param sharded on mesh dim m? # (0 or 1 due to tensor-exclusion constraints.) s_exprs: dict[int, list] = {} for m in range(k): terms: list = [] - seen_roots: set[int] = set() + seen_gids: set[int] = set() for li, fac in enumerate(rule.factors): if not fac.result_dims or fac.result_dims[0] is None: continue gid = self.factor_keys.get((nidx, li)) if gid is None: continue - root = self.uf.find(gid) - if root in seen_roots: + if gid in seen_gids: continue - seen_roots.add(root) - var = self.y_vars.get((root, m)) + seen_gids.add(gid) + var = self.y_vars.get((gid, m)) if var is not None: terms.append(var) s_exprs[m] = terms @@ -1180,10 +1195,11 @@ def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, OpSpec]: ) # Extract factor → mesh-dim assignments. - assignment: dict[int, int] = {} # root → mesh_dim - for (root, m), var in self.y_vars.items(): + # A gid can map to multiple mesh dims (e.g. (Shard(0), Shard(0))). + assignment: dict[int, list[int]] = defaultdict(list) + for (gid, m), var in self.y_vars.items(): if var.varValue is not None and var.varValue > 0.5: - assignment[root] = m + assignment[gid].append(m) # Reconstruct per-node placements. result: dict[torch.fx.Node, OpSpec] = {} @@ -1205,15 +1221,12 @@ def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, OpSpec]: gid = self.factor_keys.get((nidx, li)) if gid is None: continue - root = self.uf.find(gid) - m = assignment.get(root) - if m is None: - continue - if fac.is_reduction: - placements[m] = Partial() - else: - if fac.result_dims and fac.result_dims[0] is not None: - placements[m] = Shard(fac.result_dims[0]) + for m in assignment.get(gid, []): + if fac.is_reduction: + placements[m] = Partial() + else: + if fac.result_dims and fac.result_dims[0] is not None: + placements[m] = Shard(fac.result_dims[0]) tensor_meta = TensorMeta(val.shape, val.stride(), val.dtype) output_specs = DTensorSpec( self.mesh, tuple(placements), tensor_meta=tensor_meta @@ -1229,17 +1242,14 @@ def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, OpSpec]: gid = self.factor_keys.get((nidx, li)) if gid is None: continue - root = self.uf.find(gid) - m_assigned = assignment.get(root) - if m_assigned is None: - continue - if fac.is_reduction: - plc_list[m_assigned] = Partial() - elif ( - ri < len(fac.result_dims) - and fac.result_dims[ri] is not None - ): - plc_list[m_assigned] = Shard(fac.result_dims[ri]) + for m_assigned in assignment.get(gid, []): + if fac.is_reduction: + plc_list[m_assigned] = Partial() + elif ( + ri < len(fac.result_dims) + and fac.result_dims[ri] is not None + ): + plc_list[m_assigned] = Shard(fac.result_dims[ri]) tm = TensorMeta(v.shape, v.stride(), v.dtype) specs.append( DTensorSpec(self.mesh, tuple(plc_list), tensor_meta=tm) @@ -1274,7 +1284,7 @@ def _build_input_specs( node: torch.fx.Node, rule: FactorRule, nidx: int, - assignment: dict[int, int], + assignment: dict[int, list[int]], ) -> list[DTensorSpec]: """Reconstruct input DTensorSpecs from factor assignments. @@ -1299,16 +1309,13 @@ def _build_input_specs( gid = self.factor_keys.get((nidx, li)) if gid is None: continue - root = self.uf.find(gid) - m_assigned = assignment.get(root) - if m_assigned is None: - continue - if oi < len(fac.operand_dims): - od = fac.operand_dims[oi] - if od is not None: - inp_placements[m_assigned] = Shard(od) - elif fac.is_reduction: - inp_placements[m_assigned] = Partial() + for m_assigned in assignment.get(gid, []): + if oi < len(fac.operand_dims): + od = fac.operand_dims[oi] + if od is not None: + inp_placements[m_assigned] = Shard(od) + elif fac.is_reduction: + inp_placements[m_assigned] = Partial() # Get TensorMeta from the corresponding input node. inp_tm = None @@ -1384,7 +1391,7 @@ def _infeasibility_diagnostics(self) -> str: def get_stats(self) -> dict[str, Any]: """Return ILP size statistics (useful for comparing with original).""" - roots = self._unique_roots() + num_unique_factors = len(set(self.factor_keys.values())) # Estimate original variable count. orig_vars = 0 @@ -1400,7 +1407,8 @@ def get_stats(self) -> dict[str, Any]: n_aux_vars = getattr(self, "_num_z_vars", 0) return { "num_graph_nodes": len(self.nodes), - "num_unique_factors": len(roots), + "num_unique_factors": num_unique_factors, + "num_edges": len(self._factor_edges), "num_factor_ilp_variables": n_factor_vars + n_aux_vars, "num_factor_y_variables": n_factor_vars, "num_factor_z_variables": n_aux_vars, @@ -1416,6 +1424,7 @@ def get_log(self, verbose: bool = False) -> str: lines.append(f"Factor ILP status: {pulp.LpStatus[self.prob.status]}") s = self.get_stats() lines.append(f"Unique factors: {s['num_unique_factors']}") + lines.append(f"Factor edges: {s['num_edges']}") lines.append(f"Factor ILP variables: {s['num_factor_ilp_variables']} ({s['num_factor_y_variables']} y + {s['num_factor_z_variables']} z)") lines.append(f"Factor ILP constraints: {s['num_factor_ilp_constraints']}") lines.append( @@ -1424,16 +1433,28 @@ def get_log(self, verbose: bool = False) -> str: lines.append(f"Variable reduction: {s['variable_reduction_ratio']:.1f}x") if verbose and self.prob.status == 1: + # Build reverse lookup: gid → (node, factor, local_idx) + gid_to_info: dict[int, tuple[torch.fx.Node, Factor, int]] = {} + for node in self.graph.nodes: + rule = self.factor_rules.get(node) + if rule is None: + continue + nidx = self.node_map[node] + for li, fac in enumerate(rule.factors): + gid = self.factor_keys.get((nidx, li)) + if gid is not None and gid not in gid_to_info: + gid_to_info[gid] = (node, fac, li) + lines.append("") lines.append("Factor assignments:") - for (root, m), var in sorted(self.y_vars.items()): + for (gid, m), var in sorted(self.y_vars.items()): if var.varValue is not None and var.varValue > 0.5: - refs = self.factor_ops.get(root, []) + info = gid_to_info.get(gid) desc = "" - if refs: - _, fac, _ = refs[0] + if info is not None: + _, fac, _ = info kind = "reduction" if fac.is_reduction else "spatial" desc = f" ({kind}, size={fac.size})" - lines.append(f" Factor {root} → mesh dim {m}{desc}") + lines.append(f" Factor {gid} → mesh dim {m}{desc}") return "\n".join(lines) From 82a91d8a9e097484ccd20775e4425c7942f2b6a3 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Tue, 17 Feb 2026 20:10:35 +0000 Subject: [PATCH 14/22] Fixes - The factor uniqueness constraint is the real fix for the original view error ((S(0), S(0)) can no longer be chosen by the solver). - The _build_input_specs must use the consumer's own factor assignments, not the producer's output_specs. This is because apply_sharding.py relies on input_specs to tell it what placement the op needs its inputs in, then redistributes from curr_spec (producer output) to tgt_spec (consumer input). When I was copying producer specs, redistribution was skipped, leaving inputs with incompatible shardings for ops like mul in RMS norm. --- autoparallel/optimize_sharding_new.py | 33 ++++++++++++++++++++------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index 87a294f4..40df59a8 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -624,12 +624,12 @@ def _build_ilp(self) -> None: self.y_vars[(gid, m)] = pulp.LpVariable(f"y_{gid}_m{m}", cat="Binary") # --- Constraints --- - # NOTE: we intentionally omit a factor-uniqueness constraint here. - # A factor MAY be assigned to multiple mesh dims simultaneously, - # which corresponds to placements like (Shard(0), Shard(0)) where a - # single tensor dim is sharded across several mesh dims. The tensor - # exclusion constraint already prevents invalid combos (two different - # factors claiming the same tensor dim on the same mesh dim). + # Factor uniqueness: each factor is assigned to at most one mesh dim. + # This prevents the solver from choosing multi-dim assignments like + # (Shard(0), Shard(0)) on its own, which can produce invalid placements. + # When the user explicitly pins (S(0), S(0)) via add_node_constraint, + # the uniqueness constraint for that factor is relaxed. + self._add_factor_uniqueness(all_gids) self._add_tensor_exclusion() self._add_reduction_propagation_constraints() self._add_disabled_reduction_constraints() @@ -945,6 +945,11 @@ def add_node_constraint( return nidx = self.node_map[node] + # Track which gids are pinned to 1 on which mesh dims, so we can + # relax the factor uniqueness constraint for multi-dim pins like + # (Shard(0), Shard(0)). + pinned_gids: dict[int, list[int]] = defaultdict(list) + for m, p in enumerate(placement): if isinstance(p, Shard): for li, fac in enumerate(rule.factors): @@ -955,6 +960,7 @@ def add_node_constraint( self.y_vars[(gid, m)] == 1, f"pin_{nidx}_f{li}_m{m}", ) + pinned_gids[gid].append(m) break elif isinstance(p, Replicate): seen_gids: set[int] = set() @@ -969,6 +975,13 @@ def add_node_constraint( f"rep_{nidx}_g{gid}_m{m}", ) + # Relax factor uniqueness for gids pinned to multiple mesh dims. + for gid, mesh_dims in pinned_gids.items(): + if len(mesh_dims) > 1: + constraint_name = f"fac_uniq_{gid}" + if constraint_name in self.prob.constraints: + del self.prob.constraints[constraint_name] + def add_grad_param_constraints(self) -> None: """Ensure parameters and their gradients have matching placements. @@ -1286,14 +1299,18 @@ def _build_input_specs( nidx: int, assignment: dict[int, list[int]], ) -> list[DTensorSpec]: - """Reconstruct input DTensorSpecs from factor assignments. + """Reconstruct input DTensorSpecs from the consumer's factor assignments. For each operand, the placement on each mesh dim is derived from the - factor assigned to that mesh dim: + consumer node's factor assigned to that mesh dim: - ``operand_dims[oi]`` is not None → ``Shard(dim)`` - ``operand_dims[oi]`` is None and ``is_reduction`` → ``Partial`` - otherwise → ``Replicate`` + + The resulting input_specs tell ``apply_sharding`` what placement this + op needs its inputs in. If a producer's output_specs differ, + ``apply_sharding`` will redistribute automatically. """ if rule.num_operands == 0: return [] From 46d466929b0b4c47fda39805d8f2c9c8a53975c4 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 18 Feb 2026 08:11:26 +0000 Subject: [PATCH 15/22] Use PyTorch's communication cost MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Here's a summary of the changes made: 1. Added imports from torch.distributed.tensor._collective_utils: MeshTopoInfo, allgather_cost, allreduce_cost, reduce_scatter_cost 2. Added self.mesh_topo in __init__: MeshTopoInfo.build_from_mesh(mesh) — built once, reused throughout 3. Replaced all comm_unit = bytes_val * (mesh_size - 1) / mesh_size / self._BW * 1e6 with the appropriate collective cost function: - Spatial edges: allgather_cost(bytes_gb, mesh_topo, m) - Reduction edges: reduce_scatter_cost(bytes_gb, mesh_topo, m) as base, with allreduce_cost(...) - reduce_scatter_cost(...) as the upgrade delta - Uncovered reduction exits: same pattern 4. Removed _BW: float = 50e9 class attribute — no longer needed The key improvement: instead of a flat bandwidth constant across all mesh dims, the cost model now uses per-mesh-dim bandwidth and latency from MeshTopoInfo. This means the solver can properly distinguish between e.g. NVLink within a node (high bandwidth, low latency) vs. InfiniBand across nodes (lower bandwidth, higher latency), and make better sharding decisions accordingly. --- autoparallel/optimize_sharding_new.py | 78 ++++++++++++++++----------- 1 file changed, 47 insertions(+), 31 deletions(-) diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index 40df59a8..ac643020 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -66,6 +66,13 @@ from torch.distributed.tensor.placement_types import Partial, Placement, Replicate, Shard from torch.utils._pytree import tree_flatten, tree_map_only +from torch.distributed.tensor._collective_utils import ( + MeshTopoInfo, + allgather_cost, + allreduce_cost, + reduce_scatter_cost, +) + from .cost_models.compute_estimation import estimate_strategy_runtime_cost from .shardings.placement_options import get_placement_options from .shardings.propagation_rules import _create_all_options @@ -379,6 +386,7 @@ def __init__( self.graph = gm.graph self.nodes = list(self.graph.nodes) self.mesh = mesh + self.mesh_topo = MeshTopoInfo.build_from_mesh(mesh) self.node_map: dict[torch.fx.Node, int] = { n: i for i, n in enumerate(self.nodes) } @@ -744,15 +752,20 @@ def _add_objective(self) -> None: 2. **Edge disagreement costs** (per FactorEdge per mesh_dim): - *Spatial edges*: all-gather when producer shards but consumer - doesn't. Cost = B·(n-1)/n. + doesn't. - *Reduction edges*: reduce-scatter/all-reduce when producer is Partial but consumer isn't. 3. **Uncovered reduction exit costs**: for reduction factors with NO outgoing reduction edge to a consumer, add Partial→{Shard,Replicate} linearized cost directly. + + Communication costs use per-mesh-dim bandwidth/latency from + ``MeshTopoInfo`` via ``allgather_cost``, ``reduce_scatter_cost``, + and ``allreduce_cost``. """ terms: list[Any] = [] + mesh_topo = self.mesh_topo # --- A) Compute benefit --- for node in self.graph.nodes: @@ -778,45 +791,50 @@ def _add_objective(self) -> None: bytes_val = self._output_bytes(edge.producer_node) if bytes_val <= 0: continue + bytes_gb = bytes_val / 1024 / 1024 / 1024 for m in range(self.mesh.ndim): - mesh_size = self.mesh.shape[m] - comm_unit = bytes_val * (mesh_size - 1) / mesh_size / self._BW * 1e6 - y_prod = self.y_vars[(edge.producer_gid, m)] y_cons = self.y_vars[(edge.consumer_gid, m)] if edge.kind == "spatial": # All-gather when producer shards but consumer doesn't: # z >= y[producer, m] - y[consumer, m]; z >= 0 + ag_cost = allgather_cost(bytes_gb, mesh_topo, m) z = pulp.LpVariable(f"z_ag_{z_id}", lowBound=0) self.prob += z >= y_prod - y_cons, f"z_ag_lb_{z_id}" - terms.append(comm_unit * z) + terms.append(ag_cost * z) z_id += 1 else: # Reduction edge: Partial exit when producer is Partial # but consumer isn't. # z_exit >= y[producer, m] - y[consumer, m]; z_exit >= 0 + rs_cost = reduce_scatter_cost(bytes_gb, mesh_topo, m) z_exit = pulp.LpVariable(f"z_re_{z_id}", lowBound=0) self.prob += z_exit >= y_prod - y_cons, f"z_re_lb_{z_id}" # Base reduce-scatter cost. - terms.append(comm_unit * z_exit) + terms.append(rs_cost * z_exit) # Upgrade to all-reduce when consumer has no spatial # factor on this mesh dim: # z_ar >= z_exit - Σ_s y[s, m]; z_ar >= 0 + ar_cost = allreduce_cost(bytes_gb, mesh_topo, m) + ar_extra = ar_cost - rs_cost consumer_spatial = self._get_spatial_gids_at_node(edge.consumer_nidx) valid_gids = [s for s in consumer_spatial if (s, m) in self.y_vars] if valid_gids: - z_ar = pulp.LpVariable(f"z_ar_{z_id}", lowBound=0) - spatial_sum = pulp.lpSum( - self.y_vars[(s, m)] for s in valid_gids - ) - self.prob += z_ar >= z_exit - spatial_sum, f"z_ar_lb_{z_id}" - terms.append(comm_unit * z_ar) + if ar_extra > 0: + z_ar = pulp.LpVariable(f"z_ar_{z_id}", lowBound=0) + spatial_sum = pulp.lpSum( + self.y_vars[(s, m)] for s in valid_gids + ) + self.prob += z_ar >= z_exit - spatial_sum, f"z_ar_lb_{z_id}" + terms.append(ar_extra * z_ar) else: - # No spatial factors → always all-reduce (double cost). - terms.append(comm_unit * z_exit) + # No spatial factors → always all-reduce (extra cost + # on top of reduce-scatter already added). + if ar_extra > 0: + terms.append(ar_extra * z_exit) z_id += 1 # --- C) Uncovered reduction exit costs --- @@ -853,33 +871,35 @@ def _add_objective(self) -> None: bytes_val = self._output_bytes(node) if bytes_val <= 0: continue + bytes_gb = bytes_val / 1024 / 1024 / 1024 consumer_spatial = self._get_spatial_gids_at_node(uidx) for m in range(self.mesh.ndim): - mesh_size = self.mesh.shape[m] y_r_m = self.y_vars[(gid, m)] - comm_unit = ( - bytes_val * (mesh_size - 1) / mesh_size / self._BW * 1e6 - ) + rs_cost = reduce_scatter_cost(bytes_gb, mesh_topo, m) # Base reduce-scatter cost. - terms.append(comm_unit * y_r_m) + terms.append(rs_cost * y_r_m) # Extra cost for Partial → Replicate (linearized). + ar_cost = allreduce_cost(bytes_gb, mesh_topo, m) + ar_extra = ar_cost - rs_cost valid_gids = [ s for s in consumer_spatial if (s, m) in self.y_vars ] if valid_gids: - z = pulp.LpVariable(f"z_ure_{z_id}", lowBound=0) - spatial_sum = pulp.lpSum( - self.y_vars[(s, m)] for s in valid_gids - ) - self.prob += z >= y_r_m - spatial_sum, f"z_ure_lb_{z_id}" - terms.append(comm_unit * z) - z_id += 1 + if ar_extra > 0: + z = pulp.LpVariable(f"z_ure_{z_id}", lowBound=0) + spatial_sum = pulp.lpSum( + self.y_vars[(s, m)] for s in valid_gids + ) + self.prob += z >= y_r_m - spatial_sum, f"z_ure_lb_{z_id}" + terms.append(ar_extra * z) + z_id += 1 else: # No spatial factors → always all-reduce. - terms.append(comm_unit * y_r_m) + if ar_extra > 0: + terms.append(ar_extra * y_r_m) self._num_z_vars = z_id @@ -888,10 +908,6 @@ def _add_objective(self) -> None: # ---- cost helpers ----------------------------------------------- - # Rough inter-node bandwidth (bytes/s). 50 GB/s is a reasonable - # default for NVLink / high-end InfiniBand. - _BW: float = 50e9 - @staticmethod def _output_bytes(node: torch.fx.Node) -> float: val = _get_primary_tensor(node.meta.get("val")) From 9c88c82a2a89160f5b165f26bd56a4753c63a4cc Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 18 Feb 2026 08:23:54 +0000 Subject: [PATCH 16/22] Rename functions --- autoparallel/api.py | 4 ++-- autoparallel/optimize_sharding.py | 4 ++-- examples/example_autoparallel_factor.py | 5 +++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index 83e86564..6e262f29 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -425,7 +425,7 @@ def add_input_constraints(self, constraints): self._assert_entered() assert self.input_constraints is None, "Input constraints have already been set" - self.sharding_optimizer.add_sharded_input_constraint(constraints) + self.sharding_optimizer.add_input_constraints(constraints) self.input_constraints = constraints def add_output_constraints(self, constraints): @@ -435,7 +435,7 @@ def add_output_constraints(self, constraints): self.output_constraints is None ), "Output constraints have already been set" # forces sharding of fwd output to be S(0) on first dimension and R on others - self.sharding_optimizer.add_sharded_output_constraint(constraints) + self.sharding_optimizer.add_output_constraints(constraints) self.output_constraints = constraints def optimize_placement(self, verbose=True): diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 14bd9613..dff0374c 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -751,7 +751,7 @@ def add_node_constraint(self, node, placement=None, constraint_name=None): constraint_name=constraint_name, ) - def add_sharded_input_constraint( + def add_input_constraints( self, input_placements: Optional[list[Optional[tuple[Placement, ...]]]] = None ): """ @@ -797,7 +797,7 @@ def add_sharded_input_constraint( "the inputs before tracing to remove aliasing." ) - def add_sharded_output_constraint(self, output_placements=None): + def add_output_constraints(self, output_placements=None): """ USER CONSTRAINTS (Category 6a): Output placement constraints. Force specific placements for output nodes and their corresponding gradient outputs. diff --git a/examples/example_autoparallel_factor.py b/examples/example_autoparallel_factor.py index 7a47c4b9..7ee66b88 100644 --- a/examples/example_autoparallel_factor.py +++ b/examples/example_autoparallel_factor.py @@ -128,8 +128,8 @@ def input_fn(): t0 = time.perf_counter() orig_opt = ShardingOptimizer(gm, mesh) orig_opt.add_grad_param_constraints() - orig_opt.add_sharded_input_constraint([x_sharding]) - orig_opt.add_sharded_output_constraint([x_sharding]) + orig_opt.add_input_constraints([x_sharding]) + orig_opt.add_output_constraints([x_sharding]) orig_solution = orig_opt.get_solution(verbose=False) t_orig = time.perf_counter() - t0 @@ -147,6 +147,7 @@ def input_fn(): t0 = time.perf_counter() factor_opt = FactorShardingOptimizer(gm, mesh) + factor_opt.add_grad_param_constraints() factor_opt.add_input_constraints([x_sharding]) factor_opt.add_output_constraints([x_sharding]) factor_solution = factor_opt.get_solution(verbose=False) From 2fe53b6f102dd6fcc328b91be3c9bc6b317c35e5 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 18 Feb 2026 10:04:43 +0000 Subject: [PATCH 17/22] Add get_stats for ILP problem statistics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit autoparallel/optimize_sharding.py — get_stats: - Removed print statements - Returns a dict with keys: num_graph_nodes, num_ilp_variables, num_ilp_constraints, mesh_shape autoparallel/optimize_sharding_new.py — get_stats: - Removed print statements - Removed estimated_original_ilp_variables and variable_reduction_ratio (these belong in comparison code, not in the optimizer itself) - Renamed keys to align with the old optimizer: num_ilp_variables, num_ilp_constraints (shared), plus factor-specific: num_unique_factors, num_edges, num_y_variables, num_z_variables - Updated get_log to use the renamed keys examples/example_autoparallel_factor.py: - Updated to use orig_opt.get_stats() instead of accessing orig_opt.ds and orig_opt.prob.constraints directly - Updated key references to match the new naming --- autoparallel/optimize_sharding.py | 11 +++++++++ autoparallel/optimize_sharding_new.py | 31 +++++++------------------ examples/example_autoparallel_factor.py | 11 +++++---- 3 files changed, 25 insertions(+), 28 deletions(-) diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index dff0374c..7f50c139 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -841,6 +841,17 @@ def add_output_constraints(self, output_placements=None): "them from the graph to avoid aliasing." ) + def get_stats(self) -> dict: + """Return ILP size statistics.""" + num_vars = len(self.ds) + num_constraints = len(self.prob.constraints) + return { + "num_graph_nodes": len(list(self.graph.nodes)), + "num_ilp_variables": num_vars, + "num_ilp_constraints": num_constraints, + "mesh_shape": tuple(self.mesh.shape), + } + def validate(self): for node in self.graph.nodes: if node.op != "call_function": diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index ac643020..2ff2b5a6 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -1423,32 +1423,21 @@ def _infeasibility_diagnostics(self) -> str: ) def get_stats(self) -> dict[str, Any]: - """Return ILP size statistics (useful for comparing with original).""" + """Return ILP size statistics.""" num_unique_factors = len(set(self.factor_keys.values())) - # Estimate original variable count. - orig_vars = 0 - for node, strat in self.strats.items(): - if not strat.strategies: - continue - n_out = len(strat.strategies) - first = strat.strategies[0] - n_args = len(first.input_specs) if first.input_specs else 0 - orig_vars += max(n_args, 1) * n_out * n_out - n_factor_vars = len(self.y_vars) n_aux_vars = getattr(self, "_num_z_vars", 0) + return { "num_graph_nodes": len(self.nodes), "num_unique_factors": num_unique_factors, "num_edges": len(self._factor_edges), - "num_factor_ilp_variables": n_factor_vars + n_aux_vars, - "num_factor_y_variables": n_factor_vars, - "num_factor_z_variables": n_aux_vars, - "num_factor_ilp_constraints": len(self.prob.constraints), + "num_ilp_variables": n_factor_vars + n_aux_vars, + "num_y_variables": n_factor_vars, + "num_z_variables": n_aux_vars, + "num_ilp_constraints": len(self.prob.constraints), "mesh_shape": tuple(self.mesh.shape), - "estimated_original_ilp_variables": orig_vars, - "variable_reduction_ratio": orig_vars / max(n_factor_vars, 1), } def get_log(self, verbose: bool = False) -> str: @@ -1458,12 +1447,8 @@ def get_log(self, verbose: bool = False) -> str: s = self.get_stats() lines.append(f"Unique factors: {s['num_unique_factors']}") lines.append(f"Factor edges: {s['num_edges']}") - lines.append(f"Factor ILP variables: {s['num_factor_ilp_variables']} ({s['num_factor_y_variables']} y + {s['num_factor_z_variables']} z)") - lines.append(f"Factor ILP constraints: {s['num_factor_ilp_constraints']}") - lines.append( - f"Est. original ILP vars: {s['estimated_original_ilp_variables']}" - ) - lines.append(f"Variable reduction: {s['variable_reduction_ratio']:.1f}x") + lines.append(f"ILP variables: {s['num_ilp_variables']} ({s['num_y_variables']} y + {s['num_z_variables']} z)") + lines.append(f"ILP constraints: {s['num_ilp_constraints']}") if verbose and self.prob.status == 1: # Build reverse lookup: gid → (node, factor, local_idx) diff --git a/examples/example_autoparallel_factor.py b/examples/example_autoparallel_factor.py index 7ee66b88..78053bde 100644 --- a/examples/example_autoparallel_factor.py +++ b/examples/example_autoparallel_factor.py @@ -163,18 +163,19 @@ def input_fn(): # ------------------------------------------------------------------ stats = factor_opt.get_stats() + orig_stats = orig_opt.get_stats() + print("\n" + "=" * 70) print("COMPARISON") print("=" * 70) print(f" Mesh shape: {tuple(mesh.shape)}") print(f" Graph nodes: {stats['num_graph_nodes']}") print() - print(f" Original ILP variables: {len(orig_opt.ds):,}") - print(f" Factor ILP variables: {stats['num_factor_ilp_variables']:,}") - print(f" Variable reduction: {stats['variable_reduction_ratio']:.1f}x") + print(f" Original ILP variables: {orig_stats['num_ilp_variables']:,}") + print(f" Factor ILP variables: {stats['num_ilp_variables']:,}") print() - print(f" Original ILP constraints:{len(orig_opt.prob.constraints):,}") - print(f" Factor ILP constraints: {stats['num_factor_ilp_constraints']:,}") + print(f" Original ILP constraints:{orig_stats['num_ilp_constraints']:,}") + print(f" Factor ILP constraints: {stats['num_ilp_constraints']:,}") print() print(f" Unique factors: {stats['num_unique_factors']}") From db30c95499dbe66caf269172f54829489b3d3ec8 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 18 Feb 2026 10:40:01 +0000 Subject: [PATCH 18/22] Use HiGHS solver At least 10x faster. HiGHS has a much better branch-and-bound implementation than CBC, and it's also better at exploiting LP relaxation tightening (presolve, cuts, etc.), which is exactly where the z-variable formulation was struggling --- autoparallel/optimize_sharding_new.py | 63 ++++++++++++++------------- 1 file changed, 33 insertions(+), 30 deletions(-) diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index 2ff2b5a6..ced79b5f 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -61,17 +61,21 @@ get_plain_output_and_tangent_nodes, ) from torch.distributed._tensor.placement_types import TensorMeta -from torch.distributed.tensor._dtensor_spec import DTensorSpec -from torch.distributed.tensor._op_schema import OpSpec, OpStrategy -from torch.distributed.tensor.placement_types import Partial, Placement, Replicate, Shard -from torch.utils._pytree import tree_flatten, tree_map_only - from torch.distributed.tensor._collective_utils import ( MeshTopoInfo, allgather_cost, allreduce_cost, reduce_scatter_cost, ) +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import OpSpec, OpStrategy +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) +from torch.utils._pytree import tree_flatten, tree_map_only from .cost_models.compute_estimation import estimate_strategy_runtime_cost from .shardings.placement_options import get_placement_options @@ -321,8 +325,7 @@ def extract_factors_from_strategy( for factor_id, (out_ps, in_ps) in enumerate(seen.values()): is_reduction = any(_is_partial(p) for p in out_ps if p is not None) result_dims = [ - p.dim if p is not None and isinstance(p, Shard) else None - for p in out_ps + p.dim if p is not None and isinstance(p, Shard) else None for p in out_ps ] operand_dims = [p.dim if isinstance(p, Shard) else None for p in in_ps] size = _infer_factor_size(node, operand_dims, result_dims) @@ -336,7 +339,9 @@ def extract_factors_from_strategy( ) ) - return FactorRule(factors=factors, num_operands=num_operands, num_results=num_results) + return FactorRule( + factors=factors, num_operands=num_operands, num_results=num_results + ) def _placeholder_factor_rule(node: torch.fx.Node) -> FactorRule: @@ -820,7 +825,9 @@ def _add_objective(self) -> None: # z_ar >= z_exit - Σ_s y[s, m]; z_ar >= 0 ar_cost = allreduce_cost(bytes_gb, mesh_topo, m) ar_extra = ar_cost - rs_cost - consumer_spatial = self._get_spatial_gids_at_node(edge.consumer_nidx) + consumer_spatial = self._get_spatial_gids_at_node( + edge.consumer_nidx + ) valid_gids = [s for s in consumer_spatial if (s, m) in self.y_vars] if valid_gids: if ar_extra > 0: @@ -843,9 +850,7 @@ def _add_objective(self) -> None: covered_reduction_pairs: set[tuple[int, int]] = set() for edge in self._factor_edges: if edge.kind == "reduction": - covered_reduction_pairs.add( - (edge.producer_nidx, edge.consumer_nidx) - ) + covered_reduction_pairs.add((edge.producer_nidx, edge.consumer_nidx)) for node in self.graph.nodes: rule = self.factor_rules.get(node) @@ -893,7 +898,10 @@ def _add_objective(self) -> None: spatial_sum = pulp.lpSum( self.y_vars[(s, m)] for s in valid_gids ) - self.prob += z >= y_r_m - spatial_sum, f"z_ure_lb_{z_id}" + self.prob += ( + z >= y_r_m - spatial_sum, + f"z_ure_lb_{z_id}", + ) terms.append(ar_extra * z) z_id += 1 else: @@ -1113,9 +1121,7 @@ def add_parameter_memory_constraint( b_vars: dict[frozenset[int], pulp.LpVariable] = {} for S in all_subsets: tag = "".join(str(m) for m in sorted(S)) if S else "R" - b_vars[S] = pulp.LpVariable( - f"mem_{nidx}_{tag}", cat="Binary" - ) + b_vars[S] = pulp.LpVariable(f"mem_{nidx}_{tag}", cat="Binary") # Exactly one subset is active per parameter. self.prob += ( @@ -1213,7 +1219,8 @@ def add_output_constraints( def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, OpSpec]: """Solve the factor ILP and reconstruct per-node OpSpecs.""" - solver = pulp.PULP_CBC_CMD(msg=verbose) + solver = pulp.HiGHS(msg=verbose) + # solver = pulp.PULP_CBC_CMD(msg=verbose) self.prob.solve(solver) if self.prob.status == -1: @@ -1298,13 +1305,11 @@ def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, OpSpec]: else: input_specs = None else: - input_specs = self._build_input_specs( - node, rule, nidx, assignment - ) or None + input_specs = ( + self._build_input_specs(node, rule, nidx, assignment) or None + ) - result[node] = OpSpec( - output_specs=output_specs, input_specs=input_specs - ) + result[node] = OpSpec(output_specs=output_specs, input_specs=input_specs) return result @@ -1355,9 +1360,7 @@ def _build_input_specs( if oi < len(tensor_args): arg_val = tensor_args[oi].meta.get("val") if isinstance(arg_val, torch.Tensor): - inp_tm = TensorMeta( - arg_val.shape, arg_val.stride(), arg_val.dtype - ) + inp_tm = TensorMeta(arg_val.shape, arg_val.stride(), arg_val.dtype) elif isinstance(arg_val, (tuple, list)): # Multi-output producer (e.g. getitem consuming SDPA). # Use the getitem index to find the correct tensor. @@ -1375,9 +1378,7 @@ def _build_input_specs( inp_tm = TensorMeta(v.shape, v.stride(), v.dtype) input_specs.append( - DTensorSpec( - self.mesh, tuple(inp_placements), tensor_meta=inp_tm - ) + DTensorSpec(self.mesh, tuple(inp_placements), tensor_meta=inp_tm) ) return input_specs @@ -1447,7 +1448,9 @@ def get_log(self, verbose: bool = False) -> str: s = self.get_stats() lines.append(f"Unique factors: {s['num_unique_factors']}") lines.append(f"Factor edges: {s['num_edges']}") - lines.append(f"ILP variables: {s['num_ilp_variables']} ({s['num_y_variables']} y + {s['num_z_variables']} z)") + lines.append( + f"ILP variables: {s['num_ilp_variables']} ({s['num_y_variables']} y + {s['num_z_variables']} z)" + ) lines.append(f"ILP constraints: {s['num_ilp_constraints']}") if verbose and self.prob.status == 1: From 6658641b8c92bda26b8697f23d93bdb6a8fa17a6 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 18 Feb 2026 15:32:26 +0000 Subject: [PATCH 19/22] Add timing to different parts of the solver --- autoparallel/optimize_sharding_new.py | 35 ++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/autoparallel/optimize_sharding_new.py b/autoparallel/optimize_sharding_new.py index ced79b5f..b3662623 100644 --- a/autoparallel/optimize_sharding_new.py +++ b/autoparallel/optimize_sharding_new.py @@ -387,6 +387,8 @@ def __init__( mesh: Any, rescale_grad_comm_cost_for_mp: float = 1.0, ) -> None: + import time as _time + self.gm = gm self.graph = gm.graph self.nodes = list(self.graph.nodes) @@ -395,6 +397,7 @@ def __init__( self.node_map: dict[torch.fx.Node, int] = { n: i for i, n in enumerate(self.nodes) } + self._timings: dict[str, float] = {} # -- Step 1: build multi-dim strategies (reuses existing DTensor rules) # NOTE: in a production implementation you would build strategies for a @@ -402,24 +405,32 @@ def __init__( # flat_mesh = mesh._flatten("flat") # For this POC we reuse the real mesh so that all existing op rules work # unchanged. The key savings come from the ILP reformulation below. + _t0 = _time.perf_counter() self.strats = self._build_sharding_metadata() + self._timings["build_strategies"] = _time.perf_counter() - _t0 # -- Step 2: extract factor rules from strategies. self.factor_rules: dict[torch.fx.Node, FactorRule] = {} + _t0 = _time.perf_counter() self._extract_all_factor_rules() + self._timings["extract_factor_rules"] = _time.perf_counter() - _t0 # -- Step 3: build factor edge graph (no merging; edges track relationships). self.factor_keys: dict[tuple[int, int], int] = {} # (node_idx, local) → gid self._next_gid = 0 self._factor_edges: list[FactorEdge] = [] self._disabled_reduction_gids: set[int] = set() + _t0 = _time.perf_counter() self._build_factor_graph() + self._timings["build_factor_graph"] = _time.perf_counter() - _t0 # -- Step 4: build ILP. self._cost_cache: dict[torch.fx.Node, float] = {} self.prob = pulp.LpProblem("AutoParallel_Factor", pulp.LpMinimize) self.y_vars: dict[tuple[int, int], pulp.LpVariable] = {} + _t0 = _time.perf_counter() self._build_ilp() + self._timings["build_ilp"] = _time.perf_counter() - _t0 # ----------------------------------------------------------------- # Step 1 — strategy building (mirrors ShardingOptimizer) @@ -1219,9 +1230,18 @@ def add_output_constraints( def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, OpSpec]: """Solve the factor ILP and reconstruct per-node OpSpecs.""" - solver = pulp.HiGHS(msg=verbose) - # solver = pulp.PULP_CBC_CMD(msg=verbose) + import time as _time + + solver = pulp.HiGHS( + msg=verbose, + options=[ + ("mip_rel_gap", 0.01), + ("mip_heuristic_effort", 0.2), + ], + ) + _t0 = _time.perf_counter() self.prob.solve(solver) + self._timings["solve"] = _time.perf_counter() - _t0 if self.prob.status == -1: diag = self._infeasibility_diagnostics() @@ -1430,7 +1450,7 @@ def get_stats(self) -> dict[str, Any]: n_factor_vars = len(self.y_vars) n_aux_vars = getattr(self, "_num_z_vars", 0) - return { + stats = { "num_graph_nodes": len(self.nodes), "num_unique_factors": num_unique_factors, "num_edges": len(self._factor_edges), @@ -1440,6 +1460,8 @@ def get_stats(self) -> dict[str, Any]: "num_ilp_constraints": len(self.prob.constraints), "mesh_shape": tuple(self.mesh.shape), } + stats["timings"] = dict(self._timings) + return stats def get_log(self, verbose: bool = False) -> str: """Human-readable summary.""" @@ -1453,6 +1475,13 @@ def get_log(self, verbose: bool = False) -> str: ) lines.append(f"ILP constraints: {s['num_ilp_constraints']}") + timings = s.get("timings", {}) + if timings: + lines.append("") + lines.append("Timings:") + for step, dt in timings.items(): + lines.append(f" {step:30s} {dt:.3f}s") + if verbose and self.prob.status == 1: # Build reverse lookup: gid → (node, factor, local_idx) gid_to_info: dict[int, tuple[torch.fx.Node, Factor, int]] = {} From 5e96d1cdb3419da975cb5c19b842b7fa6c36450c Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 18 Feb 2026 16:04:34 +0000 Subject: [PATCH 20/22] Add per-mesh-dim independent optimizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Performance: The independent optimizer is the fastest at 0.48s (vs 3.26s original, 1.32s factor-based) — a 6.8x speedup over the original. ILP size: 3,724 variables / 1,782 constraints (vs 23,594 / 5,113 original). Each 1D sub-problem has ~1,862 vars and 891 constraints. Placement differences: The independent solver overwhelmingly chooses Shard(dim=0) on both mesh dimensions (i.e., 128-way data parallelism) rather than the expected DP+TP split. This is the expected behavior of the independent decomposition: each 1D solver independently sees that batch-dim sharding is cheapest (zero communication), without knowing the other solver already handles batch parallelism. From a pure per-dim cost perspective, Shard(0) on tp is strictly cheaper than Shard(1) or Shard(2) because it avoids all-gather/reduce-scatter at matmul boundaries. This is the fundamental tradeoff of the independent approach — cross-dim interactions (like "dp already handles batch, so tp should handle hidden dims") are lost. Two files were created/modified: - New: autoparallel/optimize_sharding_independent.py — IndependentShardingOptimizer class - Modified: examples/example_autoparallel_factor.py — added third optimizer comparison --- autoparallel/optimize_sharding_independent.py | 485 ++++++++++++++++++ examples/example_autoparallel_factor.py | 101 ++-- 2 files changed, 545 insertions(+), 41 deletions(-) create mode 100644 autoparallel/optimize_sharding_independent.py diff --git a/autoparallel/optimize_sharding_independent.py b/autoparallel/optimize_sharding_independent.py new file mode 100644 index 00000000..c21b7091 --- /dev/null +++ b/autoparallel/optimize_sharding_independent.py @@ -0,0 +1,485 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +""" +Independent per-mesh-dim sharding optimization. + +This module solves the sharding optimization problem by running the original +enumeration-based ILP (from :class:`ShardingOptimizer`) **independently on each +mesh dimension** using 1D sub-meshes, then combines the per-dim solutions into +multi-dimensional placements. + +Key advantages over the joint ILP: +- Strategy count per node is O(d+1) on a 1D mesh instead of O((d+1)^k) on a + k-dimensional mesh (where d = tensor dims, k = mesh dims). +- Each 1D ILP has exact redistribution costs and tight LP relaxation via + one-hot encoding (unlike the factor-based ILP which uses z-variables). +- Total solve time is roughly k × (time for 1D ILP), which is dramatically + faster than the joint formulation for k ≥ 2. + +Limitation: +- Mesh dimensions are assumed independent — cross-mesh-dim interactions (e.g. + joint memory constraints) are approximated per-dim. +""" + +from __future__ import annotations + +import operator +import time +from typing import Any, Optional + +import torch +import torch.fx +from torch.distributed._tensor.placement_types import TensorMeta +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import OpSpec +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) +from torch.utils._pytree import tree_flatten + +from .optimize_sharding import ShardingOptimizer + + +class IndependentShardingOptimizer: + """Sharding optimizer that solves per mesh dimension independently. + + Runs k independent :class:`ShardingOptimizer` instances on 1D sub-meshes + (one per mesh dimension), then combines the 1D solutions into multi-dim + placements on the full mesh. + + Public API mirrors :class:`ShardingOptimizer` and + :class:`FactorShardingOptimizer` so it can be used as a drop-in replacement. + + Parameters + ---------- + gm : torch.fx.GraphModule + Traced FX graph (joint forward + backward). + mesh : DeviceMesh + Target device mesh (may be multi-dimensional). + rescale_grad_comm_cost_for_mp : float + Scaling factor for gradient communication costs (mixed precision). + """ + + def __init__( + self, + gm: torch.fx.GraphModule, + mesh: Any, + rescale_grad_comm_cost_for_mp: float = 1.0, + ) -> None: + self.gm = gm + self.graph = gm.graph + self.mesh = mesh + self.rescale_grad_comm_cost_for_mp = rescale_grad_comm_cost_for_mp + self._timings: dict[str, float] = {} + self._solved = False + + # Create k 1D sub-solvers, one per mesh dimension. + self._sub_solvers: list[ShardingOptimizer] = [] + for m in range(mesh.ndim): + sub_mesh = self._create_1d_mesh(m) + t0 = time.perf_counter() + sub_solver = ShardingOptimizer( + gm, sub_mesh, rescale_grad_comm_cost_for_mp + ) + self._timings[f"build_dim{m}"] = time.perf_counter() - t0 + self._sub_solvers.append(sub_solver) + + # ----------------------------------------------------------------- + # Sub-mesh creation + # ----------------------------------------------------------------- + + def _create_1d_mesh(self, mesh_dim: int) -> Any: + """Create a 1D DeviceMesh for the given mesh dimension. + + Uses the parent mesh's subscript operator to extract a proper 1D + sub-mesh that reuses existing process groups. + """ + if hasattr(self.mesh, "mesh_dim_names") and self.mesh.mesh_dim_names: + dim_name = self.mesh.mesh_dim_names[mesh_dim] + return self.mesh[dim_name] + + # Fallback: construct a 1D mesh directly with sequential device IDs. + dim_size = self.mesh.shape[mesh_dim] + return torch.distributed.device_mesh.init_device_mesh( + self.mesh.device_type, + (dim_size,), + ) + + # ----------------------------------------------------------------- + # Constraint methods (project multi-dim → 1D per sub-solver) + # ----------------------------------------------------------------- + + def add_input_constraints( + self, input_placements: Optional[list[Optional[tuple[Placement, ...]]]] = None + ) -> None: + """Add input constraints, projecting multi-dim placements to 1D per dim.""" + for m, solver in enumerate(self._sub_solvers): + if input_placements is None: + solver.add_input_constraints(None) + else: + projected = [ + (p[m],) if p is not None else None for p in input_placements + ] + solver.add_input_constraints(projected) + + def add_output_constraints( + self, output_placements: Optional[list[Optional[tuple[Placement, ...]]]] = None + ) -> None: + """Add output constraints, projecting multi-dim placements to 1D per dim.""" + for m, solver in enumerate(self._sub_solvers): + if output_placements is None: + solver.add_output_constraints(None) + else: + projected = [ + (p[m],) if p is not None else None for p in output_placements + ] + solver.add_output_constraints(projected) + + def add_grad_param_constraints(self) -> None: + """Ensure parameters and their gradients have matching placements.""" + for solver in self._sub_solvers: + solver.add_grad_param_constraints() + + def add_parameter_memory_constraint( + self, memory_factor_low: float, memory_factor_high: float + ) -> None: + """Add parameter memory constraints per sub-solver. + + NOTE: This is an approximation — the true memory constraint depends on + the joint sharding across all mesh dims (product of per-dim shard + factors). Here we apply the constraint independently to each 1D + sub-solver, which may over- or under-constrain memory. + """ + # TODO: Implement a more accurate joint memory constraint by + # post-processing the combined solution or using a Lagrangian approach. + for solver in self._sub_solvers: + solver.add_parameter_memory_constraint( + memory_factor_low, memory_factor_high + ) + + def add_node_constraint( + self, + node: torch.fx.Node, + placement: Optional[tuple[Placement, ...]] = None, + constraint_name: Optional[str] = None, + ) -> None: + """Pin a node to a specific multi-dim placement.""" + if placement is None: + for solver in self._sub_solvers: + solver.add_node_constraint(node, None, constraint_name) + else: + for m, solver in enumerate(self._sub_solvers): + solver.add_node_constraint(node, (placement[m],), constraint_name) + + # ----------------------------------------------------------------- + # Solve + # ----------------------------------------------------------------- + + def get_solution(self, verbose: bool = False) -> dict[torch.fx.Node, OpSpec]: + """Solve each 1D sub-problem and combine into multi-dim OpSpecs.""" + # 1. Solve each 1D sub-problem. + sub_solutions: list[dict[torch.fx.Node, OpSpec]] = [] + for m, solver in enumerate(self._sub_solvers): + t0 = time.perf_counter() + sol = solver.get_solution(verbose=verbose) + self._timings[f"solve_dim{m}"] = time.perf_counter() - t0 + sub_solutions.append(sol) + + self._solved = True + + # 2. Combine into multi-dim OpSpecs on the full mesh. + result: dict[torch.fx.Node, OpSpec] = {} + for node in self.graph.nodes: + if node.op == "output": + continue + + per_dim_specs = [ + sub_solutions[m].get(node) for m in range(self.mesh.ndim) + ] + if all(s is None for s in per_dim_specs): + continue + + combined = self._combine_specs(node, per_dim_specs) + if combined is not None: + result[node] = combined + + return result + + # ----------------------------------------------------------------- + # Solution combination + # ----------------------------------------------------------------- + + def _extract_1d_placements( + self, spec: OpSpec | None, kind: str = "output" + ) -> Placement: + """Extract the single 1D placement from a 1D OpSpec. + + Parameters + ---------- + spec : OpSpec or None + A 1D OpSpec (single placement per DTensorSpec). + kind : str + "output" to extract from output_specs, or "input_N" to extract + from input_specs[N]. + + Returns + ------- + Placement + The 1D placement (Shard, Replicate, or Partial). + """ + if spec is None: + return Replicate() + + if kind == "output": + out = spec.output_specs + if isinstance(out, DTensorSpec): + return out.placements[0] + elif isinstance(out, (tuple, list)): + # Multi-output: return first non-None spec's placement + for s in out: + if isinstance(s, DTensorSpec): + return s.placements[0] + return Replicate() + elif kind.startswith("input_"): + idx = int(kind.split("_")[1]) + if spec.input_specs is not None and idx < len(spec.input_specs): + inp = spec.input_specs[idx] + if isinstance(inp, DTensorSpec): + return inp.placements[0] + return Replicate() + return Replicate() + + def _extract_1d_output_placements_tuple( + self, spec: OpSpec | None + ) -> tuple | Placement: + """Extract all 1D output placements, handling multi-output ops. + + Returns a single Placement for single-output ops, or a tuple of + Placements for multi-output ops. + """ + if spec is None: + return Replicate() + + out = spec.output_specs + if isinstance(out, DTensorSpec): + return out.placements[0] + elif isinstance(out, (tuple, list)): + return tuple( + s.placements[0] if isinstance(s, DTensorSpec) else None + for s in out + ) + return Replicate() + + def _combine_specs( + self, + node: torch.fx.Node, + per_dim_specs: list[OpSpec | None], + ) -> OpSpec | None: + """Combine k 1D OpSpecs into one multi-dim OpSpec on the full mesh.""" + val = node.meta.get("val") + k = self.mesh.ndim + + # --- Build output_specs --- + output_specs = self._build_combined_output_specs(node, val, per_dim_specs, k) + if output_specs is None: + return None + + # --- Build input_specs --- + input_specs = self._build_combined_input_specs(node, per_dim_specs, k) + + return OpSpec(output_specs=output_specs, input_specs=input_specs) + + def _build_combined_output_specs( + self, + node: torch.fx.Node, + val: Any, + per_dim_specs: list[OpSpec | None], + k: int, + ) -> DTensorSpec | tuple | None: + """Build combined multi-dim output specs from per-dim 1D specs.""" + if isinstance(val, torch.Tensor): + placements: list[Placement] = [] + for m in range(k): + p = self._extract_1d_placements(per_dim_specs[m], "output") + placements.append(p) + tensor_meta = TensorMeta(val.shape, val.stride(), val.dtype) + return DTensorSpec( + self.mesh, tuple(placements), tensor_meta=tensor_meta + ) + elif isinstance(val, (tuple, list)): + # Multi-output op: combine per-output placements across dims. + per_dim_out_placements = [] + for m in range(k): + p = self._extract_1d_output_placements_tuple(per_dim_specs[m]) + per_dim_out_placements.append(p) + + # Determine number of outputs from val + num_outputs = len(val) + specs = [] + for ri in range(num_outputs): + v = val[ri] + if isinstance(v, torch.Tensor): + plc_list: list[Placement] = [] + for m in range(k): + pdp = per_dim_out_placements[m] + if isinstance(pdp, tuple) and ri < len(pdp): + p = pdp[ri] if pdp[ri] is not None else Replicate() + elif isinstance(pdp, Placement): + p = pdp + else: + p = Replicate() + plc_list.append(p) + tm = TensorMeta(v.shape, v.stride(), v.dtype) + specs.append( + DTensorSpec(self.mesh, tuple(plc_list), tensor_meta=tm) + ) + else: + specs.append(None) + return tuple(specs) + else: + return None + + def _build_combined_input_specs( + self, + node: torch.fx.Node, + per_dim_specs: list[OpSpec | None], + k: int, + ) -> list[DTensorSpec] | None: + """Build combined multi-dim input specs from per-dim 1D specs.""" + # Determine number of inputs from any non-None sub-solver spec. + num_inputs = 0 + for spec in per_dim_specs: + if spec is not None and spec.input_specs is not None: + num_inputs = len(spec.input_specs) + break + + if num_inputs == 0: + # Placeholders and get_attr: use output_specs as input_specs + if node.op in ("placeholder", "get_attr"): + out = self._build_combined_output_specs( + node, node.meta.get("val"), per_dim_specs, k + ) + if isinstance(out, DTensorSpec): + return [out] + return None + + # Get tensor input nodes for TensorMeta + flat_args, _ = tree_flatten(node.args) + tensor_args = [a for a in flat_args if isinstance(a, torch.fx.Node)] + + input_specs: list[DTensorSpec] = [] + for inp_idx in range(num_inputs): + placements: list[Placement] = [] + for m in range(k): + spec = per_dim_specs[m] + if ( + spec is not None + and spec.input_specs is not None + and inp_idx < len(spec.input_specs) + ): + inp = spec.input_specs[inp_idx] + if isinstance(inp, DTensorSpec): + placements.append(inp.placements[0]) + else: + placements.append(Replicate()) + else: + placements.append(Replicate()) + + # Get TensorMeta from corresponding input node. + inp_tm = None + if inp_idx < len(tensor_args): + arg_val = tensor_args[inp_idx].meta.get("val") + if isinstance(arg_val, torch.Tensor): + inp_tm = TensorMeta( + arg_val.shape, arg_val.stride(), arg_val.dtype + ) + elif isinstance(arg_val, (tuple, list)): + if ( + node.target is operator.getitem + and inp_idx == 0 + and len(node.args) > 1 + and isinstance(node.args[1], int) + ): + idx = node.args[1] + if idx < len(arg_val) and isinstance( + arg_val[idx], torch.Tensor + ): + v = arg_val[idx] + inp_tm = TensorMeta(v.shape, v.stride(), v.dtype) + + input_specs.append( + DTensorSpec(self.mesh, tuple(placements), tensor_meta=inp_tm) + ) + + return input_specs + + # ----------------------------------------------------------------- + # Stats & logging + # ----------------------------------------------------------------- + + def get_stats(self) -> dict[str, Any]: + """Return aggregated ILP size statistics across all sub-solvers.""" + total_vars = sum(len(s.ds) for s in self._sub_solvers) + total_constraints = sum( + len(s.prob.constraints) for s in self._sub_solvers + ) + + stats: dict[str, Any] = { + "num_graph_nodes": len(list(self.graph.nodes)), + "num_ilp_variables": total_vars, + "num_ilp_constraints": total_constraints, + "mesh_shape": tuple(self.mesh.shape), + "num_sub_problems": self.mesh.ndim, + } + + per_dim: list[dict[str, Any]] = [] + for m, solver in enumerate(self._sub_solvers): + per_dim.append( + { + "mesh_dim": m, + "mesh_size": self.mesh.shape[m], + "num_ilp_variables": len(solver.ds), + "num_ilp_constraints": len(solver.prob.constraints), + } + ) + stats["per_dim"] = per_dim + stats["timings"] = dict(self._timings) + return stats + + def get_log(self, verbose: bool = False) -> str: + """Human-readable summary of the independent optimizer.""" + lines: list[str] = [] + lines.append("Independent per-mesh-dim ILP optimizer") + lines.append(f" Mesh shape: {tuple(self.mesh.shape)}") + lines.append(f" Sub-problems: {self.mesh.ndim}") + + stats = self.get_stats() + lines.append( + f" Total ILP variables: {stats['num_ilp_variables']:,}" + ) + lines.append( + f" Total ILP constraints: {stats['num_ilp_constraints']:,}" + ) + + for dim_stats in stats["per_dim"]: + m = dim_stats["mesh_dim"] + lines.append( + f" Dim {m} (size={dim_stats['mesh_size']}): " + f"{dim_stats['num_ilp_variables']:,} vars, " + f"{dim_stats['num_ilp_constraints']:,} constraints" + ) + + timings = stats.get("timings", {}) + if timings: + lines.append("") + lines.append(" Timings:") + for step, dt in timings.items(): + lines.append(f" {step:30s} {dt:.3f}s") + + return "\n".join(lines) diff --git a/examples/example_autoparallel_factor.py b/examples/example_autoparallel_factor.py index 78053bde..a2e16c7f 100644 --- a/examples/example_autoparallel_factor.py +++ b/examples/example_autoparallel_factor.py @@ -4,14 +4,14 @@ # LICENSE file in the root directory of this source tree. """ -Comparison of the original enumeration-based ILP vs the factor-based ILP. +Comparison of three sharding optimizers: + 1. Original enumeration-based ILP (ShardingOptimizer) + 2. Factor-based ILP (FactorShardingOptimizer) + 3. Independent per-mesh-dim ILP (IndependentShardingOptimizer) Uses the same transformer Block model as example_autoparallel.py, but instead -of running the full AutoParallel pipeline, it: - 1. Traces the model to obtain the FX graph. - 2. Runs the *original* ShardingOptimizer (enumeration-based). - 3. Runs the *factor-based* FactorShardingOptimizer on the same graph. - 4. Prints a side-by-side comparison of ILP sizes and solutions. +of running the full AutoParallel pipeline, it traces the model and runs all +three optimizers on the same FX graph for comparison. Usage: python examples/example_autoparallel_factor.py @@ -29,6 +29,7 @@ from autoparallel.api import AutoParallel from autoparallel.optimize_sharding import ShardingOptimizer from autoparallel.optimize_sharding_new import FactorShardingOptimizer +from autoparallel.optimize_sharding_independent import IndependentShardingOptimizer # --------------------------------------------------------------------------- # Model (same as example_autoparallel.py, minus activation checkpointing for @@ -75,7 +76,7 @@ def forward(self, x): # Setup # --------------------------------------------------------------------------- -world_size = 64 +world_size = 128 # 64 fake_store = FakeStore() torch.distributed.init_process_group( @@ -159,11 +160,29 @@ def input_fn(): # parallel_mod = autop.apply_placement(factor_solution) # ------------------------------------------------------------------ - # 3. Comparison + # 3. Independent per-mesh-dim optimizer # ------------------------------------------------------------------ - stats = factor_opt.get_stats() + print("\n" + "=" * 70) + print("Running INDEPENDENT per-mesh-dim IndependentShardingOptimizer") + print("=" * 70) + + t0 = time.perf_counter() + indep_opt = IndependentShardingOptimizer(gm, mesh) + indep_opt.add_grad_param_constraints() + indep_opt.add_input_constraints([x_sharding]) + indep_opt.add_output_constraints([x_sharding]) + indep_solution = indep_opt.get_solution(verbose=False) + t_indep = time.perf_counter() - t0 + print(f" Solve time: {t_indep:.2f}s") + print(indep_opt.get_log(verbose=True)) + + # ------------------------------------------------------------------ + # 4. Comparison + # ------------------------------------------------------------------ + stats = factor_opt.get_stats() orig_stats = orig_opt.get_stats() + indep_stats = indep_opt.get_stats() print("\n" + "=" * 70) print("COMPARISON") @@ -171,52 +190,52 @@ def input_fn(): print(f" Mesh shape: {tuple(mesh.shape)}") print(f" Graph nodes: {stats['num_graph_nodes']}") print() - print(f" Original ILP variables: {orig_stats['num_ilp_variables']:,}") - print(f" Factor ILP variables: {stats['num_ilp_variables']:,}") + print(f" Original ILP variables: {orig_stats['num_ilp_variables']:,}") + print(f" Factor ILP variables: {stats['num_ilp_variables']:,}") + print(f" Independent ILP variables: {indep_stats['num_ilp_variables']:,}") print() - print(f" Original ILP constraints:{orig_stats['num_ilp_constraints']:,}") - print(f" Factor ILP constraints: {stats['num_ilp_constraints']:,}") + print(f" Original ILP constraints: {orig_stats['num_ilp_constraints']:,}") + print(f" Factor ILP constraints: {stats['num_ilp_constraints']:,}") + print(f" Independent ILP constraints:{indep_stats['num_ilp_constraints']:,}") print() print(f" Unique factors: {stats['num_unique_factors']}") + print() + print(f" Original solve time: {t_orig:.2f}s") + print(f" Factor solve time: {t_factor:.2f}s") + print(f" Independent solve time: {t_indep:.2f}s") # ------------------------------------------------------------------ - # 4. Show per-node placement comparison + # 5. Show per-node placement comparison # ------------------------------------------------------------------ n_show = 100 print("\n" + "=" * 70) print(f"PER-NODE PLACEMENT COMPARISON (first {n_show} call_function nodes)") print("=" * 70) + def _get_placements(spec): + if spec is not None and hasattr(spec, "output_specs"): + os = spec.output_specs + if isinstance(os, DTensorSpec): + return tuple(os.placements) + elif isinstance(os, (list, tuple)) and os: + return tuple(os[0].placements) + return "?" + call_fn_nodes = [n for n in gm.graph.nodes if n.op == "call_function"] for node in call_fn_nodes[:n_show]: - orig_spec = orig_solution.get(node) - factor_spec = factor_solution.get(node) + orig_plc = _get_placements(orig_solution.get(node)) + factor_plc = _get_placements(factor_solution.get(node)) + indep_plc = _get_placements(indep_solution.get(node)) - if orig_spec is not None and hasattr(orig_spec, "output_specs"): - os = orig_spec.output_specs - if isinstance(os, DTensorSpec): - orig_plc = tuple(os.placements) - elif isinstance(os, (list, tuple)) and os: - orig_plc = tuple(os[0].placements) - else: - orig_plc = "?" - else: - orig_plc = "?" - if factor_spec is not None and hasattr(factor_spec, "output_specs"): - os = factor_spec.output_specs - if isinstance(os, DTensorSpec): - factor_plc = tuple(os.placements) - elif isinstance(os, (list, tuple)) and os: - factor_plc = tuple(os[0].placements) - else: - factor_plc = "?" - else: - factor_plc = "?" - match = "OK" if str(orig_plc) == str(factor_plc) else "DIFF" + match_f = "OK" if str(orig_plc) == str(factor_plc) else "DIFF" + match_i = "OK" if str(orig_plc) == str(indep_plc) else "DIFF" op_name = str(node) - # Truncate long op names - if len(op_name) > 40: - op_name = op_name[:37] + "..." - print(f" [{match:4s}] {op_name:42s} orig={orig_plc} factor={factor_plc}") + if len(op_name) > 35: + op_name = op_name[:32] + "..." + print( + f" {op_name:37s} orig={str(orig_plc):30s} " + f"factor[{match_f:4s}]={str(factor_plc):30s} " + f"indep[{match_i:4s}]={str(indep_plc)}" + ) print("\nDone.") From 01e3eaa7beeaeb0d76eab4056505d6c6b3f52bf9 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 5 Mar 2026 10:27:18 +0000 Subject: [PATCH 21/22] Switch between different optimizers --- autoparallel/api.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index 6e262f29..6a170ccb 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -39,6 +39,7 @@ ) from .init_weights import hook_params_setters from .optimize_sharding import ShardingOptimizer +from .optimize_sharding_new import FactorShardingOptimizer from .shardings.placement_options import ( NumericsLogger, _get_device_from_mesh, @@ -300,12 +301,22 @@ def __enter__(self): # Tiebreak, favoring performing the comms in the largest # dtype rescale_grad_comm_cost_for_mp *= 1.1 - sharding_optimizer = ShardingOptimizer( - self.gm, - self.mesh, - rescale_grad_comm_cost_for_mp, - repeated_subgraphs=self.kwargs.get("repeated_subgraphs", False), - ) + + optim_type = 2 + match optim_type: + case 0: + sharding_optimizer = ShardingOptimizer( + self.gm, + self.mesh, + rescale_grad_comm_cost_for_mp, + repeated_subgraphs=self.kwargs.get("repeated_subgraphs", False), + ) + case 1: + sharding_optimizer = FactorShardingOptimizer(self.gm, self.mesh) + case 2: + from .optimize_sharding_independent import IndependentShardingOptimizer + sharding_optimizer = IndependentShardingOptimizer(self.gm, self.mesh) + # makes sharding of params and gradients the same sharding_optimizer.add_grad_param_constraints() From 21a75f2e9f1548228cc523d85309e6da610e41e8 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 5 Mar 2026 10:44:57 +0000 Subject: [PATCH 22/22] Simplify format_sharding_log and remove dead prob.status check --- autoparallel/api.py | 15 +++++++++++---- autoparallel/log_formatting.py | 22 +++++++++++++++++++++- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index 6a170ccb..1081ef46 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -315,8 +315,8 @@ def __enter__(self): sharding_optimizer = FactorShardingOptimizer(self.gm, self.mesh) case 2: from .optimize_sharding_independent import IndependentShardingOptimizer - sharding_optimizer = IndependentShardingOptimizer(self.gm, self.mesh) + sharding_optimizer = IndependentShardingOptimizer(self.gm, self.mesh) # makes sharding of params and gradients the same sharding_optimizer.add_grad_param_constraints() @@ -464,6 +464,16 @@ def optimize_placement(self, verbose=True): if verbose: print(self.sharding_optimizer.get_log(verbose=True)) + from autoparallel.log_formatting import format_sharding_log + + print( + format_sharding_log( + graph=self.gm.graph, + sharding_placement=self.sharding_placement, + colored=False, + verbose=verbose, + ) + ) trace_structured( "artifact", @@ -474,9 +484,6 @@ def optimize_placement(self, verbose=True): payload_fn=lambda: self.sharding_optimizer.get_log(colored=False), ) - if self.sharding_optimizer.prob.status == -1: - raise RuntimeError("Didn't find solution") - return self.sharding_placement def _apply_placement_common(self, sharding_placement): diff --git a/autoparallel/log_formatting.py b/autoparallel/log_formatting.py index 64ce2c96..1dd4bcc5 100644 --- a/autoparallel/log_formatting.py +++ b/autoparallel/log_formatting.py @@ -18,10 +18,11 @@ def format_sharding_log( graph: torch.fx.Graph, - opt: dict[torch.fx.Node, list[dict[str, Any]]], + opt: dict[torch.fx.Node, list[dict[str, Any]]] | None = None, colored: bool = False, verbose: bool = False, violated_constraints_log: str = "", + sharding_placement: dict[torch.fx.Node, Any] | None = None, ) -> str: """ Format the sharding optimization results as annotated Python code. @@ -42,10 +43,29 @@ def format_sharding_log( colored: Whether to use ANSI color codes in the output. verbose: Whether to include verbose information (shapes, stack traces). violated_constraints_log: Optional string with violated constraints info. + sharding_placement: Dictionary mapping nodes to their OpSpec placements. + Used as an alternative to ``opt`` when detailed cost info is not + available. Costs will be reported as zero. Returns: A string containing the annotated Python code representation of the graph. """ + if opt is None and sharding_placement is None: + raise ValueError("Either 'opt' or 'sharding_placement' must be provided") + if opt is None: + assert sharding_placement is not None + opt = { + k: [ + { + "full_strat": v, + "comm_cost": 0, + "compute_cost": 0, + "sharding_transition_cost": 0, + "cost": 0, + } + ] + for k, v in sharding_placement.items() + } from torch.fx.graph import _color_fns, _identity nodes = list(graph.nodes)