Skip to content

Commit

Permalink
Rework the Termination Condition API to be type stable
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 30, 2023
1 parent e9316ae commit d63e00b
Showing 1 changed file with 191 additions and 0 deletions.
191 changes: 191 additions & 0 deletions src/termination_conditions.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,194 @@
@enumx NonlinearSafeTerminationReturnCode begin
Success
Default
PatienceTermination
ProtectiveTermination
Failure
end

abstract type AbstractNonlinearTerminationMode end
abstract type AbstractSafeNonlinearTerminationMode <: AbstractNonlinearTerminationMode end
abstract type AbstractSafeBestNonlinearTerminationMode <:
AbstractSafeNonlinearTerminationMode end

# TODO: Add a mode where the user can pass in custom termination criteria function
for mode in (:SteadyStateDiffEqTerminationMode, :SimpleNonlinearSolveTerminationMode,
:NormTerminationMode, :RelTerminationMode, :RelNormTerminationMode, :AbsTerminationMode,
:AbsNormTerminationMode)
@eval begin
struct $(mode) <: AbstractNonlinearTerminationMode end
end
end

for mode in (:RelSafeTerminationMode, :AbsSafeTerminationMode)
@eval begin
Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeNonlinearTerminationMode
protective_threshold::T1 = 1000
patience_steps::Int = 30
patience_objective_multiplier::T2 = 3
min_max_factor::T3 = 1.3
end
end
end

for mode in (:RelSafeBestTerminationMode, :AbsSafeBestTerminationMode)
@eval begin
Base.@kwdef struct $(mode){T1, T2, T3} <: AbstractSafeBestNonlinearTerminationMode
protective_threshold::T1 = 1000
patience_steps::Int = 30
patience_objective_multiplier::T2 = 3
min_max_factor::T3 = 1.3
end
end
end

mutable struct NonlinearTerminationModeCache{uType, T,
M <: AbstractNonlinearTerminationMode, I, OT}
u::uType
retcode::NonlinearSafeTerminationReturnCode.T
abstol::T
reltol::T
best_objective_value::T
mode::M
initial_objective::I
objectives_trace::OT
nsteps::Int
end

function __update_u!!(cache::NonlinearTerminationModeCache, u)
cache.u === nothing && return
if ArrayInterface.can_setindex(cache.u)
copyto!(cache.u, u)

Check warning on line 61 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L58-L61

Added lines #L58 - L61 were not covered by tests
else
cache.u = u

Check warning on line 63 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L63

Added line #L63 was not covered by tests
end
end

__cvt_real(::Type{T}, ::Nothing) where {T} = nothing
__cvt_real(::Type{T}, x) where {T} = real(T(x))

Check warning on line 68 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L67-L68

Added lines #L67 - L68 were not covered by tests

_get_tolerance(η, ::Type{T}) where {T} = __cvt_real(T, η)
function _get_tolerance(::Nothing, ::Type{T}) where {T}
η = real(oneunit(T)) * (eps(real(one(T))))^(4 // 5)
return _get_tolerance(η, T)

Check warning on line 73 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L70-L73

Added lines #L70 - L73 were not covered by tests
end

function SciMLBase.init(u::AbstractArray{T}, mode::AbstractNonlinearTerminationMode;

Check warning on line 76 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L76

Added line #L76 was not covered by tests
abstol = nothing, reltol = nothing, kwargs...) where {T}
abstol = _get_tolerance(abstol, T)
reltol = _get_tolerance(reltol, T)
best_value = __cvt_real(T, Inf)
TT = typeof(abstol)
u_ = mode isa AbstractSafeBestNonlinearTerminationMode ?

Check warning on line 82 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L78-L82

Added lines #L78 - L82 were not covered by tests
(ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing
if mode isa AbstractSafeNonlinearTerminationMode
initial_objective = TT(0)
objectives_trace = Vector{TT}(undef, mode.patience_steps)

Check warning on line 86 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L84-L86

Added lines #L84 - L86 were not covered by tests
else
initial_objective = nothing
objectives_trace = nothing

Check warning on line 89 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L88-L89

Added lines #L88 - L89 were not covered by tests
end
return NonlinearTerminationModeCache{typeof(u_), TT, typeof(mode),

Check warning on line 91 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L91

Added line #L91 was not covered by tests
typeof(initial_objective), typeof(objectives_trace)}(u_,
NonlinearSafeTerminationReturnCode.Default, abstol, reltol, best_value, mode,
initial_objective, objectives_trace, 0)
end

# This dispatch is needed based on how Terminating Callback works!
# This intentially drops the `abstol` and `reltol` arguments
function (cache::NonlinearTerminationModeCache)(integrator, _, _, min_t)
return cache(cache.mode, get_du(integrator), integrator.u, integrator.uprev)

Check warning on line 100 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L99-L100

Added lines #L99 - L100 were not covered by tests
end
(cache::NonlinearTerminationModeCache)(du, u, uprev) = cache(cache.mode, du, u, uprev)

Check warning on line 102 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L102

Added line #L102 was not covered by tests

function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminationMode, du,

Check warning on line 104 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L104

Added line #L104 was not covered by tests
u, uprev)
return check_convergence(mode, du, u, uprev, cache.abstol, cache.reltol)

Check warning on line 106 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L106

Added line #L106 was not covered by tests
end

function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTerminationMode,

Check warning on line 109 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L109

Added line #L109 was not covered by tests
du, u, uprev)
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
objective = NLSOLVE_DEFAULT_NORM(du)
criteria = cache.abstol

Check warning on line 113 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L111-L113

Added lines #L111 - L113 were not covered by tests
else
objective = NLSOLVE_DEFAULT_NORM(du) /

Check warning on line 115 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L115

Added line #L115 was not covered by tests
(NLSOLVE_DEFAULT_NORM(du .+ u) + eps(cache.abstol))
criteria = cache.reltol

Check warning on line 117 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L117

Added line #L117 was not covered by tests
end

# Check if best solution
if mode isa AbstractSafeBestNonlinearTerminationMode &&

Check warning on line 121 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L121

Added line #L121 was not covered by tests
objective < cache.best_objective_value
cache.best_objective_value = objective
__update_u!!(cache, u)

Check warning on line 124 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L123-L124

Added lines #L123 - L124 were not covered by tests
end

# Main Termination Condition
if objective criteria
cache.retcode = NonlinearSafeTerminationReturnCode.Success
return true

Check warning on line 130 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L128-L130

Added lines #L128 - L130 were not covered by tests
end

# Terminate if there has been no improvement for the last `patience_steps`
cache.nsteps += 1
cache.nsteps == 1 && (cache.initial_objective = objective)
cache.objectives_trace[mod1(cache.nsteps, length(cache.objectives_trace))] = objective

Check warning on line 136 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L134-L136

Added lines #L134 - L136 were not covered by tests

if objective cache.mode.patience_objective_multiplier * criteria
if cache.nsteps cache.mode.patience_steps
if cache.nsteps < length(cache.objectives_trace)
min_obj, max_obj = extrema(@view(cache.objectives_trace[1:cache.nsteps]))

Check warning on line 141 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L138-L141

Added lines #L138 - L141 were not covered by tests
else
min_obj, max_obj = extrema(cache.objectives_trace)

Check warning on line 143 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L143

Added line #L143 was not covered by tests
end
if min_obj < cache.mode.min_max_factor * max_obj
cache.retcode = NonlinearSafeTerminationReturnCode.PatienceTermination
return true

Check warning on line 147 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L145-L147

Added lines #L145 - L147 were not covered by tests
end
end
end

# Protective Break
if objective cache.initial_objective * cache.mode.protective_threshold * length(du)
cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination
return true

Check warning on line 155 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L153-L155

Added lines #L153 - L155 were not covered by tests
end

cache.retcode = NonlinearSafeTerminationReturnCode.Failure
return false

Check warning on line 159 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L158-L159

Added lines #L158 - L159 were not covered by tests
end

function check_convergence(::SteadyStateDiffEqTerminationMode, duₙ, uₙ, uₙ₋₁, abstol,

Check warning on line 162 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L162

Added line #L162 was not covered by tests
reltol)
return all((abs.(duₙ) .≤ abstol) .| (abs.(duₙ) .≤ reltol .* abs.(uₙ)))

Check warning on line 164 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L164

Added line #L164 was not covered by tests
end
function check_convergence(::SimpleNonlinearSolveTerminationMode, duₙ, uₙ, uₙ₋₁, abstol,

Check warning on line 166 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L166

Added line #L166 was not covered by tests
reltol)
return all((abs.(duₙ) .≤ abstol) .| (abs.(duₙ) .≤ reltol .* abs.(uₙ))) ||

Check warning on line 168 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L168

Added line #L168 was not covered by tests
isapprox(uₙ, uₙ₋₁; atol = abstol, rtol = reltol)
end
function check_convergence(::NormTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol)
du_norm = NLSOLVE_DEFAULT_NORM(duₙ)
return du_norm abstol || du_norm reltol * NLSOLVE_DEFAULT_NORM(duₙ .+ uₙ)

Check warning on line 173 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L171-L173

Added lines #L171 - L173 were not covered by tests
end
function check_convergence(::RelNormTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol)
return all(abs.(duₙ) .≤ reltol .* abs.(uₙ))

Check warning on line 176 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L175-L176

Added lines #L175 - L176 were not covered by tests
end
function check_convergence(::Union{RelNormTerminationMode, RelSafeTerminationMode,

Check warning on line 178 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L178

Added line #L178 was not covered by tests
RelSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol)
return NLSOLVE_DEFAULT_NORM(duₙ) reltol * NLSOLVE_DEFAULT_NORM(duₙ .+ uₙ)

Check warning on line 180 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L180

Added line #L180 was not covered by tests
end
function check_convergence(::AbsTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol)
return all(abs.(duₙ) .≤ abstol)

Check warning on line 183 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L182-L183

Added lines #L182 - L183 were not covered by tests
end
function check_convergence(::Union{AbsNormTerminationMode, AbsSafeTerminationMode,

Check warning on line 185 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L185

Added line #L185 was not covered by tests
AbsSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol)
return NLSOLVE_DEFAULT_NORM(duₙ) abstol

Check warning on line 187 in src/termination_conditions.jl

View check run for this annotation

Codecov / codecov/patch

src/termination_conditions.jl#L187

Added line #L187 was not covered by tests
end

# NOTE: Deprecate the following API eventually. This API leads to quite a bit of type
# instability
@enumx NLSolveSafeTerminationReturnCode begin
Success
PatienceTermination
Expand Down

0 comments on commit d63e00b

Please sign in to comment.