Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/systems/callbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,14 @@ end
end)
end

@generated function _generated_readback(integ, getters::NamedTuple{NS1, <:Tuple}) where {NS1}
getter_exprs = []
for name in NS1
push!(getter_exprs, :($name = getters.$name(integ)))
end
return :((; $(getter_exprs...)))
end

function check_assignable(sys, sym)
if symbolic_type(sym) == ScalarSymbolic()
is_variable(sys, sym) || is_parameter(sys, sym)
Expand Down
17 changes: 5 additions & 12 deletions src/systems/imperative_affect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,18 +189,13 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
else
zeros(sz)
end
obs_fun = build_explicit_observed_function(
sys, Symbolics.scalarize.(obs_exprs);
mkarray = (es, _) -> MakeTuple(es))
obs_sym_tuple = (obs_syms...,)
geto_funs = NamedTuple{(obs_syms...,)}((getsym.((sys,), obs_exprs)...,))

# okay so now to generate the stuff to assign it back into the system
getm_funs = NamedTuple{(mod_syms...,)}((getsym.((sys,), mod_exprs)...,))

mod_pairs = mod_exprs .=> mod_syms
mod_names = (mod_syms...,)
mod_og_val_fun = build_explicit_observed_function(
sys, Symbolics.scalarize.(first.(mod_pairs));
mkarray = (es, _) -> MakeTuple(es))

upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,))

if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
Expand All @@ -212,12 +207,10 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs.
let user_affect = func(affect), ctx = context(affect)
function (integ)
# update the to-be-mutated values; this ensures that if you do a no-op then nothing happens
modvals = mod_og_val_fun(integ.u, integ.p, integ.t)
upd_component_array = NamedTuple{mod_names}(modvals)
upd_component_array = _generated_readback(integ, getm_funs)

# update the observed values
obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(
integ.u, integ.p, integ.t))
obs_component_array = _generated_readback(integ, geto_funs)

# let the user do their thing
upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ)
Expand Down
23 changes: 23 additions & 0 deletions test/symbolic_events.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1461,3 +1461,26 @@ end
sys = structural_simplify(sys)
sol = solve(ODEProblem(sys, [], (0.0, 1.0)), Tsit5())
end

@testset "Tuples in ImperativeAffect arguments" begin
@mtkmodel ImperativeAffectTupleMWE begin
@parameters begin
y(t) = 1.0
end
@variables begin
x(t) = 0.0
end
@equations begin
D(x) ~ y
end
@continuous_events begin
(x ~ 0.5) => ModelingToolkit.ImperativeAffect(
observed = (; mypars = (x, 2 * x)), modified = (; y)) do m, o, c, i
return (; y = 2 * o.mypars[1] + o.mypars[2])
end
end
end
@mtkbuild sys = ImperativeAffectTupleMWE()
prob = ODEProblem(sys, [], (0.0, 1.0))
sol = solve(prob, Tsit5())
end
Loading