Skip to content

Commit 90bbc2c

Browse files
committed
Fixed #87 Add Feature to Select GPU Device
1 parent 0b5e658 commit 90bbc2c

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

stumpy/gpu_stump.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,12 @@ def _gpu_stump(
338338

339339

340340
def gpu_stump(
341-
T_A, m, T_B=None, ignore_trivial=True, threads_per_block=THREADS_PER_BLOCK
341+
T_A,
342+
m,
343+
T_B=None,
344+
ignore_trivial=True,
345+
threads_per_block=THREADS_PER_BLOCK,
346+
device_id=0,
342347
):
343348
"""
344349
Compute the matrix profile with GPU-STOMP
@@ -364,6 +369,9 @@ def gpu_stump(
364369
The number of GPU threads to use for all kernels. The default value is
365370
set in `THREADS_PER_BLOCK=512`.
366371
372+
device_id : int
373+
The (GPU) device number to use. The defailt value is `0`.
374+
367375
Returns
368376
-------
369377
out : ndarray
@@ -443,6 +451,12 @@ def gpu_stump(
443451
start = 0
444452
stop = l
445453

454+
cuda.select_device(device_id)
455+
if (
456+
cuda.current_context().__class__.__name__ != "FakeCUDAContext"
457+
): # pragma: no cover
458+
cuda.current_context().deallocations.clear()
459+
446460
QT, QT_first = _get_QT(start, T_A, T_B, m)
447461
profile[:], indices[:, :] = _gpu_stump(
448462
T_A,

test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ check_errs()
1515
}
1616

1717
echo "Testing Numba JIT Compiled Functions"
18+
py.test -rsx -W ignore::RuntimeWarning -W ignore::DeprecationWarning tests/test_gpu_stump.py
1819
py.test -x -W ignore::RuntimeWarning -W ignore::DeprecationWarning tests/test_stump.py tests/test_mstump.py
1920
check_errs $?
2021
py.test -x -W ignore::RuntimeWarning -W ignore::DeprecationWarning tests/test_stumped.py tests/test_mstumped.py

0 commit comments

Comments
 (0)