Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ add_library(
src/larf.cc
src/larfb.cc
src/larfg.cc
src/larfg_gpu.cc
src/larfgp.cc
src/larft.cc
src/larfx.cc
Expand Down Expand Up @@ -489,8 +490,11 @@ add_library(
src/tpmlqt.cc
src/tpmqrt.cc
src/tpqrt.cc
src/tpqrt_gpu.cc
src/tpqrt2.cc
src/tpqrt2_gpu.cc
src/tprfb.cc
src/tprfb_gpu.cc
src/tprfs.cc
src/tptri.cc
src/tptrs.cc
Expand Down
35 changes: 35 additions & 0 deletions include/lapack/device.hh
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,41 @@ void heevd(
void* host_work, size_t host_work_size,
device_info_int* dev_info, lapack::Queue& queue );

template <typename scalar_t>
void larfg(
int64_t n,
scalar_t* alpha,
scalar_t* dx, int64_t incdx,
scalar_t* tau,
lapack::Queue& queue );

template <typename scalar_t>
int64_t tpqrt(
int64_t m, int64_t n, int64_t l, int64_t nb,
scalar_t* dA, int64_t ldda,
scalar_t* dB, int64_t lddb,
scalar_t* dT, int64_t lddt,
lapack::Queue& queue );

template <typename scalar_t>
int64_t tpqrt2(
int64_t m, int64_t n, int64_t l,
scalar_t* dA, int64_t ldda,
scalar_t* dB, int64_t lddb,
scalar_t* dT, int64_t lddt,
lapack::Queue& queue );

template <typename scalar_t>
void tprfb(
lapack::Side side, lapack::Op trans,
lapack::Direction direction, lapack::StoreV storev,
int64_t m, int64_t n, int64_t k, int64_t l,
scalar_t const* dV, int64_t ldv,
scalar_t const* dT, int64_t ldt,
scalar_t* dA, int64_t lda,
scalar_t* dB, int64_t ldb,
lapack::Queue& queue );

} // namespace lapack

#endif // LAPACK_DEVICE_HH
114 changes: 114 additions & 0 deletions src/larfg_gpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright (c) 2017-2025, University of Tennessee. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// This program is free software: you can redistribute it and/or modify it under
// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.

#include "lapack.hh"
#include "lapack_internal.hh"
#include "lapack/device.hh"
#include <cmath>

namespace lapack {

using std::real, std::imag, std::hypot;

template <typename real_t>
real_t hypot( std::complex<real_t> x, real_t y )
{
return hypot( x.real(), x.imag(), y );
}

// -----------------------------------------------------------------------------
/// @ingroup larfg
template <typename scalar_t>
void larfg(
int64_t n,
scalar_t* alpha,
scalar_t* dx, int64_t incdx,
scalar_t* tau,
lapack::Queue& queue )
{
using real_t = blas::real_type< scalar_t >;

const scalar_t one = 1.0;

// Quick return if n <= 0
if (n <= 0) {
blas::device_memset( tau, 0, 1, queue );
return;
}

scalar_t alpha_;
blas::device_memcpy( &alpha_, alpha, 1, queue );
queue.sync();

real_t xnorm;
blas::nrm2( n-1, dx, incdx, &xnorm, queue );

if (xnorm == 0 && imag(alpha_) == 0) {
// h = i
blas::device_memset( tau, 0, 1, queue );
return;
}

// general case
real_t beta = -copysign( hypot( alpha_, xnorm ), real(alpha_) );
real_t safmin = std::numeric_limits<real_t>::min();
real_t rsafmn = 1.0 / safmin;

int64_t knt = 0;
if (abs( beta ) < safmin) {
// XNORM, BETA may be inaccurate; scale X and recompute them
do {
knt += 1;
blas::scal( n-1, rsafmn, dx, incdx, queue );
beta *= rsafmn;
alpha_ *= rsafmn;
} while (abs(beta) < safmin && knt < 20);

blas::nrm2( n-1, dx, incdx, &xnorm, queue );
beta = -copysign( hypot( alpha_, xnorm ), real(alpha_) );
}

scalar_t tau_ = (beta - alpha_) / beta;
blas::device_memcpy( tau, &tau_, 1, queue );
blas::scal( n-1, one / (alpha_ - beta), dx, incdx, queue );

// If alpha is subnormal, it may lose relative accuracy

for (int j = 0; j < knt; j++) {
beta *= safmin;
}
alpha_ = beta;
blas::device_memcpy( alpha, &alpha_, 1, queue );
}

template void larfg(
int64_t n,
float* alpha,
float* dx, int64_t incdx,
float* tau,
lapack::Queue& queue );

template void larfg(
int64_t n,
double* alpha,
double* dx, int64_t incdx,
double* tau,
lapack::Queue& queue );

template void larfg(
int64_t n,
std::complex<float>* alpha,
std::complex<float>* dx, int64_t incdx,
std::complex<float>* tau,
lapack::Queue& queue );

template void larfg(
int64_t n,
std::complex<double>* alpha,
std::complex<double>* dx, int64_t incdx,
std::complex<double>* tau,
lapack::Queue& queue );

} // namespace lapack
168 changes: 168 additions & 0 deletions src/tpqrt2_gpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Copyright (c) 2017-2025, University of Tennessee. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause
// This program is free software: you can redistribute it and/or modify it under
// the terms of the BSD 3-Clause license. See the accompanying LICENSE file.

#include "lapack.hh"
#include "lapack_internal.hh"
#include "lapack/device.hh"

namespace lapack {

using blas::max, blas::min;
using blas::conj;

template <typename scalar_t>
int64_t tpqrt2(
int64_t m, int64_t n, int64_t l,
scalar_t* dA, int64_t ldda,
scalar_t* dB, int64_t lddb,
scalar_t* dT, int64_t lddt,
lapack::Queue& queue )
{
#define dA(i_, j_) ( dA + i_ + (j_)*ldda )
#define dB(i_, j_) ( dB + i_ + (j_)*lddb )
#define dT(i_, j_) ( dT + i_ + (j_)*lddt )

scalar_t one = 1.0;
scalar_t zero = 0.0;

int64_t info = 0;
if (m < 0)
info = -1;
else if (n < 0)
info = -2;
else if (l < 0 || l > min( m, n ))
info = -3;
else if (ldda < max( 1, n ))
info = -5;
else if (lddb < max( 1, m ))
info = -7;
else if (lddt < max( 1, n ))
info = -9;

if (info != 0)
return info;

// Quick return if possible
if (m == 0 || n == 0)
return 0;

Op op_trans = (blas::is_complex_v< scalar_t > ? Op::ConjTrans : Op::Trans);

for (int i = 0; i < n; ++i) {
scalar_t* tau = dT(i, 0); // tau calculated from larfg
scalar_t* A_row = dA(i, i+1); // Remaining elements of current row of A to be transformed
scalar_t* v = dB(0, i); // Block reflector from larfg
scalar_t* B_rem = dB(0, i+1); // Remaining block of B to be transformed
scalar_t* work = dT(0, n-1); // Temporary workspace

// Generate elementary reflector H(i) to annihilate B(:, i)
int64_t p = m-l+min( l, i+1 );
lapack::larfg( p+1, dA(i, i), v, 1, tau, queue );

if (i < n-1) {
// Apply block reflector to C

// Compute v^H C_i where C_i = [ A_row ; B_rem ],
// (work = A_row^H + B^H v)
blas::conj( n-i-1, A_row, ldda, work, 1, queue );
blas::gemv( Layout::ColMajor, op_trans, p, n-i-1,
one, B_rem, lddb, v, 1,
one, work, 1, queue );

// Apply H to A_row (A_row = A_row - tau * work^H)
// alpha = -conj( tau )
scalar_t alpha;
blas::device_memcpy( &alpha, tau, 1, queue );
queue.sync();
alpha = -conj( alpha );
// A_row += alpha * work^H for j = [0, n-i-1)
// Allocate intermediate temp vector
scalar_t* temp = blas::device_malloc< scalar_t >( n-i-1, queue );
blas::conj( n-i-1, work, 1, temp, 1, queue );
blas::axpy( n-i-1, alpha, temp, 1, A_row, ldda, queue );
queue.sync();
blas::device_free( temp, queue ); // Free temp vector

// Apply H to B
// B_rem = B_rem + alpha*v*work^H
blas::ger( Layout::ColMajor, p, n-i-1, alpha, v, 1, work, 1,
B_rem, lddb, queue );
}
}
for (int i = 1; i < n; ++i) {
// Get T matrix

// T(1:I-1,I) := C(I:M,1:I-1)^H * (alpha * C(I:M,I))

// alpha = -dT(i, 0)
scalar_t alpha;
blas::device_memcpy( &alpha, dT(i, 0), 1, queue );
queue.sync();
alpha = -alpha;
// dT(j, i) = zero for j = [0, i)
blas::device_memset( dT(0, i), 0, i, queue );

int64_t p = min( i, l );
int64_t mp = min( m-l, m-1 );
int64_t np = min( p, n-1 );

// Triangular part of B2
// T(j, i) = alpha * B(m-l, i)
blas::device_memcpy( dT(0, i), dB(m-l, i), p, queue );
blas::scal( p, alpha, dT(0, i), 1, queue );
blas::trmv( Layout::ColMajor, Uplo::Upper, op_trans, Diag::NonUnit, p,
dB(mp, 0), lddb, dT(0, i), 1, queue );

// Rectangular part of B2
blas::gemv( Layout::ColMajor, op_trans, l, i-p,
alpha, dB(mp, np), lddb, dB(mp, i), 1,
zero, dT(np, i), 1, queue );

// B1
blas::gemv( Layout::ColMajor, op_trans, m-l, i,
alpha, dB, lddb, dB(0, i), 1,
one, dT(0, i), 1, queue );

// T(1:I-1,I) := T(1:I-1,1:I-1) * T(1:I-1,I)
blas::trmv( Layout::ColMajor, Uplo::Upper, Op::NoTrans, Diag::NonUnit,
i, dT, lddt, dT(0, i), 1, queue );

// T(i, i) = tau
blas::device_memcpy( dT(i, i), dT(i, 0), 1, queue );
blas::device_memcpy( dT(i, 0), &zero, 1, queue );
}

return info;
}

template int64_t tpqrt2(
int64_t m, int64_t n, int64_t l,
float* dA, int64_t ldda,
float* dB, int64_t lddb,
float* dT, int64_t lddt,
lapack::Queue& queue );

template int64_t tpqrt2(
int64_t m, int64_t n, int64_t l,
double* dA, int64_t ldda,
double* dB, int64_t lddb,
double* dT, int64_t lddt,
lapack::Queue& queue );

template int64_t tpqrt2(
int64_t m, int64_t n, int64_t l,
std::complex<float>* dA, int64_t ldda,
std::complex<float>* dB, int64_t lddb,
std::complex<float>* dT, int64_t lddt,
lapack::Queue& queue );

template int64_t tpqrt2(
int64_t m, int64_t n, int64_t l,
std::complex<double>* dA, int64_t ldda,
std::complex<double>* dB, int64_t lddb,
std::complex<double>* dT, int64_t lddt,
lapack::Queue& queue );

} // namespace lapack
Loading
Loading