Skip to content

Commit 0c781eb

Browse files
committed
Update base for Update on "[ET-VK] 5/n Split dispatches between multiple command buffers. Track previously submitted command buffers in context and add function to execute all previous command buffers."
The diff adds changes to store command buffers submitted with final_use set to false. Storing these buffers is necessary for `execute()` function. Since, `encode_execute()` function is typically called once but `execute()` can be called multiple times, `submit_all_non_final_cmds` function is added so all recorded command buffers with `final_use = False` can be called multiple times in `execute()`. #### Key Changes * Added a flag `execute_pending_first_submission` to the `ComputeGraph` class to track whether execute nodes have been freshly encoded and need to be submitted first. * Added a new function `submit_all_non_final_cmds` to the `Context` class, which submits all non-final command buffers to the GPU. * Modified the `submit_cmd_to_gpu` function to add the submitted command buffer to the `non_final_cmds_` list if it's not marked as final use. * Updated the `execute` function in `ComputeGraph` to submit all non-final command buffers before executing the graph. Differential Revision: [D78360038](https://our.internmc.facebook.com/intern/diff/D78360038/) [ghstack-poisoned]
2 parents 6839a4f + 97f7610 commit 0c781eb

File tree

40 files changed

+2333
-890
lines changed

40 files changed

+2333
-890
lines changed

CMakeLists.txt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@
4545
# ~~~
4646
#
4747

48-
cmake_minimum_required(VERSION 3.24)
48+
# TODO Lower to 3.24 when XNNPACK dependency is updated to include
49+
# https://github.com/google/XNNPACK/commit/c690daa67f883e1b627aadf7684c06797e9a0684
50+
cmake_minimum_required(VERSION 3.29)
4951
project(executorch)
5052

5153
include(${PROJECT_SOURCE_DIR}/tools/cmake/common/preset.cmake)
@@ -560,6 +562,10 @@ if(EXECUTORCH_BUILD_EXTENSION_LLM)
560562
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/tokenizers)
561563
endif()
562564

565+
if(EXECUTORCH_BUILD_EXTENSION_LLM_APPLE)
566+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/apple)
567+
endif()
568+
563569
if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER)
564570
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/runner)
565571
endif()

CMakePresets.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/third-party/ios-cmake/ios.toolchain.cmake",
1616
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/macos.cmake",
1717
"PLATFORM": "MAC_ARM64",
18-
"DEPLOYMENT_TARGET": "12.0"
18+
"DEPLOYMENT_TARGET": "12.0",
19+
"CMAKE_MACOSX_BUNDLE": "OFF"
1920
},
2021
"condition": {
2122
"lhs": "${hostSystemName}",

backends/apple/coreml/compiler/coreml_preprocess.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def preprocess_model(
365365

366366
match model_type:
367367
case CoreMLBackend.MODEL_TYPE.COMPILED_MODEL:
368-
shutil.rmtree(str(model_path.resolve()))
368+
shutil.rmtree(str(model_path.resolve()), ignore_errors=True)
369369
model_path = model_dir_path / MODEL_PATHS.COMPILED_MODEL.value
370370
compiled_model_path = mlmodel.get_compiled_model_path()
371371
shutil.move(
@@ -396,7 +396,7 @@ def preprocess_model(
396396
for key, value in model_debug_info.debugSymbolToHandles.items()
397397
}
398398

399-
shutil.rmtree(str(dir_path.resolve()))
399+
shutil.rmtree(str(dir_path.resolve()), ignore_errors=True)
400400
return PreprocessResult(
401401
processed_bytes=processed_bytes,
402402
debug_handle_map=debug_handle_map,

backends/apple/coreml/partition/coreml_partitioner.py

Lines changed: 92 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,21 @@
2828

2929
class OperatorsSupportedForCoreMLBackend(OperatorSupportBase):
3030
def __init__(
31-
self, skip_ops_for_coreml_delegation: Optional[List[str]] = None
31+
self,
32+
skip_ops_for_coreml_delegation: Optional[List[str]] = None,
33+
lower_full_graph: bool = False,
3234
) -> None:
3335
if skip_ops_for_coreml_delegation is None:
3436
skip_ops_for_coreml_delegation = []
3537
super().__init__()
3638
self.skip_ops_for_coreml_delegation = skip_ops_for_coreml_delegation
39+
self.lower_full_graph = lower_full_graph
40+
self._logged_msgs = set()
41+
42+
def log_once(self, msg: str) -> None:
43+
if msg not in self._logged_msgs:
44+
logging.info(msg)
45+
self._logged_msgs.add(msg)
3746

3847
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
3948
# get_attr node can always be supported on any backend
@@ -44,14 +53,63 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4453
# skip ops if specified by user
4554
node_target_name = getattr(node.target, "__name__", "").lower()
4655
if node_target_name in (self.skip_ops_for_coreml_delegation or []):
56+
self.log_once(
57+
"Skipping op for CoreML delegation because it is in skip_ops_for_coreml_delegation: "
58+
+ node_target_name
59+
)
60+
assert (
61+
not self.lower_full_graph
62+
), "Cannot have skip_ops_for_coreml_delegation when lower_full_graph is True"
4763
return False
64+
65+
# TODO: enable this after bugs in ExecuTorch's partitioner are fixed
66+
# # If lower_full_graph=False, do not partition nodes with symbolic args because it can result in symbolic args
67+
# # in the placeholders due to partitioning, which CoreML does not support
68+
# if not self.lower_full_graph and any(
69+
# isinstance(arg, torch.fx.Node)
70+
# and isinstance(
71+
# arg.meta.get("val", None),
72+
# (torch.SymInt, torch.SymBool, torch.SymFloat),
73+
# )
74+
# for arg in node.args
75+
# ):
76+
# self.log_once(
77+
# "Skipping op for CoreML delegation because it contains symbolic args: "
78+
# + node_target_name
79+
# )
80+
# assert not self.lower_full_graph
81+
# return False
82+
4883
# query coremltools to see if node is supported
49-
return ct.converters.mil.frontend.torch.is_torch_fx_node_supported(node)
84+
is_supported = ct.converters.mil.frontend.torch.is_torch_fx_node_supported(
85+
node
86+
)
87+
if not is_supported:
88+
if self.lower_full_graph:
89+
raise NotImplementedError(
90+
f"""CoreML does not support the op {node_target_name}, but you have set lower_full_graph=True in the CoreMLPartitioner.
91+
92+
Please set lower_full_graph=False in the CoreMLPartitioner to allow running unsupported ops outside of CoreML. Note that setting lower_full_graph=False may affect performance of CoreML and the available features.
93+
As an alternative to setting lower_full_graph=False, you can try rewriting your model to avoid using this op.
94+
95+
Also consider filing an issue with Apple's coremltools repo to request support for the op: https://github.com/apple/coremltools/issues
96+
Do not file an issue with ExecuTorch for op support.
97+
"""
98+
)
99+
self.log_once(
100+
"Skipping op for CoreML delegation because it is not supported by CoreML: "
101+
+ node_target_name
102+
)
103+
return is_supported
50104
# cowardly refuse to support all other types of node:
51105
# 1. placeholder / output nodes should not be tagged
52106
# reference: https://github.com/pytorch/executorch/pull/1398
53107
# 2. call_module / call_method should have been replaced with call_function?
54108
else:
109+
self.log_once(
110+
"Skipping op for CoreML delegation because it is not get_attr or call_function: "
111+
+ node.op
112+
)
55113
return False
56114

57115

@@ -62,6 +120,8 @@ def __init__(
62120
skip_ops_for_coreml_delegation: Optional[List[str]] = None,
63121
compile_specs: Optional[List[CompileSpec]] = None,
64122
take_over_mutable_buffer: Optional[bool] = True,
123+
lower_full_graph: bool = False,
124+
take_over_constant_data: bool = True,
65125
) -> None:
66126
if skip_ops_for_coreml_delegation is None:
67127
skip_ops_for_coreml_delegation = []
@@ -71,6 +131,20 @@ def __init__(
71131
compile_specs=compile_specs if compile_specs is not None else [],
72132
)
73133
self.take_over_mutable_buffer = take_over_mutable_buffer
134+
self.lower_full_graph = lower_full_graph
135+
self.take_over_constant_data = take_over_constant_data
136+
self._logged_msgs = set()
137+
138+
if self.lower_full_graph:
139+
assert (
140+
len(self.skip_ops_for_coreml_delegation) == 0
141+
), "When lower_full_graph=True, you cannot set skip_ops_for_coreml_delegation"
142+
assert (
143+
self.take_over_constant_data
144+
), "When lower_full_graph=True, you must set take_over_constant_data=True"
145+
assert (
146+
self.take_over_mutable_buffer
147+
), "When lower_full_graph=True, you must set take_over_mutable_buffer=True"
74148

75149
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
76150
# Run the CapabilityBasedPartitioner to return the largest possible
@@ -80,7 +154,9 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
80154

81155
capability_partitioner = CapabilityBasedPartitioner(
82156
exported_program.graph_module,
83-
OperatorsSupportedForCoreMLBackend(self.skip_ops_for_coreml_delegation),
157+
OperatorsSupportedForCoreMLBackend(
158+
self.skip_ops_for_coreml_delegation, self.lower_full_graph
159+
),
84160
allows_single_node_partition=True,
85161
)
86162
partition_list = capability_partitioner.propose_partitions()
@@ -90,7 +166,8 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
90166
node.meta["delegation_tag"] = tag
91167
partition_tags[tag] = self.delegation_spec
92168

93-
tag_constant_data(exported_program)
169+
if self.take_over_constant_data:
170+
tag_constant_data(exported_program)
94171
if self.take_over_mutable_buffer:
95172
logger.info(
96173
"Core ML partitioner will take over torch mutable buffer as Core ML state, "
@@ -105,12 +182,18 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
105182
tagged_exported_program=exported_program, partition_tags=partition_tags
106183
)
107184

185+
def log_once(self, msg: str) -> None:
186+
if msg not in self._logged_msgs:
187+
logging.info(msg)
188+
self._logged_msgs.add(msg)
189+
108190
def ops_to_not_decompose(
109191
self, ep: ExportedProgram
110192
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
111193
do_not_decompose = []
112-
op_support = OperatorsSupportedForCoreMLBackend()
113-
_logged_warnings = set()
194+
op_support = OperatorsSupportedForCoreMLBackend(
195+
self.skip_ops_for_coreml_delegation, self.lower_full_graph
196+
)
114197

115198
# CoreML prevents certain ops (like triu) from lowering to CoreML when put in the ExecuTorch op namespace
116199
# TODO: upstream fixes, but pending ET consuming a new published version of coremltools with the
@@ -134,9 +217,7 @@ def ops_to_not_decompose(
134217
except Exception as e:
135218
# CoreML's op_support.is_node_supported will sometimes throw
136219
# for unsupported ops, rather than returning False
137-
warn_str = f"Encountered exception when checking node support: {e}"
138-
if warn_str not in _logged_warnings:
139-
logger.warning(warn_str)
140-
_logged_warnings.add(warn_str)
141-
220+
self.log_once(
221+
f"Encountered exception when checking node support, treating node as unsupported: {e}"
222+
)
142223
return do_not_decompose, None

backends/apple/coreml/test/test_coreml_partitioner.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#
33
# Please refer to the license found in the LICENSE file in the root directory of the source tree.
44

5+
import copy
6+
import sys
57
import unittest
68

79
import coremltools as ct
@@ -14,6 +16,28 @@
1416
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1517
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
1618
from executorch.exir.backend.utils import format_delegated_graph
19+
from executorch.runtime import Runtime
20+
21+
22+
@torch.library.custom_op("unsupported::linear", mutates_args=())
23+
def _(
24+
x: torch.Tensor,
25+
w: torch.Tensor,
26+
b: torch.Tensor,
27+
) -> torch.Tensor:
28+
return torch.ops.aten.linear.default(x, w, b)
29+
30+
31+
@torch.library.register_fake("unsupported::linear")
32+
def _(
33+
x: torch.Tensor,
34+
w: torch.Tensor,
35+
b: torch.Tensor,
36+
) -> torch.Tensor:
37+
return torch.ops.aten.linear.default(x, w, b)
38+
39+
40+
_TEST_RUNTIME = sys.platform == "darwin"
1741

1842

1943
class TestCoreMLPartitioner(unittest.TestCase):
@@ -200,10 +224,120 @@ def forward(self, q, k_val, input_pos):
200224
"getitem",
201225
]
202226

227+
def test_lower_full_graph(self):
228+
class Model(torch.nn.Module):
229+
def forward(self, a, x, b):
230+
out = torch.ops.aten.linear.default(a, x, b)
231+
out2 = torch.ops.unsupported.linear.default(out, x, b)
232+
return out2
233+
234+
model = Model()
235+
model.eval()
236+
237+
example_inputs = (torch.randn(2, 2), torch.randn(2, 2), torch.randn(2, 2))
238+
exir_program_aten = torch.export.export(model, example_inputs, strict=True)
239+
edge_program_manager = executorch.exir.to_edge(exir_program_aten)
240+
edge_program_manager2 = copy.deepcopy(edge_program_manager)
241+
242+
delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner())
243+
244+
for node in delegated_program_manager.exported_program().graph.nodes:
245+
if node.op == "call_function":
246+
assert node.target.__name__ in [
247+
"unsupported.linear.default",
248+
"executorch_call_delegate",
249+
"getitem",
250+
], node.target.__name__
251+
252+
with self.assertRaises(NotImplementedError):
253+
edge_program_manager2.to_backend(CoreMLPartitioner(lower_full_graph=True))
254+
255+
# TODO: enable this after bugs are fixed in ExecuTorch's partitioner
256+
# def test_symint_arg(self):
257+
# class Model(torch.nn.Module):
258+
# def forward(self, x, w, b, y):
259+
# val = y.item()
260+
# torch._check(val >= 0)
261+
# torch._check(val < 2)
262+
# out = torch.ops.aten.linear.default(x, w, b)
263+
# out2 = out.relu()[val]
264+
# return out2
265+
266+
# model = Model()
267+
# model.eval()
268+
# example_inputs = (
269+
# torch.randn(2, 2),
270+
# torch.randn(2, 2),
271+
# torch.randn(2, 2),
272+
# torch.tensor(2),
273+
# )
274+
# exir_program_aten = torch.export.export(model, example_inputs)
275+
276+
# edge_program_manager = executorch.exir.to_edge(exir_program_aten)
277+
278+
# delegated_program_manager = edge_program_manager.to_backend(CoreMLPartitioner(skip_ops_for_coreml_delegation=["aten.scalar_tensor.default"]))
279+
280+
# # This op has symbolic args
281+
# assert (
282+
# "torch.ops.aten._assert_scalar.default"
283+
# in delegated_program_manager.exported_program().graph_module.code
284+
# )
285+
286+
# if _TEST_RUNTIME:
287+
# et_prog = delegated_program_manager.to_executorch()
288+
# runtime = Runtime.get()
289+
# program = runtime.load_program(et_prog.buffer)
290+
# method = program.load_method("forward")
291+
# et_outputs = method.execute(*example_inputs)[0]
292+
# eager_outputs = model(*example_inputs)
293+
# self.assertTrue(torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02))
294+
295+
def test_take_over_constant_data_false(self):
296+
class Model(torch.nn.Module):
297+
def __init__(self):
298+
super().__init__()
299+
self.linear = torch.nn.Linear(50, 100)
300+
301+
def forward(self, x):
302+
return self.linear(x)
303+
304+
model = Model()
305+
model.eval()
306+
example_inputs = (torch.randn(2, 50),)
307+
exir_program_aten = torch.export.export(model, example_inputs)
308+
309+
edge_program_manager = executorch.exir.to_edge_transform_and_lower(
310+
exir_program_aten,
311+
partitioner=[CoreMLPartitioner(take_over_constant_data=False)],
312+
)
313+
for node in edge_program_manager.exported_program().graph.nodes:
314+
if (
315+
node.op == "call_function"
316+
and node.target.__name__ == "executorch_call_delegate"
317+
):
318+
break
319+
320+
# lowered_module_0, x, p_linear_weight, p_linear_bias
321+
assert len(node.args) == 4
322+
323+
if _TEST_RUNTIME:
324+
et_prog = edge_program_manager.to_executorch()
325+
runtime = Runtime.get()
326+
program = runtime.load_program(et_prog.buffer)
327+
method = program.load_method("forward")
328+
et_outputs = method.execute(*example_inputs)[0]
329+
eager_outputs = model(*example_inputs)
330+
self.assertTrue(
331+
torch.allclose(et_outputs, eager_outputs, atol=1e-02, rtol=1e-02)
332+
)
333+
203334

204335
if __name__ == "__main__":
205336
test_runner = TestCoreMLPartitioner()
206337
test_runner.test_add_sub_skip_mm()
207338
test_runner.test_vit_skip_conv()
208339
test_runner.test_ops_to_not_decompose()
209340
test_runner.test_buffer()
341+
test_runner.test_lower_full_graph()
342+
# test_runner.test_symint_arg()
343+
test_runner.test_take_over_constant_data_false()

0 commit comments

Comments
 (0)