Skip to content

Commit

Permalink
Better initial objective
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 30, 2023
1 parent 9b2890f commit 89c784a
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions src/termination_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ get_reltol(cache::NonlinearTerminationModeCache) = cache.reltol

function __update_u!!(cache::NonlinearTerminationModeCache, u)
cache.u === nothing && return
if ArrayInterface.can_setindex(cache.u)
if cache.u isa AbstractArray && ArrayInterface.can_setindex(cache.u)
copyto!(cache.u, u)
else
cache.u = u
Expand All @@ -77,21 +77,27 @@ function _get_tolerance(::Nothing, ::Type{T}) where {T}
return _get_tolerance(η, T)
end

function SciMLBase.init(u::Union{AbstractArray{T}, T},
function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T}, T},
mode::AbstractNonlinearTerminationMode; abstol = nothing, reltol = nothing,
kwargs...) where {T <: Number}
abstol = _get_tolerance(abstol, T)
reltol = _get_tolerance(reltol, T)
best_value = __cvt_real(T, Inf)
TT = typeof(abstol)
u_ = mode isa AbstractSafeBestNonlinearTerminationMode ?
(ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing
if mode isa AbstractSafeNonlinearTerminationMode
initial_objective = TT(0)
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
initial_objective = NONLINEARSOLVE_DEFAULT_NORM(du)
else
initial_objective = NONLINEARSOLVE_DEFAULT_NORM(du) /
(NONLINEARSOLVE_DEFAULT_NORM(du .+ u) + eps(TT))
end
objectives_trace = Vector{TT}(undef, mode.patience_steps)
best_value = initial_objective
else
initial_objective = nothing
objectives_trace = nothing
best_value = __cvt_real(T, Inf)
end
return NonlinearTerminationModeCache{typeof(u_), TT, typeof(mode),
typeof(initial_objective), typeof(objectives_trace)}(u_,
Expand Down Expand Up @@ -122,6 +128,13 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
criteria = cache.reltol
end

# Protective Break
if isinf(objective) || isnan(objective) ||
(objective cache.initial_objective * cache.mode.protective_threshold * length(du))
cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination
return true
end

# Check if best solution
if mode isa AbstractSafeBestNonlinearTerminationMode &&
objective < cache.best_objective_value
Expand Down Expand Up @@ -154,12 +167,6 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi
end
end

# Protective Break
if objective cache.initial_objective * cache.mode.protective_threshold * length(du)
cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination
return true
end

cache.retcode = NonlinearSafeTerminationReturnCode.Failure
return false
end
Expand Down Expand Up @@ -238,9 +245,10 @@ function NLSolveSafeTerminationResult(u = nothing; best_objective_value = Inf64,
best_objective_value_iteration = 0,
return_code = NLSolveSafeTerminationReturnCode.Failure)
u = u !== nothing ? copy(u) : u
Base.depwarn("NLSolveSafeTerminationResult has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!",
:NLSolveSafeTerminationResult)
return NLSolveSafeTerminationResult{typeof(best_objective_value), typeof(u)}(u,
best_objective_value,
best_objective_value_iteration, return_code)
best_objective_value, best_objective_value_iteration, return_code)
end

const BASIC_TERMINATION_MODES = (NLSolveTerminationMode.SteadyStateDefault,
Expand Down Expand Up @@ -296,6 +304,8 @@ Define the termination criteria for the NonlinearProblem or SteadyStateProblem.
* `protective_threshold`: If the objective value increased by this factor wrt initial objective terminate immediately.
* `patience_steps`: If objective is within `patience_objective_multiplier` factor of the criteria and no improvement within `min_max_factor` has happened then terminate.
!!! warning
This has been deprecated and will be removed in the next major release. Please use the new dispatch based termination conditions API.
"""
struct NLSolveTerminationCondition{mode, T,
S <: Union{<:NLSolveSafeTerminationOptions, Nothing}}
Expand Down Expand Up @@ -323,6 +333,8 @@ function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6,
protective_threshold = 1e3, patience_steps::Int = 30,
patience_objective_multiplier = 3,
min_max_factor = 1.3) where {T}
Base.depwarn("NLSolveTerminationCondition has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!",
:NLSolveTerminationCondition)
@assert mode instances(NLSolveTerminationMode.T)
options = if mode SAFE_TERMINATION_MODES
NLSolveSafeTerminationOptions(protective_threshold, patience_steps,
Expand Down

0 comments on commit 89c784a

Please sign in to comment.