Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 55 additions & 13 deletions jaxlib/gpu/solver_kernels_ffi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,16 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch,

#if JAX_GPU_HAVE_64_BIT

absl::StatusOr<bool> IsSyevBatchedSupported() {
int version;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpusolverGetVersion(&version)));
// According to
// https://docs.nvidia.com/cuda/archive/12.6.2/cuda-toolkit-release-notes/index.html
// syevBatched is supported since CUDA 12.6.2, where CUSOLVER
// version is 11.7.1.
return version >= 11701;
}

ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream,
ffi::ScratchAllocator& scratch, bool lower,
ffi::AnyBuffer a, ffi::Result<ffi::AnyBuffer> out,
Expand All @@ -423,16 +433,36 @@ ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream,
params_cleanup(
params, [](gpusolverDnParams_t p) { gpusolverDnDestroyParams(p); });

int64_t batch_step = 1;
FFI_ASSIGN_OR_RETURN(bool is_batched_syev_supported, IsSyevBatchedSupported());
if (is_batched_syev_supported) {
int64_t matrix_size = n * n * ffi::ByteWidth(dataType);
batch_step = std::numeric_limits<int>::max() / matrix_size;
if (batch_step >= 32 * 1024) {
batch_step = 32 * 1024;
}
}
size_t workspaceInBytesOnDevice, workspaceInBytesOnHost;
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd_bufferSize(
handle.get(), params, jobz, uplo, n, aType, /*a=*/nullptr, n, wType,
/*w=*/nullptr, aType, &workspaceInBytesOnDevice,
&workspaceInBytesOnHost));
if (is_batched_syev_supported) {
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevBatched_bufferSize(
handle.get(), params, jobz, uplo, n, aType, /*a=*/nullptr, n, wType,
/*w=*/nullptr, aType, &workspaceInBytesOnDevice,
&workspaceInBytesOnHost, std::min(batch, batch_step)));
} else {
if (batch_step != 1) {
return ffi::Error(ffi::ErrorCode::kInternal,
"Syevd64Impl: batch_step != 1 but batched syev is not supported");
}
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd_bufferSize(
handle.get(), params, jobz, uplo, n, aType, /*a=*/nullptr, n, wType,
/*w=*/nullptr, aType, &workspaceInBytesOnDevice,
&workspaceInBytesOnHost));
}

auto maybe_workspace = scratch.Allocate(workspaceInBytesOnDevice);
if (!maybe_workspace.has_value()) {
return ffi::Error(ffi::ErrorCode::kResourceExhausted,
"Unable to allocate device workspace for syevd");
"Unable to allocate device workspace for syevBatched");
}
auto workspaceOnDevice = maybe_workspace.value();
auto workspaceOnHost =
Expand All @@ -447,17 +477,29 @@ ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream,
out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream));
}

size_t out_step = n * n * ffi::ByteWidth(dataType);
size_t w_step = n * ffi::ByteWidth(ffi::ToReal(dataType));
size_t out_step = n * n * ffi::ByteWidth(dataType) * batch_step;
size_t w_step = n * ffi::ByteWidth(ffi::ToReal(dataType)) * batch_step;

for (auto i = 0; i < batch; ++i) {
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd(
handle.get(), params, jobz, uplo, n, aType, out_data, n, wType, w_data,
aType, workspaceOnDevice, workspaceInBytesOnDevice,
workspaceOnHost.get(), workspaceInBytesOnHost, info_data));
for (int64_t i = 0; i < batch; i += batch_step) {
size_t batch_size = static_cast<size_t>(std::min(batch_step, batch - i));
if (is_batched_syev_supported) {
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevBatched(
handle.get(), params, jobz, uplo, n, aType, out_data, n, wType, w_data,
aType, workspaceOnDevice, workspaceInBytesOnDevice,
workspaceOnHost.get(), workspaceInBytesOnHost, info_data, batch_size));
} else {
if (batch_step != 1) {
return ffi::Error(ffi::ErrorCode::kInternal,
"Syevd64Impl: batch_step != 1 but batched syev is not supported");
}
JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXsyevd(
handle.get(), params, jobz, uplo, n, aType, out_data, n, wType, w_data,
aType, workspaceOnDevice, workspaceInBytesOnDevice,
workspaceOnHost.get(), workspaceInBytesOnHost, info_data));
}
out_data += out_step;
w_data += w_step;
++info_data;
info_data += batch_step;
}

return ffi::Error::Success();
Expand Down
4 changes: 4 additions & 0 deletions jaxlib/gpu/vendor.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,12 @@ typedef cusolverDnParams_t gpusolverDnParams_t;
#define gpusolverDnCreateParams cusolverDnCreateParams
#define gpusolverDnDestroyParams cusolverDnDestroyParams

#define gpusolverGetVersion cusolverGetVersion

#define gpusolverDnXsyevd_bufferSize cusolverDnXsyevd_bufferSize
#define gpusolverDnXsyevd cusolverDnXsyevd
#define gpusolverDnXsyevBatched_bufferSize cusolverDnXsyevBatched_bufferSize
#define gpusolverDnXsyevBatched cusolverDnXsyevBatched
#define gpusolverDnXgesvd_bufferSize cusolverDnXgesvd_bufferSize
#define gpusolverDnXgesvd cusolverDnXgesvd

Expand Down
Loading