Skip to content

Commit 90fb542

Browse files
Refactored Rosenbrock32 and Rosenbrock23
1 parent f99cca7 commit 90fb542

File tree

5 files changed

+86
-274
lines changed

5 files changed

+86
-274
lines changed

lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ function SciMLBase.interp_summary(::Type{cacheType},
33
cacheType <:
44
Union{Rosenbrock23ConstantCache,
55
Rosenbrock32ConstantCache,
6-
Rosenbrock23Cache,
7-
Rosenbrock32Cache}}
6+
RosenbrockCombinedCache}}
87
dense ? "specialized 2nd order \"free\" stiffness-aware interpolation" :
98
"1st order linear"
109
end

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 8 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ struct RosenbrockCombinedConstantCache{TF, UF, Tab, JType, WType, F, AD} <:
6464
interp_order::Int
6565
end
6666

67-
@cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType,
67+
@cache mutable struct RosenbrockCombinedCache{uType, rateType, uNoUnitsType, JType, WType,
6868
TabType, TFType, UFType, F, JCType, GCType,
6969
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
7070
u::uType
@@ -97,44 +97,7 @@ end
9797
stage_limiter!::StageLimiter
9898
end
9999

100-
@cache mutable struct Rosenbrock32Cache{uType, rateType, uNoUnitsType, JType, WType,
101-
TabType, TFType, UFType, F, JCType, GCType,
102-
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
103-
u::uType
104-
uprev::uType
105-
k₁::rateType
106-
k₂::rateType
107-
k₃::rateType
108-
du1::rateType
109-
du2::rateType
110-
f₁::rateType
111-
fsalfirst::rateType
112-
fsallast::rateType
113-
dT::rateType
114-
J::JType
115-
W::WType
116-
tmp::rateType
117-
atmp::uNoUnitsType
118-
weight::uNoUnitsType
119-
tab::TabType
120-
tf::TFType
121-
uf::UFType
122-
linsolve_tmp::rateType
123-
linsolve::F
124-
jac_config::JCType
125-
grad_config::GCType
126-
reltol::RTolType
127-
alg::A
128-
algebraic_vars::AV
129-
step_limiter!::StepLimiter
130-
stage_limiter!::StageLimiter
131-
end
132-
133-
function get_fsalfirstlast(cache::Union{Rosenbrock23Cache, Rosenbrock32Cache}, u)
134-
(cache.fsalfirst, cache.fsallast)
135-
end
136-
137-
function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
100+
function alg_cache(alg::Union{Rosenbrock23, Rosenbrock32}, u, rate_prototype, ::Type{uEltypeNoUnits},
138101
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
139102
dt, reltol, p, calck,
140103
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
@@ -153,7 +116,7 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
153116
recursivefill!(atmp, false)
154117
weight = similar(u, uEltypeNoUnits)
155118
recursivefill!(weight, false)
156-
tab = Rosenbrock23Tableau(constvalue(uBottomEltypeNoUnits))
119+
tab = RosenbrockCombinedTableau(constvalue(uBottomEltypeNoUnits))
157120
tf = TimeGradientWrapper(f, uprev, p)
158121
uf = UJacobianWrapper(f, t, p)
159122
linsolve_tmp = zero(rate_prototype)
@@ -176,61 +139,13 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
176139
algebraic_vars = f.mass_matrix === I ? nothing :
177140
[all(iszero, x) for x in eachcol(f.mass_matrix)]
178141

179-
Rosenbrock23Cache(u, uprev, k₁, k₂, k₃, du1, du2, f₁,
142+
RosenbrockCombinedCache(u, uprev, k₁, k₂, k₃, du1, du2, f₁,
180143
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
181144
linsolve_tmp,
182145
linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, alg.step_limiter!,
183146
alg.stage_limiter!)
184147
end
185148

186-
function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
187-
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
188-
dt, reltol, p, calck,
189-
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
190-
k₁ = zero(rate_prototype)
191-
k₂ = zero(rate_prototype)
192-
k₃ = zero(rate_prototype)
193-
du1 = zero(rate_prototype)
194-
du2 = zero(rate_prototype)
195-
# f₀ = zero(u) fsalfirst
196-
f₁ = zero(rate_prototype)
197-
fsalfirst = zero(rate_prototype)
198-
fsallast = zero(rate_prototype)
199-
dT = zero(rate_prototype)
200-
tmp = zero(rate_prototype)
201-
atmp = similar(u, uEltypeNoUnits)
202-
recursivefill!(atmp, false)
203-
weight = similar(u, uEltypeNoUnits)
204-
recursivefill!(weight, false)
205-
tab = Rosenbrock32Tableau(constvalue(uBottomEltypeNoUnits))
206-
207-
tf = TimeGradientWrapper(f, uprev, p)
208-
uf = UJacobianWrapper(f, t, p)
209-
linsolve_tmp = zero(rate_prototype)
210-
211-
grad_config = build_grad_config(alg, f, tf, du1, t)
212-
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
213-
214-
J, W = build_J_W(alg, u, uprev, p, t, dt, f, jac_config, uEltypeNoUnits, Val(true))
215-
216-
linprob = LinearProblem(W, _vec(linsolve_tmp); u0 = _vec(tmp))
217-
218-
Pl, Pr = wrapprecs(
219-
alg.precs(W, nothing, u, p, t, nothing, nothing, nothing,
220-
nothing)..., weight, tmp)
221-
linsolve = init(
222-
linprob, alg.linsolve, alias = LinearAliasSpecifier(alias_A = true, alias_b = true),
223-
Pl = Pl, Pr = Pr,
224-
assumptions = LinearSolve.OperatorAssumptions(true))
225-
226-
algebraic_vars = f.mass_matrix === I ? nothing :
227-
[all(iszero, x) for x in eachcol(f.mass_matrix)]
228-
229-
Rosenbrock32Cache(u, uprev, k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W,
230-
tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config,
231-
grad_config, reltol, alg, algebraic_vars, alg.step_limiter!, alg.stage_limiter!)
232-
end
233-
234149
struct Rosenbrock23ConstantCache{T, TF, UF, JType, WType, F, AD} <:
235150
RosenbrockConstantCache
236151
c₃₂::T
@@ -244,7 +159,7 @@ struct Rosenbrock23ConstantCache{T, TF, UF, JType, WType, F, AD} <:
244159
end
245160

246161
function Rosenbrock23ConstantCache(::Type{T}, tf, uf, J, W, linsolve, autodiff) where {T}
247-
tab = Rosenbrock23Tableau(T)
162+
tab = RosenbrockCombinedTableau(T)
248163
Rosenbrock23ConstantCache(tab.c₃₂, tab.d, tf, uf, J, W, linsolve, autodiff)
249164
end
250165

@@ -274,7 +189,7 @@ struct Rosenbrock32ConstantCache{T, TF, UF, JType, WType, F, AD} <:
274189
end
275190

276191
function Rosenbrock32ConstantCache(::Type{T}, tf, uf, J, W, linsolve, autodiff) where {T}
277-
tab = Rosenbrock32Tableau(T)
192+
tab = RosenbrockCombinedTableau(T)
278193
Rosenbrock32ConstantCache(tab.c₃₂, tab.d, tf, uf, J, W, linsolve, autodiff)
279194
end
280195

@@ -837,7 +752,8 @@ function alg_cache(
837752
end
838753

839754
function get_fsalfirstlast(
840-
cache::Union{RosenbrockCache,
755+
cache::Union{RosenbrockCombinedCache, Rosenbrock33Cache,
756+
Rosenbrock34Cache,
841757
Rosenbrock4Cache},
842758
u)
843759
(cache.fsalfirst, cache.fsallast)

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
### Fallbacks to capture
2-
ROSENBROCKS_WITH_INTERPOLATIONS = Union{Rosenbrock23ConstantCache, Rosenbrock23Cache,
3-
Rosenbrock32ConstantCache, Rosenbrock32Cache,
2+
ROSENBROCKS_WITH_INTERPOLATIONS = Union{Rosenbrock23ConstantCache, RosenbrockCombinedCache,
3+
Rosenbrock32ConstantCache,
4+
Rodas23WConstantCache, Rodas3PConstantCache,
5+
Rodas23WCache, Rodas3PCache,
46
RosenbrockCombinedConstantCache,
57
RosenbrockCache}
68

@@ -42,24 +44,24 @@ end
4244
end
4345

4446
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
45-
cache::Union{Rosenbrock23Cache, Rosenbrock32Cache},
47+
cache::RosenbrockCombinedCache,
4648
idxs::Nothing, T::Type{Val{0}}, differential_vars)
4749
@rosenbrock2332pre0
4850
@inbounds @.. y₀ + dt * (c1 * k[1] + c2 * k[2])
4951
end
5052

5153
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
52-
cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache,
53-
Rosenbrock32ConstantCache, Rosenbrock32Cache
54+
cache::Union{Rosenbrock23ConstantCache, RosenbrockCombinedCache,
55+
Rosenbrock32ConstantCache
5456
}, idxs, T::Type{Val{0}}, differential_vars)
5557
@rosenbrock2332pre0
5658
@.. y₀[idxs] + dt * (c1 * k[1][idxs] + c2 * k[2][idxs])
5759
end
5860

5961
@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
6062
cache::Union{Rosenbrock23ConstantCache,
61-
Rosenbrock23Cache,
62-
Rosenbrock32ConstantCache, Rosenbrock32Cache
63+
RosenbrockCombinedCache,
64+
Rosenbrock32ConstantCache
6365
}, idxs::Nothing, T::Type{Val{0}}, differential_vars)
6466
@rosenbrock2332pre0
6567
@inbounds @.. out = y₀ + dt * (c1 * k[1] + c2 * k[2])
@@ -68,8 +70,8 @@ end
6870

6971
@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
7072
cache::Union{Rosenbrock23ConstantCache,
71-
Rosenbrock23Cache,
72-
Rosenbrock32ConstantCache, Rosenbrock32Cache
73+
RosenbrockCombinedCache,
74+
Rosenbrock32ConstantCache
7375
}, idxs, T::Type{Val{0}}, differential_vars)
7476
@rosenbrock2332pre0
7577
@views @.. out = y₀[idxs] + dt * (c1 * k[1][idxs] + c2 * k[2][idxs])
@@ -84,25 +86,25 @@ end
8486
end
8587

8688
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
87-
cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache,
88-
Rosenbrock32ConstantCache, Rosenbrock32Cache
89+
cache::Union{Rosenbrock23ConstantCache, RosenbrockCombinedCache,
90+
Rosenbrock32ConstantCache
8991
}, idxs::Nothing, T::Type{Val{1}}, differential_vars)
9092
@rosenbrock2332pre1
9193
@.. c1diff * k[1] + c2diff * k[2]
9294
end
9395

9496
@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
95-
cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache,
96-
Rosenbrock32ConstantCache, Rosenbrock32Cache
97+
cache::Union{Rosenbrock23ConstantCache, RosenbrockCombinedCache,
98+
Rosenbrock32ConstantCache
9799
}, idxs, T::Type{Val{1}}, differential_vars)
98100
@rosenbrock2332pre1
99101
@.. c1diff * k[1][idxs] + c2diff * k[2][idxs]
100102
end
101103

102104
@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
103105
cache::Union{Rosenbrock23ConstantCache,
104-
Rosenbrock23Cache,
105-
Rosenbrock32ConstantCache, Rosenbrock32Cache
106+
RosenbrockCombinedCache,
107+
Rosenbrock32ConstantCache
106108
}, idxs::Nothing, T::Type{Val{1}}, differential_vars)
107109
@rosenbrock2332pre1
108110
@.. out = c1diff * k[1] + c2diff * k[2]
@@ -111,8 +113,8 @@ end
111113

112114
@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
113115
cache::Union{Rosenbrock23ConstantCache,
114-
Rosenbrock23Cache,
115-
Rosenbrock32ConstantCache, Rosenbrock32Cache
116+
RosenbrockCombinedCache,
117+
Rosenbrock32ConstantCache
116118
}, idxs, T::Type{Val{1}}, differential_vars)
117119
@rosenbrock2332pre1
118120
@views @.. out = c1diff * k[1][idxs] + c2diff * k[2][idxs]

0 commit comments

Comments
 (0)