Skip to content

Commit

Permalink
feat: add PETScSNES
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 25, 2024
1 parent 21b02bd commit 8364408
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 23 deletions.
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NLSolvers = "337daf1e-9722-11e9-073e-8b9effe078ba"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
PETSc = "ace2c81b-2b5f-4b1e-a30d-d662738edfe0"
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Expand All @@ -55,6 +57,7 @@ NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
NonlinearSolveMINPACKExt = "MINPACK"
NonlinearSolveNLSolversExt = "NLSolvers"
NonlinearSolveNLsolveExt = ["NLsolve", "LineSearches"]
NonlinearSolvePETScExt = ["PETSc", "MPI"]
NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations"
NonlinearSolveSpeedMappingExt = "SpeedMapping"
NonlinearSolveSundialsExt = "Sundials"
Expand Down Expand Up @@ -86,13 +89,15 @@ LineSearches = "7.3"
LinearAlgebra = "1.10"
LinearSolve = "2.35"
MINPACK = "1.2"
MPI = "0.20.22"
MaybeInplace = "0.1.4"
NLSolvers = "0.5"
NLsolve = "4.5"
NaNMath = "1"
NonlinearProblemLibrary = "0.1.2"
NonlinearSolveBase = "1"
OrdinaryDiffEqTsit5 = "1.1.0"
PETSc = "0.2"
Pkg = "1.10"
PrecompileTools = "1.2"
Preferences = "1.4"
Expand Down
17 changes: 17 additions & 0 deletions docs/src/api/petsc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# PETSc.jl

This is a extension for importing solvers from PETSc.jl SNES into the SciML interface. Note
that these solvers do not come by default, and thus one needs to install the package before
using these solvers:

```julia
using Pkg
Pkg.add("PETSc")
using PETSc, NonlinearSolve
```

## Solver API

```@docs
PETScSNES
```
9 changes: 9 additions & 0 deletions docs/src/solvers/nonlinear_system_solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,12 @@ This is a wrapper package for importing solvers from NLSolvers.jl into the SciML
[NLSolvers.jl](https://github.com/JuliaNLSolvers/NLSolvers.jl)

For a list of possible solvers see the [NLSolvers.jl documentation](https://julianlsolvers.github.io/NLSolvers.jl/)

### PETSc.jl

This is a wrapper package for importing solvers from PETSc.jl into the SciML interface.

- [`PETScSNES()`](@ref): A wrapper for
[PETSc.jl](https://github.com/JuliaParallel/PETSc.jl)

For a list of possible solvers see the [PETSc.jl documentation](https://petsc.org/release/manual/snes/)
85 changes: 85 additions & 0 deletions ext/NonlinearSolvePETScExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
module NonlinearSolvePETScExt

using FastClosures: @closure
using MPI: MPI
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
using NonlinearSolve: NonlinearSolve, PETScSNES
using PETSc: PETSc
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode

function SciMLBase.__solve(
prob::NonlinearProblem, alg::PETScSNES, args...; abstol = nothing, reltol = nothing,
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing,
show_trace::Val{ShT} = Val(false), kwargs...) where {ShT}
termination_condition === nothing ||
error("`PETScSNES` does not support termination conditions!")

_f!, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
T = eltype(prob.u0)

if alg.petsclib === missing
petsclibidx = findfirst(PETSc.petsclibs) do petsclib
petsclib isa PETSc.PetscLibType{T}
end

if petsclibidx === nothing
error("No compatible PETSc library found for element type $(T). Pass in a \
custom `petsclib` via `PETScSNES(; petsclib = <petsclib>, ....)`.")
end
petsclib = PETSc.petsclibs[petsclibidx]
else
petsclib = alg.petsclib
end
PETSc.initialized(petsclib) || PETSc.initialize(petsclib)

abstol = get_tolerance(abstol, T)
reltol = get_tolerance(reltol, T)

f! = @closure (cfx, cx, user_ctx) -> begin
fx = cfx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cfx; read = false) : cfx
x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx
_f!(fx, x)
Base.finalize(fx)
Base.finalize(x)
return
end

snes = PETSc.SNES{T}(petsclib,
alg.mpi_comm === missing ? MPI.COMM_SELF : alg.mpi_comm;
alg.snes_options..., snes_monitor = ShT, snes_rtol = reltol,
snes_atol = abstol, snes_max_it = maxiters)

if alg.autodiff === missing && prob.f.jac === nothing
_jac! = nothing
else
autodiff = alg.autodiff === missing ? nothing : alg.autodiff
_jac! = NonlinearSolve.__construct_extension_jac(prob, alg, u0, resid; autodiff)
end

PETSc.setfunction!(snes, f!, PETSc.VecSeq(zero(u0)))

if _jac! !== nothing # XXX: Sparsity Handling???
PJ = PETSc.MatSeqDense(zeros(T, length(resid), length(u0)))
jac! = @closure (cx, J, _, user_ctx) -> begin
x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx
_jac!(J, x)
Base.finalize(x)
PETSc.assemble(J)
return
end
PETSc.setjacobian!(snes, jac!, PJ, PJ)
end

res = PETSc.solve!(u0, snes)

_f!(resid, res)
u_ = prob.u0 isa Number ? res[1] : res
resid_ = prob.u0 isa Number ? resid[1] : resid

objective = maximum(abs, resid)
# XXX: Return Code from PETSc
retcode = ifelse(objective abstol, ReturnCode.Success, ReturnCode.Failure)
return SciMLBase.build_solution(prob, alg, u_, resid_; retcode, original = snes)
end

end
12 changes: 11 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ include("algorithms/extension_algs.jl")
include("utils.jl")
include("default.jl")

const ALL_SOLVER_TYPES = [
Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm,
LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL,
CMINPACK, PETScSNES,
NonlinearSolvePolyAlgorithm{:NLLS, <:Any}, NonlinearSolvePolyAlgorithm{:NLS, <:Any}
]

include("internal/forward_diff.jl") # we need to define after the algorithms

@setup_workload begin
Expand Down Expand Up @@ -171,8 +180,9 @@ export NonlinearSolvePolyAlgorithm, RobustMultiNewton, FastShortcutNonlinearPoly
FastShortcutNLLSPolyalg

# Extension Algorithms
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSolversJL,
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
export PETScSNES, CMINPACK

# Advanced Algorithms -- Without Bells and Whistles
export GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm, GeneralizedDFSane
Expand Down
65 changes: 58 additions & 7 deletions src/algorithms/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ for solving `NonlinearLeastSquaresProblem`.
!!! note
This algorithm is only available if `LeastSquaresOptim.jl` is installed.
This algorithm is only available if `LeastSquaresOptim.jl` is installed and loaded.
"""
struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearSolveExtensionAlgorithm
autodiff
Expand Down Expand Up @@ -65,7 +65,7 @@ see the documentation for `FastLevenbergMarquardt.jl`.
!!! note
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed and loaded.
"""
@concrete struct FastLevenbergMarquardtJL{linsolve} <:
AbstractNonlinearSolveExtensionAlgorithm
Expand Down Expand Up @@ -139,7 +139,7 @@ NonlinearLeastSquaresProblem.
!!! note
This algorithm is only available if `MINPACK.jl` is installed.
This algorithm is only available if `MINPACK.jl` is installed and loaded.
"""
@concrete struct CMINPACK <: AbstractNonlinearSolveExtensionAlgorithm
method::Symbol
Expand Down Expand Up @@ -199,7 +199,7 @@ For more information on these arguments, consult the
!!! note
This algorithm is only available if `NLsolve.jl` is installed.
This algorithm is only available if `NLsolve.jl` is installed and loaded.
"""
@concrete struct NLsolveJL <: AbstractNonlinearSolveExtensionAlgorithm
method::Symbol
Expand Down Expand Up @@ -281,7 +281,7 @@ Fixed Point Problems. We allow using this algorithm to solve root finding proble
!!! note
This algorithm is only available if `SpeedMapping.jl` is installed.
This algorithm is only available if `SpeedMapping.jl` is installed and loaded.
"""
@concrete struct SpeedMappingJL <: AbstractNonlinearSolveExtensionAlgorithm
σ_min
Expand Down Expand Up @@ -324,7 +324,7 @@ problems as well.
!!! note
This algorithm is only available if `FixedPointAcceleration.jl` is installed.
This algorithm is only available if `FixedPointAcceleration.jl` is installed and loaded.
"""
@concrete struct FixedPointAccelerationJL <: AbstractNonlinearSolveExtensionAlgorithm
algorithm::Symbol
Expand Down Expand Up @@ -402,7 +402,7 @@ end
!!! note
This algorithm is only available if `SIAMFANLEquations.jl` is installed.
This algorithm is only available if `SIAMFANLEquations.jl` is installed and loaded.
"""
@concrete struct SIAMFANLEquationsJL{L <: Union{Symbol, Nothing}} <:
AbstractNonlinearSolveExtensionAlgorithm
Expand All @@ -421,3 +421,54 @@ function SIAMFANLEquationsJL(; method = :newton, delta = 1e-3, linsolve = nothin
end
return SIAMFANLEquationsJL(method, delta, linsolve, m, beta, autodiff)
end

"""
PETScSNES(; petsclib = missing, autodiff = nothing, mpi_comm = missing, kwargs...)
Wrapper over [PETSc.jl](https://github.com/JuliaParallel/PETSc.jl) SNES solvers.
### Keyword Arguments
- `petsclib`: PETSc library to use. If set to `missing`, then we will use the first
available PETSc library in `PETSc.petsclibs` based on the problem element type.
- `autodiff`: the choice of method for generating the Jacobian. Defaults to `nothing`
which means that a default is selected according to the problem specification. Can be
any valid ADTypes.jl autodiff type (conditional on that backend being supported in
NonlinearSolve.jl). If set to `missing`, then PETSc computes the Jacobian using finite
differences.
- `mpi_comm`: MPI communicator to use. If set to `missing`, then we will use
`MPI.COMM_SELF`.
- `kwargs`: Keyword arguments to be passed to the PETSc SNES solver. See [PETSc SNES
documentation](https://petsc.org/release/manual/snes/) and
[SNESSetFromOptions](https://petsc.org/release/manualpages/SNES/SNESSetFromOptions)
for more information.
### Options via `CommonSolve.solve`
These options are forwarded from `solve` to the PETSc SNES solver. If these are provided to
`kwargs`, then they will be ignored.
| `solve` option | PETSc SNES option |
|:-------------- |:----------------- |
| `atol` | `snes_atol` |
| `rtol` | `snes_rtol` |
| `maxiters` | `snes_max_it` |
| `show_trace` | `snes_monitor` |
!!! note
This algorithm is only available if `PETSc.jl` is installed and loaded.
"""
@concrete struct PETScSNES <: AbstractNonlinearSolveExtensionAlgorithm
petsclib
mpi_comm
autodiff <: Union{Missing, Nothing, ADTypes.AbstractADType}
snes_options
end

function PETScSNES(; petsclib = missing, autodiff = nothing, mpi_comm = missing, kwargs...)
if Base.get_extension(@__MODULE__, :NonlinearSolvePETScExt) === nothing
error("PETScSNES requires PETSc.jl to be loaded")
end
return PETScSNES(petsclib, mpi_comm, autodiff, kwargs)
end
17 changes: 2 additions & 15 deletions src/internal/forward_diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,7 @@ const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem}

for algType in (
Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm,
LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSolversJL,
SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL,
NonlinearSolvePolyAlgorithm{:NLLS, <:Any}, NonlinearSolvePolyAlgorithm{:NLS, <:Any}
)
for algType in ALL_SOLVER_TYPES
@eval function SciMLBase.__solve(
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...)
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
Expand Down Expand Up @@ -43,14 +37,7 @@ function reinit_cache!(cache::NonlinearSolveForwardDiffCache;
return cache
end

for algType in (
Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm,
GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm,
LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSolversJL,
SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL,
NonlinearSolvePolyAlgorithm{:NLLS, <:Any}, NonlinearSolvePolyAlgorithm{:NLS, <:Any}
)
for algType in ALL_SOLVER_TYPES
@eval function SciMLBase.__init(
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...)
p = __value(prob.p)
Expand Down

0 comments on commit 8364408

Please sign in to comment.