Skip to content

Commit 684d1cd

Browse files
committed
Refine code style
1 parent bf7a620 commit 684d1cd

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

src/ATen/native/xpu/mkl/BatchLinearAlgebra.cpp

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,18 @@ static oneapi::mkl::transpose to_blas_(TransposeType trans) {
3838
TORCH_INTERNAL_ASSERT(false, "Invalid transpose type");
3939
}
4040

41-
void error_handle(int32_t* infos, const oneapi::mkl::lapack::batch_error& be) {
41+
void error_handle(
42+
int32_t* info_cpu,
43+
const oneapi::mkl::lapack::batch_error& be) {
4244
auto errs = be.exceptions();
4345
auto ids = be.ids();
4446

4547
if (!errs.size()) {
46-
TORCH_WARN("Caught lapack exception:\nWhat: ", be.what(), "\nInfo: ", be.info());
48+
TORCH_WARN(
49+
"Caught lapack exception:\nWhat: ", be.what(), "\nInfo: ", be.info());
4750
for (auto& i : ids) {
4851
TORCH_WARN("Error in matrix #", i);
49-
infos[i] = 1;
52+
info_cpu[i] = 1;
5053
}
5154
return;
5255
}
@@ -55,14 +58,17 @@ void error_handle(int32_t* infos, const oneapi::mkl::lapack::batch_error& be) {
5558
try {
5659
std::rethrow_exception(errs[i]);
5760
} catch (const oneapi::mkl::lapack::exception& e) {
58-
std::cout << "Cathed lapack exception:"
59-
<< "\nWhat: " << e.what() << "\nInfo: " << e.info()
60-
<< "\nDetail: " << e.detail() << std::endl;
61-
infos[i] = e.info();
61+
TORCH_WARN(
62+
"Caught lapack exception:\nWhat: ",
63+
e.what(),
64+
"\nInfo: ",
65+
e.info(),
66+
"\nDetail: ",
67+
e.detail());
68+
info_cpu[i] = e.info();
6269
} catch (const sycl::exception& e) {
63-
std::cout << "Catched SYCL exception:"
64-
<< "\nWhat: " << e.what() << "\nInfo: -1" << std::endl;
65-
infos[i] = -1;
70+
TORCH_WARN("Caught SYCL exception:\nWhat: ", e.what(), "\nInfo: -1");
71+
info_cpu[i] = -1;
6672
}
6773
}
6874
}
@@ -383,7 +389,7 @@ template <typename scalar_t>
383389
static void apply_lu_xpu_(
384390
const Tensor& self_,
385391
Tensor& pivots_,
386-
int32_t* infos_) {
392+
int32_t* info_data) {
387393
// do nothing if empty input.
388394
if (self_.numel() == 0)
389395
return;
@@ -414,7 +420,7 @@ static void apply_lu_xpu_(
414420
(scalar_t*)(scratchpad_at.data_ptr()),
415421
scratchpadsize);
416422
} catch (const oneapi::mkl::lapack::batch_error& be) {
417-
error_handle(infos_, be);
423+
error_handle(info_data, be);
418424
}
419425
}
420426

@@ -447,8 +453,8 @@ static void apply_lu_solve_xpu_(
447453
int64_t* ipiv = pivots.data_ptr<int64_t>();
448454
scalar_t* b = b_.data_ptr<scalar_t>();
449455

450-
std::vector<int32_t> infos(batch_size, 0);
451-
int32_t* infos_ = infos.data();
456+
std::vector<int32_t> info_vec(batch_size, 0);
457+
int32_t* info_data = info_vec.data();
452458

453459
auto execute_mkl_getrs =
454460
[&](scalar_t* a, scalar_t* b, int64_t* ipiv, int64_t batch_size) {
@@ -482,7 +488,7 @@ static void apply_lu_solve_xpu_(
482488
scratchpad_at.data_ptr<scalar_t>(),
483489
scratchpad_size);
484490
} catch (oneapi::mkl::lapack::batch_error be) {
485-
error_handle(infos_, be);
491+
error_handle(info_data, be);
486492
}
487493
};
488494

@@ -541,13 +547,13 @@ void lu_factor_mkl(
541547

542548
// handle the info
543549
Tensor info_ = at::zeros_like(info, Device(at::kCPU));
544-
int32_t* infos_data = info_.data_ptr<int32_t>();
550+
int32_t* info_data = info_.data_ptr<int32_t>();
545551

546552
// oneMKL requires Long for pivots but PyTorch provides Int
547553
Tensor pivots_ = at::empty(pivots.sizes(), pivots.options().dtype(kLong));
548554

549555
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(LU.scalar_type(), "lu_xpu", [&] {
550-
apply_lu_xpu_<scalar_t>(LU, pivots_, infos_data);
556+
apply_lu_xpu_<scalar_t>(LU, pivots_, info_data);
551557
});
552558

553559
// Copy to original info and pivots tensor

0 commit comments

Comments
 (0)