Skip to content

Commit 3ab7063

Browse files
authored
Arm backend: Add decomposition and test for masked_fill.Scalar (#12746)
Add decomposition and test for masked_fill.Scalar-operator Signed-off-by: Emma Kujala <[email protected]>
1 parent 3eea912 commit 3ab7063

File tree

6 files changed

+206
-1
lines changed

6 files changed

+206
-1
lines changed

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
4141
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa
4242
from .decompose_linear_pass import DecomposeLinearPass # noqa
43+
from .decompose_masked_fill import DecomposeMaskedFill # noqa
4344
from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa
4445
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
4546
from .decompose_ne_pass import DecomposeNotEqualPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
DecomposeLeakyReLUPass,
4646
DecomposeLinearPass,
4747
DecomposeLinearVectorNormPass,
48+
DecomposeMaskedFill,
4849
DecomposeMaxPool2DPass,
4950
DecomposeMeanDimPass,
5051
DecomposeNotEqualPass,
@@ -113,6 +114,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
113114
self.add_pass(
114115
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
115116
)
117+
116118
self.add_pass(ConvertFullLikeToFullPass())
117119
self.add_pass(ConvertToClampPass())
118120
self.add_pass(ConvertMinMaxPass())
@@ -146,6 +148,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
146148
self.add_pass(DecomposeMaxPool2DPass())
147149
self.add_pass(SizeAdjustInputPass())
148150
self.add_pass(DecomposeSelectPass())
151+
149152
self.add_pass(ConvertSqueezesToViewPass())
150153

151154
self.add_pass(FuseViewCopyTransform())
@@ -160,6 +163,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
160163
return self._transform(exported_program.graph_module)
161164

162165
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
166+
self.add_pass(DecomposeMaskedFill())
163167
self.add_pass(DecomposeRoundPass())
164168
self.add_pass(DecomposeAcoshPass())
165169
self.add_pass(DecomposeAsinPass())
@@ -285,4 +289,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
285289
self.add_pass(ReplaceInfValues())
286290
self.add_pass(DecomposeSumPass())
287291

292+
if not self.tosa_spec.is_U55_subset:
293+
# Uses where which is not supported on Ethos-U55
294+
self.add_pass(DecomposeMaskedFill())
295+
288296
return self._transform(graph_module)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import torch
10+
11+
from executorch.backends.arm._passes import ArmPass
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
14+
15+
edge_ops = (exir_ops.edge.aten.masked_fill.Scalar,)
16+
aten_ops = (torch.ops.aten.masked_fill.Scalar,)
17+
18+
19+
def _get_decomposition(op) -> tuple:
20+
if op in edge_ops:
21+
return (
22+
exir_ops.edge.aten.where.self,
23+
exir_ops.edge.aten.full_like.default,
24+
)
25+
if op in aten_ops:
26+
return (
27+
torch.ops.aten.where.self,
28+
torch.ops.aten.full_like.default,
29+
)
30+
raise RuntimeError(f"Unable to get decomposition for op {op}")
31+
32+
33+
class DecomposeMaskedFill(ArmPass):
34+
"""
35+
Masked fill takes in a boolean mask, a tensor and a scalar value.
36+
Fills the tensor with the scalar value according to the boolean mask.
37+
Decomposed to a where and a full_like operator.
38+
"""
39+
40+
def call_operator(self, op, args, kwargs, meta, updated=False):
41+
if op not in (edge_ops + aten_ops):
42+
return super().call_operator(op, args, kwargs, meta, updated)
43+
44+
x, mask, scalar = args
45+
46+
where_op, full_like_op = _get_decomposition(op)
47+
48+
scalar_tensor = super().call_operator(full_like_op, (x, scalar), {}, meta, True)
49+
50+
return super().call_operator(
51+
where_op, (mask, scalar_tensor, x), kwargs, meta, True
52+
)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def is_node_supported(
254254
exir_ops.edge.aten.asin.default,
255255
exir_ops.edge.aten.atanh.default,
256256
exir_ops.edge.aten.addmm.default,
257+
exir_ops.edge.aten.masked_fill.Scalar,
257258
]
258259

259260
return supported

backends/arm/quantizer/quantization_annotator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,6 @@ def any_or_hardtanh_min_zero(n: Node):
500500
elif node.target in [operator.getitem]:
501501
if not is_output_annotated(node.args[0]): # type: ignore[attr-defined, arg-type]
502502
return None
503-
504503
shared_qspec = SharedQuantizationSpec(node.args[0]) # type: ignore[arg-type]
505504
quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] # type: ignore[arg-type]
506505
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
from executorch.backends.arm.test import common
11+
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU85PipelineBI,
14+
OpNotSupportedPipeline,
15+
TosaPipelineBI,
16+
TosaPipelineMI,
17+
)
18+
19+
20+
aten_op = "torch.aten.ops.masked_fill.Scalar"
21+
exir_op = "executorch_exir_dialects_edge__ops_aten_masked_fill_scalar"
22+
23+
input_t = Tuple[torch.Tensor, torch.Tensor, float]
24+
25+
26+
class MaskedFill(torch.nn.Module):
27+
def forward(
28+
self, x: torch.Tensor, mask: torch.Tensor, value: float
29+
) -> torch.Tensor:
30+
return torch.masked_fill(x, mask, value)
31+
32+
33+
test_modules = {
34+
"masked_fill_1": lambda: (
35+
MaskedFill(),
36+
(
37+
torch.rand(1, 3, 4, 5),
38+
(torch.rand(1, 3, 4, 5) < 0.5), # boolean mask
39+
-1.0,
40+
),
41+
),
42+
"masked_fill_2": lambda: (
43+
MaskedFill(),
44+
(
45+
torch.rand(1, 10, 10, 10),
46+
(torch.rand(1, 10, 10, 10) > 0.75),
47+
3.14,
48+
),
49+
),
50+
"masked_fill_3_zero_fill": lambda: (
51+
MaskedFill(),
52+
(
53+
torch.rand(1, 3, 4, 5),
54+
torch.rand(1, 3, 4, 5) < 0.2,
55+
0.0,
56+
),
57+
),
58+
"masked_fill_4_full_mask": lambda: (
59+
MaskedFill(),
60+
(
61+
torch.rand(1, 3, 4, 5),
62+
torch.ones(1, 3, 4, 5, dtype=torch.bool),
63+
7.0,
64+
),
65+
),
66+
"masked_fill_5_no_mask": lambda: (
67+
MaskedFill(),
68+
(
69+
torch.rand(1, 3, 4, 5),
70+
torch.zeros(1, 3, 4, 5, dtype=torch.bool),
71+
-3.0,
72+
),
73+
),
74+
"masked_fill_6_scalar_broadcast": lambda: (
75+
MaskedFill(),
76+
(
77+
torch.rand(1, 1, 1, 1),
78+
torch.tensor([[[[True]]]]),
79+
42.0,
80+
),
81+
),
82+
"masked_fill_7_large_tensor": lambda: (
83+
MaskedFill(),
84+
(
85+
torch.rand(1, 8, 8, 8),
86+
torch.rand(1, 8, 8, 8) > 0.5,
87+
-127.0,
88+
),
89+
),
90+
"masked_fill_8_extreme_scalar_inf": lambda: (
91+
MaskedFill(),
92+
(
93+
torch.rand(1, 3, 7, 5),
94+
torch.rand(1, 3, 7, 5) > 0.5,
95+
float("inf"),
96+
),
97+
),
98+
}
99+
100+
101+
@common.parametrize("test_module", test_modules)
102+
def test_masked_fill_scalar_tosa_MI(test_module):
103+
module, inputs = test_module()
104+
pipeline = TosaPipelineMI[input_t](module, inputs, aten_op=[])
105+
pipeline.run()
106+
107+
108+
@common.parametrize("test_module", test_modules)
109+
def test_masked_fill_scalar_tosa_BI(test_module):
110+
module, inputs = test_module()
111+
pipeline = TosaPipelineBI[input_t](
112+
module,
113+
inputs,
114+
aten_op=[],
115+
)
116+
pipeline.run()
117+
118+
119+
@common.parametrize("test_module", test_modules)
120+
@common.XfailIfNoCorstone300
121+
def test_masked_fill_scalar_u55_BI(test_module):
122+
module, inputs = test_module()
123+
pipeline = OpNotSupportedPipeline[input_t](
124+
module,
125+
inputs,
126+
{exir_op: 0, "executorch_exir_dialects_edge__ops_aten_where_self": 1},
127+
n_expected_delegates=0,
128+
quantize=True,
129+
u55_subset=True,
130+
)
131+
pipeline.run()
132+
133+
134+
@common.parametrize("test_module", test_modules)
135+
@common.XfailIfNoCorstone320
136+
def test_masked_fill_scalar_u85_BI(test_module):
137+
module, inputs = test_module()
138+
pipeline = EthosU85PipelineBI[input_t](
139+
module,
140+
inputs,
141+
aten_ops=[],
142+
exir_ops=exir_op,
143+
)
144+
pipeline.run()

0 commit comments

Comments
 (0)