From fc2d60c2e3f49b904a5ec5fb877b10ff9ecd4b02 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 4 Oct 2024 12:51:43 -0400 Subject: [PATCH] feat: SimpleNewtonRaphson --- docs/Project.toml | 2 +- lib/NonlinearSolveBase/src/autodiff.jl | 17 +++-- lib/SimpleNonlinearSolve/Project.toml | 7 +- .../src/SimpleNonlinearSolve.jl | 10 ++- lib/SimpleNonlinearSolve/src/raphson.jl | 62 ++++++++++++++++ lib/SimpleNonlinearSolve/src/utils.jl | 70 ++++++++++++++++++- 6 files changed, 157 insertions(+), 11 deletions(-) create mode 100644 lib/SimpleNonlinearSolve/src/raphson.jl diff --git a/docs/Project.toml b/docs/Project.toml index ab35d42ae..4ad265246 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -30,7 +30,7 @@ AlgebraicMultigrid = "0.5, 0.6" ArrayInterface = "6, 7" BenchmarkTools = "1" DiffEqBase = "6.136" -DifferentiationInterface = "0.6" +DifferentiationInterface = "0.6.1" Documenter = "1" DocumenterCitations = "1" DocumenterInterLinks = "1.0.0" diff --git a/lib/NonlinearSolveBase/src/autodiff.jl b/lib/NonlinearSolveBase/src/autodiff.jl index f81ce7039..4c057d25e 100644 --- a/lib/NonlinearSolveBase/src/autodiff.jl +++ b/lib/NonlinearSolveBase/src/autodiff.jl @@ -3,21 +3,21 @@ # Ordering is important here. We want to select the first one that is compatible with the # problem. -const ReverseADs = [ +const ReverseADs = ( ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse), ADTypes.AutoZygote(), ADTypes.AutoTracker(), ADTypes.AutoReverseDiff(; compile = true), ADTypes.AutoReverseDiff(), ADTypes.AutoFiniteDiff() -] +) -const ForwardADs = [ +const ForwardADs = ( ADTypes.AutoEnzyme(; mode = EnzymeCore.Forward), ADTypes.AutoPolyesterForwardDiff(), ADTypes.AutoForwardDiff(), ADTypes.AutoFiniteDiff() -] +) # TODO: Handle Sparsity @@ -28,7 +28,8 @@ function select_forward_mode_autodiff( end if incompatible_backend_and_problem(prob, ad) adₙ = select_forward_mode_autodiff(prob, nothing; warn_check_mode) - @warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \ + @warn "The chosen AD backend `$(ad)` does not support the chosen problem. This \ + could be because the backend package for the choosen AD isn't loaded. After \ running autodiff selection detected `$(adₙ)` as a potential forward mode \ backend." return adₙ @@ -57,7 +58,8 @@ function select_reverse_mode_autodiff( end if incompatible_backend_and_problem(prob, ad) adₙ = select_reverse_mode_autodiff(prob, nothing; warn_check_mode) - @warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \ + @warn "The chosen AD backend `$(ad)` does not support the chosen problem. This \ + could be because the backend package for the choosen AD isn't loaded. After \ running autodiff selection detected `$(adₙ)` as a potential reverse mode \ backend." return adₙ @@ -77,7 +79,8 @@ end function select_jacobian_autodiff(prob::AbstractNonlinearProblem, ad::AbstractADType) if incompatible_backend_and_problem(prob, ad) adₙ = select_jacobian_autodiff(prob, nothing) - @warn "The chosen AD backend `$(ad)` does not support the chosen problem. After \ + @warn "The chosen AD backend `$(ad)` does not support the chosen problem. This \ + could be because the backend package for the choosen AD isn't loaded. After \ running autodiff selection detected `$(adₙ)` as a potential jacobian \ backend." return adₙ diff --git a/lib/SimpleNonlinearSolve/Project.toml b/lib/SimpleNonlinearSolve/Project.toml index 7ab2416be..e21eb1c5f 100644 --- a/lib/SimpleNonlinearSolve/Project.toml +++ b/lib/SimpleNonlinearSolve/Project.toml @@ -8,6 +8,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" @@ -38,11 +39,13 @@ ArrayInterface = "7.16" BracketingNonlinearSolve = "1" ChainRulesCore = "1.24" CommonSolve = "0.2.4" +ConcreteStructs = "0.2.3" DiffEqBase = "6.155" -DifferentiationInterface = "0.5.17" +DifferentiationInterface = "0.6.1" FastClosures = "0.3.2" FiniteDiff = "2.24.0" ForwardDiff = "0.10.36" +InteractiveUtils = "<0.0.1, 1" LinearAlgebra = "1.10" MaybeInplace = "0.1.4" NonlinearSolveBase = "1" @@ -51,6 +54,8 @@ Reexport = "1.2" ReverseDiff = "1.15" SciMLBase = "2.50" StaticArraysCore = "1.4.3" +Test = "1.10" +TestItemRunner = "1" Tracker = "0.2.35" julia = "1.10" diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index 0bb65181a..1baa1b50c 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -1,6 +1,7 @@ module SimpleNonlinearSolve using CommonSolve: CommonSolve, solve +using ConcreteStructs: @concrete using FastClosures: @closure using MaybeInplace: @bb using PrecompileTools: @compile_workload, @setup_workload @@ -27,6 +28,7 @@ is_extension_loaded(::Val) = false include("utils.jl") include("klement.jl") +include("raphson.jl") # By Pass the highlevel checks for NonlinearProblem for Simple Algorithms function CommonSolve.solve(prob::NonlinearProblem, @@ -69,7 +71,10 @@ function solve_adjoint_internal end prob_iip = NonlinearProblem{true}((du, u, p) -> du .= u .* u .- p, ones(T, 3), T(2)) prob_oop = NonlinearProblem{false}((u, p) -> u .* u .- p, ones(T, 3), T(2)) - algs = [SimpleKlement()] + algs = [ + SimpleKlement(), + SimpleNewtonRaphson() + ] algs_no_iip = [] @compile_workload begin @@ -87,4 +92,7 @@ export AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff export Alefeld, Bisection, Brent, Falsi, ITP, Ridder +export SimpleKlement +export SimpleGaussNewton, SimpleNewtonRaphson + end diff --git a/lib/SimpleNonlinearSolve/src/raphson.jl b/lib/SimpleNonlinearSolve/src/raphson.jl new file mode 100644 index 000000000..2af3a825a --- /dev/null +++ b/lib/SimpleNonlinearSolve/src/raphson.jl @@ -0,0 +1,62 @@ +""" + SimpleNewtonRaphson(autodiff) + SimpleNewtonRaphson(; autodiff = nothing) + +A low-overhead implementation of Newton-Raphson. This method is non-allocating on scalar +and static array problems. + +!!! note + + As part of the decreased overhead, this method omits some of the higher level error + catching of the other methods. Thus, to see better error messages, use one of the other + methods like `NewtonRaphson`. + +### Keyword Arguments + + - `autodiff`: determines the backend used for the Jacobian. Defaults to `nothing` (i.e. + automatic backend selection). Valid choices include jacobian backends from + `DifferentiationInterface.jl`. +""" +@kwdef @concrete struct SimpleNewtonRaphson <: AbstractSimpleNonlinearSolveAlgorithm + autodiff = nothing +end + +const SimpleGaussNewton = SimpleNewtonRaphson + +function SciMLBase.__solve( + prob::ImmutableNonlinearProblem, alg::SimpleNewtonRaphson, args...; + abstol = nothing, reltol = nothing, maxiters = 1000, + alias_u0 = false, termination_condition = nothing, kwargs...) + x = Utils.maybe_unaliased(prob.u0, alias_u0) + fx = Utils.get_fx(prob, x) + fx = Utils.eval_f(prob, fx, x) + + iszero(fx) && + return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success) + + abstol, reltol, tc_cache = NonlinearSolveBase.init_termination_cache( + prob, abstol, reltol, fx, x, termination_condition, Val(:simple)) + + autodiff = SciMLBase.has_jac(prob.f) ? alg.autodiff : + NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff) + + @bb xo = similar(x) + fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? similar(fx) : + nothing + jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x) + J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache) + + for _ in 1:maxiters + @bb copyto!(xo, x) + δx = Utils.restructure(x, J \ Utils.safe_vec(fx)) + @bb x .-= δx + + solved, retcode, fx_sol, x_sol = Utils.check_termination(tc_cache, fx, x, xo, prob) + solved && return SciMLBase.build_solution(prob, alg, x_sol, fx_sol; retcode) + + fx = Utils.eval_f(prob, fx, x) + J = Utils.compute_jacobian!!(J, prob, autodiff, fx_cache, x, jac_cache) + end + + return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters) +end diff --git a/lib/SimpleNonlinearSolve/src/utils.jl b/lib/SimpleNonlinearSolve/src/utils.jl index 16cf5142d..012fc277a 100644 --- a/lib/SimpleNonlinearSolve/src/utils.jl +++ b/lib/SimpleNonlinearSolve/src/utils.jl @@ -2,7 +2,7 @@ module Utils using ADTypes: AbstractADType, AutoForwardDiff, AutoFiniteDiff, AutoPolyesterForwardDiff using ArrayInterface: ArrayInterface -using DifferentiationInterface: DifferentiationInterface +using DifferentiationInterface: DifferentiationInterface, Constant using FastClosures: @closure using LinearAlgebra: LinearAlgebra, I, diagind using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, @@ -132,4 +132,72 @@ function check_termination( return false, ReturnCode.Default, fx, x end +restructure(y, x) = ArrayInterface.restructure(y, x) +restructure(::Number, x::Number) = x + +safe_vec(x::AbstractArray) = vec(x) +safe_vec(x::Number) = x + +function prepare_jacobian(prob, autodiff, _, x::Number) + if SciMLBase.has_jac(prob.f) || SciMLBase.has_vjp(prob.f) || SciMLBase.has_jvp(prob.f) + return nothing + end + return DI.prepare_derivative(prob.f, autodiff, x, Constant(prob.p)) +end +function prepare_jacobian(prob, autodiff, fx, x) + if SciMLBase.has_jac(prob.f) + return nothing + end + if SciMLBase.isinplace(prob.f) + return DI.prepare_jacobian(prob.f, fx, autodiff, x, Constant(prob.p)) + else + return DI.prepare_jacobian(prob.f, autodiff, x, Constant(prob.p)) + end +end + +function compute_jacobian!!(_, prob, autodiff, fx, x::Number, extras) + if extras === nothing + if SciMLBase.has_jac(prob.f) + return prob.f.jac(x, prob.p) + elseif SciMLBase.has_vjp(prob.f) + return prob.f.vjp(one(x), x, prob.p) + elseif SciMLBase.has_jvp(prob.f) + return prob.f.jvp(one(x), x, prob.p) + end + end + return DI.derivative(prob.f, extras, autodiff, x, Constant(prob.p)) +end +function compute_jacobian!!(J, prob, autodiff, fx, x, extras) + if J === nothing + if extras === nothing + if SciMLBase.isinplace(prob.f) + J = similar(fx, length(fx), length(x)) + prob.f.jac(J, x, prob.p) + return J + else + return prob.f.jac(x, prob.p) + end + end + if SciMLBase.isinplace(prob) + return DI.jacobian(prob.f, fx, extras, autodiff, x, Constant(prob.p)) + else + return DI.jacobian(prob.f, extras, autodiff, x, Constant(prob.p)) + end + end + if extras === nothing + if SciMLBase.isinplace(prob) + prob.jac(J, x, prob.p) + return J + else + return prob.jac(x, prob.p) + end + end + if SciMLBase.isinplace(prob) + DI.jacobian!(prob.f, J, fx, extras, autodiff, x, Constant(prob.p)) + else + DI.jacobian!(prob.f, J, extras, autodiff, x, Constant(prob.p)) + end + return J +end + end