File tree Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Expand file tree Collapse file tree 1 file changed +2
-4
lines changed Original file line number Diff line number Diff line change @@ -255,17 +255,15 @@ def get_extensions():
255255 print (
256256 "PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
257257 )
258- if (CUDA_HOME is None and ROCM_HOME is None ) and torch .cuda . is_available () :
258+ if (CUDA_HOME is None and ROCM_HOME is None ) and torch .version . cuda :
259259 print (
260260 "CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
261261 )
262262 print (
263263 "If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
264264 )
265265
266- use_cuda = torch .cuda .is_available () and (
267- CUDA_HOME is not None or ROCM_HOME is not None
268- )
266+ use_cuda = torch .version .cuda and (CUDA_HOME is not None or ROCM_HOME is not None )
269267 extension = CUDAExtension if use_cuda else CppExtension
270268
271269 extra_link_args = []
You can’t perform that action at this time.
0 commit comments