Skip to content
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseArraysBase"
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.11"
version = "0.6.0"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
Dictionaries = "0.4.4"
Documenter = "1.8.1"
Literate = "2.20.1"
SparseArraysBase = "0.5.0"
SparseArraysBase = "0.6.0"
2 changes: 1 addition & 1 deletion examples/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Dictionaries = "0.4.4"
SparseArraysBase = "0.5.0"
SparseArraysBase = "0.6.0"
Test = "<0.0.1, 1"
1 change: 1 addition & 0 deletions src/SparseArraysBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export SparseArrayDOK,
include("abstractsparsearrayinterface.jl")
include("sparsearrayinterface.jl")
include("indexing.jl")
include("map.jl")
include("wrappers.jl")
include("abstractsparsearray.jl")
include("sparsearraydok.jl")
Expand Down
66 changes: 3 additions & 63 deletions src/abstractsparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,39 +92,6 @@ end
# TODO: Define `default_similartype` or something like that?
return SparseArrayDOK{T}(undef, size)
end

# map over a specified subset of indices of the inputs.
function map_indices! end

@interface interface::AbstractArrayInterface function map_indices!(
indices, f, a_dest::AbstractArray, as::AbstractArray...
)
for I in indices
a_dest[I] = f(map(a -> a[I], as)...)
end
return a_dest
end

# Only map the stored values of the inputs.
function map_stored! end

@interface interface::AbstractArrayInterface function map_stored!(
f, a_dest::AbstractArray, as::AbstractArray...
)
@interface interface map_indices!(eachstoredindex(as...), f, a_dest, as...)
return a_dest
end

# Only map all values, not just the stored ones.
function map_all! end

@interface interface::AbstractArrayInterface function map_all!(
f, a_dest::AbstractArray, as::AbstractArray...
)
@interface interface map_indices!(eachindex(as...), f, a_dest, as...)
return a_dest
end

using DerivableInterfaces: DerivableInterfaces, zero!

# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts`
Expand All @@ -137,37 +104,10 @@ using DerivableInterfaces: DerivableInterfaces, zero!
# More generally, this codepath could be taking if `zero(eltype(a))`
# is defined and the elements are immutable.
f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero!
return @interface interface map_stored!(f, a, a)
end

# Determines if a function preserves the stored values
# of the destination sparse array.
# The current code may be inefficient since it actually
# accesses an unstored element, which in the case of a
# sparse array of arrays can allocate an array.
# Sparse arrays could be expected to define a cheap
# unstored element allocator, for example
# `get_prototypical_unstored(a::AbstractArray)`.
function preserves_unstored(f, a_dest::AbstractArray, as::AbstractArray...)
I = first(eachindex(as...))
return iszero(f(map(a -> getunstoredindex(a, I), as)...))
end

@interface interface::AbstractSparseArrayInterface function Base.map!(
f, a_dest::AbstractArray, as::AbstractArray...
)
isempty(a_dest) && return a_dest # special case to avoid trying to access empty array
indices = if !preserves_unstored(f, a_dest, as...)
eachindex(a_dest)
elseif any(a -> a_dest !== a, as)
as = map(a -> Base.unalias(a_dest, a), as)
@interface interface zero!(a_dest)
eachstoredindex(as...)
else
eachstoredindex(a_dest)
@inbounds for I in eachstoredindex(a)
a[I] = f(a[I])
end
@interface interface map_indices!(indices, f, a_dest, as...)
return a_dest
return a
end

# `f::typeof(norm)`, `op::typeof(max)` used by `norm`.
Expand Down
23 changes: 20 additions & 3 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,27 @@ end
end
end

# required:
@interface ::AbstractSparseArrayInterface eachstoredindex(style::IndexStyle, A::AbstractArray) = throw(
MethodError(eachstoredindex, Tuple{typeof(style),typeof(A)})
@noinline function error_if_canonical_eachstoredindex(style::IndexStyle, A::AbstractArray)
style === IndexStyle(A) && throw(Base.CanonicalIndexError("eachstoredindex", typeof(A)))
return nothing
end

# required: one implementation for canonical index style
@interface ::AbstractSparseArrayInterface function eachstoredindex(
style::IndexStyle, A::AbstractArray
)
error_if_canonical_eachstoredindex(style, A)
inds = eachstoredindex(A)
if style === IndexCartesian()
eltype(inds) === CartesianIndex{ndims(A)} && return inds
return map(Base.Fix1(Base.getindex, CartesianIndices(A)), inds)
elseif style === IndexLinear()
eltype(inds) === Int && return inds
return map(Base.Fix1(Base.getindex, LinearIndices(A)), inds)
else
error(lazy"unkown index style $style")
end
end

# derived but may be specialized:
@interface ::AbstractSparseArrayInterface function eachstoredindex(
Expand Down
162 changes: 162 additions & 0 deletions src/map.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# zero-preserving Traits
# ----------------------
"""
abstract type ZeroPreserving <: Function end

Holy Trait to indicate how a function interacts with abstract zero values:

- `StrongPreserving` : output is guaranteed to be zero if **any** input is.
- `WeakPreserving` : output is guaranteed to be zero if **all** inputs are.
- `NonPreserving` : no guarantees on output.

To attempt to automatically determine this, either `ZeroPreserving(f, A::AbstractArray...)` or
`ZeroPreserving(f, T::Type...)` can be used/overloaded.

!!! warning
incorrectly registering a function to be zero-preserving will lead to silently wrong results.
"""
abstract type ZeroPreserving <: Function end

struct StrongPreserving{F} <: ZeroPreserving
f::F
end
struct WeakPreserving{F} <: ZeroPreserving
f::F
end
struct NonPreserving{F} <: ZeroPreserving
f::F
end

# Backport: remove in 1.12
@static if !isdefined(Base, :haszero)
_haszero(T::Type) = false
_haszero(::Type{<:Number}) = true
else
_haszero = Base.haszero
end

# warning: cannot automatically detect WeakPreserving since this would mean checking all values
function ZeroPreserving(f, A::AbstractArray, Bs::AbstractArray...)
return ZeroPreserving(f, eltype(A), eltype.(Bs)...)
end
# TODO: the following might not properly specialize on the types
# TODO: non-concrete element types
function ZeroPreserving(f, T::Type, Ts::Type...)
if all(_haszero, (T, Ts...))
return iszero(f(zero(T), zero.(Ts)...)) ? WeakPreserving(f) : NonPreserving(f)
else
return NonPreserving(f)
end
end

const _WEAK_FUNCTIONS = (:+, :-)
for f in _WEAK_FUNCTIONS
@eval begin
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = WeakPreserving($f)
end
end

const _STRONG_FUNCTIONS = (:*,)
for f in _STRONG_FUNCTIONS
@eval begin
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = StrongPreserving(
$f
)
end
end

# map(!)
# ------
@interface I::AbstractSparseArrayInterface function Base.map(
f, A::AbstractArray, Bs::AbstractArray...
)
f_pres = ZeroPreserving(f, A, Bs...)
return @interface I map(f_pres, A, Bs...)
end
@interface I::AbstractSparseArrayInterface function Base.map(
f::ZeroPreserving, A::AbstractArray, Bs::AbstractArray...
)
T = Base.Broadcast.combine_eltypes(f.f, (A, Bs...))
C = similar(I, T, size(A))
return @interface I map!(f, C, A, Bs...)
end

@interface I::AbstractSparseArrayInterface function Base.map!(
f, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
)
f_pres = ZeroPreserving(f, A, Bs...)
return @interface I map!(f_pres, C, A, Bs...)
end

@interface ::AbstractSparseArrayInterface function Base.map!(
f::ZeroPreserving, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
)
checkshape(C, A, Bs...)
unaliased = map(Base.Fix1(Base.unalias, C), (A, Bs...))

if f isa StrongPreserving
style = IndexStyle(C, unaliased...)
inds = intersect(eachstoredindex.(Ref(style), unaliased)...)
zero!(C)
elseif f isa WeakPreserving
style = IndexStyle(C, unaliased...)
inds = union(eachstoredindex.(Ref(style), unaliased)...)
zero!(C)
elseif f isa NonPreserving
inds = eachindex(C, unaliased...)
else
error(lazy"unknown zero-preserving type $(typeof(f))")
end

@inbounds for I in inds
C[I] = f.f(ith_all(I, unaliased)...)
end

return C
end

# Derived functions
# -----------------
@interface I::AbstractSparseArrayInterface Base.copyto!(C::AbstractArray, A::AbstractArray) = @interface I map!(
identity, C, A
)

# Only map the stored values of the inputs.
function map_stored! end

@interface interface::AbstractArrayInterface function map_stored!(
f, a_dest::AbstractArray, as::AbstractArray...
)
@interface interface map!(WeakPreserving(f), a_dest, as...)
return a_dest
end

# Only map all values, not just the stored ones.
function map_all! end

@interface interface::AbstractArrayInterface function map_all!(
f, a_dest::AbstractArray, as::AbstractArray...
)
@interface interface map!(NonPreserving(f), a_dest, as...)
return a_dest
end

# Utility functions
# -----------------
# shape check similar to checkbounds
checkshape(::Type{Bool}, A::AbstractArray) = true
checkshape(::Type{Bool}, A::AbstractArray, B::AbstractArray) = size(A) == size(B)
function checkshape(::Type{Bool}, A::AbstractArray, Bs::AbstractArray...)
return allequal(size, (A, Bs...))
end

function checkshape(A::AbstractArray, Bs::AbstractArray...)
return checkshape(Bool, A, Bs...) ||
throw(DimensionMismatch("argument shapes must match"))
end

@inline ith_all(i, ::Tuple{}) = ()
function ith_all(i, as)
@_propagate_inbounds_meta
return (as[1][i], ith_all(i, Base.tail(as))...)
end
2 changes: 1 addition & 1 deletion src/oneelementarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ storedindex(a::OneElementArray) = getfield(a, :index)
function isstored(a::OneElementArray, I::Int...)
return I == storedindex(a)
end
function eachstoredindex(a::OneElementArray)
function eachstoredindex(::IndexCartesian, a::OneElementArray)
return Fill(CartesianIndex(storedindex(a)), 1)
end

Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ JLArrays = "0.2.0"
LinearAlgebra = "<0.0.1, 1"
Random = "<0.0.1, 1"
SafeTestsets = "0.1.0"
SparseArraysBase = "0.5.0"
SparseArraysBase = "0.6.0"
StableRNGs = "1.0.2"
Suppressor = "0.2.8"
Test = "<0.0.1, 1"
Loading