Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 29 additions & 11 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from .init_weights import hook_params_setters
from .optimize_sharding import ShardingOptimizer
from .optimize_sharding_new import FactorShardingOptimizer
from .shardings.placement_options import (
NumericsLogger,
_get_device_from_mesh,
Expand Down Expand Up @@ -300,12 +301,22 @@ def __enter__(self):
# Tiebreak, favoring performing the comms in the largest
# dtype
rescale_grad_comm_cost_for_mp *= 1.1
sharding_optimizer = ShardingOptimizer(
self.gm,
self.mesh,
rescale_grad_comm_cost_for_mp,
repeated_subgraphs=self.kwargs.get("repeated_subgraphs", False),
)

optim_type = 2
match optim_type:
case 0:
sharding_optimizer = ShardingOptimizer(
self.gm,
self.mesh,
rescale_grad_comm_cost_for_mp,
repeated_subgraphs=self.kwargs.get("repeated_subgraphs", False),
)
case 1:
sharding_optimizer = FactorShardingOptimizer(self.gm, self.mesh)
case 2:
from .optimize_sharding_independent import IndependentShardingOptimizer

sharding_optimizer = IndependentShardingOptimizer(self.gm, self.mesh)

# makes sharding of params and gradients the same
sharding_optimizer.add_grad_param_constraints()
Expand Down Expand Up @@ -425,7 +436,7 @@ def add_input_constraints(self, constraints):
self._assert_entered()

assert self.input_constraints is None, "Input constraints have already been set"
self.sharding_optimizer.add_sharded_input_constraint(constraints)
self.sharding_optimizer.add_input_constraints(constraints)
self.input_constraints = constraints

def add_output_constraints(self, constraints):
Expand All @@ -435,7 +446,7 @@ def add_output_constraints(self, constraints):
self.output_constraints is None
), "Output constraints have already been set"
# forces sharding of fwd output to be S(0) on first dimension and R on others
self.sharding_optimizer.add_sharded_output_constraint(constraints)
self.sharding_optimizer.add_output_constraints(constraints)
self.output_constraints = constraints

def optimize_placement(self, verbose=True):
Expand All @@ -453,6 +464,16 @@ def optimize_placement(self, verbose=True):

if verbose:
print(self.sharding_optimizer.get_log(verbose=True))
from autoparallel.log_formatting import format_sharding_log

print(
format_sharding_log(
graph=self.gm.graph,
sharding_placement=self.sharding_placement,
colored=False,
verbose=verbose,
)
)

trace_structured(
"artifact",
Expand All @@ -463,9 +484,6 @@ def optimize_placement(self, verbose=True):
payload_fn=lambda: self.sharding_optimizer.get_log(colored=False),
)

if self.sharding_optimizer.prob.status == -1:
raise RuntimeError("Didn't find solution")

return self.sharding_placement

def _apply_placement_common(self, sharding_placement):
Expand Down
22 changes: 21 additions & 1 deletion autoparallel/log_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@

def format_sharding_log(
graph: torch.fx.Graph,
opt: dict[torch.fx.Node, list[dict[str, Any]]],
opt: dict[torch.fx.Node, list[dict[str, Any]]] | None = None,
colored: bool = False,
verbose: bool = False,
violated_constraints_log: str = "",
sharding_placement: dict[torch.fx.Node, Any] | None = None,
) -> str:
"""
Format the sharding optimization results as annotated Python code.
Expand All @@ -42,10 +43,29 @@ def format_sharding_log(
colored: Whether to use ANSI color codes in the output.
verbose: Whether to include verbose information (shapes, stack traces).
violated_constraints_log: Optional string with violated constraints info.
sharding_placement: Dictionary mapping nodes to their OpSpec placements.
Used as an alternative to ``opt`` when detailed cost info is not
available. Costs will be reported as zero.

Returns:
A string containing the annotated Python code representation of the graph.
"""
if opt is None and sharding_placement is None:
raise ValueError("Either 'opt' or 'sharding_placement' must be provided")
if opt is None:
assert sharding_placement is not None
opt = {
k: [
{
"full_strat": v,
"comm_cost": 0,
"compute_cost": 0,
"sharding_transition_cost": 0,
"cost": 0,
}
]
for k, v in sharding_placement.items()
}
from torch.fx.graph import _color_fns, _identity

nodes = list(graph.nodes)
Expand Down
15 changes: 13 additions & 2 deletions autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def add_node_constraint(self, node, placement=None, constraint_name=None):
constraint_name=constraint_name,
)

def add_sharded_input_constraint(
def add_input_constraints(
self, input_placements: Optional[list[Optional[tuple[Placement, ...]]]] = None
):
"""
Expand Down Expand Up @@ -797,7 +797,7 @@ def add_sharded_input_constraint(
"the inputs before tracing to remove aliasing."
)

def add_sharded_output_constraint(self, output_placements=None):
def add_output_constraints(self, output_placements=None):
"""
USER CONSTRAINTS (Category 6a): Output placement constraints.
Force specific placements for output nodes and their corresponding gradient outputs.
Expand Down Expand Up @@ -841,6 +841,17 @@ def add_sharded_output_constraint(self, output_placements=None):
"them from the graph to avoid aliasing."
)

def get_stats(self) -> dict:
"""Return ILP size statistics."""
num_vars = len(self.ds)
num_constraints = len(self.prob.constraints)
return {
"num_graph_nodes": len(list(self.graph.nodes)),
"num_ilp_variables": num_vars,
"num_ilp_constraints": num_constraints,
"mesh_shape": tuple(self.mesh.shape),
}

def validate(self):
for node in self.graph.nodes:
if node.op != "call_function":
Expand Down
Loading
Loading