From d0d3db4df6c8a44edaed6cb8ac2fea03a14ece5c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 21 Sep 2023 15:09:02 -0400 Subject: [PATCH] Delegate common operations to an abstract type --- src/NonlinearSolve.jl | 20 ++++++++++++++++++++ src/levenberg.jl | 21 ++------------------- src/raphson.jl | 26 +------------------------- src/trustRegion.jl | 5 ++--- 4 files changed, 25 insertions(+), 47 deletions(-) diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 615f96c03..8a6b60901 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -28,12 +28,32 @@ const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences, abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm end +abstract type AbstractNonlinearSolveCache{iip} end + +isinplace(::AbstractNonlinearSolveCache{iip}) where {iip} = iip + function SciMLBase.__solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAlgorithm, args...; kwargs...) cache = init(prob, alg, args...; kwargs...) return solve!(cache) end +function SciMLBase.solve!(cache::AbstractNonlinearSolveCache) + while !cache.force_stop && cache.stats.nsteps < cache.maxiters + perform_step!(cache) + cache.stats.nsteps += 1 + end + + if cache.stats.nsteps == cache.maxiters + cache.retcode = ReturnCode.MaxIters + else + cache.retcode = ReturnCode.Success + end + + return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1; + cache.retcode, cache.stats) +end + include("utils.jl") include("linesearch.jl") include("raphson.jl") diff --git a/src/levenberg.jl b/src/levenberg.jl index 7264c127f..c4b6924d1 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -97,7 +97,8 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing, finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D) end -@concrete mutable struct LevenbergMarquardtCache{iip, uType, jType, λType, lossType} +@concrete mutable struct LevenbergMarquardtCache{iip, uType, jType, λType, lossType} <: + AbstractNonlinearSolveCache{iip} f alg u::uType @@ -138,8 +139,6 @@ end stats::NLStats end -isinplace(::LevenbergMarquardtCache{iip}) where {iip} = iip - function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarquardt, args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip} @@ -313,19 +312,3 @@ function perform_step!(cache::LevenbergMarquardtCache{false}) cache.λ_factor = cache.damping_increase_factor return nothing end - -function SciMLBase.solve!(cache::LevenbergMarquardtCache) - while !cache.force_stop && cache.stats.nsteps < cache.maxiters - perform_step!(cache) - cache.stats.nsteps += 1 - end - - if cache.stats.nsteps == cache.maxiters - cache.retcode = ReturnCode.MaxIters - else - cache.retcode = ReturnCode.Success - end - - return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1; - cache.retcode, cache.stats) -end diff --git a/src/raphson.jl b/src/raphson.jl index 16fdad0c3..e06e89635 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -36,8 +36,6 @@ for large-scale and numerically-difficult nonlinear systems. linesearch end -concrete_jac(::NewtonRaphson{CJ}) where {CJ} = CJ - function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) @@ -45,7 +43,7 @@ function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch) end -@concrete mutable struct NewtonRaphsonCache{iip} +@concrete mutable struct NewtonRaphsonCache{iip} <: AbstractNonlinearSolveCache{iip} f alg u @@ -67,8 +65,6 @@ end lscache end -isinplace(::NewtonRaphsonCache{iip}) where {iip} = iip - function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson, args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip} @@ -146,23 +142,3 @@ function SciMLBase.solve!(cache::NewtonRaphsonCache) return SciMLBase.build_solution(cache.prob, cache.alg, cache.u, cache.fu1; cache.retcode, cache.stats) end - -function SciMLBase.reinit!(cache::NewtonRaphsonCache{iip}, u0 = cache.u; p = cache.p, - abstol = cache.abstol, maxiters = cache.maxiters) where {iip} - cache.p = p - if iip - recursivecopy!(cache.u, u0) - cache.f(cache.fu1, cache.u, p) - else - # don't have alias_u0 but cache.u is never mutated for OOP problems so it doesn't matter - cache.u = u0 - cache.fu1 = cache.f(cache.u, p) - end - cache.abstol = abstol - cache.maxiters = maxiters - cache.stats.nf = 1 - cache.stats.nsteps = 1 - cache.force_stop = false - cache.retcode = ReturnCode.Default - return cache -end diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 20577ef67..62af3279f 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -155,7 +155,8 @@ function TrustRegion(; concrete_jac = nothing, linsolve = nothing, precs = DEFAU expand_threshold, shrink_factor, expand_factor, max_shrink_times) end -@concrete mutable struct TrustRegionCache{iip, trustType, floatType} +@concrete mutable struct TrustRegionCache{iip, trustType, floatType} <: + AbstractNonlinearSolveCache{iip} f alg u_prev @@ -299,8 +300,6 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion, NLStats(1, 0, 0, 0, 0)) end -isinplace(::TrustRegionCache{iip}) where {iip} = iip - function perform_step!(cache::TrustRegionCache{true}) @unpack make_new_J, J, fu, f, u, p, u_tmp, alg, linsolve = cache if cache.make_new_J