Skip to content

Commit 82fd347

Browse files
Merge pull request #935 from mxpoch/master
added callbacks to pycma wrapper
2 parents d62ff13 + 053f6e5 commit 82fd347

File tree

3 files changed

+39
-19
lines changed

3 files changed

+39
-19
lines changed

docs/src/optimization_packages/pycma.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Pkg.add("OptimizationPyCMA")
1313

1414
## Methods
1515

16-
`PyCMAOpt` supports the usual keyword arguments `maxiters`, `maxtime`, `abstol`, `reltol` in addition to any PyCMA-specific options (passed verbatim via keyword arguments to `solve`).
16+
`PyCMAOpt` supports the usual keyword arguments `maxiters`, `maxtime`, `abstol`, `reltol`, `callback` in addition to any PyCMA-specific options (passed verbatim via keyword arguments to `solve`).
1717

1818
## Example
1919

lib/OptimizationPyCMA/src/OptimizationPyCMA.jl

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ end
2020
# Defining the SciMLBase interface for PyCMAOpt
2121
SciMLBase.allowsbounds(::PyCMAOpt) = true
2222
SciMLBase.supports_opt_cache_interface(opt::PyCMAOpt) = true
23-
SciMLBase.allowscallback(::PyCMAOpt) = false
23+
SciMLBase.allowscallback(::PyCMAOpt) = true
2424
SciMLBase.requiresgradient(::PyCMAOpt) = false
2525
SciMLBase.requireshessian(::PyCMAOpt) = false
2626
SciMLBase.requiresconsjac(::PyCMAOpt) = false
@@ -47,7 +47,7 @@ function __map_optimizer_args(prob::OptimizationCache, opt::PyCMAOpt;
4747

4848
# mapping Optimization.jl args
4949
mapped_args["bounds"] = (prob.lb, prob.ub)
50-
50+
5151
if !("verbose" keys(mapped_args))
5252
mapped_args["verbose"] = -1
5353
end
@@ -116,46 +116,57 @@ function SciMLBase.__solve(cache::OptimizationCache{
116116
}
117117
local x
118118

119-
# doing conversions
120-
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
121-
maxtime = Optimization._check_and_convert_maxtime(cache.solver_args.maxtime)
122-
123119
# wrapping the objective function
124120
_loss = function (θ)
125121
x = cache.f(θ, cache.p)
126122
return first(x)
127123
end
128124

125+
_cb = function(es)
126+
opt_state = Optimization.OptimizationState(; iter = pyconvert(Int, es.countiter),
127+
u = pyconvert(Vector{Float64}, es.best.x),
128+
objective = pyconvert(Float64, es.best.f),
129+
original = es)
130+
131+
cb_call = cache.callback(opt_state, x...)
132+
if !(cb_call isa Bool)
133+
error("The callback should return a boolean `halt` for whether to stop the optimization process.")
134+
end
135+
if cb_call
136+
es.opts.set(Dict("termination_callback" => es -> true))
137+
end
138+
end
139+
140+
# doing conversions
141+
maxiters = Optimization._check_and_convert_maxiters(cache.solver_args.maxiters)
142+
maxtime = Optimization._check_and_convert_maxtime(cache.solver_args.maxtime)
143+
129144
# converting the Optimization.jl Args to PyCMA format
130145
opt_args = __map_optimizer_args(cache, cache.opt; cache.solver_args...,
131146
maxiters = maxiters,
132147
maxtime = maxtime)
133148

134149
# init the CMAopt class
135150
es = get_cma().CMAEvolutionStrategy(cache.u0, 1, pydict(opt_args))
136-
logger = es.logger
137151

138152
# running the optimization
139153
t0 = time()
140-
opt_res = es.optimize(_loss)
154+
opt_res = es.optimize(_loss, callback = _cb)
141155
t1 = time()
142-
143-
# loading logged files from disk
144-
logger.load()
145-
156+
146157
# reading the results
147158
opt_ret_dict = opt_res.stop()
148159
retcode = __map_pycma_retcode(pyconvert(Dict{String, Any}, opt_ret_dict))
149160

150161
# logging and returning results of the optimization
151162
stats = Optimization.OptimizationStats(;
152-
iterations = length(logger.xmean),
163+
iterations = pyconvert(Int, es.countiter),
153164
time = t1 - t0,
154-
fevals = length(logger.xmean))
165+
fevals = pyconvert(Int, es.countevals))
155166

156167
SciMLBase.build_solution(cache, cache.opt,
157-
pyconvert(Float64, logger.xrecent[-1][-1]),
158-
pyconvert(Float64, logger.f[-1][-1]); original = opt_res,
168+
pyconvert(Vector{Float64}, opt_res.result.xbest),
169+
pyconvert(Float64, opt_res.result.fbest); original = opt_res,
159170
retcode = retcode,
160171
stats = stats)
161172
end

lib/OptimizationPyCMA/test/runtests.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@ using OptimizationPyCMA, Test
88
f = OptimizationFunction(rosenbrock)
99
prob = OptimizationProblem(f, x0, _p, lb = [-1.0, -1.0], ub = [0.8, 0.8])
1010
sol = solve(prob, PyCMAOpt())
11-
@test 10 * sol.objective < l1
12-
sol = solve(prob, PyCMAOpt(), maxiters = 100, verbose=-1, seed=42)
11+
@test 10 * sol.objective < l1
12+
13+
# test callback function
14+
callback = function (state, l)
15+
if state.iter > 10
16+
return true
17+
end
18+
return false
19+
end
20+
21+
sol = solve(prob, PyCMAOpt(), callback = callback, maxiters = 25, verbose=-1, seed=42)
1322
end

0 commit comments

Comments
 (0)