Skip to content

Commit 6ca192d

Browse files
authored
Support rmul/lmul for mixed eltypes (#2858)
1 parent cf4b06d commit 6ca192d

File tree

2 files changed

+44
-8
lines changed

2 files changed

+44
-8
lines changed

lib/cublas/linalg.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,24 @@ function LinearAlgebra.rmul!(A::CuMatrix{T}, B::Diagonal{T,<:CuVector{T}}) where
421421
return dgmm!('R', A, B.diag, A)
422422
end
423423

424+
# eltypes do not match
425+
function LinearAlgebra.lmul!(A::Diagonal{T,<:CuVector{T}}, B::CuMatrix) where {T<:CublasFloat}
426+
@. B = A.diag * B
427+
return B
428+
end
429+
function LinearAlgebra.lmul!(A::Diagonal{Td,<:CuVector{Td}}, B::Transpose{Tt, <:CuMatrix{Tt}}) where {Td<:CublasFloat, Tt<:CublasFloat}
430+
@. B = A.diag * B
431+
return B
432+
end
433+
function LinearAlgebra.lmul!(A::Diagonal{Td,<:CuVector{Td}}, B::Adjoint{Tt, <:CuMatrix{Tt}}) where {Td<:CublasFloat, Tt<:CublasFloat}
434+
@. B = A.diag * B
435+
return B
436+
end
437+
# eltypes do not match
438+
LinearAlgebra.rmul!(A::CuMatrix, B::Diagonal{T,<:CuVector{T}}) where {T<:CublasFloat} = lmul!(B, transpose(A))
439+
LinearAlgebra.rmul!(A::Transpose{Tt, <:CuMatrix{Tt}}, B::Diagonal{Td,<:CuVector{Td}}) where {Td<:CublasFloat, Tt<:CublasFloat} = lmul!(B, A)
440+
LinearAlgebra.rmul!(A::Adjoint{Tt, <:CuMatrix{Tt}}, B::Diagonal{Td,<:CuVector{Td}}) where {Td<:CublasFloat, Tt<:CublasFloat} = conj!(lmul!(B, conj!(A)))
441+
424442
# diagm
425443

426444
LinearAlgebra.diagm(kv::Pair{<:Integer,<:CuVector}...) = _cuda_diagm(nothing, kv...)

test/libraries/cublas/extensions.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -536,49 +536,49 @@ k = 13
536536
d_XA = CuArray(XA)
537537
d_X = Diagonal(d_x)
538538
mul!(d_XA, d_X, d_A)
539-
Array(d_XA) Diagonal(x) * A
539+
@test Array(d_XA) Diagonal(x) * A
540540

541541
XA = rand(elty,m,n)
542542
d_XA = CuArray(XA)
543543
d_X = Diagonal(d_x)
544544
lmul!(d_X, d_XA)
545-
Array(d_XA) Diagonal(x) * XA
545+
@test Array(d_XA) Diagonal(x) * XA
546546

547547
AY = rand(elty,m,n)
548548
d_AY = CuArray(AY)
549549
d_Y = Diagonal(d_y)
550550
mul!(d_AY, d_A, d_Y)
551-
Array(d_AY) A * Diagonal(y)
551+
@test Array(d_AY) A * Diagonal(y)
552552

553553
AY = rand(elty,m,n)
554554
d_AY = CuArray(AY)
555555
d_Y = Diagonal(d_y)
556556
rmul!(d_AY, d_Y)
557-
Array(d_AY) AY * Diagonal(y)
557+
@test Array(d_AY) AY * Diagonal(y)
558558

559559
YA = rand(elty,n,m)
560560
d_YA = CuArray(YA)
561561
d_Y = Diagonal(d_y)
562562
mul!(d_YA, d_Y, transpose(d_A))
563-
Array(d_YA) Diagonal(y) * transpose(A)
563+
@test Array(d_YA) Diagonal(y) * transpose(A)
564564

565565
AX = rand(elty,n,m)
566566
d_AX = CuArray(AX)
567567
d_X = Diagonal(d_x)
568568
mul!(d_AX, transpose(d_A), d_X)
569-
Array(d_AX) transpose(A) * Diagonal(x)
569+
@test Array(d_AX) transpose(A) * Diagonal(x)
570570

571571
YA = rand(elty,n,m)
572572
d_YA = CuArray(YA)
573573
d_Y = Diagonal(d_y)
574574
mul!(d_YA, d_Y, d_A')
575-
Array(d_YA) Diagonal(y) * A'
575+
@test Array(d_YA) Diagonal(y) * A'
576576

577577
AX = rand(elty,n,m)
578578
d_AX = CuArray(AX)
579579
d_X = Diagonal(d_x)
580580
mul!(d_AX, d_A', d_X)
581-
Array(d_AX) A' * Diagonal(x)
581+
@test Array(d_AX) A' * Diagonal(x)
582582

583583
@test Array(d_X) == Diagonal(Array(d_x))
584584
end
@@ -592,3 +592,21 @@ k = 13
592592
@test Array(C) hC
593593
end
594594
end # elty
595+
596+
@testset "rmul/lmul with mixed eltypes ($Tr, $Tc)" for (Tr, Tc) in ((Float32, ComplexF32), (Float64, ComplexF64))
597+
x = rand(Tr,m)
598+
d_x = CuArray(x)
599+
XA = rand(Tc,m,n)
600+
d_XA = CuArray(XA)
601+
d_X = Diagonal(d_x)
602+
lmul!(d_X, d_XA)
603+
@test Array(d_XA) Diagonal(x) * XA
604+
605+
y = rand(Tr,n)
606+
d_y = CuArray(y)
607+
AY = rand(Tc,m,n)
608+
d_AY = CuArray(AY)
609+
d_Y = Diagonal(d_y)
610+
rmul!(d_AY, d_Y)
611+
@test Array(d_AY) AY * Diagonal(y)
612+
end

0 commit comments

Comments
 (0)