Skip to content

Commit 59a20b6

Browse files
authored
type stability of left division (#1475)
Closes #930
1 parent 47d02b7 commit 59a20b6

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

src/generic.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,18 +1169,17 @@ end
11691169
inv(A::Adjoint) = adjoint(inv(parent(A)))
11701170
inv(A::Transpose) = transpose(inv(parent(A)))
11711171

1172-
pinv(v::AbstractVector{T}, tol::Real = real(zero(T))) where {T<:Real} = _vectorpinv(transpose, v, tol)
1173-
pinv(v::AbstractVector{T}, tol::Real = real(zero(T))) where {T<:Complex} = _vectorpinv(adjoint, v, tol)
1174-
pinv(v::AbstractVector{T}, tol::Real = real(zero(T))) where {T} = _vectorpinv(adjoint, v, tol)
1175-
function _vectorpinv(dualfn::Tf, v::AbstractVector{Tv}, tol) where {Tv,Tf}
1176-
res = dualfn(similar(v, typeof(zero(Tv) / (abs2(one(Tv)) + abs2(one(Tv))))))
1172+
_pinvadjoint(v::AbstractVector{T}) where {T<:Real} = transpose(v)
1173+
_pinvadjoint(v::AbstractVector) = adjoint(v)
1174+
function pinv(v::AbstractVector{T}, tol::Real = real(zero(T))) where {T}
1175+
res = _pinvadjoint(similar(v, typeof(zero(T) / (abs2(one(T)) + abs2(one(T))))))
11771176
den = sum(abs2, v)
11781177
# as tol is the threshold relative to the maximum singular value, for a vector with
11791178
# single singular value σ=√den, σ ≦ tol*σ is equivalent to den=0 ∨ tol≥1
11801179
if iszero(den) || tol >= one(tol)
11811180
fill!(res, zero(eltype(res)))
11821181
else
1183-
res .= dualfn(v) ./ den
1182+
res .= _pinvadjoint(v) ./ den
11841183
end
11851184
return res
11861185
end
@@ -1224,6 +1223,7 @@ true
12241223
function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
12251224
require_one_based_indexing(A, B)
12261225
m, n = size(A)
1226+
T = promote_op(\, eltype(A), eltype(B))
12271227
if m == n
12281228
if istril(A)
12291229
if istriu(A)
@@ -1235,12 +1235,16 @@ function (\)(A::AbstractMatrix, B::AbstractVecOrMat)
12351235
if istriu(A)
12361236
return UpperTriangular(A) \ B
12371237
end
1238-
return lu(A) \ B
1238+
return lu(convert(AbstractArray{T}, A)) \ B
12391239
end
1240-
return qr(A, ColumnNorm()) \ B
1240+
return qr(convert(AbstractArray{T}, A), ColumnNorm()) \ B
12411241
end
12421242

1243-
(\)(a::AbstractVector, b::AbstractArray) = pinv(a) * b
1243+
function (\)(a::AbstractVector, b::AbstractArray)
1244+
den = sum(abs2, a)
1245+
goodden = den == 0 ? one(den) : den
1246+
return _pinvadjoint(a) * b / goodden
1247+
end
12441248
"""
12451249
A / B
12461250
@@ -1271,7 +1275,11 @@ function (/)(A::AbstractVecOrMat, B::AbstractVecOrMat)
12711275
end
12721276
# \(A::StridedMatrix,x::Number) = inv(A)*x Should be added at some point when the old elementwise version has been deprecated long enough
12731277
# /(x::Number,A::StridedMatrix) = x*inv(A)
1274-
/(x::Number, v::AbstractVector) = x*pinv(v)
1278+
function (/)(x::Number, v::AbstractVector)
1279+
den = sum(abs2, v)
1280+
goodden = den == 0 ? one(den) : den
1281+
return (x / goodden) * _pinvadjoint(v)
1282+
end
12751283

12761284
cond(x::Number) = iszero(x) ? Inf : 1.0
12771285
cond(x::Number, p) = cond(x)

test/generic.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,4 +954,25 @@ end
954954
@test Int[] Int[]
955955
end
956956

957+
@testset "issue 930" begin
958+
A = rand(Int, 2, 2)
959+
B = rand(Int, 2, 3)
960+
C = rand(Int, 2)
961+
for T (Float32, BigFloat)
962+
v = randn(T, 2)
963+
x = @inferred C \ v
964+
@test eltype(x) <: T
965+
x = @inferred zero(C) \ v
966+
@test eltype(x) <: T
967+
x = @inferred T(1) / C
968+
@test eltype(x) <: T
969+
x = @inferred T(1) / zero(C)
970+
@test eltype(x) <: T
971+
for M (A, B)
972+
x = @inferred M \ v
973+
@test eltype(x) <: T
974+
end
975+
end
976+
end
977+
957978
end # module TestGeneric

0 commit comments

Comments
 (0)