Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit b8cc83d

Browse files
committed
Add tests
1 parent 1033c15 commit b8cc83d

File tree

5 files changed

+31
-15
lines changed

5 files changed

+31
-15
lines changed

ext/SimpleNonlinearSolvePolyesterForwardDiffExt.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,10 @@ using SimpleNonlinearSolve, PolyesterForwardDiff
1010
return J
1111
end
1212

13-
end
13+
@inline function SimpleNonlinearSolve.__polyester_forwarddiff_jacobian!(f::F, J, x,
14+
chunksize) where {F}
15+
PolyesterForwardDiff.threaded_jacobian!(f, J, x, chunksize)
16+
return J
17+
end
18+
19+
end

src/nlsolve/halley.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@ A low-overhead implementation of Halley's Method.
1212
1313
### Keyword Arguments
1414
15-
- `autodiff`: determines the backend used for the Hessian. Defaults to
16-
`AutoForwardDiff()`. Valid choices are `AutoForwardDiff()` or `AutoFiniteDiff()`.
15+
- `autodiff`: determines the backend used for the Hessian. Defaults to `nothing`. Valid
16+
choices are `AutoForwardDiff()` or `AutoFiniteDiff()`.
1717
1818
!!! warning
1919
2020
Inplace Problems are currently not supported by this method.
2121
"""
2222
@kwdef @concrete struct SimpleHalley <: AbstractNewtonAlgorithm
23-
autodiff = AutoForwardDiff()
23+
autodiff = nothing
2424
end
2525

2626
function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
@@ -33,6 +33,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
3333
fx = _get_fx(prob, x)
3434
T = eltype(x)
3535

36+
autodiff = __get_concrete_autodiff(prob, alg.autodiff; polyester = Val(false))
3637
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
3738
termination_condition)
3839

@@ -50,7 +51,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
5051

5152
for i in 1:maxiters
5253
# Hessian Computation is unfortunately type unstable
53-
fx, dfx, d2fx = compute_jacobian_and_hessian(alg.autodiff, prob, fx, x)
54+
fx, dfx, d2fx = compute_jacobian_and_hessian(autodiff, prob, fx, x)
5455
setindex_trait(x) === CannotSetindex() && (A = dfx)
5556

5657
# Factorize Once and Reuse

src/utils.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ function value_and_jacobian(ad, f::F, y, x::Number, p, cache; J = nothing) where
119119
T = typeof(__standard_tag(ad.tag, x))
120120
out = f(ForwardDiff.Dual{T}(x, one(x)), p)
121121
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
122+
elseif ad isa AutoPolyesterForwardDiff
123+
# Just use ForwardDiff
124+
T = typeof(__standard_tag(nothing, x))
125+
out = f(ForwardDiff.Dual{T}(x, one(x)), p)
126+
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
122127
elseif ad isa AutoFiniteDiff
123128
_f = Base.Fix2(f, p)
124129
return _f(x), FiniteDiff.finite_difference_derivative(_f, x, ad.fdtype)
@@ -153,7 +158,7 @@ function jacobian_cache(ad, f::F, y, x::X, p) where {F, X <: AbstractArray}
153158
J = ArrayInterface.can_setindex(x) ? similar(y, length(y), length(x)) : nothing
154159
return J, __get_jacobian_config(ad, _f, x)
155160
elseif ad isa AutoPolyesterForwardDiff
156-
@assert ArrayInterface.can_setindex(x) "PolyesterForwardDiff requires mutable inputs."
161+
@assert ArrayInterface.can_setindex(x) "PolyesterForwardDiff requires mutable inputs. Use AutoForwardDiff instead."
157162
J = similar(y, length(y), length(x))
158163
return J, __get_jacobian_config(ad, _f, x)
159164
elseif ad isa AutoFiniteDiff
@@ -362,16 +367,17 @@ end
362367
end
363368

364369
# Decide which AD backend to use
365-
@inline __get_concrete_autodiff(prob, ad::ADTypes.AbstractADType) = ad
366-
@inline function __get_concrete_autodiff(prob, ::Nothing)
370+
@inline __get_concrete_autodiff(prob, ad::ADTypes.AbstractADType; kwargs...) = ad
371+
@inline function __get_concrete_autodiff(prob, ::Nothing; polyester::Val{P} = Val(true),
372+
kwargs...) where {P}
367373
if ForwardDiff.can_dual(eltype(prob.u0))
368-
if __is_extension_loaded(Val(:PolyesterForwardDiff)) && !(prob.u0 isa Number) &&
369-
ArrayInterface.can_setindex(prob.u0)
374+
if P && __is_extension_loaded(Val(:PolyesterForwardDiff)) &&
375+
!(prob.u0 isa Number) && ArrayInterface.can_setindex(prob.u0)
370376
return AutoPolyesterForwardDiff()
371377
else
372378
return AutoForwardDiff()
373379
end
374380
else
375381
return AutoFiniteDiff()
376382
end
377-
end
383+
end

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
77
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
88
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
99
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
10+
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1213
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

test/basictests.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using AllocCheck, BenchmarkTools, LinearSolve, SimpleNonlinearSolve, StaticArrays, Random,
22
LinearAlgebra, Test, ForwardDiff, DiffEqBase
3+
import PolyesterForwardDiff
34

45
_nameof(x) = applicable(nameof, x) ? nameof(x) : _nameof(typeof(x))
56

@@ -29,20 +30,21 @@ const TERMINATION_CONDITIONS = [
2930
@testset "$(alg)" for alg in (SimpleNewtonRaphson, SimpleTrustRegion)
3031
# Eval else the alg is type unstable
3132
@eval begin
32-
function benchmark_nlsolve_oop(f, u0, p = 2.0; autodiff = AutoForwardDiff())
33+
function benchmark_nlsolve_oop(f, u0, p = 2.0; autodiff = nothing)
3334
prob = NonlinearProblem{false}(f, u0, p)
3435
return solve(prob, $(alg)(; autodiff), abstol = 1e-9)
3536
end
3637

37-
function benchmark_nlsolve_iip(f, u0, p = 2.0; autodiff = AutoForwardDiff())
38+
function benchmark_nlsolve_iip(f, u0, p = 2.0; autodiff = nothing)
3839
prob = NonlinearProblem{true}(f, u0, p)
3940
return solve(prob, $(alg)(; autodiff), abstol = 1e-9)
4041
end
4142
end
4243

4344
@testset "AutoDiff: $(_nameof(autodiff))" for autodiff in (AutoFiniteDiff(),
44-
AutoForwardDiff())
45+
AutoForwardDiff(), AutoPolyesterForwardDiff())
4546
@testset "[OOP] u0: $(typeof(u0))" for u0 in ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)
47+
u0 isa SVector && autodiff isa AutoPolyesterForwardDiff && continue
4648
sol = benchmark_nlsolve_oop(quadratic_f, u0; autodiff)
4749
@test SciMLBase.successful_retcode(sol)
4850
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)
@@ -103,7 +105,7 @@ end
103105
# --- SimpleHalley tests ---
104106

105107
@testset "SimpleHalley" begin
106-
function benchmark_nlsolve_oop(f, u0, p = 2.0; autodiff = AutoForwardDiff())
108+
function benchmark_nlsolve_oop(f, u0, p = 2.0; autodiff = nothing)
107109
prob = NonlinearProblem{false}(f, u0, p)
108110
return solve(prob, SimpleHalley(; autodiff), abstol = 1e-9)
109111
end

0 commit comments

Comments
 (0)