Skip to content

Commit b631819

Browse files
Merge branch 'dg/nest' of github.com:DhairyaLGandhi/RecursiveArrayTools.jl into dg/nest
2 parents 1b3238f + bef99b0 commit b631819

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RecursiveArrayTools"
22
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "3.37.0"
4+
version = "3.37.1"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,30 @@ end
9999
end
100100
end
101101

102+
Zygote.@adjoint function Zygote.literal_getproperty(A::RecursiveArrayTools.AbstractVectorOfArray, ::Val{:u})
103+
function literal_AbstractVofA_u_adjoint(d)
104+
dA = vofa_u_adjoint(d, A)
105+
(dA, nothing)
106+
end
107+
A.u, literal_AbstractVofA_u_adjoint
108+
end
109+
110+
function vofa_u_adjoint(d, A::RecursiveArrayTools.AbstractVectorOfArray)
111+
m = map(enumerate(d)) do (idx, d_i)
112+
isnothing(d_i) && return zero(A.u[idx])
113+
d_i
114+
end
115+
VectorOfArray(m)
116+
end
117+
118+
function vofa_u_adjoint(d, A::RecursiveArrayTools.AbstractDiffEqArray)
119+
m = map(enumerate(d)) do (idx, d_i)
120+
isnothing(d_i) && return zero(A.u[idx])
121+
d_i
122+
end
123+
DiffEqArray(m, A.t)
124+
end
125+
102126
@adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x})
103127
function literal_ArrayPartition_x_adjoint(d)
104128
(ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),)

test/adjoints.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ loss(x)
9393
@test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x)
9494
@test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x)
9595

96+
voa = RecursiveArrayTools.VectorOfArray(fill(rand(3), 3))
97+
voa_gs, = Zygote.gradient(voa) do x
98+
sum(sum.(x.u))
99+
end
100+
@test voa_gs isa RecursiveArrayTools.VectorOfArray
101+
96102
x = ArrayPartition(ArrayPartition(rand(3,4), rand(3,4)), rand(2))
97103
g = Zygote.gradient(norm, x)[1]
98-
@test g isa typeof(x)
104+
@test g isa typeof(x)

0 commit comments

Comments
 (0)