Skip to content

Commit 067944a

Browse files
Merge pull request #337 from AayushSabharwal/as/debug-adjoint
build: add SciMLSensitivity tests to downstream CI
2 parents c486e62 + 4d6845f commit 067944a

File tree

6 files changed

+65
-15
lines changed

6 files changed

+65
-15
lines changed

.github/workflows/Downstream.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ jobs:
2525
- {user: SciML, repo: OrdinaryDiffEq.jl, group: Core}
2626
- {user: SciML, repo: OrdinaryDiffEq.jl, group: Interface}
2727
- {user: SciML, repo: DelayDiffEq.jl, group: Interface}
28+
- {user: SciML, repo: SciMLSensitivity.jl, group: Core1}
29+
- {user: SciML, repo: SciMLSensitivity.jl, group: Core2}
30+
- {user: SciML, repo: SciMLSensitivity.jl, group: Core3}
31+
- {user: SciML, repo: SciMLSensitivity.jl, group: Core4}
32+
- {user: SciML, repo: SciMLSensitivity.jl, group: Core5}
33+
- {user: SciML, repo: SciMLSensitivity.jl, group: Core6}
2834
steps:
2935
- uses: actions/checkout@v4
3036
- uses: julia-actions/setup-julia@v1

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
2323
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
2424
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2525
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
26+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2627

2728
[extensions]
2829
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
2930
RecursiveArrayToolsMeasurementsExt = "Measurements"
3031
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
32+
RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"]
3133
RecursiveArrayToolsTrackerExt = "Tracker"
3234
RecursiveArrayToolsZygoteExt = "Zygote"
3335

@@ -49,6 +51,7 @@ OrdinaryDiffEq = "6.62"
4951
Pkg = "1"
5052
Random = "1"
5153
RecipesBase = "1.1"
54+
ReverseDiff = "1.15"
5255
SafeTestsets = "0.1"
5356
SparseArrays = "1.10"
5457
StaticArrays = "1.6"
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
module RecursiveArrayToolsReverseDiffExt
2+
3+
using RecursiveArrayTools
4+
using ReverseDiff
5+
using Zygote: @adjoint
6+
7+
function trackedarraycopyto!(dest, src)
8+
for (i, slice) in zip(eachindex(dest.u), eachslice(src, dims=ndims(src)))
9+
if dest.u[i] isa AbstractArray
10+
dest.u[i] = reshape(reduce(vcat, slice), size(dest.u[i]))
11+
else
12+
trackedarraycopyto!(dest.u[i], slice)
13+
end
14+
end
15+
end
16+
17+
@adjoint function Array(VA::AbstractVectorOfArray{<:ReverseDiff.TrackedReal})
18+
function Array_adjoint(y)
19+
VA = recursivecopy(VA)
20+
trackedarraycopyto!(VA, y)
21+
return (VA,)
22+
end
23+
return Array(VA), Array_adjoint
24+
end
25+
end # module

ext/RecursiveArrayToolsZygoteExt.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,23 +110,29 @@ end
110110
@adjoint function Base.Array(VA::AbstractVectorOfArray)
111111
adj = let VA=VA
112112
function Array_adjoint(y)
113-
VA = copy(VA)
113+
VA = recursivecopy(VA)
114114
copyto!(VA, y)
115115
return (VA,)
116116
end
117117
end
118118
Array(VA), adj
119119
end
120120

121+
@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...)
122+
function adjoint(y)
123+
(recursivecopy(parent(y)), map(_ -> nothing, I)...)
124+
end
125+
return view(A, I...), adjoint
126+
end
127+
121128
@adjoint function Base.view(A::AbstractVectorOfArray, I...)
122-
adj = let A = A, I = I
123-
function view_adjoint(y)
124-
A = zero(A)
125-
view(A, I...) .= y
126-
return (A, map(_ -> nothing, I)...)
127-
end
129+
function view_adjoint(y)
130+
A = recursivecopy(parent(y))
131+
recursivefill!(A, zero(eltype(A)))
132+
A[I...] .= y
133+
return (A, map(_ -> nothing, I)...)
128134
end
129-
view(A, I...), adj
135+
view(A, I...), view_adjoint
130136
end
131137

132138
ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ end
2828

2929
function recursivecopy(a::AbstractVectorOfArray)
3030
b = copy(a)
31-
b.u = recursivecopy.(a.u)
31+
b.u .= recursivecopy.(a.u)
3232
return b
3333
end
3434

src/vector_of_array.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -585,16 +585,26 @@ end
585585
function Base.checkbounds(VA::AbstractVectorOfArray, idx...)
586586
checkbounds(Bool, VA, idx...) || throw(BoundsError(VA, idx))
587587
end
588-
function Base.copyto!(dest::AbstractVectorOfArray{T,N}, src::AbstractVectorOfArray{T,N}) where {T,N}
589-
copyto!.(dest.u, src.u)
588+
function Base.copyto!(dest::AbstractVectorOfArray{T,N}, src::AbstractVectorOfArray{T2,N}) where {T, T2, N}
589+
for (i, j) in zip(eachindex(dest.u), eachindex(src.u))
590+
if ArrayInterface.ismutable(dest.u[i]) || dest.u[i] isa AbstractVectorOfArray
591+
copyto!(dest.u[i], src.u[j])
592+
else
593+
dest.u[i] = StaticArraysCore.similar_type(dest.u[i])(src.u[j])
594+
end
595+
end
590596
end
591-
function Base.copyto!(dest::AbstractVectorOfArray{T, N}, src::AbstractArray{T, N}) where {T, N}
592-
for (i, slice) in enumerate(eachslice(src, dims = ndims(src)))
593-
copyto!(dest.u[i], slice)
597+
function Base.copyto!(dest::AbstractVectorOfArray{T, N}, src::AbstractArray{T2, N}) where {T, T2, N}
598+
for (i, slice) in zip(eachindex(dest.u), eachslice(src, dims = ndims(src)))
599+
if ArrayInterface.ismutable(dest.u[i]) || dest.u[i] isa AbstractVectorOfArray
600+
copyto!(dest.u[i], slice)
601+
else
602+
dest.u[i] = StaticArraysCore.similar_type(dest.u[i])(slice)
603+
end
594604
end
595605
dest
596606
end
597-
function Base.copyto!(dest::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, src::AbstractVector{T}) where {T, N}
607+
function Base.copyto!(dest::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, src::AbstractVector{T2}) where {T, T2, N}
598608
copyto!(dest.u, src)
599609
dest
600610
end

0 commit comments

Comments
 (0)