Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
[weakdeps]
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
Gridap = "56d4f2e9-7ea1-5844-9cf6-b9c51ca7ce8e"
GridapPETSc = "bcdc36c2-0c3e-11ea-095a-c9dadae499f1"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
Expand All @@ -45,9 +47,16 @@ SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"

[sources]
NonlinearSolveBase = {path = "lib/NonlinearSolveBase"}
NonlinearSolveFirstOrder = {path = "lib/NonlinearSolveFirstOrder"}
NonlinearSolveQuasiNewton = {path = "lib/NonlinearSolveQuasiNewton"}
NonlinearSolveSpectralMethods = {path = "lib/NonlinearSolveSpectralMethods"}

[extensions]
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
NonlinearSolveFixedPointAccelerationExt = "FixedPointAcceleration"
NonlinearSolveGridapPETScExt = ["Gridap", "GridapPETSc"]
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
NonlinearSolveMINPACKExt = "MINPACK"
NonlinearSolveNLSolversExt = "NLSolvers"
Expand Down
123 changes: 123 additions & 0 deletions ext/NonlinearSolveGridapPETScExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
module NonlinearSolveGridapPETScExt

using Gridap: Gridap, Algebra
using GridapPETSc: GridapPETSc

using NonlinearSolveBase: NonlinearSolveBase
using NonlinearSolve: NonlinearSolve, GridapPETScSNES
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode

using ConcreteStructs: @concrete
using FastClosures: @closure

@concrete struct NonlinearSolveOperator <: Algebra.NonlinearOperator
f!
jac!
initial_guess_cache
resid_prototype
jacobian_prototype
end

function Algebra.residual!(b::AbstractVector, op::NonlinearSolveOperator, x::AbstractVector)
op.f!(b, x)
end

function Algebra.jacobian!(
A::AbstractMatrix, op::NonlinearSolveOperator, x::AbstractVector
)
op.jac!(A, x)
end

function Algebra.zero_initial_guess(op::NonlinearSolveOperator)
fill!(op.initial_guess_cache, 0)
return op.initial_guess_cache
end

function Algebra.allocate_residual(op::NonlinearSolveOperator, ::AbstractVector)
fill!(op.resid_prototype, 0)
return op.resid_prototype
end

function Algebra.allocate_jacobian(op::NonlinearSolveOperator, ::AbstractVector)
fill!(op.jacobian_prototype, 0)
return op.jacobian_prototype
end

# TODO: Later we should just wrap `Gridap` generally and pass in `PETSc` as the solver
function SciMLBase.__solve(
prob::NonlinearProblem, alg::GridapPETScSNES, args...;
abstol = nothing, reltol = nothing,
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing,
show_trace::Val = Val(false), kwargs...
)
# XXX: https://petsc.org/release/manualpages/SNES/SNESSetConvergenceTest/
NonlinearSolveBase.assert_extension_supported_termination_condition(
termination_condition, alg; abs_norm_supported = false
)

f_wrapped!, u0, resid = NonlinearSolveBase.construct_extension_function_wrapper(
prob; alias_u0
)
T = eltype(u0)

abstol = NonlinearSolveBase.get_tolerance(abstol, T)
reltol = NonlinearSolveBase.get_tolerance(reltol, T)

nf = Ref{Int}(0)

f! = @closure (fx, x) -> begin
nf[] += 1
f_wrapped!(fx, x)
return fx
end

if prob.u0 isa Number
jac! = NonlinearSolveBase.construct_extension_jac(
prob, alg, prob.u0, prob.u0; alg.autodiff
)
J_init = zeros(T, 1, 1)
else
jac!, J_init = NonlinearSolveBase.construct_extension_jac(
prob, alg, u0, resid; alg.autodiff, initial_jacobian = Val(true)
)
end

njac = Ref{Int}(-1)
jac_fn! = @closure (J, x) -> begin
njac[] += 1
jac!(J, x)
return J
end

nop = NonlinearSolveOperator(f!, jac_fn!, u0, resid, J_init)

petsc_args = [
"-snes_rtol", string(reltol), "-snes_atol", string(abstol),
"-snes_max_it", string(maxiters)
]
for (k, v) in pairs(alg.snes_options)
push!(petsc_args, "-$(k)")
push!(petsc_args, string(v))
end
show_trace isa Val{true} && push!(petsc_args, "-snes_monitor")

# TODO: We can reuse the cache returned from this function
sol_u = GridapPETSc.with(args = petsc_args) do
sol_u = copy(u0)
Algebra.solve!(sol_u, GridapPETSc.PETScNonlinearSolver(), nop)
return sol_u
end

f_wrapped!(resid, sol_u)
u_res = prob.u0 isa Number ? sol_u[1] : sol_u
resid_res = prob.u0 isa Number ? resid[1] : resid

objective = maximum(abs, resid)
retcode = ifelse(objective ≤ abstol, ReturnCode.Success, ReturnCode.Failure)
return SciMLBase.build_solution(
prob, alg, u_res, resid_res;
retcode, stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1)
)
end

end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
end
end

12 changes: 9 additions & 3 deletions ext/NonlinearSolvePETScExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ function SciMLBase.__solve(
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing,
show_trace::Val = Val(false), kwargs...
)
if !MPI.Initialized()
@warn "MPI not initialized. Initializing MPI with MPI.Init()." maxlog=1
MPI.Init()
end

# XXX: https://petsc.org/release/manualpages/SNES/SNESSetConvergenceTest/
NonlinearSolveBase.assert_extension_supported_termination_condition(
termination_condition, alg; abs_norm_supported = false
Expand Down Expand Up @@ -68,8 +73,10 @@ function SciMLBase.__solve(
PETSc.setfunction!(snes, f!, PETSc.VecSeq(zero(u0)))

njac = Ref{Int}(-1)
if alg.autodiff !== missing || prob.f.jac !== nothing
# `missing` -> let PETSc compute the Jacobian using finite differences
if alg.autodiff !== missing
autodiff = alg.autodiff === missing ? nothing : alg.autodiff

if prob.u0 isa Number
jac! = NonlinearSolveBase.construct_extension_jac(
prob, alg, prob.u0, prob.u0; autodiff
Expand Down Expand Up @@ -125,8 +132,7 @@ function SciMLBase.__solve(
retcode = ifelse(objective ≤ abstol, ReturnCode.Success, ReturnCode.Failure)
return SciMLBase.build_solution(
prob, alg, u_res, resid_res;
retcode, original = snes,
stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1)
retcode, stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1)
)
end

Expand Down
3 changes: 3 additions & 0 deletions lib/BracketingNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,6 @@ TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a"

[targets]
test = ["Aqua", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "Test", "TestItemRunner"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
3 changes: 3 additions & 0 deletions lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "BandedMatrices", "DiffEqBase", "ExplicitImports", "ForwardDiff", "InteractiveUtils", "LinearAlgebra", "SparseArrays", "Test"]

[sources]
SciMLJacobianOperators = {path = "../SciMLJacobianOperators"}
4 changes: 4 additions & 0 deletions lib/NonlinearSolveFirstOrder/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ForwardDiff", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
SciMLJacobianOperators = {path = "../SciMLJacobianOperators"}
3 changes: 3 additions & 0 deletions lib/NonlinearSolveHomotopyContinuation/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "NonlinearSolve", "Enzyme", "NaNMath"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
3 changes: 3 additions & 0 deletions lib/NonlinearSolveQuasiNewton/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ADTypes", "Aqua", "BenchmarkTools", "Enzyme", "ExplicitImports", "FiniteDiff", "ForwardDiff", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "ReTestItems", "StableRNGs", "StaticArrays", "Test", "Zygote"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
3 changes: 3 additions & 0 deletions lib/NonlinearSolveSpectralMethods/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "BenchmarkTools", "ExplicitImports", "Hwloc", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "ReTestItems", "StableRNGs", "StaticArrays", "Test"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
3 changes: 3 additions & 0 deletions lib/SCCNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "BenchmarkTools", "ExplicitImports", "Hwloc", "InteractiveUtils", "NonlinearSolveFirstOrder", "NonlinearProblemLibrary", "Pkg", "ReTestItems", "StableRNGs", "StaticArrays", "Test"]

[sources]
NonlinearSolveFirstOrder = {path = "../NonlinearSolveFirstOrder"}
3 changes: 3 additions & 0 deletions lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "DiffEqBase", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"]

[sources]
NonlinearSolveBase = {path = "../NonlinearSolveBase"}
2 changes: 1 addition & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,6 @@ export NonlinearSolvePolyAlgorithm, FastShortcutNonlinearPolyalg, FastShortcutNL
# Extension Algorithms
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
export PETScSNES, CMINPACK
export PETScSNES, GridapPETScSNES, CMINPACK

end
13 changes: 13 additions & 0 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,16 @@ function PETScSNES(; petsclib = missing, autodiff = nothing, mpi_comm = missing,
end
return PETScSNES(petsclib, mpi_comm, autodiff, kwargs)
end

# TODO: Docs
@concrete struct GridapPETScSNES <: AbstractNonlinearSolveAlgorithm
autodiff
snes_options
end

function GridapPETScSNES(; autodiff = nothing, kwargs...)
if Base.get_extension(@__MODULE__, :NonlinearSolveGridapPETScExt) === nothing
error("`GridapPETScSNES` requires `GridapPETSc.jl` to be loaded")
end
return GridapPETScSNES(autodiff, kwargs)
end
Loading