Skip to content

Commit

Permalink
feat: SimpleNewtonRaphson
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 4, 2024
1 parent 1d7df8d commit fc2d60c
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ AlgebraicMultigrid = "0.5, 0.6"
ArrayInterface = "6, 7"
BenchmarkTools = "1"
DiffEqBase = "6.136"
DifferentiationInterface = "0.6"
DifferentiationInterface = "0.6.1"
Documenter = "1"
DocumenterCitations = "1"
DocumenterInterLinks = "1.0.0"
Expand Down
17 changes: 10 additions & 7 deletions lib/NonlinearSolveBase/src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,21 @@

# Ordering is important here. We want to select the first one that is compatible with the
# problem.
const ReverseADs = [
const ReverseADs = (
ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse),
ADTypes.AutoZygote(),
ADTypes.AutoTracker(),
ADTypes.AutoReverseDiff(; compile = true),
ADTypes.AutoReverseDiff(),
ADTypes.AutoFiniteDiff()
]
)

const ForwardADs = [
const ForwardADs = (
ADTypes.AutoEnzyme(; mode = EnzymeCore.Forward),
ADTypes.AutoPolyesterForwardDiff(),
ADTypes.AutoForwardDiff(),
ADTypes.AutoFiniteDiff()
]
)

# TODO: Handle Sparsity

Expand All @@ -28,7 +28,8 @@ function select_forward_mode_autodiff(
end
if incompatible_backend_and_problem(prob, ad)
adₙ = select_forward_mode_autodiff(prob, nothing; warn_check_mode)
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. This \
could be because the backend package for the choosen AD isn't loaded. After \

Check warning on line 32 in lib/NonlinearSolveBase/src/autodiff.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"choosen" should be "chosen".
running autodiff selection detected `$(adₙ)` as a potential forward mode \
backend."
return adₙ
Expand Down Expand Up @@ -57,7 +58,8 @@ function select_reverse_mode_autodiff(
end
if incompatible_backend_and_problem(prob, ad)
adₙ = select_reverse_mode_autodiff(prob, nothing; warn_check_mode)
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. This \
could be because the backend package for the choosen AD isn't loaded. After \

Check warning on line 62 in lib/NonlinearSolveBase/src/autodiff.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"choosen" should be "chosen".
running autodiff selection detected `$(adₙ)` as a potential reverse mode \
backend."
return adₙ
Expand All @@ -77,7 +79,8 @@ end
function select_jacobian_autodiff(prob::AbstractNonlinearProblem, ad::AbstractADType)
if incompatible_backend_and_problem(prob, ad)
adₙ = select_jacobian_autodiff(prob, nothing)
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \
@warn "The chosen AD backend `$(ad)` does not support the chosen problem. This \
could be because the backend package for the choosen AD isn't loaded. After \

Check warning on line 83 in lib/NonlinearSolveBase/src/autodiff.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"choosen" should be "chosen".
running autodiff selection detected `$(adₙ)` as a potential jacobian \
backend."
return adₙ
Expand Down
7 changes: 6 additions & 1 deletion lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
Expand Down Expand Up @@ -38,11 +39,13 @@ ArrayInterface = "7.16"
BracketingNonlinearSolve = "1"
ChainRulesCore = "1.24"
CommonSolve = "0.2.4"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.155"
DifferentiationInterface = "0.5.17"
DifferentiationInterface = "0.6.1"
FastClosures = "0.3.2"
FiniteDiff = "2.24.0"
ForwardDiff = "0.10.36"
InteractiveUtils = "<0.0.1, 1"
LinearAlgebra = "1.10"
MaybeInplace = "0.1.4"
NonlinearSolveBase = "1"
Expand All @@ -51,6 +54,8 @@ Reexport = "1.2"
ReverseDiff = "1.15"
SciMLBase = "2.50"
StaticArraysCore = "1.4.3"
Test = "1.10"
TestItemRunner = "1"
Tracker = "0.2.35"
julia = "1.10"

Expand Down
10 changes: 9 additions & 1 deletion lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module SimpleNonlinearSolve

using CommonSolve: CommonSolve, solve
using ConcreteStructs: @concrete
using FastClosures: @closure
using MaybeInplace: @bb
using PrecompileTools: @compile_workload, @setup_workload
Expand All @@ -27,6 +28,7 @@ is_extension_loaded(::Val) = false
include("utils.jl")

include("klement.jl")
include("raphson.jl")

# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
function CommonSolve.solve(prob::NonlinearProblem,
Expand Down Expand Up @@ -69,7 +71,10 @@ function solve_adjoint_internal end
prob_iip = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, ones(T, 3), T(2))
prob_oop = NonlinearProblem{false}((u, p) -> u .* u .- p, ones(T, 3), T(2))

algs = [SimpleKlement()]
algs = [
SimpleKlement(),
SimpleNewtonRaphson()
]
algs_no_iip = []

@compile_workload begin
Expand All @@ -87,4 +92,7 @@ export AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff

export Alefeld, Bisection, Brent, Falsi, ITP, Ridder

export SimpleKlement
export SimpleGaussNewton, SimpleNewtonRaphson

end
62 changes: 62 additions & 0 deletions lib/SimpleNonlinearSolve/src/raphson.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""
SimpleNewtonRaphson(autodiff)
SimpleNewtonRaphson(; autodiff = nothing)
A low-overhead implementation of Newton-Raphson. This method is non-allocating on scalar
and static array problems.
!!! note
As part of the decreased overhead, this method omits some of the higher level error
catching of the other methods. Thus, to see better error messages, use one of the other
methods like `NewtonRaphson`.
### Keyword Arguments
- `autodiff`: determines the backend used for the Jacobian. Defaults to `nothing` (i.e.
automatic backend selection). Valid choices include jacobian backends from
`DifferentiationInterface.jl`.
"""
@kwdef @concrete struct SimpleNewtonRaphson <: AbstractSimpleNonlinearSolveAlgorithm
autodiff = nothing
end

const SimpleGaussNewton = SimpleNewtonRaphson

function SciMLBase.__solve(
prob::ImmutableNonlinearProblem, alg::SimpleNewtonRaphson, args...;
abstol = nothing, reltol = nothing, maxiters = 1000,
alias_u0 = false, termination_condition = nothing, kwargs...)
x = Utils.maybe_unaliased(prob.u0, alias_u0)
fx = Utils.get_fx(prob, x)
fx = Utils.eval_f(prob, fx, x)

iszero(fx) &&
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)

abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache(
prob, abstol, reltol, fx, x, termination_condition, Val(:simple))

autodiff = SciMLBase.has_jac(prob.f) ? alg.autodiff :
NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff)

@bb xo = similar(x)
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? similar(fx) :
nothing
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)

for _ in 1:maxiters
@bb copyto!(xo, x)
δx = Utils.restructure(x, J \ Utils.safe_vec(fx))
@bb x .-= δx

solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob)
solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode)

fx = Utils.eval_f(prob, fx, x)
J = Utils.compute_jacobian!!(J, prob, autodiff, fx_cache, x, jac_cache)
end

return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
end
70 changes: 69 additions & 1 deletion lib/SimpleNonlinearSolve/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module Utils

using ADTypes: AbstractADType, AutoForwardDiff, AutoFiniteDiff, AutoPolyesterForwardDiff
using ArrayInterface: ArrayInterface
using DifferentiationInterface: DifferentiationInterface
using DifferentiationInterface: DifferentiationInterface, Constant
using FastClosures: @closure
using LinearAlgebra: LinearAlgebra, I, diagind
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem,
Expand Down Expand Up @@ -132,4 +132,72 @@ function check_termination(
return false, ReturnCode.Default, fx, x
end

restructure(y, x) = ArrayInterface.restructure(y, x)
restructure(::Number, x::Number) = x

safe_vec(x::AbstractArray) = vec(x)
safe_vec(x::Number) = x

function prepare_jacobian(prob, autodiff, _, x::Number)
if SciMLBase.has_jac(prob.f) || SciMLBase.has_vjp(prob.f) || SciMLBase.has_jvp(prob.f)
return nothing
end
return DI.prepare_derivative(prob.f, autodiff, x, Constant(prob.p))
end
function prepare_jacobian(prob, autodiff, fx, x)
if SciMLBase.has_jac(prob.f)
return nothing
end
if SciMLBase.isinplace(prob.f)
return DI.prepare_jacobian(prob.f, fx, autodiff, x, Constant(prob.p))
else
return DI.prepare_jacobian(prob.f, autodiff, x, Constant(prob.p))
end
end

function compute_jacobian!!(_, prob, autodiff, fx, x::Number, extras)
if extras === nothing
if SciMLBase.has_jac(prob.f)
return prob.f.jac(x, prob.p)
elseif SciMLBase.has_vjp(prob.f)
return prob.f.vjp(one(x), x, prob.p)
elseif SciMLBase.has_jvp(prob.f)
return prob.f.jvp(one(x), x, prob.p)
end
end
return DI.derivative(prob.f, extras, autodiff, x, Constant(prob.p))
end
function compute_jacobian!!(J, prob, autodiff, fx, x, extras)
if J === nothing
if extras === nothing
if SciMLBase.isinplace(prob.f)
J = similar(fx, length(fx), length(x))
prob.f.jac(J, x, prob.p)
return J
else
return prob.f.jac(x, prob.p)
end
end
if SciMLBase.isinplace(prob)
return DI.jacobian(prob.f, fx, extras, autodiff, x, Constant(prob.p))
else
return DI.jacobian(prob.f, extras, autodiff, x, Constant(prob.p))
end
end
if extras === nothing
if SciMLBase.isinplace(prob)
prob.jac(J, x, prob.p)
return J
else
return prob.jac(x, prob.p)
end
end
if SciMLBase.isinplace(prob)
DI.jacobian!(prob.f, J, fx, extras, autodiff, x, Constant(prob.p))
else
DI.jacobian!(prob.f, J, extras, autodiff, x, Constant(prob.p))
end
return J
end

end

0 comments on commit fc2d60c

Please sign in to comment.