Skip to content

Commit e72b040

Browse files
fix: ArrayPartition arithmetic type-stability
1 parent ef55922 commit e72b040

File tree

3 files changed

+30
-19
lines changed

3 files changed

+30
-19
lines changed

src/array_partition.jl

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,11 @@ for op in (:*, :/)
153153
end
154154

155155
function Base.:*(A::Number, B::ArrayPartition)
156-
ArrayPartition(map(y -> Base.broadcast(*, A, y), B.x))
156+
ArrayPartition(map(y -> A .* y, B.x))
157157
end
158158

159159
function Base.:\(A::Number, B::ArrayPartition)
160-
ArrayPartition(map(y -> Base.broadcast(/, y, A), B.x))
160+
B / A
161161
end
162162

163163
Base.:(==)(A::ArrayPartition, B::ArrayPartition) = A.x == B.x
@@ -284,7 +284,7 @@ recursive_eltype(A::ArrayPartition) = recursive_eltype(first(A.x))
284284
Base.iterate(A::ArrayPartition) = iterate(Chain(A.x))
285285
Base.iterate(A::ArrayPartition, state) = iterate(Chain(A.x), state)
286286

287-
Base.length(A::ArrayPartition) = sum((length(x) for x in A.x))
287+
Base.length(A::ArrayPartition) = sum(broadcast(length, A.x))
288288
Base.size(A::ArrayPartition) = (length(A),)
289289

290290
# redefine first and last to avoid slow and not type-stable indexing
@@ -323,21 +323,12 @@ function Broadcast.BroadcastStyle(::ArrayPartitionStyle,
323323
Broadcast.DefaultArrayStyle{N}()
324324
end
325325

326-
combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}()
327-
@inline function combine_styles(args::Tuple{Any})
328-
Broadcast.result_style(Broadcast.BroadcastStyle(args[1]))
329-
end
330-
@inline function combine_styles(args::Tuple{Any, Any})
331-
Broadcast.result_style(Broadcast.BroadcastStyle(args[1]),
332-
Broadcast.BroadcastStyle(args[2]))
333-
end
334-
@inline function combine_styles(args::Tuple)
335-
Broadcast.result_style(Broadcast.BroadcastStyle(args[1]),
336-
combine_styles(Base.tail(args)))
337-
end
326+
combine_styles(::Type{Tuple{}}) = Broadcast.DefaultArrayStyle{0}()
327+
combine_styles(::Type{T}) where {T} = Broadcast.result_style(Broadcast.BroadcastStyle(T.parameters[1]), combine_styles(Tuple{Base.tail((T.parameters...,))...}))
328+
338329

339330
function Broadcast.BroadcastStyle(::Type{ArrayPartition{T, S}}) where {T, S}
340-
Style = combine_styles((S.parameters...,))
331+
Style = combine_styles(S)
341332
ArrayPartitionStyle(Style)
342333
end
343334

test/interface_tests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ end
110110
@test stack(testva; dims = 1) == [1 2 3; 4 5 6; 7 8 9]
111111

112112
testva = VectorOfArray([VectorOfArray([ones(2,2), 2ones(2, 2)]), 3ones(2, 2, 2)])
113-
@test stack(testva) == [1.0 1.0; 1.0 1.0;;; 2.0 2.0; 2.0 2.0;;;; 3.0 3.0; 3.0;;; 3.0 3.0; 3.0 3.0]
113+
@test stack(testva) == [1.0 1.0; 1.0 1.0;;; 2.0 2.0; 2.0 2.0;;;; 3.0 3.0; 3.0 3.0;;; 3.0 3.0; 3.0 3.0]
114114

115115
# convert array from VectorOfArray/DiffEqArray
116116
t = 1:8

test/partitions_test.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ copyto!(p, c)
5858
## inference tests
5959

6060
x = ArrayPartition([1, 2], [3.0, 4.0])
61+
y = ArrayPartition(ArrayPartition([1], [2.0]), ArrayPartition([3], [4.0]))
6162
@test x[:, 1] == (1, 3.0)
6263

6364
# similar partitions
@@ -69,6 +70,13 @@ x = ArrayPartition([1, 2], [3.0, 4.0])
6970
@test (@inferred similar(x, Int, (2, 2))) isa AbstractMatrix{Int}
7071
# @inferred similar(x, Int, Float64)
7172

73+
@inferred similar(y)
74+
@test similar(y, (4,)) isa ArrayPartition{Float64}
75+
@test (@inferred similar(y, (2, 2))) isa AbstractMatrix{Float64}
76+
@inferred similar(y, Int)
77+
@test similar(y, Int, (4,)) isa ArrayPartition{Int}
78+
@test (@inferred similar(y, Int, (2, 2))) isa AbstractMatrix{Int}
79+
7280
# Copy
7381
@inferred copy(x)
7482
@inferred copy(ArrayPartition(x, x))
@@ -95,6 +103,17 @@ x = ArrayPartition([1, 2], [3.0, 4.0])
95103
@inferred x + x
96104
@inferred x - x
97105

106+
@inferred y + 5
107+
@inferred 5 + y
108+
@inferred y - 5
109+
@inferred 5 - y
110+
@inferred y * 5
111+
@inferred 5 * y
112+
@inferred y / 5
113+
@inferred 5 \ y
114+
@inferred y + y
115+
@inferred y - y
116+
98117
# indexing
99118
@inferred first(x)
100119
@inferred last(x)
@@ -118,6 +137,7 @@ _scalar_op(y) = y + 1
118137
_broadcast_wrapper(y) = _scalar_op.(y)
119138
# Issue #8
120139
@inferred _broadcast_wrapper(x)
140+
@test_broken @inferred _broadcast_wrapper(y)
121141

122142
# Testing map
123143
@test map(x -> x^2, x) == ArrayPartition(x.x[1] .^ 2, x.x[2] .^ 2)
@@ -151,11 +171,11 @@ function foo(y, x)
151171
end
152172
foo(xcde0, xce0)
153173
#@test 0 == @allocated foo(xcde0, xce0)
154-
function foo(y, x)
174+
function foo2(y, x)
155175
y .= y .+ 2 .* x
156176
nothing
157177
end
158-
foo(xcde0, xce0)
178+
foo2(xcde0, xce0)
159179
#@test 0 == @allocated foo(xcde0, xce0)
160180

161181
# Custom AbstractArray types broadcasting

0 commit comments

Comments
 (0)