Skip to content

Commit 650f403

Browse files
Reuse integrators in nlstep and NonlinearSolveAlg integrations
This is a pretty major performance boost
1 parent 7393265 commit 650f403

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

lib/OrdinaryDiffEqNonlinearSolve/src/newton.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,14 @@ function initialize!(nlsolver::NLSolver{<:NonlinearSolveAlg, true},
7676
nlstep_data.set_outer_tmp(nlstep_data.nlprob, zero(z))
7777
end
7878
nlstep_data.nlprob.u0 .= @view z[nlstep_data.u0perm]
79-
cache.cache = init(nlstep_data.nlprob, alg.alg)
79+
SciMLBase.reinit!(cache.cache, nlstep_data.nlprob.u0, p=nlstep_data.nlprob.p)
8080
else
8181
if f isa DAEFunction
8282
nlp_params = (tmp, ztmp, ustep, γ, α, tstep, k, invγdt, p, dt, f)
8383
else
8484
nlp_params = (tmp, ustep, γ, α, tstep, k, invγdt, method, p, dt, f)
8585
end
86-
new_prob = remake(cache.prob, p = nlp_params, u0 = z)
87-
cache.cache = init(new_prob, alg.alg)
86+
SciMLBase.reinit!(cache.cache, z, p=nlp_params)
8887
end
8988
nothing
9089
end

test/modelingtoolkit/nlstep_tests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ using Test
1010
eqs = [D(y₁) ~ -k₁ * y₁ + k₃ * y₂ * y₃,
1111
D(y₂) ~ k₁ * y₁ - k₂ * y₂^2 - k₃ * y₂ * y₃,
1212
D(y₃) ~ k₂ * y₂^2]
13-
@mtkbuild rober = ODESystem(eqs, t)
13+
@mtkcompile rober = ODESystem(eqs, t)
1414
prob = ODEProblem(rober, [[y₁, y₂, y₃] .=> [1.0; 0.0; 0.0]; [k₁, k₂, k₃] .=> (0.04, 3e7, 1e4)], (0.0, 1e5), jac = true)
1515
prob2 = ODEProblem(rober, [[y₁, y₂, y₃] .=> [1.0; 0.0; 0.0]; [k₁, k₂, k₃] .=> (0.04, 3e7, 1e4)], (0.0, 1e5), jac = true, nlstep = true)
1616

@@ -29,7 +29,7 @@ sol1 = solve(prob, TRBDF2(autodiff=AutoFiniteDiff(), nlsolve = nlalg));
2929
sol2 = solve(prob2, TRBDF2(autodiff=AutoFiniteDiff(), nlsolve = nlalg));
3030

3131
@test sol1.t != sol2.t
32-
@test sol1 != sol2
32+
@test sol1.u != sol2.u
3333
@test sol1(sol1.t) sol2(sol1.t) atol=1e-3
3434

3535
testprob = ODEProblem(rober, [[y₁, y₂, y₃] .=> [1.0; 0.0; 0.0]; [k₁, k₂, k₃] .=> (0.04, 3e7, 1e4)], (0.0, 1.0), nlstep = true)
@@ -60,15 +60,15 @@ sim = analyticless_test_convergence(dts, testprob, FBDF(autodiff=AutoFiniteDiff(
6060
eqs_nonaut = [D(y₁) ~ -k₁ * y₁ + (t+1) * k₃ * y₂ * y₃,
6161
D(y₂) ~ k₁ * y₁ - (t+1) * k₂ * y₂^2 - (t+1) * k₃ * y₂ * y₃,
6262
D(y₃) ~ (t+1) * k₂ * y₂^2]
63-
@mtkbuild rober_nonaut = ODESystem(eqs_nonaut, t)
63+
@mtkcompile rober_nonaut = ODESystem(eqs_nonaut, t)
6464
prob = ODEProblem(rober_nonaut, [[y₁, y₂, y₃] .=> [1.0; 0.0; 0.0]; [k₁, k₂, k₃] .=> (0.04, 3e7, 1e4)], (0.0, 1e5), jac = true)
6565
prob2 = ODEProblem(rober_nonaut, [[y₁, y₂, y₃] .=> [1.0; 0.0; 0.0]; [k₁, k₂, k₃] .=> (0.04, 3e7, 1e4)], (0.0, 1e5), jac = true, nlstep = true)
6666

6767
sol1 = solve(prob, FBDF(autodiff=AutoFiniteDiff(), nlsolve = nlalg));
6868
sol2 = solve(prob2, FBDF(autodiff=AutoFiniteDiff(), nlsolve = nlalg));
6969

7070
@test sol1.t != sol2.t
71-
@test sol1 != sol2
71+
@test sol1.u != sol2.u
7272
@test sol1(sol1.t) sol2(sol1.t) atol=1e-3
7373

7474
sol1 = solve(prob, TRBDF2(autodiff=AutoFiniteDiff(), nlsolve = nlalg));

0 commit comments

Comments
 (0)