Skip to content

Commit c35069d

Browse files
Merge pull request #78 from SciML/revert-77-revert-66-sy/interp_elim_var_access
Revert "Revert "Add symbol based indexing for interpolated solutions""
2 parents 9fb2ba9 + 0b725ac commit c35069d

File tree

6 files changed

+184
-63
lines changed

6 files changed

+184
-63
lines changed

src/interpolation.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,9 @@ times t (sorted), with values u and derivatives ks
248248
end
249249
end
250250

251-
251+
@inline function interpolant(Θ,id::AbstractDiffEqInterpolation,dt,y₀,y₁,dy₀,dy₁,idxs,::Type{Val{D}}) where D
252+
error("$(string(typeof(id))) for $(D)th order not implemented")
253+
end
252254
##################### Hermite Interpolants
253255

254256
"""

src/solutions/ode_solutions.jl

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,53 @@ struct ODESolution{T,N,uType,uType2,DType,tType,rateType,P,A,IType,DE} <: Abstra
1515
destats::DE
1616
retcode::Symbol
1717
end
18-
(sol::ODESolution)(t,deriv::Type=Val{0};idxs=nothing,continuity=:left) = sol.interp(t,idxs,deriv,sol.prob.p,continuity)
18+
(sol::ODESolution)(t,deriv::Type=Val{0};idxs=nothing,continuity=:left) = sol(t,deriv,idxs,continuity)
1919
(sol::ODESolution)(v,t,deriv::Type=Val{0};idxs=nothing,continuity=:left) = sol.interp(v,t,idxs,deriv,sol.prob.p,continuity)
2020

21+
function (sol::ODESolution)(t::Real,deriv,idxs::Nothing,continuity)
22+
sol.interp(t,idxs,deriv,sol.prob.p,continuity)
23+
end
24+
25+
function (sol::ODESolution)(t::AbstractVector{<:Real},deriv,idxs::Nothing,continuity)
26+
augment(sol.interp(t,idxs,deriv,sol.prob.p,continuity), sol)
27+
end
28+
29+
function (sol::ODESolution)(t::Real,deriv,idxs::Integer,continuity)
30+
sol.interp(t,idxs,deriv,sol.prob.p,continuity)
31+
end
32+
function (sol::ODESolution)(t::Real,deriv,idxs::AbstractVector{<:Integer},continuity)
33+
sol.interp(t,idxs,deriv,sol.prob.p,continuity)
34+
end
35+
function (sol::ODESolution)(t::AbstractVector{<:Real},deriv,idxs::Integer,continuity)
36+
sol.interp(t,idxs,deriv,sol.prob.p,continuity)
37+
end
38+
function (sol::ODESolution)(t::AbstractVector{<:Real},deriv,idxs::AbstractVector{<:Integer},continuity)
39+
sol.interp(t,idxs,deriv,sol.prob.p,continuity)
40+
end
41+
42+
function (sol::ODESolution)(t::Real,deriv,idxs,continuity)
43+
issymbollike(idxs) || error("Incorrect specification of `idxs`")
44+
augment(sol.interp([t],nothing,deriv,sol.prob.p,continuity), sol)[idxs][1]
45+
end
46+
47+
function (sol::ODESolution)(t::Real,deriv,idxs::AbstractVector,continuity)
48+
all(issymbollike.(idxs)) || error("Incorrect specification of `idxs`")
49+
interp_sol = augment(sol.interp([t],nothing,deriv,sol.prob.p,continuity), sol)
50+
[first(interp_sol[idx]) for idx in idxs]
51+
end
52+
53+
function (sol::ODESolution)(t::AbstractVector{<:Real},deriv,idxs,continuity)
54+
issymbollike(idxs) || error("Incorrect specification of `idxs`")
55+
interp_sol = augment(sol.interp(t,nothing,deriv,sol.prob.p,continuity), sol)
56+
DiffEqArray(interp_sol[idxs], t)
57+
end
58+
59+
function (sol::ODESolution)(t::AbstractVector{<:Real},deriv,idxs::AbstractVector,continuity)
60+
all(issymbollike.(idxs)) || error("Incorrect specification of `idxs`")
61+
interp_sol = augment(sol.interp(t,nothing,deriv,sol.prob.p,continuity), sol)
62+
DiffEqArray([[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t)
63+
end
64+
2165
function build_solution(
2266
prob::Union{AbstractODEProblem,AbstractDDEProblem},
2367
alg,t,u;timeseries_errors=length(u)>2,

src/solutions/solution_interface.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ Base.size(A::AbstractNoTimeSolution) = size(A.u)
1414

1515
Base.show(io::IO, m::MIME"text/plain", A::AbstractNoTimeSolution) = (print(io,"u: ");show(io,m,A.u))
1616

17+
# For augmenting system information to enable symbol based indexing of interpolated solutions
18+
function augment(A::DiffEqArray, sol::AbstractODESolution)
19+
observed = has_observed(sol.prob.f) ? sol.prob.f.observed : DEFAULT_OBSERVED
20+
DiffEqArray(A.u, A.t, sol.prob.f.syms,getindepsym(sol),observed,sol.prob.p)
21+
end
22+
1723
# Symbol Handling
1824

1925
# For handling ambiguities
@@ -294,13 +300,15 @@ end
294300

295301
sym_to_index(sym,sol::SciMLSolution) = sym_to_index(sym,getsyms(sol))
296302
sym_to_index(sym,syms) = findfirst(isequal(Symbol(sym)),syms)
297-
issymbollike(x) = x isa Symbol ||
303+
function issymbollike(x)
304+
x isa Symbol ||
298305
x isa AllObserved ||
299-
Symbol(parameterless_type(typeof(x))) == :Operation ||
300-
Symbol(parameterless_type(typeof(x))) == :Variable ||
301-
Symbol(parameterless_type(typeof(x))) == :Sym ||
302-
Symbol(parameterless_type(typeof(x))) == :Num ||
303-
Symbol(parameterless_type(typeof(x))) == :Term
306+
Symbol(parameterless_type(typeof(x))) == :Operation || Symbol(parameterless_type(typeof(x))) == Symbol("Symbolics.Operation") ||
307+
Symbol(parameterless_type(typeof(x))) == :Variable || Symbol(parameterless_type(typeof(x))) == Symbol("Symbolics.Variable") ||
308+
Symbol(parameterless_type(typeof(x))) == :Sym || Symbol(parameterless_type(typeof(x))) == Symbol("Symbolics.Sym") ||
309+
Symbol(parameterless_type(typeof(x))) == :Num || Symbol(parameterless_type(typeof(x))) == Symbol("Symbolics.Num") ||
310+
Symbol(parameterless_type(typeof(x))) == :Term || Symbol(parameterless_type(typeof(x))) == Symbol("Symbolics.Term")
311+
end
304312

305313
function diffeq_to_arrays(sol,plot_analytic,denseplot,plotdensity,tspan,axis_safety,vars,int_vars,tscale,strs)
306314
if tspan === nothing

test/downstream/observables.jl

Lines changed: 0 additions & 55 deletions
This file was deleted.

test/downstream/symbol_indexing.jl

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
using ModelingToolkit, OrdinaryDiffEq, Test
2+
3+
@parameters t σ ρ β
4+
@variables x(t) y(t) z(t)
5+
D = Differential(t)
6+
7+
eqs = [D(x) ~ σ*(y-x),
8+
D(y) ~ x*-z)-y,
9+
D(z) ~ x*y - β*z]
10+
11+
lorenz1 = ODESystem(eqs,name=:lorenz1)
12+
lorenz2 = ODESystem(eqs,name=:lorenz2)
13+
14+
@parameters γ
15+
@variables a(t), α(t)
16+
connections = [0 ~ lorenz1.x + lorenz2.y + a*γ,
17+
α ~ 2lorenz1.x + a*γ]
18+
sys = ODESystem(connections,t,[a,α],[γ],systems=[lorenz1,lorenz2])
19+
sys_simplified = structural_simplify(sys)
20+
21+
u0 = [lorenz1.x => 1.0,
22+
lorenz1.y => 0.0,
23+
lorenz1.z => 0.0,
24+
lorenz2.x => 0.0,
25+
lorenz2.y => 1.0,
26+
lorenz2.z => 0.0,
27+
a => 2.0]
28+
29+
p = [lorenz1.σ => 10.0,
30+
lorenz1.ρ => 28.0,
31+
lorenz1.β => 8/3,
32+
lorenz2.σ => 10.0,
33+
lorenz2.ρ => 28.0,
34+
lorenz2.β => 8/3,
35+
γ => 2.0]
36+
37+
tspan = (0.0,100.0)
38+
prob = ODEProblem(sys_simplified,u0,tspan,p)
39+
sol = solve(prob,Rodas4())
40+
41+
@test sol[lorenz1.x] isa Vector
42+
@test sol[lorenz1.x,2] isa Float64
43+
@test sol[lorenz1.x,:] isa Vector
44+
@test length(sol[lorenz1.x,1:5]) == 5
45+
@test sol[α] isa Vector
46+
@test sol[α,3] isa Float64
47+
@test length(sol[α,5:10]) == 6
48+
49+
# Check if indexing using variable names from interpolated solution works
50+
interpolated_sol = sol(0.0:1.0:10.0)
51+
@test interpolated_sol[α] isa Vector
52+
@test interpolated_sol[α,:] isa Vector
53+
@test interpolated_sol[α,2] isa Float64
54+
@test length(interpolated_sol[α,1:5]) == 5
55+
@test interpolated_sol[α] 2interpolated_sol[lorenz1.x] .+ interpolated_sol[a].*2.0
56+
57+
58+
sol1 = sol(0.0:1.0:10.0)
59+
@test sol1.u isa Vector
60+
@test first(sol1.u) isa Vector
61+
@test length(sol1.u) == 11
62+
@test length(sol1.t) == 11
63+
64+
sol2 = sol(0.1)
65+
@test sol2 isa Vector
66+
@test length(sol2) == length(states(sys_simplified))
67+
@test first(sol2) isa Real
68+
69+
sol3 = sol(0.0:1.0:10.0, idxs=[lorenz1.x, lorenz2.x])
70+
@test sol3.u isa Vector
71+
@test first(sol3.u) isa Vector
72+
@test length(sol3.u) == 11
73+
@test length(sol3.t) == 11
74+
@test_throws ErrorException sol(0.0:1.0:10.0, idxs=[lorenz1.x, 1])
75+
76+
sol4 = sol(0.1, idxs=[lorenz1.x, lorenz2.x])
77+
@test sol4 isa Vector
78+
@test length(sol4) == 2
79+
@test first(sol4) isa Real
80+
@test_throws ErrorException sol(0.1, idxs=[lorenz1.x, 1])
81+
82+
sol5 = sol(0.0:1.0:10.0, idxs=lorenz1.x)
83+
@test sol5.u isa Vector
84+
@test first(sol5.u) isa Real
85+
@test length(sol5.u) == 11
86+
@test length(sol5.t) == 11
87+
@test_throws ErrorException sol(0.0:1.0:10.0, idxs=1.2)
88+
89+
sol6 = sol(0.1, idxs=lorenz1.x)
90+
@test sol6 isa Real
91+
@test_throws ErrorException sol(0.1, idxs=1.2)
92+
93+
sol7 = sol(0.0:1.0:10.0, idxs=[2,1])
94+
@test sol7.u isa Vector
95+
@test first(sol7.u) isa Vector
96+
@test length(sol7.u) == 11
97+
@test length(sol7.t) == 11
98+
99+
sol8 = sol(0.1, idxs=[2,1])
100+
@test sol8 isa Vector
101+
@test length(sol8) == 2
102+
@test first(sol8) isa Real
103+
104+
sol9 = sol(0.0:1.0:10.0, idxs=2)
105+
@test sol9.u isa Vector
106+
@test first(sol9.u) isa Real
107+
@test length(sol9.u) == 11
108+
@test length(sol9.t) == 11
109+
110+
sol10 = sol(0.1, idxs=2)
111+
@test sol10 isa Real
112+
113+
114+
#=
115+
using Plots
116+
plot(sol,vars=(lorenz2.x,lorenz2.z))
117+
plot(sol,vars=(α,lorenz2.z))
118+
plot(sol,vars=(lorenz2.x,α))
119+
plot(sol,vars=α)
120+
plot(sol,vars=(t,α))
121+
=#

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,6 @@ end
2727
if !is_APPVEYOR && GROUP == "Downstream"
2828
activate_downstream_env()
2929
@time @safetestset "Ensembles of Zero Length Solutions" begin include("downstream/ensemble_zero_length.jl") end
30+
@time @safetestset "Symbol and integer based indexing of interpolated solutions" begin include("downstream/symbol_indexing.jl") end
3031
end
3132
end

0 commit comments

Comments
 (0)