Skip to content

Add support for an external synchronous compiler to discrete and hybrid systems #3399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 21 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
4988ee1
Early work on the new discrete backend for MTK
BenChung Feb 18, 2025
ed0612b
feat: retain original equations of the system in `TearingState`
AayushSabharwal Feb 20, 2025
7f8b8f2
feat: allow namespacing statemachine equations
AayushSabharwal Mar 14, 2025
7cf774d
feat: propagate state machines in structural simplification
AayushSabharwal Mar 14, 2025
c081a34
Handle nothing updates better
BenChung Mar 15, 2025
b60be79
Redefine the discrete_compile interface a bit
BenChung Mar 15, 2025
405aafa
Change the external synchronous signature to include the id/clock map
BenChung May 14, 2025
865523b
feat: add `zero_crossing_id` to `SymbolicContinuousCallback`
AayushSabharwal Jun 20, 2025
aeefc8a
feat: add `ZeroCrossing` and `EventClock` from zero crossing
AayushSabharwal Jun 20, 2025
fbef1a8
feat: subset variables appropriately in clock inference
AayushSabharwal Jun 27, 2025
ea917fc
feat: add hook during problem construction
AayushSabharwal Jun 27, 2025
66d5b07
fixup! feat: retain original equations of the system in `TearingState`
AayushSabharwal Jul 7, 2025
6e61380
fix: fix `get_mtkparameters_reconstructor` handling of nonnumerics
AayushSabharwal Jun 6, 2025
bcd21af
test: test nonnumerics aren't narrowed in `ODEProblem` and `init`
AayushSabharwal Jun 6, 2025
cd98296
fix: handle `Union` types in `BufferTemplate`
AayushSabharwal Jul 8, 2025
c4cc348
feat: rewrite clock inference to support polyadic synchronous operators
AayushSabharwal Jul 9, 2025
feebc16
Merge pull request #3808 from SciML/as/clock-inference
BenChung Jul 9, 2025
48763c3
Better support for multi-adic operators
BenChung Jul 11, 2025
f03c251
refactor: replace `is_synchronous_operator` with `is_timevarying_oper…
AayushSabharwal Jul 14, 2025
cdea28e
fix: fix `is_time_domain_conversion` for new `input_timedomain`
AayushSabharwal Jul 14, 2025
0b47091
fix: fix `input_timedomain` implementation for `Differential`
AayushSabharwal Jul 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/clock.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
@data InferredClock begin
Inferred
InferredDiscrete
InferredDiscrete(Int)
end

const InferredTimeDomain = InferredClock.Type
using .InferredClock: Inferred, InferredDiscrete

function InferredClock.InferredDiscrete()
return InferredDiscrete(0)
end

Base.Broadcast.broadcastable(x::InferredTimeDomain) = Ref(x)

struct VariableTimeDomain end
Expand Down Expand Up @@ -50,7 +54,7 @@ has_time_domain(x::Num) = has_time_domain(value(x))
has_time_domain(x) = false

for op in [Differential]
@eval input_timedomain(::$op, arg = nothing) = ContinuousClock()
@eval input_timedomain(::$op, arg = nothing) = (ContinuousClock(),)
@eval output_timedomain(::$op, arg = nothing) = ContinuousClock()
end

Expand Down Expand Up @@ -97,6 +101,7 @@ function is_discrete_domain(x)
end

sampletime(c) = Moshi.Match.@match c begin
x::SciMLBase.AbstractClock => nothing
PeriodicClock(dt) => dt
_ => nothing
end
Expand Down
40 changes: 33 additions & 7 deletions src/discretedomain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@ are not transparent but `Sample` and `Hold` are. Defaults to `false` if not impl
is_transparent_operator(x) = is_transparent_operator(typeof(x))
is_transparent_operator(::Type) = false

"""
$(TYPEDSIGNATURES)

Trait to be implemented for operators which determines whether the operator is applied to
a time-varying quantity and results in a time-varying quantity. For example, `Initial` and
`Pre` are not time-varying since while they are applied to variables, the application
results in a non-discrete-time parameter. `Differential`, `Shift`, `Sample` and `Hold` are
all time-varying operators. All time-varying operators must implement `input_timedomain` and
`output_timedomain`.
"""
is_timevarying_operator(x) = is_timevarying_operator(typeof(x))
is_timevarying_operator(::Type{<:Symbolics.Operator}) = true
is_timevarying_operator(::Type) = false

"""
function SampleTime()

Expand Down Expand Up @@ -314,12 +328,13 @@ Base.:-(k::ShiftIndex, i::Int) = k + (-i)
input_timedomain(op::Operator)

Return the time-domain type (`ContinuousClock()` or `InferredDiscrete()`) that `op` operates on.
Should return a tuple containing the time domain type for each argument to the operator.
"""
function input_timedomain(s::Shift, arg = nothing)
if has_time_domain(arg)
return get_time_domain(arg)
end
InferredDiscrete()
(InferredDiscrete(),)
end

"""
Expand All @@ -334,34 +349,45 @@ function output_timedomain(s::Shift, arg = nothing)
InferredDiscrete()
end

input_timedomain(::Sample, _ = nothing) = ContinuousClock()
input_timedomain(::Sample, _ = nothing) = (ContinuousClock(),)
output_timedomain(s::Sample, _ = nothing) = s.clock

function input_timedomain(h::Hold, arg = nothing)
if has_time_domain(arg)
return get_time_domain(arg)
end
InferredDiscrete() # the Hold accepts any discrete
(InferredDiscrete(),) # the Hold accepts any discrete
end
output_timedomain(::Hold, _ = nothing) = ContinuousClock()

sampletime(op::Sample, _ = nothing) = sampletime(op.clock)
sampletime(op::ShiftIndex, _ = nothing) = sampletime(op.clock)

changes_domain(op) = isoperator(op, Union{Sample, Hold})

function output_timedomain(x)
if isoperator(x, Operator)
return output_timedomain(operation(x), arguments(x)[])
args = arguments(x)
return output_timedomain(operation(x), if length(args) == 1 args[] else args end)
else
throw(ArgumentError("$x of type $(typeof(x)) is not an operator expression"))
end
end

function input_timedomain(x)
if isoperator(x, Operator)
return input_timedomain(operation(x), arguments(x)[])
args = arguments(x)
return input_timedomain(operation(x), if length(args) == 1 args[] else args end)
else
throw(ArgumentError("$x of type $(typeof(x)) is not an operator expression"))
end
end

function ZeroCrossing(expr; name = gensym(), up = true, down = true, kwargs...)
return SymbolicContinuousCallback(
[expr ~ 0], up ? ImperativeAffect(Returns(nothing)) : nothing;
affect_neg = down ? ImperativeAffect(Returns(nothing)) : nothing,
kwargs..., zero_crossing_id = name)
end

function SciMLBase.Clocks.EventClock(cb::SymbolicContinuousCallback)
return SciMLBase.Clocks.EventClock(cb.zero_crossing_id)
end
4 changes: 2 additions & 2 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -486,13 +486,12 @@ The `Initial` operator. Used by initialization to store constant constraints on
of a system. See the documentation section on initialization for more information.
"""
struct Initial <: Symbolics.Operator end
is_timevarying_operator(::Type{Initial}) = false
Initial(x) = Initial()(x)
SymbolicUtils.promote_symtype(::Type{Initial}, T) = T
SymbolicUtils.isbinop(::Initial) = false
Base.nameof(::Initial) = :Initial
Base.show(io::IO, x::Initial) = print(io, "Initial")
input_timedomain(::Initial, _ = nothing) = ContinuousClock()
output_timedomain(::Initial, _ = nothing) = ContinuousClock()

function (f::Initial)(x)
# wrap output if wrapped input
Expand Down Expand Up @@ -1228,6 +1227,7 @@ function namespace_expr(
O
end
end

_nonum(@nospecialize x) = x isa Num ? x.val : x

"""
Expand Down
29 changes: 18 additions & 11 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,11 @@ before the callback is triggered.
"""
struct Pre <: Symbolics.Operator end
Pre(x) = Pre()(x)
is_timevarying_operator(::Type{Pre}) = false
SymbolicUtils.promote_symtype(::Type{Pre}, T) = T
SymbolicUtils.isbinop(::Pre) = false
Base.nameof(::Pre) = :Pre
Base.show(io::IO, x::Pre) = print(io, "Pre")
input_timedomain(::Pre, _ = nothing) = ContinuousClock()
output_timedomain(::Pre, _ = nothing) = ContinuousClock()
unPre(x::Num) = unPre(unwrap(x))
unPre(x::Symbolics.Arr) = unPre(unwrap(x))
unPre(x::Symbolic) = (iscall(x) && operation(x) isa Pre) ? only(arguments(x)) : x
Expand Down Expand Up @@ -165,6 +164,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
finalize::Union{Affect, Nothing}
rootfind::Union{Nothing, SciMLBase.RootfindOpt}
reinitializealg::SciMLBase.DAEInitializationAlgorithm
zero_crossing_id::Symbol

function SymbolicContinuousCallback(
conditions::Union{Equation, Vector{Equation}},
Expand All @@ -174,6 +174,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
finalize = nothing,
rootfind = SciMLBase.LeftRootFind,
reinitializealg = nothing,
zero_crossing_id = gensym(),
kwargs...)
conditions = (conditions isa AbstractVector) ? conditions : [conditions]

Expand All @@ -190,7 +191,7 @@ struct SymbolicContinuousCallback <: AbstractCallback
make_affect(affect_neg; kwargs...),
make_affect(initialize; kwargs...), make_affect(
finalize; kwargs...),
rootfind, reinitializealg)
rootfind, reinitializealg, zero_crossing_id)
end # Default affect to nothing
end

Expand Down Expand Up @@ -466,7 +467,8 @@ function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuo
affect_neg = namespace_affects(affect_negs(cb), s),
initialize = namespace_affects(initialize_affects(cb), s),
finalize = namespace_affects(finalize_affects(cb), s),
rootfind = cb.rootfind, reinitializealg = cb.reinitializealg)
rootfind = cb.rootfind, reinitializealg = cb.reinitializealg,
zero_crossing_id = cb.zero_crossing_id)
end

function namespace_conditions(condition, s)
Expand All @@ -490,6 +492,8 @@ function Base.hash(cb::AbstractCallback, s::UInt)
s = hash(finalize_affects(cb), s)
!is_discrete(cb) && (s = hash(cb.rootfind, s))
hash(cb.reinitializealg, s)
!is_discrete(cb) && (s = hash(cb.zero_crossing_id, s))
return s
end

###########################
Expand Down Expand Up @@ -524,13 +528,16 @@ function finalize_affects(cbs::Vector{<:AbstractCallback})
end

function Base.:(==)(e1::AbstractCallback, e2::AbstractCallback)
(is_discrete(e1) === is_discrete(e2)) || return false
(isequal(e1.conditions, e2.conditions) && isequal(e1.affect, e2.affect) &&
isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize)) &&
isequal(e1.reinitializealg, e2.reinitializealg) ||
return false
is_discrete(e1) ||
(isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind))
is_discrete(e1) === is_discrete(e2) || return false
isequal(e1.conditions, e2.conditions) && isequal(e1.affect, e2.affect) || return false
isequal(e1.initialize, e2.initialize) || return false
isequal(e1.finalize, e2.finalize) || return false
isequal(e1.reinitializealg, e2.reinitializealg) || return false
if !is_discrete(e1)
isequal(e1.affect_neg, e2.affect_neg) || return false
isequal(e1.rootfind, e2.rootfind) || return false
isequal(e1.zero_crossing_id, e2.zero_crossing_id) || return false
end
end

Base.isempty(cb::AbstractCallback) = isempty(cb.conditions)
Expand Down
Loading