Skip to content

Commit ba5e519

Browse files
Merge pull request #103 from SciML/optjlintegration
Single closure definition to avoid overwriting in manual supplied derivatives cases
2 parents b2cbc1a + f78872b commit ba5e519

File tree

5 files changed

+38
-128
lines changed

5 files changed

+38
-128
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "OptimizationBase"
22
uuid = "bca83a33-5cc9-4baa-983d-23429ab6bcbb"
33
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
4-
version = "2.0.2"
4+
version = "2.0.3"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

ext/OptimizationEnzymeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
563563
return res_vjp
564564
end
565565
elseif cons_vjp == true && cons !== nothing
566-
cons_vjp! = (θ, σ) -> f.cons_vjp(θ, σ, p)
566+
cons_vjp! = (θ, v) -> f.cons_vjp(θ, v, p)
567567
else
568568
cons_vjp! = nothing
569569
end

ext/OptimizationZygoteExt.jl

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,7 @@ function OptimizationBase.instantiate_function(
4343
end
4444
end
4545
elseif g == true
46-
grad = (G, θ) -> f.grad(G, θ, p)
47-
if p !== SciMLBase.NullParameters() && p !== nothing
48-
grad = (G, θ, p) -> f.grad(G, θ, p)
49-
end
46+
grad = (G, θ, p = p) -> f.grad(G, θ, p)
5047
else
5148
grad = nothing
5249
end
@@ -67,10 +64,7 @@ function OptimizationBase.instantiate_function(
6764
end
6865
end
6966
elseif fg == true
70-
fg! = (G, θ) -> f.fg(G, θ, p)
71-
if p !== SciMLBase.NullParameters() && p !== nothing
72-
fg! = (G, θ, p) -> f.fg(G, θ, p)
73-
end
67+
fg! = (G, θ, p = p) -> f.fg(G, θ, p)
7468
else
7569
fg! = nothing
7670
end
@@ -89,10 +83,7 @@ function OptimizationBase.instantiate_function(
8983
end
9084
end
9185
elseif h == true
92-
hess = (H, θ) -> f.hess(H, θ, p)
93-
if p !== SciMLBase.NullParameters() && p !== nothing
94-
hess = (H, θ, p) -> f.hess(H, θ, p)
95-
end
86+
hess = (H, θ, p = p) -> f.hess(H, θ, p)
9687
else
9788
hess = nothing
9889
end
@@ -110,10 +101,7 @@ function OptimizationBase.instantiate_function(
110101
end
111102
end
112103
elseif fgh == true
113-
fgh! = (G, H, θ) -> f.fgh(G, H, θ, p)
114-
if p !== SciMLBase.NullParameters() && p !== nothing
115-
fgh! = (G, H, θ, p) -> f.fgh(G, H, θ, p)
116-
end
104+
fgh! = (G, H, θ, p = p) -> f.fgh(G, H, θ, p)
117105
else
118106
fgh! = nothing
119107
end
@@ -130,10 +118,7 @@ function OptimizationBase.instantiate_function(
130118
end
131119
end
132120
elseif hv == true
133-
hv! = (H, θ, v) -> f.hv(H, θ, v, p)
134-
if p !== SciMLBase.NullParameters() && p !== nothing
135-
hv! = (H, θ, v, p) -> f.hv(H, θ, v, p)
136-
end
121+
hv! = (H, θ, v, p = p) -> f.hv(H, θ, v, p)
137122
else
138123
hv! = nothing
139124
end
@@ -268,7 +253,7 @@ function OptimizationBase.instantiate_function(
268253
end
269254
end
270255
elseif cons !== nothing && lag_h == true
271-
lag_h! = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p)
256+
lag_h! = (res, θ, σ, μ, p = p) -> f.lag_h(res, θ, σ, μ, p)
272257
else
273258
lag_h! = nothing
274259
end
@@ -324,10 +309,7 @@ function OptimizationBase.instantiate_function(
324309
end
325310
end
326311
elseif g == true
327-
grad = (G, θ) -> f.grad(G, θ, p)
328-
if p !== SciMLBase.NullParameters() && p !== nothing
329-
grad = (G, θ, p) -> f.grad(G, θ, p)
330-
end
312+
grad = (G, θ, p = p) -> f.grad(G, θ, p)
331313
else
332314
grad = nothing
333315
end
@@ -348,10 +330,7 @@ function OptimizationBase.instantiate_function(
348330
end
349331
end
350332
elseif fg == true
351-
fg! = (G, θ) -> f.fg(G, θ, p)
352-
if p !== SciMLBase.NullParameters() && p !== nothing
353-
fg! = (G, θ, p) -> f.fg(G, θ, p)
354-
end
333+
fg! = (G, θ, p = p) -> f.fg(G, θ, p)
355334
else
356335
fg! = nothing
357336
end
@@ -373,10 +352,7 @@ function OptimizationBase.instantiate_function(
373352
end
374353
end
375354
elseif h == true
376-
hess = (H, θ) -> f.hess(H, θ, p)
377-
if p !== SciMLBase.NullParameters() && p !== nothing
378-
hess = (H, θ, p) -> f.hess(H, θ, p)
379-
end
355+
hess = (H, θ, p = p) -> f.hess(H, θ, p)
380356
else
381357
hess = nothing
382358
end
@@ -395,10 +371,7 @@ function OptimizationBase.instantiate_function(
395371
end
396372
end
397373
elseif fgh == true
398-
fgh! = (G, H, θ) -> f.fgh(G, H, θ, p)
399-
if p !== SciMLBase.NullParameters() && p !== nothing
400-
fgh! = (G, H, θ, p) -> f.fgh(G, H, θ, p)
401-
end
374+
fgh!(G, H, θ, p = p) = f.fgh(G, H, θ, p)
402375
else
403376
fgh! = nothing
404377
end
@@ -415,10 +388,7 @@ function OptimizationBase.instantiate_function(
415388
end
416389
end
417390
elseif hv == true
418-
hv! = (H, θ, v) -> f.hv(H, θ, v, p)
419-
if p !== SciMLBase.NullParameters() && p !== nothing
420-
hv! = (H, θ, v, p) -> f.hv(H, θ, v, p)
421-
end
391+
hv! = (H, θ, v, p = p) -> f.hv(H, θ, v, p)
422392
else
423393
hv! = nothing
424394
end
@@ -564,7 +534,7 @@ function OptimizationBase.instantiate_function(
564534
end
565535
end
566536
elseif cons !== nothing && cons_h == true
567-
lag_h! = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p)
537+
lag_h! = (res, θ, σ, μ, p = p) -> f.lag_h(res, θ, σ, μ, p)
568538
else
569539
lag_h! = nothing
570540
end

src/OptimizationDIExt.jl

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,7 @@ function instantiate_function(
5050
end
5151
end
5252
elseif g == true
53-
grad = (G, θ) -> f.grad(G, θ, p)
54-
if p !== SciMLBase.NullParameters() && p !== nothing
55-
grad = (G, θ, p) -> f.grad(G, θ, p)
56-
end
53+
grad = (G, θ, p = p) -> f.grad(G, θ, p)
5754
else
5855
grad = nothing
5956
end
@@ -74,10 +71,7 @@ function instantiate_function(
7471
end
7572
end
7673
elseif fg == true
77-
fg! = (G, θ) -> f.fg(G, θ, p)
78-
if p !== SciMLBase.NullParameters()
79-
fg! = (G, θ, p) -> f.fg(G, θ, p)
80-
end
74+
fg! = (G, θ, p = p) -> f.fg(G, θ, p)
8175
else
8276
fg! = nothing
8377
end
@@ -96,10 +90,7 @@ function instantiate_function(
9690
end
9791
end
9892
elseif h == true
99-
hess = (H, θ) -> f.hess(H, θ, p)
100-
if p !== SciMLBase.NullParameters() && p !== nothing
101-
hess = (H, θ, p) -> f.hess(H, θ, p)
102-
end
93+
hess = (H, θ, p = p) -> f.hess(H, θ, p)
10394
else
10495
hess = nothing
10596
end
@@ -119,10 +110,7 @@ function instantiate_function(
119110
end
120111
end
121112
elseif fgh == true
122-
fgh! = (G, H, θ) -> f.fgh(G, H, θ, p)
123-
if p !== SciMLBase.NullParameters() && p !== nothing
124-
fgh! = (G, H, θ, p) -> f.fgh(G, H, θ, p)
125-
end
113+
fgh! = (G, H, θ, p = p) -> f.fgh(G, H, θ, p)
126114
else
127115
fgh! = nothing
128116
end
@@ -139,10 +127,7 @@ function instantiate_function(
139127
end
140128
end
141129
elseif hv == true
142-
hv! = (H, θ, v) -> f.hv(H, θ, v, p)
143-
if p !== SciMLBase.NullParameters() && p !== nothing
144-
hv! = (H, θ, v, p) -> f.hv(H, θ, v, p)
145-
end
130+
hv! = (H, θ, v, p = p) -> f.hv(H, θ, v, p)
146131
else
147132
hv! = nothing
148133
end
@@ -277,7 +262,7 @@ function instantiate_function(
277262
end
278263
end
279264
elseif lag_h == true && cons !== nothing
280-
lag_h! = (res, θ, σ, μ) -> f.lag_h(res, θ, σ, μ, p)
265+
lag_h! = (res, θ, σ, μ, p = p) -> f.lag_h(res, θ, σ, μ, p)
281266
else
282267
lag_h! = nothing
283268
end
@@ -334,10 +319,7 @@ function instantiate_function(
334319
end
335320
end
336321
elseif g == true
337-
grad = (θ) -> f.grad(θ, p)
338-
if p !== SciMLBase.NullParameters() && p !== nothing
339-
grad = (θ, p) -> f.grad(θ, p)
340-
end
322+
grad = (θ, p = p) -> f.grad(θ, p)
341323
else
342324
grad = nothing
343325
end
@@ -358,10 +340,7 @@ function instantiate_function(
358340
end
359341
end
360342
elseif fg == true
361-
fg! = (θ) -> f.fg(θ, p)
362-
if p !== SciMLBase.NullParameters() && p !== nothing
363-
fg! = (θ, p) -> f.fg(θ, p)
364-
end
343+
fg! = (θ, p = p) -> f.fg(θ, p)
365344
else
366345
fg! = nothing
367346
end
@@ -380,10 +359,7 @@ function instantiate_function(
380359
end
381360
end
382361
elseif h == true
383-
hess = (θ) -> f.hess(θ, p)
384-
if p !== SciMLBase.NullParameters() && p !== nothing
385-
hess = (θ, p) -> f.hess(θ, p)
386-
end
362+
hess = (θ, p = p) -> f.hess(θ, p)
387363
else
388364
hess = nothing
389365
end
@@ -401,10 +377,7 @@ function instantiate_function(
401377
end
402378
end
403379
elseif fgh == true
404-
fgh! = (θ) -> f.fgh(θ, p)
405-
if p !== SciMLBase.NullParameters() && p !== nothing
406-
fgh! = (θ, p) -> f.fgh(θ, p)
407-
end
380+
fgh! = (θ, p = p) -> f.fgh(θ, p)
408381
else
409382
fgh! = nothing
410383
end
@@ -421,10 +394,7 @@ function instantiate_function(
421394
end
422395
end
423396
elseif hv == true
424-
hv! = (θ, v) -> f.hv(θ, v, p)
425-
if p !== SciMLBase.NullParameters() && p !== nothing
426-
hv! = (θ, v, p) -> f.hv(θ, v, p)
427-
end
397+
hv! = (θ, v, p = p) -> f.hv(θ, v, p)
428398
else
429399
hv! = nothing
430400
end
@@ -530,7 +500,7 @@ function instantiate_function(
530500
end
531501
end
532502
elseif lag_h == true && cons !== nothing
533-
lag_h! = (θ, σ, λ) -> f.lag_h(θ, σ, λ, p)
503+
lag_h! = (θ, σ, λ, p = p) -> f.lag_h(θ, σ, λ, p)
534504
else
535505
lag_h! = nothing
536506
end

0 commit comments

Comments
 (0)