Skip to content

Commit daff140

Browse files
Merge pull request #347 from AayushSabharwal/as/forwarddiff
fix: implement ForwarDiff.extract_derivative for AbstractVectorOfArray
2 parents c976b29 + 194e276 commit daff140

File tree

3 files changed

+17
-0
lines changed

3 files changed

+17
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1919

2020
[weakdeps]
2121
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
22+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2223
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
2324
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
2425
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
@@ -27,6 +28,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2728

2829
[extensions]
2930
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
31+
RecursiveArrayToolsForwardDiffExt = "ForwardDiff"
3032
RecursiveArrayToolsMeasurementsExt = "Measurements"
3133
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
3234
RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"]
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
module RecursiveArrayToolsForwardDiffExt
2+
3+
using RecursiveArrayTools
4+
using ForwardDiff
5+
6+
function ForwardDiff.extract_derivative(::Type{T}, y::AbstractVectorOfArray) where {T}
7+
ForwardDiff.extract_derivative.(T, y)
8+
end
9+
10+
end

test/adjoints.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ function loss8(x)
6262
return sum(abs2, res)
6363
end
6464

65+
function loss9(x)
66+
return VectorOfArray([collect(3i:3i+3) .* x for i in 1:5])
67+
end
68+
6569
x = float.(6:10)
6670
loss(x)
6771
@test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x)
@@ -72,3 +76,4 @@ loss(x)
7276
@test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x)
7377
@test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x)
7478
@test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x)
79+
@test ForwardDiff.derivative(loss9, 0.0) == VectorOfArray([collect(3i:3i+3) for i in 1:5])

0 commit comments

Comments
 (0)