Skip to content

Commit cf2f170

Browse files
NXP backend: Add model input and output quantization (#12586)
### Summary With this change the NeutronConverter can quantize the input and output tensors (i.e. Input and Output placeholder nodes). There is also a pass added to consequently remove the Q/DQ nodes for the placeholders, making the model fully quantized. ### Test plan Unit tests were updated with respect to newly introduced changes. --------- Co-authored-by: Lukas Sztefek <[email protected]>
1 parent 8d0053c commit cf2f170

File tree

12 files changed

+325
-15
lines changed

12 files changed

+325
-15
lines changed

backends/nxp/backend/ir/converter/node_converters/ops_converters/constant_pad_nd_converter.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import numpy as np
1010

1111
from executorch.backends.nxp.backend.edge_helper import input_rank
12-
from executorch.backends.nxp.backend.ir.converter.conversion.common import OpsList
1312
from executorch.backends.nxp.backend.ir.converter.conversion.translator import (
1413
apply_permutation_to,
1514
create_channels_first_to_channels_last_permutation,
@@ -24,6 +23,7 @@
2423
)
2524
from executorch.backends.nxp.backend.ir.tflite_generator import tflite_model
2625
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
26+
pad_options,
2727
pad_v2_options,
2828
)
2929
from torch.fx import Node
@@ -50,6 +50,10 @@ def _is_supported_in_IR(
5050
if not NodeConverter._has_shared_q_params_if_quantized(node):
5151
return False
5252

53+
if len(paddings) > 4 and paddings[4:6] != [0, 0]:
54+
# Attempt to Pad channels dimension -> currently not supported
55+
return False
56+
5357
return True
5458

5559
# noinspection PyMethodMayBeStatic
@@ -101,6 +105,15 @@ def convert(self, node: Node):
101105
np.asarray(paddings, "int32"), "paddings"
102106
)
103107

108+
if constant == 0.0:
109+
# We're padding with zeros, we can use traditional Pad op
110+
t_op.tmp_inputs = [x, paddings_tensor]
111+
t_op.tmp_outputs = [y]
112+
t_op.builtin_options = pad_options.Pad()
113+
114+
self.builder.append_operators([t_op])
115+
return
116+
104117
if x.quantization is None:
105118
constant_tensor = self.builder.create_tensor_for_data(
106119
np.array([constant], tf_lite_type_to_numpy(x.type)), "constant"
@@ -124,6 +137,4 @@ def convert(self, node: Node):
124137
t_op.tmp_outputs = [y]
125138
t_op.builtin_options = pad_v2_options.PadV2()
126139

127-
ops_to_add = OpsList(middle_op=t_op)
128-
129-
self.builder.append_operators(ops_to_add.flatten())
140+
self.builder.append_operators([t_op])

backends/nxp/backend/ir/edge_passes/__init__.py

Whitespace-only changes.
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
8+
from executorch.exir import EdgeProgramManager
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass
11+
from executorch.exir.passes.quantize_io_pass import QuantizeInputs, QuantizeOutputs
12+
from torch.fx.passes.infra.pass_base import PassResult
13+
14+
15+
class RemoveIOQuantOpsPass(ExportPass):
16+
17+
def __init__(self, edge_program_manager: EdgeProgramManager):
18+
super().__init__()
19+
self._edge_program_manager = edge_program_manager
20+
21+
def _get_quantizable_input_indices(self):
22+
exported_program = self._edge_program_manager.exported_program()
23+
24+
graph = exported_program.graph_module.graph
25+
user_inputs = exported_program.graph_signature.user_inputs
26+
27+
inputs_to_quantization = []
28+
29+
for input_index, user_input in enumerate(user_inputs):
30+
placeholders = [
31+
n for n in graph.nodes if n.op == "placeholder" and n.name == user_input
32+
]
33+
assert placeholders
34+
target_placeholder = placeholders[0]
35+
36+
if len(target_placeholder.users) != 1:
37+
raise ValueError(f"Input {input_index} has more than one users")
38+
39+
quantize = next(iter(target_placeholder.users))
40+
if (
41+
quantize.target
42+
!= exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
43+
):
44+
continue
45+
46+
inputs_to_quantization.append(input_index)
47+
48+
return inputs_to_quantization
49+
50+
def _get_quantizable_output_indices(self):
51+
exported_program = self._edge_program_manager.exported_program()
52+
53+
graph = exported_program.graph_module.graph
54+
outputs = [n for n in graph.nodes if n.op == "output"]
55+
if len(outputs) != 1:
56+
raise NotImplementedError("Only 1 output node is supported.")
57+
58+
outputs_to_quantization = []
59+
60+
user_outputs = list(outputs[0].args[0])
61+
for output_index, user_output in enumerate(user_outputs):
62+
if (
63+
user_output.target
64+
!= exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
65+
):
66+
continue
67+
68+
outputs_to_quantization.append(output_index)
69+
70+
return outputs_to_quantization
71+
72+
def call(self, graph_module: torch.fx.GraphModule):
73+
input_indices = self._get_quantizable_input_indices()
74+
output_indices = self._get_quantizable_output_indices()
75+
76+
QuantizeInputs(self._edge_program_manager, input_indices).call(graph_module)
77+
QuantizeOutputs(self._edge_program_manager, output_indices).call(graph_module)
78+
79+
return PassResult(graph_module, True)

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
no_outside_users,
4343
)
4444
from torch import fx
45+
from torch.ao.quantization.quantizer.utils import _annotate_output_qspec
4546
from torchao.quantization.pt2e import HistogramObserver, MinMaxObserver
4647
from torchao.quantization.pt2e.quantizer import (
4748
ComposableQuantizer,
@@ -239,6 +240,8 @@ def transform_for_annotation(
239240
return pass_runner(model).graph_module
240241

241242
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
243+
self._annotate_inputs(model)
244+
242245
nodes = list(model.graph.nodes)
243246
for node in nodes:
244247
if (
@@ -254,5 +257,25 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
254257

255258
return model
256259

260+
def _is_input_annotated(self, node: fx.Node) -> bool:
261+
return (
262+
"quantization_annotation" in node.meta
263+
and node.meta["quantization_annotation"]._annotated
264+
)
265+
266+
def _mark_input_node_as_annotated(self, node: fx.Node) -> None:
267+
if "quantization_annotation" not in node.meta:
268+
node.meta["quantization_annotation"] = QuantizationAnnotation()
269+
node.meta["quantization_annotation"]._annotated = True
270+
271+
def _annotate_inputs(self, model: fx.GraphModule):
272+
for node in model.graph.nodes:
273+
if self._is_input_annotated(node):
274+
continue
275+
276+
if node.op == "placeholder" and len(node.users) > 0:
277+
_annotate_output_qspec(node, act_qspec)
278+
self._mark_input_node_as_annotated(node)
279+
257280
def validate(self, model: torch.fx.GraphModule) -> None:
258281
return super().validate(model)

backends/nxp/run_unittests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ EXECUTORCH_DIR=$(dirname $(dirname $SCRIPT_DIR))
1111
cd $EXECUTORCH_DIR
1212

1313
# '-c /dev/null' is used to ignore root level pytest.ini.
14-
PYTHONPATH=`cd ..; pwd` pytest -c /dev/null backends/nxp/tests/
14+
pytest -c /dev/null backends/nxp/tests/

backends/nxp/tests/executorch_pipeline.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
import torch
77

88
from executorch import exir
9+
from executorch.backends.nxp.backend.ir.edge_passes.remove_io_quant_ops_pass import (
10+
RemoveIOQuantOpsPass,
11+
)
912
from executorch.backends.nxp.neutron_partitioner import NeutronPartitioner
1013
from executorch.backends.nxp.nxp_backend import generate_neutron_compile_spec
1114
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
@@ -37,6 +40,7 @@ def to_quantized_edge_program(
3740
operators_not_to_delegate: list[str] = None,
3841
target="imxrt700",
3942
neutron_converter_flavor="SDK_25_03",
43+
remove_quant_io_ops=False,
4044
) -> EdgeProgramManager:
4145
if isinstance(input_shapes, list):
4246
assert all(isinstance(input_shape, tuple) for input_shape in input_shapes), (
@@ -77,6 +81,11 @@ def to_quantized_edge_program(
7781
compile_config=EdgeCompileConfig(_check_ir_validity=False),
7882
)
7983

84+
if remove_quant_io_ops:
85+
edge_program_manager = edge_program_manager.transform(
86+
[RemoveIOQuantOpsPass(edge_program_manager=edge_program_manager)]
87+
)
88+
8089
return edge_program_manager
8190

8291

backends/nxp/tests/ir/converter/node_converter/test_constant_pad_nd_converter.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,10 @@ def test_constant_pad_nd_conversion__default_constant():
6363
pytest.param((2, 4), tuple(range(4)), id="2D, padding N, H"),
6464
pytest.param((2, 4, 6), tuple(range(2)), id="3D, padding H"),
6565
pytest.param((2, 4, 6), tuple(range(4)), id="3D, padding C, H"),
66-
pytest.param((2, 4, 6), list(range(6)), id="3D, padding N, C, H"),
6766
pytest.param((2, 4, 6, 8), tuple(range(2)), id="4D, padding W"),
6867
pytest.param((2, 4, 6, 8), tuple(range(4)), id="4D, padding H, W"),
69-
pytest.param((2, 4, 6, 8), list(range(6)), id="4D, padding C, H, W"),
70-
pytest.param((2, 4, 6, 8), list(range(8)), id="4D, padding N, C, H, W"),
71-
pytest.param((1, 2, 3, 4, 5), list(range(2)), id="5D, padding D"),
68+
pytest.param((1, 2, 3, 4, 5), tuple(range(2)), id="5D, padding D"),
7269
pytest.param((1, 2, 3, 4, 5), tuple(range(4)), id="5D, padding W, D"),
73-
pytest.param((1, 2, 3, 4, 5), list(range(6)), id="5D, padding H, W, D"),
74-
pytest.param((1, 2, 3, 4, 5), tuple(range(8)), id="5D, padding C, H, W, D"),
75-
pytest.param((1, 2, 3, 4, 5), list(range(10)), id="5D, padding N, C, H, W, D"),
7670
],
7771
)
7872
def test_constant_pad_nd_conversion__format_less(input_shape, paddings):
@@ -93,8 +87,9 @@ def test_constant_pad_nd_conversion__format_less(input_shape, paddings):
9387
],
9488
)
9589
def test_constant_pad_nd_conversion__channels_first(input_shape, paddings):
90+
model = ConstantPadNDConvModule(paddings)
9691
edge_program = to_edge_program(
97-
ConstantPadNDConvModule(paddings), input_shape
92+
model, input_shape
9893
).exported_program() # Extra `Conv` after the padding.
9994

10095
input_data = np.random.random(input_shape).astype(np.float32)
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
import itertools
6+
7+
import executorch.kernels.quantized # noqa F401
8+
import torch
9+
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
10+
from executorch.backends.nxp.tests.models import Conv2dReLUModule
11+
from executorch.examples.nxp.experimental.cifar_net.cifar_net import CifarNet
12+
from executorch.exir import ExecutorchBackendConfig
13+
from executorch.exir.passes.quantize_io_pass import get_config_method_name
14+
15+
16+
def test_remove_io_quant_ops_pass__conv_relu():
17+
model = Conv2dReLUModule()
18+
model.eval()
19+
20+
input_shape = (1, 4, 32, 32)
21+
edge_program_manager = to_quantized_edge_program(
22+
model, input_shape, remove_quant_io_ops=True
23+
)
24+
25+
exec_prog = edge_program_manager.to_executorch(
26+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
27+
)
28+
29+
nodes = list(exec_prog.exported_program().graph.nodes)
30+
assert (
31+
nodes[0].meta["val"].dtype == torch.int8
32+
), "Input tensor doesn't have type INT8."
33+
assert nodes[2].name == "executorch_call_delegate"
34+
assert (
35+
nodes[4].meta["val"][0].dtype == torch.int8
36+
), "Output tensor doesn't have type INT8."
37+
38+
assert (
39+
get_config_method_name(None, "input", 0, "scale") in exec_prog._config_methods
40+
)
41+
assert get_config_method_name(None, "input", 0, "zp") in exec_prog._config_methods
42+
assert (
43+
get_config_method_name(None, "output", 0, "scale") in exec_prog._config_methods
44+
)
45+
assert get_config_method_name(None, "output", 0, "zp") in exec_prog._config_methods
46+
47+
48+
def test_remove_io_quant_ops_pass__cifarnet():
49+
model = CifarNet().get_eager_model()
50+
input_shape = (1, 3, 32, 32)
51+
edge_program_manager = to_quantized_edge_program(
52+
model, input_shape, remove_quant_io_ops=True
53+
)
54+
55+
exec_prog = edge_program_manager.to_executorch(
56+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
57+
)
58+
59+
nodes = list(exec_prog.exported_program().graph.nodes)
60+
assert len(nodes) == 17
61+
assert (
62+
nodes[0].meta["val"].dtype == torch.int8
63+
), "Input tensor doesn't have type INT8."
64+
assert (
65+
nodes[16].meta["val"][0].dtype == torch.int8
66+
), "Output tensor doesn't have type INT8."
67+
68+
assert (
69+
get_config_method_name(None, "input", 0, "scale") in exec_prog._config_methods
70+
)
71+
assert get_config_method_name(None, "input", 0, "zp") in exec_prog._config_methods
72+
assert (
73+
get_config_method_name(None, "output", 0, "scale") in exec_prog._config_methods
74+
)
75+
assert get_config_method_name(None, "output", 0, "zp") in exec_prog._config_methods
76+
77+
78+
class MultiInputOutputModule(torch.nn.Module):
79+
def __init__(self):
80+
super().__init__()
81+
82+
self.conv = torch.nn.Conv2d(4, 64, 2, bias=False)
83+
self.relu = torch.nn.ReLU()
84+
85+
def forward(self, x, y):
86+
z = self.relu(x)
87+
x = self.conv(z)
88+
return x + y, z
89+
90+
91+
def test_multiple_inputs__multiple_outputs():
92+
model = MultiInputOutputModule()
93+
model.eval()
94+
95+
input_shape = [(1, 4, 32, 32), (1, 1, 1, 31)]
96+
edge_program_manager = to_quantized_edge_program(
97+
model, input_shape, remove_quant_io_ops=True
98+
)
99+
100+
exec_prog = edge_program_manager.to_executorch(
101+
config=ExecutorchBackendConfig(extract_delegate_segments=False)
102+
)
103+
104+
nodes = list(exec_prog.exported_program().graph.nodes)
105+
print(nodes)
106+
assert (
107+
nodes[0].meta["val"].dtype == torch.int8
108+
), "Input tensor doesn't have type INT8."
109+
assert nodes[3].name == "executorch_call_delegate"
110+
assert (
111+
nodes[-1].meta["val"][0].dtype == torch.int8
112+
), "Output tensor doesn't have type INT8."
113+
114+
quant_method_variants = itertools.product(
115+
["input", "output"], [0, 1], ["scale", "zp"]
116+
)
117+
118+
expected_methods = [
119+
get_config_method_name(None, arg_type, index, key)
120+
for arg_type, index, key in quant_method_variants
121+
]
122+
assert all(method in exec_prog._config_methods for method in expected_methods)

0 commit comments

Comments
 (0)