Skip to content

Commit 45a51b2

Browse files
committed
Use nightly/cu128 for nvidia 50xx
1 parent e8017b6 commit 45a51b2

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

tests/test_platform_detection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_nvidia_5xxx_gpu_windows(monkeypatch):
120120
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
121121
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
122122
gpu_infos = [GPU(NVIDIA, "NVIDIA", "2c02", "GB203 [GeForce RTX 5080]", True)]
123-
assert get_torch_platform(gpu_infos) == "nightly/cu124"
123+
assert get_torch_platform(gpu_infos) == "nightly/cu128"
124124

125125

126126
def test_intel_gpu_windows(monkeypatch):
@@ -143,7 +143,7 @@ def test_nvidia_5xxx_gpu_linux(monkeypatch):
143143
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
144144
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
145145
gpu_infos = [GPU(NVIDIA, "NVIDIA", "2c02", "GB203 [GeForce RTX 5080]", True)]
146-
assert get_torch_platform(gpu_infos) == "nightly/cu124"
146+
assert get_torch_platform(gpu_infos) == "nightly/cu128"
147147

148148

149149
def test_intel_gpu_mac(monkeypatch):

torchruntime/platform_detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _get_platform_for_discrete(gpu_infos):
111111
if os_name in ("Windows", "Linux"):
112112
device_names = set(gpu.device_name for gpu in gpu_infos)
113113
if any(BLACKWELL_DEVICES.search(device_name) for device_name in device_names):
114-
return "nightly/cu124"
114+
return "nightly/cu128"
115115

116116
return "cu124"
117117
elif os_name == "Darwin":

0 commit comments

Comments
 (0)