@@ -1089,14 +1089,31 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP(
1089
1089
for name , program in aten_programs .items ():
1090
1090
if partitioner is not None :
1091
1091
for curr_partitioner in partitioner .get (name , []):
1092
- curr_ops_no_decomp , check_op_support = (
1093
- curr_partitioner .ops_to_not_decompose (program )
1094
- )
1092
+ (
1093
+ curr_ops_no_decomp ,
1094
+ check_op_support ,
1095
+ ) = curr_partitioner .ops_to_not_decompose (program )
1095
1096
if check_op_support is not None :
1096
1097
can_skip_using_EDGE_DO_NOT_DECOMP = False
1097
1098
return can_skip_using_EDGE_DO_NOT_DECOMP
1098
1099
1099
1100
1101
+ def _replace_view_with_view_copy (program : ExportedProgram ) -> ExportedProgram :
1102
+ program = program .run_decompositions ({})
1103
+ new_gm = ReplaceViewOpsWithViewCopyOpsPass ()(program .graph_module ).graph_module
1104
+ program = ExportedProgram (
1105
+ root = new_gm ,
1106
+ graph = new_gm .graph ,
1107
+ graph_signature = _get_updated_graph_signature (program .graph_signature , new_gm ),
1108
+ state_dict = program .state_dict ,
1109
+ range_constraints = program .range_constraints ,
1110
+ module_call_graph = program .module_call_graph ,
1111
+ example_inputs = program .example_inputs ,
1112
+ constants = program .constants ,
1113
+ )
1114
+ return program
1115
+
1116
+
1100
1117
def _gen_edge_manager_for_partitioners (
1101
1118
partitioner : Dict [str , List [Partitioner ]],
1102
1119
aten_programs : Dict [str , ExportedProgram ],
@@ -1116,58 +1133,55 @@ def _gen_edge_manager_for_partitioners(
1116
1133
on nodes with preserved aten targets. They are then replaces with transformed ops to
1117
1134
keep them through the second pass of decompositions
1118
1135
"""
1136
+ can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP (
1137
+ partitioner , aten_programs
1138
+ )
1119
1139
ops_set_to_not_decompose_by_program = {}
1120
1140
edge_programs : Dict [str , ExportedProgram ] = {}
1121
1141
for name , program in aten_programs .items ():
1122
- # Functionalize program without doing any decompositions
1123
- program = program .run_decompositions ({})
1124
- ReplaceViewOpsWithViewCopyOpsPass ()(program .graph_module )
1125
-
1126
- print (program )
1127
-
1128
1142
if partitioner is not None :
1129
1143
# preserve all ops listed by all partitioners first
1130
1144
all_ops_no_decomp = set ()
1145
+ all_ops_no_decomp_needing_preservation = []
1131
1146
for curr_partitioner in partitioner .get (name , []):
1132
1147
curr_ops_no_decomp , _ = curr_partitioner .ops_to_not_decompose (program )
1133
- < << << << HEAD
1134
- curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1135
- curr_ops_no_decomp
1136
- )
1137
- == == == =
1138
- >> >> >> > ec44f8478 (updates )
1139
1148
all_ops_no_decomp |= set (curr_ops_no_decomp )
1140
-
1149
+
1141
1150
# If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
1142
- # Otherwise there will be issues
1151
+ # Otherwise there will be issues
1143
1152
if not can_skip_using_EDGE_DO_NOT_DECOMP :
1144
- all_ops_no_decomp = _remove_invalid_ops_for_not_decompose (list (all_ops_no_decomp ))
1153
+ all_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1154
+ list (all_ops_no_decomp )
1155
+ )
1145
1156
all_ops_no_decomp = set (all_ops_no_decomp )
1146
1157
1147
1158
# Run default decompositions, except for those in all_ops_no_decomp
1148
1159
table = _default_decomposition_table ()
1149
1160
for op in all_ops_no_decomp :
1150
- < << << << HEAD
1151
- table .pop (op , None )
1152
-
1153
- == == == =
1154
1161
if table .pop (op , None ) is not None :
1155
1162
all_ops_no_decomp_needing_preservation .append (op )
1156
- > >> >> >> ec44f8478 (updates )
1157
1163
program = program .run_decompositions (table )
1158
1164
1159
1165
# Among all the preserved aten ops, use the check_op_fn to do an additional
1160
1166
# check on which ops need to be preserved and which ops need to be decomposed
1161
1167
# Those which are truly preserved will be replaced with transformed ops
1162
- ops_set_to_not_decompose_by_program [name ] = (
1163
- _replace_aten_ops_with_transformed_ops (name , program , partitioner ) or []
1164
- )
1165
- program = program .run_decompositions (_default_decomposition_table ())
1168
+ if can_skip_using_EDGE_DO_NOT_DECOMP :
1169
+ ops_set_to_not_decompose_by_program [
1170
+ name
1171
+ ] = all_ops_no_decomp_needing_preservation
1172
+ else :
1173
+ ops_set_to_not_decompose_by_program [name ] = (
1174
+ _replace_aten_ops_with_transformed_ops (name , program , partitioner )
1175
+ or []
1176
+ )
1166
1177
1167
- _restore_transformed_ops_to_aten_ops (program )
1178
+ if not can_skip_using_EDGE_DO_NOT_DECOMP :
1179
+ program = program .run_decompositions (_default_decomposition_table ())
1180
+ _restore_transformed_ops_to_aten_ops (program )
1168
1181
1182
+ # Edge will complain if there are view ops requested for preservation, so we replace them with view_copy
1183
+ program = _replace_view_with_view_copy (program )
1169
1184
edge_programs [name ] = program
1170
-
1171
1185
edge_programs [name ] = _generate_edge_program (
1172
1186
config ,
1173
1187
program ,
@@ -1211,7 +1225,7 @@ def collect_named_data_store_outputs(
1211
1225
1212
1226
1213
1227
@et_logger ("to_edge_transform_and_lower" )
1214
- def to_edge_transform_and_lower (
1228
+ def to_edge_transform_and_lower ( # noqa: C901
1215
1229
programs : Union [ExportedProgram , Dict [str , ExportedProgram ]],
1216
1230
transform_passes : Optional [
1217
1231
Union [Sequence [PassType ], Dict [str , Sequence [PassType ]]]
@@ -1276,6 +1290,9 @@ def to_edge_transform_and_lower(
1276
1290
elif partitioner is None :
1277
1291
partitioner = {name : [] for name in aten_programs .keys ()}
1278
1292
1293
+ can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP (
1294
+ partitioner , aten_programs
1295
+ )
1279
1296
edge_manager = _gen_edge_manager_for_partitioners (
1280
1297
partitioner , aten_programs , config , constant_methods
1281
1298
)
@@ -1301,7 +1318,8 @@ def to_edge_transform_and_lower(
1301
1318
curr_op_set , check_op_support = curr_partitioner .ops_to_not_decompose (
1302
1319
program
1303
1320
)
1304
- curr_op_set = _remove_invalid_ops_for_not_decompose (curr_op_set )
1321
+ if not can_skip_using_EDGE_DO_NOT_DECOMP :
1322
+ curr_op_set = _remove_invalid_ops_for_not_decompose (curr_op_set )
1305
1323
ops_set_to_not_decompose = ops_set_to_not_decompose .union (curr_op_set )
1306
1324
_sanity_check_graph_for_non_decomp_ops (
1307
1325
name ,
0 commit comments