Skip to content

Commit 39c6e3e

Browse files
committed
Fix device detection regex for NVIDIA GPUs. Missed GM20x (Maxwell) like Tesla M40
1 parent 9593399 commit 39c6e3e

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

tests/test_platform_detection.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,17 @@ def test_multiple_gpu_NVIDIA(monkeypatch):
254254
assert get_torch_platform(gpu_infos) == expected
255255

256256

257+
def test_multiple_gpu_NVIDIA_maxwell(monkeypatch):
258+
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
259+
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
260+
gpu_infos = [
261+
GPU(NVIDIA, "NVIDIA", "17fd", "GM200GL [Tesla M40]", True),
262+
GPU(NVIDIA, "NVIDIA", "1401", "GM206 [GeForce GTX 960]", True),
263+
]
264+
expected = "cu124"
265+
assert get_torch_platform(gpu_infos) == expected
266+
267+
257268
def test_multiple_gpu_AMD_Navi3_Navi2(monkeypatch, capsys):
258269
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
259270
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")

torchruntime/platform_detection.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
py_version = sys.version_info
1111

1212
# https://www.techpowerup.com/gpu-specs/?architecture=Kepler&sort=generation and so on (change the arch field)
13-
KEPLER_DEVICES = re.compile(r"\b(gk1\d{2}\w*)\b", re.IGNORECASE) # sm3.7
14-
MAXWELL_DEVICES = re.compile(r"\b(gm10\d\w*)\b", re.IGNORECASE) # sm5
15-
PASCAL_DEVICES = re.compile(r"\b(gp10\d\w*)\b", re.IGNORECASE) # sm6
16-
VOLTA_DEVICES = re.compile(r"\b(gv100\w*)\b", re.IGNORECASE) # sm7
17-
TURING_DEVICES = re.compile(r"\b(tu1\d{2}\w*)\b", re.IGNORECASE) # sm7.5
18-
AMPERE_DEVICES = re.compile(r"\b(ga10\d\w*)\b", re.IGNORECASE) # sm8.6
19-
ADA_LOVELACE_DEVICES = re.compile(r"\b(ad10\d\w*)\b", re.IGNORECASE) # sm8.9
13+
KEPLER_DEVICES = re.compile(r"\b(gk\d+\w*)\b", re.IGNORECASE) # sm3.7
14+
MAXWELL_DEVICES = re.compile(r"\b(gm\d+\w*)\b", re.IGNORECASE) # sm5
15+
PASCAL_DEVICES = re.compile(r"\b(gp\d+\w*)\b", re.IGNORECASE) # sm6
16+
VOLTA_DEVICES = re.compile(r"\b(gv\d+\w*)\b", re.IGNORECASE) # sm7
17+
TURING_DEVICES = re.compile(r"\b(tu\d+\w*)\b", re.IGNORECASE) # sm7.5
18+
AMPERE_DEVICES = re.compile(r"\b(ga\d+\w*)\b", re.IGNORECASE) # sm8.6
19+
ADA_LOVELACE_DEVICES = re.compile(r"\b(ad\d+\w*)\b", re.IGNORECASE) # sm8.9
2020
BLACKWELL_DEVICES = re.compile(r"\b(?:5060|5070|5080|5090)\b", re.IGNORECASE) # sm10, sm12
2121

2222
NVIDIA_ARCH_MAP = {

0 commit comments

Comments
 (0)