Skip to content

Commit 1544f43

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Update some remove ops passes to correctly set their modified bit
Summary: Updated - RemoveNopSelectOpPass - RemovePermutesAroundElementwiseOps Differential Revision: D87895473
1 parent 83f8914 commit 1544f43

File tree

2 files changed

+236
-93
lines changed

2 files changed

+236
-93
lines changed

backends/cadence/aot/remove_ops.py

Lines changed: 60 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,10 @@
2727
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
2828
from executorch.exir.dialects._ops import ops as exir_ops
2929
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
30-
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
30+
from executorch.exir.pass_base import ExportPass, PassResult
3131
from executorch.exir.pass_manager import PassManager, PassType
3232
from executorch.exir.passes import dead_code_elimination_pass
33-
from executorch.exir.passes.spec_prop_pass import SpecPropPass
34-
from torch.fx.node import Argument, Node
33+
from torch.fx.node import Node
3534

3635

3736
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -246,7 +245,7 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
246245

247246

248247
@register_cadence_pass(CadencePassAttribute(opt_level=1))
249-
class RemoveNopSelectOpPass(ExportPass):
248+
class RemoveNopSelectOpPass(RemoveOrReplacePassInterface):
250249
"""
251250
A select op that selects from a dimension that is size 1 can be eliminated
252251
in a few cases. For example,
@@ -273,87 +272,57 @@ class RemoveNopSelectOpPass(ExportPass):
273272
exir_ops.edge.aten.div.Tensor,
274273
}
275274

276-
def __init__(self) -> None:
277-
super().__init__()
278-
self.op_sizes: dict[str, tuple[torch.Size, torch.Size]] = {}
275+
@property
276+
def targets(self) -> list[EdgeOpOverload]:
277+
return [exir_ops.edge.aten.select_copy.int]
279278

280-
# For select, view, or any op in binary_broadcast_ops, record the shapes of
281-
# input and output tensors.
282-
def call_operator(
283-
self,
284-
op, # pyre-ignore
285-
args: tuple[Argument, ...],
286-
kwargs: dict[str, Argument],
287-
meta: NodeMetadata,
288-
) -> ProxyValue:
289-
res = super().call_operator(op, args, kwargs, meta)
290-
# Unary ops: input and output
291-
if op in {
292-
exir_ops.edge.aten.select_copy.int,
293-
exir_ops.edge.aten.view_copy.default,
294-
}:
295-
arg0 = cast(ProxyValue, args[0])
296-
self.op_sizes[res.node.name] = (arg0.to_tensor().shape, meta["val"].shape)
297-
# Binary ops: two inputs, output shape can be inferred
298-
elif op in self.binary_broadcast_ops:
299-
arg0 = cast(ProxyValue, args[0])
300-
arg1 = cast(ProxyValue, args[1])
301-
self.op_sizes[res.node.name] = (
302-
arg0.to_tensor().shape,
303-
arg1.to_tensor().shape,
304-
)
305-
return res
306-
307-
# Eliminate nop select ops. We begin by inspecting the binary_broadcast_ops,
308-
# and check if their arg is a select op.
309-
def eliminate_nop_select_op(self, graph_module: torch.fx.GraphModule) -> None:
310-
for sel_node in graph_module.graph.nodes:
311-
# We are only interested in select ops
312-
if sel_node.target != exir_ops.edge.aten.select_copy.int:
313-
continue
314-
# The shape of the input/output operands for this select op should
315-
# have been precomputed.
316-
assert sel_node.name in self.op_sizes
317-
(sel_in_shape, sel_out_shape) = self.op_sizes[sel_node.name]
318-
# Get the select dimension
319-
sel_dim = (
320-
sel_node.args[1]
321-
if sel_node.args[1] >= 0
322-
else sel_node.args[1] + len(sel_in_shape)
323-
)
324-
# If the input size along select dimension is not 1, bail.
325-
if sel_in_shape[sel_dim] != 1:
326-
continue
279+
def maybe_remove_or_replace(self, node: Node) -> bool:
280+
# Get the select input node and shapes
281+
sel_in_node = node.args[0]
282+
assert isinstance(sel_in_node, Node)
327283

328-
# Get all the users of the select op that are either view, or
329-
# binary_broadcast_ops.
330-
users = [x for x in list(sel_node.users.keys()) if x.name in self.op_sizes]
331-
sel_in = sel_node.args[0]
332-
333-
# Iterate over the users of select op, and remove the use of the
334-
# select op in the user if feasible.
335-
for node in users:
336-
args = list(node.args)
337-
for idx, sel_arg in enumerate(args):
338-
# Check if the arg is the select op
339-
if sel_arg != sel_node:
340-
continue
341-
# If the input of select has the same shape as the other arg
342-
# of the binary op, the select op can be bypassed.
343-
if sel_in_shape == self.op_sizes[node.name][(idx + 1) % 2]:
344-
args[idx] = sel_in
345-
# update the node's args
346-
node.args = tuple(args)
347-
348-
graph_module.recompile()
349-
graph_module.graph.eliminate_dead_code()
284+
sel_in_shape = sel_in_node.meta["val"].shape
350285

351-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
352-
result = SpecPropPass()(graph_module)
353-
assert result is not None
354-
result = super().call(result.graph_module)
355-
self.eliminate_nop_select_op(result.graph_module)
356-
return result
286+
# Get the select dimension
287+
sel_dim = cast(int, node.args[1])
288+
if sel_dim < 0:
289+
sel_dim += len(sel_in_shape)
290+
291+
# If the input size along select dimension is not 1, bail.
292+
if sel_in_shape[sel_dim] != 1:
293+
return False
294+
295+
# Check if ALL users of the select op can be bypassed.
296+
# A user can be bypassed if:
297+
# 1. It's a view_copy op, OR
298+
# 2. It's a binary_broadcast_op and the other operand has the same shape as sel_in
299+
for user_node in node.users.keys():
300+
can_bypass = False
301+
302+
# View ops can always bypass the select
303+
if user_node.target == exir_ops.edge.aten.view_copy.default:
304+
can_bypass = True
305+
# For binary ops, check if the other operand has the same shape
306+
elif user_node.target in self.binary_broadcast_ops:
307+
# Find which argument is the select node
308+
for idx, arg in enumerate(user_node.args):
309+
if arg == node:
310+
# Get the other argument
311+
other_idx = (idx + 1) % 2
312+
other_arg = user_node.args[other_idx]
313+
if isinstance(other_arg, Node):
314+
other_shape = other_arg.meta["val"].shape
315+
if sel_in_shape == other_shape:
316+
can_bypass = True
317+
break
318+
319+
# If any user cannot be bypassed, we can't remove this select
320+
if not can_bypass:
321+
return False
322+
323+
# All users can be bypassed, so replace the select node with its input
324+
node.replace_all_uses_with(sel_in_node)
325+
return True
357326

358327

359328
@register_cadence_pass(CadencePassAttribute(opt_level=1))
@@ -547,10 +516,7 @@ class Subgraph:
547516
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
548517
subgraphs_found: list[RemovePermutesAroundElementwiseOps.Subgraph] = []
549518
processed_nodes: set[torch.fx.Node] = set()
550-
for node in graph_module.graph.nodes:
551-
if node.target != exir_ops.edge.aten.permute_copy.default:
552-
continue
553-
519+
for node in graph_module.graph.find_nodes(op="call_function", target=exir_ops.edge.aten.permute_copy.default):
554520
start_permute = self.get_permutation(node)
555521
# Expected end permutation for the subgraph.
556522
end_permute = [start_permute.index(i) for i in range(len(start_permute))]
@@ -566,13 +532,18 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
566532
for node in subgraph.nodes:
567533
processed_nodes.add(node)
568534

535+
modified = False
569536
for subgraph in subgraphs_found:
570537
self.permute_subgraph(subgraph)
538+
modified = True
571539

572-
graph_module.graph.eliminate_dead_code()
573-
graph_module.recompile()
540+
if modified:
541+
graph_module.graph.eliminate_dead_code()
542+
graph_module.recompile()
574543

575-
return super().call(graph_module)
544+
return super().call(graph_module)
545+
546+
return PassResult(graph_module, False)
576547

577548
def visit(
578549
self,

0 commit comments

Comments
 (0)