From 4a90b1da02544e725e91b9e9adc27f63f63e382e Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Tue, 12 Aug 2025 16:24:31 +0530 Subject: [PATCH 01/10] chore: add literal_getproperty adjoint for VOfA --- ext/RecursiveArrayToolsZygoteExt.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index e1987dd8..065b7f0d 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -99,6 +99,17 @@ end end end +Zygote.@adjoint function Zygote.literal_getproperty(A::RecursiveArrayTools.VectorOfArray, ::Val{:u}) + function literal_VectorOfArray_x_adjoint(d) + m = map(enumerate(d)) do (idx, d_i) + isnothing(d_i) && return zero(A.u[idx]) + d_i + end + (VectorOfArray(m),nothing) + end + A.u, literal_VectorOfArray_x_adjoint +end + @adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x}) function literal_ArrayPartition_x_adjoint(d) (ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),) From 93d2c481926ebaa4b893d57f0319e726ef8d311d Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Tue, 12 Aug 2025 16:35:19 +0530 Subject: [PATCH 02/10] test: check getproperty(, :u) returns VectorOfArray --- test/adjoints.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/adjoints.jl b/test/adjoints.jl index af2abd42..a390c33a 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -92,3 +92,9 @@ loss(x) VectorOfArray([collect((3i):(3i + 3)) for i in 1:5]) @test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x) @test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x) + +voa = RecursiveArrayTools.VectorOfArray(fill(rand(3), 3)) +voa_gs, = Zygote.gradient(voa) do x + sum(sum.(x.u)) +end +@test voa_gs isa RecursiveArrayTools.VectorOfArray From dbe9f423708a04f4633eed5011ce338888ad3587 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Wed, 13 Aug 2025 10:48:22 +0530 Subject: [PATCH 03/10] chore: formatting --- ext/RecursiveArrayToolsZygoteExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 065b7f0d..448625e6 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -105,7 +105,7 @@ Zygote.@adjoint function Zygote.literal_getproperty(A::RecursiveArrayTools.Vecto isnothing(d_i) && return zero(A.u[idx]) d_i end - (VectorOfArray(m),nothing) + (VectorOfArray(m), nothing) end A.u, literal_VectorOfArray_x_adjoint end From b9442e15014b7549fcac3d4fe0582484985c5fe4 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 13 Aug 2025 22:21:22 +0530 Subject: [PATCH 04/10] Update ext/RecursiveArrayToolsZygoteExt.jl --- ext/RecursiveArrayToolsZygoteExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 448625e6..184d5d70 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -99,7 +99,7 @@ end end end -Zygote.@adjoint function Zygote.literal_getproperty(A::RecursiveArrayTools.VectorOfArray, ::Val{:u}) +Zygote.@adjoint function Zygote.literal_getproperty(A::RecursiveArrayTools.AbstractVectorOfArray, ::Val{:u}) function literal_VectorOfArray_x_adjoint(d) m = map(enumerate(d)) do (idx, d_i) isnothing(d_i) && return zero(A.u[idx]) From 2e24855c00809f48b04574c03cd593ee663463c3 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 14 Aug 2025 10:38:46 +0530 Subject: [PATCH 05/10] feat: allow reverse pass over DiffEqArray as well --- ext/RecursiveArrayToolsZygoteExt.jl | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 184d5d70..317c4844 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -100,16 +100,29 @@ end end Zygote.@adjoint function Zygote.literal_getproperty(A::RecursiveArrayTools.AbstractVectorOfArray, ::Val{:u}) - function literal_VectorOfArray_x_adjoint(d) - m = map(enumerate(d)) do (idx, d_i) - isnothing(d_i) && return zero(A.u[idx]) - d_i - end - (VectorOfArray(m), nothing) + function literal_AbstractVofA_u_adjoint(d) + dA = vofa_u_adjoint(d, A) + (dA, nothing) end A.u, literal_VectorOfArray_x_adjoint end +function vofa_u_adjoint(d, A::RecursiveArrayTools.VectorOfArray) + m = map(enumerate(d)) do (idx, d_i) + isnothing(d_i) && return zero(A.u[idx]) + d_i + end + VectorOfArray(m) +end + +function vofa_u_adjoint(d, A::RecursiveArrayTools.DiffEqArray) + m = map(enumerate(d)) do (idx, d_i) + isnothing(d_i) && return zero(A.u[idx]) + d_i + end + DiffEqArray(m, A.t) +end + @adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x}) function literal_ArrayPartition_x_adjoint(d) (ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),) From 67e87d163a793e6500e02331773aa3ac5e6fc107 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 14 Aug 2025 11:30:32 +0530 Subject: [PATCH 06/10] chore: typo --- ext/RecursiveArrayToolsZygoteExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 317c4844..ed8377e2 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -104,7 +104,7 @@ Zygote.@adjoint function Zygote.literal_getproperty(A::RecursiveArrayTools.Abstr dA = vofa_u_adjoint(d, A) (dA, nothing) end - A.u, literal_VectorOfArray_x_adjoint + A.u, literal_VectorOfArray_u_adjoint end function vofa_u_adjoint(d, A::RecursiveArrayTools.VectorOfArray) From 2ed1893a79e6f2eb42d49bafa6161c94e07b46ec Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 14 Aug 2025 12:27:15 +0530 Subject: [PATCH 07/10] chore: typo --- ext/RecursiveArrayToolsZygoteExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index ed8377e2..b940cc00 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -104,7 +104,7 @@ Zygote.@adjoint function Zygote.literal_getproperty(A::RecursiveArrayTools.Abstr dA = vofa_u_adjoint(d, A) (dA, nothing) end - A.u, literal_VectorOfArray_u_adjoint + A.u, literal_AbstractVofA_u_adjoint end function vofa_u_adjoint(d, A::RecursiveArrayTools.VectorOfArray) From 956aaa5c9448c3ba383c9f32ff1a8d27832c3bfc Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 14 Aug 2025 14:51:18 +0530 Subject: [PATCH 08/10] chore: try to catch ODESolutions --- ext/RecursiveArrayToolsZygoteExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index b940cc00..0b4c97fb 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -115,7 +115,7 @@ function vofa_u_adjoint(d, A::RecursiveArrayTools.VectorOfArray) VectorOfArray(m) end -function vofa_u_adjoint(d, A::RecursiveArrayTools.DiffEqArray) +function vofa_u_adjoint(d, A::RecursiveArrayTools.AbstractDiffEqArray) m = map(enumerate(d)) do (idx, d_i) isnothing(d_i) && return zero(A.u[idx]) d_i From a8ad1d8345e4b9aebadd8515446a39b8b85806b1 Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 14 Aug 2025 16:00:39 +0530 Subject: [PATCH 09/10] chore: try to catch EnsembleSolutions --- ext/RecursiveArrayToolsZygoteExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 0b4c97fb..9a900d2b 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -107,7 +107,7 @@ Zygote.@adjoint function Zygote.literal_getproperty(A::RecursiveArrayTools.Abstr A.u, literal_AbstractVofA_u_adjoint end -function vofa_u_adjoint(d, A::RecursiveArrayTools.VectorOfArray) +function vofa_u_adjoint(d, A::RecursiveArrayTools.AbstractVectorOfArray) m = map(enumerate(d)) do (idx, d_i) isnothing(d_i) && return zero(A.u[idx]) d_i From 43adec043e6672b59139336fd4c3556deaf3a7db Mon Sep 17 00:00:00 2001 From: DhairyaLGandhi Date: Thu, 14 Aug 2025 20:32:32 +0530 Subject: [PATCH 10/10] chore: format --- ext/RecursiveArrayToolsZygoteExt.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/RecursiveArrayToolsZygoteExt.jl b/ext/RecursiveArrayToolsZygoteExt.jl index 9a900d2b..b8ec612e 100644 --- a/ext/RecursiveArrayToolsZygoteExt.jl +++ b/ext/RecursiveArrayToolsZygoteExt.jl @@ -110,8 +110,8 @@ end function vofa_u_adjoint(d, A::RecursiveArrayTools.AbstractVectorOfArray) m = map(enumerate(d)) do (idx, d_i) isnothing(d_i) && return zero(A.u[idx]) - d_i - end + d_i + end VectorOfArray(m) end