Skip to content

Commit b23fcbd

Browse files
authored
Arm backend: Move down a few passes to after fold Q/DQ pass (#16035)
This patch moves `ConvertSplitToSlicePass`, `QuantizeClampArgumentsPass`, `RemoveGetItemPass`, `DecomposeBatchNormNoStatsPass` to after the Q/DQ folding pass. This keeps the passes to work on node metadata instead of Q/DQ nodes. Signed-off-by: Martin Lindström <[email protected]>
1 parent ca91e74 commit b23fcbd

File tree

4 files changed

+73
-12
lines changed

4 files changed

+73
-12
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,18 +175,14 @@ def _tosa_pipeline(
175175
self.add_passes(
176176
[
177177
FuseQuantizedActivationPass(),
178-
RemoveGetItemPass(),
179178
ConvertToClampPass(),
180179
DecomposeInt32ClampPass(),
181180
DecomposeGroupNormPass(),
182181
DecomposeLayerNormPass(),
183-
DecomposeBatchNormNoStatsPass(),
184182
DecomposeVarPass(),
185183
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
186184
AnnotateDecomposedMatmulPass(),
187185
ConvertELUParamsPass(),
188-
ConvertSplitToSlicePass(),
189-
QuantizeClampArgumentsPass(),
190186
]
191187
)
192188

@@ -208,6 +204,10 @@ def _tosa_pipeline(
208204
# Node transformation passes (post q/dq folding)
209205
self.add_passes(
210206
[
207+
ConvertSplitToSlicePass(),
208+
QuantizeClampArgumentsPass(),
209+
RemoveGetItemPass(),
210+
DecomposeBatchNormNoStatsPass(),
211211
DecomposeLogitPass(),
212212
DecomposeMaskedFillPass(),
213213
DecomposeRoundPass(),

backends/arm/_passes/convert_split_to_slice.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,48 @@ def call(self, graph_module: torch.fx.GraphModule):
8585
graph,
8686
self.slice,
8787
(input_node, dim, starts[index], ends[index]),
88+
from_node=node,
89+
)
90+
slice_node.meta = _copy_user_node_qparams(
91+
split_node, output_node, index
8892
)
89-
slice_node.meta = split_node.meta.copy()
90-
slice_node.meta["val"] = slice_node.meta["val"][index]
9193
output_node.replace_all_uses_with(slice_node)
9294
graph.eliminate_dead_code()
9395
graph_module.recompile()
9496
graph_module = super().call(graph_module).graph_module
9597
return PassResult(graph_module, True)
98+
99+
100+
def _copy_user_node_qparams(
101+
split_node: torch.fx.Node, output_node: torch.fx.Node, index: int
102+
) -> dict:
103+
"""
104+
Construct metadata for the slice node that will replace the split output.
105+
106+
Note that output quantization parameters are copied from the user nodes
107+
of the split node. The split node itself does not have output quantization
108+
parameters.
109+
110+
Args:
111+
split_node: The split node being replaced.
112+
output_node: The getitem node that is user of the split node.
113+
index: The index of the output being processed.
114+
Returns:
115+
Updated metadata dictionary for the slice node.
116+
"""
117+
118+
def _select_index(value):
119+
if isinstance(value, (list, tuple)):
120+
return value[index]
121+
return value
122+
123+
meta = split_node.meta.copy()
124+
if "val" in meta:
125+
meta["val"] = _select_index(meta["val"])
126+
if "tensor_meta" in meta:
127+
meta["tensor_meta"] = _select_index(meta["tensor_meta"])
128+
if "input_qparams" in meta:
129+
meta["input_qparams"] = dict(meta["input_qparams"])
130+
if "output_qparams" in meta:
131+
meta["output_qparams"] = dict(output_node.meta["output_qparams"])
132+
return meta

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ class QuantizeClampArgumentsPass(ArmPass):
334334
- Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator.
335335
"""
336336

337-
_passes_required_after: Set[Type[ExportPass]] = {FoldAndAnnotateQParamsPass}
337+
_passes_required_after: Set[Type[ExportPass]] = set()
338338

339339
def call(self, graph_module: GraphModule) -> PassResult:
340340
modified = False
@@ -346,12 +346,15 @@ def call(self, graph_module: GraphModule) -> PassResult:
346346
}:
347347
continue
348348

349-
# Make sure we have a quantized operator
350-
user = list(n.users)[0]
351-
if user.target not in Q_OPS:
349+
try:
350+
output_qparams = get_output_qparams(n)
351+
except ValueError:
352+
continue
353+
if len(output_qparams) == 0:
352354
continue
353355

354-
qargs = QuantArgs.from_operator(user.target, user.args)
356+
# Qparams are stored per user index; use the first entry.
357+
qargs = next(iter(output_qparams.values()))
355358

356359
if n.target == exir_ops.edge.aten.clamp.default:
357360
# Quantize the min and max arguments of clamp, if they are not None
@@ -368,4 +371,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
368371

369372
modified = True
370373

374+
if modified:
375+
# Retrace to refresh fake tensor metadata after updating clamp min/max.
376+
graph_module = super().call(graph_module).graph_module
377+
graph_module.recompile()
378+
371379
return PassResult(graph_module, modified)

backends/transforms/remove_getitem_op.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
import copy
9+
810
import torch
911
from executorch.exir.dialects._ops import ops as exir_ops
1012

11-
from executorch.exir.pass_base import ExportPass, PassResult
13+
from executorch.exir.pass_base import ExportPass, PassResult, PROTECTED_KEYS
1214

1315

1416
class RemoveGetItemPass(ExportPass):
@@ -81,6 +83,8 @@ def call(self, graph_module: torch.fx.GraphModule):
8183
new_max_wd.meta = node.meta.copy()
8284
new_max_wd.meta["val"] = new_max_wd.meta["val"][0]
8385

86+
_copy_node_metadata(node, new_max_wd)
87+
8488
getitem_node.replace_all_uses_with(new_max_wd)
8589

8690
mdule.graph.erase_node(getitem_node)
@@ -91,3 +95,15 @@ def call(self, graph_module: torch.fx.GraphModule):
9195
graph_module = super().call(graph_module).graph_module
9296

9397
return PassResult(graph_module, True)
98+
99+
100+
def _copy_node_metadata(node: torch.fx.Node, new_max_wd: torch.fx.Node):
101+
"""Copy metadata from original node to new node."""
102+
103+
for key, value in node.meta.items():
104+
if key in PROTECTED_KEYS:
105+
continue
106+
try:
107+
new_max_wd.meta[key] = copy.deepcopy(value)
108+
except Exception:
109+
new_max_wd.meta[key] = value

0 commit comments

Comments
 (0)