Skip to content

Commit 2e24855

Browse files
feat: allow reverse pass over DiffEqArray as well
1 parent b9442e1 commit 2e24855

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,16 +100,29 @@ end
100100
end
101101

102102
Zygote.@adjoint function Zygote.literal_getproperty(A::RecursiveArrayTools.AbstractVectorOfArray, ::Val{:u})
103-
function literal_VectorOfArray_x_adjoint(d)
104-
m = map(enumerate(d)) do (idx, d_i)
105-
isnothing(d_i) && return zero(A.u[idx])
106-
d_i
107-
end
108-
(VectorOfArray(m), nothing)
103+
function literal_AbstractVofA_u_adjoint(d)
104+
dA = vofa_u_adjoint(d, A)
105+
(dA, nothing)
109106
end
110107
A.u, literal_VectorOfArray_x_adjoint
111108
end
112109

110+
function vofa_u_adjoint(d, A::RecursiveArrayTools.VectorOfArray)
111+
m = map(enumerate(d)) do (idx, d_i)
112+
isnothing(d_i) && return zero(A.u[idx])
113+
d_i
114+
end
115+
VectorOfArray(m)
116+
end
117+
118+
function vofa_u_adjoint(d, A::RecursiveArrayTools.DiffEqArray)
119+
m = map(enumerate(d)) do (idx, d_i)
120+
isnothing(d_i) && return zero(A.u[idx])
121+
d_i
122+
end
123+
DiffEqArray(m, A.t)
124+
end
125+
113126
@adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x})
114127
function literal_ArrayPartition_x_adjoint(d)
115128
(ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),)

0 commit comments

Comments
 (0)