Skip to content

Commit 1f88956

Browse files
authored
Merge pull request #437 from SciML/CJMselectionrework
Adapt selection
2 parents 25e09f4 + a632989 commit 1f88956

File tree

8 files changed

+173
-101
lines changed

8 files changed

+173
-101
lines changed

docs/src/solvers/common.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,4 +47,16 @@ As we can see above, the use of a [`Basis`](@ref) is optional to invoke the esti
4747

4848
The [`DataDrivenSolution`](@ref) `res` contains a `result` which is the inferred system and a [`Basis`](@ref).
4949

50+
## Model Selection
5051

52+
Most estimation and model inference algorithms require hyperparameters ,e.g., the sparsity controlling penalty, train-test splits. To account for this, the keyword `selector` can be passed to the [`DataDrivenCommonOptions`](@ref). This allows the user to control the selection criteria and returns the **minimum** selector.
53+
54+
Common choices for `selector` are `rss`, `bic`, `aic`, `aicc`, and `r2`. Given that each subresult of the algorithm extends the `StatsBase` api, we can also use different schemes like:
55+
56+
```julia
57+
options = DataDrivenCommonOptions(
58+
selector = (x)->rss(x) / nobs(x)
59+
)
60+
```
61+
62+
Which results in the mean squared error of the system.

lib/DataDrivenDMD/src/result.jl

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct KoopmanResult{K, B, C, Q, P, T, TE} <: AbstractDataDrivenResult
1+
struct KoopmanResult{K, B, C, Q, P, T} <: AbstractDataDrivenResult
22
"""Matrix representation of the operator / generator"""
33
k::K
44
"""Matrix representation of the inputs mapping"""
@@ -9,26 +9,54 @@ struct KoopmanResult{K, B, C, Q, P, T, TE} <: AbstractDataDrivenResult
99
q::Q
1010
"""Internal matrix used for updating"""
1111
p::P
12-
"""L2 norm error of the testing dataset"""
13-
testerror::T
14-
"""L2 norm error of the training dataset"""
15-
trainerror::TE
12+
# StatsBase results
13+
"""Residual sum of squares"""
14+
rss::T
15+
"""Loglikelihood"""
16+
loglikelihood::T
17+
"""Nullloglikelihood"""
18+
nullloglikelihood::T
19+
"""Degrees of freedom"""
20+
dof::Int
21+
"""Number of observations"""
22+
nobs::Int
23+
1624
"""Returncode"""
1725
retcode::DDReturnCode
18-
end
1926

20-
is_success(k::KoopmanResult) = getfield(k, :retcode) == DDReturnCode(1)
21-
l2error(k::KoopmanResult) = is_success(k) ? getfield(k, :testerror) : Inf
27+
function KoopmanResult(k_::K, b::B, c::C, q::Q, p::P, X::AbstractMatrix{T},
28+
Y::AbstractMatrix{T}, U::AbstractMatrix) where {K, B, C, Q, P, T}
29+
k = Matrix(k_)
30+
rss = isempty(b) ? sum(abs2, Y .- c * k * X) : sum(abs2, Y .- c * (k * X .+ b * U))
31+
dof = sum(!iszero, k)
32+
dof += isempty(b) ? 0 : sum(!iszero, b)
33+
nobs = prod(size(Y))
34+
ll = -nobs / 2 * log(rss / nobs)
35+
nll = -nobs / 2 * log(mean(abs2, Y .- vec(mean(Y, dims = 2))))
2236

23-
function l2error(k::KoopmanResult{<:Any, <:Any, <:Any, <:Any, <:Any, Nothing})
24-
is_success(k) ? getfield(k, :traineerror) : Inf
37+
new{K, B, C, Q, P, T}(k_, b, c, q, p, rss, ll, nll, dof, nobs, DDReturnCode(1))
38+
end
2539
end
2640

41+
is_success(k::KoopmanResult) = getfield(k, :retcode) == DDReturnCode(1)
42+
2743
get_operator(k::KoopmanResult) = getfield(k, :k)
2844
get_generator(k::KoopmanResult) = getfield(k, :k)
2945

3046
get_inputmap(k::KoopmanResult) = getfield(k, :b)
3147
get_outputmap(k::KoopmanResult) = getfield(k, :c)
3248

33-
get_trainerror(k::KoopmanResult) = getfield(k, :trainerror)
34-
get_testerror(k::KoopmanResult) = getfield(k, :testerror)
49+
# StatsBase Overload
50+
StatsBase.coef(x::KoopmanResult) = getfield(x, :k)
51+
52+
StatsBase.rss(x::KoopmanResult) = getfield(x, :rss)
53+
54+
StatsBase.dof(x::KoopmanResult) = getfield(x, :dof)
55+
56+
StatsBase.nobs(x::KoopmanResult) = getfield(x, :nobs)
57+
58+
StatsBase.loglikelihood(x::KoopmanResult) = getfield(x, :loglikelihood)
59+
60+
StatsBase.nullloglikelihood(x::KoopmanResult) = getfield(x, :nullloglikelihood)
61+
62+
StatsBase.r2(x::KoopmanResult) = r2(x, :CoxSnell)

lib/DataDrivenDMD/src/solve.jl

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@ function CommonSolve.solve!(prob::InternalDataDrivenProblem{A}) where {
6464
AbstractKoopmanAlgorithm
6565
}
6666
@unpack alg, basis, testdata, traindata, control_idx, options, problem, kwargs = prob
67-
67+
@unpack selector = options
6868
# Check for
6969
results = alg(prob; kwargs...)
7070

71-
# Get the best result based on test error, if applicable else use testerror
72-
sort!(results, by = l2error)
71+
# Get the best result based on selector
72+
idx = argmin(map(selector, results))
73+
best_res = results[idx]
7374
# Convert to basis
74-
best_res = first(results)
7575
new_basis = convert_to_basis(best_res, basis, problem, options, control_idx)
7676
# Build DataDrivenResult
7777
DataDrivenSolution(new_basis, problem, alg, results, prob, best_res.retcode)
@@ -93,13 +93,6 @@ function convert_to_basis(res::KoopmanResult, basis::Basis, prob, options, contr
9393
DataDrivenDiffEq.__construct_basis(Θ, basis, prob, options)
9494
end
9595

96-
function __compute_rss(Z, C, K, B, X, U)
97-
begin
98-
(isempty(U) || isempty(B)) && return sum(abs2, Z .- C * (K * X))
99-
return sum(abs2, Z .- C * (K * X + B * U))
100-
end
101-
end
102-
10396
function (algorithm::AbstractKoopmanAlgorithm)(prob::InternalDataDrivenProblem;
10497
control_input = nothing, kwargs...)
10598
@unpack traindata, testdata, control_idx, options = prob
@@ -127,14 +120,6 @@ function (algorithm::AbstractKoopmanAlgorithm)(prob::InternalDataDrivenProblem;
127120
Q = Y_ * X'
128121
P = X * X'
129122
C = Z / Y_
130-
trainerror = __compute_rss(Z, C, Matrix(K), B, X_, U_)
131-
if !isempty(X̃)
132-
testerror = __compute_rss(Z̃, C, Matrix(K), B, X̃, Ũ)
133-
retcode = testerror <= abstol ? DDReturnCode(1) : DDReturnCode(5)
134-
else
135-
testerror = nothing
136-
retcode = trainerror <= abstol ? DDReturnCode(1) : DDReturnCode(5)
137-
end
138-
KoopmanResult(K, B, C, Q, P, testerror, trainerror, retcode)
123+
KoopmanResult(K, B, C, Q, P, X_, Z, U_)
139124
end
140125
end

lib/DataDrivenDMD/test/linear_autonomous.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ rng = StableRNG(42)
3131
@test Matrix(get_operator(operator_res)) A
3232
@test isempty(get_inputmap(operator_res))
3333
@test get_outputmap(operator_res) I(2)
34-
@test get_trainerror(operator_res) <= 1e-10
35-
@test isnothing(get_testerror(operator_res))
34+
@test rss(operator_res) <= 1e-10
3635
end
3736
end
3837
end
@@ -52,8 +51,7 @@ rng = StableRNG(42)
5251
@test Matrix(get_operator(operator_res))A atol=1e-2
5352
@test isempty(get_inputmap(operator_res))
5453
@test get_outputmap(operator_res) I(2)
55-
@test get_trainerror(operator_res) <= 1e-2
56-
@test isnothing(get_testerror(operator_res))
54+
@test rss(operator_res) <= 1e-2
5755
end
5856
end
5957
end
@@ -80,8 +78,7 @@ end
8078
@test Matrix(get_operator(operator_res)) A
8179
@test isempty(get_inputmap(operator_res))
8280
@test get_outputmap(operator_res) I(2)
83-
@test get_trainerror(operator_res) <= 1e-10
84-
@test isnothing(get_testerror(operator_res))
81+
@test rss(operator_res) <= 1e-10
8582
end
8683
end
8784
end

lib/DataDrivenSR/src/DataDrivenSR.jl

Lines changed: 107 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -44,79 +44,107 @@ $(FIELDS)
4444
eq_options::SymbolicRegression.Options = SymbolicRegression.Options()
4545
end
4646

47-
struct SRResult{H, P, T, TE} <: AbstractDataDrivenResult
47+
struct SRResult{H, P, T} <: AbstractDataDrivenResult
48+
"The resulting basis"
49+
basis::Basis
50+
"The Hall of Fame"
4851
halloffame::H
52+
"""The Paretofrontier"""
4953
paretofrontier::P
50-
testerror::T
51-
trainerror::TE
54+
# StatsBase results
55+
"""Residual sum of squares"""
56+
rss::T
57+
"""Loglikelihood"""
58+
loglikelihood::T
59+
"""Nullloglikelihood"""
60+
nullloglikelihood::T
61+
"""Degrees of freedom"""
62+
dof::Int
63+
"""Number of observations"""
64+
nobs::Int
65+
"""Returncode"""
5266
retcode::DDReturnCode
5367
end
5468

55-
is_success(k::SRResult) = getfield(k, :retcode) == DDReturnCode(1)
56-
l2error(k::SRResult) = is_success(k) ? getfield(k, :testerror) : Inf
57-
function l2error(k::SRResult{<:Any, <:Any, <:Any, Nothing})
58-
is_success(k) ? getfield(k, :traineerror) : Inf
69+
function SRResult(prob, hof, paretos)
70+
@unpack basis, problem = prob
71+
bs = convert_to_basis(paretos, prob)
72+
ps = get_parameter_values(bs)
73+
problem = DataDrivenDiffEq.remake_problem(problem, p = ps)
74+
y = DataDrivenDiffEq.get_implicit_data(problem)
75+
rss = sum(abs2, y .- bs(problem))
76+
dof = length(ps)
77+
nobs = prod(size(y))
78+
ll = iszero(rss) ? convert(eltype(rss), Inf) : -nobs / 2 * log(rss / nobs)
79+
ll0 = -nobs / 2 * log.(sum(abs2, y .- mean(y, dims = 2)[:, 1]) / nobs)
80+
return SRResult(bs, hof, paretos,
81+
rss, ll, ll0, dof, nobs,
82+
DDReturnCode(1))
5983
end
6084

61-
# apply the algorithm on each dataset
62-
function (x::EQSearch)(ps::InternalDataDrivenProblem{EQSearch}, X, Y)
63-
@unpack problem, testdata, options = ps
64-
@unpack maxiters, abstol = options
65-
@unpack weights, eq_options, numprocs, procs, parallelism, runtests = x
85+
is_success(k::SRResult) = getfield(k, :retcode) == DDReturnCode(1)
6686

67-
hofs = SymbolicRegression.EquationSearch(X, Y,
68-
niterations = maxiters,
69-
weights = weights,
70-
options = eq_options,
71-
numprocs = numprocs,
72-
procs = procs, parallelism = parallelism,
73-
runtests = runtests)
87+
# StatsBase Overload
88+
StatsBase.coef(x::SRResult) = getfield(x, :k)
7489

75-
# We always want something which is a vector or tuple
76-
hofs = !isa(hofs, AbstractVector) ? [hofs] : hofs
90+
StatsBase.rss(x::SRResult) = getfield(x, :rss)
7791

78-
# Evaluate over the full training data
79-
paretos = map(enumerate(hofs)) do (i, hof)
80-
SymbolicRegression.calculate_pareto_frontier(X, Y[i, :], hof, eq_options)
92+
StatsBase.dof(x::SRResult) = getfield(x, :dof)
93+
94+
StatsBase.nobs(x::SRResult) = getfield(x, :nobs)
95+
96+
StatsBase.loglikelihood(x::SRResult) = getfield(x, :loglikelihood)
97+
98+
StatsBase.nullloglikelihood(x::SRResult) = getfield(x, :nullloglikelihood)
99+
100+
StatsBase.r2(x::SRResult) = r2(x, :CoxSnell)
101+
102+
function collect_numerical_parameters(eq, options = DataDrivenCommonOptions())
103+
ps = Any[]
104+
eqs = map(eq) do eqi
105+
_collect_numerical_parameters!(ps, eqi, options)
81106
end
107+
return eqs, ps
108+
end
82109

83-
# Trainingerror
84-
trainerror = mean(x -> x[end].loss, paretos)
85-
# Testerror
86-
X̃, Ỹ = testdata
87-
if !isempty(X̃)
88-
testerror = mean(map(enumerate(hofs)) do (i, hof)
89-
doms = SymbolicRegression.calculate_pareto_frontier(X̃,
90-
Ỹ[i, :],
91-
hof,
92-
eq_options)
93-
doms[end].loss
94-
end)
95-
retcode = testerror <= abstol ? DDReturnCode(1) : DDReturnCode(5)
110+
function _collect_numerical_parameters!(ps::AbstractVector, eq, options)
111+
if Symbolics.istree(eq)
112+
args_ = map(Symbolics.arguments(eq)) do (eqi)
113+
_collect_numerical_parameters!(ps, eqi, options)
114+
end
115+
return Symbolics.operation(eq)(args_...)
116+
elseif isa(eq, Number)
117+
pval = round(eq, options.roundingmode, digits = options.digits)
118+
# We do not collect zeros or ones
119+
iszero(pval) && return zero(eltype(pval))
120+
(abs(pval) 1) & return sign(pval) * one(eltype(pval))
121+
p_ = Symbolics.variable(:p, length(ps) + 1)
122+
p_ = Symbolics.setdefaultval(p_, pval)
123+
p_ = ModelingToolkit.toparam(p_)
124+
push!(ps, p_)
125+
return p_
96126
else
97-
testerror = nothing
98-
retcode = trainerror <= abstol ? DDReturnCode(1) : DDReturnCode(5)
127+
return eq
99128
end
100-
101-
return SRResult(hofs, paretos, testerror, trainerror, retcode)
102129
end
103130

104-
function convert_to_basis(res::SRResult, prob)
105-
@unpack paretofrontier = res
131+
function convert_to_basis(paretofrontier, prob)
106132
@unpack alg, basis, problem, options = prob
107133
@unpack eq_options = alg
108134
@unpack maxiters, eval_expresssion, generate_symbolic_parameters, digits, roundingmode = options
109135

110-
eqs_ = Num.(map(paretofrontier) do dom
111-
node_to_symbolic(dom[end].tree, eq_options)
112-
end)
136+
eqs_ = map(paretofrontier) do dom
137+
node_to_symbolic(dom[end].tree, eq_options)
138+
end
113139

114140
# Substitute with the basis elements
115141
atoms = map(xi -> xi.rhs, equations(basis))
116142

117143
subs = Dict([SymbolicUtils.Sym{LiteralReal}(Symbol("x$(i)")) => x
118144
for (i, x) in enumerate(atoms)]...)
119-
eqs = map(Base.Fix2(substitute, subs), eqs_)
145+
146+
eqs, ps = collect_numerical_parameters(eqs_)
147+
eqs = map(Base.Fix2(substitute, subs), eqs)
120148

121149
# Get the lhs
122150
causality, dt = DataDrivenDiffEq.assert_lhs(problem)
@@ -135,40 +163,61 @@ function convert_to_basis(res::SRResult, prob)
135163
eqs = [phi[i] ~ eq for (i, eq) in enumerate(eqs)]
136164
end
137165

138-
ps = parameters(basis)
166+
ps_ = parameters(basis)
139167
@unpack p = problem
140168

141169
p_new = map(eachindex(p)) do i
142-
DataDrivenDiffEq._set_default_val(Num(ps[i]), p[i])
170+
DataDrivenDiffEq._set_default_val(Num(ps_[i]), p[i])
143171
end
144172

145173
Basis(eqs, states(basis),
146-
parameters = p_new, iv = get_iv(basis),
174+
parameters = [p_new; ps], iv = get_iv(basis),
147175
controls = controls(basis), observed = observed(basis),
148176
implicits = implicit_variables(basis),
149177
name = gensym(:Basis),
150178
eval_expression = eval_expresssion)
151179
end
152180

181+
# apply the algorithm on each dataset
182+
function (x::EQSearch)(ps::InternalDataDrivenProblem{EQSearch}, X, Y)
183+
@unpack problem, testdata, options = ps
184+
@unpack maxiters, abstol = options
185+
@unpack weights, eq_options, numprocs, procs, parallelism, runtests = x
186+
187+
hofs = SymbolicRegression.EquationSearch(X, Y,
188+
niterations = maxiters,
189+
weights = weights,
190+
options = eq_options,
191+
numprocs = numprocs,
192+
procs = procs, parallelism = parallelism,
193+
runtests = runtests)
194+
195+
# We always want something which is a vector or tuple
196+
hofs = !isa(hofs, AbstractVector) ? [hofs] : hofs
197+
198+
# Evaluate over the full training data
199+
paretos = map(enumerate(hofs)) do (i, hof)
200+
SymbolicRegression.calculate_pareto_frontier(X, Y[i, :], hof, eq_options)
201+
end
202+
203+
return SRResult(ps, hofs, paretos)
204+
end
205+
153206
function CommonSolve.solve!(ps::InternalDataDrivenProblem{EQSearch})
154207
@unpack alg, basis, testdata, traindata, kwargs = ps
155208
@unpack weights, numprocs, procs, addprocs_function, parallelism, runtests, eq_options = alg
156209
@unpack traindata, testdata, basis, options = ps
157-
@unpack maxiters, eval_expresssion, generate_symbolic_parameters, digits, roundingmode = options
210+
@unpack maxiters, eval_expresssion, generate_symbolic_parameters, digits, roundingmode, selector = options
158211
@unpack problem = ps
159212

160213
results = map(traindata) do (X, Y)
161214
alg(ps, X, Y)
162215
end
163216

164-
# Get the best result based on test error, if applicable else use testerror
165-
sort!(results, by = l2error)
166-
# Convert to basis
167-
best_res = first(results)
168-
169-
new_basis = convert_to_basis(best_res, ps)
217+
idx = argmin(map(selector, results))
218+
best_res = results[idx]
170219

171-
DataDrivenSolution(new_basis, problem, alg, results, ps, best_res.retcode)
220+
DataDrivenSolution(best_res.basis, problem, alg, results, ps, best_res.retcode)
172221
end
173222

174223
export EQSearch

0 commit comments

Comments
 (0)