@@ -6,7 +6,7 @@ import OptimizationBase.ADTypes: AutoZygote
6
6
isdefined (Base, :get_extension ) ? (using Zygote, Zygote. ForwardDiff) :
7
7
(using .. Zygote, .. Zygote. ForwardDiff)
8
8
9
- function OptimizationBase. instantiate_function (f, x, adtype:: AutoZygote , p,
9
+ function OptimizationBase. instantiate_function (f:: OptimizationFunction{true} , x, adtype:: AutoZygote , p,
10
10
num_cons = 0 )
11
11
_f = (θ, args... ) -> f (θ, p, args... )[1 ]
12
12
if f. grad === nothing
@@ -83,7 +83,7 @@ function OptimizationBase.instantiate_function(f, x, adtype::AutoZygote, p,
83
83
lag_h, f. lag_hess_prototype)
84
84
end
85
85
86
- function OptimizationBase. instantiate_function (f, cache:: OptimizationBase.ReInitCache ,
86
+ function OptimizationBase. instantiate_function (f:: OptimizationFunction{true} , cache:: OptimizationBase.ReInitCache ,
87
87
adtype:: AutoZygote , num_cons = 0 )
88
88
_f = (θ, args... ) -> f (θ, cache. p, args... )[1 ]
89
89
if f. grad === nothing
@@ -160,4 +160,167 @@ function OptimizationBase.instantiate_function(f, cache::OptimizationBase.ReInit
160
160
lag_h, f. lag_hess_prototype)
161
161
end
162
162
163
+
164
+ function OptimizationBase. instantiate_function (f:: OptimizationFunction{false} , x, adtype:: AutoZygote , p,
165
+ num_cons = 0 )
166
+ _f = (θ, args... ) -> f (θ, p, args... )[1 ]
167
+ if f. grad === nothing
168
+ grad = function (θ, args... )
169
+ val = Zygote. gradient (x -> _f (x, args... ), θ)[1 ]
170
+ if val === nothing
171
+ return zero (typeof (θ))
172
+ else
173
+ return val
174
+ end
175
+ end
176
+ else
177
+ grad = (θ, args... ) -> f. grad (θ, p, args... )
178
+ end
179
+
180
+ if f. hess === nothing
181
+ hess = function (θ, args... )
182
+ return ForwardDiff. jacobian (θ) do θ
183
+ return Zygote. gradient (x -> _f (x, args... ), θ)[1 ]
184
+ end
185
+ end
186
+ else
187
+ hess = (θ, args... ) -> f. hess (θ, p, args... )
188
+ end
189
+
190
+ if f. hv === nothing
191
+ hv = function (H, θ, v, args... )
192
+ _θ = ForwardDiff. Dual .(θ, v)
193
+ res = grad (_θ, args... )
194
+ return getindex .(ForwardDiff. partials .(res), 1 )
195
+ end
196
+ else
197
+ hv = f. hv
198
+ end
199
+
200
+ if f. cons === nothing
201
+ cons = nothing
202
+ else
203
+ cons = (θ) -> f. cons (θ, p)
204
+ cons_oop = cons
205
+ end
206
+
207
+ if cons != = nothing && f. cons_j === nothing
208
+ cons_j = function (θ)
209
+ if num_cons > 1
210
+ return first (Zygote. jacobian (cons_oop, θ))
211
+ else
212
+ return first (Zygote. jacobian (cons_oop, θ))[1 , :]
213
+ end
214
+ end
215
+ else
216
+ cons_j = (θ) -> f. cons_j (θ, p)
217
+ end
218
+
219
+ if cons != = nothing && f. cons_h === nothing
220
+ fncs = [(x) -> cons_oop (x)[i] for i in 1 : num_cons]
221
+ cons_h = function (θ)
222
+ return map (1 : num_cons) do i
223
+ Zygote. hessian (fncs[i], θ)
224
+ end
225
+ end
226
+ else
227
+ cons_h = (θ) -> f. cons_h (θ, p)
228
+ end
229
+
230
+ if f. lag_h === nothing
231
+ lag_h = nothing # Consider implementing this
232
+ else
233
+ lag_h = (θ, σ, μ) -> f. lag_h (θ, σ, μ, p)
234
+ end
235
+
236
+ return OptimizationFunction {false} (f. f, adtype; grad = grad, hess = hess, hv = hv,
237
+ cons = cons, cons_j = cons_j, cons_h = cons_h,
238
+ hess_prototype = f. hess_prototype,
239
+ cons_jac_prototype = f. cons_jac_prototype,
240
+ cons_hess_prototype = f. cons_hess_prototype,
241
+ lag_h, f. lag_hess_prototype)
242
+ end
243
+
244
+ function OptimizationBase. instantiate_function (f:: OptimizationFunction{false} , cache:: OptimizationBase.ReInitCache ,
245
+ adtype:: AutoZygote , num_cons = 0 )
246
+ _f = (θ, args... ) -> f (θ, cache. p, args... )[1 ]
247
+ p = cache. p
248
+
249
+ if f. grad === nothing
250
+ grad = function (θ, args... )
251
+ val = Zygote. gradient (x -> _f (x, args... ), θ)[1 ]
252
+ if val === nothing
253
+ return zero (typeof (θ))
254
+ else
255
+ return val
256
+ end
257
+ end
258
+ else
259
+ grad = (θ, args... ) -> f. grad (θ, p, args... )
260
+ end
261
+
262
+ if f. hess === nothing
263
+ hess = function (θ, args... )
264
+ return ForwardDiff. jacobian (θ) do θ
265
+ Zygote. gradient (x -> _f (x, args... ), θ)[1 ]
266
+ end
267
+ end
268
+ else
269
+ hess = (θ, args... ) -> f. hess (θ, p, args... )
270
+ end
271
+
272
+ if f. hv === nothing
273
+ hv = function (H, θ, v, args... )
274
+ _θ = ForwardDiff. Dual .(θ, v)
275
+ res = grad (_θ, args... )
276
+ return getindex .(ForwardDiff. partials .(res), 1 )
277
+ end
278
+ else
279
+ hv = f. hv
280
+ end
281
+
282
+ if f. cons === nothing
283
+ cons = nothing
284
+ else
285
+ cons = (θ) -> f. cons (θ, p)
286
+ cons_oop = cons
287
+ end
288
+
289
+ if cons != = nothing && f. cons_j === nothing
290
+ cons_j = function (θ)
291
+ if num_cons > 1
292
+ return first (Zygote. jacobian (cons_oop, θ))
293
+ else
294
+ return first (Zygote. jacobian (cons_oop, θ))[1 , :]
295
+ end
296
+ end
297
+ else
298
+ cons_j = (θ) -> f. cons_j (θ, p)
299
+ end
300
+
301
+ if cons != = nothing && f. cons_h === nothing
302
+ fncs = [(x) -> cons_oop (x)[i] for i in 1 : num_cons]
303
+ cons_h = function (θ)
304
+ return map (1 : num_cons) do i
305
+ Zygote. hessian (fncs[i], θ)
306
+ end
307
+ end
308
+ else
309
+ cons_h = (θ) -> f. cons_h (θ, p)
310
+ end
311
+
312
+ if f. lag_h === nothing
313
+ lag_h = nothing # Consider implementing this
314
+ else
315
+ lag_h = (θ, σ, μ) -> f. lag_h (θ, σ, μ, p)
316
+ end
317
+
318
+ return OptimizationFunction {false} (f. f, adtype; grad = grad, hess = hess, hv = hv,
319
+ cons = cons, cons_j = cons_j, cons_h = cons_h,
320
+ hess_prototype = f. hess_prototype,
321
+ cons_jac_prototype = f. cons_jac_prototype,
322
+ cons_hess_prototype = f. cons_hess_prototype,
323
+ lag_h, f. lag_hess_prototype)
324
+ end
325
+
163
326
end
0 commit comments