Skip to content

Commit c5475a6

Browse files
committed
Temporarily pick NVIDIA over AMD if both are present in the same system
1 parent 78ab1db commit c5475a6

File tree

2 files changed

+20
-8
lines changed

2 files changed

+20
-8
lines changed

tests/test_platform_detection.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,23 @@ def test_intel_gpu_mac(monkeypatch):
140140
get_torch_platform(gpu_infos)
141141

142142

143-
def test_multiple_gpu_vendors(monkeypatch):
143+
def test_multiple_gpu_vendors_with_NVIDIA(monkeypatch):
144144
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
145145
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
146146
gpu_infos = [
147147
GPU(AMD, "AMD", 0x1234, "Radeon", True),
148148
GPU(NVIDIA, "NVIDIA", 0x5678, "GeForce", True),
149149
]
150+
assert get_torch_platform(gpu_infos) == "cu124"
151+
152+
153+
def test_multiple_gpu_vendors_without_NVIDIA(monkeypatch):
154+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
155+
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
156+
gpu_infos = [
157+
GPU(AMD, "AMD", 0x1234, "Radeon", True),
158+
GPU(INTEL, "Intel", 0x5678, "Iris", True),
159+
]
150160
with pytest.raises(NotImplementedError):
151161
get_torch_platform(gpu_infos)
152162

@@ -376,6 +386,4 @@ def test_mixed_multiple_discrete_and_integrated(monkeypatch):
376386
GPU(NVIDIA, "NVIDIA", "2504", "RTX 3060", True), # discrete NVIDIA
377387
GPU(AMD, "AMD", "73f0", "Navi 33", True), # discrete AMD
378388
]
379-
# Should raise NotImplementedError due to multiple discrete GPU vendors
380-
with pytest.raises(NotImplementedError):
381-
get_torch_platform(gpu_infos)
389+
assert get_torch_platform(gpu_infos) == "cu124"

torchruntime/platform_detection.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,14 @@ def _get_platform_for_discrete(gpu_infos):
6161
vendor_ids = set(gpu.vendor_id for gpu in gpu_infos)
6262

6363
if len(vendor_ids) > 1:
64-
device_names = list(gpu.vendor_name + " " + gpu.device_name for gpu in gpu_infos)
65-
raise NotImplementedError(
66-
f"torchruntime does not currently support multiple graphics card manufacturers on the same computer: {device_names}! Please contact torchruntime at {CONTACT_LINK} with details about your hardware."
67-
)
64+
if NVIDIA in vendor_ids: # temp hack to pick NVIDIA over everything else, pending a better fix
65+
gpu_infos = [gpu for gpu in gpu_infos if gpu.vendor_id == NVIDIA]
66+
vendor_ids = set(gpu.vendor_id for gpu in gpu_infos)
67+
else:
68+
device_names = list(gpu.vendor_name + " " + gpu.device_name for gpu in gpu_infos)
69+
raise NotImplementedError(
70+
f"torchruntime does not currently support multiple graphics card manufacturers on the same computer: {device_names}! Please contact torchruntime at {CONTACT_LINK} with details about your hardware."
71+
)
6872

6973
vendor_id = vendor_ids.pop()
7074
if vendor_id == AMD:

0 commit comments

Comments
 (0)