@@ -423,16 +423,21 @@ ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream,
423
423
params_cleanup (
424
424
params, [](gpusolverDnParams_t p) { gpusolverDnDestroyParams (p); });
425
425
426
+ int64_t matrix_size = n * n * ffi::ByteWidth (dataType);
427
+ int64_t batch_step = std::numeric_limits<int >::max () / matrix_size;
428
+ if (batch_step >= 32 * 1024 ) {
429
+ batch_step = 32 * 1024 ;
430
+ }
426
431
size_t workspaceInBytesOnDevice, workspaceInBytesOnHost;
427
- JAX_FFI_RETURN_IF_GPU_ERROR (gpusolverDnXsyevd_bufferSize (
432
+ JAX_FFI_RETURN_IF_GPU_ERROR (gpusolverDnXsyevBatched_bufferSize (
428
433
handle.get (), params, jobz, uplo, n, aType, /* a=*/ nullptr , n, wType,
429
434
/* w=*/ nullptr , aType, &workspaceInBytesOnDevice,
430
- &workspaceInBytesOnHost));
435
+ &workspaceInBytesOnHost, std::min (batch, batch_step) ));
431
436
432
437
auto maybe_workspace = scratch.Allocate (workspaceInBytesOnDevice);
433
438
if (!maybe_workspace.has_value ()) {
434
439
return ffi::Error (ffi::ErrorCode::kResourceExhausted ,
435
- " Unable to allocate device workspace for syevd " );
440
+ " Unable to allocate device workspace for syevBatched " );
436
441
}
437
442
auto workspaceOnDevice = maybe_workspace.value ();
438
443
auto workspaceOnHost =
@@ -447,17 +452,18 @@ ffi::Error Syevd64Impl(int64_t batch, int64_t n, gpuStream_t stream,
447
452
out_data, a_data, a.size_bytes (), gpuMemcpyDeviceToDevice, stream));
448
453
}
449
454
450
- size_t out_step = n * n * ffi::ByteWidth (dataType);
451
- size_t w_step = n * ffi::ByteWidth (ffi::ToReal (dataType));
455
+ size_t out_step = n * n * ffi::ByteWidth (dataType) * batch_step ;
456
+ size_t w_step = n * ffi::ByteWidth (ffi::ToReal (dataType)) * batch_step ;
452
457
453
- for (auto i = 0 ; i < batch; ++i) {
454
- JAX_FFI_RETURN_IF_GPU_ERROR (gpusolverDnXsyevd (
458
+ for (int64_t i = 0 ; i < batch; i += batch_step) {
459
+ size_t batch_size = static_cast <size_t >(std::min (batch_step, batch - i));
460
+ JAX_FFI_RETURN_IF_GPU_ERROR (gpusolverDnXsyevBatched (
455
461
handle.get (), params, jobz, uplo, n, aType, out_data, n, wType, w_data,
456
462
aType, workspaceOnDevice, workspaceInBytesOnDevice,
457
- workspaceOnHost.get (), workspaceInBytesOnHost, info_data));
463
+ workspaceOnHost.get (), workspaceInBytesOnHost, info_data, batch_size ));
458
464
out_data += out_step;
459
465
w_data += w_step;
460
- ++ info_data;
466
+ info_data += batch_step ;
461
467
}
462
468
463
469
return ffi::Error::Success ();
@@ -576,18 +582,23 @@ ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch,
576
582
CheckShape (out->dimensions (), {batch, rows, cols}, " out" , " syevd" ));
577
583
FFI_RETURN_IF_ERROR (CheckShape (w->dimensions (), {batch, cols}, " w" , " syevd" ));
578
584
FFI_RETURN_IF_ERROR (CheckShape (info->dimensions (), batch, " info" , " syevd" ));
579
- if (algorithm == SyevdAlgorithm:: kJacobi ||
580
- (algorithm == SyevdAlgorithm::kDefault && cols <= 32 ) ) {
585
+ # if JAX_GPU_HAVE_64_BIT
586
+ if (algorithm == SyevdAlgorithm::kJacobi ) {
581
587
SOLVER_DISPATCH_IMPL (SyevdjImpl, batch, cols, stream, scratch, lower, a,
582
- out, w, info);
588
+ out, w, info);
583
589
} else {
584
- #if JAX_GPU_HAVE_64_BIT
585
590
return Syevd64Impl (batch, cols, stream, scratch, lower, a, out, w, info);
591
+ }
586
592
#else
593
+ if (algorithm == SyevdAlgorithm::kJacobi ||
594
+ (algorithm == SyevdAlgorithm::kDefault && cols <= 32 )) {
595
+ SOLVER_DISPATCH_IMPL (SyevdjImpl, batch, cols, stream, scratch, lower, a,
596
+ out, w, info);
597
+ } else {
587
598
SOLVER_DISPATCH_IMPL (SyevdImpl, batch, cols, stream, scratch, lower, a, out,
588
599
w, info);
589
- #endif
590
600
}
601
+ #endif
591
602
return ffi::Error::InvalidArgument (absl::StrFormat (
592
603
" Unsupported dtype %s in syevd" , absl::FormatStreamed (dataType)));
593
604
}
0 commit comments