Skip to content

Commit 0e24548

Browse files
authored
Add safeguards for CUDA kernel load in Deformable DETR (#19037)
1 parent 31be02f commit 0e24548

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

src/transformers/models/deformable_detr/modeling_deformable_detr.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,21 @@
4141
)
4242
from ...modeling_outputs import BaseModelOutput
4343
from ...modeling_utils import PreTrainedModel
44-
from ...utils import logging
44+
from ...utils import is_ninja_available, logging
4545
from .configuration_deformable_detr import DeformableDetrConfig
4646
from .load_custom import load_cuda_kernels
4747

4848

4949
logger = logging.get_logger(__name__)
5050

5151
# Move this to not compile only when importing, this needs to happen later, like in __init__.
52-
if is_torch_cuda_available():
52+
if is_torch_cuda_available() and is_ninja_available():
5353
logger.info("Loading custom CUDA kernels...")
54-
MultiScaleDeformableAttention = load_cuda_kernels()
54+
try:
55+
MultiScaleDeformableAttention = load_cuda_kernels()
56+
except Exception as e:
57+
logger.warning(f"Could not load the custom kernel for multi-scale deformable attention: {e}")
58+
MultiScaleDeformableAttention = None
5559
else:
5660
MultiScaleDeformableAttention = None
5761

src/transformers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
is_in_notebook,
9999
is_ipex_available,
100100
is_librosa_available,
101+
is_ninja_available,
101102
is_onnx_available,
102103
is_pandas_available,
103104
is_phonemizer_available,

src/transformers/utils/import_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,10 @@ def is_apex_available():
471471
return importlib.util.find_spec("apex") is not None
472472

473473

474+
def is_ninja_available():
475+
return importlib.util.find_spec("ninja") is not None
476+
477+
474478
def is_ipex_available():
475479
def get_major_and_minor_from_version(full_version):
476480
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)

0 commit comments

Comments
 (0)