Skip to content

Commit 4a90b1d

Browse files
chore: add literal_getproperty adjoint for VOfA
1 parent a7db825 commit 4a90b1d

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,17 @@ end
9999
end
100100
end
101101

102+
Zygote.@adjoint function Zygote.literal_getproperty(A::RecursiveArrayTools.VectorOfArray, ::Val{:u})
103+
function literal_VectorOfArray_x_adjoint(d)
104+
m = map(enumerate(d)) do (idx, d_i)
105+
isnothing(d_i) && return zero(A.u[idx])
106+
d_i
107+
end
108+
(VectorOfArray(m),nothing)
109+
end
110+
A.u, literal_VectorOfArray_x_adjoint
111+
end
112+
102113
@adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x})
103114
function literal_ArrayPartition_x_adjoint(d)
104115
(ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),)

0 commit comments

Comments
 (0)