Skip to content

Commit

Permalink
Add a reinit function
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 9, 2024
1 parent a4fa9b5 commit 1e2db4c
Showing 1 changed file with 34 additions and 33 deletions.
67 changes: 34 additions & 33 deletions src/termination_conditions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T
Vector{TT}(undef, mode.max_stalled_steps)
best_value = initial_objective
max_stalled_steps = mode.max_stalled_steps
if ArrayInterface.can_setindex(u_) && step_norm_trace !== nothing
if ArrayInterface.can_setindex(u_) && !(u_ isa Number) && step_norm_trace !== nothing
u_diff_cache = similar(u_)
else
u_diff_cache = u_
Expand All @@ -286,37 +286,38 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T
step_norm_trace, max_stalled_steps, u_diff_cache)
end

# function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{uType, T, dep_retcode}, du,
# u, saved_value_prototype...; abstol = nothing, reltol = nothing,
# kwargs...) where {uType, T, dep_retcode}
# length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype)

# u_ = cache.mode isa AbstractSafeBestNonlinearTerminationMode ?
# (ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing
# cache.u = u_
# cache.retcode = ifelse(dep_retcode, NonlinearSafeTerminationReturnCode.Default,
# ReturnCode.Default)

# cache.abstol = _get_tolerance(abstol, T)
# cache.reltol = _get_tolerance(reltol, T)
# cache.nsteps = 0

# if mode isa AbstractSafeNonlinearTerminationMode
# if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
# initial_objective = maximum(abs, du)
# else
# initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT))
# end
# best_value = initial_objective
# else
# initial_objective = nothing
# objectives_trace = nothing
# best_value = __cvt_real(T, Inf)
# end
# cache.best_objective_value = best_value
# cache.initial_objective = initial_objective
# return cache
# end
function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{uType, T, dep_retcode}, du,
u, saved_value_prototype...; abstol = nothing, reltol = nothing,
kwargs...) where {uType, T, dep_retcode}
length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype)

u_ = cache.mode isa AbstractSafeBestNonlinearTerminationMode ?
(ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing
cache.u = u_
cache.retcode = ifelse(dep_retcode, NonlinearSafeTerminationReturnCode.Default,
ReturnCode.Default)

cache.abstol = _get_tolerance(abstol, T)
cache.reltol = _get_tolerance(reltol, T)
cache.nsteps = 0

if mode isa AbstractSafeNonlinearTerminationMode
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
initial_objective = maximum(abs, du)
else
initial_objective = maximum(abs, du) / (maximum(abs, du .+ u) + eps(TT))
cache.max_stalled_steps !== nothing && (cache.u0_norm = norm(u_, 2))
end
best_value = initial_objective
else
initial_objective = nothing
objectives_trace = nothing
best_value = __cvt_real(T, Inf)
end
cache.best_objective_value = best_value
cache.initial_objective = initial_objective
return cache
end

# This dispatch is needed based on how Terminating Callback works!
# This intentially drops the `abstol` and `reltol` arguments
Expand Down Expand Up @@ -399,7 +400,7 @@ function (cache::NonlinearTerminationModeCache{uType, TT, dep_retcode})(mode::Ab

# Test for stalling if that is not disabled
if cache.step_norm_trace !== nothing
if ArrayInterface.can_setindex(cache.u_diff_cache)
if ArrayInterface.can_setindex(cache.u_diff_cache) && !(u isa Number)
@. cache.u_diff_cache = u - uprev
else
cache.u_diff_cache = u .- uprev
Expand Down

0 comments on commit 1e2db4c

Please sign in to comment.