From f4a4bdf6c4f88ab7a0f82de53e167fd3711d725f Mon Sep 17 00:00:00 2001 From: Emma Kujala Date: Fri, 11 Jul 2025 12:42:54 +0200 Subject: [PATCH] Arm backend: Add decomposition and test for masked_fill.Scalar Signed-off-by: Emma Kujala Change-Id: I633da58539f93e80e4e2f7484501efc5d36b4bf7 --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 8 + backends/arm/_passes/decompose_masked_fill.py | 52 +++++++ .../tosa_supported_operators.py | 1 + .../arm/quantizer/quantization_annotator.py | 1 - backends/arm/test/ops/test_masked_fill.py | 144 ++++++++++++++++++ 6 files changed, 206 insertions(+), 1 deletion(-) create mode 100644 backends/arm/_passes/decompose_masked_fill.py create mode 100644 backends/arm/test/ops/test_masked_fill.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index ee6b3671eee..b2a6c52313a 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -40,6 +40,7 @@ from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa from .decompose_linear_pass import DecomposeLinearPass # noqa +from .decompose_masked_fill import DecomposeMaskedFill # noqa from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa from .decompose_meandim_pass import DecomposeMeanDimPass # noqa from .decompose_ne_pass import DecomposeNotEqualPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index b5af4a09be0..6a25b8b3a8a 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -45,6 +45,7 @@ DecomposeLeakyReLUPass, DecomposeLinearPass, DecomposeLinearVectorNormPass, + DecomposeMaskedFill, DecomposeMaxPool2DPass, DecomposeMeanDimPass, DecomposeNotEqualPass, @@ -113,6 +114,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass( DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec) ) + self.add_pass(ConvertFullLikeToFullPass()) self.add_pass(ConvertToClampPass()) self.add_pass(ConvertMinMaxPass()) @@ -146,6 +148,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(DecomposeMaxPool2DPass()) self.add_pass(SizeAdjustInputPass()) self.add_pass(DecomposeSelectPass()) + self.add_pass(ConvertSqueezesToViewPass()) self.add_pass(FuseViewCopyTransform()) @@ -160,6 +163,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul return self._transform(exported_program.graph_module) def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule: + self.add_pass(DecomposeMaskedFill()) self.add_pass(DecomposeRoundPass()) self.add_pass(DecomposeAcoshPass()) self.add_pass(DecomposeAsinPass()) @@ -285,4 +289,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(ReplaceInfValues()) self.add_pass(DecomposeSumPass()) + if not self.tosa_spec.is_U55_subset: + # Uses where which is not supported on Ethos-U55 + self.add_pass(DecomposeMaskedFill()) + return self._transform(graph_module) diff --git a/backends/arm/_passes/decompose_masked_fill.py b/backends/arm/_passes/decompose_masked_fill.py new file mode 100644 index 00000000000..fbf3079c92b --- /dev/null +++ b/backends/arm/_passes/decompose_masked_fill.py @@ -0,0 +1,52 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + + +import torch + +from executorch.backends.arm._passes import ArmPass +from executorch.exir.dialects._ops import ops as exir_ops + + +edge_ops = (exir_ops.edge.aten.masked_fill.Scalar,) +aten_ops = (torch.ops.aten.masked_fill.Scalar,) + + +def _get_decomposition(op) -> tuple: + if op in edge_ops: + return ( + exir_ops.edge.aten.where.self, + exir_ops.edge.aten.full_like.default, + ) + if op in aten_ops: + return ( + torch.ops.aten.where.self, + torch.ops.aten.full_like.default, + ) + raise RuntimeError(f"Unable to get decomposition for op {op}") + + +class DecomposeMaskedFill(ArmPass): + """ + Masked fill takes in a boolean mask, a tensor and a scalar value. + Fills the tensor with the scalar value according to the boolean mask. + Decomposed to a where and a full_like operator. + """ + + def call_operator(self, op, args, kwargs, meta, updated=False): + if op not in (edge_ops + aten_ops): + return super().call_operator(op, args, kwargs, meta, updated) + + x, mask, scalar = args + + where_op, full_like_op = _get_decomposition(op) + + scalar_tensor = super().call_operator(full_like_op, (x, scalar), {}, meta, True) + + return super().call_operator( + where_op, (mask, scalar_tensor, x), kwargs, meta, True + ) diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 6d1b8e66c2f..0a0430b7906 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -254,6 +254,7 @@ def is_node_supported( exir_ops.edge.aten.asin.default, exir_ops.edge.aten.atanh.default, exir_ops.edge.aten.addmm.default, + exir_ops.edge.aten.masked_fill.Scalar, ] return supported diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 191341cc217..80ea569f249 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -500,7 +500,6 @@ def any_or_hardtanh_min_zero(n: Node): elif node.target in [operator.getitem]: if not is_output_annotated(node.args[0]): # type: ignore[attr-defined, arg-type] return None - shared_qspec = SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type] quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] # type: ignore[arg-type] quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type] diff --git a/backends/arm/test/ops/test_masked_fill.py b/backends/arm/test/ops/test_masked_fill.py new file mode 100644 index 00000000000..bfd5c8857c7 --- /dev/null +++ b/backends/arm/test/ops/test_masked_fill.py @@ -0,0 +1,144 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch + +from executorch.backends.arm.test import common + +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU85PipelineBI, + OpNotSupportedPipeline, + TosaPipelineBI, + TosaPipelineMI, +) + + +aten_op = "torch.aten.ops.masked_fill.Scalar" +exir_op = "executorch_exir_dialects_edge__ops_aten_masked_fill_scalar" + +input_t = Tuple[torch.Tensor, torch.Tensor, float] + + +class MaskedFill(torch.nn.Module): + def forward( + self, x: torch.Tensor, mask: torch.Tensor, value: float + ) -> torch.Tensor: + return torch.masked_fill(x, mask, value) + + +test_modules = { + "masked_fill_1": lambda: ( + MaskedFill(), + ( + torch.rand(1, 3, 4, 5), + (torch.rand(1, 3, 4, 5) < 0.5), # boolean mask + -1.0, + ), + ), + "masked_fill_2": lambda: ( + MaskedFill(), + ( + torch.rand(1, 10, 10, 10), + (torch.rand(1, 10, 10, 10) > 0.75), + 3.14, + ), + ), + "masked_fill_3_zero_fill": lambda: ( + MaskedFill(), + ( + torch.rand(1, 3, 4, 5), + torch.rand(1, 3, 4, 5) < 0.2, + 0.0, + ), + ), + "masked_fill_4_full_mask": lambda: ( + MaskedFill(), + ( + torch.rand(1, 3, 4, 5), + torch.ones(1, 3, 4, 5, dtype=torch.bool), + 7.0, + ), + ), + "masked_fill_5_no_mask": lambda: ( + MaskedFill(), + ( + torch.rand(1, 3, 4, 5), + torch.zeros(1, 3, 4, 5, dtype=torch.bool), + -3.0, + ), + ), + "masked_fill_6_scalar_broadcast": lambda: ( + MaskedFill(), + ( + torch.rand(1, 1, 1, 1), + torch.tensor([[[[True]]]]), + 42.0, + ), + ), + "masked_fill_7_large_tensor": lambda: ( + MaskedFill(), + ( + torch.rand(1, 8, 8, 8), + torch.rand(1, 8, 8, 8) > 0.5, + -127.0, + ), + ), + "masked_fill_8_extreme_scalar_inf": lambda: ( + MaskedFill(), + ( + torch.rand(1, 3, 7, 5), + torch.rand(1, 3, 7, 5) > 0.5, + float("inf"), + ), + ), +} + + +@common.parametrize("test_module", test_modules) +def test_masked_fill_scalar_tosa_MI(test_module): + module, inputs = test_module() + pipeline = TosaPipelineMI[input_t](module, inputs, aten_op=[]) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +def test_masked_fill_scalar_tosa_BI(test_module): + module, inputs = test_module() + pipeline = TosaPipelineBI[input_t]( + module, + inputs, + aten_op=[], + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +@common.XfailIfNoCorstone300 +def test_masked_fill_scalar_u55_BI(test_module): + module, inputs = test_module() + pipeline = OpNotSupportedPipeline[input_t]( + module, + inputs, + {exir_op: 0, "executorch_exir_dialects_edge__ops_aten_where_self": 1}, + n_expected_delegates=0, + quantize=True, + u55_subset=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +@common.XfailIfNoCorstone320 +def test_masked_fill_scalar_u85_BI(test_module): + module, inputs = test_module() + pipeline = EthosU85PipelineBI[input_t]( + module, + inputs, + aten_ops=[], + exir_ops=exir_op, + ) + pipeline.run()