Skip to content

Commit

Permalink
Update all algorithms to use termination condition
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsh530 committed Oct 19, 2023
1 parent 37b7015 commit efb8c5e
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}

abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm end
abstract type AbstractNewtonAlgorithm{CJ, AD, TC} <: AbstractNonlinearSolveAlgorithm end

abstract type AbstractNonlinearSolveCache{iip} end

Expand Down
58 changes: 47 additions & 11 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ numerically-difficult nonlinear systems.
[this paper](https://arxiv.org/abs/1201.5885) to use a minimum value of the elements in
`DᵀD` to prevent the damping from being too small. Defaults to `1e-8`.
"""
@concrete struct LevenbergMarquardt{CJ, AD, T} <: AbstractNewtonAlgorithm{CJ, AD}
@concrete struct LevenbergMarquardt{CJ, AD, T, TC <: NLSolveTerminationCondition} <:
AbstractNewtonAlgorithm{CJ, AD, TC}
ad::AD
linsolve
precs
Expand All @@ -85,6 +86,7 @@ numerically-difficult nonlinear systems.
α_geodesic::T
b_uphill::T
min_damping_D::T
termination_condition::TC
end

function set_ad(alg::LevenbergMarquardt{CJ}, ad) where {CJ}
Expand All @@ -97,17 +99,22 @@ function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, damping_initial::Real = 1.0, damping_increase_factor::Real = 2.0,
damping_decrease_factor::Real = 3.0, finite_diff_step_geodesic::Real = 0.1,
α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing),
adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return LevenbergMarquardt{_unwrap_val(concrete_jac)}(ad, linsolve, precs,
damping_initial, damping_increase_factor, damping_decrease_factor,
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D,
termination_condition)
end

@concrete mutable struct LevenbergMarquardtCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
u_prev
fu1
fu2
du
Expand All @@ -121,6 +128,7 @@ end
internalnorm
retcode::ReturnCode.T
abstol
reltol
prob
DᵀD
JᵀJ
Expand All @@ -145,11 +153,13 @@ end
Jv
mat_tmp
stats::NLStats
tc_storage
end

function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
NonlinearLeastSquaresProblem{uType, iip}}, alg_::LevenbergMarquardt,
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
NonlinearLeastSquaresProblem{uType, iip}}, alg_::LevenbergMarquardt,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob
Expand Down Expand Up @@ -184,15 +194,30 @@ function SciMLBase.__init(prob::Union{NonlinearProblem{uType, iip},
fu_tmp = zero(fu1)
mat_tmp = zero(JᵀJ)

return LevenbergMarquardtCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, DᵀD,
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)

atol = _get_tolerance(abstol, tc.abstol, eltype(u))
rtol = _get_tolerance(reltol, tc.reltol, eltype(u))

storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing

return LevenbergMarquardtCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve,
J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol, prob, DᵀD,
JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic,
b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp,
zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0))
zero(u), zero(fu1), mat_tmp, NLStats(1, 0, 0, 0, 0), storage)

end

function perform_step!(cache::LevenbergMarquardtCache{true})
@unpack fu1, f, make_new_J = cache

tc_storage = cache.tc_storage
termination_condition = cache.alg.termination_condition(tc_storage)

if iszero(fu1)
cache.force_stop = true
return nothing
Expand All @@ -205,7 +230,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
cache.make_new_J = false
cache.stats.njacs += 1
end
@unpack u, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache
@unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache

# Usual Levenberg-Marquardt step ("velocity").
# The following lines do: cache.v = -cache.mat_tmp \ cache.u_tmp
Expand Down Expand Up @@ -246,7 +271,11 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
if (1 - β)^b_uphill * loss loss_old
# Accept step.
cache.u .+= δ
if loss < cache.abstol
if termination_condition(cache.fu_tmp,
cache.u,
u_prev,
cache.abstol,
cache.reltol)
cache.force_stop = true
return nothing
end
Expand All @@ -258,13 +287,18 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
cache.make_new_J = true
end
end
@. u_prev = u
cache.λ *= cache.λ_factor
cache.λ_factor = cache.damping_increase_factor
return nothing
end

function perform_step!(cache::LevenbergMarquardtCache{false})
@unpack fu1, f, make_new_J = cache

tc_storage = cache.tc_storage
termination_condition = cache.alg.termination_condition(tc_storage)

if iszero(fu1)
cache.force_stop = true
return nothing
Expand All @@ -281,7 +315,8 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
cache.make_new_J = false
cache.stats.njacs += 1
end
@unpack u, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache

@unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, linsolve, alg = cache

cache.mat_tmp = JᵀJ + λ * DᵀD
# Usual Levenberg-Marquardt step ("velocity").
Expand Down Expand Up @@ -322,7 +357,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
if (1 - β)^b_uphill * loss loss_old
# Accept step.
cache.u += δ
if loss < cache.abstol
if termination_condition(fu_new, cache.u, u_prev, cache.abstol, cache.reltol)
cache.force_stop = true
return nothing
end
Expand All @@ -334,6 +369,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
cache.make_new_J = true
end
end
cache.u_prev = @. cache.u
cache.λ *= cache.λ_factor
cache.λ_factor = cache.damping_increase_factor
return nothing
Expand Down
40 changes: 24 additions & 16 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ for large-scale and numerically-difficult nonlinear systems.
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
used here directly, and they will be converted to the correct `LineSearch`.
"""
@concrete struct NewtonRaphson{CJ, AD, TC <: NLSolveTerminationCondition} <: AbstractNewtonAlgorithm{CJ, AD}
@concrete struct NewtonRaphson{CJ, AD, TC <: NLSolveTerminationCondition} <:
AbstractNewtonAlgorithm{CJ, AD, TC}
ad::AD
linsolve
precs
Expand All @@ -43,19 +44,24 @@ function set_ad(alg::NewtonRaphson{CJ}, ad) where {CJ}
end

function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
linesearch = LineSearch(), precs = DEFAULT_PRECS, termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing), adkwargs...)
linesearch = LineSearch(), precs = DEFAULT_PRECS,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing), adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch, termination_condition)
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad,
linsolve,
precs,
linesearch,
termination_condition)
end

@concrete mutable struct NewtonRaphsonCache{iip} <: AbstractNonlinearSolveCache{iip}
f
alg
u
uprev
u_prev
fu1
fu2
du
Expand All @@ -77,7 +83,8 @@ end
end

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphson, args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
alg = get_concrete_algorithm(alg_, prob)
@unpack f, u0, p = prob
Expand All @@ -86,7 +93,6 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_kwargs)


tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)

Expand All @@ -98,11 +104,12 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso

return NewtonRaphsonCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve, J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol, prob,
NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip)), storage)
NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip)),
storage)
end

function perform_step!(cache::NewtonRaphsonCache{true})
@unpack u, uprev, fu1, f, p, alg, J, linsolve, du = cache
@unpack u, u_prev, fu1, f, p, alg, J, linsolve, du = cache
jacobian!!(J, cache)

tc_storage = cache.tc_storage
Expand All @@ -118,9 +125,10 @@ function perform_step!(cache::NewtonRaphsonCache{true})
@. u = u - α * du
f(cache.fu1, u, p)

termination_condition(cache.fu1, u, uprev, cache.abstol, cache.reltol) && (cache.force_stop = true)
termination_condition(cache.fu1, u, u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)

@. uprev = u
@. u_prev = u
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
Expand All @@ -129,12 +137,11 @@ function perform_step!(cache::NewtonRaphsonCache{true})
end

function perform_step!(cache::NewtonRaphsonCache{false})
@unpack u, uprev, fu1, f, p, alg, linsolve = cache
@unpack u, u_prev, fu1, f, p, alg, linsolve = cache

tc_storage = cache.tc_storage
termination_condition = cache.alg.termination_condition(tc_storage)


cache.J = jacobian!!(cache.J, cache)
# u = u - J \ fu
if linsolve === nothing
Expand All @@ -150,9 +157,10 @@ function perform_step!(cache::NewtonRaphsonCache{false})
cache.u = @. u - α * cache.du # `u` might not support mutation
cache.fu1 = f(cache.u, p)

termination_condition(cache.fu1, cache.u, uprev, cache.abstol, cache.reltol) && (cache.force_stop = true)
termination_condition(cache.fu1, cache.u, u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)

cache.uprev = cache.u
cache.u_prev = @. cache.u
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
Expand Down
Loading

0 comments on commit efb8c5e

Please sign in to comment.