@@ -38,15 +38,18 @@ static oneapi::mkl::transpose to_blas_(TransposeType trans) {
38
38
TORCH_INTERNAL_ASSERT (false , " Invalid transpose type" );
39
39
}
40
40
41
- void error_handle (int32_t * infos, const oneapi::mkl::lapack::batch_error& be) {
41
+ void error_handle (
42
+ int32_t * info_cpu,
43
+ const oneapi::mkl::lapack::batch_error& be) {
42
44
auto errs = be.exceptions ();
43
45
auto ids = be.ids ();
44
46
45
47
if (!errs.size ()) {
46
- TORCH_WARN (" Caught lapack exception:\n What: " , be.what (), " \n Info: " , be.info ());
48
+ TORCH_WARN (
49
+ " Caught lapack exception:\n What: " , be.what (), " \n Info: " , be.info ());
47
50
for (auto & i : ids) {
48
51
TORCH_WARN (" Error in matrix #" , i);
49
- infos [i] = 1 ;
52
+ info_cpu [i] = 1 ;
50
53
}
51
54
return ;
52
55
}
@@ -55,14 +58,17 @@ void error_handle(int32_t* infos, const oneapi::mkl::lapack::batch_error& be) {
55
58
try {
56
59
std::rethrow_exception (errs[i]);
57
60
} catch (const oneapi::mkl::lapack::exception& e) {
58
- std::cout << " Cathed lapack exception:"
59
- << " \n What: " << e.what () << " \n Info: " << e.info ()
60
- << " \n Detail: " << e.detail () << std::endl;
61
- infos[i] = e.info ();
61
+ TORCH_WARN (
62
+ " Caught lapack exception:\n What: " ,
63
+ e.what (),
64
+ " \n Info: " ,
65
+ e.info (),
66
+ " \n Detail: " ,
67
+ e.detail ());
68
+ info_cpu[i] = e.info ();
62
69
} catch (const sycl::exception& e) {
63
- std::cout << " Catched SYCL exception:"
64
- << " \n What: " << e.what () << " \n Info: -1" << std::endl;
65
- infos[i] = -1 ;
70
+ TORCH_WARN (" Caught SYCL exception:\n What: " , e.what (), " \n Info: -1" );
71
+ info_cpu[i] = -1 ;
66
72
}
67
73
}
68
74
}
@@ -383,7 +389,7 @@ template <typename scalar_t>
383
389
static void apply_lu_xpu_ (
384
390
const Tensor& self_,
385
391
Tensor& pivots_,
386
- int32_t * infos_ ) {
392
+ int32_t * info_data ) {
387
393
// do nothing if empty input.
388
394
if (self_.numel () == 0 )
389
395
return ;
@@ -414,7 +420,7 @@ static void apply_lu_xpu_(
414
420
(scalar_t *)(scratchpad_at.data_ptr ()),
415
421
scratchpadsize);
416
422
} catch (const oneapi::mkl::lapack::batch_error& be) {
417
- error_handle (infos_ , be);
423
+ error_handle (info_data , be);
418
424
}
419
425
}
420
426
@@ -447,8 +453,8 @@ static void apply_lu_solve_xpu_(
447
453
int64_t * ipiv = pivots.data_ptr <int64_t >();
448
454
scalar_t * b = b_.data_ptr <scalar_t >();
449
455
450
- std::vector<int32_t > infos (batch_size, 0 );
451
- int32_t * infos_ = infos .data ();
456
+ std::vector<int32_t > info_vec (batch_size, 0 );
457
+ int32_t * info_data = info_vec .data ();
452
458
453
459
auto execute_mkl_getrs =
454
460
[&](scalar_t * a, scalar_t * b, int64_t * ipiv, int64_t batch_size) {
@@ -482,7 +488,7 @@ static void apply_lu_solve_xpu_(
482
488
scratchpad_at.data_ptr <scalar_t >(),
483
489
scratchpad_size);
484
490
} catch (oneapi::mkl::lapack::batch_error be) {
485
- error_handle (infos_ , be);
491
+ error_handle (info_data , be);
486
492
}
487
493
};
488
494
@@ -541,13 +547,13 @@ void lu_factor_mkl(
541
547
542
548
// handle the info
543
549
Tensor info_ = at::zeros_like (info, Device (at::kCPU ));
544
- int32_t * infos_data = info_.data_ptr <int32_t >();
550
+ int32_t * info_data = info_.data_ptr <int32_t >();
545
551
546
552
// oneMKL requires Long for pivots but PyTorch provides Int
547
553
Tensor pivots_ = at::empty (pivots.sizes (), pivots.options ().dtype (kLong ));
548
554
549
555
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES (LU.scalar_type (), " lu_xpu" , [&] {
550
- apply_lu_xpu_<scalar_t >(LU, pivots_, infos_data );
556
+ apply_lu_xpu_<scalar_t >(LU, pivots_, info_data );
551
557
});
552
558
553
559
// Copy to original info and pivots tensor
0 commit comments