Skip to content

Commit

Permalink
Move other algorithms to use termination conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsh530 committed Oct 26, 2023
1 parent 416e656 commit f20c9bc
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 28 deletions.
53 changes: 44 additions & 9 deletions src/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ end
f
alg
u
u_prev
du
fu
fu2
Expand All @@ -46,17 +47,21 @@ end
internalnorm
retcode::ReturnCode.T
abstol
reltol
reset_tolerance
reset_check
prob
stats::NLStats
lscache
termination_condition
tc_storage
end

get_fu(cache::GeneralBroydenCache) = cache.fu

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyden, args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
Expand All @@ -65,23 +70,38 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::GeneralBroyde
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(eltype(u))) :
alg.reset_tolerance
reset_check = x -> abs(x) reset_tolerance
return GeneralBroydenCache{iip}(f, alg, u, _mutable_zero(u), fu, zero(fu),

abstol, reltol, termination_condition = _init_termination_elements(abstol,
reltol,
termination_condition,
eltype(u))

mode = DiffEqBase.get_termination_mode(termination_condition)

storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing
return GeneralBroydenCache{iip}(f, alg, u, zero(u), _mutable_zero(u), fu, zero(fu),
zero(fu), p, J⁻¹, zero(_reshape(fu, 1, :)), _mutable_zero(u), false, 0,
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reset_tolerance,
alg.max_resets, maxiters, internalnorm, ReturnCode.Default, abstol, reltol,
reset_tolerance,
reset_check, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition,
storage)
end

function perform_step!(cache::GeneralBroydenCache{true})
@unpack f, p, du, fu, fu2, dfu, u, J⁻¹, J⁻¹df, J⁻¹₂ = cache
@unpack f, p, du, fu, fu2, dfu, u, u_prev, J⁻¹, J⁻¹df, J⁻¹₂, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)
T = eltype(u)

mul!(_vec(du), J⁻¹, -_vec(fu))
α = perform_linesearch!(cache.lscache, u, du)
_axpy!(α, du, u)
f(fu2, u, p)

cache.internalnorm(fu2) < cache.abstol && (cache.force_stop = true)
termination_condition(fu2, u, u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)
cache.stats.nf += 1

cache.force_stop && return nothing
Expand All @@ -106,20 +126,25 @@ function perform_step!(cache::GeneralBroydenCache{true})
mul!(J⁻¹, _vec(du), J⁻¹₂, 1, 1)
end
fu .= fu2
@. u_prev = u

return nothing
end

function perform_step!(cache::GeneralBroydenCache{false})
@unpack f, p = cache
@unpack f, p, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)

T = eltype(cache.u)

cache.du = _restructure(cache.du, cache.J⁻¹ * -_vec(cache.fu))
α = perform_linesearch!(cache.lscache, cache.u, cache.du)
cache.u = cache.u .+ α * cache.du
cache.fu2 = f(cache.u, p)

cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)
cache.stats.nf += 1

cache.force_stop && return nothing
Expand All @@ -142,12 +167,15 @@ function perform_step!(cache::GeneralBroydenCache{false})
cache.J⁻¹ = cache.J⁻¹ .+ _vec(cache.du) * cache.J⁻¹₂
end
cache.fu = cache.fu2
cache.u_prev = @. cache.u

return nothing
end

function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
abstol = cache.abstol, reltol = cache.reltol,
termination_condition = cache.termination_condition,
maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
Expand All @@ -157,7 +185,14 @@ function SciMLBase.reinit!(cache::GeneralBroydenCache{iip}, u0 = cache.u; p = ca
cache.u = u0
cache.fu = cache.f(cache.u, p)
end
termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,
termination_condition)

cache.abstol = abstol
cache.reltol = reltol
cache.termination_condition = termination_condition
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
Expand Down
51 changes: 42 additions & 9 deletions src/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ end
f
alg
u
u_prev
fu
fu2
du
Expand All @@ -65,7 +66,8 @@ end
get_fu(cache::GeneralKlementCache) = cache.fu

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKlement, args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
Expand All @@ -84,16 +86,30 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::GeneralKleme
linsolve = __setup_linsolve(J, _vec(fu), _vec(du), p, alg)
end

return GeneralKlementCache{iip}(f, alg, u, fu, zero(fu), du, p, linsolve,
abstol, reltol, termination_condition = _init_termination_elements(abstol,
reltol,
termination_condition,
eltype(u))

mode = DiffEqBase.get_termination_mode(termination_condition)

storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing

return GeneralKlementCache{iip}(f, alg, u, zero(u), fu, zero(fu), du, p, linsolve,
J, zero(J), zero(J), _vec(zero(fu)), _vec(zero(fu)), 0, false,
maxiters, internalnorm, ReturnCode.Default, abstol, prob, NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob,
NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition,
storage)
end

function perform_step!(cache::GeneralKlementCache{true})
@unpack u, fu, f, p, alg, J, linsolve, du = cache
@unpack u, u_prev, fu, f, p, alg, J, linsolve, du, tc_storage = cache
T = eltype(J)

termination_condition = cache.termination_condition(tc_storage)

singular, fact_done = _try_factorize_and_check_singular!(linsolve, J)

if singular
Expand All @@ -118,7 +134,8 @@ function perform_step!(cache::GeneralKlementCache{true})
_axpy!(α, du, u)
f(cache.fu2, u, p)

cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
termination_condition(cache.fu2, u, u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)
cache.stats.nf += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1
Expand All @@ -138,13 +155,17 @@ function perform_step!(cache::GeneralKlementCache{true})
mul!(cache.J_cache2, cache.J_cache, J)
J .+= cache.J_cache2

@. u_prev = u
cache.fu .= cache.fu2

return nothing
end

function perform_step!(cache::GeneralKlementCache{false})
@unpack fu, f, p, alg, J, linsolve = cache
@unpack fu, f, p, alg, J, linsolve, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)

T = eltype(J)

singular, fact_done = _try_factorize_and_check_singular!(linsolve, J)
Expand Down Expand Up @@ -174,7 +195,10 @@ function perform_step!(cache::GeneralKlementCache{false})
cache.u = @. cache.u + α * cache.du # `u` might not support mutation
cache.fu2 = f(cache.u, p)

cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)

cache.u_prev = @. cache.u
cache.stats.nf += 1
cache.stats.nsolve += 1
cache.stats.nfactors += 1
Expand All @@ -198,7 +222,9 @@ function perform_step!(cache::GeneralKlementCache{false})
end

function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
abstol = cache.abstol, reltol = cache.reltol,
termination_condition = cache.termination_condition,
maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.u, u0)
Expand All @@ -208,7 +234,14 @@ function SciMLBase.reinit!(cache::GeneralKlementCache{iip}, u0 = cache.u; p = ca
cache.u = u0
cache.fu = cache.f(cache.u, p)
end

termination_condition = _get_reinit_termination_condition(cache,
abstol,
reltol,
termination_condition)
cache.abstol = abstol
cache.reltol = reltol
cache.termination_condition = termination_condition
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
Expand Down
42 changes: 34 additions & 8 deletions src/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ end
f
alg
u
u_prev
du
fu
fu2
Expand All @@ -53,17 +54,21 @@ end
internalnorm
retcode::ReturnCode.T
abstol
reltol
reset_tolerance
reset_check
prob
stats::NLStats
lscache
termination_condition
tc_storage
end

get_fu(cache::LimitedMemoryBroydenCache) = cache.fu

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LimitedMemoryBroyden,
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
termination_condition = nothing, internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
Expand All @@ -80,23 +85,38 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LimitedMemory
reset_tolerance = alg.reset_tolerance === nothing ? sqrt(eps(eltype(u))) :
alg.reset_tolerance
reset_check = x -> abs(x) reset_tolerance
return LimitedMemoryBroydenCache{iip}(f, alg, u, du, fu, zero(fu),

abstol, reltol, termination_condition = _init_termination_elements(abstol,
reltol,
termination_condition,
eltype(u))

mode = DiffEqBase.get_termination_mode(termination_condition)

storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing

return LimitedMemoryBroydenCache{iip}(f, alg, u, zero(u), du, fu, zero(fu),
zero(fu), p, U, Vᵀ, similar(u, threshold), similar(u, 1, threshold),
zero(u), zero(u), false, 0, 0, alg.max_resets, maxiters, internalnorm,
ReturnCode.Default, abstol, reset_tolerance, reset_check, prob,
ReturnCode.Default, abstol, reltol, reset_tolerance, reset_check, prob,
NLStats(1, 0, 0, 0, 0),
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)))
init_linesearch_cache(alg.linesearch, f, u, p, fu, Val(iip)), termination_condition,
storage)
end

function perform_step!(cache::LimitedMemoryBroydenCache{true})
@unpack f, p, du, u = cache
@unpack f, p, du, u, tc_storage = cache
T = eltype(u)

termination_condition = cache.termination_condition(tc_storage)

α = perform_linesearch!(cache.lscache, u, du)
_axpy!(α, du, u)
f(cache.fu2, u, p)

cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)
cache.stats.nf += 1

cache.force_stop && return nothing
Expand Down Expand Up @@ -138,20 +158,25 @@ function perform_step!(cache::LimitedMemoryBroydenCache{true})
cache.iterations_since_reset += 1
end

cache.u_prev .= cache.u
cache.fu .= cache.fu2

return nothing
end

function perform_step!(cache::LimitedMemoryBroydenCache{false})
@unpack f, p = cache
@unpack f, p, tc_storage = cache

termination_condition = cache.termination_condition(tc_storage)

T = eltype(cache.u)

α = perform_linesearch!(cache.lscache, cache.u, cache.du)
cache.u = cache.u .+ α * cache.du
cache.fu2 = f(cache.u, p)

cache.internalnorm(cache.fu2) < cache.abstol && (cache.force_stop = true)
termination_condition(cache.fu2, cache.u, cache.u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)
cache.stats.nf += 1

cache.force_stop && return nothing
Expand Down Expand Up @@ -194,6 +219,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{false})
cache.iterations_since_reset += 1
end

cache.u_prev = @. cache.u
cache.fu = cache.fu2

return nothing
Expand Down
3 changes: 1 addition & 2 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,8 @@ function perform_step!(cache::NewtonRaphsonCache{true})
end

function perform_step!(cache::NewtonRaphsonCache{false})
@unpack u, u_prev, fu1, f, p, alg, linsolve = cache
@unpack u, u_prev, fu1, f, p, alg, linsolve, tc_storage = cache

tc_storage = cache.tc_storage
termination_condition = cache.termination_condition(tc_storage)

cache.J = jacobian!!(cache.J, cache)
Expand Down

0 comments on commit f20c9bc

Please sign in to comment.