From d63e00b88bf41473979c0d500664d35e04c0b5bb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 29 Oct 2023 22:42:29 -0400 Subject: [PATCH] Rework the Termination Condition API to be type stable --- src/termination_conditions.jl | 191 ++++++++++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index ce310a490..a83cb51d9 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -1,3 +1,194 @@ +@enumx NonlinearSafeTerminationReturnCode begin + Success + Default + PatienceTermination + ProtectiveTermination + Failure +end + +abstract type AbstractNonlinearTerminationMode end +abstract type AbstractSafeNonlinearTerminationMode <: AbstractNonlinearTerminationMode end +abstract type AbstractSafeBestNonlinearTerminationMode <: + AbstractSafeNonlinearTerminationMode end + +# TODO: Add a mode where the user can pass in custom termination criteria function +for mode in (:SteadyStateDiffEqTerminationMode, :SimpleNonlinearSolveTerminationMode, + :NormTerminationMode, :RelTerminationMode, :RelNormTerminationMode, :AbsTerminationMode, + :AbsNormTerminationMode) + @eval begin + struct $(mode) <: AbstractNonlinearTerminationMode end + end +end + +for mode in (:RelSafeTerminationMode, :AbsSafeTerminationMode) + @eval begin + Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeNonlinearTerminationMode + protective_threshold::T1 = 1000 + patience_steps::Int = 30 + patience_objective_multiplier::T2 = 3 + min_max_factor::T3 = 1.3 + end + end +end + +for mode in (:RelSafeBestTerminationMode, :AbsSafeBestTerminationMode) + @eval begin + Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeBestNonlinearTerminationMode + protective_threshold::T1 = 1000 + patience_steps::Int = 30 + patience_objective_multiplier::T2 = 3 + min_max_factor::T3 = 1.3 + end + end +end + +mutable struct NonlinearTerminationModeCache{uType, T, + M <: AbstractNonlinearTerminationMode, I, OT} + u::uType + retcode::NonlinearSafeTerminationReturnCode.T + abstol::T + reltol::T + best_objective_value::T + mode::M + initial_objective::I + objectives_trace::OT + nsteps::Int +end + +function __update_u!!(cache::NonlinearTerminationModeCache, u) + cache.u === nothing && return + if ArrayInterface.can_setindex(cache.u) + copyto!(cache.u, u) + else + cache.u = u + end +end + +__cvt_real(::Type{T}, ::Nothing) where {T} = nothing +__cvt_real(::Type{T}, x) where {T} = real(T(x)) + +_get_tolerance(η, ::Type{T}) where {T} = __cvt_real(T, η) +function _get_tolerance(::Nothing, ::Type{T}) where {T} + η = real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) + return _get_tolerance(η, T) +end + +function SciMLBase.init(u::AbstractArray{T}, mode::AbstractNonlinearTerminationMode; + abstol = nothing, reltol = nothing, kwargs...) where {T} + abstol = _get_tolerance(abstol, T) + reltol = _get_tolerance(reltol, T) + best_value = __cvt_real(T, Inf) + TT = typeof(abstol) + u_ = mode isa AbstractSafeBestNonlinearTerminationMode ? + (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing + if mode isa AbstractSafeNonlinearTerminationMode + initial_objective = TT(0) + objectives_trace = Vector{TT}(undef, mode.patience_steps) + else + initial_objective = nothing + objectives_trace = nothing + end + return NonlinearTerminationModeCache{typeof(u_), TT, typeof(mode), + typeof(initial_objective), typeof(objectives_trace)}(u_, + NonlinearSafeTerminationReturnCode.Default, abstol, reltol, best_value, mode, + initial_objective, objectives_trace, 0) +end + +# This dispatch is needed based on how Terminating Callback works! +# This intentially drops the `abstol` and `reltol` arguments +function (cache::NonlinearTerminationModeCache)(integrator, _, _, min_t) + return cache(cache.mode, get_du(integrator), integrator.u, integrator.uprev) +end +(cache::NonlinearTerminationModeCache)(du, u, uprev) = cache(cache.mode, du, u, uprev) + +function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminationMode, du, + u, uprev) + return check_convergence(mode, du, u, uprev, cache.abstol, cache.reltol) +end + +function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTerminationMode, + du, u, uprev) + if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode + objective = NLSOLVE_DEFAULT_NORM(du) + criteria = cache.abstol + else + objective = NLSOLVE_DEFAULT_NORM(du) / + (NLSOLVE_DEFAULT_NORM(du .+ u) + eps(cache.abstol)) + criteria = cache.reltol + end + + # Check if best solution + if mode isa AbstractSafeBestNonlinearTerminationMode && + objective < cache.best_objective_value + cache.best_objective_value = objective + __update_u!!(cache, u) + end + + # Main Termination Condition + if objective ≤ criteria + cache.retcode = NonlinearSafeTerminationReturnCode.Success + return true + end + + # Terminate if there has been no improvement for the last `patience_steps` + cache.nsteps += 1 + cache.nsteps == 1 && (cache.initial_objective = objective) + cache.objectives_trace[mod1(cache.nsteps, length(cache.objectives_trace))] = objective + + if objective ≤ cache.mode.patience_objective_multiplier * criteria + if cache.nsteps ≥ cache.mode.patience_steps + if cache.nsteps < length(cache.objectives_trace) + min_obj, max_obj = extrema(@view(cache.objectives_trace[1:cache.nsteps])) + else + min_obj, max_obj = extrema(cache.objectives_trace) + end + if min_obj < cache.mode.min_max_factor * max_obj + cache.retcode = NonlinearSafeTerminationReturnCode.PatienceTermination + return true + end + end + end + + # Protective Break + if objective ≥ cache.initial_objective * cache.mode.protective_threshold * length(du) + cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination + return true + end + + cache.retcode = NonlinearSafeTerminationReturnCode.Failure + return false +end + +function check_convergence(::SteadyStateDiffEqTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, + reltol) + return all((abs.(duₙ) .≤ abstol) .| (abs.(duₙ) .≤ reltol .* abs.(uₙ))) +end +function check_convergence(::SimpleNonlinearSolveTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, + reltol) + return all((abs.(duₙ) .≤ abstol) .| (abs.(duₙ) .≤ reltol .* abs.(uₙ))) || + isapprox(uₙ, uₙ₋₁; atol = abstol, rtol = reltol) +end +function check_convergence(::NormTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) + du_norm = NLSOLVE_DEFAULT_NORM(duₙ) + return du_norm ≤ abstol || du_norm ≤ reltol * NLSOLVE_DEFAULT_NORM(duₙ .+ uₙ) +end +function check_convergence(::RelNormTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) + return all(abs.(duₙ) .≤ reltol .* abs.(uₙ)) +end +function check_convergence(::Union{RelNormTerminationMode, RelSafeTerminationMode, + RelSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) + return NLSOLVE_DEFAULT_NORM(duₙ) ≤ reltol * NLSOLVE_DEFAULT_NORM(duₙ .+ uₙ) +end +function check_convergence(::AbsTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) + return all(abs.(duₙ) .≤ abstol) +end +function check_convergence(::Union{AbsNormTerminationMode, AbsSafeTerminationMode, + AbsSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) + return NLSOLVE_DEFAULT_NORM(duₙ) ≤ abstol +end + +# NOTE: Deprecate the following API eventually. This API leads to quite a bit of type +# instability @enumx NLSolveSafeTerminationReturnCode begin Success PatienceTermination