diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index dd7f3d02518..2b935172a7f 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -27,6 +27,7 @@ from .decompose_select import DecomposeSelectPass # noqa from .decompose_softmax_pass import DecomposeSoftmaxPass # noqa from .decompose_softmax_unstable_pass import DecomposeSoftmaxUnstablePass # noqa +from .decompose_sqrt_pass import DecomposeSqrtPass # noqa from .decompose_var_pass import DecomposeVarPass # noqa from .fold_qdq_with_annotated_qparams_pass import ( # noqa FoldAndAnnotateQParamsPass, diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 703c6ff214c..85004686ebe 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -32,6 +32,7 @@ DecomposeSelectPass, DecomposeSoftmaxPass, DecomposeSoftmaxUnstablePass, + DecomposeSqrtPass, DecomposeVarPass, FoldAndAnnotateQParamsPass, FuseBatchnorm2DPass, @@ -115,6 +116,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul return self._transform(exported_program.graph_module) def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: + self.add_pass(DecomposeSqrtPass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSAMI()) self.add_pass(FuseQuantizedActivationPass()) self.add_pass(RemoveGetItemPass()) @@ -181,6 +183,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeMeanDimPass()) self.add_pass(DecomposeDivPass()) self.add_pass(DecomposeLeakyReLUPass()) + self.add_pass(DecomposeSqrtPass()) if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: # Numerically stable softmax uses amax which is not supported on Ethos-U55 diff --git a/backends/arm/_passes/decompose_sqrt_pass.py b/backends/arm/_passes/decompose_sqrt_pass.py new file mode 100644 index 00000000000..d4a678affea --- /dev/null +++ b/backends/arm/_passes/decompose_sqrt_pass.py @@ -0,0 +1,39 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +edge_sqrt_ops = (exir_ops.edge.aten.sqrt.default,) +aten_sqrt_ops = ( + torch.ops.aten.sqrt.default, + torch.ops.aten.sqrt_.default, +) + + +def get_sqrt_decomposition(op) -> tuple: + # TODO : "MLETORCH-863 : Replace current sqrt -> pow.Tensor_Scalar workaround with pow.Tensor_Tensor" + if op in edge_sqrt_ops: + return exir_ops.edge.aten.pow.Tensor_Scalar + if op in aten_sqrt_ops: + return torch.ops.aten.pow.Tensor_Scalar + raise RuntimeError(f"Can't get sqrt decomposition for op {op}") + + +class DecomposeSqrtPass(ExportPass): + + def call_operator(self, op, args, kwargs, meta): + """ + Decomposes `sqrt(x)` into `pow(x, 0.5)` for backend support. + """ + + if op not in (edge_sqrt_ops + aten_sqrt_ops): + return super().call_operator(op, args, kwargs, meta) + + pow_op = get_sqrt_decomposition(op) + + return super().call_operator(pow_op, (args[0], 0.5), {}, meta) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 09230e44257..75e634edf5d 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -194,6 +194,7 @@ def is_node_supported( exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.relu.default, exir_ops.edge.aten.leaky_relu.default, + exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.rsqrt.default, exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.select_copy.int, @@ -256,6 +257,7 @@ def is_node_supported( exir_ops.edge.aten.var.correction, exir_ops.edge.aten.var.dim, exir_ops.edge.aten.add.Scalar, + exir_ops.edge.aten.sqrt.default, exir_ops.edge.aten.sub.Scalar, exir_ops.edge.aten.mul.Scalar, exir_ops.edge.aten.div.Scalar, diff --git a/backends/arm/test/ops/test_sqrt.py b/backends/arm/test/ops/test_sqrt.py new file mode 100644 index 00000000000..53a1e79c0a8 --- /dev/null +++ b/backends/arm/test/ops/test_sqrt.py @@ -0,0 +1,78 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Dict, Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineBI, + EthosU85PipelineBI, + TosaPipelineBI, + TosaPipelineMI, +) + + +class Sqrt(torch.nn.Module): + input_t = Tuple[torch.Tensor] + aten_op_MI = "torch.ops.aten.sqrt.default" + exir_op_MI = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Tensor" + + aten_op_BI = "torch.ops.aten.pow.Tensor_Scalar" + exir_op_BI = "executorch_exir_dialects_edge__ops_aten_pow_Tensor_Scalar" + + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sqrt(x) + + test_data: Dict[str, input_t] = { + "sqrt_tensor_rank1_ones": (torch.ones(10),), + "sqrt_tensor_rank2_random": (torch.rand(5, 10),), + "sqrt_tensor_rank3_ones": (torch.ones(2, 3, 4),), + "sqrt_tensor_rank4_random": (torch.rand(1, 3, 8, 8),), + "sqrt_tensor_rank4_multibatch": (torch.rand(2, 3, 4, 4),), + } + + +fvp_xfails = { + "sqrt_tensor_rank4_multibatch": "MLETORCH-517 : Multiple batches not supported", +} + + +@common.parametrize("test_data", Sqrt.test_data) +def test_sqrt_tosa_MI(test_data: Sqrt.input_t): + pipeline = TosaPipelineMI[Sqrt.input_t]( + Sqrt(), test_data, Sqrt.aten_op_MI, Sqrt.exir_op_MI + ) + pipeline.run() + + +@common.parametrize("test_data", Sqrt.test_data) +def test_sqrt_tosa_BI(test_data: Sqrt.input_t): + pipeline = TosaPipelineBI[Sqrt.input_t]( + Sqrt(), test_data, Sqrt.aten_op_BI, Sqrt.exir_op_BI + ) + pipeline.run() + + +@common.parametrize("test_data", Sqrt.test_data, fvp_xfails) +@common.XfailIfNoCorstone300 +def test_sqrt_u55_BI(test_data: Sqrt.input_t): + pipeline = EthosU55PipelineBI[Sqrt.input_t]( + Sqrt(), test_data, Sqrt.aten_op_BI, Sqrt.exir_op_BI, run_on_fvp=True + ) + pipeline.run() + + +@common.parametrize("test_data", Sqrt.test_data, fvp_xfails) +@common.XfailIfNoCorstone320 +def test_sqrt_u85_BI(test_data: Sqrt.input_t): + pipeline = EthosU85PipelineBI[Sqrt.input_t]( + Sqrt(), test_data, Sqrt.aten_op_BI, Sqrt.exir_op_BI, run_on_fvp=True + ) + pipeline.run()