Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,18 +174,14 @@ def _tosa_pipeline(
self.add_passes(
[
FuseQuantizedActivationPass(),
RemoveGetItemPass(),
ConvertToClampPass(),
DecomposeInt32ClampPass(),
DecomposeGroupNormPass(),
DecomposeLayerNormPass(),
DecomposeBatchNormNoStatsPass(),
DecomposeVarPass(),
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec),
AnnotateDecomposedMatmulPass(),
ConvertELUParamsPass(),
ConvertSplitToSlicePass(),
QuantizeClampArgumentsPass(),
]
)

Expand All @@ -207,6 +203,10 @@ def _tosa_pipeline(
# Node transformation passes (post q/dq folding)
self.add_passes(
[
ConvertSplitToSlicePass(),
QuantizeClampArgumentsPass(),
RemoveGetItemPass(),
DecomposeBatchNormNoStatsPass(),
DecomposeLogitPass(),
DecomposeMaskedFillPass(),
DecomposeRoundPass(),
Expand Down
41 changes: 39 additions & 2 deletions backends/arm/_passes/convert_split_to_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,48 @@ def call(self, graph_module: torch.fx.GraphModule):
graph,
self.slice,
(input_node, dim, starts[index], ends[index]),
from_node=node,
)
slice_node.meta = _copy_user_node_qparams(
split_node, output_node, index
)
slice_node.meta = split_node.meta.copy()
slice_node.meta["val"] = slice_node.meta["val"][index]
output_node.replace_all_uses_with(slice_node)
graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)


def _copy_user_node_qparams(
split_node: torch.fx.Node, output_node: torch.fx.Node, index: int
) -> dict:
"""
Construct metadata for the slice node that will replace the split output.

Note that output quantization parameters are copied from the user nodes
of the split node. The split node itself does not have output quantization
parameters.

Args:
split_node: The split node being replaced.
output_node: The getitem node that is user of the split node.
index: The index of the output being processed.
Returns:
Updated metadata dictionary for the slice node.
"""

def _select_index(value):
if isinstance(value, (list, tuple)):
return value[index]
return value

meta = split_node.meta.copy()
if "val" in meta:
meta["val"] = _select_index(meta["val"])
if "tensor_meta" in meta:
meta["tensor_meta"] = _select_index(meta["tensor_meta"])
if "input_qparams" in meta:
meta["input_qparams"] = dict(meta["input_qparams"])
if "output_qparams" in meta:
meta["output_qparams"] = dict(output_node.meta["output_qparams"])
return meta
18 changes: 13 additions & 5 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ class QuantizeClampArgumentsPass(ArmPass):
- Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator.
"""

_passes_required_after: Set[Type[ExportPass]] = {FoldAndAnnotateQParamsPass}
_passes_required_after: Set[Type[ExportPass]] = set()

def call(self, graph_module: GraphModule) -> PassResult:
modified = False
Expand All @@ -346,12 +346,15 @@ def call(self, graph_module: GraphModule) -> PassResult:
}:
continue

# Make sure we have a quantized operator
user = list(n.users)[0]
if user.target not in Q_OPS:
try:
output_qparams = get_output_qparams(n)
except ValueError:
continue
if len(output_qparams) == 0:
continue

qargs = QuantArgs.from_operator(user.target, user.args)
# Qparams are stored per user index; use the first entry.
qargs = next(iter(output_qparams.values()))

if n.target == exir_ops.edge.aten.clamp.default:
# Quantize the min and max arguments of clamp, if they are not None
Expand All @@ -368,4 +371,9 @@ def call(self, graph_module: GraphModule) -> PassResult:

modified = True

if modified:
# Retrace to refresh fake tensor metadata after updating clamp min/max.
graph_module = super().call(graph_module).graph_module
graph_module.recompile()

return PassResult(graph_module, modified)
18 changes: 17 additions & 1 deletion backends/transforms/remove_getitem_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy

import torch
from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.pass_base import ExportPass, PassResult, PROTECTED_KEYS


class RemoveGetItemPass(ExportPass):
Expand Down Expand Up @@ -81,6 +83,8 @@ def call(self, graph_module: torch.fx.GraphModule):
new_max_wd.meta = node.meta.copy()
new_max_wd.meta["val"] = new_max_wd.meta["val"][0]

_copy_node_metadata(node, new_max_wd)

getitem_node.replace_all_uses_with(new_max_wd)

mdule.graph.erase_node(getitem_node)
Expand All @@ -91,3 +95,15 @@ def call(self, graph_module: torch.fx.GraphModule):
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, True)


def _copy_node_metadata(node: torch.fx.Node, new_max_wd: torch.fx.Node):
"""Copy metadata from original node to new node."""

for key, value in node.meta.items():
if key in PROTECTED_KEYS:
continue
try:
new_max_wd.meta[key] = copy.deepcopy(value)
except Exception:
new_max_wd.meta[key] = value
Loading