@@ -2884,10 +2884,12 @@ def codegen_partition_call(
28842884 def set_all_partition_names (self , num_partitions : int ):
28852885 self .all_partition_names = [f"partition_{ idx } " for idx in range (num_partitions )]
28862886
2887- def codegen_subgraph_call (self , subgraph , outer_inputs , outer_outputs ):
2887+ def codegen_subgraph_call_with_flattened_outputs (
2888+ self , subgraph , outer_inputs , outer_flattened_outputs
2889+ ):
28882890 # Get the input and output names of the subgraph
2889- outer_output_names = ", " .join (outer_outputs ) + (
2890- "," if len (outer_outputs ) == 1 else ""
2891+ outer_output_names = ", " .join (outer_flattened_outputs ) + (
2892+ "," if len (outer_flattened_outputs ) == 1 else ""
28912893 )
28922894 outer_input_names = ", " .join (outer_inputs ) + (
28932895 "," if len (outer_inputs ) == 1 else ""
@@ -2900,13 +2902,20 @@ def codegen_subgraph_call(self, subgraph, outer_inputs, outer_outputs):
29002902 f"({ outer_output_names } ) = { subgraph .graph .name } ({ subgraph .graph .name } _args)"
29012903 )
29022904
2903- def codegen_subgraph (self , subgraph , outer_inputs , outer_outputs ):
2904- # Codegen subgraph by recursively calling the codegen for the subgraph.
2905- # This lifts the subgraph as a function in the output code.
2906- if V .graph .aot_mode :
2907- self .codegen_subgraph_by_inlining (subgraph , outer_inputs , outer_outputs )
2908- return
2905+ def codegen_subgraph_call (self , subgraph , outer_inputs , outer_buffer_name ):
2906+ # Get the input and output names of the subgraph
2907+ outer_input_names = ", " .join (outer_inputs ) + (
2908+ "," if len (outer_inputs ) == 1 else ""
2909+ )
2910+
2911+ self .writeline (f"{ subgraph .graph .name } _args = [{ outer_input_names } ]" )
29092912
2913+ # Call the subgraph launcher function
2914+ self .writeline (
2915+ f"{ outer_buffer_name } = { subgraph .graph .name } ({ subgraph .graph .name } _args)"
2916+ )
2917+
2918+ def codegen_subgraph_common (self , subgraph ):
29102919 self .push_codegened_graph (subgraph .graph )
29112920 self .writeline ("" )
29122921 self .writeline (f"{ self .comment } subgraph: { subgraph .name } " )
@@ -2925,21 +2934,40 @@ def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
29252934 self .already_codegened_subgraphs .add (subgraph .graph .name )
29262935 self .define_subgraph_launcher_fn (subgraph_code .value )
29272936
2928- self .codegen_subgraph_call (subgraph , outer_inputs , outer_outputs )
2937+ def codegen_subgraph_with_flattened_outputs (
2938+ self , subgraph , outer_inputs , outer_flattened_outputs
2939+ ):
2940+ self .codegen_subgraph_common (subgraph )
2941+ self .codegen_subgraph_call_with_flattened_outputs (
2942+ subgraph , outer_inputs , outer_flattened_outputs
2943+ )
2944+
2945+ def codegen_subgraph (self , subgraph , outer_inputs , outer_buffer_name ):
2946+ # Codegen subgraph by recursively calling the codegen for the subgraph.
2947+ # This lifts the subgraph as a function in the output code.
2948+ self .codegen_subgraph_common (subgraph )
2949+ self .codegen_subgraph_call (subgraph , outer_inputs , outer_buffer_name )
29292950
29302951 def codegen_invoke_subgraph (self , invoke_subgraph ):
29312952 name = invoke_subgraph .get_name ()
29322953
29332954 self .writeline (f"{ name } = [None] * { len (invoke_subgraph .outputs )} " )
29342955 outer_inputs = [buf .codegen_reference () for buf in invoke_subgraph .inputs ]
2935- outer_outputs = [f"{ name } [{ i } ]" for i in range (len (invoke_subgraph .outputs ))]
2936- self .codegen_subgraph (invoke_subgraph .subgraph , outer_inputs , outer_outputs )
2956+
2957+ if V .graph .aot_mode :
2958+ outer_outputs = [
2959+ f"{ name } [{ i } ]" for i in range (len (invoke_subgraph .outputs ))
2960+ ]
2961+ self .codegen_subgraph_by_inlining (
2962+ invoke_subgraph .subgraph , outer_inputs , outer_outputs
2963+ )
2964+ else :
2965+ self .codegen_subgraph (invoke_subgraph .subgraph , outer_inputs , name )
29372966
29382967 def codegen_conditional (self , conditional ):
29392968 name = conditional .get_name ()
29402969
29412970 outer_inputs = [buf .codegen_reference () for buf in conditional .operands ]
2942- outer_outputs = [f"{ name } [{ i } ]" for i in range (len (conditional .outputs ))]
29432971
29442972 predicate = conditional .predicate .codegen_reference ()
29452973 if not isinstance (conditional .predicate , ir .ShapeAsConstantBuffer ):
@@ -2949,11 +2977,24 @@ def codegen_conditional(self, conditional):
29492977 self .writeline (f"{ name } = [None] * { len (conditional .outputs )} " )
29502978 self .writeline (f"if { predicate } :" )
29512979 self .writeline (EnterSubgraphLine (self , conditional .true_subgraph .graph ))
2952- self .codegen_subgraph (conditional .true_subgraph , outer_inputs , outer_outputs )
2980+ if V .graph .aot_mode :
2981+ outer_outputs = [f"{ name } [{ i } ]" for i in range (len (conditional .outputs ))]
2982+ self .codegen_subgraph_by_inlining (
2983+ conditional .true_subgraph , outer_inputs , outer_outputs
2984+ )
2985+ else :
2986+ self .codegen_subgraph (conditional .true_subgraph , outer_inputs , name )
2987+
29532988 self .writeline (ExitSubgraphLine (self ))
29542989 self .writeline ("else:" )
29552990 self .writeline (EnterSubgraphLine (self , conditional .false_subgraph .graph ))
2956- self .codegen_subgraph (conditional .false_subgraph , outer_inputs , outer_outputs )
2991+ if V .graph .aot_mode :
2992+ outer_outputs = [f"{ name } [{ i } ]" for i in range (len (conditional .outputs ))]
2993+ self .codegen_subgraph_by_inlining (
2994+ conditional .false_subgraph , outer_inputs , outer_outputs
2995+ )
2996+ else :
2997+ self .codegen_subgraph (conditional .false_subgraph , outer_inputs , name )
29572998 self .writeline (ExitSubgraphLine (self ))
29582999
29593000 def codegen_while_loop (self , while_loop ):
@@ -2985,17 +3026,28 @@ def codegen_while_loop(self, while_loop):
29853026
29863027 self .writeline ("while True:" )
29873028 self .writeline (EnterSubgraphLine (self , while_loop .cond_subgraph .graph ))
2988- self .codegen_subgraph (
2989- while_loop .cond_subgraph , cond_outer_inputs , cond_outer_outputs
2990- )
3029+
3030+ if V .graph .aot_mode :
3031+ self .codegen_subgraph_by_inlining (
3032+ while_loop .cond_subgraph , cond_outer_inputs , cond_outer_outputs
3033+ )
3034+ else :
3035+ self .codegen_subgraph_with_flattened_outputs (
3036+ while_loop .cond_subgraph , cond_outer_inputs , cond_outer_outputs
3037+ )
29913038 self .writeline (
29923039 f"if not { cond_outer_outputs [0 ]} : break"
29933040 ) # condition doesn't hold
29943041 self .writeline (ExitSubgraphLine (self ))
29953042 self .writeline (EnterSubgraphLine (self , while_loop .body_subgraph .graph ))
2996- self .codegen_subgraph (
2997- while_loop .body_subgraph , body_outer_inputs , body_outer_outputs
2998- )
3043+ if V .graph .aot_mode :
3044+ self .codegen_subgraph_by_inlining (
3045+ while_loop .body_subgraph , body_outer_inputs , body_outer_outputs
3046+ )
3047+ else :
3048+ self .codegen_subgraph_with_flattened_outputs (
3049+ while_loop .body_subgraph , body_outer_inputs , body_outer_outputs
3050+ )
29993051 self .writeline (ExitSubgraphLine (self ))
30003052
30013053 @staticmethod
0 commit comments