@@ -44,79 +44,107 @@ $(FIELDS)
44
44
eq_options:: SymbolicRegression.Options = SymbolicRegression. Options ()
45
45
end
46
46
47
- struct SRResult{H, P, T, TE} <: AbstractDataDrivenResult
47
+ struct SRResult{H, P, T} <: AbstractDataDrivenResult
48
+ " The resulting basis"
49
+ basis:: Basis
50
+ " The Hall of Fame"
48
51
halloffame:: H
52
+ """ The Paretofrontier"""
49
53
paretofrontier:: P
50
- testerror:: T
51
- trainerror:: TE
54
+ # StatsBase results
55
+ """ Residual sum of squares"""
56
+ rss:: T
57
+ """ Loglikelihood"""
58
+ loglikelihood:: T
59
+ """ Nullloglikelihood"""
60
+ nullloglikelihood:: T
61
+ """ Degrees of freedom"""
62
+ dof:: Int
63
+ """ Number of observations"""
64
+ nobs:: Int
65
+ """ Returncode"""
52
66
retcode:: DDReturnCode
53
67
end
54
68
55
- is_success (k:: SRResult ) = getfield (k, :retcode ) == DDReturnCode (1 )
56
- l2error (k:: SRResult ) = is_success (k) ? getfield (k, :testerror ) : Inf
57
- function l2error (k:: SRResult{<:Any, <:Any, <:Any, Nothing} )
58
- is_success (k) ? getfield (k, :traineerror ) : Inf
69
+ function SRResult (prob, hof, paretos)
70
+ @unpack basis, problem = prob
71
+ bs = convert_to_basis (paretos, prob)
72
+ ps = get_parameter_values (bs)
73
+ problem = DataDrivenDiffEq. remake_problem (problem, p = ps)
74
+ y = DataDrivenDiffEq. get_implicit_data (problem)
75
+ rss = sum (abs2, y .- bs (problem))
76
+ dof = length (ps)
77
+ nobs = prod (size (y))
78
+ ll = iszero (rss) ? convert (eltype (rss), Inf ) : - nobs / 2 * log (rss / nobs)
79
+ ll0 = - nobs / 2 * log .(sum (abs2, y .- mean (y, dims = 2 )[:, 1 ]) / nobs)
80
+ return SRResult (bs, hof, paretos,
81
+ rss, ll, ll0, dof, nobs,
82
+ DDReturnCode (1 ))
59
83
end
60
84
61
- # apply the algorithm on each dataset
62
- function (x:: EQSearch )(ps:: InternalDataDrivenProblem{EQSearch} , X, Y)
63
- @unpack problem, testdata, options = ps
64
- @unpack maxiters, abstol = options
65
- @unpack weights, eq_options, numprocs, procs, parallelism, runtests = x
85
+ is_success (k:: SRResult ) = getfield (k, :retcode ) == DDReturnCode (1 )
66
86
67
- hofs = SymbolicRegression. EquationSearch (X, Y,
68
- niterations = maxiters,
69
- weights = weights,
70
- options = eq_options,
71
- numprocs = numprocs,
72
- procs = procs, parallelism = parallelism,
73
- runtests = runtests)
87
+ # StatsBase Overload
88
+ StatsBase. coef (x:: SRResult ) = getfield (x, :k )
74
89
75
- # We always want something which is a vector or tuple
76
- hofs = ! isa (hofs, AbstractVector) ? [hofs] : hofs
90
+ StatsBase. rss (x:: SRResult ) = getfield (x, :rss )
77
91
78
- # Evaluate over the full training data
79
- paretos = map (enumerate (hofs)) do (i, hof)
80
- SymbolicRegression. calculate_pareto_frontier (X, Y[i, :], hof, eq_options)
92
+ StatsBase. dof (x:: SRResult ) = getfield (x, :dof )
93
+
94
+ StatsBase. nobs (x:: SRResult ) = getfield (x, :nobs )
95
+
96
+ StatsBase. loglikelihood (x:: SRResult ) = getfield (x, :loglikelihood )
97
+
98
+ StatsBase. nullloglikelihood (x:: SRResult ) = getfield (x, :nullloglikelihood )
99
+
100
+ StatsBase. r2 (x:: SRResult ) = r2 (x, :CoxSnell )
101
+
102
+ function collect_numerical_parameters (eq, options = DataDrivenCommonOptions ())
103
+ ps = Any[]
104
+ eqs = map (eq) do eqi
105
+ _collect_numerical_parameters! (ps, eqi, options)
81
106
end
107
+ return eqs, ps
108
+ end
82
109
83
- # Trainingerror
84
- trainerror = mean (x -> x[end ]. loss, paretos)
85
- # Testerror
86
- X̃, Ỹ = testdata
87
- if ! isempty (X̃)
88
- testerror = mean (map (enumerate (hofs)) do (i, hof)
89
- doms = SymbolicRegression. calculate_pareto_frontier (X̃,
90
- Ỹ[i, :],
91
- hof,
92
- eq_options)
93
- doms[end ]. loss
94
- end )
95
- retcode = testerror <= abstol ? DDReturnCode (1 ) : DDReturnCode (5 )
110
+ function _collect_numerical_parameters! (ps:: AbstractVector , eq, options)
111
+ if Symbolics. istree (eq)
112
+ args_ = map (Symbolics. arguments (eq)) do (eqi)
113
+ _collect_numerical_parameters! (ps, eqi, options)
114
+ end
115
+ return Symbolics. operation (eq)(args_... )
116
+ elseif isa (eq, Number)
117
+ pval = round (eq, options. roundingmode, digits = options. digits)
118
+ # We do not collect zeros or ones
119
+ iszero (pval) && return zero (eltype (pval))
120
+ (abs (pval) ≈ 1 ) & return sign (pval) * one (eltype (pval))
121
+ p_ = Symbolics. variable (:p , length (ps) + 1 )
122
+ p_ = Symbolics. setdefaultval (p_, pval)
123
+ p_ = ModelingToolkit. toparam (p_)
124
+ push! (ps, p_)
125
+ return p_
96
126
else
97
- testerror = nothing
98
- retcode = trainerror <= abstol ? DDReturnCode (1 ) : DDReturnCode (5 )
127
+ return eq
99
128
end
100
-
101
- return SRResult (hofs, paretos, testerror, trainerror, retcode)
102
129
end
103
130
104
- function convert_to_basis (res:: SRResult , prob)
105
- @unpack paretofrontier = res
131
+ function convert_to_basis (paretofrontier, prob)
106
132
@unpack alg, basis, problem, options = prob
107
133
@unpack eq_options = alg
108
134
@unpack maxiters, eval_expresssion, generate_symbolic_parameters, digits, roundingmode = options
109
135
110
- eqs_ = Num .( map (paretofrontier) do dom
111
- node_to_symbolic (dom[end ]. tree, eq_options)
112
- end )
136
+ eqs_ = map (paretofrontier) do dom
137
+ node_to_symbolic (dom[end ]. tree, eq_options)
138
+ end
113
139
114
140
# Substitute with the basis elements
115
141
atoms = map (xi -> xi. rhs, equations (basis))
116
142
117
143
subs = Dict ([SymbolicUtils. Sym {LiteralReal} (Symbol (" x$(i) " )) => x
118
144
for (i, x) in enumerate (atoms)]. .. )
119
- eqs = map (Base. Fix2 (substitute, subs), eqs_)
145
+
146
+ eqs, ps = collect_numerical_parameters (eqs_)
147
+ eqs = map (Base. Fix2 (substitute, subs), eqs)
120
148
121
149
# Get the lhs
122
150
causality, dt = DataDrivenDiffEq. assert_lhs (problem)
@@ -135,40 +163,61 @@ function convert_to_basis(res::SRResult, prob)
135
163
eqs = [phi[i] ~ eq for (i, eq) in enumerate (eqs)]
136
164
end
137
165
138
- ps = parameters (basis)
166
+ ps_ = parameters (basis)
139
167
@unpack p = problem
140
168
141
169
p_new = map (eachindex (p)) do i
142
- DataDrivenDiffEq. _set_default_val (Num (ps [i]), p[i])
170
+ DataDrivenDiffEq. _set_default_val (Num (ps_ [i]), p[i])
143
171
end
144
172
145
173
Basis (eqs, states (basis),
146
- parameters = p_new, iv = get_iv (basis),
174
+ parameters = [ p_new; ps] , iv = get_iv (basis),
147
175
controls = controls (basis), observed = observed (basis),
148
176
implicits = implicit_variables (basis),
149
177
name = gensym (:Basis ),
150
178
eval_expression = eval_expresssion)
151
179
end
152
180
181
+ # apply the algorithm on each dataset
182
+ function (x:: EQSearch )(ps:: InternalDataDrivenProblem{EQSearch} , X, Y)
183
+ @unpack problem, testdata, options = ps
184
+ @unpack maxiters, abstol = options
185
+ @unpack weights, eq_options, numprocs, procs, parallelism, runtests = x
186
+
187
+ hofs = SymbolicRegression. EquationSearch (X, Y,
188
+ niterations = maxiters,
189
+ weights = weights,
190
+ options = eq_options,
191
+ numprocs = numprocs,
192
+ procs = procs, parallelism = parallelism,
193
+ runtests = runtests)
194
+
195
+ # We always want something which is a vector or tuple
196
+ hofs = ! isa (hofs, AbstractVector) ? [hofs] : hofs
197
+
198
+ # Evaluate over the full training data
199
+ paretos = map (enumerate (hofs)) do (i, hof)
200
+ SymbolicRegression. calculate_pareto_frontier (X, Y[i, :], hof, eq_options)
201
+ end
202
+
203
+ return SRResult (ps, hofs, paretos)
204
+ end
205
+
153
206
function CommonSolve. solve! (ps:: InternalDataDrivenProblem{EQSearch} )
154
207
@unpack alg, basis, testdata, traindata, kwargs = ps
155
208
@unpack weights, numprocs, procs, addprocs_function, parallelism, runtests, eq_options = alg
156
209
@unpack traindata, testdata, basis, options = ps
157
- @unpack maxiters, eval_expresssion, generate_symbolic_parameters, digits, roundingmode = options
210
+ @unpack maxiters, eval_expresssion, generate_symbolic_parameters, digits, roundingmode, selector = options
158
211
@unpack problem = ps
159
212
160
213
results = map (traindata) do (X, Y)
161
214
alg (ps, X, Y)
162
215
end
163
216
164
- # Get the best result based on test error, if applicable else use testerror
165
- sort! (results, by = l2error)
166
- # Convert to basis
167
- best_res = first (results)
168
-
169
- new_basis = convert_to_basis (best_res, ps)
217
+ idx = argmin (map (selector, results))
218
+ best_res = results[idx]
170
219
171
- DataDrivenSolution (new_basis , problem, alg, results, ps, best_res. retcode)
220
+ DataDrivenSolution (best_res . basis , problem, alg, results, ps, best_res. retcode)
172
221
end
173
222
174
223
export EQSearch
0 commit comments