Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 23 additions & 40 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@

import torch
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._functorch.aot_autograd import (
aot_compile_joint_with_descriptors,
aot_export_joint_with_descriptors,
boxed_nop_preserve_node_meta,
)
from torch._functorch.aot_autograd import (aot_compile_joint_with_descriptors,
aot_export_joint_with_descriptors,
boxed_nop_preserve_node_meta)
from torch._inductor.compile_fx import compile_fx_inner
from torch._inductor.decomposition import select_decomp_table
from torch._logging import trace_structured
from torch._subclasses import FakeTensorMode
from torch.distributed.fsdp import MixedPrecisionPolicy
Expand All @@ -28,22 +25,19 @@
from torch.fx.experimental.symbolic_shapes import ShapeEnv

from .apply_sharding import apply_sharding_to_model
from .cast_parametrization import apply_dtype_cast, canonicalize_mp, set_dtype_cast
from .cast_parametrization import (apply_dtype_cast, canonicalize_mp,
set_dtype_cast)
from .graph_passes.activation_checkpointing import ac_joint_pass
from .graph_passes.graph_utils import (
_add_alias,
_replace_view_mm_view_with_einsum,
assert_has_no_collectives,
cleanup_graph,
update_joint_with_descriptors,
)
from .graph_passes.graph_utils import (_add_alias,
_replace_view_mm_view_with_einsum,
assert_has_no_collectives,
cleanup_graph,
update_joint_with_descriptors)
from .init_weights import hook_params_setters
from .optimize_sharding import ShardingOptimizer
from .shardings.placement_options import (
NumericsLogger,
_get_device_from_mesh,
debug_boxed_nop_preserve_node_meta,
)
from .shardings.placement_options import (NumericsLogger,
_get_device_from_mesh,
debug_boxed_nop_preserve_node_meta)

_APPLY_VIEW_MM_VIEW_PATTERN = False

Expand Down Expand Up @@ -111,28 +105,18 @@ def _assign_attr(


def _get_decomp_table():
decomp_table = copy.copy(select_decomp_table())
# TODO: removing those as they cause missing DTensor propagation rules
decomp_table.pop(torch.ops.aten.full_like.default)
decomp_table.pop(torch.ops.aten.empty_like.default)
decomp_table.pop(torch.ops.aten.threshold_backward.default)
decomp_table.pop(torch.ops.aten.native_layer_norm.default)
decomp_table.pop(torch.ops.aten.embedding_dense_backward.default)
decomp_table.pop(torch.ops.aten.native_layer_norm_backward.default)
decomp_table.pop(torch.ops.aten._softmax_backward_data.default)
decomp_table.pop(torch.ops.aten._softmax.default)
decomp_table.pop(torch.ops.aten.stack.default)

# decompose addmm to allow for TP on mm
decomp_table.pop(torch.ops.aten.addmm.default)

def addmm_decomp(self, mat1, mat2, beta=1, alpha=1):
return self + mat1 @ mat2

decomp_table[torch.ops.aten.addmm.default] = addmm_decomp
# decomp_table = None
def detach_decomp(x):
return x

return decomp_table
return {
# Decompose addmm into mm + add to allow for TP on the mm.
torch.ops.aten.addmm.default: addmm_decomp,
# Remove detach nodes inserted by aot_autograd on saved tensors.
torch.ops.aten.detach.default: detach_decomp,
}


def move_to_fake(model: torch.nn.Module, mode: FakeTensorMode, device: torch.device):
Expand Down Expand Up @@ -173,9 +157,8 @@ def _move_to_fake(module, k, device, parameter=True):

@contextmanager
def enable_local_map_wrapping():
from torch._dynamo.variables.higher_order_ops import (
LocalMapWrappedHigherOrderVariable as vt_cls,
)
from torch._dynamo.variables.higher_order_ops import \
LocalMapWrappedHigherOrderVariable as vt_cls
from torch._higher_order_ops import local_map as local_map_module

with vt_cls.enable(), local_map_module.defer_inlining():
Expand Down
5 changes: 5 additions & 0 deletions autoparallel/cost_models/compute_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,11 @@ def _has_zero_cost(node):
if node.target.is_view:
return True

# _unsafe_view is not tagged as a view op but is semantically a reshape;
# its non-tensor shape args become invalid after sharding so skip costing.
if node.target == torch.ops.aten._unsafe_view.default:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be fixed in PyTorch? I'd expect _unsafe_view to have the is_view tag set?

Copy link
Contributor Author

@pianpwk pianpwk Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this is intentional? this flag prevents AOTAutograd from doing all the view-replay needed for functionalization.

From https://github.com/pytorch/pytorch/blob/960d2b693a031966158a7d7f5f6324cf15361f8f/aten/src/ATen/native/TensorShape.cpp#L4080-L4090:

// NOTE [ Unsafe View ]
// _unsafe_view() differs from view() in that the returned tensor isn't treated
// as a view for the purposes of automatic differentiation. (It's not listed in
// VIEW_FUNCTIONS in gen_inplace_or_view_type.py).  It's only safe to use if the
// `self` tensor is temporary. For example, the viewed tensor here (a + b) is
// discarded immediately after viewing:
//
//  res = at::_unsafe_view(a + b, size);
//
// This is a hack because in-place operations on tensors treated like views
// can be much more expensive than the same operations on non-view tensors.

But I think for autoparallel you want the "is_view", 0-cost semantics

return True

return False


Expand Down
Loading