From 62db20bc48dab9f0ea0b838498a1aeb1268cd351 Mon Sep 17 00:00:00 2001 From: Elena Zhelezina Date: Wed, 9 Jul 2025 17:30:09 +0100 Subject: [PATCH] Arm backend: Add function to return quant params for lowered graph Signed-off-by: Elena Zhelezina Change-Id: I09de39c603d68d5ac5de4614a35eb7e3fc9ba518 --- exir/backend/io_quant_params.py | 100 ++++++++++++++++++++++ exir/backend/test/test_io_quant_params.py | 93 ++++++++++++++++++++ 2 files changed, 193 insertions(+) create mode 100644 exir/backend/io_quant_params.py create mode 100644 exir/backend/test/test_io_quant_params.py diff --git a/exir/backend/io_quant_params.py b/exir/backend/io_quant_params.py new file mode 100644 index 00000000000..160aef78c86 --- /dev/null +++ b/exir/backend/io_quant_params.py @@ -0,0 +1,100 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, Sequence + +import torch.fx as fx +from executorch.exir import EdgeProgramManager +from executorch.exir.passes.quantize_io_pass import QuantizeInputs, QuantizeOutputs + + +def extract_io_quant_params( + edge_prog: EdgeProgramManager, + *, + input_idxs: Sequence[int] = (0,), + output_idxs: Sequence[int] = (0,), +) -> Dict[str, Dict[str, Dict[str, Any]]]: + """ + Returns quantization parameters such as scale/zero_point: + { + "inputs": { + : {"scale": float, "zero_point": int} + }, + "outputs": { + : {"scale": float, "zero_point": int} + } + } + + Note that this function will strip out the IO quantize/dequantize ops as + it records their parameters, so if you need to preserve the original graph + you need to make a copy with copy.deepcopy before. + + Note that `to_edge_transform_and_lower` should be called before. + """ + # Use IO passes + passes = [] + for idx in input_idxs: + passes.append(QuantizeInputs(edge_prog, [idx])) + for idx in output_idxs: + passes.append(QuantizeOutputs(edge_prog, [idx])) + + # Apply them + edge_prog = edge_prog.transform(passes) + + cfg = getattr(edge_prog, "_config_methods", {}) or {} + + # We need GraphModule to find node names + gm = edge_prog.exported_program().graph_module + + input_names = _gather_io_names(gm, side="input") + output_names = _gather_io_names(gm, side="output") + + # Build the result dict + result = {"inputs": {}, "outputs": {}} + for key, val in cfg.items(): + if key.startswith("input"): + prefix, section, names = "input", "inputs", input_names + elif key.startswith("output"): + prefix, section, names = "output", "outputs", output_names + else: + continue + + idx_str, param = key[len(prefix) :].split("_", 1) + idx = int(idx_str) + name = names[idx] + # We need to map 'zp' to 'zero_point' + out_param = "zero_point" if param in ("zp", "zero_point") else param + result[section].setdefault(name, {})[out_param] = val + + return result + + +def _gather_io_names(gm: fx.GraphModule, side: str): + """ + For 'input', returns placeholder names in graph order. + For 'output', returns names of output nodes. + """ + if side == "input": + return [n.name for n in gm.graph.nodes if n.op == "placeholder"] + + if side == "output": + + def _flatten(args): + out = [] + + def rec(x): + if isinstance(x, (tuple, list)): + for y in x: + rec(y) + elif isinstance(x, fx.Node): + out.append(x) + + rec(args) + return out + + output_node = next(n for n in gm.graph.nodes if n.op == "output") + return [n.name for n in _flatten(output_node.args)] + + raise ValueError(f"Unknown side: {side}") diff --git a/exir/backend/test/test_io_quant_params.py b/exir/backend/test/test_io_quant_params.py new file mode 100644 index 00000000000..689eedc099f --- /dev/null +++ b/exir/backend/test/test_io_quant_params.py @@ -0,0 +1,93 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from executorch.exir import to_edge_transform_and_lower +from executorch.exir.backend.io_quant_params import extract_io_quant_params + +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + + +class SimpleAdd(torch.nn.Module): + def forward(self, x, y): + return x + y + + +class TestExtractIOQuantParamsPT2E(unittest.TestCase): + def setUp(self): + self.example_inputs = ( + torch.ones(1, 5), + torch.full( + ( + 1, + 5, + ), + 2.0, + ), + ) + self.mod = SimpleAdd().eval() + + # Setup XNNPACK quantizer for example + self.quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config() + self.quantizer.set_global(operator_config) + + exported = torch.export.export_for_training( + self.mod, + copy.deepcopy(self.example_inputs), + strict=True, + ) + prepared = prepare_pt2e(exported.module(), self.quantizer) + + # Call observers to calibrate + _ = prepared(*self.example_inputs) + + converted = convert_pt2e(prepared) + + # Export again with quant parameters + final_export = torch.export.export_for_training( + converted, + self.example_inputs, + strict=True, + ) + + # Lower to EdgeProgramManager + self.edge_prog = to_edge_transform_and_lower(final_export) + + def test_roundtrip_extracts_io_params(self): + # Get dict with quant parameters + q = extract_io_quant_params( + self.edge_prog, + input_idxs=(0, 1), + output_idxs=(0,), + ) + + # Validate structure + self.assertIn("inputs", q) + self.assertIn("outputs", q) + self.assertEqual(len(q["inputs"]), 2) + self.assertEqual(len(q["outputs"]), 1) + + # Each entry must have a float 'scale' and int 'zero_point' + for name, params in q["inputs"].items(): + self.assertIsInstance(name, str) + self.assertIsInstance(params["scale"], float) + self.assertIsInstance(params["zero_point"], int) + + out_name, out_params = next(iter(q["outputs"].items())) + self.assertIsInstance(out_name, str) + self.assertIsInstance(out_params["scale"], float) + self.assertIsInstance(out_params["zero_point"], int) + + +if __name__ == "__main__": + unittest.main()