Skip to content

Commit 2ff245a

Browse files
Gasoonjiafacebook-github-bot
authored andcommitted
cuda export supported (#14478)
Summary: this diff introuce the cuda backend that compiles the partitioned model graph to run on CUDA devices. It uses the AOTInductor compiler to generate optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices using the Executorch runtime. Reviewed By: angelayi, larryliu0820 Differential Revision: D82987410
1 parent bfb502d commit 2ff245a

File tree

6 files changed

+457
-4
lines changed

6 files changed

+457
-4
lines changed

backends/cuda/TARGETS

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22

33
oncall("executorch")
44

5+
runtime.python_library(
6+
name = "cuda_backend",
7+
srcs = [
8+
"cuda_backend.py",
9+
],
10+
visibility = [
11+
"//executorch/...",
12+
],
13+
deps = [
14+
"//caffe2:torch",
15+
"//executorch/exir/_serialize:lib",
16+
"//executorch/exir/backend:backend_details",
17+
"//executorch/exir/backend:compile_spec_schema",
18+
],
19+
)
20+
521
runtime.python_library(
622
name = "cuda_partitioner",
723
srcs = [

backends/cuda/cuda_backend.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import contextlib
8+
import os
9+
import typing
10+
11+
from typing import Any, Dict, final, List, Optional, Set
12+
13+
import torch
14+
from executorch.exir._serialize._named_data_store import NamedDataStore
15+
from executorch.exir.backend.backend_details import (
16+
BackendDetails,
17+
ExportedProgram,
18+
PreprocessResult,
19+
)
20+
from executorch.exir.backend.compile_spec_schema import CompileSpec
21+
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
22+
from torch.export.passes import move_to_device_pass
23+
24+
25+
# exist fallback operators in et namespace;
26+
supported_fallback_kernels: Dict[str, Any] = {}
27+
28+
# required fallback kernels but not supported
29+
missing_fallback_kernels: Set[str] = set()
30+
31+
32+
# context manager for non-fallback guarantee
33+
# it will raise exception when generating fallback kernels during aoti compile
34+
@contextlib.contextmanager
35+
def collect_unsupported_fallback_kernels():
36+
original_generate_c_shim_extern_kernel_call = (
37+
CppWrapperCpu.generate_c_shim_extern_kernel_call
38+
)
39+
original_generate_fallback_kernel_with_runtime_lookup_aot = (
40+
CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot
41+
)
42+
43+
def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels(
44+
self,
45+
kernel: str,
46+
args: list[str],
47+
device: str,
48+
*,
49+
debug_args: Optional[list[str]] = None,
50+
):
51+
if kernel not in supported_fallback_kernels:
52+
missing_fallback_kernels.add(kernel)
53+
54+
original_generate_c_shim_extern_kernel_call(
55+
self, kernel, args, device, debug_args=debug_args
56+
)
57+
58+
def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels(
59+
self,
60+
op_overload,
61+
raw_args,
62+
output_args,
63+
raw_outputs,
64+
):
65+
# Extract kernel name for collection
66+
kernel_name = getattr(op_overload, "_name", str(op_overload))
67+
if kernel_name not in supported_fallback_kernels:
68+
missing_fallback_kernels.add(kernel_name)
69+
70+
original_generate_fallback_kernel_with_runtime_lookup_aot(
71+
self, op_overload, raw_args, output_args, raw_outputs
72+
)
73+
74+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
75+
generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels
76+
)
77+
CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = (
78+
generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels
79+
)
80+
try:
81+
yield
82+
finally:
83+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
84+
original_generate_c_shim_extern_kernel_call
85+
)
86+
CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = (
87+
original_generate_fallback_kernel_with_runtime_lookup_aot
88+
)
89+
90+
91+
@final
92+
class CudaBackend(BackendDetails):
93+
"""
94+
CudaBackend is a backend that compiles a model to run on CUDA devices. It uses the AOTInductor compiler to generate
95+
optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices
96+
using the Executorch runtime.
97+
"""
98+
99+
@staticmethod
100+
def preprocess(
101+
edge_program: ExportedProgram,
102+
compile_specs: List[CompileSpec],
103+
) -> PreprocessResult:
104+
# Move the edge_program from CPU to CUDA for aoti compile
105+
cuda_edge_program = move_to_device_pass(edge_program, "cuda")
106+
107+
edge_program_module = cuda_edge_program.module()
108+
109+
# Grab all input placeholders from the graph
110+
user_input_names = cuda_edge_program.graph_signature.user_inputs
111+
user_input_placeholders = []
112+
for node in cuda_edge_program.graph.nodes:
113+
if node.op == "placeholder" and node.name in user_input_names:
114+
user_input_placeholders.append(node.meta["val"])
115+
116+
# Create pseudo user inputs using torch.randn and metadata from input placeholders
117+
faked_user_inputs = []
118+
for placeholder in user_input_placeholders:
119+
if isinstance(placeholder, torch.Tensor):
120+
# Generate fake input with same shape and dtype, on CUDA
121+
fake_input = torch.randn(
122+
placeholder.shape, dtype=placeholder.dtype, device="cuda"
123+
)
124+
faked_user_inputs.append(fake_input)
125+
126+
faked_user_inputs = tuple(faked_user_inputs)
127+
128+
options: dict[str, typing.Any] = {
129+
# Embed CUDA kernel binaries directly into the compiled shared object
130+
"aot_inductor.embed_kernel_binary": True,
131+
# Do not link against the full PyTorch/libtorch library
132+
"aot_inductor.link_libtorch": False,
133+
# Package model constants and other generated files directly in the shared object (.so) file
134+
"aot_inductor.package_constants_in_so": True,
135+
# Enable maximum automatic tuning for optimal performance
136+
"max_autotune": True,
137+
# Use TRITON for GEMM (General Matrix Multiply) operations tuning only to avoid using operators in libtorch
138+
"max_autotune_gemm_backends": "TRITON",
139+
# Use TRITON backend for convolution operations tuning only to avoid using operators in libtorch
140+
"max_autotune_conv_backends": "TRITON",
141+
}
142+
143+
with collect_unsupported_fallback_kernels():
144+
so_path = torch._inductor.aot_compile(edge_program_module, faked_user_inputs, options=options) # type: ignore[arg-type]
145+
if len(missing_fallback_kernels) > 0:
146+
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
147+
raise RuntimeError(
148+
f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
149+
"Please add them to the AOTI backend."
150+
)
151+
152+
# pyre-ignorep[6]: Incompatible parameter type
153+
with open(so_path, "rb") as f:
154+
so_data = f.read()
155+
156+
named_data_store = NamedDataStore()
157+
named_data_store.add_named_data("so_blob", so_data, 1, "aoti_cuda_blob")
158+
159+
# Clean up the generated so file; it has been packaged into the NamdeDataStore
160+
# pyre-ignorep[6]: Incompatible parameter type
161+
os.remove(so_path)
162+
163+
return PreprocessResult(
164+
processed_bytes=b"",
165+
debug_handle_map={},
166+
data_store_output=named_data_store.get_named_data_store_output(),
167+
)

backends/cuda/cuda_partitioner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Callable, Dict, final, List, Optional, Tuple
88

99
import torch
10+
from executorch.backends.cuda.cuda_backend import CudaBackend # usort: skip
1011
from executorch.exir.backend.compile_spec_schema import CompileSpec
1112
from executorch.exir.backend.partitioner import (
1213
DelegationSpec,
@@ -31,7 +32,7 @@ class CudaPartitioner(Partitioner):
3132
"""
3233

3334
def __init__(self, compile_spec: List[CompileSpec]) -> None:
34-
self.delegation_spec = DelegationSpec("CudaBackend", compile_spec)
35+
self.delegation_spec = DelegationSpec(CudaBackend.__name__, compile_spec)
3536

3637
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
3738
"""

backends/cuda/tests/TARGETS

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,28 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
3+
load("@fbcode_macros//build_defs:python_unittest_remote_gpu.bzl", "python_unittest_remote_gpu")
34

45
oncall("executorch")
56

7+
python_unittest_remote_gpu(
8+
name = "test_cuda_export",
9+
srcs = [
10+
"test_cuda_export.py",
11+
],
12+
visibility = [
13+
"//executorch/...",
14+
],
15+
deps = [
16+
"//caffe2:torch",
17+
"//executorch/backends/cuda:cuda_backend",
18+
"//executorch/backends/cuda:cuda_partitioner",
19+
"//executorch/exir:lib",
20+
"//executorch/exir/backend:backend_api",
21+
"//executorch/exir/backend:compile_spec_schema",
22+
],
23+
keep_gpu_sections = True,
24+
)
25+
626
python_unittest(
727
name = "test_cuda_partitioner",
828
srcs = [
@@ -14,6 +34,7 @@ python_unittest(
1434
deps = [
1535
"//caffe2:torch",
1636
"//executorch/backends/cuda:cuda_partitioner",
37+
"//executorch/backends/cuda:cuda_backend",
1738
"//executorch/exir:lib",
1839
"//executorch/exir/backend:compile_spec_schema",
1940
],

0 commit comments

Comments
 (0)