From f51cd4dcaebe0f9697a295df3c3ee8c69ebb9abf Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Fri, 27 Feb 2026 17:58:42 -0800 Subject: [PATCH] try smaller decomp table --- autoparallel/api.py | 63 +++++++------------ .../cost_models/compute_estimation.py | 5 ++ 2 files changed, 28 insertions(+), 40 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index 83e86564..dc348bbd 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -12,13 +12,10 @@ import torch from torch._dynamo.functional_export import _dynamo_graph_capture_for_export -from torch._functorch.aot_autograd import ( - aot_compile_joint_with_descriptors, - aot_export_joint_with_descriptors, - boxed_nop_preserve_node_meta, -) +from torch._functorch.aot_autograd import (aot_compile_joint_with_descriptors, + aot_export_joint_with_descriptors, + boxed_nop_preserve_node_meta) from torch._inductor.compile_fx import compile_fx_inner -from torch._inductor.decomposition import select_decomp_table from torch._logging import trace_structured from torch._subclasses import FakeTensorMode from torch.distributed.fsdp import MixedPrecisionPolicy @@ -28,22 +25,19 @@ from torch.fx.experimental.symbolic_shapes import ShapeEnv from .apply_sharding import apply_sharding_to_model -from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast +from .cast_parametrization import (apply_dtype_cast, canonicalize_mp, + set_dtype_cast) from .graph_passes.activation_checkpointing import ac_joint_pass -from .graph_passes.graph_utils import ( - _add_alias, - _replace_view_mm_view_with_einsum, - assert_has_no_collectives, - cleanup_graph, - update_joint_with_descriptors, -) +from .graph_passes.graph_utils import (_add_alias, + _replace_view_mm_view_with_einsum, + assert_has_no_collectives, + cleanup_graph, + update_joint_with_descriptors) from .init_weights import hook_params_setters from .optimize_sharding import ShardingOptimizer -from .shardings.placement_options import ( - NumericsLogger, - _get_device_from_mesh, - debug_boxed_nop_preserve_node_meta, -) +from .shardings.placement_options import (NumericsLogger, + _get_device_from_mesh, + debug_boxed_nop_preserve_node_meta) _APPLY_VIEW_MM_VIEW_PATTERN = False @@ -111,28 +105,18 @@ def _assign_attr( def _get_decomp_table(): - decomp_table = copy.copy(select_decomp_table()) - # TODO: removing those as they cause missing DTensor propagation rules - decomp_table.pop(torch.ops.aten.full_like.default) - decomp_table.pop(torch.ops.aten.empty_like.default) - decomp_table.pop(torch.ops.aten.threshold_backward.default) - decomp_table.pop(torch.ops.aten.native_layer_norm.default) - decomp_table.pop(torch.ops.aten.embedding_dense_backward.default) - decomp_table.pop(torch.ops.aten.native_layer_norm_backward.default) - decomp_table.pop(torch.ops.aten._softmax_backward_data.default) - decomp_table.pop(torch.ops.aten._softmax.default) - decomp_table.pop(torch.ops.aten.stack.default) - - # decompose addmm to allow for TP on mm - decomp_table.pop(torch.ops.aten.addmm.default) - def addmm_decomp(self, mat1, mat2, beta=1, alpha=1): return self + mat1 @ mat2 - decomp_table[torch.ops.aten.addmm.default] = addmm_decomp - # decomp_table = None + def detach_decomp(x): + return x - return decomp_table + return { + # Decompose addmm into mm + add to allow for TP on the mm. + torch.ops.aten.addmm.default: addmm_decomp, + # Remove detach nodes inserted by aot_autograd on saved tensors. + torch.ops.aten.detach.default: detach_decomp, + } def move_to_fake(model: torch.nn.Module, mode: FakeTensorMode, device: torch.device): @@ -173,9 +157,8 @@ def _move_to_fake(module, k, device, parameter=True): @contextmanager def enable_local_map_wrapping(): - from torch._dynamo.variables.higher_order_ops import ( - LocalMapWrappedHigherOrderVariable as vt_cls, - ) + from torch._dynamo.variables.higher_order_ops import \ + LocalMapWrappedHigherOrderVariable as vt_cls from torch._higher_order_ops import local_map as local_map_module with vt_cls.enable(), local_map_module.defer_inlining(): diff --git a/autoparallel/cost_models/compute_estimation.py b/autoparallel/cost_models/compute_estimation.py index d0b77bbb..fdb843d6 100644 --- a/autoparallel/cost_models/compute_estimation.py +++ b/autoparallel/cost_models/compute_estimation.py @@ -344,6 +344,11 @@ def _has_zero_cost(node): if node.target.is_view: return True + # _unsafe_view is not tagged as a view op but is semantically a reshape; + # its non-tensor shape args become invalid after sharding so skip costing. + if node.target == torch.ops.aten._unsafe_view.default: + return True + return False