diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index 9bad48751..a91f733ac 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -259,7 +259,7 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T Vector{TT}(undef, mode.max_stalled_steps) best_value = initial_objective max_stalled_steps = mode.max_stalled_steps - if ArrayInterface.can_setindex(u_) && step_norm_trace !== nothing + if ArrayInterface.can_setindex(u_) && !(u_ isa Number) && step_norm_trace !== nothing u_diff_cache = similar(u_) else u_diff_cache = u_ @@ -286,37 +286,38 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T step_norm_trace, max_stalled_steps, u_diff_cache) end -# function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{uType, T, dep_retcode}, du, -# u, saved_value_prototype...; abstol = nothing, reltol = nothing, -# kwargs...) where {uType, T, dep_retcode} -# length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype) - -# u_ = cache.mode isa AbstractSafeBestNonlinearTerminationMode ? -# (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing -# cache.u = u_ -# cache.retcode = ifelse(dep_retcode, NonlinearSafeTerminationReturnCode.Default, -# ReturnCode.Default) - -# cache.abstol = _get_tolerance(abstol, T) -# cache.reltol = _get_tolerance(reltol, T) -# cache.nsteps = 0 - -# if mode isa AbstractSafeNonlinearTerminationMode -# if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode -# initial_objective = maximum(abs, du) -# else -# initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT)) -# end -# best_value = initial_objective -# else -# initial_objective = nothing -# objectives_trace = nothing -# best_value = __cvt_real(T, Inf) -# end -# cache.best_objective_value = best_value -# cache.initial_objective = initial_objective -# return cache -# end +function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{uType, T, dep_retcode}, du, + u, saved_value_prototype...; abstol = nothing, reltol = nothing, + kwargs...) where {uType, T, dep_retcode} + length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype) + + u_ = cache.mode isa AbstractSafeBestNonlinearTerminationMode ? + (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing + cache.u = u_ + cache.retcode = ifelse(dep_retcode, NonlinearSafeTerminationReturnCode.Default, + ReturnCode.Default) + + cache.abstol = _get_tolerance(abstol, T) + cache.reltol = _get_tolerance(reltol, T) + cache.nsteps = 0 + + if mode isa AbstractSafeNonlinearTerminationMode + if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode + initial_objective = maximum(abs, du) + else + initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT)) + cache.max_stalled_steps !== nothing && (cache.u0_norm = norm(u_, 2)) + end + best_value = initial_objective + else + initial_objective = nothing + objectives_trace = nothing + best_value = __cvt_real(T, Inf) + end + cache.best_objective_value = best_value + cache.initial_objective = initial_objective + return cache +end # This dispatch is needed based on how Terminating Callback works! # This intentially drops the `abstol` and `reltol` arguments @@ -399,7 +400,7 @@ function (cache::NonlinearTerminationModeCache{uType, TT, dep_retcode})(mode::Ab # Test for stalling if that is not disabled if cache.step_norm_trace !== nothing - if ArrayInterface.can_setindex(cache.u_diff_cache) + if ArrayInterface.can_setindex(cache.u_diff_cache) && !(u isa Number) @. cache.u_diff_cache = u - uprev else cache.u_diff_cache = u .- uprev