From b120e934138d94aca719eb8a9a7e8e67e96b15f7 Mon Sep 17 00:00:00 2001 From: Elena Zhelezina Date: Mon, 14 Jul 2025 18:38:03 +0100 Subject: [PATCH] Arm backend: Add function extract_io_quant_params Signed-off-by: Elena Zhelezina Change-Id: I74d7292252c6daf1486ef28bbb2dcffdd34cd7b7 --- exir/passes/quantize_io_pass.py | 98 +++++++++++++++++++++- exir/tests/test_extract_io_quant_params.py | 93 ++++++++++++++++++++ 2 files changed, 190 insertions(+), 1 deletion(-) create mode 100644 exir/tests/test_extract_io_quant_params.py diff --git a/exir/passes/quantize_io_pass.py b/exir/passes/quantize_io_pass.py index 836a7376f7d..2ff2dccf99b 100644 --- a/exir/passes/quantize_io_pass.py +++ b/exir/passes/quantize_io_pass.py @@ -1,15 +1,21 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Sequence, Union import numpy as np import torch +import torch.fx as fx from executorch.exir import EdgeProgramManager, ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops @@ -316,3 +322,93 @@ def call(self, graph_module: torch.fx.GraphModule): self.edge_manager_update_quant_config_method(i, self.dequant_args[i]) return PassResult(graph_module, True) + + +def extract_io_quant_params( + edge_prog: EdgeProgramManager, + *, + input_idxs: Sequence[int] = (0,), + output_idxs: Sequence[int] = (0,), +) -> Dict[str, Dict[str, Dict[str, Any]]]: + """ + Returns quantization parameters such as scale/zero_point: + { + "inputs": { + : {"scale": float, "zero_point": int} + }, + "outputs": { + : {"scale": float, "zero_point": int} + } + } + + Note that this function will strip out the IO quantize/dequantize ops as + it records their parameters, so if you need to preserve the original graph + you need to make a copy with copy.deepcopy before. + + Note that `to_edge_transform_and_lower` should be called before. + """ + # Use IO passes + passes = [] + for idx in input_idxs: + passes.append(QuantizeInputs(edge_prog, [idx])) + for idx in output_idxs: + passes.append(QuantizeOutputs(edge_prog, [idx])) + + # Apply them + edge_prog = edge_prog.transform(passes) + + cfg = getattr(edge_prog, "_config_methods", {}) or {} + + # We need GraphModule to find node names + gm = edge_prog.exported_program().graph_module + + input_names = _gather_io_names(gm, side="input") + output_names = _gather_io_names(gm, side="output") + + # Build the result dict + result = {"inputs": {}, "outputs": {}} + for key, val in cfg.items(): + if key.startswith("input"): + prefix, section, names = "input", "inputs", input_names + elif key.startswith("output"): + prefix, section, names = "output", "outputs", output_names + else: + continue + + idx_str, param = key[len(prefix) :].split("_", 1) + idx = int(idx_str) + name = names[idx] + # We need to map 'zp' to 'zero_point' + out_param = "zero_point" if param in ("zp", "zero_point") else param + result[section].setdefault(name, {})[out_param] = val + + return result + + +def _gather_io_names(gm: fx.GraphModule, side: str): + """ + For 'input', returns placeholder names in graph order. + For 'output', returns names of output nodes. + """ + if side == "input": + return [n.name for n in gm.graph.nodes if n.op == "placeholder"] + + if side == "output": + + def _flatten(args): + out = [] + + def rec(x): + if isinstance(x, (tuple, list)): + for y in x: + rec(y) + elif isinstance(x, fx.Node): + out.append(x) + + rec(args) + return out + + output_node = next(n for n in gm.graph.nodes if n.op == "output") + return [n.name for n in _flatten(output_node.args)] + + raise ValueError(f"Unknown side: {side}") diff --git a/exir/tests/test_extract_io_quant_params.py b/exir/tests/test_extract_io_quant_params.py new file mode 100644 index 00000000000..84da01c673d --- /dev/null +++ b/exir/tests/test_extract_io_quant_params.py @@ -0,0 +1,93 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) +from executorch.exir import to_edge_transform_and_lower +from executorch.exir.passes.quantize_io_pass import extract_io_quant_params + +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + + +class SimpleAdd(torch.nn.Module): + def forward(self, x, y): + return x + y + + +class TestExtractIOQuantParamsPT2E(unittest.TestCase): + def setUp(self): + self.example_inputs = ( + torch.ones(1, 5), + torch.full( + ( + 1, + 5, + ), + 2.0, + ), + ) + self.mod = SimpleAdd().eval() + + # Setup XNNPACK quantizer for example + self.quantizer = XNNPACKQuantizer() + operator_config = get_symmetric_quantization_config() + self.quantizer.set_global(operator_config) + + exported = torch.export.export_for_training( + self.mod, + copy.deepcopy(self.example_inputs), + strict=True, + ) + prepared = prepare_pt2e(exported.module(), self.quantizer) + + # Call observers to calibrate + _ = prepared(*self.example_inputs) + + converted = convert_pt2e(prepared) + + # Export again with quant parameters + final_export = torch.export.export_for_training( + converted, + self.example_inputs, + strict=True, + ) + + # Lower to EdgeProgramManager + self.edge_prog = to_edge_transform_and_lower(final_export) + + def test_roundtrip_extracts_io_params(self): + # Get dict with quant parameters + q = extract_io_quant_params( + self.edge_prog, + input_idxs=(0, 1), + output_idxs=(0,), + ) + + # Validate structure + self.assertIn("inputs", q) + self.assertIn("outputs", q) + self.assertEqual(len(q["inputs"]), 2) + self.assertEqual(len(q["outputs"]), 1) + + # Each entry must have a float 'scale' and int 'zero_point' + for name, params in q["inputs"].items(): + self.assertIsInstance(name, str) + self.assertIsInstance(params["scale"], float) + self.assertIsInstance(params["zero_point"], int) + + out_name, out_params = next(iter(q["outputs"].items())) + self.assertIsInstance(out_name, str) + self.assertIsInstance(out_params["scale"], float) + self.assertIsInstance(out_params["zero_point"], int) + + +if __name__ == "__main__": + unittest.main()