diff --git a/src/array_partition.jl b/src/array_partition.jl index 89261f4b..b9b436bc 100644 --- a/src/array_partition.jl +++ b/src/array_partition.jl @@ -336,9 +336,10 @@ function Broadcast.BroadcastStyle(::ArrayPartitionStyle{Style}, } ArrayPartitionStyle{Style}() end -function Broadcast.BroadcastStyle(::ArrayPartitionStyle, - ::Broadcast.DefaultArrayStyle{N}) where {N} - Broadcast.DefaultArrayStyle{N}() +function Broadcast.BroadcastStyle(::ArrayPartitionStyle{AStyle}, + ::Broadcast.DefaultArrayStyle{N}) where {AStyle, N} + pick = Broadcast.BroadcastStyle(AStyle(), Broadcast.DefaultArrayStyle{N}()) + ArrayPartitionStyle(pick, Val(N)) end combine_styles(::Type{Tuple{}}) = Broadcast.DefaultArrayStyle{0}() diff --git a/test/adjoints.jl b/test/adjoints.jl index a390c33a..bc537216 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -98,3 +98,7 @@ voa_gs, = Zygote.gradient(voa) do x sum(sum.(x.u)) end @test voa_gs isa RecursiveArrayTools.VectorOfArray + +x = ArrayPartition(ArrayPartition(rand(3,4), rand(3,4)), rand(2)) +g = Zygote.gradient(norm, x)[1] +@test g isa typeof(x) diff --git a/test/basic_indexing.jl b/test/basic_indexing.jl index 442e761f..f2e0dc59 100644 --- a/test/basic_indexing.jl +++ b/test/basic_indexing.jl @@ -280,3 +280,7 @@ x = VectorOfArray(StructArray{SVector{1, Float64}}(ntuple(_ -> [1.0, 2.0], 1))) y = 2 * x @. x = y @test all(all.(y .== x)) + + +x_ap = ArrayPartition(ArrayPartition(rand(3,4), rand(3,4)), rand(2)) +@test (x_ap .* 1.2) isa typeof(x_ap)