From 4988ee10da32d305e3c74b9cdc0888ff68dcb9e8 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Mon, 17 Feb 2025 17:51:46 -0800 Subject: [PATCH 01/20] Early work on the new discrete backend for MTK --- src/systems/clock_inference.jl | 11 ++++++++- src/systems/systems.jl | 7 +++++- src/systems/systemstructure.jl | 43 +++++++++++++++++++++++++--------- 3 files changed, 48 insertions(+), 13 deletions(-) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index 42fe28f7c7..86fe7e85ac 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -100,7 +100,7 @@ function infer_clocks!(ci::ClockInference) c = BitSet(c′) idxs = intersect(c, inferred) isempty(idxs) && continue - if !allequal(var_domain[i] for i in idxs) + if !allequal(iscontinuous(var_domain[i]) for i in idxs) display(fullvars[c′]) throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])")) end @@ -155,6 +155,9 @@ function split_system(ci::ClockInference{S}) where {S} cid_to_var = Vector{Int}[] # cid_counter = number of clocks cid_counter = Ref(0) + + # populates clock_to_id and id_to_clock + # checks if there is a continuous_id (for some reason? clock to id does this too) for (i, d) in enumerate(eq_domain) cid = let cid_counter = cid_counter, id_to_clock = id_to_clock, continuous_id = continuous_id @@ -174,9 +177,13 @@ function split_system(ci::ClockInference{S}) where {S} resize_or_push!(cid_to_eq, i, cid) end continuous_id = continuous_id[] + # for each clock partition what are the input (indexes/vars) input_idxs = map(_ -> Int[], 1:cid_counter[]) inputs = map(_ -> Any[], 1:cid_counter[]) + # var_domain corresponds to fullvars/all variables in the system nvv = length(var_domain) + # put variables into the right clock partition + # keep track of inputs to each partition for i in 1:nvv d = var_domain[i] cid = get(clock_to_id, d, 0) @@ -190,6 +197,7 @@ function split_system(ci::ClockInference{S}) where {S} resize_or_push!(cid_to_var, i, cid) end + # breaks the system up into a continous and 0 or more discrete systems tss = similar(cid_to_eq, S) for (id, ieqs) in enumerate(cid_to_eq) ts_i = system_subset(ts, ieqs) @@ -199,6 +207,7 @@ function split_system(ci::ClockInference{S}) where {S} end tss[id] = ts_i end + # put the continous system at the back if continuous_id != 0 tss[continuous_id], tss[end] = tss[end], tss[continuous_id] inputs[continuous_id], inputs[end] = inputs[end], inputs[continuous_id] diff --git a/src/systems/systems.jl b/src/systems/systems.jl index ff455fb811..c6b9d78f97 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -36,7 +36,7 @@ function mtkcompile( isscheduled(sys) && throw(RepeatedStructuralSimplificationError()) newsys′ = __mtkcompile(sys; simplify, allow_symbolic, allow_parameter, conservative, fully_determined, - inputs, outputs, disturbance_inputs, + inputs, outputs, disturbance_inputs, additional_passes, kwargs...) if newsys′ isa Tuple @assert length(newsys′) == 2 @@ -292,3 +292,8 @@ function map_variables_to_equations(sys::AbstractSystem; rename_dummy_derivative return mapping end + +""" +Mark whether an extra pass `p` can support compiling discrete systems. +""" +discrete_compile_pass(p) = false diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 5dfd36a6fc..a5401c9f9b 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -820,19 +820,40 @@ function mtkcompile!(state::TearingState; simplify = false, time_domains = merge(Dict(state.fullvars .=> ci.var_domain), Dict(default_toterm.(state.fullvars) .=> ci.var_domain)) tss, clocked_inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci) + if continuous_id == 0 + # do a trait check here - handle fully discrete system + additional_passes = get(kwargs, :additional_passes, nothing) + if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes) + # take the first discrete compilation pass given for now + discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) + discrete_compile = additional_passes[discrete_pass_idx] + deleteat!(additional_passes, discrete_pass_idx) + return discrete_compile(tss, clocked_inputs) + end + throw(HybridSystemNotSupportedException(""" + Discrete systems with multiple clocks are not supported with the standard \ + MTK compiler. + """)) + end if length(tss) > 1 - if continuous_id == 0 - throw(HybridSystemNotSupportedException(""" - Discrete systems with multiple clocks are not supported with the standard \ - MTK compiler. - """)) - else - throw(HybridSystemNotSupportedException(""" - Hybrid continuous-discrete systems are currently not supported with \ - the standard MTK compiler. This system requires JuliaSimCompiler.jl, \ - see https://help.juliahub.com/juliasimcompiler/stable/ - """)) + # simplify as normal + sys = _mtkcompile!(tss[continuous_id]; simplify, + inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs, + check_consistency, fully_determined, + kwargs...) + if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes) + discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) + discrete_compile = additional_passes[discrete_pass_idx] + deleteat!(additional_passes, discrete_pass_idx) + # in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems + # and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed + return discrete_compile(sys, tss[2:end], inputs) end + throw(HybridSystemNotSupportedException(""" + Hybrid continuous-discrete systems are currently not supported with \ + the standard MTK compiler. This system requires JuliaSimCompiler.jl, \ + see https://help.juliahub.com/juliasimcompiler/stable/ + """)) end if get_is_discrete(state.sys) || continuous_id == 1 && any(Base.Fix2(isoperator, Shift), state.fullvars) From ed0612b33c3d112498095a75f08b9ccb6e9d2cef Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 20 Feb 2025 14:52:14 +0530 Subject: [PATCH 02/20] feat: retain original equations of the system in `TearingState` --- src/systems/systems.jl | 4 ++-- src/systems/systemstructure.jl | 21 ++++++++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index c6b9d78f97..bdca6ff71a 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -80,7 +80,6 @@ function __mtkcompile(sys::AbstractSystem; simplify = false, @unpack structure, fullvars = state @unpack graph, var_to_diff, var_types = structure - eqs = equations(state) brown_vars = Int[] new_idxs = zeros(Int, length(var_types)) idx = 0 @@ -98,7 +97,8 @@ function __mtkcompile(sys::AbstractSystem; simplify = false, Is = Int[] Js = Int[] vals = Num[] - new_eqs = copy(eqs) + make_eqs_zero_equals!(state) + new_eqs = copy(equations(state)) dvar2eq = Dict{Any, Int}() for (v, dv) in enumerate(var_to_diff) dv === nothing && continue diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index a5401c9f9b..c3ff82478e 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -203,6 +203,7 @@ end mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} """The system of equations.""" sys::T + original_eqs::Vector{Equation} """The set of variables of the system.""" fullvars::Vector{BasicSymbolic} structure::SystemStructure @@ -219,6 +220,7 @@ end TransformationState(sys::AbstractSystem) = TearingState(sys) function system_subset(ts::TearingState, ieqs::Vector{Int}) eqs = equations(ts) + @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.sys.eqs = eqs[ieqs] @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.structure = system_subset(ts.structure, ieqs) @@ -524,7 +526,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) eq_to_diff = DiffGraph(nsrcs(graph)) - ts = TearingState(sys, fullvars, + ts = TearingState(sys, original_eqs, fullvars, SystemStructure(complete(var_to_diff), complete(eq_to_diff), complete(graph), nothing, var_types, false), Any[], param_derivative_map, original_eqs, Equation[]) @@ -810,6 +812,22 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure) printstyled(io, " SelectedState") end +function make_eqs_zero_equals!(ts::TearingState) + neweqs = map(enumerate(get_eqs(ts.sys))) do kvp + i, eq = kvp + isalgeq = true + for j in 𝑠neighbors(ts.structure.graph, i) + isalgeq &= invview(ts.structure.var_to_diff)[j] === nothing + end + if isalgeq + return 0 ~ eq.rhs - eq.lhs + else + return eq + end + end + copyto!(get_eqs(ts.sys), neweqs) +end + function mtkcompile!(state::TearingState; simplify = false, check_consistency = true, fully_determined = true, warn_initialize_determined = true, inputs = Any[], outputs = Any[], @@ -836,6 +854,7 @@ function mtkcompile!(state::TearingState; simplify = false, """)) end if length(tss) > 1 + make_eqs_zero_equals!(tss[continuous_id]) # simplify as normal sys = _mtkcompile!(tss[continuous_id]; simplify, inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs, From 7f8b8f260efd0ea7c067baea118c16f40263c091 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 14 Mar 2025 14:50:31 +0530 Subject: [PATCH 03/20] feat: allow namespacing statemachine equations --- src/systems/abstractsystem.jl | 1 + src/systems/state_machines.jl | 33 +++++++++++++++++++++++++++++++++ src/utils.jl | 6 ++++++ 3 files changed, 40 insertions(+) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 70e7b06bfe..34dfb5283e 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1228,6 +1228,7 @@ function namespace_expr( O end end + _nonum(@nospecialize x) = x isa Num ? x.val : x """ diff --git a/src/systems/state_machines.jl b/src/systems/state_machines.jl index 347f92e6f8..ea65981804 100644 --- a/src/systems/state_machines.jl +++ b/src/systems/state_machines.jl @@ -153,3 +153,36 @@ entry When used in a finite state machine, this operator returns `true` if the queried state is active and false otherwise. """ activeState + +function vars!(vars, O::Transition; op = Differential) + vars!(vars, O.from) + vars!(vars, O.to) + vars!(vars, O.cond; op) + return vars +end +function vars!(vars, O::InitialState; op = Differential) + vars!(vars, O.s; op) + return vars +end +function vars!(vars, O::StateMachineOperator; op = Differential) + error("Unhandled state machine operator") +end + +function namespace_expr( + O::Transition, sys, n = nameof(sys); ivs = independent_variables(sys)) + return Transition( + O.from === nothing ? O.from : renamespace(sys, O.from), + O.to === nothing ? O.to : renamespace(sys, O.to), + O.cond === nothing ? O.cond : namespace_expr(O.cond, sys), + O.immediate, O.reset, O.synchronize, O.priority + ) +end + +function namespace_expr( + O::InitialState, sys, n = nameof(sys); ivs = independent_variables(sys)) + return InitialState(O.s === nothing ? O.s : renamespace(sys, O.s)) +end + +function namespace_expr(O::StateMachineOperator, sys, n = nameof(sys); kwargs...) + error("Unhandled state machine operator") +end diff --git a/src/utils.jl b/src/utils.jl index e96f31f533..d028d4ed18 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -391,6 +391,12 @@ vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op) function vars!(vars, eq::Equation; op = Differential) (vars!(vars, eq.lhs; op = op); vars!(vars, eq.rhs; op = op); vars) end +function vars!(vars, O::AbstractSystem; op = Differential) + for eq in equations(O) + vars!(vars, eq; op) + end + return vars +end function vars!(vars, O; op = Differential) if isvariable(O) if iscall(O) && operation(O) === getindex && iscalledparameter(first(arguments(O))) From 7cf774d0bd4d355b9f696def2246988b49265d87 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 14 Mar 2025 14:56:00 +0530 Subject: [PATCH 04/20] feat: propagate state machines in structural simplification --- src/systems/systems.jl | 4 +- src/systems/systemstructure.jl | 81 ++++++++++++++++++++++++++++++---- 2 files changed, 76 insertions(+), 9 deletions(-) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index bdca6ff71a..9769d42e96 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -75,8 +75,10 @@ function __mtkcompile(sys::AbstractSystem; simplify = false, return simplify_optimization_system(sys; kwargs..., sort_eqs, simplify) end + sys, statemachines = extract_top_level_statemachines(sys) sys = expand_connections(sys) - state = TearingState(sys; sort_eqs) + state = TearingState(sys) + append!(state.statemachines, statemachines) @unpack structure, fullvars = state @unpack graph, var_to_diff, var_types = structure diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index c3ff82478e..f072dc3f52 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -203,7 +203,6 @@ end mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} """The system of equations.""" sys::T - original_eqs::Vector{Equation} """The set of variables of the system.""" fullvars::Vector{BasicSymbolic} structure::SystemStructure @@ -215,6 +214,7 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} are not used in the rest of the system. """ additional_observed::Vector{Equation} + statemachines::Vector{T} end TransformationState(sys::AbstractSystem) = TearingState(sys) @@ -224,6 +224,22 @@ function system_subset(ts::TearingState, ieqs::Vector{Int}) @set! ts.sys.eqs = eqs[ieqs] @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.structure = system_subset(ts.structure, ieqs) + if all(eq -> eq.rhs isa StateMachineOperator, get_eqs(ts.sys)) + names = Symbol[] + for eq in get_eqs(ts.sys) + if eq.lhs isa Transition + push!(names, first(namespace_hierarchy(nameof(eq.rhs.from)))) + push!(names, first(namespace_hierarchy(nameof(eq.rhs.to)))) + elseif eq.lhs isa InitialState + push!(names, first(namespace_hierarchy(nameof(eq.rhs.s)))) + else + error("Unhandled state machine operator") + end + end + @set! ts.statemachines = filter(x -> nameof(x) in names, ts.statemachines) + else + @set! ts.statemachines = eltype(ts.statemachines)[] + end ts end @@ -277,6 +293,49 @@ function symbolic_contains(var, set) all(x -> x in set, Symbolics.scalarize(var)) end +""" + $(TYPEDSIGNATURES) + +Descend through the system hierarchy and look for statemachines. Remove equations from +the inner statemachine systems. Return the new `sys` and an array of top-level +statemachines. +""" +function extract_top_level_statemachines(sys::AbstractSystem) + eqs = get_eqs(sys) + + if !isempty(eqs) && all(eq -> eq.lhs isa StateMachineOperator, eqs) + # top-level statemachine + with_removed = @set sys.systems = map(remove_child_equations, get_systems(sys)) + return with_removed, [sys] + elseif !isempty(eqs) && any(eq -> eq.lhs isa StateMachineOperator, eqs) + # error: can't mix + error("Mixing statemachine equations and standard equations in a top-level statemachine is not allowed.") + else + # descend + subsystems = get_systems(sys) + newsubsystems = eltype(subsystems)[] + statemachines = eltype(subsystems)[] + for subsys in subsystems + newsubsys, sub_statemachines = extract_top_level_statemachines(subsys) + push!(newsubsystems, newsubsys) + append!(statemachines, sub_statemachines) + end + @set! sys.systems = newsubsystems + return sys, statemachines + end +end + +""" + $(TYPEDSIGNATURES) + +Return `sys` with all equations (including those in subsystems) removed. +""" +function remove_child_equations(sys::AbstractSystem) + @set! sys.eqs = eltype(get_eqs(sys))[] + @set! sys.systems = map(remove_child_equations, get_systems(sys)) + return sys +end + function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) # flatten system sys = flatten(sys) @@ -342,9 +401,16 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) # change the equation if the RHS is `missing` so the rest of this loop works eq = 0.0 ~ coalesce(eq.rhs, 0.0) end - rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs - if !_iszero(eq.lhs) + is_statemachine_equation = false + if eq.lhs isa StateMachineOperator + is_statemachine_equation = true + eq = eq + rhs = eq.rhs + elseif _iszero(eq.lhs) + rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs + else lhs = quick_cancel ? quick_cancel_expr(eq.lhs) : eq.lhs + rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs eq = 0 ~ rhs - lhs end empty!(varsbuf) @@ -408,8 +474,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) addvar!(v, VARIABLE) end end - - if isalgeq + if isalgeq || is_statemachine_equation eqs[i] = eq else eqs[i] = eqs[i].lhs ~ rhs @@ -526,11 +591,10 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) eq_to_diff = DiffGraph(nsrcs(graph)) - ts = TearingState(sys, original_eqs, fullvars, + ts = TearingState(sys, fullvars, SystemStructure(complete(var_to_diff), complete(eq_to_diff), complete(graph), nothing, var_types, false), - Any[], param_derivative_map, original_eqs, Equation[]) - + Any[], param_derivative_map, original_eqs, Equation[], typeof(sys)[]) return ts end @@ -860,6 +924,7 @@ function mtkcompile!(state::TearingState; simplify = false, inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs, check_consistency, fully_determined, kwargs...) + additional_passes = get(kwargs, :additional_passes, nothing) if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes) discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) discrete_compile = additional_passes[discrete_pass_idx] From c081a34b30549b7f2bff9ae5feff11a57067a4ca Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Fri, 14 Mar 2025 18:24:11 -0700 Subject: [PATCH 05/20] Handle nothing updates better --- src/systems/imperative_affect.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index 7b1a9fb286..f3d45e258a 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -262,7 +262,9 @@ function compile_functional_affect( upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ) # write the new values back to the integrator - _generated_writeback(integ, upd_funs, upd_vals) + if !isnothing(upd_vals) + _generated_writeback(integ, upd_funs, upd_vals) + end reset_jumps && reset_aggregated_jumps!(integ) end From b60be7975ee056f2458340e75332b0f01eb89229 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Fri, 14 Mar 2025 18:24:24 -0700 Subject: [PATCH 06/20] Redefine the discrete_compile interface a bit --- src/systems/systemstructure.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index f072dc3f52..06f8ccc639 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -897,6 +897,9 @@ function mtkcompile!(state::TearingState; simplify = false, inputs = Any[], outputs = Any[], disturbance_inputs = Any[], kwargs...) + # split_system returns one or two systems and the inputs for each + # mod clock inference to be binary + # if it's continous keep going, if not then error unless given trait impl in additional passes ci = ModelingToolkit.ClockInference(state) ci = ModelingToolkit.infer_clocks!(ci) time_domains = merge(Dict(state.fullvars .=> ci.var_domain), @@ -910,7 +913,7 @@ function mtkcompile!(state::TearingState; simplify = false, discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes) discrete_compile = additional_passes[discrete_pass_idx] deleteat!(additional_passes, discrete_pass_idx) - return discrete_compile(tss, clocked_inputs) + return discrete_compile(tss, clocked_inputs, ci) end throw(HybridSystemNotSupportedException(""" Discrete systems with multiple clocks are not supported with the standard \ @@ -931,7 +934,7 @@ function mtkcompile!(state::TearingState; simplify = false, deleteat!(additional_passes, discrete_pass_idx) # in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems # and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed - return discrete_compile(sys, tss[2:end], inputs) + return discrete_compile(sys, tss[[i for i in eachindex(tss) if i != continuous_id]], clocked_inputs, ci) end throw(HybridSystemNotSupportedException(""" Hybrid continuous-discrete systems are currently not supported with \ From 405aafabfba1af39bee6f81d6506f004c943dff4 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Wed, 14 May 2025 13:43:00 -0700 Subject: [PATCH 07/20] Change the external synchronous signature to include the id/clock map --- src/systems/systemstructure.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 06f8ccc639..66f3a4b6f9 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -934,7 +934,9 @@ function mtkcompile!(state::TearingState; simplify = false, deleteat!(additional_passes, discrete_pass_idx) # in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems # and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed - return discrete_compile(sys, tss[[i for i in eachindex(tss) if i != continuous_id]], clocked_inputs, ci) + return discrete_compile( + sys, tss[[i for i in eachindex(tss) if i != continuous_id]], + clocked_inputs, ci, id_to_clock) end throw(HybridSystemNotSupportedException(""" Hybrid continuous-discrete systems are currently not supported with \ From 865523b33df9a36698879f85e7e68d32eb659623 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 20 Jun 2025 17:15:07 +0530 Subject: [PATCH 08/20] feat: add `zero_crossing_id` to `SymbolicContinuousCallback` --- src/systems/callbacks.jl | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index a4f39243d9..f05e455cbc 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -165,6 +165,7 @@ struct SymbolicContinuousCallback <: AbstractCallback finalize::Union{Affect, Nothing} rootfind::Union{Nothing, SciMLBase.RootfindOpt} reinitializealg::SciMLBase.DAEInitializationAlgorithm + zero_crossing_id::Symbol function SymbolicContinuousCallback( conditions::Union{Equation, Vector{Equation}}, @@ -174,6 +175,7 @@ struct SymbolicContinuousCallback <: AbstractCallback finalize = nothing, rootfind = SciMLBase.LeftRootFind, reinitializealg = nothing, + zero_crossing_id = gensym(), kwargs...) conditions = (conditions isa AbstractVector) ? conditions : [conditions] @@ -190,7 +192,7 @@ struct SymbolicContinuousCallback <: AbstractCallback make_affect(affect_neg; kwargs...), make_affect(initialize; kwargs...), make_affect( finalize; kwargs...), - rootfind, reinitializealg) + rootfind, reinitializealg, zero_crossing_id) end # Default affect to nothing end @@ -466,7 +468,8 @@ function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuo affect_neg = namespace_affects(affect_negs(cb), s), initialize = namespace_affects(initialize_affects(cb), s), finalize = namespace_affects(finalize_affects(cb), s), - rootfind = cb.rootfind, reinitializealg = cb.reinitializealg) + rootfind = cb.rootfind, reinitializealg = cb.reinitializealg, + zero_crossing_id = cb.zero_crossing_id) end function namespace_conditions(condition, s) @@ -490,6 +493,8 @@ function Base.hash(cb::AbstractCallback, s::UInt) s = hash(finalize_affects(cb), s) !is_discrete(cb) && (s = hash(cb.rootfind, s)) hash(cb.reinitializealg, s) + !is_discrete(cb) && (s = hash(cb.zero_crossing_id, s)) + return s end ########################### @@ -524,13 +529,16 @@ function finalize_affects(cbs::Vector{<:AbstractCallback}) end function Base.:(==)(e1::AbstractCallback, e2::AbstractCallback) - (is_discrete(e1) === is_discrete(e2)) || return false - (isequal(e1.conditions, e2.conditions) && isequal(e1.affect, e2.affect) && - isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize)) && - isequal(e1.reinitializealg, e2.reinitializealg) || - return false - is_discrete(e1) || - (isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind)) + is_discrete(e1) === is_discrete(e2) || return false + isequal(e1.conditions, e2.conditions) && isequal(e1.affect, e2.affect) || return false + isequal(e1.initialize, e2.initialize) || return false + isequal(e1.finalize, e2.finalize) || return false + isequal(e1.reinitializealg, e2.reinitializealg) || return false + if !is_discrete(e1) + isequal(e1.affect_neg, e2.affect_neg) || return false + isequal(e1.rootfind, e2.rootfind) || return false + isequal(e1.zero_crossing_id, e2.zero_crossing_id) || return false + end end Base.isempty(cb::AbstractCallback) = isempty(cb.conditions) From aeefc8abcfd6c00170352459d55f132b6ecfba1a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 20 Jun 2025 17:25:53 +0530 Subject: [PATCH 09/20] feat: add `ZeroCrossing` and `EventClock` from zero crossing --- src/discretedomain.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/discretedomain.jl b/src/discretedomain.jl index da8417de4e..370d93d894 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -365,3 +365,14 @@ function input_timedomain(x) throw(ArgumentError("$x of type $(typeof(x)) is not an operator expression")) end end + +function ZeroCrossing(expr; name = gensym(), up = true, down = true, kwargs...) + return SymbolicContinuousCallback( + [expr ~ 0], up ? ImperativeAffect(Returns(nothing)) : nothing; + affect_neg = down ? ImperativeAffect(Returns(nothing)) : nothing, + kwargs..., zero_crossing_id = name) +end + +function SciMLBase.Clocks.EventClock(cb::SymbolicContinuousCallback) + return SciMLBase.Clocks.EventClock(cb.zero_crossing_id) +end From fbef1a8bd5f27080b8d40c1100c5eb30764b50e0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 27 Jun 2025 13:10:21 +0530 Subject: [PATCH 10/20] feat: subset variables appropriately in clock inference --- src/systems/clock_inference.jl | 4 ++-- src/systems/systemstructure.jl | 30 +++++++++++++++++++++--------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index 86fe7e85ac..ff2d77f19b 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -199,8 +199,8 @@ function split_system(ci::ClockInference{S}) where {S} # breaks the system up into a continous and 0 or more discrete systems tss = similar(cid_to_eq, S) - for (id, ieqs) in enumerate(cid_to_eq) - ts_i = system_subset(ts, ieqs) + for (id, (ieqs, ivars)) in enumerate(zip(cid_to_eq, cid_to_var)) + ts_i = system_subset(ts, ieqs, ivars) if id != continuous_id ts_i = shift_discrete_system(ts_i) @set! ts_i.structure.only_discrete = true diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 66f3a4b6f9..3a0e0584d3 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -218,12 +218,12 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T} end TransformationState(sys::AbstractSystem) = TearingState(sys) -function system_subset(ts::TearingState, ieqs::Vector{Int}) +function system_subset(ts::TearingState, ieqs::Vector{Int}, ivars::Vector{Int}) eqs = equations(ts) @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.sys.eqs = eqs[ieqs] @set! ts.original_eqs = ts.original_eqs[ieqs] - @set! ts.structure = system_subset(ts.structure, ieqs) + @set! ts.structure = system_subset(ts.structure, ieqs, ivars) if all(eq -> eq.rhs isa StateMachineOperator, get_eqs(ts.sys)) names = Symbol[] for eq in get_eqs(ts.sys) @@ -240,22 +240,33 @@ function system_subset(ts::TearingState, ieqs::Vector{Int}) else @set! ts.statemachines = eltype(ts.statemachines)[] end + @set! ts.fullvars = ts.fullvars[ivars] ts end -function system_subset(structure::SystemStructure, ieqs::Vector{Int}) - @unpack graph, eq_to_diff = structure +function system_subset(structure::SystemStructure, ieqs::Vector{Int}, ivars::Vector{Int}) + @unpack graph = structure fadj = Vector{Int}[] eq_to_diff = DiffGraph(length(ieqs)) + var_to_diff = DiffGraph(length(ivars)) + ne = 0 + old_to_new_var = zeros(Int, ndsts(graph)) + for (i, iv) in enumerate(ivars) + old_to_new_var[iv] = i + structure.var_to_diff[iv] === nothing && continue + var_to_diff[i] = old_to_new_var[structure.var_to_diff[iv]] + end for (j, eq_i) in enumerate(ieqs) - ivars = copy(graph.fadjlist[eq_i]) - ne += length(ivars) - push!(fadj, ivars) + var_adj = [old_to_new_var[i] for i in graph.fadjlist[eq_i]] + @assert all(!iszero, var_adj) + ne += length(var_adj) + push!(fadj, var_adj) eq_to_diff[j] = structure.eq_to_diff[eq_i] end - @set! structure.graph = complete(BipartiteGraph(ne, fadj, ndsts(graph))) + @set! structure.graph = complete(BipartiteGraph(ne, fadj, length(ivars))) @set! structure.eq_to_diff = eq_to_diff + @set! structure.var_to_diff = complete(var_to_diff) structure end @@ -440,7 +451,8 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) isdelay(v, iv) && continue if !symbolic_contains(v, dvs) - isvalid = iscall(v) && (operation(v) isa Shift || is_transparent_operator(operation(v))) + isvalid = iscall(v) && + (operation(v) isa Shift || is_transparent_operator(operation(v))) v′ = v while !isvalid && iscall(v′) && operation(v′) isa Union{Differential, Shift} v′ = arguments(v′)[1] From ea917fc060bc6beb695c1a649302162565efa978 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 27 Jun 2025 14:24:24 +0530 Subject: [PATCH 11/20] feat: add hook during problem construction --- src/systems/problem_utils.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index ecba542fd0..a35b3b1663 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -1272,6 +1272,8 @@ function get_p_constructor(p_constructor, pType::Type, floatT::Type) end end +abstract type ProblemConstructionHook end + """ $(TYPEDSIGNATURES) @@ -1324,6 +1326,8 @@ function process_SciMLProblem( check_inputmap_keys(sys, op) + op = getmetadata(sys, ProblemConstructionHook, identity)(op) + defs = add_toterms(recursive_unwrap(defaults(sys)); replace = is_discrete_system(sys)) kwargs = NamedTuple(kwargs) From 66d5b0760e382aebb7a8a04fcbbfca8449ce6c77 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 7 Jul 2025 15:40:07 +0530 Subject: [PATCH 12/20] fixup! feat: retain original equations of the system in `TearingState` --- src/systems/systemstructure.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 3a0e0584d3..1a76fc88c0 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -220,7 +220,6 @@ end TransformationState(sys::AbstractSystem) = TearingState(sys) function system_subset(ts::TearingState, ieqs::Vector{Int}, ivars::Vector{Int}) eqs = equations(ts) - @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.sys.eqs = eqs[ieqs] @set! ts.original_eqs = ts.original_eqs[ieqs] @set! ts.structure = system_subset(ts.structure, ieqs, ivars) From 6e6138087614d0341c9b4be9c7f2237ddd89ff2f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Jun 2025 15:22:42 +0530 Subject: [PATCH 13/20] fix: fix `get_mtkparameters_reconstructor` handling of nonnumerics --- src/systems/problem_utils.jl | 40 ++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index a35b3b1663..1f5c021a3b 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -711,7 +711,8 @@ end $(TYPEDEF) A callable struct which applies `p_constructor` to possibly nested arrays. It also -ensures that views (including nested ones) are concretized. +ensures that views (including nested ones) are concretized. This is implemented manually +of using `narrow_buffer_type` to preserve type-stability. """ struct PConstructorApplicator{F} p_constructor::F @@ -721,10 +722,18 @@ function (pca::PConstructorApplicator)(x::AbstractArray) pca.p_constructor(x) end +function (pca::PConstructorApplicator)(x::AbstractArray{Bool}) + pca.p_constructor(BitArray(x)) +end + function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray) collect(x) end +function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{Bool}) + BitArray(x) +end + function (pca::PConstructorApplicator{typeof(identity)})(x::SubArray{<:AbstractArray}) collect(pca.(x)) end @@ -749,6 +758,7 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns """ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem; initials = false, unwrap_initials = false, p_constructor = identity) + _p_constructor = p_constructor p_constructor = PConstructorApplicator(p_constructor) # if we call `getu` on this (and it were able to handle empty tuples) we get the # fields of `MTKParameters` except caches. @@ -802,14 +812,24 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, syms[3]) end - rest_getters = map(Base.tail(Base.tail(Base.tail(syms)))) do buf - if buf == () - return Returns(()) - else - return Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, buf) - end + const_getter = if syms[4] == () + Returns(()) + else + Base.Fix1(broadcast, p_constructor) ∘ getu(srcsys, syms[4]) end - getters = (tunable_getter, initials_getter, discs_getter, rest_getters...) + nonnumeric_getter = if syms[5] == () + Returns(()) + else + ic = get_index_cache(dstsys) + buftypes = Tuple(map(ic.nonnumeric_buffer_sizes) do bufsize + Vector{bufsize.type} + end) + # nonnumerics retain the assigned buffer type without narrowing + Base.Fix1(broadcast, _p_constructor) ∘ + Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘ getu(srcsys, syms[5]) + end + getters = ( + tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter) getter = let getters = getters function _getter(valp, initprob) oldcache = parameter_values(initprob).caches @@ -822,6 +842,10 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac return getter end +function call(f, args...) + f(args...) +end + """ $(TYPEDSIGNATURES) From bcd21af0562582cdb8495453d8790fb902afb345 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Jun 2025 14:23:30 +0530 Subject: [PATCH 14/20] test: test nonnumerics aren't narrowed in `ODEProblem` and `init` --- test/initializationsystem.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 3be3e400c3..8f37ca905b 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1670,3 +1670,23 @@ end prob = ODEProblem(sys, [x[1] => nothing, x[2] => 1], (0.0, 1.0)) @test SciMLBase.initialization_status(prob) == SciMLBase.FULLY_DETERMINED end + +@testset "Nonnumerics aren't narrowed" begin + @mtkmodel Foo begin + @variables begin + x(t) = 1.0 + end + @parameters begin + p::AbstractString + r = 1.0 + end + @equations begin + D(x) ~ r * x + end + end + @mtkbuild sys = Foo(p = "a") + prob = ODEProblem(sys, [], (0.0, 1.0)) + @test prob.p.nonnumeric[1] isa Vector{AbstractString} + integ = init(prob) + @test integ.p.nonnumeric[1] isa Vector{AbstractString} +end From cd9829695c0d65961b6bc6bf342783b429f79a20 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 8 Jul 2025 17:29:19 +0530 Subject: [PATCH 15/20] fix: handle `Union` types in `BufferTemplate` --- src/systems/index_cache.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 2ce1c7cffa..75128554bc 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -1,5 +1,5 @@ struct BufferTemplate - type::Union{DataType, UnionAll} + type::Union{DataType, UnionAll, Union} length::Int end From c4cc348227ef9ef7995976d55346f123343d8351 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 10 Jul 2025 00:21:02 +0530 Subject: [PATCH 16/20] feat: rewrite clock inference to support polyadic synchronous operators and other cool stuff too --- src/clock.jl | 6 +- src/discretedomain.jl | 21 +++- src/systems/clock_inference.jl | 187 +++++++++++++++++++++++++++------ src/systems/connectiongraph.jl | 25 +++-- 4 files changed, 189 insertions(+), 50 deletions(-) diff --git a/src/clock.jl b/src/clock.jl index 1c9ed89128..a230c54210 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -1,11 +1,15 @@ @data InferredClock begin Inferred - InferredDiscrete + InferredDiscrete(Int) end const InferredTimeDomain = InferredClock.Type using .InferredClock: Inferred, InferredDiscrete +function InferredClock.InferredDiscrete() + return InferredDiscrete(0) +end + Base.Broadcast.broadcastable(x::InferredTimeDomain) = Ref(x) struct VariableTimeDomain end diff --git a/src/discretedomain.jl b/src/discretedomain.jl index 370d93d894..e33cf1e429 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -10,6 +10,15 @@ are not transparent but `Sample` and `Hold` are. Defaults to `false` if not impl is_transparent_operator(x) = is_transparent_operator(typeof(x)) is_transparent_operator(::Type) = false +""" + $(TYPEDSIGNATURES) + +Trait to be implemented for operators which determines whether they are synchronous operators. +Synchronous operators must implement `input_timedomain` and `output_timedomain`. +""" +is_synchronous_operator(x) = is_synchronous_operator(typeof(x)) +is_synchronous_operator(::Type) = false + """ function SampleTime() @@ -52,6 +61,7 @@ struct Shift <: Operator end Shift(steps::Int) = new(nothing, steps) normalize_to_differential(s::Shift) = Differential(s.t)^s.steps +is_synchronous_operator(::Type{Shift}) = true Base.nameof(::Shift) = :Shift SymbolicUtils.isbinop(::Shift) = false @@ -138,6 +148,7 @@ struct Sample <: Operator Sample(clock::Union{TimeDomain, InferredTimeDomain} = InferredDiscrete()) = new(clock) end +is_synchronous_operator(::Type{Sample}) = true is_transparent_operator(::Type{Sample}) = true function Sample(arg::Real) @@ -193,6 +204,7 @@ struct Hold <: Operator end is_transparent_operator(::Type{Hold}) = true +is_synchronous_operator(::Type{Hold}) = true (D::Hold)(x) = Term{symtype(x)}(D, Any[x]) (D::Hold)(x::Num) = Num(D(value(x))) @@ -314,12 +326,13 @@ Base.:-(k::ShiftIndex, i::Int) = k + (-i) input_timedomain(op::Operator) Return the time-domain type (`ContinuousClock()` or `InferredDiscrete()`) that `op` operates on. +Should return a tuple containing the time domain type for each argument to the operator. """ function input_timedomain(s::Shift, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - InferredDiscrete() + (InferredDiscrete(),) end """ @@ -334,22 +347,20 @@ function output_timedomain(s::Shift, arg = nothing) InferredDiscrete() end -input_timedomain(::Sample, _ = nothing) = ContinuousClock() +input_timedomain(::Sample, _ = nothing) = (ContinuousClock(),) output_timedomain(s::Sample, _ = nothing) = s.clock function input_timedomain(h::Hold, arg = nothing) if has_time_domain(arg) return get_time_domain(arg) end - InferredDiscrete() # the Hold accepts any discrete + (InferredDiscrete(),) # the Hold accepts any discrete end output_timedomain(::Hold, _ = nothing) = ContinuousClock() sampletime(op::Sample, _ = nothing) = sampletime(op.clock) sampletime(op::ShiftIndex, _ = nothing) = sampletime(op.clock) -changes_domain(op) = isoperator(op, Union{Sample, Hold}) - function output_timedomain(x) if isoperator(x, Operator) return output_timedomain(operation(x), arguments(x)[]) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index ff2d77f19b..24ac8c11ee 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -1,3 +1,9 @@ +@data ClockVertex begin + Variable(Int) + Equation(Int) + Clock(SciMLBase.AbstractClock) +end + struct ClockInference{S} """Tearing state.""" ts::S @@ -5,6 +11,7 @@ struct ClockInference{S} eq_domain::Vector{TimeDomain} """The output time domain (discrete clock, continuous) of each variable.""" var_domain::Vector{TimeDomain} + inference_graph::HyperGraph{ClockVertex.Type} """The set of variables with concrete domains.""" inferred::BitSet end @@ -22,7 +29,21 @@ function ClockInference(ts::TransformationState) var_domain[i] = d end end - ClockInference(ts, eq_domain, var_domain, inferred) + inference_graph = HyperGraph{ClockVertex.Type}() + for i in 1:nsrcs(graph) + add_vertex!(inference_graph, ClockVertex.Equation(i)) + end + for i in 1:ndsts(graph) + varvert = ClockVertex.Variable(i) + add_vertex!(inference_graph, varvert) + v = ts.fullvars[i] + d = get_time_domain(v) + is_concrete_time_domain(d) || continue + dvert = ClockVertex.Clock(d) + add_vertex!(inference_graph, dvert) + add_edge!(inference_graph, (varvert, dvert)) + end + ClockInference(ts, eq_domain, var_domain, inference_graph, inferred) end struct NotInferredTimeDomain end @@ -75,47 +96,147 @@ end Update the equation-to-time domain mapping by inferring the time domain from the variables. """ function infer_clocks!(ci::ClockInference) - @unpack ts, eq_domain, var_domain, inferred = ci + @unpack ts, eq_domain, var_domain, inferred, inference_graph = ci @unpack var_to_diff, graph = ts.structure fullvars = get_fullvars(ts) isempty(inferred) && return ci - # TODO: add a graph type to do this lazily - var_graph = SimpleGraph(ndsts(graph)) - for eq in 𝑠vertices(graph) - vvs = 𝑠neighbors(graph, eq) - if !isempty(vvs) - fv, vs = Iterators.peel(vvs) - for v in vs - add_edge!(var_graph, fv, v) - end - end + + var_to_idx = Dict(fullvars .=> eachindex(fullvars)) + + # all shifted variables have the same clock as the unshifted variant + for (i, v) in enumerate(fullvars) + iscall(v) || continue + operation(v) isa Shift || continue + unshifted = only(arguments(v)) + add_edge!(inference_graph, (ClockVertex.Variable(i), ClockVertex.Variable(var_to_idx[unshifted]))) end - for v in vertices(var_to_diff) - if (v′ = var_to_diff[v]) !== nothing - add_edge!(var_graph, v, v′) + + # preallocated buffers: + # variables in each equation + varsbuf = Set() + # variables in each argument to an operator + arg_varsbuf = Set() + # hyperedge for each equation + hyperedge = Set{ClockVertex.Type}() + # hyperedge for each argument to an operator + arg_hyperedge = Set{ClockVertex.Type}() + # mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition + relative_hyperedges = Dict{Int, Set{ClockVertex.Type}}() + + for (ieq, eq) in enumerate(equations(ts)) + empty!(varsbuf) + empty!(hyperedge) + # get variables in equation + vars!(varsbuf, eq; op = Symbolics.Operator) + # add the equation to the hyperedge + push!(hyperedge, ClockVertex.Equation(ieq)) + for var in varsbuf + idx = get(var_to_idx, var, nothing) + # if this is just a single variable, add it to the hyperedge + if idx isa Int + push!(hyperedge, ClockVertex.Variable(idx)) + # we don't immediately `continue` here because this variable might be a + # `Sample` or similar and we want the clock information from it if it is. + end + # now we only care about synchronous operators + iscall(var) || continue + op = operation(var) + is_synchronous_operator(op) || continue + + # arguments and corresponding time domains + args = arguments(var) + tdomains = input_timedomain(op) + nargs = length(args) + ndoms = length(tdomains) + if nargs != ndoms + throw(ArgumentError(""" + Operator $op applied to $nargs arguments $args but only returns $ndoms \ + domains $tdomains from `input_timedomain`. + """)) + end + + # each relative clock mapping is only valid per operator application + empty!(relative_hyperedges) + for (arg, domain) in zip(args, tdomains) + empty!(arg_varsbuf) + empty!(arg_hyperedge) + # get variables in argument + vars!(arg_varsbuf, arg; op = Union{Differential, Shift}) + # get hyperedge for involved variables + for v in arg_varsbuf + vidx = get(var_to_idx, v, nothing) + vidx === nothing && continue + push!(arg_hyperedge, ClockVertex.Variable(vidx)) + end + + Moshi.Match.@match domain begin + # If the time domain for this argument is a clock, then all variables in this edge have that clock. + x::SciMLBase.AbstractClock => begin + # add the clock to the edge + push!(arg_hyperedge, ClockVertex.Clock(x)) + # add the edge to the graph + add_edge!(inference_graph, arg_hyperedge) + end + # We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the + # involved variables have the same clock. + InferredClock.Inferred() => add_edge!(inference_graph, arg_hyperedge) + # All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't + # add the edge, and instead add this to the `relative_hyperedges` mapping. + InferredClock.InferredDiscrete(i) => begin + relative_edge = get!(() -> Set{ClockVertex.Type}(), relative_hyperedges, i) + union!(relative_edge, arg_hyperedge) + end + end + end + + outdomain = output_timedomain(op) + Moshi.Match.@match outdomain begin + x::SciMLBase.AbstractClock => begin + push!(hyperedge, ClockVertex.Clock(x)) + end + InferredClock.Inferred() => nothing + InferredClock.InferredDiscrete(i) => begin + buffer = get(relative_hyperedges, i, nothing) + if buffer !== nothing + union!(hyperedge, buffer) + delete!(relative_hyperedges, i) + end + end + end + + for (_, relative_edge) in relative_hyperedges + add_edge!(inference_graph, relative_edge) + end end + + add_edge!(inference_graph, hyperedge) end - cc = connected_components(var_graph) - for c′ in cc - c = BitSet(c′) - idxs = intersect(c, inferred) - isempty(idxs) && continue - if !allequal(iscontinuous(var_domain[i]) for i in idxs) - display(fullvars[c′]) - throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])")) + + clock_partitions = connectionsets(inference_graph) + for partition in clock_partitions + clockidxs = findall(vert -> Moshi.Data.isa_variant(vert, ClockVertex.Clock), partition) + if isempty(clockidxs) + vidxs = Int[vert.:1 for vert in partition if Moshi.Data.isa_variant(vert, ClockVertex.Variable)] + throw(ArgumentError(""" + Found clock partion with no associated clock. Involved variables: $(fullvars[vidxs]). + """)) end - vd = var_domain[first(idxs)] - for v in c′ - var_domain[v] = vd + if length(clockidxs) > 1 + vidxs = Int[vert.:1 for vert in partition if Moshi.Data.isa_variant(vert, ClockVertex.Variable)] + clks = [vert.:1 for vert in view(partition, clockidxs)] + throw(ArgumentError(""" + Found clock partition with multiple associated clocks. Involved variables: \ + $(fullvars[vidxs]). Involved clocks: $(clks). + """)) end - end - for v in 𝑑vertices(graph) - vd = var_domain[v] - eqs = 𝑑neighbors(graph, v) - isempty(eqs) && continue - for eq in eqs - eq_domain[eq] = vd + clock = partition[only(clockidxs)].:1 + for vert in partition + Moshi.Match.@match vert begin + ClockVertex.Variable(i) => (var_domain[i] = clock) + ClockVertex.Equation(i) => (eq_domain[i] = clock) + ClockVertex.Clock(_) => nothing + end end end diff --git a/src/systems/connectiongraph.jl b/src/systems/connectiongraph.jl index 5c5e8716c6..99110e37e9 100644 --- a/src/systems/connectiongraph.jl +++ b/src/systems/connectiongraph.jl @@ -119,15 +119,15 @@ connection sets. $(TYPEDFIELDS) """ -struct ConnectionGraph +struct HyperGraph{V} """ Mapping from vertices to their integer ID. """ - labels::Dict{ConnectionVertex, Int} + labels::Dict{V, Int} """ Reverse mapping from integer ID to vertices. """ - invmap::Vector{ConnectionVertex} + invmap::Vector{V} """ Core data structure for storing the hypergraph. Each hyperedge is a source vertex and has bipartite edges to the connection vertices it is incident on. @@ -135,14 +135,16 @@ struct ConnectionGraph graph::BipartiteGraph{Int, Nothing} end +const ConnectionGraph = HyperGraph{ConnectionVertex} + """ $(TYPEDSIGNATURES) Create an empty `ConnectionGraph`. """ -function ConnectionGraph() +function HyperGraph{V}() where {V} graph = BipartiteGraph(0, 0, Val(true)) - return ConnectionGraph(Dict{ConnectionVertex, Int}(), ConnectionVertex[], graph) + return HyperGraph{V}(Dict{V, Int}(), V[], graph) end function Base.show(io::IO, graph::ConnectionGraph) @@ -178,7 +180,7 @@ end Add the given vertex to the connection graph. Return the integer ID of the added vertex. No-op if the vertex already exists. """ -function Graphs.add_vertex!(graph::ConnectionGraph, dst::ConnectionVertex) +function Graphs.add_vertex!(graph::HyperGraph{V}, dst::V) where {V} j = get(graph.labels, dst, 0) iszero(j) || return j j = Graphs.add_vertex!(graph.graph, DST) @@ -188,7 +190,8 @@ function Graphs.add_vertex!(graph::ConnectionGraph, dst::ConnectionVertex) return j end -const ConnectionGraphEdge = Union{Vector{ConnectionVertex}, Tuple{Vararg{ConnectionVertex}}} +const HyperGraphEdge{V} = Union{Vector{V}, Tuple{Vararg{V}}, Set{V}} +const ConnectionGraphEdge = HyperGraphEdge{ConnectionVertex} """ $(TYPEDSIGNATURES) @@ -196,7 +199,7 @@ const ConnectionGraphEdge = Union{Vector{ConnectionVertex}, Tuple{Vararg{Connect Add the given hyperedge to the connection graph. Adds all vertices in the given edge if they do not exist. Returns the integer ID of the added edge. """ -function Graphs.add_edge!(graph::ConnectionGraph, src::ConnectionGraphEdge) +function Graphs.add_edge!(graph::HyperGraph{V}, src::HyperGraphEdge{V}) where {V} i = Graphs.add_vertex!(graph.graph, SRC) for vert in src j = Graphs.add_vertex!(graph, vert) @@ -447,7 +450,7 @@ end Return the merged connection sets in `graph` as a `Vector{Vector{ConnectionVertex}}`. These are equivalent to the connected components of `graph`. """ -function connectionsets(graph::ConnectionGraph) +function connectionsets(graph::HyperGraph{V}) where {V} bigraph = graph.graph invmap = graph.invmap @@ -465,11 +468,11 @@ function connectionsets(graph::ConnectionGraph) # maps the root of a vertex in `disjoint_sets` to the index of the corresponding set # in `vertex_sets` root_to_set = Dict{Int, Int}() - vertex_sets = Vector{ConnectionVertex}[] + vertex_sets = Vector{V}[] for (vert_i, vert) in enumerate(invmap) root = find_root!(disjoint_sets, vert_i) set_i = get!(root_to_set, root) do - push!(vertex_sets, ConnectionVertex[]) + push!(vertex_sets, V[]) return length(vertex_sets) end push!(vertex_sets[set_i], vert) From 48763c3a4385436ac282835b3f25af8cb0a88006 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Thu, 10 Jul 2025 17:08:34 -0700 Subject: [PATCH 17/20] Better support for multi-adic operators --- src/clock.jl | 1 + src/discretedomain.jl | 6 ++++-- src/systems/clock_inference.jl | 7 ++++++- src/systems/systemstructure.jl | 5 +++-- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/clock.jl b/src/clock.jl index a230c54210..8dc4e293f7 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -101,6 +101,7 @@ function is_discrete_domain(x) end sampletime(c) = Moshi.Match.@match c begin + x::SciMLBase.AbstractClock => nothing PeriodicClock(dt) => dt _ => nothing end diff --git a/src/discretedomain.jl b/src/discretedomain.jl index e33cf1e429..30795cf655 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -363,7 +363,8 @@ sampletime(op::ShiftIndex, _ = nothing) = sampletime(op.clock) function output_timedomain(x) if isoperator(x, Operator) - return output_timedomain(operation(x), arguments(x)[]) + args = arguments(x) + return output_timedomain(operation(x), if length(args) == 1 args[] else args end) else throw(ArgumentError("$x of type $(typeof(x)) is not an operator expression")) end @@ -371,7 +372,8 @@ end function input_timedomain(x) if isoperator(x, Operator) - return input_timedomain(operation(x), arguments(x)[]) + args = arguments(x) + return input_timedomain(operation(x), if length(args) == 1 args[] else args end) else throw(ArgumentError("$x of type $(typeof(x)) is not an operator expression")) end diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index 24ac8c11ee..c4f39d15c8 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -141,11 +141,16 @@ function infer_clocks!(ci::ClockInference) # now we only care about synchronous operators iscall(var) || continue op = operation(var) - is_synchronous_operator(op) || continue + if (!is_synchronous_operator(op)) && !(op isa Differential) + continue + end # arguments and corresponding time domains args = arguments(var) tdomains = input_timedomain(op) + if !(tdomains isa AbstractArray || tdomains isa Tuple) + tdomains = [tdomains] + end nargs = length(args) ndoms = length(tdomains) if nargs != ndoms diff --git a/src/systems/systemstructure.jl b/src/systems/systemstructure.jl index 1a76fc88c0..fb515ea436 100644 --- a/src/systems/systemstructure.jl +++ b/src/systems/systemstructure.jl @@ -467,8 +467,9 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true) addvar!(v, VARIABLE) if iscall(v) && operation(v) isa Symbolics.Operator && !isdifferential(v) && (it = input_timedomain(v)) !== nothing - v′ = only(arguments(v)) - addvar!(setmetadata(v′, VariableTimeDomain, it), VARIABLE) + for v′ ∈ arguments(v) + addvar!(setmetadata(v′, VariableTimeDomain, it), VARIABLE) + end end end From f03c251e22eaf2c65d3132c495fef0778fe2d8ed Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 14 Jul 2025 12:33:49 +0530 Subject: [PATCH 18/20] refactor: replace `is_synchronous_operator` with `is_timevarying_operator` --- src/discretedomain.jl | 16 +++++++++------- src/systems/abstractsystem.jl | 3 +-- src/systems/callbacks.jl | 3 +-- src/systems/clock_inference.jl | 4 +--- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/discretedomain.jl b/src/discretedomain.jl index 30795cf655..e1fbfaa236 100644 --- a/src/discretedomain.jl +++ b/src/discretedomain.jl @@ -13,11 +13,16 @@ is_transparent_operator(::Type) = false """ $(TYPEDSIGNATURES) -Trait to be implemented for operators which determines whether they are synchronous operators. -Synchronous operators must implement `input_timedomain` and `output_timedomain`. +Trait to be implemented for operators which determines whether the operator is applied to +a time-varying quantity and results in a time-varying quantity. For example, `Initial` and +`Pre` are not time-varying since while they are applied to variables, the application +results in a non-discrete-time parameter. `Differential`, `Shift`, `Sample` and `Hold` are +all time-varying operators. All time-varying operators must implement `input_timedomain` and +`output_timedomain`. """ -is_synchronous_operator(x) = is_synchronous_operator(typeof(x)) -is_synchronous_operator(::Type) = false +is_timevarying_operator(x) = is_timevarying_operator(typeof(x)) +is_timevarying_operator(::Type{<:Symbolics.Operator}) = true +is_timevarying_operator(::Type) = false """ function SampleTime() @@ -61,7 +66,6 @@ struct Shift <: Operator end Shift(steps::Int) = new(nothing, steps) normalize_to_differential(s::Shift) = Differential(s.t)^s.steps -is_synchronous_operator(::Type{Shift}) = true Base.nameof(::Shift) = :Shift SymbolicUtils.isbinop(::Shift) = false @@ -148,7 +152,6 @@ struct Sample <: Operator Sample(clock::Union{TimeDomain, InferredTimeDomain} = InferredDiscrete()) = new(clock) end -is_synchronous_operator(::Type{Sample}) = true is_transparent_operator(::Type{Sample}) = true function Sample(arg::Real) @@ -204,7 +207,6 @@ struct Hold <: Operator end is_transparent_operator(::Type{Hold}) = true -is_synchronous_operator(::Type{Hold}) = true (D::Hold)(x) = Term{symtype(x)}(D, Any[x]) (D::Hold)(x::Num) = Num(D(value(x))) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 34dfb5283e..cb5d9911e7 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -486,13 +486,12 @@ The `Initial` operator. Used by initialization to store constant constraints on of a system. See the documentation section on initialization for more information. """ struct Initial <: Symbolics.Operator end +is_timevarying_operator(::Type{Initial}) = false Initial(x) = Initial()(x) SymbolicUtils.promote_symtype(::Type{Initial}, T) = T SymbolicUtils.isbinop(::Initial) = false Base.nameof(::Initial) = :Initial Base.show(io::IO, x::Initial) = print(io, "Initial") -input_timedomain(::Initial, _ = nothing) = ContinuousClock() -output_timedomain(::Initial, _ = nothing) = ContinuousClock() function (f::Initial)(x) # wrap output if wrapped input diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index f05e455cbc..0ddc3d10cd 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -56,12 +56,11 @@ before the callback is triggered. """ struct Pre <: Symbolics.Operator end Pre(x) = Pre()(x) +is_timevarying_operator(::Type{Pre}) = false SymbolicUtils.promote_symtype(::Type{Pre}, T) = T SymbolicUtils.isbinop(::Pre) = false Base.nameof(::Pre) = :Pre Base.show(io::IO, x::Pre) = print(io, "Pre") -input_timedomain(::Pre, _ = nothing) = ContinuousClock() -output_timedomain(::Pre, _ = nothing) = ContinuousClock() unPre(x::Num) = unPre(unwrap(x)) unPre(x::Symbolics.Arr) = unPre(unwrap(x)) unPre(x::Symbolic) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index c4f39d15c8..3fef25993e 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -141,9 +141,7 @@ function infer_clocks!(ci::ClockInference) # now we only care about synchronous operators iscall(var) || continue op = operation(var) - if (!is_synchronous_operator(op)) && !(op isa Differential) - continue - end + is_timevarying_operator(op) || continue # arguments and corresponding time domains args = arguments(var) From cdea28e52495af513abfe40d5c1c5114ce37b952 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 14 Jul 2025 12:34:06 +0530 Subject: [PATCH 19/20] fix: fix `is_time_domain_conversion` for new `input_timedomain` --- src/systems/clock_inference.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index 3fef25993e..ef8f75a230 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -259,8 +259,14 @@ function resize_or_push!(v, val, idx) end function is_time_domain_conversion(v) - iscall(v) && (o = operation(v)) isa Operator && - input_timedomain(o) != output_timedomain(o) + iscall(v) || return false + o = operation(v) + o isa Operator || return false + itd = input_timedomain(o) + allequal(itd) || return true + otd = output_timedomain(o) + itd[1] == otd || return true + return false end """ From 0b470915035c1140ef276431d17397c13da35150 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 14 Jul 2025 12:34:19 +0530 Subject: [PATCH 20/20] fix: fix `input_timedomain` implementation for `Differential` --- src/clock.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/clock.jl b/src/clock.jl index 8dc4e293f7..df3b6f4b47 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -54,7 +54,7 @@ has_time_domain(x::Num) = has_time_domain(value(x)) has_time_domain(x) = false for op in [Differential] - @eval input_timedomain(::$op, arg = nothing) = ContinuousClock() + @eval input_timedomain(::$op, arg = nothing) = (ContinuousClock(),) @eval output_timedomain(::$op, arg = nothing) = ContinuousClock() end