Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit 94c638a

Browse files
explicit tuple
1 parent 8a4d5b9 commit 94c638a

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

ext/OptimizationEnzymeExt.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true}, x,
134134

135135
if cons !== nothing && f.cons_j === nothing
136136
cons_j = function (J, θ, args...)
137-
return Enzyme.autodiff(Enzyme.Forward, cons_oop, BatchDuplicated(θ, (J[i, :] for i in 1:num_cons)),
137+
return Enzyme.autodiff(Enzyme.Forward, cons_oop,
138+
BatchDuplicated(θ, Tuple(J[i, :] for i in 1:num_cons)),
138139
Const(f.cons), Const(p), Const.(args)...)
139140
end
140141
else
@@ -240,7 +241,8 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{true},
240241

241242
if cons !== nothing && f.cons_j === nothing
242243
cons_j = function (J, θ, args...)
243-
Enzyme.autodiff(Enzyme.Forward, cons_oop, BatchDuplicated(θ, (J[i, :] for i in 1:num_cons)),
244+
Enzyme.autodiff(Enzyme.Forward, cons_oop,
245+
BatchDuplicated(θ, Tuple(J[i, :] for i in 1:num_cons)),
244246
Const(f.cons), Const(p), Const.(args)...)
245247
end
246248
else
@@ -347,7 +349,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false}, x
347349

348350
if f.cons !== nothing && f.cons_j === nothing
349351
cons_j = function (θ, args...)
350-
J = (zeros(eltype(θ), length(θ)) for i in 1:num_cons)
352+
J = Tuple(zeros(eltype(θ), length(θ)) for i in 1:num_cons)
351353
Enzyme.autodiff(
352354
Enzyme.Forward, f.cons, BatchDuplicated(θ, J), Const(p), Const.(args)...)
353355
return reduce(vcat, reshape.(J, Ref(1), Ref(length(θ))))
@@ -459,7 +461,7 @@ function OptimizationBase.instantiate_function(f::OptimizationFunction{false},
459461

460462
if f.cons !== nothing && f.cons_j === nothing
461463
cons_j = function (θ, args...)
462-
J = (zeros(eltype(θ), length(θ)) for i in 1:num_cons)
464+
J = Tuple(zeros(eltype(θ), length(θ)) for i in 1:num_cons)
463465
Enzyme.autodiff(
464466
Enzyme.Forward, f.cons, BatchDuplicated(θ, J), Const(p), Const.(args)...)
465467
return reduce(vcat, reshape.(J, Ref(1), Ref(length(θ))))

0 commit comments

Comments
 (0)