@@ -40,12 +40,6 @@ function hv_f2_alloc(x, f, p, args...)
40
40
return dx
41
41
end
42
42
43
- function cons_oop (x, f, p, args... )
44
- res = zeros (eltype (x), size (x, 1 ))
45
- f (res, x, p, args... )
46
- return res
47
- end
48
-
49
43
function inner_cons (x, fcons:: Function , p:: Union{SciMLBase.NullParameters, Nothing} ,
50
44
num_cons:: Int , i:: Int , args:: Vararg{Any, N} ) where {N}
51
45
res = zeros (eltype (x), num_cons)
@@ -133,10 +127,23 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
133
127
end
134
128
135
129
if cons != = nothing && f. cons_j === nothing
130
+ seeds = Tuple ((Array (r) for r in eachrow (I (length (x)) * one (eltype (x)))))
131
+ Jaccache = Tuple (zeros (eltype (x), num_cons) for i in 1 : length (x))
132
+ y = zeros (eltype (x), num_cons)
136
133
cons_j = function (J, θ, args... )
137
- return Enzyme. autodiff (Enzyme. Forward, cons_oop,
138
- BatchDuplicated (θ, Tuple (J[i, :] for i in 1 : num_cons)),
139
- Const (f. cons), Const (p), Const .(args)... )
134
+ for i in 1 : num_cons
135
+ Enzyme. make_zero! (Jaccache[i])
136
+ end
137
+ Enzyme. make_zero! (y)
138
+ Enzyme. autodiff (Enzyme. Forward, f. cons, BatchDuplicated (y, Jaccache),
139
+ BatchDuplicated (θ, seeds), Const (p), Const .(args)... )[1 ]
140
+ for i in 1 : length (θ)
141
+ if J isa Vector
142
+ J[i] = Jaccache[i][1 ]
143
+ else
144
+ copyto! (@view (J[:, i]), Jaccache[i])
145
+ end
146
+ end
140
147
end
141
148
else
142
149
cons_j = (J, θ, args... ) -> f. cons_j (J, θ, p, args... )
@@ -240,10 +247,24 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
240
247
end
241
248
242
249
if cons != = nothing && f. cons_j === nothing
250
+ seeds = Tuple ((Array (r)
251
+ for r in eachrow (I (length (cache. u0)) * one (eltype (cache. u0)))))
252
+ Jaccache = Tuple (zeros (eltype (cache. u0), num_cons) for i in 1 : length (cache. u0))
253
+ y = zeros (eltype (cache. u0), num_cons)
243
254
cons_j = function (J, θ, args... )
244
- Enzyme. autodiff (Enzyme. Forward, cons_oop,
245
- BatchDuplicated (θ, Tuple (J[i, :] for i in 1 : num_cons)),
246
- Const (f. cons), Const (p), Const .(args)... )
255
+ for i in 1 : num_cons
256
+ Enzyme. make_zero! (Jaccache[i])
257
+ end
258
+ Enzyme. make_zero! (y)
259
+ Enzyme. autodiff (Enzyme. Forward, f. cons, BatchDuplicated (y, Jaccache),
260
+ BatchDuplicated (θ, seeds), Const (p), Const .(args)... )[1 ]
261
+ for i in 1 : length (θ)
262
+ if J isa Vector
263
+ J[i] = Jaccache[i][1 ]
264
+ else
265
+ copyto! (@view (J[:, i]), Jaccache[i])
266
+ end
267
+ end
247
268
end
248
269
else
249
270
cons_j = (J, θ, args... ) -> f. cons_j (J, θ, p, args... )
@@ -348,11 +369,15 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
348
369
end
349
370
350
371
if f. cons != = nothing && f. cons_j === nothing
372
+ seeds = Tuple ((Array (r) for r in eachrow (I (length (x)) * one (eltype (x)))))
351
373
cons_j = function (θ, args... )
352
- J = Tuple (zeros (eltype (θ), length (θ)) for i in 1 : num_cons)
353
- Enzyme. autodiff (
354
- Enzyme. Forward, f. cons, BatchDuplicated (θ, J), Const (p), Const .(args)... )
355
- return reduce (vcat, reshape .(J, Ref (1 ), Ref (length (θ))))
374
+ J = Enzyme. autodiff (
375
+ Enzyme. Forward, f. cons, BatchDuplicated (θ, seeds), Const (p), Const .(args)... )[1 ]
376
+ if num_cons == 1
377
+ return reduce (vcat, J)
378
+ else
379
+ return reduce (hcat, J)
380
+ end
356
381
end
357
382
else
358
383
cons_j = (θ) -> f. cons_j (θ, p)
@@ -460,8 +485,11 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
460
485
end
461
486
462
487
if f. cons != = nothing && f. cons_j === nothing
488
+ J = Tuple (zeros (eltype (cache. u0), length (cache. u0)) for i in 1 : num_cons)
463
489
cons_j = function (θ, args... )
464
- J = Tuple (zeros (eltype (θ), length (θ)) for i in 1 : num_cons)
490
+ for i in 1 : num_cons
491
+ Enzyme. make_zero! (J[i])
492
+ end
465
493
Enzyme. autodiff (
466
494
Enzyme. Forward, f. cons, BatchDuplicated (θ, J), Const (p), Const .(args)... )
467
495
return reduce (vcat, reshape .(J, Ref (1 ), Ref (length (θ))))
0 commit comments