Skip to content

Commit 596084d

Browse files
authored
Merge pull request #141 from JuliaControl/reduce_alloc_jacobian!
added: `linearize!` is now allocation free
2 parents fc2ce6a + 25865ea commit 596084d

File tree

10 files changed

+151
-65
lines changed

10 files changed

+151
-65
lines changed

src/controller/explicitmpc.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ struct ExplicitMPC{NT<:Real, SE<:StateEstimator} <: PredictiveController{NT}
3131
Yop::Vector{NT}
3232
Dop::Vector{NT}
3333
buffer::PredictiveControllerBuffer{NT}
34-
function ExplicitMPC{NT, SE}(
34+
function ExplicitMPC{NT}(
3535
estim::SE, Hp, Hc, M_Hp, N_Hc, L_Hp
3636
) where {NT<:Real, SE<:StateEstimator}
3737
model = estim.model
@@ -158,7 +158,7 @@ function ExplicitMPC(
158158
@warn("prediction horizon Hp ($Hp) ≤ estimated number of delays in model "*
159159
"($nk), the closed-loop system may be unstable or zero-gain (unresponsive)")
160160
end
161-
return ExplicitMPC{NT, SE}(estim, Hp, Hc, M_Hp, N_Hc, L_Hp)
161+
return ExplicitMPC{NT}(estim, Hp, Hc, M_Hp, N_Hc, L_Hp)
162162
end
163163

164164
setconstraint!(::ExplicitMPC; kwargs...) = error("ExplicitMPC does not support constraints.")

src/controller/linmpc.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ struct LinMPC{
4040
Yop::Vector{NT}
4141
Dop::Vector{NT}
4242
buffer::PredictiveControllerBuffer{NT}
43-
function LinMPC{NT, SE, JM}(
43+
function LinMPC{NT}(
4444
estim::SE, Hp, Hc, M_Hp, N_Hc, L_Hp, Cwt, optim::JM
4545
) where {NT<:Real, SE<:StateEstimator, JM<:JuMP.GenericModel}
4646
model = estim.model
@@ -237,7 +237,7 @@ function LinMPC(
237237
@warn("prediction horizon Hp ($Hp) ≤ estimated number of delays in model "*
238238
"($nk), the closed-loop system may be unstable or zero-gain (unresponsive)")
239239
end
240-
return LinMPC{NT, SE, JM}(estim, Hp, Hc, M_Hp, N_Hc, L_Hp, Cwt, optim)
240+
return LinMPC{NT}(estim, Hp, Hc, M_Hp, N_Hc, L_Hp, Cwt, optim)
241241
end
242242

243243
"""

src/controller/nonlinmpc.jl

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ struct NonLinMPC{
44
NT<:Real,
55
SE<:StateEstimator,
66
JM<:JuMP.GenericModel,
7+
P<:Any,
78
JEfunc<:Function,
8-
GCfunc<:Function,
9-
P<:Any
9+
GCfunc<:Function
1010
} <: PredictiveController{NT}
1111
estim::SE
1212
# note: `NT` and the number type `JNT` in `JuMP.GenericModel{JNT}` can be
@@ -45,16 +45,16 @@ struct NonLinMPC{
4545
Yop::Vector{NT}
4646
Dop::Vector{NT}
4747
buffer::PredictiveControllerBuffer{NT}
48-
function NonLinMPC{NT, SE, JM, JEfunc, GCfunc, P}(
48+
function NonLinMPC{NT}(
4949
estim::SE,
5050
Hp, Hc, M_Hp, N_Hc, L_Hp, Cwt, Ewt, JE::JEfunc, gc!::GCfunc, nc, p::P, optim::JM
5151
) where {
5252
NT<:Real,
5353
SE<:StateEstimator,
54-
JM<:JuMP.GenericModel,
54+
JM<:JuMP.GenericModel,
55+
P<:Any,
5556
JEfunc<:Function,
5657
GCfunc<:Function,
57-
P<:Any
5858
}
5959
model = estim.model
6060
nu, ny, nd, nx̂ = model.nu, model.ny, model.nd, estim.nx̂
@@ -80,7 +80,7 @@ struct NonLinMPC{
8080
nΔŨ = size(Ẽ, 2)
8181
ΔŨ = zeros(NT, nΔŨ)
8282
buffer = PredictiveControllerBuffer{NT}(nu, ny, nd, Hp, Hc, nϵ)
83-
mpc = new{NT, SE, JM, JEfunc, GCfunc, P}(
83+
mpc = new{NT, SE, JM, P, JEfunc, GCfunc}(
8484
estim, optim, con,
8585
ΔŨ, ŷ,
8686
Hp, Hc, nϵ,
@@ -317,7 +317,7 @@ function NonLinMPC(
317317
L_Hp = diagm(repeat(Lwt, Hp)),
318318
Cwt = DEFAULT_CWT,
319319
Ewt = DEFAULT_EWT,
320-
JE ::JEfunc = (_,_,_,_) -> 0.0,
320+
JE ::Function = (_,_,_,_) -> 0.0,
321321
gc!::Function = (_,_,_,_,_,_) -> nothing,
322322
gc ::Function = gc!,
323323
nc = 0,
@@ -327,7 +327,6 @@ function NonLinMPC(
327327
NT<:Real,
328328
SE<:StateEstimator{NT},
329329
JM<:JuMP.GenericModel,
330-
JEfunc<:Function,
331330
P<:Any
332331
}
333332
nk = estimate_delays(estim.model)
@@ -337,8 +336,7 @@ function NonLinMPC(
337336
end
338337
validate_JE(NT, JE)
339338
gc! = get_mutating_gc(NT, gc)
340-
GCfunc = get_type_mutating_gc(gc!)
341-
return NonLinMPC{NT, SE, JM, JEfunc, GCfunc, P}(
339+
return NonLinMPC{NT}(
342340
estim, Hp, Hc, M_Hp, N_Hc, L_Hp, Cwt, Ewt, JE, gc!, nc, p, optim
343341
)
344342
end
@@ -396,9 +394,6 @@ function get_mutating_gc(NT, gc)
396394
return gc!
397395
end
398396

399-
"Get the type of the mutating version of the custom constrain function `gc!`."
400-
get_type_mutating_gc(::GCfunc) where {GCfunc<:Function} = GCfunc
401-
402397
"""
403398
test_custom_functions(NT, model::SimModel, JE, gc!, nc, Uop, Yop, Dop, p)
404399

src/estimator/internal_model.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ struct InternalModel{NT<:Real, SM<:SimModel} <: StateEstimator{NT}
2929
direct::Bool
3030
corrected::Vector{Bool}
3131
buffer::StateEstimatorBuffer{NT}
32-
function InternalModel{NT, SM}(
32+
function InternalModel{NT}(
3333
model::SM, i_ym, Asm, Bsm, Csm, Dsm
3434
) where {NT<:Real, SM<:SimModel}
3535
nu, ny, nd = model.nu, model.ny, model.nd
@@ -117,7 +117,7 @@ function InternalModel(
117117
stoch_ym = c2d(stoch_ym_c, model.Ts, :tustin)
118118
end
119119
end
120-
return InternalModel{NT, SM}(model, i_ym, stoch_ym.A, stoch_ym.B, stoch_ym.C, stoch_ym.D)
120+
return InternalModel{NT}(model, i_ym, stoch_ym.A, stoch_ym.B, stoch_ym.C, stoch_ym.D)
121121
end
122122

123123
"Validate if deterministic `model` and stochastic model `Csm, Dsm` for `InternalModel`s."

src/estimator/kalman.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ struct SteadyKalmanFilter{NT<:Real, SM<:LinModel} <: StateEstimator{NT}
2727
direct::Bool
2828
corrected::Vector{Bool}
2929
buffer::StateEstimatorBuffer{NT}
30-
function SteadyKalmanFilter{NT, SM}(
30+
function SteadyKalmanFilter{NT}(
3131
model::SM, i_ym, nint_u, nint_ym, Q̂, R̂; direct=true
3232
) where {NT<:Real, SM<:LinModel}
3333
nu, ny, nd = model.nu, model.ny, model.nd
@@ -182,7 +182,7 @@ function SteadyKalmanFilter(
182182
# estimated covariances matrices (variance = σ²) :
183183
= Hermitian(diagm(NT[σQ; σQint_u; σQint_ym ].^2), :L)
184184
= Hermitian(diagm(NT[σR;].^2), :L)
185-
return SteadyKalmanFilter{NT, SM}(model, i_ym, nint_u, nint_ym, Q̂, R̂; direct)
185+
return SteadyKalmanFilter{NT}(model, i_ym, nint_u, nint_ym, Q̂, R̂; direct)
186186
end
187187

188188
@doc raw"""
@@ -196,7 +196,7 @@ function SteadyKalmanFilter(
196196
model::SM, i_ym, nint_u, nint_ym, Q̂, R̂; direct=true
197197
) where {NT<:Real, SM<:LinModel{NT}}
198198
Q̂, R̂ = to_mat(Q̂), to_mat(R̂)
199-
return SteadyKalmanFilter{NT, SM}(model, i_ym, nint_u, nint_ym, Q̂, R̂; direct)
199+
return SteadyKalmanFilter{NT}(model, i_ym, nint_u, nint_ym, Q̂, R̂; direct)
200200
end
201201

202202
"Throw an error if `setmodel!` is called on a SteadyKalmanFilter w/o the default values."
@@ -308,7 +308,7 @@ struct KalmanFilter{NT<:Real, SM<:LinModel} <: StateEstimator{NT}
308308
direct::Bool
309309
corrected::Vector{Bool}
310310
buffer::StateEstimatorBuffer{NT}
311-
function KalmanFilter{NT, SM}(
311+
function KalmanFilter{NT}(
312312
model::SM, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂; direct=true
313313
) where {NT<:Real, SM<:LinModel}
314314
nu, ny, nd = model.nu, model.ny, model.nd
@@ -419,7 +419,7 @@ function KalmanFilter(
419419
P̂_0 = Hermitian(diagm(NT[σP_0; σPint_u_0; σPint_ym_0].^2), :L)
420420
= Hermitian(diagm(NT[σQ; σQint_u; σQint_ym ].^2), :L)
421421
= Hermitian(diagm(NT[σR;].^2), :L)
422-
return KalmanFilter{NT, SM}(model, i_ym, nint_u, nint_ym, P̂_0, Q̂ , R̂; direct)
422+
return KalmanFilter{NT}(model, i_ym, nint_u, nint_ym, P̂_0, Q̂ , R̂; direct)
423423
end
424424

425425
@doc raw"""
@@ -433,7 +433,7 @@ function KalmanFilter(
433433
model::SM, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂; direct=true
434434
) where {NT<:Real, SM<:LinModel{NT}}
435435
P̂_0, Q̂, R̂ = to_mat(P̂_0), to_mat(Q̂), to_mat(R̂)
436-
return KalmanFilter{NT, SM}(model, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂; direct)
436+
return KalmanFilter{NT}(model, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂; direct)
437437
end
438438

439439
@doc raw"""
@@ -530,7 +530,7 @@ struct UnscentedKalmanFilter{NT<:Real, SM<:SimModel} <: StateEstimator{NT}
530530
direct::Bool
531531
corrected::Vector{Bool}
532532
buffer::StateEstimatorBuffer{NT}
533-
function UnscentedKalmanFilter{NT, SM}(
533+
function UnscentedKalmanFilter{NT}(
534534
model::SM, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂, α, β, κ; direct=true
535535
) where {NT<:Real, SM<:SimModel{NT}}
536536
nu, ny, nd = model.nu, model.ny, model.nd
@@ -681,7 +681,7 @@ function UnscentedKalmanFilter(
681681
P̂_0 = Hermitian(diagm(NT[σP_0; σPint_u_0; σPint_ym_0].^2), :L)
682682
= Hermitian(diagm(NT[σQ; σQint_u; σQint_ym ].^2), :L)
683683
= Hermitian(diagm(NT[σR;].^2), :L)
684-
return UnscentedKalmanFilter{NT, SM}(
684+
return UnscentedKalmanFilter{NT}(
685685
model, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂, α, β, κ; direct
686686
)
687687
end
@@ -699,7 +699,7 @@ function UnscentedKalmanFilter(
699699
model::SM, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂, α=1e-3, β=2, κ=0; direct=true
700700
) where {NT<:Real, SM<:SimModel{NT}}
701701
P̂_0, Q̂, R̂ = to_mat(P̂_0), to_mat(Q̂), to_mat(R̂)
702-
return UnscentedKalmanFilter{NT, SM}(
702+
return UnscentedKalmanFilter{NT}(
703703
model, i_ym, nint_u, nint_ym, P̂_0, Q̂ , R̂, α, β, κ; direct
704704
)
705705
end
@@ -905,7 +905,7 @@ struct ExtendedKalmanFilter{NT<:Real, SM<:SimModel} <: StateEstimator{NT}
905905
direct::Bool
906906
corrected::Vector{Bool}
907907
buffer::StateEstimatorBuffer{NT}
908-
function ExtendedKalmanFilter{NT, SM}(
908+
function ExtendedKalmanFilter{NT}(
909909
model::SM, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂; direct=true
910910
) where {NT<:Real, SM<:SimModel}
911911
nu, ny, nd = model.nu, model.ny, model.nd
@@ -994,7 +994,7 @@ function ExtendedKalmanFilter(
994994
P̂_0 = Hermitian(diagm(NT[σP_0; σPint_u_0; σPint_ym_0].^2), :L)
995995
= Hermitian(diagm(NT[σQ; σQint_u; σQint_ym ].^2), :L)
996996
= Hermitian(diagm(NT[σR;].^2), :L)
997-
return ExtendedKalmanFilter{NT, SM}(model, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂; direct)
997+
return ExtendedKalmanFilter{NT}(model, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂; direct)
998998
end
999999

10001000
@doc raw"""
@@ -1008,7 +1008,7 @@ function ExtendedKalmanFilter(
10081008
model::SM, i_ym, nint_u, nint_ym,P̂_0, Q̂, R̂; direct=true
10091009
) where {NT<:Real, SM<:SimModel{NT}}
10101010
P̂_0, Q̂, R̂ = to_mat(P̂_0), to_mat(Q̂), to_mat(R̂)
1011-
return ExtendedKalmanFilter{NT, SM}(model, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂; direct)
1011+
return ExtendedKalmanFilter{NT}(model, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂; direct)
10121012
end
10131013

10141014
"""

src/estimator/mhe/construct.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ struct MovingHorizonEstimator{
107107
direct::Bool
108108
corrected::Vector{Bool}
109109
buffer::StateEstimatorBuffer{NT}
110-
function MovingHorizonEstimator{NT, SM, JM, CE}(
110+
function MovingHorizonEstimator{NT}(
111111
model::SM, He, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂, Cwt, optim::JM, covestim::CE;
112112
direct=true
113113
) where {NT<:Real, SM<:SimModel{NT}, JM<:JuMP.GenericModel, CE<:StateEstimator{NT}}
@@ -406,7 +406,7 @@ function MovingHorizonEstimator(
406406
covestim::CE = default_covestim_mhe(model, i_ym, nint_u, nint_ym, P̂_0, Q̂, R̂; direct)
407407
) where {NT<:Real, SM<:SimModel{NT}, JM<:JuMP.GenericModel, CE<:StateEstimator{NT}}
408408
P̂_0, Q̂, R̂ = to_mat(P̂_0), to_mat(Q̂), to_mat(R̂)
409-
return MovingHorizonEstimator{NT, SM, JM, CE}(
409+
return MovingHorizonEstimator{NT}(
410410
model, He, i_ym, nint_u, nint_ym, P̂_0, Q̂ , R̂, Cwt, optim, covestim; direct
411411
)
412412
end

src/model/linearization.jl

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,31 @@
1+
2+
function jacobianA!(A, jb::JacobianBuffer, model::SimModel, x, u, d)
3+
jb.x .= x; jb.u .= u; jb.d .= d
4+
return ForwardDiff.jacobian!(A, jb.f_x!, jb.xnext, jb.x, jb.f_x_cfg)
5+
end
6+
jacobianA!( _ , _ , model::LinModel, _ , _ , _ ) = model.A
7+
function jacobianBu!(Bu, jb::JacobianBuffer, model::SimModel, x, u, d)
8+
jb.x .= x; jb.u .= u; jb.d .= d
9+
return ForwardDiff.jacobian!(Bu, jb.f_u!, jb.xnext, jb.u, jb.f_u_cfg)
10+
end
11+
jacobianBu!( _ , _ , model::LinModel, _ , _ , _ ) = model.Bu
12+
function jacobianBd!(Bd, jb::JacobianBuffer, model::SimModel, x, u, d)
13+
jb.x .= x; jb.u .= u; jb.d .= d
14+
return ForwardDiff.jacobian!(Bd, jb.f_d!, jb.xnext, jb.d, jb.f_d_cfg)
15+
end
16+
jacobianBd!( _ , _ , model::LinModel, _ , _ , _ ) = model.Bd
17+
function jacobianC!(C, jb::JacobianBuffer, model::SimModel, x, d)
18+
jb.x .= x; jb.d .= d
19+
return ForwardDiff.jacobian!(C, jb.h_x!, jb.y, jb.x, jb.h_x_cfg)
20+
end
21+
jacobianC!( _ , _ , model::LinModel, _ , _ ) = model.C
22+
function jacobianDd!(Dd, jb::JacobianBuffer, model::SimModel, x, d)
23+
jb.x .= x; jb.d .= d
24+
return ForwardDiff.jacobian!(Dd, jb.h_d!, jb.y, jb.d, jb.h_d_cfg)
25+
end
26+
jacobianDd!( _ , _ , model::LinModel, _ , _ ) = model.Dd
27+
28+
129
"""
230
LinModel(model::NonLinModel; x=model.x0+model.xop, u=model.uop, d=model.dop)
331
@@ -118,19 +146,25 @@ function linearize!(
118146
d0 .= d .- nonlinmodel.dop
119147
x0 .= x .- nonlinmodel.xop
120148
# --- compute the Jacobians at linearization points ---
121-
A::Matrix{NT}, Bu::Matrix{NT}, Bd::Matrix{NT} = linmodel.A, linmodel.Bu, linmodel.Bd
122-
C::Matrix{NT}, Dd::Matrix{NT} = linmodel.C, linmodel.Dd
149+
#A::Matrix{NT}, Bu::Matrix{NT}, Bd::Matrix{NT} = linmodel.A, linmodel.Bu, linmodel.Bd
150+
#C::Matrix{NT}, Dd::Matrix{NT} = linmodel.C, linmodel.Dd
123151
xnext0::Vector{NT}, y0::Vector{NT} = linmodel.buffer.x, linmodel.buffer.y
124-
myf_x0!(xnext0, x0) = f!(xnext0, nonlinmodel, x0, u0, d0, model.p)
125-
myf_u0!(xnext0, u0) = f!(xnext0, nonlinmodel, x0, u0, d0, model.p)
126-
myf_d0!(xnext0, d0) = f!(xnext0, nonlinmodel, x0, u0, d0, model.p)
127-
myh_x0!(y0, x0) = h!(y0, nonlinmodel, x0, d0, model.p)
128-
myh_d0!(y0, d0) = h!(y0, nonlinmodel, x0, d0, model.p)
129-
ForwardDiff.jacobian!(A, myf_x0!, xnext0, x0)
130-
ForwardDiff.jacobian!(Bu, myf_u0!, xnext0, u0)
131-
ForwardDiff.jacobian!(Bd, myf_d0!, xnext0, d0)
132-
ForwardDiff.jacobian!(C, myh_x0!, y0, x0)
133-
ForwardDiff.jacobian!(Dd, myh_d0!, y0, d0)
152+
#myf_x0!(xnext0, x0) = f!(xnext0, nonlinmodel, x0, u0, d0, model.p)
153+
#myf_u0!(xnext0, u0) = f!(xnext0, nonlinmodel, x0, u0, d0, model.p)
154+
#myf_d0!(xnext0, d0) = f!(xnext0, nonlinmodel, x0, u0, d0, model.p)
155+
#myh_x0!(y0, x0) = h!(y0, nonlinmodel, x0, d0, model.p)
156+
#myh_d0!(y0, d0) = h!(y0, nonlinmodel, x0, d0, model.p)
157+
#ForwardDiff.jacobian!(A, myf_x0!, xnext0, x0)
158+
#ForwardDiff.jacobian!(Bu, myf_u0!, xnext0, u0)
159+
#ForwardDiff.jacobian!(Bd, myf_d0!, xnext0, d0)
160+
#ForwardDiff.jacobian!(C, myh_x0!, y0, x0)
161+
#ForwardDiff.jacobian!(Dd, myh_d0!, y0, d0)
162+
jb = nonlinmodel.buffer.jacobian
163+
jacobianA!(linmodel.A, jb, nonlinmodel, x0, u0, d0)
164+
jacobianBu!(linmodel.Bu, jb, nonlinmodel, x0, u0, d0)
165+
jacobianBd!(linmodel.Bd, jb, nonlinmodel, x0, u0, d0)
166+
jacobianC!(linmodel.C, jb, nonlinmodel, x0, d0)
167+
jacobianDd!(linmodel.Dd, jb, nonlinmodel, x0, d0)
134168
# --- compute the nonlinear model output at operating points ---
135169
h!(y0, nonlinmodel, x0, d0, model.p)
136170
y = y0

src/model/linmodel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ struct LinModel{NT<:Real} <: SimModel{NT}
2121
yname::Vector{String}
2222
dname::Vector{String}
2323
xname::Vector{String}
24-
buffer::SimModelBuffer{NT}
24+
buffer::SimModelBuffer{NT, Nothing}
2525
function LinModel{NT}(A, Bu, C, Bd, Dd, Ts) where {NT<:Real}
2626
A, Bu = to_mat(A, 1, 1), to_mat(Bu, 1, 1)
2727
nu, nx = size(Bu, 2), size(A, 2)

src/model/nonlinmodel.jl

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
struct NonLinModel{
2-
NT<:Real, F<:Function, H<:Function, P<:Any, DS<:DiffSolver
2+
NT<:Real, F<:Function, H<:Function, P<:Any, DS<:DiffSolver, SMB<:SimModelBuffer
33
} <: SimModel{NT}
44
x0::Vector{NT}
55
f!::F
@@ -21,10 +21,10 @@ struct NonLinModel{
2121
yname::Vector{String}
2222
dname::Vector{String}
2323
xname::Vector{String}
24-
buffer::SimModelBuffer{NT}
25-
function NonLinModel{NT, F, H, P, DS}(
26-
f!::F, h!::H, Ts, nu, nx, ny, nd, p::P, solver::DS
27-
) where {NT<:Real, F<:Function, H<:Function, P<:Any, DS<:DiffSolver}
24+
buffer::SMB
25+
function NonLinModel{NT}(
26+
f!::F, h!::H, Ts, nu, nx, ny, nd, p::P, solver::DS, buffer::SMB
27+
) where {NT<:Real, F<:Function, H<:Function, P<:Any, DS<:DiffSolver, SMB<:SimModelBuffer}
2828
Ts > 0 || error("Sampling time Ts must be positive")
2929
uop = zeros(NT, nu)
3030
yop = zeros(NT, ny)
@@ -37,8 +37,7 @@ struct NonLinModel{
3737
xname = ["\$x_{$i}\$" for i in 1:nx]
3838
x0 = zeros(NT, nx)
3939
t = zeros(NT, 1)
40-
buffer = SimModelBuffer{NT}(nu, nx, ny, nd)
41-
return new{NT, F, H, P, DS}(
40+
return new{NT, F, H, P, DS, SMB}(
4241
x0,
4342
f!, h!,
4443
p,
@@ -145,8 +144,9 @@ function NonLinModel{NT}(
145144
isnothing(solver) && (solver=EmptySolver())
146145
f!, h! = get_mutating_functions(NT, f, h)
147146
f!, h! = get_solver_functions(NT, solver, f!, h!, Ts, nu, nx, ny, nd)
148-
F, H, P, DS = get_types(f!, h!, p, solver)
149-
return NonLinModel{NT, F, H, P, DS}(f!, h!, Ts, nu, nx, ny, nd, p, solver)
147+
jacobian = JacobianBuffer{NT}(f!, h!, nu, nx, ny, nd, p)
148+
buffer = SimModelBuffer{NT}(nu, nx, ny, nd, jacobian)
149+
return NonLinModel{NT}(f!, h!, Ts, nu, nx, ny, nd, p, solver, buffer)
150150
end
151151

152152
function NonLinModel(
@@ -224,13 +224,6 @@ function validate_h(NT, h)
224224
return ismutating
225225
end
226226

227-
"Get the types of `f!`, `h!` and `solver` to construct a `NonLinModel`."
228-
function get_types(
229-
::F, ::H, ::P, ::DS
230-
) where {F<:Function, H<:Function, P<:Any, DS<:DiffSolver}
231-
return F, H, P, DS
232-
end
233-
234227
"Do nothing if `model` is a [`NonLinModel`](@ref)."
235228
steadystate!(::SimModel, _ , _ ) = nothing
236229

0 commit comments

Comments
 (0)