Skip to content

Commit ea4b7e0

Browse files
anijain2305pytorchmergebot
authored andcommitted
1 parent 5c0f474 commit ea4b7e0

File tree

2 files changed

+74
-22
lines changed

2 files changed

+74
-22
lines changed

torch/_inductor/codegen/wrapper.py

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2884,10 +2884,12 @@ def codegen_partition_call(
28842884
def set_all_partition_names(self, num_partitions: int):
28852885
self.all_partition_names = [f"partition_{idx}" for idx in range(num_partitions)]
28862886

2887-
def codegen_subgraph_call(self, subgraph, outer_inputs, outer_outputs):
2887+
def codegen_subgraph_call_with_flattened_outputs(
2888+
self, subgraph, outer_inputs, outer_flattened_outputs
2889+
):
28882890
# Get the input and output names of the subgraph
2889-
outer_output_names = ", ".join(outer_outputs) + (
2890-
"," if len(outer_outputs) == 1 else ""
2891+
outer_output_names = ", ".join(outer_flattened_outputs) + (
2892+
"," if len(outer_flattened_outputs) == 1 else ""
28912893
)
28922894
outer_input_names = ", ".join(outer_inputs) + (
28932895
"," if len(outer_inputs) == 1 else ""
@@ -2900,13 +2902,20 @@ def codegen_subgraph_call(self, subgraph, outer_inputs, outer_outputs):
29002902
f"({outer_output_names}) = {subgraph.graph.name}({subgraph.graph.name}_args)"
29012903
)
29022904

2903-
def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
2904-
# Codegen subgraph by recursively calling the codegen for the subgraph.
2905-
# This lifts the subgraph as a function in the output code.
2906-
if V.graph.aot_mode:
2907-
self.codegen_subgraph_by_inlining(subgraph, outer_inputs, outer_outputs)
2908-
return
2905+
def codegen_subgraph_call(self, subgraph, outer_inputs, outer_buffer_name):
2906+
# Get the input and output names of the subgraph
2907+
outer_input_names = ", ".join(outer_inputs) + (
2908+
"," if len(outer_inputs) == 1 else ""
2909+
)
2910+
2911+
self.writeline(f"{subgraph.graph.name}_args = [{outer_input_names}]")
29092912

2913+
# Call the subgraph launcher function
2914+
self.writeline(
2915+
f"{outer_buffer_name} = {subgraph.graph.name}({subgraph.graph.name}_args)"
2916+
)
2917+
2918+
def codegen_subgraph_common(self, subgraph):
29102919
self.push_codegened_graph(subgraph.graph)
29112920
self.writeline("")
29122921
self.writeline(f"{self.comment} subgraph: {subgraph.name}")
@@ -2925,21 +2934,40 @@ def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs):
29252934
self.already_codegened_subgraphs.add(subgraph.graph.name)
29262935
self.define_subgraph_launcher_fn(subgraph_code.value)
29272936

2928-
self.codegen_subgraph_call(subgraph, outer_inputs, outer_outputs)
2937+
def codegen_subgraph_with_flattened_outputs(
2938+
self, subgraph, outer_inputs, outer_flattened_outputs
2939+
):
2940+
self.codegen_subgraph_common(subgraph)
2941+
self.codegen_subgraph_call_with_flattened_outputs(
2942+
subgraph, outer_inputs, outer_flattened_outputs
2943+
)
2944+
2945+
def codegen_subgraph(self, subgraph, outer_inputs, outer_buffer_name):
2946+
# Codegen subgraph by recursively calling the codegen for the subgraph.
2947+
# This lifts the subgraph as a function in the output code.
2948+
self.codegen_subgraph_common(subgraph)
2949+
self.codegen_subgraph_call(subgraph, outer_inputs, outer_buffer_name)
29292950

29302951
def codegen_invoke_subgraph(self, invoke_subgraph):
29312952
name = invoke_subgraph.get_name()
29322953

29332954
self.writeline(f"{name} = [None] * {len(invoke_subgraph.outputs)}")
29342955
outer_inputs = [buf.codegen_reference() for buf in invoke_subgraph.inputs]
2935-
outer_outputs = [f"{name}[{i}]" for i in range(len(invoke_subgraph.outputs))]
2936-
self.codegen_subgraph(invoke_subgraph.subgraph, outer_inputs, outer_outputs)
2956+
2957+
if V.graph.aot_mode:
2958+
outer_outputs = [
2959+
f"{name}[{i}]" for i in range(len(invoke_subgraph.outputs))
2960+
]
2961+
self.codegen_subgraph_by_inlining(
2962+
invoke_subgraph.subgraph, outer_inputs, outer_outputs
2963+
)
2964+
else:
2965+
self.codegen_subgraph(invoke_subgraph.subgraph, outer_inputs, name)
29372966

29382967
def codegen_conditional(self, conditional):
29392968
name = conditional.get_name()
29402969

29412970
outer_inputs = [buf.codegen_reference() for buf in conditional.operands]
2942-
outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
29432971

29442972
predicate = conditional.predicate.codegen_reference()
29452973
if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer):
@@ -2949,11 +2977,24 @@ def codegen_conditional(self, conditional):
29492977
self.writeline(f"{name} = [None] * {len(conditional.outputs)}")
29502978
self.writeline(f"if {predicate}:")
29512979
self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph))
2952-
self.codegen_subgraph(conditional.true_subgraph, outer_inputs, outer_outputs)
2980+
if V.graph.aot_mode:
2981+
outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
2982+
self.codegen_subgraph_by_inlining(
2983+
conditional.true_subgraph, outer_inputs, outer_outputs
2984+
)
2985+
else:
2986+
self.codegen_subgraph(conditional.true_subgraph, outer_inputs, name)
2987+
29532988
self.writeline(ExitSubgraphLine(self))
29542989
self.writeline("else:")
29552990
self.writeline(EnterSubgraphLine(self, conditional.false_subgraph.graph))
2956-
self.codegen_subgraph(conditional.false_subgraph, outer_inputs, outer_outputs)
2991+
if V.graph.aot_mode:
2992+
outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))]
2993+
self.codegen_subgraph_by_inlining(
2994+
conditional.false_subgraph, outer_inputs, outer_outputs
2995+
)
2996+
else:
2997+
self.codegen_subgraph(conditional.false_subgraph, outer_inputs, name)
29572998
self.writeline(ExitSubgraphLine(self))
29582999

29593000
def codegen_while_loop(self, while_loop):
@@ -2985,17 +3026,28 @@ def codegen_while_loop(self, while_loop):
29853026

29863027
self.writeline("while True:")
29873028
self.writeline(EnterSubgraphLine(self, while_loop.cond_subgraph.graph))
2988-
self.codegen_subgraph(
2989-
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
2990-
)
3029+
3030+
if V.graph.aot_mode:
3031+
self.codegen_subgraph_by_inlining(
3032+
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
3033+
)
3034+
else:
3035+
self.codegen_subgraph_with_flattened_outputs(
3036+
while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs
3037+
)
29913038
self.writeline(
29923039
f"if not {cond_outer_outputs[0]}: break"
29933040
) # condition doesn't hold
29943041
self.writeline(ExitSubgraphLine(self))
29953042
self.writeline(EnterSubgraphLine(self, while_loop.body_subgraph.graph))
2996-
self.codegen_subgraph(
2997-
while_loop.body_subgraph, body_outer_inputs, body_outer_outputs
2998-
)
3043+
if V.graph.aot_mode:
3044+
self.codegen_subgraph_by_inlining(
3045+
while_loop.body_subgraph, body_outer_inputs, body_outer_outputs
3046+
)
3047+
else:
3048+
self.codegen_subgraph_with_flattened_outputs(
3049+
while_loop.body_subgraph, body_outer_inputs, body_outer_outputs
3050+
)
29993051
self.writeline(ExitSubgraphLine(self))
30003052

30013053
@staticmethod

torch/_inductor/ir.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6033,7 +6033,7 @@ def __init__(self, graph: GraphLowering):
60336033
self.graph = graph
60346034
self.name = graph.name
60356035

6036-
wrapper.codegen_subgraph(
6036+
wrapper.codegen_subgraph_with_flattened_outputs(
60376037
CodegenGraph(self.subgraph),
60386038
[*[buffer.get_name() for buffer in self.inputs]],
60396039
[self.name],

0 commit comments

Comments
 (0)