Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit 44fb589

Browse files
Add Zygote
1 parent deec17f commit 44fb589

File tree

2 files changed

+194
-5
lines changed

2 files changed

+194
-5
lines changed

ext/OptimizationZygoteExt.jl

Lines changed: 165 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import OptimizationBase.ADTypes: AutoZygote
66
isdefined(Base, :get_extension) ? (using Zygote, Zygote.ForwardDiff) :
77
(using ..Zygote, ..Zygote.ForwardDiff)
88

9-
function OptimizationBase.instantiate_function(f, x, adtype::AutoZygote, p,
9+
function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x, adtype::AutoZygote, p,
1010
num_cons = 0)
1111
_f = (θ, args...) -> f(θ, p, args...)[1]
1212
if f.grad === nothing
@@ -83,7 +83,7 @@ function OptimizationBase.instantiate_function(f, x, adtype::AutoZygote, p,
8383
lag_h, f.lag_hess_prototype)
8484
end
8585

86-
function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInitCache,
86+
function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, cache::OptimizationBase.ReInitCache,
8787
adtype::AutoZygote, num_cons = 0)
8888
_f = (θ, args...) -> f(θ, cache.p, args...)[1]
8989
if f.grad === nothing
@@ -160,4 +160,167 @@ function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInit
160160
lag_h, f.lag_hess_prototype)
161161
end
162162

163+
164+
function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x, adtype::AutoZygote, p,
165+
num_cons = 0)
166+
_f = (θ, args...) -> f(θ, p, args...)[1]
167+
if f.grad === nothing
168+
grad = function (θ, args...)
169+
val = Zygote.gradient(x -> _f(x, args...), θ)[1]
170+
if val === nothing
171+
return zero(typeof(θ))
172+
else
173+
return val
174+
end
175+
end
176+
else
177+
grad = (θ, args...) -> f.grad(θ, p, args...)
178+
end
179+
180+
if f.hess === nothing
181+
hess = function (θ, args...)
182+
return ForwardDiff.jacobian(θ) do θ
183+
return Zygote.gradient(x -> _f(x, args...), θ)[1]
184+
end
185+
end
186+
else
187+
hess = (θ, args...) -> f.hess(θ, p, args...)
188+
end
189+
190+
if f.hv === nothing
191+
hv = function (H, θ, v, args...)
192+
= ForwardDiff.Dual.(θ, v)
193+
res = grad(_θ, args...)
194+
return getindex.(ForwardDiff.partials.(res), 1)
195+
end
196+
else
197+
hv = f.hv
198+
end
199+
200+
if f.cons === nothing
201+
cons = nothing
202+
else
203+
cons = (θ) -> f.cons(θ, p)
204+
cons_oop = cons
205+
end
206+
207+
if cons !== nothing && f.cons_j === nothing
208+
cons_j = function (θ)
209+
if num_cons > 1
210+
return first(Zygote.jacobian(cons_oop, θ))
211+
else
212+
return first(Zygote.jacobian(cons_oop, θ))[1, :]
213+
end
214+
end
215+
else
216+
cons_j = (θ) -> f.cons_j(θ, p)
217+
end
218+
219+
if cons !== nothing && f.cons_h === nothing
220+
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
221+
cons_h = function (θ)
222+
return map(1:num_cons) do i
223+
Zygote.hessian(fncs[i], θ)
224+
end
225+
end
226+
else
227+
cons_h = (θ) -> f.cons_h(θ, p)
228+
end
229+
230+
if f.lag_h === nothing
231+
lag_h = nothing # Consider implementing this
232+
else
233+
lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p)
234+
end
235+
236+
return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv,
237+
cons = cons, cons_j = cons_j, cons_h = cons_h,
238+
hess_prototype = f.hess_prototype,
239+
cons_jac_prototype = f.cons_jac_prototype,
240+
cons_hess_prototype = f.cons_hess_prototype,
241+
lag_h, f.lag_hess_prototype)
242+
end
243+
244+
function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, cache::OptimizationBase.ReInitCache,
245+
adtype::AutoZygote, num_cons = 0)
246+
_f = (θ, args...) -> f(θ, cache.p, args...)[1]
247+
p = cache.p
248+
249+
if f.grad === nothing
250+
grad = function (θ, args...)
251+
val = Zygote.gradient(x -> _f(x, args...), θ)[1]
252+
if val === nothing
253+
return zero(typeof(θ))
254+
else
255+
return val
256+
end
257+
end
258+
else
259+
grad = (θ, args...) -> f.grad(θ, p, args...)
260+
end
261+
262+
if f.hess === nothing
263+
hess = function (θ, args...)
264+
return ForwardDiff.jacobian(θ) do θ
265+
Zygote.gradient(x -> _f(x, args...), θ)[1]
266+
end
267+
end
268+
else
269+
hess = (θ, args...) -> f.hess(θ, p, args...)
270+
end
271+
272+
if f.hv === nothing
273+
hv = function (H, θ, v, args...)
274+
= ForwardDiff.Dual.(θ, v)
275+
res = grad(_θ, args...)
276+
return getindex.(ForwardDiff.partials.(res), 1)
277+
end
278+
else
279+
hv = f.hv
280+
end
281+
282+
if f.cons === nothing
283+
cons = nothing
284+
else
285+
cons = (θ) -> f.cons(θ, p)
286+
cons_oop = cons
287+
end
288+
289+
if cons !== nothing && f.cons_j === nothing
290+
cons_j = function (θ)
291+
if num_cons > 1
292+
return first(Zygote.jacobian(cons_oop, θ))
293+
else
294+
return first(Zygote.jacobian(cons_oop, θ))[1, :]
295+
end
296+
end
297+
else
298+
cons_j = (θ) -> f.cons_j(θ, p)
299+
end
300+
301+
if cons !== nothing && f.cons_h === nothing
302+
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
303+
cons_h = function (θ)
304+
return map(1:num_cons) do i
305+
Zygote.hessian(fncs[i], θ)
306+
end
307+
end
308+
else
309+
cons_h = (θ) -> f.cons_h(θ, p)
310+
end
311+
312+
if f.lag_h === nothing
313+
lag_h = nothing # Consider implementing this
314+
else
315+
lag_h = (θ, σ, μ) -> f.lag_h(θ, σ, μ, p)
316+
end
317+
318+
return OptimizationFunction{false}(f.f, adtype; grad = grad, hess = hess, hv = hv,
319+
cons = cons, cons_j = cons_j, cons_h = cons_h,
320+
hess_prototype = f.hess_prototype,
321+
cons_jac_prototype = f.cons_jac_prototype,
322+
cons_hess_prototype = f.cons_hess_prototype,
323+
lag_h, f.lag_hess_prototype)
324+
end
325+
163326
end

test/adtests.jl

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ optprob.hess(H2, x0)
645645
@test optprob.cons_j([5.0, 3.0]) == [10.0, 6.0]
646646

647647
@test optprob.cons_h(x0) == [[2.0 0.0; 0.0 2.0]]
648-
648+
649649
cons = (x, p) -> [x[1]^2 + x[2]^2, x[2] * sin(x[1]) - x[1]]
650650
optf = OptimizationFunction{false}(rosenbrock, OptimizationBase.AutoSparseReverseDiff(), cons = cons)
651651
optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseReverseDiff(),
@@ -664,13 +664,39 @@ optprob.hess(H2, x0)
664664

665665
@test optprob.grad(x0) == G1
666666
@test optprob.hess(x0) == H1
667+
@test optprob.cons(x0) == [0.0]
668+
669+
@test optprob.cons_j([5.0, 3.0]) == [10.0, 6.0]
670+
671+
@test optprob.cons_h(x0) == [[2.0 0.0; 0.0 2.0]]
672+
673+
cons = (x, p) -> [x[1]^2 + x[2]^2, x[2] * sin(x[1]) - x[1]]
674+
optf = OptimizationFunction{false}(rosenbrock, OptimizationBase.AutoSparseReverseDiff(true), cons = cons)
675+
optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseReverseDiff(true),
676+
nothing, 2)
677+
678+
@test optprob.grad(x0) == G1
679+
@test Array(optprob.hess(x0)) H1
667680
@test optprob.cons(x0) == [0.0, 0.0]
668681
@test optprob.cons_j([5.0, 3.0]) [10.0 6.0; -0.149013 -0.958924] rtol = 1e-6
669682
@test Array.(optprob.cons_h(x0)) [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]]
670683

684+
cons = (x, p) -> [x[1]^2 + x[2]^2]
685+
optf = OptimizationFunction{false}(rosenbrock, OptimizationBase.AutoZygote(), cons = cons)
686+
optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoZygote(),
687+
nothing, 1)
688+
689+
@test optprob.grad(x0) == G1
690+
@test optprob.hess(x0) == H1
691+
@test optprob.cons(x0) == [0.0]
692+
693+
@test optprob.cons_j([5.0, 3.0]) == [10.0, 6.0]
694+
695+
@test optprob.cons_h(x0) == [[2.0 0.0; 0.0 2.0]]
696+
671697
cons = (x, p) -> [x[1]^2 + x[2]^2, x[2] * sin(x[1]) - x[1]]
672-
optf = OptimizationFunction{false}(rosenbrock, OptimizationBase.AutoSparseReverseDiff(true), cons = cons)
673-
optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoSparseReverseDiff(true),
698+
optf = OptimizationFunction{false}(rosenbrock, OptimizationBase.AutoZygote(), cons = cons)
699+
optprob = OptimizationBase.instantiate_function(optf, x0, OptimizationBase.AutoZygote(),
674700
nothing, 2)
675701

676702
@test optprob.grad(x0) == G1

0 commit comments

Comments
 (0)