Skip to content
This repository was archived by the owner on Aug 25, 2025. It is now read-only.

Commit b45eebc

Browse files
update tests
1 parent d0a3d32 commit b45eebc

File tree

4 files changed

+43
-34
lines changed

4 files changed

+43
-34
lines changed

src/OptimizationBase.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ end
1010

1111
using ArrayInterface, Base.Iterators, SparseArrays, LinearAlgebra
1212
using SymbolicIndexingInterface
13-
using SymbolicAnalysis: propagate_sign, propagate_curvature, propagate_gcurvature
13+
using SymbolicAnalysis: propagate_sign, propagate_curvature, propagate_gcurvature,
14+
getcurvature, getgcurvature, getsign
1415
import Symbolics
1516
import Manifolds
1617
import Symbolics: variable, Equation, Inequality, unwrap, @variables

src/cache.jl

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,24 +22,27 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data = DEFA
2222
abstol::Union{Number, Nothing} = nothing,
2323
reltol::Union{Number, Nothing} = nothing,
2424
progress = false,
25-
structural_analysis = true,
25+
structural_analysis = false,
26+
manifold = nothing,
2627
kwargs...)
2728
reinit_cache = OptimizationBase.ReInitCache(prob.u0, prob.p)
2829
num_cons = prob.ucons === nothing ? 0 : length(prob.ucons)
2930
f = OptimizationBase.instantiate_function(prob.f, reinit_cache, prob.f.adtype, num_cons)
3031

31-
if (f.sys === nothing || f.sys isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) && structural_analysis
32+
if (f.sys === nothing ||
33+
f.sys isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) &&
34+
structural_analysis
3235
try
33-
vars =
34-
if prob.u0 isa Matrix
36+
vars = if prob.u0 isa Matrix
3537
@variables X[1:size(prob.u0, 1), 1:size(prob.u0, 2)]
3638
else
37-
ArrayInterface.restructure(prob.u0, [variable(:x, i) for i in eachindex(prob.u0)])
39+
ArrayInterface.restructure(
40+
prob.u0, [variable(:x, i) for i in eachindex(prob.u0)])
3841
end
3942
params = if prob.p isa SciMLBase.NullParameters
4043
[]
41-
# elseif prob.p isa MTK.MTKParameters
42-
# [variable(:α, i) for i in eachindex(vcat(p...))]
44+
elseif prob.p isa MTK.MTKParameters
45+
[variable(, i) for i in eachindex(vcat(p...))]
4346
else
4447
ArrayInterface.restructure(p, [variable(, i) for i in eachindex(p)])
4548
end
@@ -87,7 +90,7 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data = DEFA
8790
cons_expr = nothing
8891
end
8992
catch err
90-
throw(ArgumentError("Automatic symbolic expression generation with ModelingToolkit failed with error: $err.
93+
throw(ArgumentError("Automatic symbolic expression generation with failed with error: $err.
9194
Try by setting `structural_analysis = false` instead if the solver doesn't require symbolic expressions."))
9295
end
9396
else
@@ -96,25 +99,28 @@ function OptimizationCache(prob::SciMLBase.OptimizationProblem, opt, data = DEFA
9699
obj_expr = f.expr
97100
cons_expr = f.cons_expr
98101
end
99-
try
100-
obj_expr = obj_expr |> Symbolics.unwrap
101-
obj_expr = propagate_curvature(propagate_sign(obj_expr))
102-
@info "Objective Euclidean curvature: $(SymbolicAnalysis.getcurvature(obj_expr))"
103-
catch
104-
@info "No euclidean atom available"
105-
end
106102

107-
try
108-
obj_expr = SymbolicAnalysis.propagate_gcurvature(propagate_sign(obj_expr), prob.kwargs[1])
109-
@info "Objective Geodesic curvature: $(SymbolicAnalysis.getgcurvature(obj_expr))"
110-
catch e
111-
@show e
103+
if obj_expr !== nothing
104+
try
105+
obj_expr = obj_expr |> Symbolics.unwrap
106+
obj_expr = propagate_curvature(propagate_sign(obj_expr))
107+
@info "Objective Euclidean curvature: $(getcurvature(obj_expr))"
108+
catch
109+
@info "No euclidean atom available"
110+
end
111+
112+
try
113+
obj_expr = propagate_gcurvature(propagate_sign(obj_expr), manifold)
114+
@info "Objective Geodesic curvature: $(getgcurvature(obj_expr))"
115+
catch
116+
@info "No geodesic atom available"
117+
end
112118
end
113119

114-
if !isnothing(cons_expr)
120+
if cons_expr !== nothing
115121
cons_expr = cons_expr .|> Symbolics.unwrap
116122
cons_expr = propagate_curvature.(propagate_sign.(cons_expr))
117-
@info "Constraints Euclidean curvature: $(SymbolicAnalysis.getcurvature.(cons_expr))"
123+
@info "Constraints Euclidean curvature: $(getcurvature.(cons_expr))"
118124
end
119125

120126
return OptimizationCache(f, reinit_cache, prob.lb, prob.ub, prob.lcons,

test/adtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ optprob.hess(H2, x0)
521521
@test optprob.hess(x0) == H1
522522
@test optprob.cons(x0) == [0.0, 0.0]
523523
@test optprob.cons_j([5.0, 3.0])[10.0 6.0; -0.149013 -0.958924] rtol=1e-6
524-
@test_broken optprob.cons_h(x0) == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]]
524+
@test optprob.cons_h(x0) == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]]
525525

526526
cons = (x, p) -> [x[1]^2 + x[2]^2]
527527
optf = OptimizationFunction{false}(rosenbrock,

test/cvxtest.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
using Optimization, OptimizationBase, ForwardDiff, SymbolicAnalysis, LinearAlgebra, Manifolds, OptimizationManopt
1+
using Optimization, OptimizationBase, ForwardDiff, SymbolicAnalysis, LinearAlgebra,
2+
Manifolds, OptimizationManopt
23

34
function f(x, p = nothing)
45
return exp(x[1]) + x[1]^2
56
end
67

78
optf = OptimizationFunction(f, Optimization.AutoForwardDiff())
8-
prob = OptimizationProblem(optf, [0.4])
9+
prob = OptimizationProblem(optf, [0.4], structural_analysis = true)
910

1011
@time sol = solve(prob, Optimization.LBFGS(), maxiters = 1000)
1112

@@ -14,28 +15,29 @@ rosenbrock(x, p = nothing) = (1 - x[1])^2 + 100 * (x[2] - x[1]^2)^2
1415
l1 = rosenbrock(x0)
1516

1617
optf = OptimizationFunction(rosenbrock, AutoEnzyme())
17-
prob = OptimizationProblem(optf, x0)
18+
prob = OptimizationProblem(optf, x0, structural_analysis = true)
1819
@time res = solve(prob, Optimization.LBFGS(), maxiters = 100)
1920

2021
function con2_c(res, x, p)
21-
res .= [x[1]^2 + x[2]^2, (x[2] * sin(x[1]) + x[1])-5]
22+
res .= [x[1]^2 + x[2]^2, (x[2] * sin(x[1]) + x[1]) - 5]
2223
end
2324

2425
optf = OptimizationFunction(rosenbrock, AutoZygote(), cons = con2_c)
25-
prob = OptimizationProblem(optf, x0, lcons = [1.0, -Inf], ucons = [1.0, 0.0], lb = [-1.0, -1.0], ub = [1.0, 1.0])
26+
prob = OptimizationProblem(optf, x0, lcons = [1.0, -Inf], ucons = [1.0, 0.0],
27+
lb = [-1.0, -1.0], ub = [1.0, 1.0], structural_analysis = true)
2628
@time res = solve(prob, Optimization.LBFGS(), maxiters = 100)
2729

2830
m = 100
2931
σ = 0.005
3032
q = Matrix{Float64}(LinearAlgebra.I(5)) .+ 2.0
3133

3234
M = SymmetricPositiveDefinite(5)
33-
data2 = [exp(M, q, σ * rand(M; vector_at=q)) for i in 1:m];
35+
data2 = [exp(M, q, σ * rand(M; vector_at = q)) for i in 1:m];
3436

3537
f(x, p = nothing) = sum(SymbolicAnalysis.distance(M, data2[i], x)^2 for i in 1:5)
36-
optf = OptimizationFunction(f, Optimization.AutoZygote())
37-
prob = OptimizationProblem(optf, data2[1]; manifold = M)
38+
optf = OptimizationFunction(f, Optimization.AutoForwardDiff())
39+
prob = OptimizationProblem(optf, data2[1]; manifold = M, structural_analysis = true)
3840

3941
opt = OptimizationManopt.GradientDescentOptimizer()
40-
@time sol = solve(prob, opt, maxiters = 100)
41-
@test sol.minimizer < 1e-1
42+
@time sol = solve(prob, Optimization.LBFGS(), maxiters = 100)
43+
@test sol.minimizer < 1e-1

0 commit comments

Comments
 (0)