Skip to content

Commit

Permalink
Docs
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 30, 2023
1 parent 26b99eb commit 8fccdcc
Showing 1 changed file with 188 additions and 42 deletions.
230 changes: 188 additions & 42 deletions src/termination_conditions.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,39 @@
"""
NonlinearSafeTerminationReturnCode
Return Codes for the safe nonlinear termination conditions.
"""
@enumx NonlinearSafeTerminationReturnCode begin
"""
NonlinearSafeTerminationReturnCode.Success
Termination Condition was satisfied!
"""
Success
"""
NonlinearSafeTerminationReturnCode.Default
Default Return Code. Used for type stability and conveys no additional information!
"""
Default
"""
NonlinearSafeTerminationReturnCode.PatienceTermination
Terminate if there has been no improvement for the last `patience_steps`.
"""
PatienceTermination
"""
NonlinearSafeTerminationReturnCode.ProtectiveTermination
Terminate if the objective value increased by this factor wrt initial objective or the
value diverged.
"""
ProtectiveTermination
"""
NonlinearSafeTerminationReturnCode.Failure
Termination Condition was not satisfied!
"""
Failure
end

Expand All @@ -12,34 +43,149 @@ abstract type AbstractSafeBestNonlinearTerminationMode <:
AbstractSafeNonlinearTerminationMode end

# TODO: Add a mode where the user can pass in custom termination criteria function
for mode in (:SteadyStateDiffEqTerminationMode, :SimpleNonlinearSolveTerminationMode,
:NormTerminationMode, :RelTerminationMode, :RelNormTerminationMode, :AbsTerminationMode,
:AbsNormTerminationMode)
@eval begin
struct $(mode) <: AbstractNonlinearTerminationMode end
end

"""
SteadyStateDiffEqTerminationMode <: AbstractNonlinearTerminationMode
Check if all values of the derivative is close to zero wrt both relative and absolute
tolerance.
The default used in SteadyStateDiffEq.jl! Not recommended for large problems, since the
convergence criteria is very strict and never reliably satisfied for most problems.
"""
struct SteadyStateDiffEqTerminationMode <: AbstractNonlinearTerminationMode end

"""
SimpleNonlinearSolveTerminationMode <: AbstractNonlinearTerminationMode
Check if all values of the derivative is close to zero wrt both relative and absolute
tolerance. Or check that the value of the current and previous state is within the specified
tolerances.
The default used in SimpleNonlinearSolve.jl! Not recommended for large problems, since the
convergence criteria is very strict and never reliably satisfied for most problems.
"""
struct SimpleNonlinearSolveTerminationMode <: AbstractNonlinearTerminationMode end

@doc doc"""
NormTerminationMode <: AbstractNonlinearTerminationMode
Terminates if
``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|``
or ``\| \frac{\partial u}{\partial t} \| \leq abstol``
"""
struct NormTerminationMode <: AbstractNonlinearTerminationMode end

@doc doc"""
RelTerminationMode <: AbstractNonlinearTerminationMode
Terminates if
``all \left(| \frac{\partial u}{\partial t} | \leq reltol \times | u | \right)``.
"""
struct RelTerminationMode <: AbstractNonlinearTerminationMode end

@doc doc"""
RelNormTerminationMode <: AbstractNonlinearTerminationMode
Terminates if
``\| \frac{\partial u}{\partial t} \| \leq reltol \times \| \frac{\partial u}{\partial t} + u \|``
"""
struct RelNormTerminationMode <: AbstractNonlinearTerminationMode end

@doc doc"""
AbsTerminationMode <: AbstractNonlinearTerminationMode
Terminates if ``all \left( | \frac{\partial u}{\partial t} | \leq abstol \right)``.
"""
struct AbsTerminationMode <: AbstractNonlinearTerminationMode end

@doc doc"""
AbsNormTerminationMode <: AbstractNonlinearTerminationMode
Terminates if ``\| \frac{\partial u}{\partial t} \| \leq abstol``.
"""
struct AbsNormTerminationMode <: AbstractNonlinearTerminationMode end

@doc doc"""
RelSafeTerminationMode <: AbstractSafeNonlinearTerminationMode
Essentially [`RelNormTerminationMode`](@ref) + terminate if there has been no improvement
for the last `patience_steps` + terminate if the solution blows up (diverges).
## Constructor
```julia
RelSafeTerminationMode(; protective_threshold = 1e3, patience_steps = 100,
patience_objective_multiplier = 3, min_max_factor = 1.3)
```
"""
Base.@kwdef struct RelSafeTerminationMode{T1, T2, T3} <:
AbstractSafeNonlinearTerminationMode
protective_threshold::T1 = 1000
patience_steps::Int = 100
patience_objective_multiplier::T2 = 3
min_max_factor::T3 = 1.3
end

for mode in (:RelSafeTerminationMode, :AbsSafeTerminationMode)
@eval begin
Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeNonlinearTerminationMode
protective_threshold::T1 = 1000
patience_steps::Int = 30
patience_objective_multiplier::T2 = 3
min_max_factor::T3 = 1.3
end
end
@doc doc"""
AbsSafeTerminationMode <: AbstractSafeNonlinearTerminationMode
Essentially [`AbsNormTerminationMode`](@ref) + terminate if there has been no improvement
for the last `patience_steps` + terminate if the solution blows up (diverges).
## Constructor
```julia
AbsSafeTerminationMode(; protective_threshold = 1e3, patience_steps = 100,
patience_objective_multiplier = 3, min_max_factor = 1.3)
```
"""
Base.@kwdef struct AbsSafeTerminationMode{T1, T2, T3} <:
AbstractSafeNonlinearTerminationMode
protective_threshold::T1 = 1000
patience_steps::Int = 100
patience_objective_multiplier::T2 = 3
min_max_factor::T3 = 1.3
end

for mode in (:RelSafeBestTerminationMode, :AbsSafeBestTerminationMode)
@eval begin
Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeBestNonlinearTerminationMode
protective_threshold::T1 = 1000
patience_steps::Int = 100
patience_objective_multiplier::T2 = 3
min_max_factor::T3 = 1.3
end
end
@doc doc"""
RelSafeBestTerminationMode <: AbstractSafeBestNonlinearTerminationMode
Essentially [`RelSafeTerminationMode`](@ref), but caches the best solution found so far.
## Constructor
```julia
RelSafeBestTerminationMode(; protective_threshold = 1e3, patience_steps = 100,
patience_objective_multiplier = 3, min_max_factor = 1.3)
```
"""
Base.@kwdef struct RelSafeBestTerminationMode{T1, T2, T3} <:
AbstractSafeNonlinearTerminationMode
protective_threshold::T1 = 1000
patience_steps::Int = 100
patience_objective_multiplier::T2 = 3
min_max_factor::T3 = 1.3
end

@doc doc"""
AbsSafeBestTerminationMode <: AbstractSafeBestNonlinearTerminationMode
Essentially [`AbsSafeTerminationMode`](@ref), but caches the best solution found so far.
## Constructor
```julia
AbsSafeBestTerminationMode(; protective_threshold = 1e3, patience_steps = 100,
patience_objective_multiplier = 3, min_max_factor = 1.3)
```
"""
Base.@kwdef struct AbsSafeBestTerminationMode{T1, T2, T3} <:
AbstractSafeNonlinearTerminationMode
protective_threshold::T1 = 1000
patience_steps::Int = 100
patience_objective_multiplier::T2 = 3
min_max_factor::T3 = 1.3
end

mutable struct NonlinearTerminationModeCache{uType, T,
Expand Down Expand Up @@ -78,8 +224,8 @@ function _get_tolerance(::Nothing, ::Type{T}) where {T}
end

function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T}, T},
mode::AbstractNonlinearTerminationMode; abstol = nothing, reltol = nothing,
kwargs...) where {T <: Number}
mode::AbstractNonlinearTerminationMode; abstol = nothing, reltol = nothing,
kwargs...) where {T <: Number}
abstol = _get_tolerance(abstol, T)
reltol = _get_tolerance(reltol, T)
TT = typeof(abstol)
Expand Down Expand Up @@ -113,12 +259,12 @@ end
(cache::NonlinearTerminationModeCache)(du, u, uprev) = cache(cache.mode, du, u, uprev)

function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminationMode, du,
u, uprev)
u, uprev)
return check_convergence(mode, du, u, uprev, cache.abstol, cache.reltol)
end

function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTerminationMode,
du, u, uprev)
du, u, uprev)
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
objective = NONLINEARSOLVE_DEFAULT_NORM(du)
criteria = cache.abstol
Expand All @@ -130,7 +276,7 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi

# Protective Break
if isinf(objective) || isnan(objective) ||
(objective cache.initial_objective * cache.mode.protective_threshold * length(du))
(objective cache.initial_objective * cache.mode.protective_threshold * length(du))
cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination
return true
end
Expand Down Expand Up @@ -172,11 +318,11 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
end

function check_convergence(::SteadyStateDiffEqTerminationMode, duₙ, uₙ, uₙ₋₁, abstol,
reltol)
reltol)
return all((abs.(duₙ) .≤ abstol) .| (abs.(duₙ) .≤ reltol .* abs.(uₙ)))
end
function check_convergence(::SimpleNonlinearSolveTerminationMode, duₙ, uₙ, uₙ₋₁, abstol,
reltol)
reltol)
return all((abs.(duₙ) .≤ abstol) .| (abs.(duₙ) .≤ reltol .* abs.(uₙ))) ||
isapprox(uₙ, uₙ₋₁; atol = abstol, rtol = reltol)
end
Expand All @@ -188,15 +334,15 @@ function check_convergence(::RelTerminationMode, duₙ, uₙ, uₙ₋₁, abstol
return all(abs.(duₙ) .≤ reltol .* abs.(uₙ))
end
function check_convergence(::Union{RelNormTerminationMode, RelSafeTerminationMode,
RelSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol)
RelSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol)
return NONLINEARSOLVE_DEFAULT_NORM(duₙ)
reltol * NONLINEARSOLVE_DEFAULT_NORM(duₙ .+ uₙ)
end
function check_convergence(::AbsTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol)
return all(abs.(duₙ) .≤ abstol)
end
function check_convergence(::Union{AbsNormTerminationMode, AbsSafeTerminationMode,
AbsSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol)
AbsSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol)
return NONLINEARSOLVE_DEFAULT_NORM(duₙ) abstol
end

Expand Down Expand Up @@ -242,8 +388,8 @@ mutable struct NLSolveSafeTerminationResult{T, uType}
end

function NLSolveSafeTerminationResult(u = nothing; best_objective_value = Inf64,
best_objective_value_iteration = 0,
return_code = NLSolveSafeTerminationReturnCode.Failure)
best_objective_value_iteration = 0,
return_code = NLSolveSafeTerminationReturnCode.Failure)
u = u !== nothing ? copy(u) : u
Base.depwarn("NLSolveSafeTerminationResult has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!",
:NLSolveSafeTerminationResult)
Expand Down Expand Up @@ -330,9 +476,9 @@ get_termination_mode(::NLSolveTerminationCondition{mode}) where {mode} = mode

# Don't specify `mode` since the defaults would depend on the package
function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6,
protective_threshold = 1e3, patience_steps::Int = 30,
patience_objective_multiplier = 3,
min_max_factor = 1.3) where {T}
protective_threshold = 1e3, patience_steps::Int = 30,
patience_objective_multiplier = 3,
min_max_factor = 1.3) where {T}
Base.depwarn("NLSolveTerminationCondition has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!",
:NLSolveTerminationCondition)
@assert mode instances(NLSolveTerminationMode.T)
Expand All @@ -346,9 +492,9 @@ function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6,
end

function (cond::NLSolveTerminationCondition)(storage::Union{
NLSolveSafeTerminationResult,
Nothing,
})
NLSolveSafeTerminationResult,
Nothing,
})
mode = get_termination_mode(cond)
# We need both the dispatches to support solvers that don't use the integrator
# interface like SimpleNonlinearSolve
Expand Down Expand Up @@ -438,7 +584,7 @@ end

# Convergence Criteria
@inline function _has_converged(du, u, uprev, cond::NLSolveTerminationCondition{mode},
abstol = cond.abstol, reltol = cond.reltol) where {mode}
abstol = cond.abstol, reltol = cond.reltol) where {mode}
return _has_converged(du, u, uprev, mode, abstol, reltol)
end

Expand Down

0 comments on commit 8fccdcc

Please sign in to comment.