Skip to content

Commit c9b9244

Browse files
fix: use improved discrete saving API
1 parent f51778d commit c9b9244

File tree

1 file changed

+28
-39
lines changed

1 file changed

+28
-39
lines changed

src/systems/callbacks.jl

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -697,12 +697,18 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
697697
return generate_callback(cbs[cb_ind], sys; kwargs...)
698698
end
699699

700+
if is_split(sys)
701+
ic = get_index_cache(sys)
702+
else
703+
ic = nothing
704+
end
700705
trigger = compile_condition(
701706
cbs, sys, unknowns(sys), parameters(sys; initial_parameters = true); kwargs...)
702707
affects = []
703708
affect_negs = []
704709
inits = []
705710
finals = []
711+
discrete_save_idxs = Vector{Int}[]
706712
for cb in cbs
707713
affect = compile_affect(cb.affect, cb, sys; default = EMPTY_AFFECT, kwargs...)
708714
push!(affects, affect)
@@ -712,8 +718,12 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
712718
push!(affect_negs, affect_neg)
713719
push!(inits,
714720
compile_affect(
715-
cb.initialize, cb, sys; default = nothing, is_init = true, kwargs...))
721+
cb.initialize, cb, sys; default = nothing, kwargs...))
716722
push!(finals, compile_affect(cb.finalize, cb, sys; default = nothing, kwargs...))
723+
724+
if ic !== nothing
725+
push!(discrete_save_idxs, get(ic.callback_to_clocks, cb, Int[]))
726+
end
717727
end
718728

719729
# Since there may be different number of conditions and affects,
@@ -739,7 +749,8 @@ function generate_callback(cbs::Vector{SymbolicContinuousCallback}, sys; kwargs.
739749

740750
return VectorContinuousCallback(
741751
trigger, affect, affect_neg, length(eqs); initialize, finalize,
742-
rootfind = cbs[1].rootfind, initializealg = cbs[1].reinitializealg)
752+
rootfind = cbs[1].rootfind, initializealg = cbs[1].reinitializealg,
753+
discrete_save_idxs)
743754
end
744755

745756
function generate_callback(cb, sys; kwargs...)
@@ -756,27 +767,33 @@ function generate_callback(cb, sys; kwargs...)
756767
compile_affect(cb.affect_neg, cb, sys; default = EMPTY_AFFECT, kwargs...)
757768
end
758769
init = compile_affect(cb.initialize, cb, sys; default = SciMLBase.INITIALIZE_DEFAULT,
759-
is_init = true, kwargs...)
770+
kwargs...)
760771
final = compile_affect(
761772
cb.finalize, cb, sys; default = SciMLBase.FINALIZE_DEFAULT, kwargs...)
762773

763774
initialize = isnothing(cb.initialize) ? init : ((c, u, t, i) -> init(i))
764775
finalize = isnothing(cb.finalize) ? final : ((c, u, t, i) -> final(i))
765776

777+
discrete_save_idxs = if is_split(sys)
778+
get(get_index_cache(sys).callback_to_clocks, cb, ())
779+
else
780+
()
781+
end
766782
if is_discrete(cb)
767783
if is_timed && conditions(cb) isa AbstractVector
768784
return PresetTimeCallback(trigger, affect; initialize,
769-
finalize, initializealg = cb.reinitializealg)
785+
finalize, initializealg = cb.reinitializealg, discrete_save_idxs)
770786
elseif is_timed
771787
return PeriodicCallback(
772-
affect, trigger; initialize, finalize, initializealg = cb.reinitializealg)
788+
affect, trigger; initialize, finalize, initializealg = cb.reinitializealg,
789+
discrete_save_idxs)
773790
else
774791
return DiscreteCallback(trigger, affect; initialize,
775-
finalize, initializealg = cb.reinitializealg)
792+
finalize, initializealg = cb.reinitializealg, discrete_save_idxs)
776793
end
777794
else
778795
return ContinuousCallback(trigger, affect, affect_neg; initialize, finalize,
779-
rootfind = cb.rootfind, initializealg = cb.reinitializealg)
796+
rootfind = cb.rootfind, initializealg = cb.reinitializealg, discrete_save_idxs)
780797
end
781798
end
782799

@@ -791,41 +808,13 @@ Notes
791808
"""
792809
function compile_affect(
793810
aff::Union{Nothing, Affect}, cb::AbstractCallback, sys::AbstractSystem;
794-
default = nothing, is_init = false, kwargs...)
795-
save_idxs = if !(has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing)
796-
Int[]
797-
else
798-
get(ic.callback_to_clocks, cb, Int[])
799-
end
800-
811+
default = nothing, kwargs...)
801812
if isnothing(aff)
802-
is_init ? wrap_save_discretes(default, save_idxs) : default
813+
default
803814
elseif aff isa AffectSystem
804-
f = compile_equational_affect(aff, sys; kwargs...)
805-
wrap_save_discretes(f, save_idxs)
815+
compile_equational_affect(aff, sys; kwargs...)
806816
elseif aff isa ImperativeAffect
807-
f = compile_functional_affect(aff, sys; kwargs...)
808-
wrap_save_discretes(f, save_idxs)
809-
end
810-
end
811-
812-
function wrap_save_discretes(f, save_idxs)
813-
let save_idxs = save_idxs, f = f
814-
if f === SciMLBase.INITIALIZE_DEFAULT
815-
(c, u, t, i) -> begin
816-
f(c, u, t, i)
817-
for idx in save_idxs
818-
SciMLBase.save_discretes!(i, idx)
819-
end
820-
end
821-
else
822-
(i) -> begin
823-
isnothing(f) || f(i)
824-
for idx in save_idxs
825-
SciMLBase.save_discretes!(i, idx)
826-
end
827-
end
828-
end
817+
compile_functional_affect(aff, sys; kwargs...)
829818
end
830819
end
831820

0 commit comments

Comments
 (0)