From 94a13c9a6523c741518d40015f7957758a793df7 Mon Sep 17 00:00:00 2001 From: Stefan Kopecz Date: Mon, 29 Sep 2025 17:45:58 +0200 Subject: [PATCH] implemented SDIRK22 --- lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl | 2 +- lib/OrdinaryDiffEqSDIRK/src/algorithms.jl | 6 +- lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl | 33 ++-- .../src/sdirk_perform_step.jl | 168 +++++++----------- lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl | 37 +++- 5 files changed, 114 insertions(+), 132 deletions(-) diff --git a/lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl b/lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl index a4ac5fee51..754fdb1485 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/alg_utils.jl @@ -1,6 +1,6 @@ alg_extrapolates(alg::ImplicitEuler) = true alg_extrapolates(alg::Trapezoid) = true -alg_extrapolates(alg::SDIRK22) = true +alg_extrapolates(alg::SDIRK22) = false alg_order(alg::Trapezoid) = 2 alg_order(alg::ImplicitEuler) = 1 diff --git a/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl b/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl index e1f74a4d57..98447237fb 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/algorithms.jl @@ -332,6 +332,7 @@ struct SDIRK22{CS, AD, F, F2, P, FDT, ST, CJ, StepLimiter} <: linsolve::F nlsolve::F2 precs::P + smooth_est::Bool extrapolant::Symbol controller::Symbol step_limiter!::StepLimiter @@ -342,15 +343,16 @@ function SDIRK22(; chunk_size = Val{0}(), autodiff = AutoForwardDiff(), standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward}(), linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), - extrapolant = :linear, + smooth_est = false, extrapolant = :linear, controller = :PI, step_limiter! = trivial_limiter!) AD_choice, chunk_size, diff_type = _process_AD_choice(autodiff, chunk_size, diff_type) - Trapezoid{_unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), + SDIRK22{_unwrap_val(chunk_size), typeof(AD_choice), typeof(linsolve), typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac), typeof(step_limiter!)}(linsolve, nlsolve, precs, + smooth_est, extrapolant, controller, step_limiter!, diff --git a/lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl b/lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl index 75a6453fcf..1cccfa9953 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/sdirk_caches.jl @@ -226,9 +226,7 @@ function alg_cache(alg::SDIRK2, u, rate_prototype, ::Type{uEltypeNoUnits}, SDIRK2Cache(u, uprev, fsalfirst, z₁, z₂, atmp, nlsolver, alg.step_limiter!) end -struct SDIRK22ConstantCache{uType, tType, N, Tab} <: SDIRKConstantCache - uprev3::uType - tprev2::tType +mutable struct SDIRK22ConstantCache{N, Tab} <: SDIRKConstantCache nlsolver::N tab::Tab end @@ -238,26 +236,29 @@ function alg_cache(alg::SDIRK22, u, rate_prototype, ::Type{uEltypeNoUnits}, uprev, uprev2, f, t, dt, reltol, p, calck, ::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tab = SDIRK22Tableau(constvalue(uBottomEltypeNoUnits)) - uprev3 = u - tprev2 = t - γ, c = 1, 1 + + # Want to solve nonlinear problems of the from + # z = dt ⋅ f(tmp + γ ⋅ z, p, t + c ⋅ dt) + # + # 1st stage of SDIRK22: + # z = dt ⋅ f(u + γ ⋅ z, p, t + γ ⋅ dt) + γ = tab.γ + c = γ nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(false)) - SDIRK22ConstantCache(uprev3, tprev2, nlsolver) + SDIRK22ConstantCache(nlsolver, tab) end @cache mutable struct SDIRK22Cache{ - uType, rateType, uNoUnitsType, tType, N, Tab, StepLimiter} <: + uType, rateType, uNoUnitsType, N, Tab, StepLimiter} <: SDIRKMutableCache u::uType uprev::uType - uprev2::uType fsalfirst::rateType + z1::uType atmp::uNoUnitsType - uprev3::uType - tprev2::tType nlsolver::N tab::Tab step_limiter!::StepLimiter @@ -268,18 +269,18 @@ function alg_cache(alg::SDIRK22, u, rate_prototype, ::Type{uEltypeNoUnits}, dt, reltol, p, calck, ::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits} tab = SDIRK22Tableau(constvalue(uBottomEltypeNoUnits)) - γ, c = 1, 1 + γ = tab.γ + c = γ + nlsolver = build_nlsolver(alg, u, uprev, p, t, dt, f, rate_prototype, uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits, γ, c, Val(true)) fsalfirst = zero(rate_prototype) - uprev3 = zero(u) - tprev2 = t atmp = similar(u, uEltypeNoUnits) recursivefill!(atmp, false) + z1 = zero(u) - SDIRK22Cache( - u, uprev, uprev2, fsalfirst, atmp, uprev3, tprev2, nlsolver, tab, alg.step_limiter!) # shouldn't this be SDIRK22Cache instead of SDIRK22? + SDIRK22Cache(u, uprev, fsalfirst, z1, atmp, nlsolver, tab, alg.step_limiter!) end mutable struct SSPSDIRK2ConstantCache{N} <: SDIRKConstantCache diff --git a/lib/OrdinaryDiffEqSDIRK/src/sdirk_perform_step.jl b/lib/OrdinaryDiffEqSDIRK/src/sdirk_perform_step.jl index a0f33020dc..5815c1e3b4 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/sdirk_perform_step.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/sdirk_perform_step.jl @@ -635,74 +635,56 @@ end @muladd function perform_step!(integrator, cache::SDIRK22ConstantCache, repeat_step = false) @unpack t, dt, uprev, u, f, p = integrator - @unpack a, α, β = cache.tab + @unpack γ, a21, bhat1, bhat2, btilde1, btilde2 = cache.tab nlsolver = cache.nlsolver alg = unwrap_alg(integrator, true) - - # precalculations - γ = a * dt - γdt = γ * dt markfirststage!(nlsolver) - # initial guess - zprev = dt * integrator.fsalfirst - nlsolver.z = zprev + # Want to solve nonlinear problems of the from + # z = dt ⋅ f(tmp + γ ⋅ z, p, t + c ⋅ dt) + # 1st stage of SDIRK22: + # z1 = dt ⋅ f(uprev + γ ⋅ z1, p, t + γ ⋅ dt) + nlsolver.c = γ + nlsolver.tmp = uprev + # The same nlsolver.γ is used in all stages - # first stage - nlsolver.tmp = uprev + γdt * integrator.fsalfirst - z = nlsolve!(nlsolver, integrator, cache, repeat_step) + # Initial guess (FSAL) + zprev = dt * integrator.fsalfirst + nlsolver.z = zprev + + z1 = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - uprev = α * nlsolver.tmp + β * z - # final stage - γ = dt - γdt = γ * dt - markfirststage!(nlsolver) - nlsolver.tmp = uprev + γdt * integrator.fsalfirst + # 2nd stage of SDIRK22: + # z = dt ⋅ f(uprev + a21 * z1 + γ ⋅ z, p, t + dt) + nlsolver.c = 1 + nlsolver.tmp = uprev + a21 * z1 + + # Initial guess + nlsolver.z = z1 + z = nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - u = nlsolver.tmp - if integrator.opts.adaptive - if integrator.iter > 2 - # local truncation error (LTE) bound by dt^3/12*max|y'''(t)| - # use 3rd divided differences (DD) a la SPICE and Shampine - - # TODO: check numerical stability - uprev2 = integrator.uprev2 - tprev = integrator.tprev - uprev3 = cache.uprev3 - tprev2 = cache.tprev2 + # u = uprev + (1-γ) * z1 + γ * z + u = nlsolver.tmp + γ * z - dt1 = dt * (t + dt - tprev) - dt2 = (t - tprev) * (t + dt - tprev) - dt3 = (t - tprev) * (t - tprev2) - dt4 = (tprev - tprev2) * (t - tprev2) - dt5 = t + dt - tprev2 - c = 7 / 12 # default correction factor in SPICE (LTE overestimated by DD) - r = c * dt^3 / 2 # by mean value theorem 3rd DD equals y'''(s)/6 for some s + ################################### Finalize - DD31 = (u - uprev) / dt1 - (uprev - uprev2) / dt2 - DD30 = (uprev - uprev2) / dt3 - (uprev2 - uprev3) / dt4 - tmp = r * abs((DD31 - DD30) / dt5) - atmp = calculate_residuals(tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, - t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - if integrator.EEst <= 1 - cache.uprev3 = uprev2 - cache.tprev2 = tprev - end - elseif integrator.success_iter > 0 - integrator.EEst = 1 - cache.uprev3 = integrator.uprev2 - cache.tprev2 = integrator.tprev + if integrator.opts.adaptive + tmp = btilde1 * z1 + btilde2 * z + if isnewton(nlsolver) && alg.smooth_est # From Shampine + integrator.stats.nsolve += 1 + est = _reshape(get_W(nlsolver) \ _vec(tmp), axes(tmp)) else - integrator.EEst = 1 + est = tmp end + atmp = calculate_residuals(est, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = integrator.opts.internalnorm(atmp, t) end - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 2) + integrator.fsallast = z ./ dt integrator.k[1] = integrator.fsalfirst integrator.k[2] = integrator.fsallast integrator.u = u @@ -710,77 +692,51 @@ end @muladd function perform_step!(integrator, cache::SDIRK22Cache, repeat_step = false) @unpack t, dt, uprev, u, f, p = integrator - @unpack atmp, nlsolver, step_limiter! = cache + @unpack z1, atmp, nlsolver, step_limiter! = cache @unpack z, tmp = nlsolver - @unpack a, α, β = cache.tab + W = isnewton(nlsolver) ? get_W(nlsolver) : nothing + b = nlsolver.ztmp + @unpack γ, a21, bhat1, bhat2, btilde1, btilde2 = cache.tab alg = unwrap_alg(integrator, true) - mass_matrix = integrator.f.mass_matrix - - # precalculations - γ = a * dt - γdt = γ * dt markfirststage!(nlsolver) - # first stage + # See in-place version for details @.. broadcast=false z=dt * integrator.fsalfirst - @.. broadcast=false tmp=uprev + γdt * integrator.fsalfirst - z = nlsolve!(nlsolver, integrator, cache, repeat_step) + @.. broadcast=false tmp=uprev + nlsolver.c = γ + z1 .= nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - @.. broadcast=false u=α * tmp + β * z - # final stage - γ = dt - γdt = γ * dt - markfirststage!(nlsolver) - @.. broadcast=false tmp=uprev + γdt * integrator.fsalfirst - z = nlsolve!(nlsolver, integrator, cache, repeat_step) + @.. broadcast=false z=z1 + @.. broadcast=false tmp=uprev + a21 * z + nlsolver.c = 1 + isnewton(nlsolver) && set_new_W!(nlsolver, false) + z .= nlsolve!(nlsolver, integrator, cache, repeat_step) nlsolvefail(nlsolver) && return - @.. broadcast=false u=nlsolver.tmp - step_limiter!(u, integrator, p, t + dt) + @.. broadcast=false u=tmp + γ * z - if integrator.opts.adaptive - if integrator.iter > 2 - # local truncation error (LTE) bound by dt^3/12*max|y'''(t)| - # use 3rd divided differences (DD) a la SPICE and Shampine + step_limiter!(u, integrator, p, t + dt) - # TODO: check numerical stability - uprev2 = integrator.uprev2 - tprev = integrator.tprev - uprev3 = cache.uprev3 - tprev2 = cache.tprev2 + ################################### Finalize - dt1 = dt * (t + dt - tprev) - dt2 = (t - tprev) * (t + dt - tprev) - dt3 = (t - tprev) * (t - tprev2) - dt4 = (tprev - tprev2) * (t - tprev2) - dt5 = t + dt - tprev2 - c = 7 / 12 # default correction factor in SPICE (LTE overestimated by DD) - r = c * dt^3 / 2 # by mean value theorem 3rd DD equals y'''(s)/6 for some s + if integrator.opts.adaptive + @.. broadcast=false tmp=btilde1 * z1 + btilde2 * z + if alg.smooth_est && isnewton(nlsolver) # From Shampine + est = nlsolver.cache.dz + linres = dolinsolve(integrator, nlsolver.cache.linsolve; b = _vec(tmp), + linu = _vec(est)) - @inbounds for i in eachindex(u) - DD31 = (u[i] - uprev[i]) / dt1 - (uprev[i] - uprev2[i]) / dt2 - DD30 = (uprev[i] - uprev2[i]) / dt3 - (uprev2[i] - uprev3[i]) / dt4 - tmp[i] = r * abs((DD31 - DD30) / dt5) - end - calculate_residuals!(atmp, tmp, uprev, u, integrator.opts.abstol, - integrator.opts.reltol, integrator.opts.internalnorm, t) - integrator.EEst = integrator.opts.internalnorm(atmp, t) - if integrator.EEst <= 1 - copyto!(cache.uprev3, uprev2) - cache.tprev2 = tprev - end - elseif integrator.success_iter > 0 - integrator.EEst = 1 - copyto!(cache.uprev3, integrator.uprev2) - cache.tprev2 = integrator.tprev + integrator.stats.nsolve += 1 else - integrator.EEst = 1 + est = tmp end + calculate_residuals!(atmp, est, uprev, u, integrator.opts.abstol, + integrator.opts.reltol, integrator.opts.internalnorm, t) + integrator.EEst = integrator.opts.internalnorm(atmp, t) end - OrdinaryDiffEqCore.increment_nf!(integrator.stats, 2) - f(integrator.fsallast, u, p, t + dt) + @.. broadcast=false integrator.fsallast=z / dt end @muladd function perform_step!(integrator, cache::SSPSDIRK2ConstantCache, diff --git a/lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl b/lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl index ed95cef957..27512bf655 100644 --- a/lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl +++ b/lib/OrdinaryDiffEqSDIRK/src/sdirk_tableaus.jl @@ -2210,16 +2210,39 @@ function ESDIRK659L2SATableau(T, T2) end struct SDIRK22Tableau{T} - a::T - α::T - β::T + γ::T + a21::T + bhat1::T + bhat2::T + btilde1::T + btilde2::T end +#= +Tableau: +γ=1-1/√2 + +c = [γ; 1] +A = [ γ 0 + 1-γ γ] +b = [1-γ; γ] +Embedded scheme: 1st order, A- and L-stable +bhat = [1-γ+γ^2; γ-γ^2] + +Error estimation: +btilde = bhat-b = [btilde1; btilde2] = [γ/(1+γ) - 1 + γ; 1/(1+γ) - γ]. +=# function SDIRK22Tableau(T) - a = convert(T, 1 - 1 / sqrt(2)) - α = convert(T, -sqrt(2)) - β = convert(T, 1 + sqrt(2)) - SDIRK22Tableau(a, α, β) + γ = convert(T, 1 - 1 / sqrt(2)) + a21 = convert(T, 1 - γ) + #bhat1 = convert(T, γ / (1 + γ)) + #bhat2 = convert(T, 1 / (1 + γ)) + bhat1 = convert(T, 1 - γ + γ^2) + bhat2 = convert(T, γ - γ^2) + btilde1 = convert(T, bhat1 - 1 + γ) + btilde2 = convert(T, bhat2 - γ) + + SDIRK22Tableau(γ, a21, bhat1, bhat2, btilde1, btilde2) end struct KenCarp47Tableau{T, T2}