From efb8c5ef324322ebde0b74fc7af7da9d116e99e0 Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Mon, 25 Sep 2023 21:18:59 -0400 Subject: [PATCH] Update all algorithms to use termination condition --- src/NonlinearSolve.jl | 2 +- src/levenberg.jl | 58 +++++++++++++++++++++++++++++------- src/raphson.jl | 40 +++++++++++++++---------- src/trustRegion.jl | 68 +++++++++++++++++++++++++++++++++++-------- test/basictests.jl | 42 ++++++++++++++++++++++++++ 5 files changed, 170 insertions(+), 40 deletions(-) diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 3a38989d4..b0fb28fb1 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -26,7 +26,7 @@ const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences, ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode} abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end -abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm end +abstract type AbstractNewtonAlgorithm{CJ, AD, TC} <: AbstractNonlinearSolveAlgorithm end abstract type AbstractNonlinearSolveCache{iip} end diff --git a/src/levenberg.jl b/src/levenberg.jl index f35f35cb2..460377245 100644 --- a/src/levenberg.jl +++ b/src/levenberg.jl @@ -74,7 +74,8 @@ numerically-difficult nonlinear systems. [this paper](https://arxiv.org/abs/1201.5885) to use a minimum value of the elements in `DᵀD` to prevent the damping from being too small. Defaults to `1e-8`. """ -@concrete struct LevenbergMarquardt{CJ, AD, T} <: AbstractNewtonAlgorithm{CJ, AD} +@concrete struct LevenbergMarquardt{CJ, AD, T, TC <: NLSolveTerminationCondition} <: + AbstractNewtonAlgorithm{CJ, AD, TC} ad::AD linsolve precs @@ -85,6 +86,7 @@ numerically-difficult nonlinear systems. α_geodesic::T b_uphill::T min_damping_D::T + termination_condition::TC end function set_ad(alg::LevenbergMarquardt{CJ}, ad) where {CJ} @@ -97,17 +99,22 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS, damping_initial::Real = 1.0, damping_increase_factor::Real = 2.0, damping_decrease_factor::Real = 3.0, finite_diff_step_geodesic::Real = 0.1, α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing), adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) return LevenbergMarquardt{_unwrap_val(concrete_jac)}(ad, linsolve, precs, damping_initial, damping_increase_factor, damping_decrease_factor, - finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D) + finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D, + termination_condition) end @concrete mutable struct LevenbergMarquardtCache{iip} <: AbstractNonlinearSolveCache{iip} f alg u + u_prev fu1 fu2 du @@ -121,6 +128,7 @@ end internalnorm retcode::ReturnCode.T abstol + reltol prob DᵀD JᵀJ @@ -145,11 +153,13 @@ end Jv mat_tmp stats::NLStats + tc_storage end function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, - NonlinearLeastSquaresProblem{uType, iip}}, alg_::LevenbergMarquardt, - args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + NonlinearLeastSquaresProblem{uType, iip}}, alg_::LevenbergMarquardt, + args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, + internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip} alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob @@ -184,15 +194,30 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip}, fu_tmp = zero(fu1) mat_tmp = zero(JᵀJ) - return LevenbergMarquardtCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J, - jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, DᵀD, + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) + + atol = _get_tolerance(abstol, tc.abstol, eltype(u)) + rtol = _get_tolerance(reltol, tc.reltol, eltype(u)) + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + + return LevenbergMarquardtCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve, + J, + jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol, prob, DᵀD, JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic, b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp, - zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0)) + zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0), storage) + end function perform_step!(cache::LevenbergMarquardtCache{true}) @unpack fu1, f, make_new_J = cache + + tc_storage = cache.tc_storage + termination_condition = cache.alg.termination_condition(tc_storage) + if iszero(fu1) cache.force_stop = true return nothing @@ -205,7 +230,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true}) cache.make_new_J = false cache.stats.njacs += 1 end - @unpack u, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache + @unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache # Usual Levenberg-Marquardt step ("velocity"). # The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp @@ -246,7 +271,11 @@ function perform_step!(cache::LevenbergMarquardtCache{true}) if (1 - β)^b_uphill * loss ≤ loss_old # Accept step. cache.u .+= δ - if loss < cache.abstol + if termination_condition(cache.fu_tmp, + cache.u, + u_prev, + cache.abstol, + cache.reltol) cache.force_stop = true return nothing end @@ -258,6 +287,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true}) cache.make_new_J = true end end + @. u_prev = u cache.λ *= cache.λ_factor cache.λ_factor = cache.damping_increase_factor return nothing @@ -265,6 +295,10 @@ end function perform_step!(cache::LevenbergMarquardtCache{false}) @unpack fu1, f, make_new_J = cache + + tc_storage = cache.tc_storage + termination_condition = cache.alg.termination_condition(tc_storage) + if iszero(fu1) cache.force_stop = true return nothing @@ -281,7 +315,8 @@ function perform_step!(cache::LevenbergMarquardtCache{false}) cache.make_new_J = false cache.stats.njacs += 1 end - @unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache + + @unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache cache.mat_tmp = JᵀJ + λ * DᵀD # Usual Levenberg-Marquardt step ("velocity"). @@ -322,7 +357,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false}) if (1 - β)^b_uphill * loss ≤ loss_old # Accept step. cache.u += δ - if loss < cache.abstol + if termination_condition(fu_new, cache.u, u_prev, cache.abstol, cache.reltol) cache.force_stop = true return nothing end @@ -334,6 +369,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false}) cache.make_new_J = true end end + cache.u_prev = @. cache.u cache.λ *= cache.λ_factor cache.λ_factor = cache.damping_increase_factor return nothing diff --git a/src/raphson.jl b/src/raphson.jl index a0bb2eb24..57547baca 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -30,7 +30,8 @@ for large-scale and numerically-difficult nonlinear systems. which means that no line search is performed. Algorithms from `LineSearches.jl` can be used here directly, and they will be converted to the correct `LineSearch`. """ -@concrete struct NewtonRaphson{CJ, AD, TC <: NLSolveTerminationCondition} <: AbstractNewtonAlgorithm{CJ, AD} +@concrete struct NewtonRaphson{CJ, AD, TC <: NLSolveTerminationCondition} <: + AbstractNewtonAlgorithm{CJ, AD, TC} ad::AD linsolve precs @@ -43,19 +44,24 @@ function set_ad(alg::NewtonRaphson{CJ}, ad) where {CJ} end function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, - linesearch = LineSearch(), precs = DEFAULT_PRECS, termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; - abstol = nothing, - reltol = nothing), adkwargs...) + linesearch = LineSearch(), precs = DEFAULT_PRECS, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing), adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch) - return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch, termination_condition) + return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, + linsolve, + precs, + linesearch, + termination_condition) end @concrete mutable struct NewtonRaphsonCache{iip} <: AbstractNonlinearSolveCache{iip} f alg u - uprev + u_prev fu1 fu2 du @@ -77,7 +83,8 @@ end end function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphson, args...; - alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, + internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip} alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob @@ -86,7 +93,6 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); linsolve_kwargs) - tc = alg.termination_condition mode = DiffEqBase.get_termination_mode(tc) @@ -98,11 +104,12 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso return NewtonRaphsonCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve, J, jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol, prob, - NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip)), storage) + NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip)), + storage) end function perform_step!(cache::NewtonRaphsonCache{true}) - @unpack u, uprev, fu1, f, p, alg, J, linsolve, du = cache + @unpack u, u_prev, fu1, f, p, alg, J, linsolve, du = cache jacobian!!(J, cache) tc_storage = cache.tc_storage @@ -118,9 +125,10 @@ function perform_step!(cache::NewtonRaphsonCache{true}) @. u = u - α * du f(cache.fu1, u, p) - termination_condition(cache.fu1, u, uprev, cache.abstol, cache.reltol) && (cache.force_stop = true) + termination_condition(cache.fu1, u, u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) - @. uprev = u + @. u_prev = u cache.stats.nf += 1 cache.stats.njacs += 1 cache.stats.nsolve += 1 @@ -129,12 +137,11 @@ function perform_step!(cache::NewtonRaphsonCache{true}) end function perform_step!(cache::NewtonRaphsonCache{false}) - @unpack u, uprev, fu1, f, p, alg, linsolve = cache + @unpack u, u_prev, fu1, f, p, alg, linsolve = cache tc_storage = cache.tc_storage termination_condition = cache.alg.termination_condition(tc_storage) - cache.J = jacobian!!(cache.J, cache) # u = u - J \ fu if linsolve === nothing @@ -150,9 +157,10 @@ function perform_step!(cache::NewtonRaphsonCache{false}) cache.u = @. u - α * cache.du # `u` might not support mutation cache.fu1 = f(cache.u, p) - termination_condition(cache.fu1, cache.u, uprev, cache.abstol, cache.reltol) && (cache.force_stop = true) + termination_condition(cache.fu1, cache.u, u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) - cache.uprev = cache.u + cache.u_prev = @. cache.u cache.stats.nf += 1 cache.stats.njacs += 1 cache.stats.nsolve += 1 diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 769f5e75c..bfdc2a557 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -148,7 +148,8 @@ for large-scale and numerically-difficult nonlinear systems. `linsolve` and `precs` are used exclusively for the inplace version of the algorithm. Support for the OOP version is planned! """ -@concrete struct TrustRegion{CJ, AD, MTR} <: AbstractNewtonAlgorithm{CJ, AD} +@concrete struct TrustRegion{CJ, AD, MTR, TC <: NLSolveTerminationCondition} <: + AbstractNewtonAlgorithm{CJ, AD, TC} ad::AD linsolve precs @@ -161,6 +162,7 @@ for large-scale and numerically-difficult nonlinear systems. shrink_factor::MTR expand_factor::MTR max_shrink_times::Int + termination_condition::TC end function set_ad(alg::TrustRegion{CJ}, ad) where {CJ} @@ -175,11 +177,15 @@ function TrustRegion(; concrete_jac = nothing, linsolve = nothing, precs = DEFAU max_trust_radius::Real = 0 // 1, initial_trust_radius::Real = 0 // 1, step_threshold::Real = 1 // 10000, shrink_threshold::Real = 1 // 4, expand_threshold::Real = 3 // 4, shrink_factor::Real = 1 // 4, - expand_factor::Real = 2 // 1, max_shrink_times::Int = 32, adkwargs...) + expand_factor::Real = 2 // 1, max_shrink_times::Int = 32, + termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing), adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) return TrustRegion{_unwrap_val(concrete_jac)}(ad, linsolve, precs, radius_update_scheme, max_trust_radius, initial_trust_radius, step_threshold, shrink_threshold, - expand_threshold, shrink_factor, expand_factor, max_shrink_times) + expand_threshold, shrink_factor, expand_factor, max_shrink_times, + termination_condition) end @concrete mutable struct TrustRegionCache{iip, trustType, floatType} <: @@ -201,6 +207,7 @@ end internalnorm retcode::ReturnCode.T abstol + reltol prob radius_update_scheme::RadiusUpdateSchemes.T trust_r::trustType @@ -228,10 +235,12 @@ end p4::floatType ϵ::floatType stats::NLStats + tc_storage end function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, args...; - alias_u0 = false, maxiters = 1000, abstol = 1e-8, internalnorm = DEFAULT_NORM, + alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, + internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip} alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob @@ -333,13 +342,22 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, initial_trust_radius = convert(trustType, 1.0) end + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) + + atol = _get_tolerance(abstol, tc.abstol, eltype(u)) + rtol = _get_tolerance(reltol, tc.reltol, eltype(u)) + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + return TrustRegionCache{iip}(f, alg, u_prev, u, fu_prev, fu1, fu2, p, uf, linsolve, J, - jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, + jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol, prob, radius_update_scheme, initial_trust_radius, max_trust_radius, step_threshold, shrink_threshold, expand_threshold, shrink_factor, expand_factor, loss, loss_new, H, g, shrink_counter, du, u_tmp, u_gauss_newton, u_cauchy, fu_new, make_new_J, r, p1, p2, p3, p4, ϵ, - NLStats(1, 0, 0, 0, 0)) + NLStats(1, 0, 0, 0, 0), storage) end function perform_step!(cache::TrustRegionCache{true}) @@ -416,6 +434,10 @@ end function trust_region_step!(cache::TrustRegionCache) @unpack fu_new, du, g, H, loss, max_trust_r, radius_update_scheme = cache + + tc_storage = cache.tc_storage + termination_condition = cache.alg.termination_condition(tc_storage) + cache.loss_new = get_loss(fu_new) # Compute the ratio of the actual reduction to the predicted reduction. @@ -444,8 +466,11 @@ function trust_region_step!(cache::TrustRegionCache) # No need to make a new J, no step was taken, so we try again with a smaller trust_r cache.make_new_J = false end - - if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol + if iszero(cache.fu) || termination_condition(cache.fu, + cache.u, + cache.u_prev, + cache.abstol, + cache.reltol) cache.force_stop = true end @@ -513,7 +538,12 @@ function trust_region_step!(cache::TrustRegionCache) cache.trust_r = rfunc(r, shrink_threshold, p1, p3, p4, p2) * cache.internalnorm(du) - if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || + if iszero(cache.fu) || + termination_condition(cache.fu, + cache.u, + cache.u_prev, + cache.abstol, + cache.reltol) || cache.internalnorm(g) < cache.ϵ cache.force_stop = true end @@ -538,7 +568,12 @@ function trust_region_step!(cache::TrustRegionCache) @unpack p1 = cache cache.trust_r = p1 * cache.internalnorm(jvp!(cache)) - if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || + if iszero(cache.fu) || + termination_condition(cache.fu, + cache.u, + cache.u_prev, + cache.abstol, + cache.reltol) || cache.internalnorm(g) < cache.ϵ cache.force_stop = true end @@ -562,7 +597,12 @@ function trust_region_step!(cache::TrustRegionCache) @unpack p1 = cache cache.trust_r = p1 * (cache.internalnorm(cache.fu)^0.99) - if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol || + if iszero(cache.fu) || + termination_condition(cache.fu, + cache.u, + cache.u_prev, + cache.abstol, + cache.reltol) || cache.internalnorm(g) < cache.ϵ cache.force_stop = true end @@ -580,7 +620,11 @@ function trust_region_step!(cache::TrustRegionCache) cache.trust_r *= cache.p2 cache.shrink_counter += 1 end - if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol + if iszero(cache.fu) || termination_condition(cache.fu, + cache.u, + cache.u_prev, + cache.abstol, + cache.reltol) cache.force_stop = true end end diff --git a/test/basictests.jl b/test/basictests.jl index 06cfc103d..250a2e8b4 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -123,6 +123,20 @@ end probN = NonlinearProblem(quadratic_f, u0, 2.0) @test all(solve(probN, NewtonRaphson(; autodiff)).u .≈ sqrt(2.0)) end + + @testset "Termination condition: $(mode) u0: $(_nameof(u0))" for mode in instances(NLSolveTerminationMode.T), + u0 in (1.0, [1.0, 1.0]) + + if mode ∈ + (NLSolveTerminationMode.SteadyStateDefault, NLSolveTerminationMode.RelSafeBest, + NLSolveTerminationMode.AbsSafeBest) + continue + end + termination_condition = NLSolveTerminationCondition(mode; abstol = nothing, + reltol = nothing) + probN = NonlinearProblem(quadratic_f, u0, 2.0) + @test all(solve(probN, NewtonRaphson(; termination_condition)).u .≈ sqrt(2.0)) + end end # --- TrustRegion tests --- @@ -281,6 +295,20 @@ end @test sol_iip.u ≈ sol_oop.u end end + + @testset "Termination condition: $(mode) u0: $(_nameof(u0))" for mode in instances(NLSolveTerminationMode.T), + u0 in (1.0, [1.0, 1.0]) + + if mode ∈ + (NLSolveTerminationMode.SteadyStateDefault, NLSolveTerminationMode.RelSafeBest, + NLSolveTerminationMode.AbsSafeBest) + continue + end + termination_condition = NLSolveTerminationCondition(mode; abstol = nothing, + reltol = nothing) + probN = NonlinearProblem(quadratic_f, u0, 2.0) + @test all(solve(probN, TrustRegion(; termination_condition)).u .≈ sqrt(2.0)) + end end # --- LevenbergMarquardt tests --- @@ -390,6 +418,20 @@ end @test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-10) end end + + @testset "Termination condition: $(mode) u0: $(_nameof(u0))" for mode in instances(NLSolveTerminationMode.T), + u0 in (1.0, [1.0, 1.0]) + + if mode ∈ + (NLSolveTerminationMode.SteadyStateDefault, NLSolveTerminationMode.RelSafeBest, + NLSolveTerminationMode.AbsSafeBest) + continue + end + termination_condition = NLSolveTerminationCondition(mode; abstol = nothing, + reltol = nothing) + probN2 = NonlinearProblem(quadratic_f, u0, 2.0) + @test all(solve(probN, LevenbergMarquardt(; termination_condition)).u .≈ sqrt(2.0)) + end end # --- DFSane tests ---