Skip to content

Commit 2fbd961

Browse files
authored
Merge pull request #120 from SciML/gg
Allow for a GenearlizedGenerated-free basis generation setup
2 parents f2516a7 + 92d300f commit 2fbd961

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

src/basis.jl

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ is_independent(o::Operation) = isempty(o.args)
3434

3535

3636
"""
37-
Basis(f, u; p, iv)
37+
Basis(f, u; p, iv, eval_expression)
3838
3939
A basis over the variables `u` with parameters `p` and independent variable `iv`.
4040
`f` can either be a Julia function which is able to use ModelingToolkit variables or
@@ -53,8 +53,20 @@ using DataDrivenDiffEq
5353
5454
Ψ = Basis([u; sin.(w.*u)], u, parameters = p, iv = t)
5555
```
56+
57+
## Note
58+
59+
The keyword argument `eval_expression` controls the function creation
60+
behavior. `eval_expression=true` means that `eval` is used, so normal
61+
world-age behavior applies (i.e. the functions cannot be called from
62+
the function that generates them). If `eval_expression=false`,
63+
then construction via GeneralizedGenerated.jl is utilized to allow for
64+
same world-age evaluation. However, this can cause Julia to segfault
65+
on sufficiently large basis functions. By default eval_expression=false.
66+
5667
"""
57-
function Basis(basis::AbstractArray{Operation}, variables::AbstractArray{Operation}; parameters::AbstractArray = Operation[], iv = nothing)
68+
function Basis(basis::AbstractArray{Operation}, variables::AbstractArray{Operation};
69+
parameters::AbstractArray = Operation[], iv = nothing, eval_expression = false)
5870
@assert all(is_independent.(variables)) "Please provide independent variables for basis."
5971

6072
bs = unique(basis)
@@ -67,7 +79,11 @@ function Basis(basis::AbstractArray{Operation}, variables::AbstractArray{Operati
6779
vs = [ModelingToolkit.Variable(Symbol(i)) for i in variables]
6880
ps = [ModelingToolkit.Variable(Symbol(i)) for i in parameters]
6981

70-
f_oop, f_iip = ModelingToolkit.build_function(bs, vs, ps, [iv], expression = Val{false})
82+
if eval_expression
83+
f_oop, f_iip = eval.(ModelingToolkit.build_function(bs, vs, ps, [iv], expression = Val{true}))
84+
else
85+
f_oop, f_iip = ModelingToolkit.build_function(bs, vs, ps, [iv], expression = Val{false})
86+
end
7187

7288
f_(u, p, t) = f_oop(u, p, t)
7389
f_(du, u, p, t) = f_iip(du, u, p, t)
@@ -93,12 +109,16 @@ function Basis(basis::Function, variables::AbstractArray{Operation}; parameters
93109
end
94110

95111

96-
function update!(basis::Basis)
112+
function update!(basis::Basis,eval_expression = false)
97113

98114
vs = [ModelingToolkit.Variable(Symbol(i))(basis.iv) for i in variables(basis)]
99115
ps = [ModelingToolkit.Variable(Symbol(i)) for i in parameters(basis)]
100116

101-
f_oop, f_iip = ModelingToolkit.build_function(basis.basis, vs, ps, [basis.iv], expression = Val{false})
117+
if eval_expression
118+
f_oop, f_iip = eval.(ModelingToolkit.build_function(basis.basis, vs, ps, [basis.iv], expression = Val{false}))
119+
else
120+
f_oop, f_iip = ModelingToolkit.build_function(basis.basis, vs, ps, [basis.iv], expression = Val{false})
121+
end
102122

103123
f_(u, p, t) = f_oop(u, p, t)
104124
f_(du, u, p, t) = f_iip(du, u, p, t)
@@ -260,14 +280,18 @@ ModelingToolkit.independent_variable(b::Basis) = b.iv
260280
Returns a function representing the jacobian matrix / gradient of the `Basis` with respect to the
261281
dependent variables as a function with the common signature `f(u,p,t)` for out of place and `f(du, u, p, t)` for in place computation.
262282
"""
263-
function jacobian(basis::Basis)
283+
function jacobian(basis::Basis, eval_expression = false)
264284

265285
vs = [ModelingToolkit.Variable(Symbol(i))(independent_variable(basis)) for i in variables(basis)]
266286
ps = [ModelingToolkit.Variable(Symbol(i)) for i in parameters(basis)]
267287

268288
j = ModelingToolkit.jacobian(basis.basis, variables(basis))
269289

270-
f_oop, f_iip = ModelingToolkit.build_function(expand_derivatives.(j), vs, ps, [basis.iv], expression = Val{false})
290+
if eval_expression
291+
f_oop, f_iip = eval.(ModelingToolkit.build_function(expand_derivatives.(j), vs, ps, [basis.iv], expression = Val{true}))
292+
else
293+
f_oop, f_iip = ModelingToolkit.build_function(expand_derivatives.(j), vs, ps, [basis.iv], expression = Val{false})
294+
end
271295

272296
f_(u, p, t) = f_oop(u, p, t)
273297
f_(du, u, p, t) = f_iip(du, u, p, t)

0 commit comments

Comments
 (0)