@@ -58,19 +58,37 @@ def forward(self, x_raw, h, c):
58
58
input_h = torch .ones ([1 , 32 ])
59
59
input_c = torch .ones ([1 , 32 ])
60
60
61
+ pattern_lstm_conv_lifted = (
62
+ exir .capture (
63
+ LSTMConvPattern (),
64
+ (input_x , input_h , input_c ),
65
+ exir .CaptureConfig (pt2_mode = True , enable_aot = True ),
66
+ )
67
+ .to_edge ()
68
+ .exported_program .graph_module
69
+ )
61
70
pattern_lstm_conv = (
62
71
exir .capture (
63
72
LSTMConvPattern (),
64
73
(input_x , input_h , input_c ),
65
74
exir .CaptureConfig (pt2_mode = True ),
66
75
)
67
- .to_edge (exir . EdgeCompileConfig ( _check_ir_validity = False ) )
76
+ .to_edge ()
68
77
.exported_program .graph_module
69
78
)
70
79
71
80
def sub (x , y ):
72
81
return torch .sub (x , y )
73
82
83
+ pattern_sub_lifted = (
84
+ exir .capture (
85
+ sub ,
86
+ (input_x , input_h ),
87
+ exir .CaptureConfig (pt2_mode = True , enable_aot = True , _unlift = False ),
88
+ )
89
+ .to_edge (exir .EdgeCompileConfig (_use_edge_ops = True ))
90
+ .exported_program .graph_module
91
+ )
74
92
pattern_sub = (
75
93
exir .capture (
76
94
sub ,
@@ -80,7 +98,12 @@ def sub(x, y):
80
98
.to_edge ()
81
99
.exported_program .graph_module
82
100
)
83
- self .patterns = [pattern_lstm_conv .graph , pattern_sub .graph ]
101
+ self .patterns = [
102
+ pattern_lstm_conv_lifted .graph ,
103
+ pattern_lstm_conv .graph ,
104
+ pattern_sub_lifted .graph ,
105
+ pattern_sub .graph ,
106
+ ]
84
107
85
108
backend_id = QnnBackend .__name__
86
109
self .delegation_spec = DelegationSpec (backend_id , [])
@@ -145,28 +168,18 @@ def generate_partition_list(self, graph_module) -> List[Partition]:
145
168
]
146
169
147
170
"""
148
- partitions_from_all_pattern = [
149
- generate_pattern_op_partitions (graph_module , patterns = [pattern ])
150
- for pattern in self .patterns
151
- ]
152
-
153
- # Check if all partitions are exclusive, this partitions don't support inclusive partitions.
154
- is_exclusive = self .is_exclusive (partitions_from_all_pattern )
155
-
156
- assert (
157
- is_exclusive
158
- ), "There exists inclusive partitions. Currently the fuse method only handle exclusive partitions."
171
+ partitions_from_all_pattern = generate_pattern_op_partitions (
172
+ graph_module , self .patterns
173
+ )
159
174
160
175
# Assign a unique id for each partition
161
176
partition_id = 0
162
177
163
- # If want to support inclusive partitions, the logic can be done here to merge partitions etc.
164
178
flat_proposed_partitions_with_unique_id = []
165
- for partitions_from_one_pattern in partitions_from_all_pattern :
166
- for partition in partitions_from_one_pattern :
167
- partition .id = partition_id
168
- flat_proposed_partitions_with_unique_id .append (partition )
169
- partition_id += 1
179
+ for partition in partitions_from_all_pattern :
180
+ partition .id = partition_id
181
+ flat_proposed_partitions_with_unique_id .append (partition )
182
+ partition_id += 1
170
183
171
184
return flat_proposed_partitions_with_unique_id
172
185
@@ -213,16 +226,28 @@ def forward(self, x_raw, h, c):
213
226
input_h = torch .ones ([1 , 32 ])
214
227
input_c = torch .ones ([1 , 32 ])
215
228
216
- pattern_lstm_conv = (
229
+ pattern_lstm_conv_lifted = (
217
230
exir .capture (
218
231
LSTMConvPattern (),
219
232
(input_x , input_h , input_c ),
220
- exir .CaptureConfig (pt2_mode = True , enable_aot = True , _unlift = False ),
233
+ exir .CaptureConfig (pt2_mode = True , enable_aot = True ),
234
+ )
235
+ .to_edge ()
236
+ .exported_program .graph_module
237
+ )
238
+ pattern_lstm_conv_unlifted = (
239
+ exir .capture (
240
+ LSTMConvPattern (),
241
+ (input_x , input_h , input_c ),
242
+ exir .CaptureConfig (pt2_mode = True ),
221
243
)
222
- .to_edge (exir . EdgeCompileConfig ( _check_ir_validity = False ) )
244
+ .to_edge ()
223
245
.exported_program .graph_module
224
246
)
225
- self .patterns = [pattern_lstm_conv .graph ]
247
+ self .patterns = [
248
+ pattern_lstm_conv_lifted .graph ,
249
+ pattern_lstm_conv_unlifted .graph ,
250
+ ]
226
251
# Only (lstm + conv) pattern is lowerable
227
252
228
253
backend_id = QnnBackend .__name__
0 commit comments