Skip to content

Commit 73a5f49

Browse files
committed
cross compile guard for windows
1 parent 5e38c04 commit 73a5f49

File tree

3 files changed

+63
-2
lines changed

3 files changed

+63
-2
lines changed

py/torch_tensorrt/_features.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
from collections import namedtuple
44
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar
55

6-
from torch_tensorrt._utils import sanitized_torch_version
6+
from torch_tensorrt._utils import (
7+
check_cross_compile_trt_win_lib,
8+
sanitized_torch_version,
9+
)
710

811
from packaging import version
912

@@ -15,6 +18,7 @@
1518
"dynamo_frontend",
1619
"fx_frontend",
1720
"refit",
21+
"windows_cross_compile",
1822
],
1923
)
2024

@@ -38,9 +42,15 @@
3842
_DYNAMO_FE_AVAIL = version.parse(sanitized_torch_version()) >= version.parse("2.1.dev")
3943
_FX_FE_AVAIL = True
4044
_REFIT_AVAIL = version.parse(sys.version.split()[0]) < version.parse("3.13")
45+
_WINDOWS_CROSS_COMPILE = check_cross_compile_trt_win_lib
4146

4247
ENABLED_FEATURES = FeatureSet(
43-
_TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL
48+
_TS_FE_AVAIL,
49+
_TORCHTRT_RT_AVAIL,
50+
_DYNAMO_FE_AVAIL,
51+
_FX_FE_AVAIL,
52+
_REFIT_AVAIL,
53+
_WINDOWS_CROSS_COMPILE,
4454
)
4555

4656

@@ -80,6 +90,22 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
8090
return wrapper
8191

8292

93+
def needs_cross_compile(f: Callable[..., Any]) -> Callable[..., Any]:
94+
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
95+
if ENABLED_FEATURES.windows_cross_compile:
96+
return f(*args, **kwargs)
97+
else:
98+
99+
def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
100+
raise NotImplementedError(
101+
"Windows cross compilation feature is not available"
102+
)
103+
104+
return not_implemented(*args, **kwargs)
105+
106+
return wrapper
107+
108+
83109
T = TypeVar("T")
84110

85111

py/torch_tensorrt/_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import ctypes
2+
import platform
3+
import sys
14
from typing import Any
25

36
import torch
7+
from torch_tensorrt import _find_lib
48

59

610
def sanitized_torch_version() -> Any:
@@ -9,3 +13,27 @@ def sanitized_torch_version() -> Any:
913
if ".nv" not in torch.__version__
1014
else torch.__version__.split(".nv")[0]
1115
)
16+
17+
18+
def check_cross_compile_trt_win_lib(lib_name):
19+
if sys.platform.startswith("linux"):
20+
LINUX_PATHS = ["/usr/local/cuda-12.8/lib64", "/usr/lib", "/usr/lib64"]
21+
22+
if platform.uname().processor == "x86_64":
23+
LINUX_PATHS += [
24+
"/usr/lib/x86_64-linux-gnu",
25+
]
26+
elif platform.uname().processor == "aarch64":
27+
LINUX_PATHS += ["/usr/lib/aarch64-linux-gnu"]
28+
29+
LINUX_LIBS = [
30+
f"libnvinfer_builder_resource_win.so.*",
31+
]
32+
33+
for lib in LINUX_LIBS:
34+
try:
35+
ctypes.CDLL(_find_lib(lib, LINUX_PATHS))
36+
return True
37+
except:
38+
continue
39+
return False

tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import torch_tensorrt
99
from torch.testing._internal.common_utils import TestCase
10+
from torch_tensorrt._utils import check_cross_compile_trt_win_lib
1011

1112
from ..testing_utilities import DECIMALS_OF_AGREEMENT
1213

@@ -16,6 +17,9 @@ class TestCrossCompileSaveForWindows(TestCase):
1617
platform.system() != "Linux" or platform.architecture()[0] != "64bit",
1718
"Cross compile for windows can only be enabled on linux x86-64 platform",
1819
)
20+
@unittest.skipIf(
21+
not (check_cross_compile_trt_win_lib()),
22+
)
1923
@pytest.mark.unit
2024
def test_cross_compile_for_windows(self):
2125
class Add(torch.nn.Module):
@@ -40,6 +44,9 @@ def forward(self, a, b):
4044
platform.system() != "Linux" or platform.architecture()[0] != "64bit",
4145
"Cross compile for windows can only be enabled on linux x86-64 platform",
4246
)
47+
@unittest.skipIf(
48+
not (check_cross_compile_trt_win_lib()),
49+
)
4350
@pytest.mark.unit
4451
def test_dynamo_cross_compile_for_windows(self):
4552
class Add(torch.nn.Module):

0 commit comments

Comments
 (0)