Skip to content

Added OptimizationState to OptimizationBase.jl #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,21 @@ version = "2.10.0"
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand All @@ -42,23 +47,28 @@ OptimizationZygoteExt = "Zygote"
[compat]
ADTypes = "1.9"
ArrayInterface = "7.6"
ConsoleProgressMonitor = "0.1.2"
DifferentiationInterface = "0.7"
DocStringExtensions = "0.9"
Enzyme = "0.13.2"
FastClosures = "0.3"
FiniteDiff = "2.12"
ForwardDiff = "0.10.26, 1"
LinearAlgebra = "1.9, 1.10"
Logging = "1.11.0"
LoggingExtras = "1.1.0"
MLDataDevices = "1"
MLUtils = "0.4"
ModelingToolkit = "9, 10"
PDMats = "0.11"
ProgressLogging = "0.1.5"
Reexport = "1.2"
ReverseDiff = "1.14"
SciMLBase = "2"
SparseConnectivityTracer = "0.6, 1"
SparseMatrixColorings = "0.4"
SymbolicAnalysis = "0.3"
TerminalLoggers = "0.1.7"
Zygote = "0.6.67, 0.7"
julia = "1.10"

Expand Down
4 changes: 3 additions & 1 deletion src/OptimizationBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module OptimizationBase
using DocStringExtensions
using Reexport
@reexport using SciMLBase, ADTypes

using Logging, ProgressLogging, ConsoleProgressMonitor, TerminalLoggers, LoggingExtras
using ArrayInterface, Base.Iterators, SparseArrays, LinearAlgebra
import SciMLBase: OptimizationProblem,
OptimizationFunction, ObjSense,
Expand All @@ -24,6 +24,8 @@ Base.length(::NullData) = 0
include("adtypes.jl")
include("symify.jl")
include("cache.jl")
include("state.jl")
include("utils.jl")
include("OptimizationDIExt.jl")
include("OptimizationDISparseExt.jl")
include("function.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/augmented_lagrangian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function generate_auglag(θ)
cache.f.cons(cons_tmp, θ)
cons_tmp[eq_inds] .= cons_tmp[eq_inds] - cache.lcons[eq_inds]
cons_tmp[ineq_inds] .= cons_tmp[ineq_inds] .- cache.ucons[ineq_inds]
opt_state = Optimization.OptimizationState(u = θ, objective = x[1])
opt_state = OptimizationBase.OptimizationState(u = θ, objective = x[1])
if cache.callback(opt_state, x...)
error("Optimization halted by callback.")
end
Expand Down
30 changes: 30 additions & 0 deletions src/state.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""
$(TYPEDEF)

Stores the optimization run's state at the current iteration
and is passed to the callback function as the first argument.

## Fields

- `iter`: current iteration
- `u`: current solution
- `objective`: current objective value
- `gradient`: current gradient
- `hessian`: current hessian
- `original`: if the solver has its own state object then it is stored here
- `p`: optimization parameters
"""
struct OptimizationState{X, O, G, H, S, P}
iter::Int
u::X
objective::O
grad::G
hess::H
original::S
p::P
end

function OptimizationState(; iter = 0, u = nothing, objective = nothing,
grad = nothing, hess = nothing, original = nothing, p = nothing)
OptimizationState(iter, u, objective, grad, hess, original, p)
end
132 changes: 132 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
function get_maxiters(data)
Iterators.IteratorSize(typeof(DEFAULT_DATA)) isa Iterators.IsInfinite ||
Iterators.IteratorSize(typeof(DEFAULT_DATA)) isa Iterators.SizeUnknown ?
typemax(Int) : length(data)
end

maybe_with_logger(f, logger) = logger === nothing ? f() : Logging.with_logger(f, logger)

function default_logger(logger)
Logging.min_enabled_level(logger) ≤ ProgressLogging.ProgressLevel && return nothing
if Sys.iswindows() || (isdefined(Main, :IJulia) && Main.IJulia.inited)
progresslogger = ConsoleProgressMonitor.ProgressLogger()
else
progresslogger = TerminalLoggers.TerminalLogger()
end
logger1 = LoggingExtras.EarlyFilteredLogger(progresslogger) do log
log.level == ProgressLogging.ProgressLevel
end
logger2 = LoggingExtras.EarlyFilteredLogger(logger) do log
log.level != ProgressLogging.ProgressLevel
end
LoggingExtras.TeeLogger(logger1, logger2)
end

macro withprogress(progress, exprs...)
quote
if $progress
$maybe_with_logger($default_logger($Logging.current_logger())) do
$ProgressLogging.@withprogress $(exprs...)
end
else
$(exprs[end])
end
end |> esc
end

decompose_trace(trace) = trace

function _check_and_convert_maxiters(maxiters)
if !(isnothing(maxiters)) && maxiters <= 0.0
error("The number of maxiters has to be a non-negative and non-zero number.")
elseif !(isnothing(maxiters))
return convert(Int, round(maxiters))
end
end

function _check_and_convert_maxtime(maxtime)
if !(isnothing(maxtime)) && maxtime <= 0.0
error("The maximum time has to be a non-negative and non-zero number.")
elseif !(isnothing(maxtime))
return convert(Float32, maxtime)
end
end

# RetCode handling for BBO and others.
using SciMLBase: ReturnCode

# Define a dictionary to map regular expressions to ReturnCode values
const STOP_REASON_MAP = Dict(
r"Delta fitness .* below tolerance .*" => ReturnCode.Success,
r"Fitness .* within tolerance .* of optimum" => ReturnCode.Success,
r"CONVERGENCE: NORM_OF_PROJECTED_GRADIENT_<=_PGTOL" => ReturnCode.Success,
r"^CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR\*EPSMCH\s*$" => ReturnCode.Success,
r"Terminated" => ReturnCode.Terminated,
r"MaxIters|MAXITERS_EXCEED|Max number of steps .* reached" => ReturnCode.MaxIters,
r"MaxTime|TIME_LIMIT" => ReturnCode.MaxTime,
r"Max time" => ReturnCode.MaxTime,
r"DtLessThanMin" => ReturnCode.DtLessThanMin,
r"Unstable" => ReturnCode.Unstable,
r"InitialFailure" => ReturnCode.InitialFailure,
r"ConvergenceFailure|ITERATION_LIMIT" => ReturnCode.ConvergenceFailure,
r"Infeasible|INFEASIBLE|DUAL_INFEASIBLE|LOCALLY_INFEASIBLE|INFEASIBLE_OR_UNBOUNDED" => ReturnCode.Infeasible,
r"TOTAL NO. of ITERATIONS REACHED LIMIT" => ReturnCode.MaxIters,
r"TOTAL NO. of f AND g EVALUATIONS EXCEEDS LIMIT" => ReturnCode.MaxIters,
r"ABNORMAL_TERMINATION_IN_LNSRCH" => ReturnCode.Unstable,
r"ERROR INPUT DATA" => ReturnCode.InitialFailure,
r"FTOL.TOO.SMALL" => ReturnCode.ConvergenceFailure,
r"GTOL.TOO.SMALL" => ReturnCode.ConvergenceFailure,
r"XTOL.TOO.SMALL" => ReturnCode.ConvergenceFailure,
r"STOP: TERMINATION" => ReturnCode.Terminated,
r"Optimization completed" => ReturnCode.Success,
r"Convergence achieved" => ReturnCode.Success,
r"ROUNDOFF_LIMITED" => ReturnCode.Success
)

# Function to deduce ReturnCode from a stop_reason string using the dictionary
function deduce_retcode(stop_reason::String)
for (pattern, retcode) in STOP_REASON_MAP
if occursin(pattern, stop_reason)
return retcode
end
end
@warn "Unrecognized stop reason: $stop_reason. Defaulting to ReturnCode.Default."
return ReturnCode.Default
end

# Function to deduce ReturnCode from a Symbol
function deduce_retcode(retcode::Symbol)
if retcode == :Default || retcode == :DEFAULT
return ReturnCode.Default
elseif retcode == :Success || retcode == :EXACT_SOLUTION_LEFT ||
retcode == :FLOATING_POINT_LIMIT || retcode == :true || retcode == :OPTIMAL ||
retcode == :LOCALLY_SOLVED || retcode == :ROUNDOFF_LIMITED ||
retcode == :SUCCESS ||
retcode == :STOPVAL_REACHED || retcode == :FTOL_REACHED ||
retcode == :XTOL_REACHED
return ReturnCode.Success
elseif retcode == :Terminated
return ReturnCode.Terminated
elseif retcode == :MaxIters || retcode == :MAXITERS_EXCEED ||
retcode == :MAXEVAL_REACHED
return ReturnCode.MaxIters
elseif retcode == :MaxTime || retcode == :TIME_LIMIT || retcode == :MAXTIME_REACHED
return ReturnCode.MaxTime
elseif retcode == :DtLessThanMin
return ReturnCode.DtLessThanMin
elseif retcode == :Unstable
return ReturnCode.Unstable
elseif retcode == :InitialFailure
return ReturnCode.InitialFailure
elseif retcode == :ConvergenceFailure || retcode == :ITERATION_LIMIT
return ReturnCode.ConvergenceFailure
elseif retcode == :Failure || retcode == :false
return ReturnCode.Failure
elseif retcode == :Infeasible || retcode == :INFEASIBLE ||
retcode == :DUAL_INFEASIBLE || retcode == :LOCALLY_INFEASIBLE ||
retcode == :INFEASIBLE_OR_UNBOUNDED
return ReturnCode.Infeasible
else
return ReturnCode.Failure
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ using Test
include("adtests.jl")
include("cvxtest.jl")
include("matrixvalued.jl")
include("utilstest.jl")
end
Loading