Skip to content

Commit 7bdf24d

Browse files
Merge pull request #329 from AayushSabharwal/as/fixes
fix: VoA stack, ArrayPartition arithmetic type stability
2 parents 0e877e3 + e72b040 commit 7bdf24d

File tree

4 files changed

+34
-20
lines changed

4 files changed

+34
-20
lines changed

src/array_partition.jl

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,11 @@ for op in (:*, :/)
153153
end
154154

155155
function Base.:*(A::Number, B::ArrayPartition)
156-
ArrayPartition(map(y -> Base.broadcast(*, A, y), B.x))
156+
ArrayPartition(map(y -> A .* y, B.x))
157157
end
158158

159159
function Base.:\(A::Number, B::ArrayPartition)
160-
ArrayPartition(map(y -> Base.broadcast(/, y, A), B.x))
160+
B / A
161161
end
162162

163163
Base.:(==)(A::ArrayPartition, B::ArrayPartition) = A.x == B.x
@@ -284,7 +284,7 @@ recursive_eltype(A::ArrayPartition) = recursive_eltype(first(A.x))
284284
Base.iterate(A::ArrayPartition) = iterate(Chain(A.x))
285285
Base.iterate(A::ArrayPartition, state) = iterate(Chain(A.x), state)
286286

287-
Base.length(A::ArrayPartition) = sum((length(x) for x in A.x))
287+
Base.length(A::ArrayPartition) = sum(broadcast(length, A.x))
288288
Base.size(A::ArrayPartition) = (length(A),)
289289

290290
# redefine first and last to avoid slow and not type-stable indexing
@@ -323,21 +323,12 @@ function Broadcast.BroadcastStyle(::ArrayPartitionStyle,
323323
Broadcast.DefaultArrayStyle{N}()
324324
end
325325

326-
combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}()
327-
@inline function combine_styles(args::Tuple{Any})
328-
Broadcast.result_style(Broadcast.BroadcastStyle(args[1]))
329-
end
330-
@inline function combine_styles(args::Tuple{Any, Any})
331-
Broadcast.result_style(Broadcast.BroadcastStyle(args[1]),
332-
Broadcast.BroadcastStyle(args[2]))
333-
end
334-
@inline function combine_styles(args::Tuple)
335-
Broadcast.result_style(Broadcast.BroadcastStyle(args[1]),
336-
combine_styles(Base.tail(args)))
337-
end
326+
combine_styles(::Type{Tuple{}}) = Broadcast.DefaultArrayStyle{0}()
327+
combine_styles(::Type{T}) where {T} = Broadcast.result_style(Broadcast.BroadcastStyle(T.parameters[1]), combine_styles(Tuple{Base.tail((T.parameters...,))...}))
328+
338329

339330
function Broadcast.BroadcastStyle(::Type{ArrayPartition{T, S}}) where {T, S}
340-
Style = combine_styles((S.parameters...,))
331+
Style = combine_styles(S)
341332
ArrayPartitionStyle(Style)
342333
end
343334

src/vector_of_array.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ function Base.append!(VA::AbstractVectorOfArray{T, N},
499499
end
500500

501501
function Base.stack(VA::AbstractVectorOfArray; dims = :)
502-
stack(VA.u; dims)
502+
stack(stack.(VA.u); dims)
503503
end
504504

505505
# AbstractArray methods
@@ -633,7 +633,7 @@ function Base.convert(::Type{Array}, VA::AbstractVectorOfArray)
633633
if !allequal(size.(VA.u))
634634
error("Can only convert non-ragged VectorOfArray to Array")
635635
end
636-
return stack(VA.u)
636+
return stack(VA)
637637
end
638638

639639
# statistics

test/interface_tests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ end
109109
@test stack(testva) == [1 4 7; 2 5 8; 3 6 9]
110110
@test stack(testva; dims = 1) == [1 2 3; 4 5 6; 7 8 9]
111111

112+
testva = VectorOfArray([VectorOfArray([ones(2,2), 2ones(2, 2)]), 3ones(2, 2, 2)])
113+
@test stack(testva) == [1.0 1.0; 1.0 1.0;;; 2.0 2.0; 2.0 2.0;;;; 3.0 3.0; 3.0 3.0;;; 3.0 3.0; 3.0 3.0]
114+
112115
# convert array from VectorOfArray/DiffEqArray
113116
t = 1:8
114117
recs = [rand(10, 7) for i in 1:8]

test/partitions_test.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ copyto!(p, c)
5858
## inference tests
5959

6060
x = ArrayPartition([1, 2], [3.0, 4.0])
61+
y = ArrayPartition(ArrayPartition([1], [2.0]), ArrayPartition([3], [4.0]))
6162
@test x[:, 1] == (1, 3.0)
6263

6364
# similar partitions
@@ -69,6 +70,13 @@ x = ArrayPartition([1, 2], [3.0, 4.0])
6970
@test (@inferred similar(x, Int, (2, 2))) isa AbstractMatrix{Int}
7071
# @inferred similar(x, Int, Float64)
7172

73+
@inferred similar(y)
74+
@test similar(y, (4,)) isa ArrayPartition{Float64}
75+
@test (@inferred similar(y, (2, 2))) isa AbstractMatrix{Float64}
76+
@inferred similar(y, Int)
77+
@test similar(y, Int, (4,)) isa ArrayPartition{Int}
78+
@test (@inferred similar(y, Int, (2, 2))) isa AbstractMatrix{Int}
79+
7280
# Copy
7381
@inferred copy(x)
7482
@inferred copy(ArrayPartition(x, x))
@@ -95,6 +103,17 @@ x = ArrayPartition([1, 2], [3.0, 4.0])
95103
@inferred x + x
96104
@inferred x - x
97105

106+
@inferred y + 5
107+
@inferred 5 + y
108+
@inferred y - 5
109+
@inferred 5 - y
110+
@inferred y * 5
111+
@inferred 5 * y
112+
@inferred y / 5
113+
@inferred 5 \ y
114+
@inferred y + y
115+
@inferred y - y
116+
98117
# indexing
99118
@inferred first(x)
100119
@inferred last(x)
@@ -118,6 +137,7 @@ _scalar_op(y) = y + 1
118137
_broadcast_wrapper(y) = _scalar_op.(y)
119138
# Issue #8
120139
@inferred _broadcast_wrapper(x)
140+
@test_broken @inferred _broadcast_wrapper(y)
121141

122142
# Testing map
123143
@test map(x -> x^2, x) == ArrayPartition(x.x[1] .^ 2, x.x[2] .^ 2)
@@ -151,11 +171,11 @@ function foo(y, x)
151171
end
152172
foo(xcde0, xce0)
153173
#@test 0 == @allocated foo(xcde0, xce0)
154-
function foo(y, x)
174+
function foo2(y, x)
155175
y .= y .+ 2 .* x
156176
nothing
157177
end
158-
foo(xcde0, xce0)
178+
foo2(xcde0, xce0)
159179
#@test 0 == @allocated foo(xcde0, xce0)
160180

161181
# Custom AbstractArray types broadcasting

0 commit comments

Comments
 (0)