Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
alg_extrapolates(alg::ImplicitEuler) = true
alg_extrapolates(alg::Trapezoid) = true
alg_extrapolates(alg::SDIRK22) = true
alg_extrapolates(alg::SDIRK22) = false

alg_order(alg::Trapezoid) = 2
alg_order(alg::ImplicitEuler) = 1
Expand Down
6 changes: 4 additions & 2 deletions lib/OrdinaryDiffEqSDIRK/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ struct SDIRK22{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <:
linsolve::F
nlsolve::F2
precs::P
smooth_est::Bool
extrapolant::Symbol
controller::Symbol
step_limiter!::StepLimiter
Expand All @@ -342,15 +343,16 @@ function SDIRK22(;
chunk_size = Val{0}(), autodiff = AutoForwardDiff(), standardtag = Val{true}(),
concrete_jac = nothing, diff_type = Val{:forward}(),
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(),
extrapolant = :linear,
smooth_est = false, extrapolant = :linear,
controller = :PI, step_limiter! = trivial_limiter!)
AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type)

Trapezoid{_unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve),
SDIRK22{_unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve),
typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag),
_unwrap_val(concrete_jac), typeof(step_limiter!)}(linsolve,
nlsolve,
precs,
smooth_est,
extrapolant,
controller,
step_limiter!,
Expand Down
33 changes: 17 additions & 16 deletions lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,7 @@ function alg_cache(alg::SDIRK2, u, rate_prototype, ::Type{uEltypeNoUnits},
SDIRK2Cache(u, uprev, fsalfirst, z₁, z₂, atmp, nlsolver, alg.step_limiter!)
end

struct SDIRK22ConstantCache{uType, tType, N, Tab} <: SDIRKConstantCache
uprev3::uType
tprev2::tType
mutable struct SDIRK22ConstantCache{N, Tab} <: SDIRKConstantCache
nlsolver::N
tab::Tab
end
Expand All @@ -238,26 +236,29 @@ function alg_cache(alg::SDIRK22, u, rate_prototype, ::Type{uEltypeNoUnits},
uprev, uprev2, f, t, dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tab = SDIRK22Tableau(constvalue(uBottomEltypeNoUnits))
uprev3 = u
tprev2 = t
γ, c = 1, 1

# Want to solve nonlinear problems of the from
# z = dt ⋅ f(tmp + γ ⋅ z, p, t + c ⋅ dt)
#
# 1st stage of SDIRK22:
# z = dt ⋅ f(u + γ ⋅ z, p, t + γ ⋅ dt)
γ = tab.γ
c = γ

nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false))

SDIRK22ConstantCache(uprev3, tprev2, nlsolver)
SDIRK22ConstantCache(nlsolver, tab)
end

@cache mutable struct SDIRK22Cache{
uType, rateType, uNoUnitsType, tType, N, Tab, StepLimiter} <:
uType, rateType, uNoUnitsType, N, Tab, StepLimiter} <:
SDIRKMutableCache
u::uType
uprev::uType
uprev2::uType
fsalfirst::rateType
z1::uType
atmp::uNoUnitsType
uprev3::uType
tprev2::tType
nlsolver::N
tab::Tab
step_limiter!::StepLimiter
Expand All @@ -268,18 +269,18 @@ function alg_cache(alg::SDIRK22, u, rate_prototype, ::Type{uEltypeNoUnits},
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
tab = SDIRK22Tableau(constvalue(uBottomEltypeNoUnits))
γ, c = 1, 1
γ = tab.γ
c = γ

nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits,
uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true))
fsalfirst = zero(rate_prototype)

uprev3 = zero(u)
tprev2 = t
atmp = similar(u, uEltypeNoUnits)
recursivefill!(atmp, false)
z1 = zero(u)

SDIRK22Cache(
u, uprev, uprev2, fsalfirst, atmp, uprev3, tprev2, nlsolver, tab, alg.step_limiter!) # shouldn't this be SDIRK22Cache instead of SDIRK22?
SDIRK22Cache(u, uprev, fsalfirst, z1, atmp, nlsolver, tab, alg.step_limiter!)
end

mutable struct SSPSDIRK2ConstantCache{N} <: SDIRKConstantCache
Expand Down
168 changes: 62 additions & 106 deletions lib/OrdinaryDiffEqSDIRK/src/sdirk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -635,152 +635,108 @@ end

@muladd function perform_step!(integrator, cache::SDIRK22ConstantCache, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
@unpack a, α, β = cache.tab
@unpack γ, a21, bhat1, bhat2, btilde1, btilde2 = cache.tab
nlsolver = cache.nlsolver
alg = unwrap_alg(integrator, true)

# precalculations
γ = a * dt
γdt = γ * dt
markfirststage!(nlsolver)

# initial guess
zprev = dt * integrator.fsalfirst
nlsolver.z = zprev
# Want to solve nonlinear problems of the from
# z = dt ⋅ f(tmp + γ ⋅ z, p, t + c ⋅ dt)
# 1st stage of SDIRK22:
# z1 = dt ⋅ f(uprev + γ ⋅ z1, p, t + γ ⋅ dt)
nlsolver.c = γ
nlsolver.tmp = uprev
# The same nlsolver.γ is used in all stages

# first stage
nlsolver.tmp = uprev + γdt * integrator.fsalfirst
z = nlsolve!(nlsolver, integrator, cache, repeat_step)
# Initial guess (FSAL)
zprev = dt * integrator.fsalfirst
nlsolver.z = zprev

z1 = nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
uprev = α * nlsolver.tmp + β * z

# final stage
γ = dt
γdt = γ * dt
markfirststage!(nlsolver)
nlsolver.tmp = uprev + γdt * integrator.fsalfirst
# 2nd stage of SDIRK22:
# z = dt ⋅ f(uprev + a21 * z1 + γ ⋅ z, p, t + dt)
nlsolver.c = 1
nlsolver.tmp = uprev + a21 * z1

# Initial guess
nlsolver.z = z1

z = nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
u = nlsolver.tmp

if integrator.opts.adaptive
if integrator.iter > 2
# local truncation error (LTE) bound by dt^3/12*max|y'''(t)|
# use 3rd divided differences (DD) a la SPICE and Shampine

# TODO: check numerical stability
uprev2 = integrator.uprev2
tprev = integrator.tprev
uprev3 = cache.uprev3
tprev2 = cache.tprev2
# u = uprev + (1-γ) * z1 + γ * z
u = nlsolver.tmp + γ * z

dt1 = dt * (t + dt - tprev)
dt2 = (t - tprev) * (t + dt - tprev)
dt3 = (t - tprev) * (t - tprev2)
dt4 = (tprev - tprev2) * (t - tprev2)
dt5 = t + dt - tprev2
c = 7 / 12 # default correction factor in SPICE (LTE overestimated by DD)
r = c * dt^3 / 2 # by mean value theorem 3rd DD equals y'''(s)/6 for some s
################################### Finalize

DD31 = (u - uprev) / dt1 - (uprev - uprev2) / dt2
DD30 = (uprev - uprev2) / dt3 - (uprev2 - uprev3) / dt4
tmp = r * abs((DD31 - DD30) / dt5)
atmp = calculate_residuals(tmp, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm,
t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
if integrator.EEst <= 1
cache.uprev3 = uprev2
cache.tprev2 = tprev
end
elseif integrator.success_iter > 0
integrator.EEst = 1
cache.uprev3 = integrator.uprev2
cache.tprev2 = integrator.tprev
if integrator.opts.adaptive
tmp = btilde1 * z1 + btilde2 * z
if isnewton(nlsolver) && alg.smooth_est # From Shampine
integrator.stats.nsolve += 1
est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp))
else
integrator.EEst = 1
est = tmp
end
atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end

OrdinaryDiffEqCore.increment_nf!(integrator.stats, 2)
integrator.fsallast = z ./ dt
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
integrator.u = u
end

@muladd function perform_step!(integrator, cache::SDIRK22Cache, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
@unpack atmp, nlsolver, step_limiter! = cache
@unpack z1, atmp, nlsolver, step_limiter! = cache
@unpack z, tmp = nlsolver
@unpack a, α, β = cache.tab
W = isnewton(nlsolver) ? get_W(nlsolver) : nothing
b = nlsolver.ztmp
@unpack γ, a21, bhat1, bhat2, btilde1, btilde2 = cache.tab
alg = unwrap_alg(integrator, true)
mass_matrix = integrator.f.mass_matrix

# precalculations
γ = a * dt
γdt = γ * dt
markfirststage!(nlsolver)

# first stage
# See in-place version for details
@.. broadcast=false z=dt * integrator.fsalfirst
@.. broadcast=false tmp=uprev + γdt * integrator.fsalfirst
z = nlsolve!(nlsolver, integrator, cache, repeat_step)
@.. broadcast=false tmp=uprev
nlsolver.c = γ
z1 .= nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
@.. broadcast=false u=α * tmp + β * z

# final stage
γ = dt
γdt = γ * dt
markfirststage!(nlsolver)
@.. broadcast=false tmp=uprev + γdt * integrator.fsalfirst
z = nlsolve!(nlsolver, integrator, cache, repeat_step)
@.. broadcast=false z=z1
@.. broadcast=false tmp=uprev + a21 * z
nlsolver.c = 1
isnewton(nlsolver) && set_new_W!(nlsolver, false)
z .= nlsolve!(nlsolver, integrator, cache, repeat_step)
nlsolvefail(nlsolver) && return
@.. broadcast=false u=nlsolver.tmp

step_limiter!(u, integrator, p, t + dt)
@.. broadcast=false u=tmp + γ * z

if integrator.opts.adaptive
if integrator.iter > 2
# local truncation error (LTE) bound by dt^3/12*max|y'''(t)|
# use 3rd divided differences (DD) a la SPICE and Shampine
step_limiter!(u, integrator, p, t + dt)

# TODO: check numerical stability
uprev2 = integrator.uprev2
tprev = integrator.tprev
uprev3 = cache.uprev3
tprev2 = cache.tprev2
################################### Finalize

dt1 = dt * (t + dt - tprev)
dt2 = (t - tprev) * (t + dt - tprev)
dt3 = (t - tprev) * (t - tprev2)
dt4 = (tprev - tprev2) * (t - tprev2)
dt5 = t + dt - tprev2
c = 7 / 12 # default correction factor in SPICE (LTE overestimated by DD)
r = c * dt^3 / 2 # by mean value theorem 3rd DD equals y'''(s)/6 for some s
if integrator.opts.adaptive
@.. broadcast=false tmp=btilde1 * z1 + btilde2 * z
if alg.smooth_est && isnewton(nlsolver) # From Shampine
est = nlsolver.cache.dz
linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp),
linu = _vec(est))

@inbounds for i in eachindex(u)
DD31 = (u[i] - uprev[i]) / dt1 - (uprev[i] - uprev2[i]) / dt2
DD30 = (uprev[i] - uprev2[i]) / dt3 - (uprev2[i] - uprev3[i]) / dt4
tmp[i] = r * abs((DD31 - DD30) / dt5)
end
calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
if integrator.EEst <= 1
copyto!(cache.uprev3, uprev2)
cache.tprev2 = tprev
end
elseif integrator.success_iter > 0
integrator.EEst = 1
copyto!(cache.uprev3, integrator.uprev2)
cache.tprev2 = integrator.tprev
integrator.stats.nsolve += 1
else
integrator.EEst = 1
est = tmp
end
calculate_residuals!(atmp, est, uprev, u, integrator.opts.abstol,
integrator.opts.reltol, integrator.opts.internalnorm, t)
integrator.EEst = integrator.opts.internalnorm(atmp, t)
end

OrdinaryDiffEqCore.increment_nf!(integrator.stats, 2)
f(integrator.fsallast, u, p, t + dt)
@.. broadcast=false integrator.fsallast=z / dt
end

@muladd function perform_step!(integrator, cache::SSPSDIRK2ConstantCache,
Expand Down
37 changes: 30 additions & 7 deletions lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2210,16 +2210,39 @@ function ESDIRK659L2SATableau(T, T2)
end

struct SDIRK22Tableau{T}
a::T
α::T
β::T
γ::T
a21::T
bhat1::T
bhat2::T
btilde1::T
btilde2::T
end
#=
Tableau:
γ=1-1/√2

c = [γ; 1]
A = [ γ 0
1-γ γ]
b = [1-γ; γ]

Embedded scheme: 1st order, A- and L-stable
bhat = [1-γ+γ^2; γ-γ^2]

Error estimation:
btilde = bhat-b = [btilde1; btilde2] = [γ/(1+γ) - 1 + γ; 1/(1+γ) - γ].
=#
function SDIRK22Tableau(T)
a = convert(T, 1 - 1 / sqrt(2))
α = convert(T, -sqrt(2))
β = convert(T, 1 + sqrt(2))
SDIRK22Tableau(a, α, β)
γ = convert(T, 1 - 1 / sqrt(2))
a21 = convert(T, 1 - γ)
#bhat1 = convert(T, γ / (1 + γ))
#bhat2 = convert(T, 1 / (1 + γ))
bhat1 = convert(T, 1 - γ + γ^2)
bhat2 = convert(T, γ - γ^2)
btilde1 = convert(T, bhat1 - 1 + γ)
btilde2 = convert(T, bhat2 - γ)

SDIRK22Tableau(γ, a21, bhat1, bhat2, btilde1, btilde2)
end

struct KenCarp47Tableau{T, T2}
Expand Down