Skip to content

Commit 62db20b

Browse files
committed
Arm backend: Add function to return quant params for lowered graph
Signed-off-by: Elena Zhelezina <[email protected]> Change-Id: I09de39c603d68d5ac5de4614a35eb7e3fc9ba518
1 parent aaf0a4c commit 62db20b

File tree

2 files changed

+193
-0
lines changed

2 files changed

+193
-0
lines changed

exir/backend/io_quant_params.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, Dict, Sequence
7+
8+
import torch.fx as fx
9+
from executorch.exir import EdgeProgramManager
10+
from executorch.exir.passes.quantize_io_pass import QuantizeInputs, QuantizeOutputs
11+
12+
13+
def extract_io_quant_params(
14+
edge_prog: EdgeProgramManager,
15+
*,
16+
input_idxs: Sequence[int] = (0,),
17+
output_idxs: Sequence[int] = (0,),
18+
) -> Dict[str, Dict[str, Dict[str, Any]]]:
19+
"""
20+
Returns quantization parameters such as scale/zero_point:
21+
{
22+
"inputs": {
23+
<placeholder_name>: {"scale": float, "zero_point": int}
24+
},
25+
"outputs": {
26+
<node_name>: {"scale": float, "zero_point": int}
27+
}
28+
}
29+
30+
Note that this function will strip out the IO quantize/dequantize ops as
31+
it records their parameters, so if you need to preserve the original graph
32+
you need to make a copy with copy.deepcopy before.
33+
34+
Note that `to_edge_transform_and_lower` should be called before.
35+
"""
36+
# Use IO passes
37+
passes = []
38+
for idx in input_idxs:
39+
passes.append(QuantizeInputs(edge_prog, [idx]))
40+
for idx in output_idxs:
41+
passes.append(QuantizeOutputs(edge_prog, [idx]))
42+
43+
# Apply them
44+
edge_prog = edge_prog.transform(passes)
45+
46+
cfg = getattr(edge_prog, "_config_methods", {}) or {}
47+
48+
# We need GraphModule to find node names
49+
gm = edge_prog.exported_program().graph_module
50+
51+
input_names = _gather_io_names(gm, side="input")
52+
output_names = _gather_io_names(gm, side="output")
53+
54+
# Build the result dict
55+
result = {"inputs": {}, "outputs": {}}
56+
for key, val in cfg.items():
57+
if key.startswith("input"):
58+
prefix, section, names = "input", "inputs", input_names
59+
elif key.startswith("output"):
60+
prefix, section, names = "output", "outputs", output_names
61+
else:
62+
continue
63+
64+
idx_str, param = key[len(prefix) :].split("_", 1)
65+
idx = int(idx_str)
66+
name = names[idx]
67+
# We need to map 'zp' to 'zero_point'
68+
out_param = "zero_point" if param in ("zp", "zero_point") else param
69+
result[section].setdefault(name, {})[out_param] = val
70+
71+
return result
72+
73+
74+
def _gather_io_names(gm: fx.GraphModule, side: str):
75+
"""
76+
For 'input', returns placeholder names in graph order.
77+
For 'output', returns names of output nodes.
78+
"""
79+
if side == "input":
80+
return [n.name for n in gm.graph.nodes if n.op == "placeholder"]
81+
82+
if side == "output":
83+
84+
def _flatten(args):
85+
out = []
86+
87+
def rec(x):
88+
if isinstance(x, (tuple, list)):
89+
for y in x:
90+
rec(y)
91+
elif isinstance(x, fx.Node):
92+
out.append(x)
93+
94+
rec(args)
95+
return out
96+
97+
output_node = next(n for n in gm.graph.nodes if n.op == "output")
98+
return [n.name for n in _flatten(output_node.args)]
99+
100+
raise ValueError(f"Unknown side: {side}")
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import copy
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
11+
get_symmetric_quantization_config,
12+
XNNPACKQuantizer,
13+
)
14+
from executorch.exir import to_edge_transform_and_lower
15+
from executorch.exir.backend.io_quant_params import extract_io_quant_params
16+
17+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
18+
19+
20+
class SimpleAdd(torch.nn.Module):
21+
def forward(self, x, y):
22+
return x + y
23+
24+
25+
class TestExtractIOQuantParamsPT2E(unittest.TestCase):
26+
def setUp(self):
27+
self.example_inputs = (
28+
torch.ones(1, 5),
29+
torch.full(
30+
(
31+
1,
32+
5,
33+
),
34+
2.0,
35+
),
36+
)
37+
self.mod = SimpleAdd().eval()
38+
39+
# Setup XNNPACK quantizer for example
40+
self.quantizer = XNNPACKQuantizer()
41+
operator_config = get_symmetric_quantization_config()
42+
self.quantizer.set_global(operator_config)
43+
44+
exported = torch.export.export_for_training(
45+
self.mod,
46+
copy.deepcopy(self.example_inputs),
47+
strict=True,
48+
)
49+
prepared = prepare_pt2e(exported.module(), self.quantizer)
50+
51+
# Call observers to calibrate
52+
_ = prepared(*self.example_inputs)
53+
54+
converted = convert_pt2e(prepared)
55+
56+
# Export again with quant parameters
57+
final_export = torch.export.export_for_training(
58+
converted,
59+
self.example_inputs,
60+
strict=True,
61+
)
62+
63+
# Lower to EdgeProgramManager
64+
self.edge_prog = to_edge_transform_and_lower(final_export)
65+
66+
def test_roundtrip_extracts_io_params(self):
67+
# Get dict with quant parameters
68+
q = extract_io_quant_params(
69+
self.edge_prog,
70+
input_idxs=(0, 1),
71+
output_idxs=(0,),
72+
)
73+
74+
# Validate structure
75+
self.assertIn("inputs", q)
76+
self.assertIn("outputs", q)
77+
self.assertEqual(len(q["inputs"]), 2)
78+
self.assertEqual(len(q["outputs"]), 1)
79+
80+
# Each entry must have a float 'scale' and int 'zero_point'
81+
for name, params in q["inputs"].items():
82+
self.assertIsInstance(name, str)
83+
self.assertIsInstance(params["scale"], float)
84+
self.assertIsInstance(params["zero_point"], int)
85+
86+
out_name, out_params = next(iter(q["outputs"].items()))
87+
self.assertIsInstance(out_name, str)
88+
self.assertIsInstance(out_params["scale"], float)
89+
self.assertIsInstance(out_params["zero_point"], int)
90+
91+
92+
if __name__ == "__main__":
93+
unittest.main()

0 commit comments

Comments
 (0)