Skip to content

Commit e95f3e8

Browse files
committed
take 1
1 parent 74c4bf6 commit e95f3e8

File tree

4 files changed

+78
-1
lines changed

4 files changed

+78
-1
lines changed

src/Unitful.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,5 +69,6 @@ include("logarithm.jl")
6969
include("complex.jl")
7070
include("pkgdefaults.jl")
7171
include("dates.jl")
72+
include("linearalgebra.jl")
7273

7374
end

src/linearalgebra.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
using LinearAlgebra
2+
3+
# This function is re-defined during testing, to check we hit the fast path:
4+
linearalgebra_count() = nothing
5+
6+
function LinearAlgebra.mul!(C::StridedVecOrMat{<:AbstractQuantity{T}},
7+
A::StridedMatrix{<:AbstractQuantity{T}},
8+
B::StridedVecOrMat{<:AbstractQuantity{T}},
9+
alpha::Bool, beta::Bool) where {T<:Base.HWNumber}
10+
# This is exactly how A * B creates C = similar(B, T, ...)
11+
eltype(C) == Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) || error("bad eltypes")
12+
C0 = ustrip(C)
13+
A0 = ustrip(A)
14+
B0 = ustrip(B)
15+
mul!(C0, A0, B0)
16+
linearalgebra_count()
17+
return C
18+
end
19+
20+
function LinearAlgebra.mul!(C::StridedVecOrMat{<:AbstractQuantity{T}},
21+
A::LinearAlgebra.AdjOrTransAbsMat{<:AbstractQuantity{T}, <:StridedMatrix},
22+
B::StridedVecOrMat{<:AbstractQuantity{T}},
23+
alpha::Bool, beta::Bool) where {T<:Base.HWNumber}
24+
25+
eltype(C) == Base.promote_op(LinearAlgebra.matprod, eltype(A), eltype(B)) || error("bad eltypes")
26+
C0 = ustrip(C)
27+
A0 = A isa Adjoint ? adjoint(ustrip(parent(A))) : transpose(ustrip(parent(A)))
28+
B0 = ustrip(B)
29+
mul!(C0, A0, B0)
30+
linearalgebra_count()
31+
return C
32+
end
33+
34+
function LinearAlgebra.dot(A::StridedArray{<:AbstractQuantity{T}},
35+
B::StridedArray{<:AbstractQuantity{T}}) where {T<:Base.HWNumber}
36+
A0 = ustrip(A)
37+
B0 = ustrip(B)
38+
C0 = dot(A0, B0)
39+
linearalgebra_count()
40+
C = C0 * oneunit(eltype(A)) * oneunit(eltype(B)) # surely there is an official way
41+
return C
42+
end

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ julia> a[1] = 3u"m"; b
7575
2
7676
```
7777
"""
78-
@inline ustrip(A::Array{Q}) where {Q <: Quantity} = reinterpret(numtype(Q), A)
78+
@inline ustrip(A::StridedArray{Q}) where {Q <: Quantity} = reinterpret(numtype(Q), A)
7979

8080
@deprecate(ustrip(A::AbstractArray{T}) where {T<:Number}, ustrip.(A))
8181

test/runtests.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,40 @@ end
9494
@test ConstructionBase.constructorof(typeof(1.0m))(2) === 2m
9595
end
9696

97+
@testset "LinearAlgebra functions" begin
98+
CNT = Ref(0)
99+
Unitful.linearalgebra_count() = (CNT[] += 1; nothing)
100+
@testset "> Matrix multiplication: *" begin
101+
M = rand(3,3) .* u"m"
102+
M_ = view(M,:,1:3)
103+
v = rand(3) .* u"V"
104+
v_ = view(v, 1:3)
105+
106+
CNT[] = 0
107+
108+
@test unit(first(M * M)) == u"m*m"
109+
@test M * M == M_ * M == M * M_ == M_ * M_
110+
111+
@test unit(first(M * v)) == u"m*V"
112+
@test M * v == M_ * v == M * v_ == M_ * v_
113+
114+
@test CNT[] == 10
115+
116+
@test unit(first(v' * M)) == u"m*V"
117+
@test v' * M == v_' * M == v_' * M == v_' * M_
118+
119+
@test CNT[] == 15
120+
121+
@test unit(v' * v) == u"V*V"
122+
@test v' * v == v_' * v == v_' * v == v_' * v_
123+
124+
@test CNT[] == 20
125+
end
126+
@testset "> Matrix multiplication: mul!" begin
127+
128+
end
129+
end
130+
97131
@testset "Types" begin
98132
@test Base.complex(Quantity{Float64,NoDims,NoUnits}) ==
99133
Quantity{Complex{Float64},NoDims,NoUnits}

0 commit comments

Comments
 (0)