Skip to content

Commit 0a7271f

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Support eliminate_quant_dequant_pairs flag (pytorch#16029)
Summary: This change adds support for the `eliminate_quant_dequant_pairs` flag in the remove clone ops transform. This flag allows users to control whether quantization-dequantization pairs should be eliminated during the clone removal optimization pass. We need this to control the functionality of this pass in the subsequent diffs. Differential Revision: D88092217
1 parent 93bf861 commit 0a7271f

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
RemoveAliasCopyOpPass,
2222
RemoveBranchedQuantDequant,
2323
RemoveCatFromSliceCopyPass,
24-
RemoveCloneOpPass,
24+
RemoveCloneOpsTransformImported,
2525
RemoveContiguousOpPass,
2626
RemoveDetachCopyPass,
2727
RemoveNopAddOpPass,
@@ -241,7 +241,7 @@ def test_remove_clone(self) -> None:
241241
clone = builder.call_operator(op=exir_ops.edge.aten.clone.default, args=(x,))
242242
builder.output([clone])
243243
original = builder.get_graph_module()
244-
p = RemoveCloneOpPass()
244+
p = RemoveCloneOpsTransformImported()
245245
graph_after_passes = cast(PassResult, p(original)).graph_module
246246
self.assertEqual(
247247
count_node(graph_after_passes, torch.ops.aten.clone.default), 0

backends/transforms/remove_clone_ops.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,18 @@ class RemoveCloneOpsTransform(ExportPass):
2525
exir_ops.edge.dim_order_ops._clone_dim_order.default,
2626
}
2727

28-
def __init__(self, preserve_input_output_copies: bool = False) -> None:
28+
def __init__(
29+
self,
30+
preserve_input_output_copies: bool = False,
31+
eliminate_quant_dequant_pairs: bool = True,
32+
) -> None:
2933
super().__init__()
3034
self._preserve_input_output_copies = preserve_input_output_copies
35+
self._eliminate_quant_dequant_pairs = eliminate_quant_dequant_pairs
3136

32-
def _remove(self, graph_module: torch.fx.GraphModule) -> None:
37+
def _remove(self, graph_module: torch.fx.GraphModule) -> bool:
3338
dequant_nodes = []
39+
modified = False
3440

3541
for n in graph_module.graph.nodes:
3642
if n.target not in self.clone_ops:
@@ -44,20 +50,26 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None:
4450
if self._is_input_output_copy(n) and self._preserve_input_output_copies:
4551
continue
4652

53+
modified = True
4754
to_be_removed = n
4855
for user_n in list(n.users.keys()):
4956
user_n.replace_input_with(n, n.args[0])
5057
if n.args[0].target in _DEQUANT_OPS:
5158
dequant_nodes += [n.args[0]]
5259
graph_module.graph.erase_node(to_be_removed)
5360

54-
eliminate_dq_q(graph_module, dequant_nodes)
61+
if self._eliminate_quant_dequant_pairs:
62+
eliminate_dq_q(graph_module, dequant_nodes)
63+
64+
return modified
5565

5666
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
57-
self._remove(graph_module)
58-
graph_module.recompile()
59-
dead_code_elimination_pass(graph_module)
60-
return PassResult(graph_module, True)
67+
if self._remove(graph_module):
68+
graph_module.recompile()
69+
dead_code_elimination_pass(graph_module)
70+
return PassResult(graph_module, True)
71+
else:
72+
return PassResult(graph_module, False)
6173

6274
def _is_non_identity_clone(self, node: torch.fx.Node) -> bool:
6375
"""Return True if clone has modified memory layout or dim order."""

0 commit comments

Comments
 (0)