From a4fa9b5d11fca3a5bb512a323d85fe5116cf7b0b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 9 Jan 2024 03:08:50 -0500 Subject: [PATCH 1/3] Implement Stalling and Use ReturnCode --- Project.toml | 4 +- src/termination_conditions.jl | 145 ++++++++++++++++++++++++++++------ 2 files changed, 122 insertions(+), 27 deletions(-) diff --git a/Project.toml b/Project.toml index 15f993706..c224cfe50 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "DiffEqBase" uuid = "2b5f629d-d688-5b77-993f-72d75c75574e" authors = ["Chris Rackauckas "] -version = "6.145.6" +version = "6.146.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -79,7 +79,7 @@ PrecompileTools = "1" Printf = "1.9" RecursiveArrayTools = "2, 3" Reexport = "1.0" -SciMLBase = "2.12.0" +SciMLBase = "2.18.0" SciMLOperators = "0.2, 0.3" Setfield = "0.8, 1" SparseArrays = "1.9" diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index d3761ab05..9bad48751 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -2,6 +2,9 @@ NonlinearSafeTerminationReturnCode Return Codes for the safe nonlinear termination conditions. + +These return codes have been deprecated. Termination Conditions will return +`SciMLBase.Retcode.T` starting from v7. """ @enumx NonlinearSafeTerminationReturnCode begin """ @@ -116,15 +119,16 @@ for the last `patience_steps` + terminate if the solution blows up (diverges). ```julia RelSafeTerminationMode(; protective_threshold = nothing, patience_steps = 100, - patience_objective_multiplier = 3, min_max_factor = 1.3) + patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = 20) ``` """ -Base.@kwdef struct RelSafeTerminationMode{T1, T2, T3} <: +Base.@kwdef struct RelSafeTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <: AbstractSafeNonlinearTerminationMode protective_threshold::T1 = nothing patience_steps::Int = 100 patience_objective_multiplier::T2 = 3 min_max_factor::T3 = 1.3 + max_stalled_steps::T4 = nothing end @doc doc""" @@ -137,15 +141,16 @@ for the last `patience_steps` + terminate if the solution blows up (diverges). ```julia AbsSafeTerminationMode(; protective_threshold = nothing, patience_steps = 100, - patience_objective_multiplier = 3, min_max_factor = 1.3) + patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = 20) ``` """ -Base.@kwdef struct AbsSafeTerminationMode{T1, T2, T3} <: +Base.@kwdef struct AbsSafeTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <: AbstractSafeNonlinearTerminationMode protective_threshold::T1 = nothing patience_steps::Int = 100 patience_objective_multiplier::T2 = 3 min_max_factor::T3 = 1.3 + max_stalled_steps::T4 = nothing end @doc doc""" @@ -157,15 +162,16 @@ Essentially [`RelSafeTerminationMode`](@ref), but caches the best solution found ```julia RelSafeBestTerminationMode(; protective_threshold = nothing, patience_steps = 100, - patience_objective_multiplier = 3, min_max_factor = 1.3) + patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = 20) ``` """ -Base.@kwdef struct RelSafeBestTerminationMode{T1, T2, T3} <: +Base.@kwdef struct RelSafeBestTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <: AbstractSafeBestNonlinearTerminationMode protective_threshold::T1 = nothing patience_steps::Int = 100 patience_objective_multiplier::T2 = 3 min_max_factor::T3 = 1.3 + max_stalled_steps::T4 = nothing end @doc doc""" @@ -177,21 +183,23 @@ Essentially [`AbsSafeTerminationMode`](@ref), but caches the best solution found ```julia AbsSafeBestTerminationMode(; protective_threshold = nothing, patience_steps = 100, - patience_objective_multiplier = 3, min_max_factor = 1.3) + patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = 20) ``` """ -Base.@kwdef struct AbsSafeBestTerminationMode{T1, T2, T3} <: +Base.@kwdef struct AbsSafeBestTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <: AbstractSafeBestNonlinearTerminationMode protective_threshold::T1 = nothing patience_steps::Int = 100 patience_objective_multiplier::T2 = 3 min_max_factor::T3 = 1.3 + max_stalled_steps::T4 = nothing end -mutable struct NonlinearTerminationModeCache{uType, T, - M <: AbstractNonlinearTerminationMode, I, OT, SV} +mutable struct NonlinearTerminationModeCache{uType, T, dep_retcode, + M <: AbstractNonlinearTerminationMode, I, OT, SV, + R <: Union{NonlinearSafeTerminationReturnCode.T, ReturnCode.T}, UN, ST, MSS} u::uType - retcode::NonlinearSafeTerminationReturnCode.T + retcode::R abstol::T reltol::T best_objective_value::T @@ -200,6 +208,10 @@ mutable struct NonlinearTerminationModeCache{uType, T, objectives_trace::OT nsteps::Int saved_values::SV + u0_norm::UN + step_norm_trace::ST + max_stalled_steps::MSS + u_diff_cache::uType end get_termination_mode(cache::NonlinearTerminationModeCache) = cache.mode @@ -227,7 +239,8 @@ end function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T}, T}, mode::AbstractNonlinearTerminationMode, saved_value_prototype...; - abstol = nothing, reltol = nothing, kwargs...) where {T <: Number} + use_deprecated_retcodes::Val{D} = Val(true), # Remove in v8, warn in v7 + abstol = nothing, reltol = nothing, kwargs...) where {D, T <: Number} abstol = _get_tolerance(abstol, T) reltol = _get_tolerance(reltol, T) TT = typeof(abstol) @@ -236,25 +249,74 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T if mode isa AbstractSafeNonlinearTerminationMode if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode initial_objective = maximum(abs, du) + u0_norm = nothing else initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT)) + u0_norm = mode.max_stalled_steps === nothing ? nothing : norm(u, 2) end objectives_trace = Vector{TT}(undef, mode.patience_steps) + step_norm_trace = mode.max_stalled_steps === nothing ? nothing : + Vector{TT}(undef, mode.max_stalled_steps) best_value = initial_objective + max_stalled_steps = mode.max_stalled_steps + if ArrayInterface.can_setindex(u_) && step_norm_trace !== nothing + u_diff_cache = similar(u_) + else + u_diff_cache = u_ + end else initial_objective = nothing objectives_trace = nothing + u0_norm = nothing + step_norm_trace = nothing best_value = __cvt_real(T, Inf) + max_stalled_steps = nothing + u_diff_cache = u_ end length(saved_value_prototype) == 0 && (saved_value_prototype = nothing) - return NonlinearTerminationModeCache{typeof(u_), TT, typeof(mode), - typeof(initial_objective), typeof(objectives_trace), - typeof(saved_value_prototype)}(u_, NonlinearSafeTerminationReturnCode.Default, - abstol, reltol, best_value, mode, initial_objective, objectives_trace, 0, - saved_value_prototype) -end + retcode = ifelse(D, NonlinearSafeTerminationReturnCode.Default, ReturnCode.Default) + + return NonlinearTerminationModeCache{typeof(u_), TT, D, typeof(mode), + typeof(initial_objective), typeof(objectives_trace), typeof(saved_value_prototype), + typeof(retcode), typeof(u0_norm), typeof(step_norm_trace), + typeof(max_stalled_steps)}(u_, retcode, abstol, reltol, best_value, mode, + initial_objective, objectives_trace, 0, saved_value_prototype, u0_norm, + step_norm_trace, max_stalled_steps, u_diff_cache) +end + +# function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{uType, T, dep_retcode}, du, +# u, saved_value_prototype...; abstol = nothing, reltol = nothing, +# kwargs...) where {uType, T, dep_retcode} +# length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype) + +# u_ = cache.mode isa AbstractSafeBestNonlinearTerminationMode ? +# (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing +# cache.u = u_ +# cache.retcode = ifelse(dep_retcode, NonlinearSafeTerminationReturnCode.Default, +# ReturnCode.Default) + +# cache.abstol = _get_tolerance(abstol, T) +# cache.reltol = _get_tolerance(reltol, T) +# cache.nsteps = 0 + +# if mode isa AbstractSafeNonlinearTerminationMode +# if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode +# initial_objective = maximum(abs, du) +# else +# initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT)) +# end +# best_value = initial_objective +# else +# initial_objective = nothing +# objectives_trace = nothing +# best_value = __cvt_real(T, Inf) +# end +# cache.best_objective_value = best_value +# cache.initial_objective = initial_objective +# return cache +# end # This dispatch is needed based on how Terminating Callback works! # This intentially drops the `abstol` and `reltol` arguments @@ -273,8 +335,8 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminati return check_convergence(mode, du, u, uprev, cache.abstol, cache.reltol) end -function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTerminationMode, - du, u, uprev, args...) +function (cache::NonlinearTerminationModeCache{uType, TT, dep_retcode})(mode::AbstractSafeNonlinearTerminationMode, + du, u, uprev, args...) where {uType, TT, dep_retcode} if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode objective = maximum(abs, du) criteria = cache.abstol @@ -285,13 +347,15 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi # Protective Break if isinf(objective) || isnan(objective) - cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.ProtectiveTermination, ReturnCode.Unstable) return true end ## By default we turn this off since it has the potential for false positives if cache.mode.protective_threshold !== nothing && (objective > cache.initial_objective * cache.mode.protective_threshold * length(du)) - cache.retcode = NonlinearSafeTerminationReturnCode.ProtectiveTermination + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.ProtectiveTermination, ReturnCode.Unstable) return true end @@ -307,7 +371,8 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi # Main Termination Condition if objective ≤ criteria - cache.retcode = NonlinearSafeTerminationReturnCode.Success + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.Success, ReturnCode.Success) return true end @@ -324,13 +389,43 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractSafeNonlinearTermi min_obj, max_obj = extrema(cache.objectives_trace) end if min_obj < cache.mode.min_max_factor * max_obj - cache.retcode = NonlinearSafeTerminationReturnCode.PatienceTermination + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.PatienceTermination, + ReturnCode.Stalled) + return true + end + end + end + + # Test for stalling if that is not disabled + if cache.step_norm_trace !== nothing + if ArrayInterface.can_setindex(cache.u_diff_cache) + @. cache.u_diff_cache = u - uprev + else + cache.u_diff_cache = u .- uprev + end + du_norm = norm(cache.u_diff_cache, 2) + cache.step_norm_trace[mod1(cache.nsteps, length(cache.step_norm_trace))] = du_norm + if cache.nsteps ≥ cache.mode.max_stalled_steps + max_step_norm = maximum(cache.step_norm_trace) + if cache.mode isa AbsSafeTerminationMode || + cache.mode isa AbsSafeBestTerminationMode + stalled_step = max_step_norm ≤ cache.abstol + else + stalled_step = max_step_norm ≤ + cache.reltol * (max_step_norm + cache.u0_norm) + end + if stalled_step + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.PatienceTermination, + ReturnCode.Stalled) return true end end end - cache.retcode = NonlinearSafeTerminationReturnCode.Failure + cache.retcode = ifelse(dep_retcode, + NonlinearSafeTerminationReturnCode.Failure, ReturnCode.Failure) return false end From 0a1f5fc3dcfd39c7840cec5c1a107a3d3942b683 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 9 Jan 2024 03:21:40 -0500 Subject: [PATCH 2/3] Add a reinit function --- src/termination_conditions.jl | 68 ++++++++++++++++++----------------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index 9bad48751..148fe4990 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -259,7 +259,7 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T Vector{TT}(undef, mode.max_stalled_steps) best_value = initial_objective max_stalled_steps = mode.max_stalled_steps - if ArrayInterface.can_setindex(u_) && step_norm_trace !== nothing + if ArrayInterface.can_setindex(u_) && !(u_ isa Number) && step_norm_trace !== nothing u_diff_cache = similar(u_) else u_diff_cache = u_ @@ -286,37 +286,39 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T step_norm_trace, max_stalled_steps, u_diff_cache) end -# function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{uType, T, dep_retcode}, du, -# u, saved_value_prototype...; abstol = nothing, reltol = nothing, -# kwargs...) where {uType, T, dep_retcode} -# length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype) - -# u_ = cache.mode isa AbstractSafeBestNonlinearTerminationMode ? -# (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing -# cache.u = u_ -# cache.retcode = ifelse(dep_retcode, NonlinearSafeTerminationReturnCode.Default, -# ReturnCode.Default) - -# cache.abstol = _get_tolerance(abstol, T) -# cache.reltol = _get_tolerance(reltol, T) -# cache.nsteps = 0 - -# if mode isa AbstractSafeNonlinearTerminationMode -# if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode -# initial_objective = maximum(abs, du) -# else -# initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT)) -# end -# best_value = initial_objective -# else -# initial_objective = nothing -# objectives_trace = nothing -# best_value = __cvt_real(T, Inf) -# end -# cache.best_objective_value = best_value -# cache.initial_objective = initial_objective -# return cache -# end +function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{uType, T, dep_retcode}, du, + u, saved_value_prototype...; abstol = nothing, reltol = nothing, + kwargs...) where {uType, T, dep_retcode} + length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype) + + u_ = cache.mode isa AbstractSafeBestNonlinearTerminationMode ? + (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing + cache.u = u_ + cache.retcode = ifelse(dep_retcode, NonlinearSafeTerminationReturnCode.Default, + ReturnCode.Default) + + cache.abstol = _get_tolerance(abstol, T) + cache.reltol = _get_tolerance(reltol, T) + cache.nsteps = 0 + + mode = get_termination_mode(cache) + if mode isa AbstractSafeNonlinearTerminationMode + if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode + initial_objective = maximum(abs, du) + else + initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT)) + cache.max_stalled_steps !== nothing && (cache.u0_norm = norm(u_, 2)) + end + best_value = initial_objective + else + initial_objective = nothing + objectives_trace = nothing + best_value = __cvt_real(T, Inf) + end + cache.best_objective_value = best_value + cache.initial_objective = initial_objective + return cache +end # This dispatch is needed based on how Terminating Callback works! # This intentially drops the `abstol` and `reltol` arguments @@ -399,7 +401,7 @@ function (cache::NonlinearTerminationModeCache{uType, TT, dep_retcode})(mode::Ab # Test for stalling if that is not disabled if cache.step_norm_trace !== nothing - if ArrayInterface.can_setindex(cache.u_diff_cache) + if ArrayInterface.can_setindex(cache.u_diff_cache) && !(u isa Number) @. cache.u_diff_cache = u - uprev else cache.u_diff_cache = u .- uprev From a0cf82738e7781bef14f24e8aeaf948fcf3f8a5f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 14 Jan 2024 06:38:03 -0500 Subject: [PATCH 3/3] up the compat --- Project.toml | 2 +- src/termination_conditions.jl | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index c224cfe50..d6e3fa802 100644 --- a/Project.toml +++ b/Project.toml @@ -79,7 +79,7 @@ PrecompileTools = "1" Printf = "1.9" RecursiveArrayTools = "2, 3" Reexport = "1.0" -SciMLBase = "2.18.0" +SciMLBase = "2.19.0" SciMLOperators = "0.2, 0.3" Setfield = "0.8, 1" SparseArrays = "1.9" diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index 148fe4990..c71beebe3 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -119,7 +119,7 @@ for the last `patience_steps` + terminate if the solution blows up (diverges). ```julia RelSafeTerminationMode(; protective_threshold = nothing, patience_steps = 100, - patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = 20) + patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = nothing) ``` """ Base.@kwdef struct RelSafeTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <: @@ -141,7 +141,7 @@ for the last `patience_steps` + terminate if the solution blows up (diverges). ```julia AbsSafeTerminationMode(; protective_threshold = nothing, patience_steps = 100, - patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = 20) + patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = nothing) ``` """ Base.@kwdef struct AbsSafeTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <: @@ -162,7 +162,7 @@ Essentially [`RelSafeTerminationMode`](@ref), but caches the best solution found ```julia RelSafeBestTerminationMode(; protective_threshold = nothing, patience_steps = 100, - patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = 20) + patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = nothing) ``` """ Base.@kwdef struct RelSafeBestTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <: @@ -183,7 +183,7 @@ Essentially [`AbsSafeTerminationMode`](@ref), but caches the best solution found ```julia AbsSafeBestTerminationMode(; protective_threshold = nothing, patience_steps = 100, - patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = 20) + patience_objective_multiplier = 3, min_max_factor = 1.3, max_stalled_steps = nothing) ``` """ Base.@kwdef struct AbsSafeBestTerminationMode{T1, T2, T3, T4 <: Union{Nothing, Int}} <: