From 3cbad614c59ae949c2160260791e97d7eabcf18c Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Thu, 12 Feb 2026 13:32:32 +0100 Subject: [PATCH] Arm backend: Rewrite le/lt to ge/gt and simplify visitors Add a pass that rewrites le/lt to ge/gt with swapped inputs. Move eq/ge/gt to SimpleNodeVisitor and drop comparison visitor. Create test for rewrite_le_lt_to_ge_gt_pass. Change-Id: I6c79fb009ff47a1195ae47f417c2e66fa16147bb Signed-off-by: Sebastian Larsson --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 2 + .../arm/_passes/decompose_div_tensor_mode.py | 10 ++-- .../_passes/rewrite_le_lt_to_ge_gt_pass.py | 42 +++++++++++++ backends/arm/operators/__init__.py | 2 - backends/arm/operators/op_eq.py | 59 +++++-------------- backends/arm/operators/op_ge.py | 59 +++++-------------- backends/arm/operators/op_gt.py | 59 +++++-------------- backends/arm/operators/op_le.py | 58 ------------------ backends/arm/operators/op_lt.py | 58 ------------------ backends/arm/operators/simple_node_visitor.py | 6 ++ .../test_rewrite_le_lt_to_ge_gt_pass.py | 51 ++++++++++++++++ 12 files changed, 155 insertions(+), 252 deletions(-) create mode 100644 backends/arm/_passes/rewrite_le_lt_to_ge_gt_pass.py delete mode 100644 backends/arm/operators/op_le.py delete mode 100644 backends/arm/operators/op_lt.py create mode 100644 backends/arm/test/passes/test_rewrite_le_lt_to_ge_gt_pass.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 898073ae4aa..0350af3e6ad 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -132,6 +132,7 @@ ) from .rewrite_conv_pass import RewriteConvPass # noqa from .rewrite_index_put_pass import RewriteIndexPutPass # noqa +from .rewrite_le_lt_to_ge_gt_pass import RewriteLeLtToGeGtPass # noqa from .rewrite_matmul import RewriteMatmulPass # noqa from .rewrite_upsample import RewriteUpsamplePass # noqa from .scalars_to_attribute_pass import ScalarsToAttributePass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 7642f2d6670..574b5be4e1e 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -117,6 +117,7 @@ RewriteBoolToFp32CastViaInt8Pass, RewriteConvPass, RewriteIndexPutPass, + RewriteLeLtToGeGtPass, RewriteMatmulPass, RewriteUpsamplePass, ScalarsToAttributePass, @@ -310,6 +311,7 @@ def _tosa_pipeline( self.add_passes( [ ReplaceScalarWithTensorByProfilePass(), + RewriteLeLtToGeGtPass(), ConvertFullLikeToFullPass(), MatchArgDtypePass(), UnsqueezeScalarPlaceholdersPass(exported_program), diff --git a/backends/arm/_passes/decompose_div_tensor_mode.py b/backends/arm/_passes/decompose_div_tensor_mode.py index 2c16853c27a..774557b816f 100644 --- a/backends/arm/_passes/decompose_div_tensor_mode.py +++ b/backends/arm/_passes/decompose_div_tensor_mode.py @@ -20,7 +20,7 @@ "floor": exir_ops.edge.aten.floor.default, "ceil": exir_ops.edge.aten.ceil.default, "full": exir_ops.edge.aten.full.default, - "lt": exir_ops.edge.aten.lt.Tensor, + "gt": exir_ops.edge.aten.gt.Tensor, "where": exir_ops.edge.aten.where.self, } @@ -29,7 +29,7 @@ "floor": torch.ops.aten.floor.default, "ceil": torch.ops.aten.ceil.default, "full": torch.ops.aten.full.default, - "lt": torch.ops.aten.lt.Tensor, + "gt": torch.ops.aten.gt.Tensor, "where": torch.ops.aten.where.self, } @@ -87,11 +87,13 @@ def call_operator(self, op, args, kwargs, meta): meta=meta, updated=True, ) - lt0 = super().call_operator(opset["lt"], (q, zero), {}, meta, updated=True) + is_neg = super().call_operator( + opset["gt"], (zero, q), {}, meta, updated=True + ) ceilq = super().call_operator(opset["ceil"], (q,), {}, meta, updated=True) floorq = super().call_operator(opset["floor"], (q,), {}, meta, updated=True) return super().call_operator( - opset["where"], (lt0, ceilq, floorq), {}, meta, updated=True + opset["where"], (is_neg, ceilq, floorq), {}, meta, updated=True ) raise RuntimeError( diff --git a/backends/arm/_passes/rewrite_le_lt_to_ge_gt_pass.py b/backends/arm/_passes/rewrite_le_lt_to_ge_gt_pass.py new file mode 100644 index 00000000000..9119567b7aa --- /dev/null +++ b/backends/arm/_passes/rewrite_le_lt_to_ge_gt_pass.py @@ -0,0 +1,42 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Set, Type + +import torch + +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +OP_MAP = { + exir_ops.edge.aten.le.Tensor: exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.lt.Tensor: exir_ops.edge.aten.gt.Tensor, + torch.ops.aten.le.Tensor: torch.ops.aten.ge.Tensor, + torch.ops.aten.lt.Tensor: torch.ops.aten.gt.Tensor, +} + + +class RewriteLeLtToGeGtPass(ArmPass): + """Rewrite le/lt into ge/gt with swapped inputs.""" + + _passes_required_after: Set[Type[ExportPass]] = set() + + def call_operator(self, op, args, kwargs, meta): + if not self.allowed_to_transform(meta): + return super().call_operator(op, args, kwargs, meta) + + target_op = OP_MAP.get(op) + if target_op is None: + return super().call_operator(op, args, kwargs, meta) + + lhs, rhs = args + return super().call_operator( + target_op, + (rhs, lhs), + kwargs, + meta, + updated=True, + ) diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 8ccc315e783..4069ff1dc74 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -32,10 +32,8 @@ op_ge, op_gt, op_index_tensor, - op_le, op_log, op_logical_not, - op_lt, op_max_pool2d, op_maximum, op_minimum, diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index 4f414a796dc..7d5472d07ed 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -3,56 +3,27 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - -from typing import Any, List - import tosa_serializer as ts -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, +from executorch.backends.arm.operators.node_visitor import register_node_visitor +from executorch.backends.arm.operators.simple_node_visitor import ( + SimpleNodeVisitor, + SimpleNodeVisitorConfig, ) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, -) -from executorch.backends.arm.tosa.mapping import TosaArg - -from torch.fx import Node @register_node_visitor -class EqualVisitor(NodeVisitor): +class EqualVisitor(SimpleNodeVisitor): target = "aten.eq.Tensor" - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - validate_num_inputs(self.target, inputs, 2) - validate_same_dtype(self.target, inputs, ts) - validate_valid_dtype( - self.target, - inputs, - [ts.DType.INT32, ts.DType.FP16, ts.DType.FP32, ts.DType.BF16], - self.tosa_spec, - ) - validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec) - - attr = ts.TosaSerializerAttribute() - attr.EqualAttribute() - self._serialize_operator( - node, - tosa_graph, - ts.Op.EQUAL, - [inputs[0].name, inputs[1].name], - [output.name], - attr, + @classmethod + def get_config(cls) -> SimpleNodeVisitorConfig: + return SimpleNodeVisitorConfig( + tosa_op=ts.Op.EQUAL, + attr_method="EqualAttribute", + num_inputs=2, + input_dtypes=[ts.DType.INT32, ts.DType.FP16, ts.DType.FP32, ts.DType.BF16], + output_dtypes=[ts.DType.BOOL], + same_dtype_with_output=False, + dtype_check_inputs_only=True, ) diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index 7f0527e8a0c..aec8dc96044 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -3,56 +3,29 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - -from typing import Any, List - import tosa_serializer as ts -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, +from executorch.backends.arm.operators.node_visitor import register_node_visitor +from executorch.backends.arm.operators.simple_node_visitor import ( + SimpleNodeVisitor, + SimpleNodeVisitorConfig, ) -from executorch.backends.arm.tosa.mapping import TosaArg -from torch.fx import Node +COMPARE_INPUT_DTYPES = [ts.DType.INT32, ts.DType.FP32, ts.DType.BF16, ts.DType.FP16] @register_node_visitor -class GreaterEqualVisitor(NodeVisitor): +class GreaterEqualVisitor(SimpleNodeVisitor): target = "aten.ge.Tensor" - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - validate_num_inputs(self.target, inputs, 2) - validate_same_dtype(self.target, inputs, ts) - validate_valid_dtype( - self.target, - inputs, - [ts.DType.INT32, ts.DType.FP32, ts.DType.BF16, ts.DType.FP16], - self.tosa_spec, - ) - validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec) - - attr = ts.TosaSerializerAttribute() - attr.GreaterEqualAttribute() - self._serialize_operator( - node, - tosa_graph, - ts.Op.GREATER_EQUAL, - [inputs[0].name, inputs[1].name], - [output.name], - attr, + @classmethod + def get_config(cls) -> SimpleNodeVisitorConfig: + return SimpleNodeVisitorConfig( + tosa_op=ts.Op.GREATER_EQUAL, + attr_method="GreaterEqualAttribute", + num_inputs=2, + input_dtypes=COMPARE_INPUT_DTYPES, + output_dtypes=[ts.DType.BOOL], + same_dtype_with_output=False, + dtype_check_inputs_only=True, ) diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index 7c3326b6e8f..f7b05889d1d 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -3,56 +3,29 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - -from typing import Any, List - import tosa_serializer as ts -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, +from executorch.backends.arm.operators.node_visitor import register_node_visitor +from executorch.backends.arm.operators.simple_node_visitor import ( + SimpleNodeVisitor, + SimpleNodeVisitorConfig, ) -from executorch.backends.arm.tosa.mapping import TosaArg -from torch.fx import Node +COMPARE_INPUT_DTYPES = [ts.DType.INT32, ts.DType.FP32, ts.DType.BF16, ts.DType.FP16] @register_node_visitor -class GreaterThanVisitor(NodeVisitor): +class GreaterThanVisitor(SimpleNodeVisitor): target = "aten.gt.Tensor" - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - validate_num_inputs(self.target, inputs, 2) - validate_same_dtype(self.target, inputs, ts) - validate_valid_dtype( - self.target, - inputs, - [ts.DType.INT32, ts.DType.FP32, ts.DType.BF16, ts.DType.FP16], - self.tosa_spec, - ) - validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec) - - attr = ts.TosaSerializerAttribute() - attr.GreaterAttribute() - self._serialize_operator( - node, - tosa_graph, - ts.Op.GREATER, - [inputs[0].name, inputs[1].name], - [output.name], - attr, + @classmethod + def get_config(cls) -> SimpleNodeVisitorConfig: + return SimpleNodeVisitorConfig( + tosa_op=ts.Op.GREATER, + attr_method="GreaterAttribute", + num_inputs=2, + input_dtypes=COMPARE_INPUT_DTYPES, + output_dtypes=[ts.DType.BOOL], + same_dtype_with_output=False, + dtype_check_inputs_only=True, ) diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py deleted file mode 100644 index ca71b636cda..00000000000 --- a/backends/arm/operators/op_le.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import Any, List - -import tosa_serializer as ts - -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, -) -from executorch.backends.arm.tosa.mapping import TosaArg - -from torch.fx import Node - - -@register_node_visitor -class LessEqualVisitor(NodeVisitor): - target = "aten.le.Tensor" - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - validate_num_inputs(self.target, inputs, 2) - validate_same_dtype(self.target, inputs, ts) - validate_valid_dtype( - self.target, - inputs, - [ts.DType.INT32, ts.DType.FP16, ts.DType.FP32], - self.tosa_spec, - ) - validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec) - - attr = ts.TosaSerializerAttribute() - attr.GreaterEqualAttribute() - self._serialize_operator( - node, - tosa_graph, - ts.Op.GREATER_EQUAL, - [inputs[1].name, inputs[0].name], - [output.name], - attr, - ) diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py deleted file mode 100644 index b3c19aaaece..00000000000 --- a/backends/arm/operators/op_lt.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -from typing import Any, List - -import tosa_serializer as ts - -from executorch.backends.arm.operators.node_visitor import ( - NodeVisitor, - register_node_visitor, -) -from executorch.backends.arm.operators.operator_validation_utils import ( - validate_num_inputs, - validate_same_dtype, - validate_valid_dtype, -) -from executorch.backends.arm.tosa.mapping import TosaArg - -from torch.fx import Node - - -@register_node_visitor -class LessThanVisitor(NodeVisitor): - target = "aten.lt.Tensor" - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - validate_num_inputs(self.target, inputs, 2) - validate_same_dtype(self.target, inputs, ts) - validate_valid_dtype( - self.target, - inputs, - [ts.DType.INT32, ts.DType.FP16, ts.DType.FP32], - self.tosa_spec, - ) - validate_valid_dtype(self.target, output, ts.DType.BOOL, self.tosa_spec) - - attr = ts.TosaSerializerAttribute() - attr.GreaterAttribute() - self._serialize_operator( - node, - tosa_graph, - ts.Op.GREATER, - [inputs[1].name, inputs[0].name], - [output.name], - attr, - ) diff --git a/backends/arm/operators/simple_node_visitor.py b/backends/arm/operators/simple_node_visitor.py index 825143b2226..ad44ffe3d0a 100644 --- a/backends/arm/operators/simple_node_visitor.py +++ b/backends/arm/operators/simple_node_visitor.py @@ -23,6 +23,9 @@ class SimpleNodeVisitorConfig: num_inputs: int | List[int] input_dtypes: List[Any] attr_kwargs: dict[str, Any] | None = None + output_dtypes: List[Any] | None = None + same_dtype_with_output: bool = True + dtype_check_inputs_only: bool = False class SimpleNodeVisitor(NodeVisitor): @@ -51,6 +54,9 @@ def define_node( output=output, num_inputs=cfg.num_inputs, input_dtypes=cfg.input_dtypes, + output_dtypes=cfg.output_dtypes, + same_dtype_with_output=cfg.same_dtype_with_output, + dtype_check_inputs_only=cfg.dtype_check_inputs_only, ) self.serialize( diff --git a/backends/arm/test/passes/test_rewrite_le_lt_to_ge_gt_pass.py b/backends/arm/test/passes/test_rewrite_le_lt_to_ge_gt_pass.py new file mode 100644 index 00000000000..781bc656be7 --- /dev/null +++ b/backends/arm/test/passes/test_rewrite_le_lt_to_ge_gt_pass.py @@ -0,0 +1,51 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm._passes.rewrite_le_lt_to_ge_gt_pass import ( + RewriteLeLtToGeGtPass, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline + +input_t = Tuple[torch.Tensor, torch.Tensor] + + +class LtLe(torch.nn.Module): + def get_inputs(self) -> input_t: + return (torch.randn(4, 4), torch.randn(4, 4)) + + def forward( + self, x: torch.Tensor, y: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + return (x < y, x <= y) + + +@common.parametrize("module", {"lt_le": LtLe()}) +def test_rewrite_le_lt_to_ge_gt_no_target(module: LtLe) -> None: + pipeline = PassPipeline[input_t]( + module, + module.get_inputs(), + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten_lt_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_le_Tensor": 1, + }, + ops_not_before_pass=[ + "executorch_exir_dialects_edge__ops_aten_gt_Tensor", + "executorch_exir_dialects_edge__ops_aten_ge_Tensor", + ], + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_gt_Tensor": 1, + "executorch_exir_dialects_edge__ops_aten_ge_Tensor": 1, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten_lt_Tensor", + "executorch_exir_dialects_edge__ops_aten_le_Tensor", + ], + pass_list=[RewriteLeLtToGeGtPass], + ) + pipeline.run()