Skip to content

Commit a538331

Browse files
committed
[gpu] Use syevBatched for eigh
Use syevBatched for eigh
1 parent 193c4ae commit a538331

File tree

2 files changed

+27
-14
lines changed

2 files changed

+27
-14
lines changed

jaxlib/gpu/solver_kernels_ffi.cc

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -423,16 +423,21 @@ ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream,
423423
params_cleanup(
424424
params, [](gpusolverDnParams_t p) { gpusolverDnDestroyParams(p); });
425425

426+
int64_t matrix_size = n * n * ffi::ByteWidth(dataType);
427+
int64_t batch_step = std::numeric_limits<int>::max() / matrix_size;
428+
if (batch_step >= 32 * 1024) {
429+
batch_step = 32 * 1024;
430+
}
426431
size_t workspaceInBytesOnDevice, workspaceInBytesOnHost;
427-
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd_bufferSize(
432+
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevBatched_bufferSize(
428433
handle.get(), params, jobz, uplo, n, aType, /*a=*/nullptr, n, wType,
429434
/*w=*/nullptr, aType, &workspaceInBytesOnDevice,
430-
&workspaceInBytesOnHost));
435+
&workspaceInBytesOnHost, std::min(batch, batch_step)));
431436

432437
auto maybe_workspace = scratch.Allocate(workspaceInBytesOnDevice);
433438
if (!maybe_workspace.has_value()) {
434439
return ffi::Error(ffi::ErrorCode::kResourceExhausted,
435-
"Unable to allocate device workspace for syevd");
440+
"Unable to allocate device workspace for syevBatched");
436441
}
437442
auto workspaceOnDevice = maybe_workspace.value();
438443
auto workspaceOnHost =
@@ -447,17 +452,18 @@ ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream,
447452
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
448453
}
449454

450-
size_t out_step = n * n * ffi::ByteWidth(dataType);
451-
size_t w_step = n * ffi::ByteWidth(ffi::ToReal(dataType));
455+
size_t out_step = n * n * ffi::ByteWidth(dataType) * batch_step;
456+
size_t w_step = n * ffi::ByteWidth(ffi::ToReal(dataType)) * batch_step;
452457

453-
for (auto i = 0; i < batch; ++i) {
454-
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd(
458+
for (int64_t i = 0; i < batch; i += batch_step) {
459+
size_t batch_size = static_cast<size_t>(std::min(batch_step, batch - i));
460+
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevBatched(
455461
handle.get(), params, jobz, uplo, n, aType, out_data, n, wType, w_data,
456462
aType, workspaceOnDevice, workspaceInBytesOnDevice,
457-
workspaceOnHost.get(), workspaceInBytesOnHost, info_data));
463+
workspaceOnHost.get(), workspaceInBytesOnHost, info_data, batch_size));
458464
out_data += out_step;
459465
w_data += w_step;
460-
++info_data;
466+
info_data += batch_step;
461467
}
462468

463469
return ffi::Error::Success();
@@ -576,18 +582,23 @@ ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
576582
CheckShape(out->dimensions(), {batch, rows, cols}, "out", "syevd"));
577583
FFI_RETURN_IF_ERROR(CheckShape(w->dimensions(), {batch, cols}, "w", "syevd"));
578584
FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "syevd"));
579-
if (algorithm == SyevdAlgorithm::kJacobi ||
580-
(algorithm == SyevdAlgorithm::kDefault && cols <= 32)) {
585+
#if JAX_GPU_HAVE_64_BIT
586+
if (algorithm == SyevdAlgorithm::kJacobi) {
581587
SOLVER_DISPATCH_IMPL(SyevdjImpl, batch, cols, stream, scratch, lower, a,
582-
out, w, info);
588+
out, w, info);
583589
} else {
584-
#if JAX_GPU_HAVE_64_BIT
585590
return Syevd64Impl(batch, cols, stream, scratch, lower, a, out, w, info);
591+
}
586592
#else
593+
if (algorithm == SyevdAlgorithm::kJacobi ||
594+
(algorithm == SyevdAlgorithm::kDefault && cols <= 32)) {
595+
SOLVER_DISPATCH_IMPL(SyevdjImpl, batch, cols, stream, scratch, lower, a,
596+
out, w, info);
597+
} else {
587598
SOLVER_DISPATCH_IMPL(SyevdImpl, batch, cols, stream, scratch, lower, a, out,
588599
w, info);
589-
#endif
590600
}
601+
#endif
591602
return ffi::Error::InvalidArgument(absl::StrFormat(
592603
"Unsupported dtype %s in syevd", absl::FormatStreamed(dataType)));
593604
}

jaxlib/gpu/vendor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,8 @@ typedef cusolverDnParams_t gpusolverDnParams_t;
412412

413413
#define gpusolverDnXsyevd_bufferSize cusolverDnXsyevd_bufferSize
414414
#define gpusolverDnXsyevd cusolverDnXsyevd
415+
#define gpusolverDnXsyevBatched_bufferSize cusolverDnXsyevBatched_bufferSize
416+
#define gpusolverDnXsyevBatched cusolverDnXsyevBatched
415417
#define gpusolverDnXgesvd_bufferSize cusolverDnXgesvd_bufferSize
416418
#define gpusolverDnXgesvd cusolverDnXgesvd
417419

0 commit comments

Comments
 (0)