From f20c9bcc974c7a4d353c2d46ea9fd41e790309cb Mon Sep 17 00:00:00 2001 From: Utkarsh Date: Thu, 26 Oct 2023 10:44:30 -0400 Subject: [PATCH] Move other algorithms to use termination conditions --- src/broyden.jl | 53 ++++++++++++++++++++++++++++++++++++++++--------- src/klement.jl | 51 ++++++++++++++++++++++++++++++++++++++--------- src/lbroyden.jl | 42 +++++++++++++++++++++++++++++++-------- src/raphson.jl | 3 +-- 4 files changed, 121 insertions(+), 28 deletions(-) diff --git a/src/broyden.jl b/src/broyden.jl index 6be29a77b..22a1a9cc8 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -31,6 +31,7 @@ end f alg u + u_prev du fu fu2 @@ -46,17 +47,21 @@ end internalnorm retcode::ReturnCode.T abstol + reltol reset_tolerance reset_check prob stats::NLStats lscache + termination_condition + tc_storage end get_fu(cache::GeneralBroydenCache) = cache.fu function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...; - alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, + termination_condition = nothing, internalnorm = DEFAULT_NORM, kwargs...) where {uType, iip} @unpack f, u0, p = prob u = alias_u0 ? u0 : deepcopy(u0) @@ -65,15 +70,29 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(eltype(u))) : alg.reset_tolerance reset_check = x -> abs(x) ≤ reset_tolerance - return GeneralBroydenCache{iip}(f, alg, u, _mutable_zero(u), fu, zero(fu), + + abstol, reltol, termination_condition = _init_termination_elements(abstol, + reltol, + termination_condition, + eltype(u)) + + mode = DiffEqBase.get_termination_mode(termination_condition) + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + return GeneralBroydenCache{iip}(f, alg, u, zero(u), _mutable_zero(u), fu, zero(fu), zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0, - alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance, + alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, + reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0), - init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip))) + init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition, + storage) end function perform_step!(cache::GeneralBroydenCache{true}) - @unpack f, p, du, fu, fu2, dfu, u, J⁻¹, J⁻¹df, J⁻¹₂ = cache + @unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂, tc_storage = cache + + termination_condition = cache.termination_condition(tc_storage) T = eltype(u) mul!(_vec(du), J⁻¹, -_vec(fu)) @@ -81,7 +100,8 @@ function perform_step!(cache::GeneralBroydenCache{true}) _axpy!(α, du, u) f(fu2, u, p) - cache.internalnorm(fu2) < cache.abstol && (cache.force_stop = true) + termination_condition(fu2, u, u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) cache.stats.nf += 1 cache.force_stop && return nothing @@ -106,12 +126,16 @@ function perform_step!(cache::GeneralBroydenCache{true}) mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1) end fu .= fu2 + @. u_prev = u return nothing end function perform_step!(cache::GeneralBroydenCache{false}) - @unpack f, p = cache + @unpack f, p, tc_storage = cache + + termination_condition = cache.termination_condition(tc_storage) + T = eltype(cache.u) cache.du = _restructure(cache.du, cache.J⁻¹ * -_vec(cache.fu)) @@ -119,7 +143,8 @@ function perform_step!(cache::GeneralBroydenCache{false}) cache.u = cache.u .+ α * cache.du cache.fu2 = f(cache.u, p) - cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true) + termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) cache.stats.nf += 1 cache.force_stop && return nothing @@ -142,12 +167,15 @@ function perform_step!(cache::GeneralBroydenCache{false}) cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂ end cache.fu = cache.fu2 + cache.u_prev = @. cache.u return nothing end function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p, - abstol = cache.abstol, maxiters = cache.maxiters) where {iip} + abstol = cache.abstol, reltol = cache.reltol, + termination_condition = cache.termination_condition, + maxiters = cache.maxiters) where {iip} cache.p = p if iip recursivecopy!(cache.u, u0) @@ -157,7 +185,14 @@ function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = ca cache.u = u0 cache.fu = cache.f(cache.u, p) end + termination_condition = _get_reinit_termination_condition(cache, + abstol, + reltol, + termination_condition) + cache.abstol = abstol + cache.reltol = reltol + cache.termination_condition = termination_condition cache.maxiters = maxiters cache.stats.nf = 1 cache.stats.nsteps = 1 diff --git a/src/klement.jl b/src/klement.jl index a16ed2873..32aec15d5 100644 --- a/src/klement.jl +++ b/src/klement.jl @@ -41,6 +41,7 @@ end f alg u + u_prev fu fu2 du @@ -65,7 +66,8 @@ end get_fu(cache::GeneralKlementCache) = cache.fu function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKlement, args...; - alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, + termination_condition = nothing, internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip} @unpack f, u0, p = prob u = alias_u0 ? u0 : deepcopy(u0) @@ -84,16 +86,30 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKleme linsolve = __setup_linsolve(J, _vec(fu), _vec(du), p, alg) end - return GeneralKlementCache{iip}(f, alg, u, fu, zero(fu), du, p, linsolve, + abstol, reltol, termination_condition = _init_termination_elements(abstol, + reltol, + termination_condition, + eltype(u)) + + mode = DiffEqBase.get_termination_mode(termination_condition) + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + + return GeneralKlementCache{iip}(f, alg, u, zero(u), fu, zero(fu), du, p, linsolve, J, zero(J), zero(J), _vec(zero(fu)), _vec(zero(fu)), 0, false, - maxiters, internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0), - init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip))) + maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob, + NLStats(1, 0, 0, 0, 0), + init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition, + storage) end function perform_step!(cache::GeneralKlementCache{true}) - @unpack u, fu, f, p, alg, J, linsolve, du = cache + @unpack u, u_prev, fu, f, p, alg, J, linsolve, du, tc_storage = cache T = eltype(J) + termination_condition = cache.termination_condition(tc_storage) + singular, fact_done = _try_factorize_and_check_singular!(linsolve, J) if singular @@ -118,7 +134,8 @@ function perform_step!(cache::GeneralKlementCache{true}) _axpy!(α, du, u) f(cache.fu2, u, p) - cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true) + termination_condition(cache.fu2, u, u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) cache.stats.nf += 1 cache.stats.nsolve += 1 cache.stats.nfactors += 1 @@ -138,13 +155,17 @@ function perform_step!(cache::GeneralKlementCache{true}) mul!(cache.J_cache2, cache.J_cache, J) J .+= cache.J_cache2 + @. u_prev = u cache.fu .= cache.fu2 return nothing end function perform_step!(cache::GeneralKlementCache{false}) - @unpack fu, f, p, alg, J, linsolve = cache + @unpack fu, f, p, alg, J, linsolve, tc_storage = cache + + termination_condition = cache.termination_condition(tc_storage) + T = eltype(J) singular, fact_done = _try_factorize_and_check_singular!(linsolve, J) @@ -174,7 +195,10 @@ function perform_step!(cache::GeneralKlementCache{false}) cache.u = @. cache.u + α * cache.du # `u` might not support mutation cache.fu2 = f(cache.u, p) - cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true) + termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) + + cache.u_prev = @. cache.u cache.stats.nf += 1 cache.stats.nsolve += 1 cache.stats.nfactors += 1 @@ -198,7 +222,9 @@ function perform_step!(cache::GeneralKlementCache{false}) end function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = cache.p, - abstol = cache.abstol, maxiters = cache.maxiters) where {iip} + abstol = cache.abstol, reltol = cache.reltol, + termination_condition = cache.termination_condition, + maxiters = cache.maxiters) where {iip} cache.p = p if iip recursivecopy!(cache.u, u0) @@ -208,7 +234,14 @@ function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = ca cache.u = u0 cache.fu = cache.f(cache.u, p) end + + termination_condition = _get_reinit_termination_condition(cache, + abstol, + reltol, + termination_condition) cache.abstol = abstol + cache.reltol = reltol + cache.termination_condition = termination_condition cache.maxiters = maxiters cache.stats.nf = 1 cache.stats.nsteps = 1 diff --git a/src/lbroyden.jl b/src/lbroyden.jl index d045d0b20..db4353b41 100644 --- a/src/lbroyden.jl +++ b/src/lbroyden.jl @@ -34,6 +34,7 @@ end f alg u + u_prev du fu fu2 @@ -53,17 +54,21 @@ end internalnorm retcode::ReturnCode.T abstol + reltol reset_tolerance reset_check prob stats::NLStats lscache + termination_condition + tc_storage end get_fu(cache::LimitedMemoryBroydenCache) = cache.fu function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LimitedMemoryBroyden, - args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM, + args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, + termination_condition = nothing, internalnorm = DEFAULT_NORM, kwargs...) where {uType, iip} @unpack f, u0, p = prob u = alias_u0 ? u0 : deepcopy(u0) @@ -80,23 +85,38 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LimitedMemory reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(eltype(u))) : alg.reset_tolerance reset_check = x -> abs(x) ≤ reset_tolerance - return LimitedMemoryBroydenCache{iip}(f, alg, u, du, fu, zero(fu), + + abstol, reltol, termination_condition = _init_termination_elements(abstol, + reltol, + termination_condition, + eltype(u)) + + mode = DiffEqBase.get_termination_mode(termination_condition) + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + + return LimitedMemoryBroydenCache{iip}(f, alg, u, zero(u), du, fu, zero(fu), zero(fu), p, U, Vᵀ, similar(u, threshold), similar(u, 1, threshold), zero(u), zero(u), false, 0, 0, alg.max_resets, maxiters, internalnorm, - ReturnCode.Default, abstol, reset_tolerance, reset_check, prob, + ReturnCode.Default, abstol, reltol, reset_tolerance, reset_check, prob, NLStats(1, 0, 0, 0, 0), - init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip))) + init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition, + storage) end function perform_step!(cache::LimitedMemoryBroydenCache{true}) - @unpack f, p, du, u = cache + @unpack f, p, du, u, tc_storage = cache T = eltype(u) + termination_condition = cache.termination_condition(tc_storage) + α = perform_linesearch!(cache.lscache, u, du) _axpy!(α, du, u) f(cache.fu2, u, p) - cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true) + termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) cache.stats.nf += 1 cache.force_stop && return nothing @@ -138,20 +158,25 @@ function perform_step!(cache::LimitedMemoryBroydenCache{true}) cache.iterations_since_reset += 1 end + cache.u_prev .= cache.u cache.fu .= cache.fu2 return nothing end function perform_step!(cache::LimitedMemoryBroydenCache{false}) - @unpack f, p = cache + @unpack f, p, tc_storage = cache + + termination_condition = cache.termination_condition(tc_storage) + T = eltype(cache.u) α = perform_linesearch!(cache.lscache, cache.u, cache.du) cache.u = cache.u .+ α * cache.du cache.fu2 = f(cache.u, p) - cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true) + termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) && + (cache.force_stop = true) cache.stats.nf += 1 cache.force_stop && return nothing @@ -194,6 +219,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{false}) cache.iterations_since_reset += 1 end + cache.u_prev = @. cache.u cache.fu = cache.fu2 return nothing diff --git a/src/raphson.jl b/src/raphson.jl index eef64ec7b..a34d860ce 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -135,9 +135,8 @@ function perform_step!(cache::NewtonRaphsonCache{true}) end function perform_step!(cache::NewtonRaphsonCache{false}) - @unpack u, u_prev, fu1, f, p, alg, linsolve = cache + @unpack u, u_prev, fu1, f, p, alg, linsolve, tc_storage = cache - tc_storage = cache.tc_storage termination_condition = cache.termination_condition(tc_storage) cache.J = jacobian!!(cache.J, cache)