2727from executorch .backends .transforms .remove_clone_ops import RemoveCloneOpsTransform
2828from executorch .exir .dialects ._ops import ops as exir_ops
2929from executorch .exir .dialects .edge ._ops import EdgeOpOverload , EdgeOpOverloadPacket
30- from executorch .exir .pass_base import ExportPass , NodeMetadata , PassResult , ProxyValue
30+ from executorch .exir .pass_base import ExportPass , PassResult
3131from executorch .exir .pass_manager import PassManager , PassType
3232from executorch .exir .passes import dead_code_elimination_pass
33- from executorch .exir .passes .spec_prop_pass import SpecPropPass
34- from torch .fx .node import Argument , Node
33+ from torch .fx .node import Node
3534
3635
3736@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
@@ -246,7 +245,7 @@ def maybe_remove_or_replace(self, node: Node) -> bool:
246245
247246
248247@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
249- class RemoveNopSelectOpPass (ExportPass ):
248+ class RemoveNopSelectOpPass (RemoveOrReplacePassInterface ):
250249 """
251250 A select op that selects from a dimension that is size 1 can be eliminated
252251 in a few cases. For example,
@@ -273,87 +272,57 @@ class RemoveNopSelectOpPass(ExportPass):
273272 exir_ops .edge .aten .div .Tensor ,
274273 }
275274
276- def __init__ ( self ) -> None :
277- super (). __init__ ()
278- self . op_sizes : dict [ str , tuple [ torch . Size , torch . Size ]] = {}
275+ @ property
276+ def targets ( self ) -> list [ EdgeOpOverload ]:
277+ return [ exir_ops . edge . aten . select_copy . int ]
279278
280- # For select, view, or any op in binary_broadcast_ops, record the shapes of
281- # input and output tensors.
282- def call_operator (
283- self ,
284- op , # pyre-ignore
285- args : tuple [Argument , ...],
286- kwargs : dict [str , Argument ],
287- meta : NodeMetadata ,
288- ) -> ProxyValue :
289- res = super ().call_operator (op , args , kwargs , meta )
290- # Unary ops: input and output
291- if op in {
292- exir_ops .edge .aten .select_copy .int ,
293- exir_ops .edge .aten .view_copy .default ,
294- }:
295- arg0 = cast (ProxyValue , args [0 ])
296- self .op_sizes [res .node .name ] = (arg0 .to_tensor ().shape , meta ["val" ].shape )
297- # Binary ops: two inputs, output shape can be inferred
298- elif op in self .binary_broadcast_ops :
299- arg0 = cast (ProxyValue , args [0 ])
300- arg1 = cast (ProxyValue , args [1 ])
301- self .op_sizes [res .node .name ] = (
302- arg0 .to_tensor ().shape ,
303- arg1 .to_tensor ().shape ,
304- )
305- return res
306-
307- # Eliminate nop select ops. We begin by inspecting the binary_broadcast_ops,
308- # and check if their arg is a select op.
309- def eliminate_nop_select_op (self , graph_module : torch .fx .GraphModule ) -> None :
310- for sel_node in graph_module .graph .nodes :
311- # We are only interested in select ops
312- if sel_node .target != exir_ops .edge .aten .select_copy .int :
313- continue
314- # The shape of the input/output operands for this select op should
315- # have been precomputed.
316- assert sel_node .name in self .op_sizes
317- (sel_in_shape , sel_out_shape ) = self .op_sizes [sel_node .name ]
318- # Get the select dimension
319- sel_dim = (
320- sel_node .args [1 ]
321- if sel_node .args [1 ] >= 0
322- else sel_node .args [1 ] + len (sel_in_shape )
323- )
324- # If the input size along select dimension is not 1, bail.
325- if sel_in_shape [sel_dim ] != 1 :
326- continue
279+ def maybe_remove_or_replace (self , node : Node ) -> bool :
280+ # Get the select input node and shapes
281+ sel_in_node = node .args [0 ]
282+ assert isinstance (sel_in_node , Node )
327283
328- # Get all the users of the select op that are either view, or
329- # binary_broadcast_ops.
330- users = [x for x in list (sel_node .users .keys ()) if x .name in self .op_sizes ]
331- sel_in = sel_node .args [0 ]
332-
333- # Iterate over the users of select op, and remove the use of the
334- # select op in the user if feasible.
335- for node in users :
336- args = list (node .args )
337- for idx , sel_arg in enumerate (args ):
338- # Check if the arg is the select op
339- if sel_arg != sel_node :
340- continue
341- # If the input of select has the same shape as the other arg
342- # of the binary op, the select op can be bypassed.
343- if sel_in_shape == self .op_sizes [node .name ][(idx + 1 ) % 2 ]:
344- args [idx ] = sel_in
345- # update the node's args
346- node .args = tuple (args )
347-
348- graph_module .recompile ()
349- graph_module .graph .eliminate_dead_code ()
284+ sel_in_shape = sel_in_node .meta ["val" ].shape
350285
351- def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
352- result = SpecPropPass ()(graph_module )
353- assert result is not None
354- result = super ().call (result .graph_module )
355- self .eliminate_nop_select_op (result .graph_module )
356- return result
286+ # Get the select dimension
287+ sel_dim = cast (int , node .args [1 ])
288+ if sel_dim < 0 :
289+ sel_dim += len (sel_in_shape )
290+
291+ # If the input size along select dimension is not 1, bail.
292+ if sel_in_shape [sel_dim ] != 1 :
293+ return False
294+
295+ # Check if ALL users of the select op can be bypassed.
296+ # A user can be bypassed if:
297+ # 1. It's a view_copy op, OR
298+ # 2. It's a binary_broadcast_op and the other operand has the same shape as sel_in
299+ for user_node in node .users .keys ():
300+ can_bypass = False
301+
302+ # View ops can always bypass the select
303+ if user_node .target == exir_ops .edge .aten .view_copy .default :
304+ can_bypass = True
305+ # For binary ops, check if the other operand has the same shape
306+ elif user_node .target in self .binary_broadcast_ops :
307+ # Find which argument is the select node
308+ for idx , arg in enumerate (user_node .args ):
309+ if arg == node :
310+ # Get the other argument
311+ other_idx = (idx + 1 ) % 2
312+ other_arg = user_node .args [other_idx ]
313+ if isinstance (other_arg , Node ):
314+ other_shape = other_arg .meta ["val" ].shape
315+ if sel_in_shape == other_shape :
316+ can_bypass = True
317+ break
318+
319+ # If any user cannot be bypassed, we can't remove this select
320+ if not can_bypass :
321+ return False
322+
323+ # All users can be bypassed, so replace the select node with its input
324+ node .replace_all_uses_with (sel_in_node )
325+ return True
357326
358327
359328@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
@@ -547,10 +516,7 @@ class Subgraph:
547516 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
548517 subgraphs_found : list [RemovePermutesAroundElementwiseOps .Subgraph ] = []
549518 processed_nodes : set [torch .fx .Node ] = set ()
550- for node in graph_module .graph .nodes :
551- if node .target != exir_ops .edge .aten .permute_copy .default :
552- continue
553-
519+ for node in graph_module .graph .find_nodes (op = "call_function" , target = exir_ops .edge .aten .permute_copy .default ):
554520 start_permute = self .get_permutation (node )
555521 # Expected end permutation for the subgraph.
556522 end_permute = [start_permute .index (i ) for i in range (len (start_permute ))]
@@ -566,13 +532,18 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
566532 for node in subgraph .nodes :
567533 processed_nodes .add (node )
568534
535+ modified = False
569536 for subgraph in subgraphs_found :
570537 self .permute_subgraph (subgraph )
538+ modified = True
571539
572- graph_module .graph .eliminate_dead_code ()
573- graph_module .recompile ()
540+ if modified :
541+ graph_module .graph .eliminate_dead_code ()
542+ graph_module .recompile ()
574543
575- return super ().call (graph_module )
544+ return super ().call (graph_module )
545+
546+ return PassResult (graph_module , False )
576547
577548 def visit (
578549 self ,
0 commit comments