Skip to content

Commit 29445ae

Browse files
authored
New design for handling unstored values (#65)
1 parent 1dd8127 commit 29445ae

File tree

10 files changed

+328
-250
lines changed

10 files changed

+328
-250
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseArraysBase"
22
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.6.0"
4+
version = "0.7.0"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
88
Dictionaries = "0.4.4"
99
Documenter = "1.8.1"
1010
Literate = "2.20.1"
11-
SparseArraysBase = "0.6.0"
11+
SparseArraysBase = "0.7.0"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
55

66
[compat]
77
Dictionaries = "0.4.4"
8-
SparseArraysBase = "0.6.0"
8+
SparseArraysBase = "0.7.0"
99
Test = "<0.0.1, 1"

src/abstractsparsearray.jl

Lines changed: 102 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,54 @@ function DerivableInterfaces.interface(::Type{<:AbstractSparseArray})
1111
return SparseArrayInterface()
1212
end
1313

14+
function Base.copy(a::AnyAbstractSparseArray)
15+
return copyto!(similar(a), a)
16+
end
17+
18+
function similar_sparsearray(a::AnyAbstractSparseArray, unstored::Unstored)
19+
return SparseArrayDOK(unstored)
20+
end
21+
function similar_sparsearray(a::AnyAbstractSparseArray, T::Type, ax::Tuple)
22+
return similar_sparsearray(a, Unstored(unstoredsimilar(unstored(a), T, ax)))
23+
end
24+
function similar_sparsearray(a::AnyAbstractSparseArray, T::Type)
25+
return similar_sparsearray(a, Unstored(unstoredsimilar(unstored(a), T)))
26+
end
27+
function similar_sparsearray(a::AnyAbstractSparseArray, ax::Tuple)
28+
return similar_sparsearray(a, Unstored(unstoredsimilar(unstored(a), ax)))
29+
end
30+
function similar_sparsearray(a::AnyAbstractSparseArray)
31+
return similar_sparsearray(a, Unstored(unstored(a)))
32+
end
33+
34+
function Base.similar(a::AnyAbstractSparseArray, unstored::Unstored)
35+
return similar_sparsearray(a, unstored)
36+
end
37+
function Base.similar(a::AnyAbstractSparseArray)
38+
return similar_sparsearray(a)
39+
end
40+
function Base.similar(a::AnyAbstractSparseArray, T::Type)
41+
return similar_sparsearray(a, T)
42+
end
43+
function Base.similar(a::AnyAbstractSparseArray, ax::Tuple)
44+
return similar_sparsearray(a, ax)
45+
end
46+
function Base.similar(a::AnyAbstractSparseArray, T::Type, ax::Tuple)
47+
return similar_sparsearray(a, T, ax)
48+
end
49+
# Fix ambiguity error.
50+
function Base.similar(a::AnyAbstractSparseArray, T::Type, ax::Tuple{Int,Vararg{Int}})
51+
return similar_sparsearray(a, T, ax)
52+
end
53+
# Fix ambiguity error.
54+
function Base.similar(
55+
a::AnyAbstractSparseArray,
56+
T::Type,
57+
ax::Tuple{Union{Integer,Base.OneTo},Vararg{Union{Integer,Base.OneTo}}},
58+
)
59+
return similar_sparsearray(a, T, ax)
60+
end
61+
1462
using DerivableInterfaces: @derive
1563

1664
# TODO: These need to be loaded since `AbstractArrayOps`
@@ -20,12 +68,30 @@ using DerivableInterfaces: @derive
2068
using ArrayLayouts: ArrayLayouts
2169
using LinearAlgebra: LinearAlgebra
2270

23-
# DerivableInterfaces `Base.getindex`, `Base.setindex!`, etc.
24-
# TODO: Define `AbstractMatrixOps` and overload for
25-
# `AnyAbstractSparseMatrix` and `AnyAbstractSparseVector`,
26-
# which is where matrix multiplication and factorizations
27-
# should go.
28-
@derive AnyAbstractSparseArray AbstractArrayOps
71+
@derive (T=AnyAbstractSparseArray,) begin
72+
Base.getindex(::T, ::Any...)
73+
Base.getindex(::T, ::Int...)
74+
Base.setindex!(::T, ::Any, ::Any...)
75+
Base.setindex!(::T, ::Any, ::Int...)
76+
Base.copy!(::AbstractArray, ::T)
77+
Base.copyto!(::AbstractArray, ::T)
78+
Base.map(::Any, ::T...)
79+
Base.map!(::Any, ::AbstractArray, ::T...)
80+
Base.mapreduce(::Any, ::Any, ::T...; kwargs...)
81+
Base.reduce(::Any, ::T...; kwargs...)
82+
Base.all(::Function, ::T)
83+
Base.all(::T)
84+
Base.iszero(::T)
85+
Base.real(::T)
86+
Base.fill!(::T, ::Any)
87+
DerivableInterfaces.zero!(::T)
88+
Base.zero(::T)
89+
Base.permutedims!(::Any, ::T, ::Any)
90+
Broadcast.BroadcastStyle(::Type{<:T})
91+
Base.copyto!(::T, ::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}})
92+
ArrayLayouts.MemoryLayout(::Type{<:T})
93+
LinearAlgebra.mul!(::AbstractMatrix, ::T, ::T, ::Number, ::Number)
94+
end
2995

3096
using DerivableInterfaces.Concatenate: concatenate
3197
# We overload `Base._cat` instead of `Base.cat` since it
@@ -35,7 +101,12 @@ function Base._cat(dims, a::AnyAbstractSparseArray...)
35101
return concatenate(dims, a...)
36102
end
37103

104+
# TODO: Use `map(WeakPreserving(f), a)` instead.
105+
# Currently that has trouble with type unstable maps, since
106+
# the element type becomes abstract and therefore the zero/unstored
107+
# values are not well defined.
38108
function map_stored(f, a::AnyAbstractSparseArray)
109+
iszero(storedlength(a)) && return a
39110
kvs = storedpairs(a)
40111
# `collect` to convert to `Vector`, since otherwise
41112
# if it stays as `Dictionary` we might hit issues like
@@ -52,6 +123,10 @@ end
52123

53124
using Adapt: adapt
54125
function Base.print_array(io::IO, a::AnyAbstractSparseArray)
126+
# TODO: Use `map(WeakPreserving(adapt(Array)), a)` instead.
127+
# Currently that has trouble with type unstable maps, since
128+
# the element type becomes abstract and therefore the zero/unstored
129+
# values are not well defined.
55130
a′ = map_stored(adapt(Array), a)
56131
return @invoke Base.print_array(io::typeof(io), a′::AbstractArray{<:Any,ndims(a)})
57132
end
@@ -75,27 +150,30 @@ from the input indices.
75150
This constructor does not take ownership of the supplied storage, and will result in an
76151
independent container.
77152
"""
78-
sparse(::Union{AbstractDict,AbstractDictionary}, dims...; kwargs...)
153+
sparse(::Union{AbstractDict,AbstractDictionary}, dims...)
79154

80155
const AbstractDictOrDictionary = Union{AbstractDict,AbstractDictionary}
81156
# checked constructor from data: use `setindex!` to validate/convert input
82-
function sparse(storage::AbstractDictOrDictionary, dims::Dims; kwargs...)
83-
A = SparseArrayDOK{valtype(storage)}(undef, dims; kwargs...)
157+
function sparse(storage::AbstractDictOrDictionary, unstored::AbstractArray)
158+
A = SparseArrayDOK(Unstored(unstored))
84159
for (i, v) in pairs(storage)
85160
A[i] = v
86161
end
87162
return A
88163
end
89-
function sparse(storage::AbstractDictOrDictionary, dims::Int...; kwargs...)
90-
return sparse(storage, dims; kwargs...)
164+
function sparse(storage::AbstractDictOrDictionary, ax::Tuple)
165+
return sparse(storage, Zeros{valtype(storage)}(ax))
166+
end
167+
function sparse(storage::AbstractDictOrDictionary, dims::Int...)
168+
return sparse(storage, dims)
91169
end
92170
# Determine the size automatically.
93-
function sparse(storage::AbstractDictOrDictionary; kwargs...)
171+
function sparse(storage::AbstractDictOrDictionary)
94172
dims = ntuple(Returns(0), length(keytype(storage)))
95173
for I in keys(storage)
96174
dims = map(max, dims, Tuple(I))
97175
end
98-
return sparse(storage, dims; kwargs...)
176+
return sparse(storage, dims)
99177
end
100178

101179
using Random: Random, AbstractRNG, default_rng
@@ -107,12 +185,18 @@ Create an empty size `dims` sparse array.
107185
The optional `T` argument specifies the element type, which defaults to `Float64`.
108186
""" sparsezeros
109187

110-
function sparsezeros(::Type{T}, dims::Dims; kwargs...) where {T}
111-
return SparseArrayDOK{T}(undef, dims; kwargs...)
188+
function sparsezeros(::Type{T}, unstored::AbstractArray{<:Any,N}) where {T,N}
189+
return SparseArrayDOK{T,N}(Unstored(unstored))
190+
end
191+
function sparsezeros(unstored::AbstractArray{T,N}) where {T,N}
192+
return SparseArrayDOK{T,N}(Unstored(unstored))
193+
end
194+
function sparsezeros(::Type{T}, dims::Dims) where {T}
195+
return sparsezeros(T, Zeros{T}(dims))
112196
end
113-
sparsezeros(::Type{T}, dims::Int...; kwargs...) where {T} = sparsezeros(T, dims; kwargs...)
114-
sparsezeros(dims::Dims; kwargs...) = sparsezeros(Float64, dims; kwargs...)
115-
sparsezeros(dims::Int...; kwargs...) = sparsezeros(Float64, dims; kwargs...)
197+
sparsezeros(::Type{T}, dims::Int...) where {T} = sparsezeros(T, dims)
198+
sparsezeros(dims::Dims) = sparsezeros(Float64, dims)
199+
sparsezeros(dims::Int...) = sparsezeros(Float64, dims)
116200

117201
@doc """
118202
sparserand([rng], [T::Type], dims; density::Real=0.5, randfun::Function=rand) -> A::SparseArrayDOK{T}

src/abstractsparsearrayinterface.jl

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
using Base: @_propagate_inbounds_meta
22
using DerivableInterfaces:
33
DerivableInterfaces, @derive, @interface, AbstractArrayInterface, zero!
4+
using FillArrays: Zeros
45

6+
function unstored end
57
function eachstoredindex end
68
function getstoredindex end
79
function getunstoredindex end
@@ -12,9 +14,25 @@ function storedlength end
1214
function storedpairs end
1315
function storedvalues end
1416

15-
# Replace the function for accessing
16-
# unstored values.
17-
function set_getunstoredindex end
17+
# Indicates that the array should be interpreted
18+
# as the unstored values of a sparse array.
19+
struct Unstored{T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N}
20+
parent::P
21+
end
22+
Base.parent(a::Unstored) = a.parent
23+
24+
unstored(a::AbstractArray) = Zeros{eltype(a)}(axes(a))
25+
26+
function unstoredsimilar(a::AbstractArray, T::Type, ax::Tuple)
27+
return Zeros{T}(ax)
28+
end
29+
function unstoredsimilar(a::AbstractArray, ax::Tuple)
30+
return unstoredsimilar(a, eltype(a), ax)
31+
end
32+
function unstoredsimilar(a::AbstractArray, T::Type)
33+
return AbstractArray{T}(a)
34+
end
35+
unstoredsimilar(a::AbstractArray) = a
1836

1937
# Generic functionality for converting to a
2038
# dense array, trying to preserve information
@@ -84,14 +102,6 @@ Base.size(a::StoredValues) = size(a.storedindices)
84102
return setindex!(a.array, value, a.storedindices[I])
85103
end
86104

87-
# TODO: This may need to be defined in `sparsearraydok.jl`, after `SparseArrayDOK`
88-
# is defined. And/or define `default_type(::SparseArrayStyle, T::Type) = SparseArrayDOK{T}`.
89-
@interface ::AbstractSparseArrayInterface function Base.similar(
90-
a::AbstractArray, T::Type, size::Tuple{Vararg{Int}}
91-
)
92-
# TODO: Define `default_similartype` or something like that?
93-
return SparseArrayDOK{T}(undef, size)
94-
end
95105
using DerivableInterfaces: DerivableInterfaces, zero!
96106

97107
# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts`
@@ -136,7 +146,10 @@ end
136146

137147
abstract type AbstractSparseArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end
138148

139-
@derive AbstractSparseArrayStyle AbstractArrayStyleOps
149+
@derive (T=AbstractSparseArrayStyle,) begin
150+
Base.similar(::Broadcast.Broadcasted{<:T}, ::Type, ::Tuple)
151+
Base.copyto!(::AbstractArray, ::Broadcast.Broadcasted{<:T})
152+
end
140153

141154
struct SparseArrayStyle{N} <: AbstractSparseArrayStyle{N} end
142155

src/map.jl

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,20 @@ function ZeroPreserving(f, T::Type, Ts::Type...)
4848
return NonPreserving(f)
4949
end
5050
end
51+
ZeroPreserving(f::ZeroPreserving, T::Type, Ts::Type...) = f
5152

52-
const _WEAK_FUNCTIONS = (:+, :-)
53-
for f in _WEAK_FUNCTIONS
53+
for F in (:(typeof(+)), :(typeof(-)), :(typeof(identity)))
5454
@eval begin
55-
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = WeakPreserving($f)
55+
ZeroPreserving(f::$F, ::Type, ::Type...) = WeakPreserving(f)
5656
end
5757
end
5858

59-
const _STRONG_FUNCTIONS = (:*,)
60-
for f in _STRONG_FUNCTIONS
59+
using MapBroadcast: MapFunction
60+
for F in (:(typeof(*)), :(MapFunction{typeof(*)}))
6161
@eval begin
62-
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = StrongPreserving(
63-
$f
64-
)
62+
function ZeroPreserving(f::$F, ::Type, ::Type...)
63+
return StrongPreserving(f)
64+
end
6565
end
6666
end
6767

@@ -71,29 +71,35 @@ end
7171
f, A::AbstractArray, Bs::AbstractArray...
7272
)
7373
f_pres = ZeroPreserving(f, A, Bs...)
74-
return @interface I map(f_pres, A, Bs...)
74+
return map_sparsearray(f_pres, A, Bs...)
7575
end
76-
@interface I::AbstractSparseArrayInterface function Base.map(
77-
f::ZeroPreserving, A::AbstractArray, Bs::AbstractArray...
78-
)
76+
77+
# This isn't an overload of `Base.map` since that leads to ambiguity errors.
78+
function map_sparsearray(f::ZeroPreserving, A::AbstractArray, Bs::AbstractArray...)
7979
T = Base.Broadcast.combine_eltypes(f.f, (A, Bs...))
80-
C = similar(I, T, size(A))
81-
return @interface I map!(f, C, A, Bs...)
80+
C = similar(A, T)
81+
# TODO: Instead use:
82+
# ```julia
83+
# U = map(f.f, map(unstored, (A, Bs...))...)
84+
# C = similar(A, Unstored(U))
85+
# ```
86+
# though right now `map` doesn't preserve `Zeros` or `BlockZeros`.
87+
return map_sparsearray!(f, C, A, Bs...)
8288
end
8389

8490
@interface I::AbstractSparseArrayInterface function Base.map!(
8591
f, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
8692
)
8793
f_pres = ZeroPreserving(f, A, Bs...)
88-
return @interface I map!(f_pres, C, A, Bs...)
94+
return map_sparsearray!(f_pres, C, A, Bs...)
8995
end
9096

91-
@interface ::AbstractSparseArrayInterface function Base.map!(
97+
# This isn't an overload of `Base.map!` since that leads to ambiguity errors.
98+
function map_sparsearray!(
9299
f::ZeroPreserving, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
93100
)
94101
checkshape(C, A, Bs...)
95102
unaliased = map(Base.Fix1(Base.unalias, C), (A, Bs...))
96-
97103
if f isa StrongPreserving
98104
style = IndexStyle(C, unaliased...)
99105
inds = intersect(eachstoredindex.(Ref(style), unaliased)...)
@@ -107,19 +113,20 @@ end
107113
else
108114
error(lazy"unknown zero-preserving type $(typeof(f))")
109115
end
110-
111116
@inbounds for I in inds
112117
C[I] = f.f(ith_all(I, unaliased)...)
113118
end
114-
115119
return C
116120
end
117121

118122
# Derived functions
119123
# -----------------
120-
@interface I::AbstractSparseArrayInterface Base.copyto!(C::AbstractArray, A::AbstractArray) = @interface I map!(
121-
identity, C, A
124+
@interface I::AbstractSparseArrayInterface function Base.copyto!(
125+
dest::AbstractArray, src::AbstractArray
122126
)
127+
@interface I map!(identity, dest, src)
128+
return dest
129+
end
123130

124131
# Only map the stored values of the inputs.
125132
function map_stored! end

0 commit comments

Comments
 (0)