Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 60 additions & 89 deletions backends/cadence/aot/remove_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,10 @@
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.pass_manager import PassManager, PassType
from executorch.exir.passes import dead_code_elimination_pass
from executorch.exir.passes.spec_prop_pass import SpecPropPass
from torch.fx.node import Argument, Node
from torch.fx.node import Node


@register_cadence_pass(CadencePassAttribute(opt_level=0))
Expand Down Expand Up @@ -246,7 +245,7 @@ def maybe_remove_or_replace(self, node: Node) -> bool:


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class RemoveNopSelectOpPass(ExportPass):
class RemoveNopSelectOpPass(RemoveOrReplacePassInterface):
"""
A select op that selects from a dimension that is size 1 can be eliminated
in a few cases. For example,
Expand All @@ -273,87 +272,57 @@ class RemoveNopSelectOpPass(ExportPass):
exir_ops.edge.aten.div.Tensor,
}

def __init__(self) -> None:
super().__init__()
self.op_sizes: dict[str, tuple[torch.Size, torch.Size]] = {}
@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten.select_copy.int]

# For select, view, or any op in binary_broadcast_ops, record the shapes of
# input and output tensors.
def call_operator(
self,
op, # pyre-ignore
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
res = super().call_operator(op, args, kwargs, meta)
# Unary ops: input and output
if op in {
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten.view_copy.default,
}:
arg0 = cast(ProxyValue, args[0])
self.op_sizes[res.node.name] = (arg0.to_tensor().shape, meta["val"].shape)
# Binary ops: two inputs, output shape can be inferred
elif op in self.binary_broadcast_ops:
arg0 = cast(ProxyValue, args[0])
arg1 = cast(ProxyValue, args[1])
self.op_sizes[res.node.name] = (
arg0.to_tensor().shape,
arg1.to_tensor().shape,
)
return res

# Eliminate nop select ops. We begin by inspecting the binary_broadcast_ops,
# and check if their arg is a select op.
def eliminate_nop_select_op(self, graph_module: torch.fx.GraphModule) -> None:
for sel_node in graph_module.graph.nodes:
# We are only interested in select ops
if sel_node.target != exir_ops.edge.aten.select_copy.int:
continue
# The shape of the input/output operands for this select op should
# have been precomputed.
assert sel_node.name in self.op_sizes
(sel_in_shape, sel_out_shape) = self.op_sizes[sel_node.name]
# Get the select dimension
sel_dim = (
sel_node.args[1]
if sel_node.args[1] >= 0
else sel_node.args[1] + len(sel_in_shape)
)
# If the input size along select dimension is not 1, bail.
if sel_in_shape[sel_dim] != 1:
continue
def maybe_remove_or_replace(self, node: Node) -> bool:
# Get the select input node and shapes
sel_in_node = node.args[0]
assert isinstance(sel_in_node, Node)

# Get all the users of the select op that are either view, or
# binary_broadcast_ops.
users = [x for x in list(sel_node.users.keys()) if x.name in self.op_sizes]
sel_in = sel_node.args[0]

# Iterate over the users of select op, and remove the use of the
# select op in the user if feasible.
for node in users:
args = list(node.args)
for idx, sel_arg in enumerate(args):
# Check if the arg is the select op
if sel_arg != sel_node:
continue
# If the input of select has the same shape as the other arg
# of the binary op, the select op can be bypassed.
if sel_in_shape == self.op_sizes[node.name][(idx + 1) % 2]:
args[idx] = sel_in
# update the node's args
node.args = tuple(args)

graph_module.recompile()
graph_module.graph.eliminate_dead_code()
sel_in_shape = sel_in_node.meta["val"].shape

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
result = SpecPropPass()(graph_module)
assert result is not None
result = super().call(result.graph_module)
self.eliminate_nop_select_op(result.graph_module)
return result
# Get the select dimension
sel_dim = cast(int, node.args[1])
if sel_dim < 0:
sel_dim += len(sel_in_shape)

# If the input size along select dimension is not 1, bail.
if sel_in_shape[sel_dim] != 1:
return False

# Check if ALL users of the select op can be bypassed.
# A user can be bypassed if:
# 1. It's a view_copy op, OR
# 2. It's a binary_broadcast_op and the other operand has the same shape as sel_in
for user_node in node.users.keys():
can_bypass = False

# View ops can always bypass the select
if user_node.target == exir_ops.edge.aten.view_copy.default:
can_bypass = True
# For binary ops, check if the other operand has the same shape
elif user_node.target in self.binary_broadcast_ops:
# Find which argument is the select node
for idx, arg in enumerate(user_node.args):
if arg == node:
# Get the other argument
other_idx = (idx + 1) % 2
other_arg = user_node.args[other_idx]
if isinstance(other_arg, Node):
other_shape = other_arg.meta["val"].shape
if sel_in_shape == other_shape:
can_bypass = True
break

# If any user cannot be bypassed, we can't remove this select
if not can_bypass:
return False

# All users can be bypassed, so replace the select node with its input
node.replace_all_uses_with(sel_in_node)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=1))
Expand Down Expand Up @@ -547,10 +516,7 @@ class Subgraph:
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
subgraphs_found: list[RemovePermutesAroundElementwiseOps.Subgraph] = []
processed_nodes: set[torch.fx.Node] = set()
for node in graph_module.graph.nodes:
if node.target != exir_ops.edge.aten.permute_copy.default:
continue

for node in graph_module.graph.find_nodes(op="call_function", target=exir_ops.edge.aten.permute_copy.default):
start_permute = self.get_permutation(node)
# Expected end permutation for the subgraph.
end_permute = [start_permute.index(i) for i in range(len(start_permute))]
Expand All @@ -566,13 +532,18 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in subgraph.nodes:
processed_nodes.add(node)

modified = False
for subgraph in subgraphs_found:
self.permute_subgraph(subgraph)
modified = True

graph_module.graph.eliminate_dead_code()
graph_module.recompile()
if modified:
graph_module.graph.eliminate_dead_code()
graph_module.recompile()

return super().call(graph_module)
return super().call(graph_module)

return PassResult(graph_module, False)

def visit(
self,
Expand Down
Loading
Loading