Skip to content

Commit 8c87d85

Browse files
Merge pull request #309 from AayushSabharwal/as/broadcast-from-array
feat: add ability to set VectorOfArray with Array using broadcast
2 parents dd5c756 + 300f692 commit 8c87d85

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

src/vector_of_array.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -663,10 +663,25 @@ end
663663
bc = Broadcast.flatten(bc)
664664
N = narrays(bc)
665665
@inbounds for i in 1:N
666-
if dest[:, i] isa AbstractArray && !isa(dest[:, i], StaticArraysCore.SArray)
666+
if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i])
667667
copyto!(dest[:, i], unpack_voa(bc, i))
668668
else
669-
dest[:, i] = copy(unpack_voa(bc, i))
669+
unpacked = unpack_voa(bc, i)
670+
dest[:, i] = unpacked.f(unpacked.args...)
671+
end
672+
end
673+
dest
674+
end
675+
676+
@inline function Base.copyto!(dest::AbstractVectorOfArray,
677+
bc::Broadcast.Broadcasted{<:Broadcast.DefaultArrayStyle})
678+
bc = Broadcast.flatten(bc)
679+
@inbounds for i in 1:length(dest.u)
680+
if dest[:, i] isa AbstractArray && ArrayInterface.ismutable(dest[:, i])
681+
copyto!(dest[:, i], unpack_voa(bc, i))
682+
else
683+
unpacked = unpack_voa(bc, i)
684+
dest[:, i] = unpacked.f(unpacked.args...)
670685
end
671686
end
672687
dest

test/interface_tests.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,27 @@ z = VectorOfArray([zeros(SVector{2, Float64}), zeros(SVector{2, Float64})])
125125
z .= x .+ y
126126

127127
@test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})])
128+
129+
u1 = VectorOfArray([fill(2, SVector{2, Float64}), ones(SVector{2, Float64})])
130+
u2 = VectorOfArray([fill(4, SVector{2, Float64}), 2 .* ones(SVector{2, Float64})])
131+
u3 = VectorOfArray([fill(4, SVector{2, Float64}), 2 .* ones(SVector{2, Float64})])
132+
133+
function f(u1,u2,u3)
134+
u3 .= u1 .+ u2
135+
end
136+
f(u1,u2,u3)
137+
@test (@allocated f(u1,u2,u3)) == 0
138+
139+
yy = [2.0 1.0; 2.0 1.0]
140+
zz = x .+ yy
141+
@test zz == [4.0 2.0; 4.0 2.0]
142+
143+
z = VectorOfArray([zeros(SVector{2, Float64}), zeros(SVector{2, Float64})])
144+
z .= zz
145+
@test z == VectorOfArray([fill(4, SVector{2, Float64}), fill(2, SVector{2, Float64})])
146+
147+
function f!(z,zz)
148+
z .= zz
149+
end
150+
f!(z,zz)
151+
@test (@allocated f!(z,zz)) == 0

0 commit comments

Comments
 (0)