Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseArraysBase"
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.11"
version = "0.5.12"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
1 change: 1 addition & 0 deletions src/SparseArraysBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export SparseArrayDOK,
include("abstractsparsearrayinterface.jl")
include("sparsearrayinterface.jl")
include("indexing.jl")
include("map.jl")
include("wrappers.jl")
include("abstractsparsearray.jl")
include("sparsearraydok.jl")
Expand Down
32 changes: 16 additions & 16 deletions src/abstractsparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,22 +153,22 @@ function preserves_unstored(f, a_dest::AbstractArray, as::AbstractArray...)
return iszero(f(map(a -> getunstoredindex(a, I), as)...))
end

@interface interface::AbstractSparseArrayInterface function Base.map!(
f, a_dest::AbstractArray, as::AbstractArray...
)
isempty(a_dest) && return a_dest # special case to avoid trying to access empty array
indices = if !preserves_unstored(f, a_dest, as...)
eachindex(a_dest)
elseif any(a -> a_dest !== a, as)
as = map(a -> Base.unalias(a_dest, a), as)
@interface interface zero!(a_dest)
eachstoredindex(as...)
else
eachstoredindex(a_dest)
end
@interface interface map_indices!(indices, f, a_dest, as...)
return a_dest
end
# @interface interface::AbstractSparseArrayInterface function Base.map!(
# f, a_dest::AbstractArray, as::AbstractArray...
# )
# isempty(a_dest) && return a_dest # special case to avoid trying to access empty array
# indices = if !preserves_unstored(f, a_dest, as...)
# eachindex(a_dest)
# elseif any(a -> a_dest !== a, as)
# as = map(a -> Base.unalias(a_dest, a), as)
# @interface interface zero!(a_dest)
# eachstoredindex(as...)
# else
# eachstoredindex(a_dest)
# end
# @interface interface map_indices!(indices, f, a_dest, as...)
# return a_dest
# end

# `f::typeof(norm)`, `op::typeof(max)` used by `norm`.
function reduce_init(f, op, as...)
Expand Down
16 changes: 13 additions & 3 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,20 @@ end
end
end

# required:
@interface ::AbstractSparseArrayInterface eachstoredindex(style::IndexStyle, A::AbstractArray) = throw(
MethodError(eachstoredindex, Tuple{typeof(style),typeof(A)})
# required: one implementation for canonical index style
@interface ::AbstractSparseArrayInterface function eachstoredindex(
style::IndexStyle, A::AbstractArray
)
if style == IndexStyle(A)
throw(MethodError(eachstoredindex, Tuple{typeof(style),typeof(A)}))
elseif style == IndexCartesian()
return map(Base.Fix1(Base.getindex, CartesianIndices(A)), eachindex(A))
elseif style == IndexLinear()
return map(Base.Fix1(Base.getindex, LinearIndices(A)), eachindex(A))
else
throw(ArgumentError(lazy"unknown index style $style"))
end
end

# derived but may be specialized:
@interface ::AbstractSparseArrayInterface function eachstoredindex(
Expand Down
142 changes: 142 additions & 0 deletions src/map.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# zero-preserving Traits
# ----------------------
"""
abstract type ZeroPreserving <: Function end

Holy Trait to indicate how a function interacts with abstract zero values:

- `StrongPreserving` : output is guaranteed to be zero if **any** input is.
- `WeakPreserving` : output is guaranteed to be zero if **all** inputs are.
- `NonPreserving` : no guarantees on output.

To attempt to automatically determine this, either `ZeroPreserving(f, A::AbstractArray...)` or
`ZeroPreserving(f, T::Type...)` can be used/overloaded.

!!! warning
incorrectly registering a function to be zero-preserving will lead to silently wrong results.
"""
abstract type ZeroPreserving <: Function end

struct StrongPreserving{F} <: ZeroPreserving
f::F
end
struct WeakPreserving{F} <: ZeroPreserving
f::F
end
struct NonPreserving{F} <: ZeroPreserving
f::F
end

# Backport: remove in 1.12
@static if !isdefined(Base, :haszero)
_haszero(T::Type) = false
_haszero(::Type{<:Number}) = true
else
_haszero = Base.haszero
end

# warning: cannot automatically detect WeakPreserving since this would mean checking all values
function ZeroPreserving(f, A::AbstractArray, Bs::AbstractArray...)
return ZeroPreserving(f, eltype(A), eltype.(Bs)...)
end
# TODO: the following might not properly specialize on the types
# TODO: non-concrete element types
function ZeroPreserving(f, T::Type, Ts::Type...)
if all(_haszero, (T, Ts...))
return iszero(f(zero(T), zero.(Ts)...)) ? WeakPreserving(f) : NonPreserving(f)
else
return NonPreserving(f)
end
end

const _WEAK_FUNCTIONS = (:+, :-)
for f in _WEAK_FUNCTIONS
@eval begin
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = WeakPreserving($f)
end
end

const _STRONG_FUNCTIONS = (:*,)
for f in _STRONG_FUNCTIONS
@eval begin
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = StrongPreserving(
$f
)
end
end

# map(!)
# ------
@interface I::AbstractSparseArrayInterface function Base.map(
f, A::AbstractArray, Bs::AbstractArray...
)
f_pres = ZeroPreserving(f, A, Bs...)
return @interface I map(f_pres, A, Bs...)
end
@interface I::AbstractSparseArrayInterface function Base.map(
f::ZeroPreserving, A::AbstractArray, Bs::AbstractArray...
)
T = Base.Broadcast.combine_eltypes(f.f, (A, Bs...))
C = similar(I, T, size(A))
return @interface I map!(f, C, A, Bs...)
end

@interface I::AbstractSparseArrayInterface function Base.map!(
f, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
)
f_pres = ZeroPreserving(f, A, Bs...)
return @interface I map!(f_pres, C, A, Bs...)
end

@interface ::AbstractSparseArrayInterface function Base.map!(
f::ZeroPreserving, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
)
checkshape(C, A, Bs...)
unaliased = map(Base.Fix1(Base.unalias, C), (A, Bs...))

if f isa StrongPreserving
style = IndexStyle(C, unaliased...)
inds = intersect(eachstoredindex.(Ref(style), unaliased)...)
zero!(C)
elseif f isa WeakPreserving
style = IndexStyle(C, unaliased...)
inds = union(eachstoredindex.(Ref(style), unaliased)...)
zero!(C)
elseif f isa NonPreserving
inds = eachindex(C, unaliased...)
else
error(lazy"unknown zero-preserving type $(typeof(f))")
end

@inbounds for I in inds
C[I] = f.f(ith_all(I, unaliased)...)
end

return C
end

# Derived functions
# -----------------
@interface I::AbstractSparseArrayInterface Base.copyto!(C::AbstractArray, A::AbstractArray) = @interface I map!(
identity, C, A
)

# Utility functions
# -----------------
# shape check similar to checkbounds
checkshape(::Type{Bool}, A::AbstractArray) = true
checkshape(::Type{Bool}, A::AbstractArray, B::AbstractArray) = size(A) == size(B)
function checkshape(::Type{Bool}, A::AbstractArray, Bs::AbstractArray...)
return allequal(size, (A, Bs...))
end

function checkshape(A::AbstractArray, Bs::AbstractArray...)
return checkshape(Bool, A, Bs...) ||
throw(DimensionMismatch("argument shapes must match"))
end

@inline ith_all(i, ::Tuple{}) = ()
function ith_all(i, as)
@_propagate_inbounds_meta
return (as[1][i], ith_all(i, Base.tail(as))...)
end
2 changes: 1 addition & 1 deletion src/oneelementarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ storedindex(a::OneElementArray) = getfield(a, :index)
function isstored(a::OneElementArray, I::Int...)
return I == storedindex(a)
end
function eachstoredindex(a::OneElementArray)
function eachstoredindex(::IndexCartesian, a::OneElementArray)
return Fill(CartesianIndex(storedindex(a)), 1)
end

Expand Down
Loading