Skip to content

Arm backend: Add function extract_io_quant_params #12481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion exir/passes/quantize_io_pass.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree

#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union

import numpy as np

import torch
import torch.fx as fx

from executorch.exir import EdgeProgramManager, ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -316,3 +322,93 @@ def call(self, graph_module: torch.fx.GraphModule):
self.edge_manager_update_quant_config_method(i, self.dequant_args[i])

return PassResult(graph_module, True)


def extract_io_quant_params(
edge_prog: EdgeProgramManager,
*,
input_idxs: Sequence[int] = (0,),
output_idxs: Sequence[int] = (0,),
) -> Dict[str, Dict[str, Dict[str, Any]]]:
"""
Returns quantization parameters such as scale/zero_point:
{
"inputs": {
<placeholder_name>: {"scale": float, "zero_point": int}
},
"outputs": {
<node_name>: {"scale": float, "zero_point": int}
}
}

Note that this function will strip out the IO quantize/dequantize ops as
it records their parameters, so if you need to preserve the original graph
you need to make a copy with copy.deepcopy before.

Note that `to_edge_transform_and_lower` should be called before.
"""
# Use IO passes
passes = []
for idx in input_idxs:
passes.append(QuantizeInputs(edge_prog, [idx]))
for idx in output_idxs:
passes.append(QuantizeOutputs(edge_prog, [idx]))

# Apply them
edge_prog = edge_prog.transform(passes)

cfg = getattr(edge_prog, "_config_methods", {}) or {}

# We need GraphModule to find node names
gm = edge_prog.exported_program().graph_module

input_names = _gather_io_names(gm, side="input")
output_names = _gather_io_names(gm, side="output")

# Build the result dict
result = {"inputs": {}, "outputs": {}}
for key, val in cfg.items():
if key.startswith("input"):
prefix, section, names = "input", "inputs", input_names
elif key.startswith("output"):
prefix, section, names = "output", "outputs", output_names
else:
continue

idx_str, param = key[len(prefix) :].split("_", 1)
idx = int(idx_str)
name = names[idx]
# We need to map 'zp' to 'zero_point'
out_param = "zero_point" if param in ("zp", "zero_point") else param
result[section].setdefault(name, {})[out_param] = val

return result


def _gather_io_names(gm: fx.GraphModule, side: str):
"""
For 'input', returns placeholder names in graph order.
For 'output', returns names of output nodes.
"""
if side == "input":
return [n.name for n in gm.graph.nodes if n.op == "placeholder"]

if side == "output":

def _flatten(args):
out = []

def rec(x):
if isinstance(x, (tuple, list)):
for y in x:
rec(y)
elif isinstance(x, fx.Node):
out.append(x)

rec(args)
return out

output_node = next(n for n in gm.graph.nodes if n.op == "output")
return [n.name for n in _flatten(output_node.args)]

raise ValueError(f"Unknown side: {side}")
93 changes: 93 additions & 0 deletions exir/tests/test_extract_io_quant_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import unittest

import torch
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from executorch.exir import to_edge_transform_and_lower
from executorch.exir.passes.quantize_io_pass import extract_io_quant_params

from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e


class SimpleAdd(torch.nn.Module):
def forward(self, x, y):
return x + y


class TestExtractIOQuantParamsPT2E(unittest.TestCase):
def setUp(self):
self.example_inputs = (
torch.ones(1, 5),
torch.full(
(
1,
5,
),
2.0,
),
)
self.mod = SimpleAdd().eval()

# Setup XNNPACK quantizer for example
self.quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config()
self.quantizer.set_global(operator_config)

exported = torch.export.export_for_training(
self.mod,
copy.deepcopy(self.example_inputs),
strict=True,
)
prepared = prepare_pt2e(exported.module(), self.quantizer)

# Call observers to calibrate
_ = prepared(*self.example_inputs)

converted = convert_pt2e(prepared)

# Export again with quant parameters
final_export = torch.export.export_for_training(
converted,
self.example_inputs,
strict=True,
)

# Lower to EdgeProgramManager
self.edge_prog = to_edge_transform_and_lower(final_export)

def test_roundtrip_extracts_io_params(self):
# Get dict with quant parameters
q = extract_io_quant_params(
self.edge_prog,
input_idxs=(0, 1),
output_idxs=(0,),
)

# Validate structure
self.assertIn("inputs", q)
self.assertIn("outputs", q)
self.assertEqual(len(q["inputs"]), 2)
self.assertEqual(len(q["outputs"]), 1)

# Each entry must have a float 'scale' and int 'zero_point'
for name, params in q["inputs"].items():
self.assertIsInstance(name, str)
self.assertIsInstance(params["scale"], float)
self.assertIsInstance(params["zero_point"], int)

out_name, out_params = next(iter(q["outputs"].items()))
self.assertIsInstance(out_name, str)
self.assertIsInstance(out_params["scale"], float)
self.assertIsInstance(out_params["zero_point"], int)


if __name__ == "__main__":
unittest.main()
Loading