Skip to content

Commit 3eea912

Browse files
StrycekSimonrobert-kalmarroman-janik-nxp
authored
NXP backend: Add support for aten.add.Tensor, aten.mean.dim, aten._adaptive_avg_pool2d, aten.clone and aten.abs operators (#12585)
### Summary Add quantization and conversion support for the following operators: - aten.add.Tensor - aten.mean.dim - aten._adaptive_avg_pool2d - aten.clone - aten.abs ### Test plan All newly supported operators should be covered by unit tests. --------- Co-authored-by: Robert Kalmar <[email protected]> Co-authored-by: Roman Janik <[email protected]>
1 parent b562f36 commit 3eea912

25 files changed

+1363
-185
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,20 @@
2323

2424
# noinspection PyProtectedMember
2525
functions_converters = {
26+
exir_ops.edge.aten.abs.default: AbsConverter, # noqa F405
27+
exir_ops.edge.aten._adaptive_avg_pool2d.default: AdaptiveAvgPool2dConverter, # noqa F405
2628
exir_ops.edge.aten.addmm.default: AddMMConverter, # noqa F405
29+
exir_ops.edge.aten.add.Tensor: AddTensorConverter, # noqa F405
2730
exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405
31+
exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405
2832
exir_ops.edge.aten.constant_pad_nd.default: ConstantPadNDConverter, # noqa F405
2933
exir_ops.edge.aten.convolution.default: ConvolutionConverter, # noqa F405
34+
exir_ops.edge.aten.hardtanh.default: HardTanhConverter, # noqa F405
3035
exir_ops.edge.aten.max_pool2d.default: MaxPool2dConverter, # noqa F405
36+
exir_ops.edge.aten.mean.dim: MeanDimConverter, # noqa F405
3137
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
3238
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
3339
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
34-
exir_ops.edge.aten.hardtanh.default: HardTanhConverter, # noqa F405
3540
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
3641
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
3742
}

backends/nxp/backend/ir/converter/conversion/common.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
max_pool_2d_options,
2323
transpose_conv_options,
2424
)
25+
from torch.fx import Node
2526

2627

2728
def exactly_one_is_none(obj1: Optional, obj2: Optional) -> bool:
@@ -166,6 +167,34 @@ def uses_shape_broadcasting(t_op: tflite_model.Operator) -> bool:
166167
)
167168

168169

170+
def node_uses_shape_broadcasting(node: Node) -> bool:
171+
"""Determine if given PyTorch fx Node uses shape broadcasting for it's input nodes or not.
172+
173+
:param node: PyTorch fx Node with 'all_input_nodes' initialized.
174+
:return: True, if the node uses shape broadcasting for it's input nodes.
175+
False otherwise.
176+
"""
177+
178+
if node.all_input_nodes is None:
179+
logger.e(
180+
logger.Code.INTERNAL_ERROR,
181+
"common.node_uses_shape_broadcasting(): 'all_input_nodes' are None!",
182+
)
183+
184+
if len(node.all_input_nodes) == 0:
185+
logger.e(
186+
logger.Code.INTERNAL_ERROR,
187+
"common.node_uses_shape_broadcasting(): Operator has no inputs!",
188+
)
189+
190+
first_input_shape = node.all_input_nodes[0].meta["val"].shape
191+
192+
return any(
193+
input_tensor.meta["val"].shape != first_input_shape
194+
for input_tensor in node.all_input_nodes[1:]
195+
)
196+
197+
169198
def uses_multiple_input_types(t_op: tflite_model.Operator) -> bool:
170199
"""Determine if the input tensors of given TFLite operator use different data types or not.
171200

backends/nxp/backend/ir/converter/node_converter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ def _create_tflite_op_with_io_tensors(self, node: Node) -> tflite_model.Operator
173173

174174
# Initialize node's inputs
175175
t_operator.inputs = tflite_model.OperatorInputs()
176-
for ancestor_node in node.all_input_nodes:
176+
input_nodes = [arg for arg in node.args if isinstance(arg, Node)]
177+
for ancestor_node in input_nodes:
177178
assert self.context.tflite_builder.tensor_exists(ancestor_node.name)
178179
t_operator.tmp_inputs.append(
179180
self.context.tflite_builder.tensor_for_name(ancestor_node.name)

backends/nxp/backend/ir/converter/node_converters/ops_converters/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
1+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.abs_converter import (
2+
AbsConverter,
3+
)
4+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.adaptive_avg_pool_2d_converter import (
5+
AdaptiveAvgPool2dConverter,
6+
)
7+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.add_tensor_converter import (
8+
AddTensorConverter,
9+
)
110
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.addmm_converter import (
211
AddMMConverter,
312
)
413
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.avg_pool_2d_converter import (
514
AvgPool2dConverter,
615
)
16+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.clone_converter import (
17+
CloneConverter,
18+
)
719
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.constant_pad_nd_converter import (
820
ConstantPadNDConverter,
921
)
@@ -16,6 +28,9 @@
1628
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.max_pool_2d_converter import (
1729
MaxPool2dConverter,
1830
)
31+
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.mean_dim_converter import (
32+
MeanDimConverter,
33+
)
1934
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.mm_converter import (
2035
MMConverter,
2136
)
@@ -49,7 +64,12 @@
4964
"QDQQuantizeConverter",
5065
"ConstantPadNDConverter",
5166
"ReLUConverter",
67+
"MeanDimConverter",
5268
"MaxPool2dConverter",
5369
"AvgPool2dConverter",
70+
"AddTensorConverter",
71+
"CloneConverter",
72+
"AbsConverter",
73+
"AdaptiveAvgPool2dConverter",
5474
"HardTanhConverter",
5575
]
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright (c) 2025 NXP
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from executorch.backends.nxp.backend.ir.converter.node_converter import (
9+
NodeConverter,
10+
Target,
11+
)
12+
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
13+
abs_options,
14+
)
15+
from torch.fx import Node
16+
from torch.nn import Parameter
17+
18+
19+
class AbsConverter(NodeConverter):
20+
supported_targets = [Target.RT700]
21+
22+
@staticmethod
23+
def _is_supported_in_IR(
24+
node: Node, parameters_mapping: dict[str, Parameter]
25+
) -> bool:
26+
return True
27+
28+
def convert(self, node: Node):
29+
"""Convert 'aten::abs' operator to TFLite 'Abs'."""
30+
self.assert_convertible(node)
31+
32+
t_op = self._create_tflite_op_with_io_tensors(node)
33+
34+
t_op.builtin_options = abs_options.Abs()
35+
self.builder.append_operators([t_op])
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) 2025 NXP
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import executorch.backends.nxp.backend.ir.lib.tflite.Padding as tflPadding
8+
from executorch.backends.nxp.backend.ir.converter.conversion import common
9+
from executorch.backends.nxp.backend.ir.converter.node_converter import (
10+
NodeConverter,
11+
Target,
12+
)
13+
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
14+
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
15+
average_pool_2d_options,
16+
)
17+
from torch import Size
18+
from torch.fx import Node
19+
from torch.nn import Parameter
20+
21+
22+
class AdaptiveAvgPool2dConverter(NodeConverter):
23+
supported_targets = [Target.RT700]
24+
25+
@staticmethod
26+
def _is_supported_in_IR(
27+
node: Node, parameters_mapping: dict[str, Parameter]
28+
) -> bool:
29+
input_size = node.args[0].meta["val"].shape
30+
output_size = node.args[1]
31+
32+
if (input_size[-1] % output_size[-1] != 0) or (
33+
input_size[-2] % output_size[-2] != 0
34+
):
35+
return False
36+
37+
if not NodeConverter._has_shared_q_params_if_quantized(node):
38+
return False
39+
40+
return True
41+
42+
# noinspection PyMethodMayBeStatic
43+
def _convert_adaptive_avg_pool_2d(
44+
self, input_size: Size, output_size: list[int], t_op: tflite_model.Operator
45+
):
46+
t_op.builtin_options = average_pool_2d_options.AveragePool2D()
47+
stride = [input_size[-2] // output_size[-2], input_size[-1] // output_size[-1]]
48+
common.assign_2d_strides(t_op.builtin_options, stride)
49+
t_op.builtin_options.filter_h = (
50+
input_size[-2] - (output_size[-2] - 1) * stride[-2]
51+
)
52+
t_op.builtin_options.filter_w = (
53+
input_size[-1] - (output_size[-1] - 1) * stride[-1]
54+
)
55+
t_op.builtin_options.padding = tflPadding.Padding.VALID
56+
57+
# AdaptiveAvgPool2d Node format: (Tensor self, SymInt[2] output_size)
58+
def convert(self, node: Node):
59+
"""Convert '_adaptive_avg_pool2d' operator to TFLite 'AveragePool2D'."""
60+
self.assert_convertible(node)
61+
62+
input_size = node.args[0].meta["val"].shape
63+
output_size = node.args[1]
64+
65+
t_op = self._create_tflite_op_with_io_tensors(node)
66+
67+
self._convert_adaptive_avg_pool_2d(input_size, output_size, t_op)
68+
self.builder.append_operators([t_op])
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) 2025 NXP
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from executorch.backends.nxp.backend.ir.converter.conversion.common import (
8+
node_uses_shape_broadcasting,
9+
)
10+
from executorch.backends.nxp.backend.ir.converter.node_converter import (
11+
NodeConverter,
12+
Target,
13+
)
14+
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
15+
add_options,
16+
)
17+
from torch.fx import Node
18+
from torch.nn import Parameter
19+
20+
21+
class AddTensorConverter(NodeConverter):
22+
supported_targets = [Target.RT700]
23+
24+
@staticmethod
25+
def _is_supported_in_IR(
26+
node: Node, parameters_mapping: dict[str, Parameter]
27+
) -> bool:
28+
if len(node.args) != 2:
29+
return False
30+
31+
if hasattr(node.kwargs, "alpha"):
32+
return False
33+
34+
# Don't convert if broadcasting input tensors
35+
if node_uses_shape_broadcasting(node):
36+
return False
37+
38+
return True
39+
40+
# add.Tensor Node format: (Tensor self, Tensor other, *, Scalar alpha=1)
41+
def convert(self, node: Node):
42+
"""Convert 'add_tensor' operator to TFLite 'add'."""
43+
self.assert_convertible(node)
44+
45+
t_op = self._create_tflite_op_with_io_tensors(node)
46+
47+
t_op.builtin_options = add_options.Add()
48+
self.builder.append_operators([t_op])
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.nxp.backend.ir.converter.node_converter import (
8+
NodeConverter,
9+
Target,
10+
)
11+
from torch.fx import Node
12+
from torch.nn import Parameter
13+
14+
15+
def _has_supported_memory_format(node: Node) -> bool:
16+
if "memory_format" in node.kwargs.keys():
17+
return node.kwargs["memory_format"] == torch.preserve_format
18+
19+
return True
20+
21+
22+
class CloneConverter(NodeConverter):
23+
supported_targets = [Target.RT700]
24+
25+
@staticmethod
26+
def _is_supported_in_IR(
27+
node: Node, parameters_mapping: dict[str, Parameter]
28+
) -> bool:
29+
return _has_supported_memory_format(node)
30+
31+
def convert(self, node: Node):
32+
"""Skip `aten.clone` operator if it has no `memory_format` specified."""
33+
self.assert_convertible(node)
34+
35+
t_op = self._create_tflite_op_with_io_tensors(node)
36+
37+
self.builder.turn_operator_to_identity(t_op)
38+
self.builder.append_operators([t_op])

0 commit comments

Comments
 (0)