Skip to content

Commit 1dd8127

Browse files
authored
Sparse mapping refactor (#63)
1 parent 00cac74 commit 1dd8127

File tree

9 files changed

+191
-71
lines changed

9 files changed

+191
-71
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseArraysBase"
22
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.5.11"
4+
version = "0.6.0"
55

66
[deps]
77
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
88
Dictionaries = "0.4.4"
99
Documenter = "1.8.1"
1010
Literate = "2.20.1"
11-
SparseArraysBase = "0.5.0"
11+
SparseArraysBase = "0.6.0"

examples/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
55

66
[compat]
77
Dictionaries = "0.4.4"
8-
SparseArraysBase = "0.5.0"
8+
SparseArraysBase = "0.6.0"
99
Test = "<0.0.1, 1"

src/SparseArraysBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ export SparseArrayDOK,
2020
include("abstractsparsearrayinterface.jl")
2121
include("sparsearrayinterface.jl")
2222
include("indexing.jl")
23+
include("map.jl")
2324
include("wrappers.jl")
2425
include("abstractsparsearray.jl")
2526
include("sparsearraydok.jl")

src/abstractsparsearrayinterface.jl

Lines changed: 3 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -92,39 +92,6 @@ end
9292
# TODO: Define `default_similartype` or something like that?
9393
return SparseArrayDOK{T}(undef, size)
9494
end
95-
96-
# map over a specified subset of indices of the inputs.
97-
function map_indices! end
98-
99-
@interface interface::AbstractArrayInterface function map_indices!(
100-
indices, f, a_dest::AbstractArray, as::AbstractArray...
101-
)
102-
for I in indices
103-
a_dest[I] = f(map(a -> a[I], as)...)
104-
end
105-
return a_dest
106-
end
107-
108-
# Only map the stored values of the inputs.
109-
function map_stored! end
110-
111-
@interface interface::AbstractArrayInterface function map_stored!(
112-
f, a_dest::AbstractArray, as::AbstractArray...
113-
)
114-
@interface interface map_indices!(eachstoredindex(as...), f, a_dest, as...)
115-
return a_dest
116-
end
117-
118-
# Only map all values, not just the stored ones.
119-
function map_all! end
120-
121-
@interface interface::AbstractArrayInterface function map_all!(
122-
f, a_dest::AbstractArray, as::AbstractArray...
123-
)
124-
@interface interface map_indices!(eachindex(as...), f, a_dest, as...)
125-
return a_dest
126-
end
127-
12895
using DerivableInterfaces: DerivableInterfaces, zero!
12996

13097
# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts`
@@ -137,37 +104,10 @@ using DerivableInterfaces: DerivableInterfaces, zero!
137104
# More generally, this codepath could be taking if `zero(eltype(a))`
138105
# is defined and the elements are immutable.
139106
f = eltype(a) <: Number ? Returns(zero(eltype(a))) : zero!
140-
return @interface interface map_stored!(f, a, a)
141-
end
142-
143-
# Determines if a function preserves the stored values
144-
# of the destination sparse array.
145-
# The current code may be inefficient since it actually
146-
# accesses an unstored element, which in the case of a
147-
# sparse array of arrays can allocate an array.
148-
# Sparse arrays could be expected to define a cheap
149-
# unstored element allocator, for example
150-
# `get_prototypical_unstored(a::AbstractArray)`.
151-
function preserves_unstored(f, a_dest::AbstractArray, as::AbstractArray...)
152-
I = first(eachindex(as...))
153-
return iszero(f(map(a -> getunstoredindex(a, I), as)...))
154-
end
155-
156-
@interface interface::AbstractSparseArrayInterface function Base.map!(
157-
f, a_dest::AbstractArray, as::AbstractArray...
158-
)
159-
isempty(a_dest) && return a_dest # special case to avoid trying to access empty array
160-
indices = if !preserves_unstored(f, a_dest, as...)
161-
eachindex(a_dest)
162-
elseif any(a -> a_dest !== a, as)
163-
as = map(a -> Base.unalias(a_dest, a), as)
164-
@interface interface zero!(a_dest)
165-
eachstoredindex(as...)
166-
else
167-
eachstoredindex(a_dest)
107+
@inbounds for I in eachstoredindex(a)
108+
a[I] = f(a[I])
168109
end
169-
@interface interface map_indices!(indices, f, a_dest, as...)
170-
return a_dest
110+
return a
171111
end
172112

173113
# `f::typeof(norm)`, `op::typeof(max)` used by `norm`.

src/indexing.jl

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,27 @@ end
308308
end
309309
end
310310

311-
# required:
312-
@interface ::AbstractSparseArrayInterface eachstoredindex(style::IndexStyle, A::AbstractArray) = throw(
313-
MethodError(eachstoredindex, Tuple{typeof(style),typeof(A)})
311+
@noinline function error_if_canonical_eachstoredindex(style::IndexStyle, A::AbstractArray)
312+
style === IndexStyle(A) && throw(Base.CanonicalIndexError("eachstoredindex", typeof(A)))
313+
return nothing
314+
end
315+
316+
# required: one implementation for canonical index style
317+
@interface ::AbstractSparseArrayInterface function eachstoredindex(
318+
style::IndexStyle, A::AbstractArray
314319
)
320+
error_if_canonical_eachstoredindex(style, A)
321+
inds = eachstoredindex(A)
322+
if style === IndexCartesian()
323+
eltype(inds) === CartesianIndex{ndims(A)} && return inds
324+
return map(Base.Fix1(Base.getindex, CartesianIndices(A)), inds)
325+
elseif style === IndexLinear()
326+
eltype(inds) === Int && return inds
327+
return map(Base.Fix1(Base.getindex, LinearIndices(A)), inds)
328+
else
329+
error(lazy"unkown index style $style")
330+
end
331+
end
315332

316333
# derived but may be specialized:
317334
@interface ::AbstractSparseArrayInterface function eachstoredindex(

src/map.jl

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# zero-preserving Traits
2+
# ----------------------
3+
"""
4+
abstract type ZeroPreserving <: Function end
5+
6+
Holy Trait to indicate how a function interacts with abstract zero values:
7+
8+
- `StrongPreserving` : output is guaranteed to be zero if **any** input is.
9+
- `WeakPreserving` : output is guaranteed to be zero if **all** inputs are.
10+
- `NonPreserving` : no guarantees on output.
11+
12+
To attempt to automatically determine this, either `ZeroPreserving(f, A::AbstractArray...)` or
13+
`ZeroPreserving(f, T::Type...)` can be used/overloaded.
14+
15+
!!! warning
16+
incorrectly registering a function to be zero-preserving will lead to silently wrong results.
17+
"""
18+
abstract type ZeroPreserving <: Function end
19+
20+
struct StrongPreserving{F} <: ZeroPreserving
21+
f::F
22+
end
23+
struct WeakPreserving{F} <: ZeroPreserving
24+
f::F
25+
end
26+
struct NonPreserving{F} <: ZeroPreserving
27+
f::F
28+
end
29+
30+
# Backport: remove in 1.12
31+
@static if !isdefined(Base, :haszero)
32+
_haszero(T::Type) = false
33+
_haszero(::Type{<:Number}) = true
34+
else
35+
_haszero = Base.haszero
36+
end
37+
38+
# warning: cannot automatically detect WeakPreserving since this would mean checking all values
39+
function ZeroPreserving(f, A::AbstractArray, Bs::AbstractArray...)
40+
return ZeroPreserving(f, eltype(A), eltype.(Bs)...)
41+
end
42+
# TODO: the following might not properly specialize on the types
43+
# TODO: non-concrete element types
44+
function ZeroPreserving(f, T::Type, Ts::Type...)
45+
if all(_haszero, (T, Ts...))
46+
return iszero(f(zero(T), zero.(Ts)...)) ? WeakPreserving(f) : NonPreserving(f)
47+
else
48+
return NonPreserving(f)
49+
end
50+
end
51+
52+
const _WEAK_FUNCTIONS = (:+, :-)
53+
for f in _WEAK_FUNCTIONS
54+
@eval begin
55+
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = WeakPreserving($f)
56+
end
57+
end
58+
59+
const _STRONG_FUNCTIONS = (:*,)
60+
for f in _STRONG_FUNCTIONS
61+
@eval begin
62+
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = StrongPreserving(
63+
$f
64+
)
65+
end
66+
end
67+
68+
# map(!)
69+
# ------
70+
@interface I::AbstractSparseArrayInterface function Base.map(
71+
f, A::AbstractArray, Bs::AbstractArray...
72+
)
73+
f_pres = ZeroPreserving(f, A, Bs...)
74+
return @interface I map(f_pres, A, Bs...)
75+
end
76+
@interface I::AbstractSparseArrayInterface function Base.map(
77+
f::ZeroPreserving, A::AbstractArray, Bs::AbstractArray...
78+
)
79+
T = Base.Broadcast.combine_eltypes(f.f, (A, Bs...))
80+
C = similar(I, T, size(A))
81+
return @interface I map!(f, C, A, Bs...)
82+
end
83+
84+
@interface I::AbstractSparseArrayInterface function Base.map!(
85+
f, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
86+
)
87+
f_pres = ZeroPreserving(f, A, Bs...)
88+
return @interface I map!(f_pres, C, A, Bs...)
89+
end
90+
91+
@interface ::AbstractSparseArrayInterface function Base.map!(
92+
f::ZeroPreserving, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
93+
)
94+
checkshape(C, A, Bs...)
95+
unaliased = map(Base.Fix1(Base.unalias, C), (A, Bs...))
96+
97+
if f isa StrongPreserving
98+
style = IndexStyle(C, unaliased...)
99+
inds = intersect(eachstoredindex.(Ref(style), unaliased)...)
100+
zero!(C)
101+
elseif f isa WeakPreserving
102+
style = IndexStyle(C, unaliased...)
103+
inds = union(eachstoredindex.(Ref(style), unaliased)...)
104+
zero!(C)
105+
elseif f isa NonPreserving
106+
inds = eachindex(C, unaliased...)
107+
else
108+
error(lazy"unknown zero-preserving type $(typeof(f))")
109+
end
110+
111+
@inbounds for I in inds
112+
C[I] = f.f(ith_all(I, unaliased)...)
113+
end
114+
115+
return C
116+
end
117+
118+
# Derived functions
119+
# -----------------
120+
@interface I::AbstractSparseArrayInterface Base.copyto!(C::AbstractArray, A::AbstractArray) = @interface I map!(
121+
identity, C, A
122+
)
123+
124+
# Only map the stored values of the inputs.
125+
function map_stored! end
126+
127+
@interface interface::AbstractArrayInterface function map_stored!(
128+
f, a_dest::AbstractArray, as::AbstractArray...
129+
)
130+
@interface interface map!(WeakPreserving(f), a_dest, as...)
131+
return a_dest
132+
end
133+
134+
# Only map all values, not just the stored ones.
135+
function map_all! end
136+
137+
@interface interface::AbstractArrayInterface function map_all!(
138+
f, a_dest::AbstractArray, as::AbstractArray...
139+
)
140+
@interface interface map!(NonPreserving(f), a_dest, as...)
141+
return a_dest
142+
end
143+
144+
# Utility functions
145+
# -----------------
146+
# shape check similar to checkbounds
147+
checkshape(::Type{Bool}, A::AbstractArray) = true
148+
checkshape(::Type{Bool}, A::AbstractArray, B::AbstractArray) = size(A) == size(B)
149+
function checkshape(::Type{Bool}, A::AbstractArray, Bs::AbstractArray...)
150+
return allequal(size, (A, Bs...))
151+
end
152+
153+
function checkshape(A::AbstractArray, Bs::AbstractArray...)
154+
return checkshape(Bool, A, Bs...) ||
155+
throw(DimensionMismatch("argument shapes must match"))
156+
end
157+
158+
@inline ith_all(i, ::Tuple{}) = ()
159+
function ith_all(i, as)
160+
@_propagate_inbounds_meta
161+
return (as[1][i], ith_all(i, Base.tail(as))...)
162+
end

src/oneelementarray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ storedindex(a::OneElementArray) = getfield(a, :index)
287287
function isstored(a::OneElementArray, I::Int...)
288288
return I == storedindex(a)
289289
end
290-
function eachstoredindex(a::OneElementArray)
290+
function eachstoredindex(::IndexCartesian, a::OneElementArray)
291291
return Fill(CartesianIndex(storedindex(a)), 1)
292292
end
293293

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ JLArrays = "0.2.0"
2121
LinearAlgebra = "<0.0.1, 1"
2222
Random = "<0.0.1, 1"
2323
SafeTestsets = "0.1.0"
24-
SparseArraysBase = "0.5.0"
24+
SparseArraysBase = "0.6.0"
2525
StableRNGs = "1.0.2"
2626
Suppressor = "0.2.8"
2727
Test = "<0.0.1, 1"

0 commit comments

Comments
 (0)