Skip to content

Commit a73e269

Browse files
committed
up
1 parent 5d60767 commit a73e269

File tree

1 file changed

+49
-31
lines changed

1 file changed

+49
-31
lines changed

exir/program/_program.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,14 +1089,31 @@ def _can_skip_using_EDGE_DO_NOT_DECOMP(
10891089
for name, program in aten_programs.items():
10901090
if partitioner is not None:
10911091
for curr_partitioner in partitioner.get(name, []):
1092-
curr_ops_no_decomp, check_op_support = (
1093-
curr_partitioner.ops_to_not_decompose(program)
1094-
)
1092+
(
1093+
curr_ops_no_decomp,
1094+
check_op_support,
1095+
) = curr_partitioner.ops_to_not_decompose(program)
10951096
if check_op_support is not None:
10961097
can_skip_using_EDGE_DO_NOT_DECOMP = False
10971098
return can_skip_using_EDGE_DO_NOT_DECOMP
10981099

10991100

1101+
def _replace_view_with_view_copy(program: ExportedProgram) -> ExportedProgram:
1102+
program = program.run_decompositions({})
1103+
new_gm = ReplaceViewOpsWithViewCopyOpsPass()(program.graph_module).graph_module
1104+
program = ExportedProgram(
1105+
root=new_gm,
1106+
graph=new_gm.graph,
1107+
graph_signature=_get_updated_graph_signature(program.graph_signature, new_gm),
1108+
state_dict=program.state_dict,
1109+
range_constraints=program.range_constraints,
1110+
module_call_graph=program.module_call_graph,
1111+
example_inputs=program.example_inputs,
1112+
constants=program.constants,
1113+
)
1114+
return program
1115+
1116+
11001117
def _gen_edge_manager_for_partitioners(
11011118
partitioner: Dict[str, List[Partitioner]],
11021119
aten_programs: Dict[str, ExportedProgram],
@@ -1116,58 +1133,55 @@ def _gen_edge_manager_for_partitioners(
11161133
on nodes with preserved aten targets. They are then replaces with transformed ops to
11171134
keep them through the second pass of decompositions
11181135
"""
1136+
can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
1137+
partitioner, aten_programs
1138+
)
11191139
ops_set_to_not_decompose_by_program = {}
11201140
edge_programs: Dict[str, ExportedProgram] = {}
11211141
for name, program in aten_programs.items():
1122-
# Functionalize program without doing any decompositions
1123-
program = program.run_decompositions({})
1124-
ReplaceViewOpsWithViewCopyOpsPass()(program.graph_module)
1125-
1126-
print(program)
1127-
11281142
if partitioner is not None:
11291143
# preserve all ops listed by all partitioners first
11301144
all_ops_no_decomp = set()
1145+
all_ops_no_decomp_needing_preservation = []
11311146
for curr_partitioner in partitioner.get(name, []):
11321147
curr_ops_no_decomp, _ = curr_partitioner.ops_to_not_decompose(program)
1133-
<<<<<<< HEAD
1134-
curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1135-
curr_ops_no_decomp
1136-
)
1137-
=======
1138-
>>>>>>> ec44f8478 (updates)
11391148
all_ops_no_decomp |= set(curr_ops_no_decomp)
1140-
1149+
11411150
# If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
1142-
# Otherwise there will be issues
1151+
# Otherwise there will be issues
11431152
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1144-
all_ops_no_decomp = _remove_invalid_ops_for_not_decompose(list(all_ops_no_decomp))
1153+
all_ops_no_decomp = _remove_invalid_ops_for_not_decompose(
1154+
list(all_ops_no_decomp)
1155+
)
11451156
all_ops_no_decomp = set(all_ops_no_decomp)
11461157

11471158
# Run default decompositions, except for those in all_ops_no_decomp
11481159
table = _default_decomposition_table()
11491160
for op in all_ops_no_decomp:
1150-
<<<<<<< HEAD
1151-
table.pop(op, None)
1152-
1153-
=======
11541161
if table.pop(op, None) is not None:
11551162
all_ops_no_decomp_needing_preservation.append(op)
1156-
>>>>>>> ec44f8478 (updates)
11571163
program = program.run_decompositions(table)
11581164

11591165
# Among all the preserved aten ops, use the check_op_fn to do an additional
11601166
# check on which ops need to be preserved and which ops need to be decomposed
11611167
# Those which are truly preserved will be replaced with transformed ops
1162-
ops_set_to_not_decompose_by_program[name] = (
1163-
_replace_aten_ops_with_transformed_ops(name, program, partitioner) or []
1164-
)
1165-
program = program.run_decompositions(_default_decomposition_table())
1168+
if can_skip_using_EDGE_DO_NOT_DECOMP:
1169+
ops_set_to_not_decompose_by_program[
1170+
name
1171+
] = all_ops_no_decomp_needing_preservation
1172+
else:
1173+
ops_set_to_not_decompose_by_program[name] = (
1174+
_replace_aten_ops_with_transformed_ops(name, program, partitioner)
1175+
or []
1176+
)
11661177

1167-
_restore_transformed_ops_to_aten_ops(program)
1178+
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1179+
program = program.run_decompositions(_default_decomposition_table())
1180+
_restore_transformed_ops_to_aten_ops(program)
11681181

1182+
# Edge will complain if there are view ops requested for preservation, so we replace them with view_copy
1183+
program = _replace_view_with_view_copy(program)
11691184
edge_programs[name] = program
1170-
11711185
edge_programs[name] = _generate_edge_program(
11721186
config,
11731187
program,
@@ -1211,7 +1225,7 @@ def collect_named_data_store_outputs(
12111225

12121226

12131227
@et_logger("to_edge_transform_and_lower")
1214-
def to_edge_transform_and_lower(
1228+
def to_edge_transform_and_lower( # noqa: C901
12151229
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
12161230
transform_passes: Optional[
12171231
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
@@ -1276,6 +1290,9 @@ def to_edge_transform_and_lower(
12761290
elif partitioner is None:
12771291
partitioner = {name: [] for name in aten_programs.keys()}
12781292

1293+
can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP(
1294+
partitioner, aten_programs
1295+
)
12791296
edge_manager = _gen_edge_manager_for_partitioners(
12801297
partitioner, aten_programs, config, constant_methods
12811298
)
@@ -1301,7 +1318,8 @@ def to_edge_transform_and_lower(
13011318
curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose(
13021319
program
13031320
)
1304-
curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set)
1321+
if not can_skip_using_EDGE_DO_NOT_DECOMP:
1322+
curr_op_set = _remove_invalid_ops_for_not_decompose(curr_op_set)
13051323
ops_set_to_not_decompose = ops_set_to_not_decompose.union(curr_op_set)
13061324
_sanity_check_graph_for_non_decomp_ops(
13071325
name,

0 commit comments

Comments
 (0)