Skip to content

Commit 09c084b

Browse files
feat: subset variables appropriately in clock inference
1 parent 08633e9 commit 09c084b

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

src/systems/clock_inference.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ function split_system(ci::ClockInference{S}) where {S}
199199

200200
# breaks the system up into a continous and 0 or more discrete systems
201201
tss = similar(cid_to_eq, S)
202-
for (id, ieqs) in enumerate(cid_to_eq)
203-
ts_i = system_subset(ts, ieqs)
202+
for (id, (ieqs, ivars)) in enumerate(zip(cid_to_eq, cid_to_var))
203+
ts_i = system_subset(ts, ieqs, ivars)
204204
if id != continuous_id
205205
ts_i = shift_discrete_system(ts_i)
206206
@set! ts_i.structure.only_discrete = true

src/systems/systemstructure.jl

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,11 @@ mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
213213
end
214214

215215
TransformationState(sys::AbstractSystem) = TearingState(sys)
216-
function system_subset(ts::TearingState, ieqs::Vector{Int})
216+
function system_subset(ts::TearingState, ieqs::Vector{Int}, ivars::Vector{Int})
217217
eqs = equations(ts)
218218
@set! ts.original_eqs = ts.original_eqs[ieqs]
219219
@set! ts.sys.eqs = eqs[ieqs]
220-
@set! ts.structure = system_subset(ts.structure, ieqs)
220+
@set! ts.structure = system_subset(ts.structure, ieqs, ivars)
221221
if all(eq -> eq.rhs isa StateMachineOperator, get_eqs(ts.sys))
222222
names = Symbol[]
223223
for eq in get_eqs(ts.sys)
@@ -234,22 +234,33 @@ function system_subset(ts::TearingState, ieqs::Vector{Int})
234234
else
235235
@set! ts.statemachines = eltype(ts.statemachines)[]
236236
end
237+
@set! ts.fullvars = ts.fullvars[ivars]
237238
ts
238239
end
239240

240-
function system_subset(structure::SystemStructure, ieqs::Vector{Int})
241-
@unpack graph, eq_to_diff = structure
241+
function system_subset(structure::SystemStructure, ieqs::Vector{Int}, ivars::Vector{Int})
242+
@unpack graph = structure
242243
fadj = Vector{Int}[]
243244
eq_to_diff = DiffGraph(length(ieqs))
245+
var_to_diff = DiffGraph(length(ivars))
246+
244247
ne = 0
248+
old_to_new_var = zeros(Int, ndsts(graph))
249+
for (i, iv) in enumerate(ivars)
250+
old_to_new_var[iv] = i
251+
structure.var_to_diff[iv] === nothing && continue
252+
var_to_diff[i] = old_to_new_var[structure.var_to_diff[iv]]
253+
end
245254
for (j, eq_i) in enumerate(ieqs)
246-
ivars = copy(graph.fadjlist[eq_i])
247-
ne += length(ivars)
248-
push!(fadj, ivars)
255+
var_adj = [old_to_new_var[i] for i in graph.fadjlist[eq_i]]
256+
@assert all(!iszero, var_adj)
257+
ne += length(var_adj)
258+
push!(fadj, var_adj)
249259
eq_to_diff[j] = structure.eq_to_diff[eq_i]
250260
end
251-
@set! structure.graph = complete(BipartiteGraph(ne, fadj, ndsts(graph)))
261+
@set! structure.graph = complete(BipartiteGraph(ne, fadj, length(ivars)))
252262
@set! structure.eq_to_diff = eq_to_diff
263+
@set! structure.var_to_diff = complete(var_to_diff)
253264
structure
254265
end
255266

@@ -433,7 +444,8 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
433444
isdelay(v, iv) && continue
434445

435446
if !symbolic_contains(v, dvs)
436-
isvalid = iscall(v) && (operation(v) isa Shift || is_transparent_operator(operation(v)))
447+
isvalid = iscall(v) &&
448+
(operation(v) isa Shift || is_transparent_operator(operation(v)))
437449
v′ = v
438450
while !isvalid && iscall(v′) && operation(v′) isa Union{Differential, Shift}
439451
v′ = arguments(v′)[1]

0 commit comments

Comments
 (0)