Skip to content

Commit 8874af3

Browse files
committed
reduce allocation NonLinMPC
1 parent c00365f commit 8874af3

File tree

2 files changed

+47
-46
lines changed

2 files changed

+47
-46
lines changed

src/controller/execute.jl

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,11 @@ julia> round.(getinfo(mpc)[:Ŷ], digits=3)
104104
"""
105105
function getinfo(mpc::PredictiveController{NT}) where NT<:Real
106106
info = Dict{Symbol, Union{JuMP._SolutionSummary, Vector{NT}, NT}}()
107-
Ŷ, x̂, u0 = similar(mpc.Ŷop), similar(mpc.estim.x̂), similar(mpc.estim.lastu0)
108-
Ŷ, x̂end = predict!(Ŷ, x̂, u0, mpc, mpc.estim.model, mpc.ΔŨ)
107+
Ŷ, x̂, u = similar(mpc.Ŷop), similar(mpc.estim.x̂), similar(mpc.estim.lastu0)
108+
Ŷ, x̂end = predict!(Ŷ, x̂, u, mpc, mpc.estim.model, mpc.ΔŨ)
109109
U = mpc.*mpc.ΔŨ + mpc.T*(mpc.estim.lastu0 + mpc.estim.model.uop)
110110
Ȳ, Ū = similar(Ŷ), similar(U)
111-
J = obj_nonlinprog!(Ȳ, Ū, mpc, mpc.estim.model, Ŷ, mpc.ΔŨ)
111+
J = obj_nonlinprog!(Ȳ, Ū, u, mpc, mpc.estim.model, Ŷ, mpc.ΔŨ)
112112
info[:ΔU] = mpc.ΔŨ[1:mpc.Hc*mpc.estim.model.nu]
113113
info[] = isinf(mpc.C) ? NaN : mpc.ΔŨ[end]
114114
info[:J] = J
@@ -284,20 +284,20 @@ function predict!(
284284
end
285285

286286
@doc raw"""
287-
predict!(Ŷ, x̂, u0, mpc::PredictiveController, model::SimModel, ΔŨ) -> Ŷ, x̂end
287+
predict!(Ŷ, x̂, u, mpc::PredictiveController, model::SimModel, ΔŨ) -> Ŷ, x̂end
288288
289289
Compute both vectors if `model` is not a [`LinModel`](@ref).
290290
291-
The method mutates `Ŷ`, `x̂` and `u0` arguments. The latter is the manipulated input without
292-
the operating points ``\mathbf{u_0}(k) = \mathbf{u}(k) - \mathbf{u_{op}}(k)``.
291+
The method mutates `Ŷ`, `x̂` and `u` arguments.
293292
"""
294293
function predict!(
295-
Ŷ, x̂, u0, mpc::PredictiveController, model::SimModel, ΔŨ::Vector{NT}
294+
Ŷ, x̂, u, mpc::PredictiveController, model::SimModel, ΔŨ::Vector{NT}
296295
) where {NT<:Real}
297296
nu, ny, nd, Hp, Hc = model.nu, model.ny, model.nd, mpc.Hp, mpc.Hc
297+
u0 = u
298298
x̂ .= mpc.estim.
299299
u0 .= mpc.estim.lastu0
300-
d0 = @views mpc.d0[1:end]
300+
d0 = @views mpc.d0[1:end]
301301
for j=1:Hp
302302
if j Hc
303303
u0 .+= @views ΔŨ[(1 + nu*(j-1)):(nu*j)]
@@ -312,21 +312,23 @@ function predict!(
312312
end
313313

314314
"""
315-
obj_nonlinprog!(_ , _ , mpc::PredictiveController, model::LinModel, Ŷ, ΔŨ)
315+
obj_nonlinprog!( _ , _ , u , mpc::PredictiveController, model::LinModel, Ŷ, ΔŨ)
316316
317317
Nonlinear programming objective function when `model` is a [`LinModel`](@ref).
318318
319319
The function is called by the nonlinear optimizer of [`NonLinMPC`](@ref) controllers. It can
320320
also be called on any [`PredictiveController`](@ref)s to evaluate the objective function `J`
321-
at specific input increments `ΔŨ` and predictions `Ŷ` values. This method does not mutate
322-
its argument.
321+
at specific input increments `ΔŨ` and predictions `Ŷ` values. This method mutate `u`
322+
argument.
323323
"""
324324
function obj_nonlinprog!(
325-
_ , _ , mpc::PredictiveController, model::LinModel, Ŷ, ΔŨ::Vector{NT}
325+
_ , _ , u , mpc::PredictiveController, model::LinModel, Ŷ, ΔŨ::Vector{NT}
326326
) where {NT<:Real}
327327
J = obj_quadprog(ΔŨ, mpc.H̃, mpc.q̃) + mpc.p[]
328328
if !iszero(mpc.E)
329-
U = mpc.*ΔŨ + mpc.T*(mpc.estim.lastu0 + model.uop)
329+
lastu = u
330+
lastu .= mpc.estim.lastu0 .+ model.uop
331+
U = mpc.*ΔŨ + mpc.T*lastu
330332
UE = [U; U[(end - model.nu + 1):end]]
331333
ŶE = [mpc.ŷ; Ŷ]
332334
J += mpc.E*mpc.JE(UE, ŶE, mpc.D̂E)
@@ -335,14 +337,14 @@ function obj_nonlinprog!(
335337
end
336338

337339
"""
338-
obj_nonlinprog!(Ȳ, Ū. mpc::PredictiveController, model::SimModel, Ŷ, ΔŨ)
340+
obj_nonlinprog!(Ȳ, Ū, u, mpc::PredictiveController, model::SimModel, Ŷ, ΔŨ)
339341
340342
Nonlinear programming objective function when `model` is not a [`LinModel`](@ref). The
341343
function `dot(x, A, x)` is a performant way of calculating `x'*A*x`. This method mutates
342344
`Ȳ` and `Ū` vector arguments (output and input setpoint tracking error, respectively).
343345
"""
344346
function obj_nonlinprog!(
345-
Ȳ, Ū, mpc::PredictiveController, model::SimModel, Ŷ, ΔŨ::Vector{NT}
347+
Ȳ, Ū, u, mpc::PredictiveController, model::SimModel, Ŷ, ΔŨ::Vector{NT}
346348
) where {NT<:Real}
347349
# --- output setpoint tracking term ---
348350
Ȳ .= mpc.R̂y .-
@@ -351,7 +353,9 @@ function obj_nonlinprog!(
351353
JΔŨ = dot(ΔŨ, mpc.Ñ_Hc, ΔŨ)
352354
# --- input over prediction horizon ---
353355
if !mpc.noR̂u || !iszero(mpc.E)
354-
U = mpc.*ΔŨ + mpc.T*(mpc.estim.lastu0 + model.uop)
356+
lastu = u
357+
lastu .= mpc.estim.lastu0 .+ model.uop
358+
U = mpc.*ΔŨ + mpc.T*lastu
355359
end
356360
# --- input setpoint tracking term ---
357361
if !mpc.noR̂u

src/controller/nonlinmpc.jl

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -301,61 +301,58 @@ function init_optimization!(mpc::NonLinMPC, optim::JuMP.GenericModel{JNT}) where
301301
# inspired from https://jump.dev/JuMP.jl/stable/tutorials/nonlinear/tips_and_tricks/#User-defined-operators-with-vector-outputs
302302
Jfunc, gfunc = let mpc=mpc, model=model, ng=ng, nΔŨ=nΔŨ, nŶ=Hp*ny, nx̂=nx̂, nu=nu, nU=Hp*nu
303303
last_ΔŨtup_float, last_ΔŨtup_dual = nothing, nothing
304-
Ŷ_cache ::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŶ), nΔŨ + 3)
305-
g_cache ::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, ng), nΔŨ + 3)
306-
x̂_cache ::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nx̂), nΔŨ + 3)
307-
u0_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nu), nΔŨ + 3)
308-
Ȳ_cache ::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŶ), nΔŨ + 3)
309-
Ū_cache ::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nU), nΔŨ + 3)
304+
Ŷ_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŶ), nΔŨ + 3)
305+
g_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, ng), nΔŨ + 3)
306+
x̂_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nx̂), nΔŨ + 3)
307+
u_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nu), nΔŨ + 3)
308+
Ȳ_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nŶ), nΔŨ + 3)
309+
Ū_cache::DiffCache{Vector{JNT}, Vector{JNT}} = DiffCache(zeros(JNT, nU), nΔŨ + 3)
310310
function Jfunc(ΔŨtup::JNT...)
311-
ΔŨtud1 = ΔŨtup[begin]
312-
= get_tmp(Ŷ_cache, ΔŨtud1)
311+
ΔŨ1 = ΔŨtup[begin]
312+
, u = get_tmp(Ŷ_cache, ΔŨ1), get_tmp(u_cache, ΔŨ1)
313313
ΔŨ = collect(ΔŨtup)
314314
if ΔŨtup !== last_ΔŨtup_float
315-
x̂, u0 = get_tmp(x̂_cache, ΔŨtud1), get_tmp(u0_cache, ΔŨtud1)
316-
g = get_tmp(g_cache, ΔŨtup[1])
317-
Ŷ, x̂end = predict!(Ŷ, x̂, u0, mpc, model, ΔŨ)
315+
x̂, g = get_tmp(x̂_cache, ΔŨ1), get_tmp(g_cache, ΔŨ1)
316+
Ŷ, x̂end = predict!(Ŷ, x̂, u, mpc, model, ΔŨ)
318317
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
319318
last_ΔŨtup_float = ΔŨtup
320319
end
321-
Ȳ, Ū = get_tmp(Ȳ_cache, ΔŨtud1), get_tmp(Ū_cache, ΔŨtud1)
322-
return obj_nonlinprog!(Ȳ, Ū, mpc, model, Ŷ, ΔŨ)
320+
Ȳ, Ū = get_tmp(Ȳ_cache, ΔŨ1), get_tmp(Ū_cache, ΔŨ1)
321+
return obj_nonlinprog!(Ȳ, Ū, u, mpc, model, Ŷ, ΔŨ)
323322
end
324323
function Jfunc(ΔŨtup::ForwardDiff.Dual...)
325-
ΔŨtud1 = ΔŨtup[begin]
326-
= get_tmp(Ŷ_cache, ΔŨtud1)
324+
ΔŨ1 = ΔŨtup[begin]
325+
, u = get_tmp(Ŷ_cache, ΔŨ1), get_tmp(u_cache, ΔŨ1)
327326
ΔŨ = collect(ΔŨtup)
328327
if ΔŨtup !== last_ΔŨtup_dual
329-
x̂, u0 = get_tmp(x̂_cache, ΔŨtud1), get_tmp(u0_cache, ΔŨtud1)
330-
g = get_tmp(g_cache, ΔŨtud1)
331-
Ŷ, x̂end = predict!(Ŷ, x̂, u0, mpc, model, ΔŨ)
328+
x̂, g = get_tmp(x̂_cache, ΔŨ1), get_tmp(g_cache, ΔŨ1)
329+
g = get_tmp(g_cache, ΔŨ1)
330+
Ŷ, x̂end = predict!(Ŷ, x̂, u, mpc, model, ΔŨ)
332331
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
333332
last_ΔŨtup_dual = ΔŨtup
334333
end
335-
Ȳ, Ū = get_tmp(Ȳ_cache, ΔŨtud1), get_tmp(Ū_cache, ΔŨtud1)
336-
return obj_nonlinprog!(Ȳ, Ū, mpc, model, Ŷ, ΔŨ)
334+
Ȳ, Ū = get_tmp(Ȳ_cache, ΔŨ1), get_tmp(Ū_cache, ΔŨ1)
335+
return obj_nonlinprog!(Ȳ, Ū, u, mpc, model, Ŷ, ΔŨ)
337336
end
338337
function gfunc_i(i, ΔŨtup::NTuple{N, JNT}) where N
339-
ΔŨtud1 = ΔŨtup[begin]
340-
g = get_tmp(g_cache, ΔŨtud1)
338+
ΔŨ1 = ΔŨtup[begin]
339+
g = get_tmp(g_cache, ΔŨ1)
341340
if ΔŨtup !== last_ΔŨtup_float
342-
x̂, u0 = get_tmp(x̂_cache, ΔŨtud1), get_tmp(u0_cache, ΔŨtud1)
343-
= get_tmp(Ŷ_cache, ΔŨtud1)
341+
Ŷ, u, x̂ = get_tmp(Ŷ_cache, ΔŨ1), get_tmp(u_cache, ΔŨ1), get_tmp(x̂_cache, ΔŨ1)
344342
ΔŨ = collect(ΔŨtup)
345-
Ŷ, x̂end = predict!(Ŷ, x̂, u0, mpc, model, ΔŨ)
343+
Ŷ, x̂end = predict!(Ŷ, x̂, u, mpc, model, ΔŨ)
346344
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
347345
last_ΔŨtup_float = ΔŨtup
348346
end
349347
return g[i]
350348
end
351349
function gfunc_i(i, ΔŨtup::NTuple{N, ForwardDiff.Dual}) where N
352-
ΔŨtud1 = ΔŨtup[begin]
353-
g = get_tmp(g_cache, ΔŨtud1)
350+
ΔŨ1 = ΔŨtup[begin]
351+
g = get_tmp(g_cache, ΔŨ1)
354352
if ΔŨtup !== last_ΔŨtup_dual
355-
x̂, u0 = get_tmp(x̂_cache, ΔŨtud1), get_tmp(u0_cache, ΔŨtud1)
356-
= get_tmp(Ŷ_cache, ΔŨtud1)
353+
Ŷ, u, x̂ = get_tmp(Ŷ_cache, ΔŨ1), get_tmp(u_cache, ΔŨ1), get_tmp(x̂_cache, ΔŨ1)
357354
ΔŨ = collect(ΔŨtup)
358-
Ŷ, x̂end = predict!(Ŷ, x̂, u0, mpc, model, ΔŨ)
355+
Ŷ, x̂end = predict!(Ŷ, x̂, u, mpc, model, ΔŨ)
359356
g = con_nonlinprog!(g, mpc, model, x̂end, Ŷ, ΔŨ)
360357
last_ΔŨtup_dual = ΔŨtup
361358
end

0 commit comments

Comments
 (0)