Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit 7898ad1

Browse files
Fix jacobians
1 parent 94c638a commit 7898ad1

File tree

1 file changed

+45
-17
lines changed

1 file changed

+45
-17
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,6 @@ function hv_f2_alloc(x, f, p, args...)
4040
return dx
4141
end
4242

43-
function cons_oop(x, f, p, args...)
44-
res = zeros(eltype(x), size(x, 1))
45-
f(res, x, p, args...)
46-
return res
47-
end
48-
4943
function inner_cons(x, fcons::Function, p::Union{SciMLBase.NullParameters, Nothing},
5044
num_cons::Int, i::Int, args::Vararg{Any, N}) where {N}
5145
res = zeros(eltype(x), num_cons)
@@ -133,10 +127,23 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
133127
end
134128

135129
if cons !== nothing && f.cons_j === nothing
130+
seeds = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x)))))
131+
Jaccache = Tuple(zeros(eltype(x), num_cons) for i in 1:length(x))
132+
y = zeros(eltype(x), num_cons)
136133
cons_j = function (J, θ, args...)
137-
return Enzyme.autodiff(Enzyme.Forward, cons_oop,
138-
BatchDuplicated(θ, Tuple(J[i, :] for i in 1:num_cons)),
139-
Const(f.cons), Const(p), Const.(args)...)
134+
for i in 1:num_cons
135+
Enzyme.make_zero!(Jaccache[i])
136+
end
137+
Enzyme.make_zero!(y)
138+
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
139+
BatchDuplicated(θ, seeds), Const(p), Const.(args)...)[1]
140+
for i in 1:length(θ)
141+
if J isa Vector
142+
J[i] = Jaccache[i][1]
143+
else
144+
copyto!(@view(J[:, i]), Jaccache[i])
145+
end
146+
end
140147
end
141148
else
142149
cons_j = (J, θ, args...) -> f.cons_j(J, θ, p, args...)
@@ -240,10 +247,24 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
240247
end
241248

242249
if cons !== nothing && f.cons_j === nothing
250+
seeds = Tuple((Array(r)
251+
for r in eachrow(I(length(cache.u0)) * one(eltype(cache.u0)))))
252+
Jaccache = Tuple(zeros(eltype(cache.u0), num_cons) for i in 1:length(cache.u0))
253+
y = zeros(eltype(cache.u0), num_cons)
243254
cons_j = function (J, θ, args...)
244-
Enzyme.autodiff(Enzyme.Forward, cons_oop,
245-
BatchDuplicated(θ, Tuple(J[i, :] for i in 1:num_cons)),
246-
Const(f.cons), Const(p), Const.(args)...)
255+
for i in 1:num_cons
256+
Enzyme.make_zero!(Jaccache[i])
257+
end
258+
Enzyme.make_zero!(y)
259+
Enzyme.autodiff(Enzyme.Forward, f.cons, BatchDuplicated(y, Jaccache),
260+
BatchDuplicated(θ, seeds), Const(p), Const.(args)...)[1]
261+
for i in 1:length(θ)
262+
if J isa Vector
263+
J[i] = Jaccache[i][1]
264+
else
265+
copyto!(@view(J[:, i]), Jaccache[i])
266+
end
267+
end
247268
end
248269
else
249270
cons_j = (J, θ, args...) -> f.cons_j(J, θ, p, args...)
@@ -348,11 +369,15 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
348369
end
349370

350371
if f.cons !== nothing && f.cons_j === nothing
372+
seeds = Tuple((Array(r) for r in eachrow(I(length(x)) * one(eltype(x)))))
351373
cons_j = function (θ, args...)
352-
J = Tuple(zeros(eltype(θ), length(θ)) for i in 1:num_cons)
353-
Enzyme.autodiff(
354-
Enzyme.Forward, f.cons, BatchDuplicated(θ, J), Const(p), Const.(args)...)
355-
return reduce(vcat, reshape.(J, Ref(1), Ref(length(θ))))
374+
J = Enzyme.autodiff(
375+
Enzyme.Forward, f.cons, BatchDuplicated(θ, seeds), Const(p), Const.(args)...)[1]
376+
if num_cons == 1
377+
return reduce(vcat, J)
378+
else
379+
return reduce(hcat, J)
380+
end
356381
end
357382
else
358383
cons_j = (θ) -> f.cons_j(θ, p)
@@ -460,8 +485,11 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
460485
end
461486

462487
if f.cons !== nothing && f.cons_j === nothing
488+
J = Tuple(zeros(eltype(cache.u0), length(cache.u0)) for i in 1:num_cons)
463489
cons_j = function (θ, args...)
464-
J = Tuple(zeros(eltype(θ), length(θ)) for i in 1:num_cons)
490+
for i in 1:num_cons
491+
Enzyme.make_zero!(J[i])
492+
end
465493
Enzyme.autodiff(
466494
Enzyme.Forward, f.cons, BatchDuplicated(θ, J), Const(p), Const.(args)...)
467495
return reduce(vcat, reshape.(J, Ref(1), Ref(length(θ))))

0 commit comments

Comments
 (0)