Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function EnzymeCore.EnzymeRules.inactive_noinl(
true
end
function EnzymeCore.EnzymeRules.inactive_noinl(
::typeof(OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!), args...)
::typeof(OrdinaryDiffEqCore.fixed_t_for_tstop_error!), args...)
true
end
function EnzymeCore.EnzymeRules.inactive_noinl(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{
typeof(OrdinaryDiffEqCore.SciMLBase.check_error), Any}
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{
typeof(OrdinaryDiffEqCore.fixed_t_for_floatingpoint_error!), Any, Any}
typeof(OrdinaryDiffEqCore.fixed_t_for_tstop_error!), Any, Any}
Mooncake.@zero_adjoint Mooncake.MinimalCtx Tuple{
typeof(OrdinaryDiffEqCore.final_progress), Any}

Expand Down
85 changes: 62 additions & 23 deletions lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,58 @@ function modify_dt_for_tstops!(integrator)
if has_tstop(integrator)
tdir_t = integrator.tdir * integrator.t
tdir_tstop = first_tstop(integrator)
distance_to_tstop = abs(tdir_tstop - tdir_t)

# Store the original dt to check if it gets significantly reduced
original_dt = abs(integrator.dt)

if integrator.opts.adaptive
integrator.dt = integrator.tdir *
min(abs(integrator.dt), abs(tdir_tstop - tdir_t)) # step! to the end
integrator.dtpropose = original_dt
if original_dt < distance_to_tstop
# Normal step, no tstop interference
integrator.next_step_tstop = false
else
# Distance is smaller, entering tstop snap mode
integrator.next_step_tstop = true
integrator.tstop_target = integrator.tdir * tdir_tstop
end
integrator.dt = integrator.tdir * min(original_dt, distance_to_tstop)
elseif iszero(integrator.dtcache) && integrator.dtchangeable
integrator.dt = integrator.tdir * abs(tdir_tstop - tdir_t)
integrator.dt = integrator.tdir * distance_to_tstop
integrator.next_step_tstop = true
integrator.tstop_target = integrator.tdir * tdir_tstop
elseif integrator.dtchangeable && !integrator.force_stepfail
# always try to step! with dtcache, but lower if a tstop
# however, if force_stepfail then don't set to dtcache, and no tstop worry
integrator.dt = integrator.tdir *
min(abs(integrator.dtcache), abs(tdir_tstop - tdir_t)) # step! to the end
if abs(integrator.dtcache) < distance_to_tstop
# Normal step with dtcache, no tstop interference
integrator.next_step_tstop = false
else
# Distance is smaller, entering tstop snap mode
integrator.next_step_tstop = true
integrator.tstop_target = integrator.tdir * tdir_tstop
end
integrator.dt = integrator.tdir * min(abs(integrator.dtcache), distance_to_tstop)
else
integrator.next_step_tstop = false
end
else
integrator.next_step_tstop = false
end
end

function handle_tstop_step!(integrator)
if integrator.t isa AbstractFloat && abs(integrator.dt) < eps(abs(integrator.t))
# Skip perform_step! entirely for tiny dt
integrator.accept_step = true
else
# Normal step
perform_step!(integrator, integrator.cache)
end

# Flag will be reset in fixed_t_for_tstop_error! when t is updated
end

# Want to extend savevalues! for DDEIntegrator
function savevalues!(integrator::ODEIntegrator, force_save = false, reduce_size = true)
_savevalues!(integrator, force_save, reduce_size)
Expand Down Expand Up @@ -149,7 +187,7 @@ function _savevalues!(integrator, force_save, reduce_size)::Tuple{Bool, Bool}
end
if force_save || (integrator.opts.save_everystep &&
(isempty(integrator.sol.t) ||
(integrator.t !== integrator.sol.t[end]) &&
(integrator.t !== integrator.sol.t[end] || iszero(integrator.dt)) &&
(integrator.opts.save_end || integrator.t !== integrator.sol.prob.tspan[2])
))
integrator.saveiter += 1
Expand Down Expand Up @@ -274,12 +312,20 @@ function _loopfooter!(integrator)
if integrator.accept_step # Accept
increment_accept!(integrator.stats)
integrator.last_stepfail = false
integrator.tprev = integrator.t

if integrator.next_step_tstop
# Step controller dt is overly pessimistic, since dt = time to tstop
# For example, if super dense time, dt = eps(t), so the next step is tiny
# Thus if snap to tstop, let the step controller assume dt was the pre-fixed version
integrator.dt = integrator.dtpropose
end
integrator.t = fixed_t_for_tstop_error!(integrator, ttmp)

dtnew = DiffEqBase.value(step_accept_controller!(integrator,
integrator.alg,
q)) *
oneunit(integrator.dt)
integrator.tprev = integrator.t
integrator.t = fixed_t_for_floatingpoint_error!(integrator, ttmp)
calc_dt_propose!(integrator, dtnew)
handle_callbacks!(integrator)
else # Reject
Expand All @@ -288,7 +334,7 @@ function _loopfooter!(integrator)
elseif !integrator.opts.adaptive #Not adaptive
increment_accept!(integrator.stats)
integrator.tprev = integrator.t
integrator.t = fixed_t_for_floatingpoint_error!(integrator, ttmp)
integrator.t = fixed_t_for_tstop_error!(integrator, ttmp)
integrator.last_stepfail = false
integrator.accept_step = true
integrator.dtpropose = integrator.dt
Expand Down Expand Up @@ -327,16 +373,12 @@ function log_step!(progress_name, progress_id, progress_message, dt, u, p, t, ts
progress=(t-t1)/(t2-t1))
end

function fixed_t_for_floatingpoint_error!(integrator, ttmp)
if has_tstop(integrator)
tstop = integrator.tdir * first_tstop(integrator)
if abs(ttmp - tstop) <
100eps(float(max(integrator.t, tstop) / oneunit(integrator.t))) *
oneunit(integrator.t)
tstop
else
ttmp
end
function fixed_t_for_tstop_error!(integrator, ttmp)
# If we're in tstop snap mode, use exact tstop target
if integrator.next_step_tstop
# Reset the flag now that we're snapping to tstop
integrator.next_step_tstop = false
return integrator.tstop_target
else
ttmp
end
Expand Down Expand Up @@ -468,10 +510,7 @@ function handle_tstop!(integrator)
tdir_t = integrator.tdir * integrator.t
tdir_tstop = first_tstop(integrator)
if tdir_t == tdir_tstop
while tdir_t == tdir_tstop #remove all redundant copies
res = pop_tstop!(integrator)
has_tstop(integrator) ? (tdir_tstop = first_tstop(integrator)) : break
end
res = pop_tstop!(integrator)
integrator.just_hit_tstop = true
elseif tdir_t > tdir_tstop
if !integrator.dtchangeable
Expand Down
2 changes: 2 additions & 0 deletions lib/OrdinaryDiffEqCore/src/integrators/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ mutable struct ODEIntegrator{algType <: Union{OrdinaryDiffEqAlgorithm, DAEAlgori
force_stepfail::Bool
last_stepfail::Bool
just_hit_tstop::Bool
next_step_tstop::Bool
tstop_target::tType
do_error_check::Bool
event_last_time::Int
vector_event_last_time::Int
Expand Down
26 changes: 19 additions & 7 deletions lib/OrdinaryDiffEqCore/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -503,8 +503,10 @@ function SciMLBase.__init(
u_modified = false
EEst = EEstT(1)
just_hit_tstop = false
next_step_tstop = false
tstop_target = zero(t)
isout = false
accept_step = false
accept_step = true
force_stepfail = false
last_stepfail = false
do_error_check = true
Expand Down Expand Up @@ -541,7 +543,7 @@ function SciMLBase.__init(
callback_cache,
kshortsize, force_stepfail,
last_stepfail,
just_hit_tstop, do_error_check,
just_hit_tstop, next_step_tstop, tstop_target, do_error_check,
event_last_time,
vector_event_last_time,
last_event_error, accept_step,
Expand Down Expand Up @@ -603,14 +605,24 @@ end

function SciMLBase.solve!(integrator::ODEIntegrator)
@inbounds while !isempty(integrator.opts.tstops)
while integrator.tdir * integrator.t < first(integrator.opts.tstops)
first_tstop = first(integrator.opts.tstops)
while integrator.tdir * integrator.t <= first_tstop
loopheader!(integrator)
if integrator.do_error_check && check_error!(integrator) != ReturnCode.Success
return integrator.sol
end
perform_step!(integrator, integrator.cache)

# Use special tstop handling if flag is set, otherwise normal stepping
if integrator.next_step_tstop
handle_tstop_step!(integrator)
else
perform_step!(integrator, integrator.cache)
end

should_exit = integrator.next_step_tstop

loopfooter!(integrator)
if isempty(integrator.opts.tstops)
if isempty(integrator.opts.tstops) || should_exit
break
end
end
Expand Down Expand Up @@ -662,11 +674,11 @@ end

for t in tstops
tdir_t = tdir * t
tdir_t0 < tdir_t tdir_tf && push!(tstops_internal, tdir_t)
tdir_t0 < tdir_t < tdir_tf && push!(tstops_internal, tdir_t)
end
for t in d_discontinuities
tdir_t = tdir * t
tdir_t0 < tdir_t tdir_tf && push!(tstops_internal, tdir_t)
tdir_t0 < tdir_t < tdir_tf && push!(tstops_internal, tdir_t)
end
push!(tstops_internal, tdir_tf)

Expand Down
Loading
Loading