Skip to content

Commit

Permalink
Update all algorithms to use termination condition
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsh530 committed Oct 9, 2023
1 parent 34fefaf commit d3e99fd
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 48 deletions.
2 changes: 1 addition & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}

abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
abstract type AbstractNewtonAlgorithm{CJ, AD} <: AbstractNonlinearSolveAlgorithm end
abstract type AbstractNewtonAlgorithm{CJ, AD, TC} <: AbstractNonlinearSolveAlgorithm end

function SciMLBase.__solve(prob::NonlinearProblem, alg::AbstractNonlinearSolveAlgorithm,
args...; kwargs...)
Expand Down
54 changes: 44 additions & 10 deletions src/levenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ numerically-difficult nonlinear systems.
[this paper](https://arxiv.org/abs/1201.5885) to use a minimum value of the elements in
`DᵀD` to prevent the damping from being too small. Defaults to `1e-8`.
"""
@concrete struct LevenbergMarquardt{CJ, AD, T} <: AbstractNewtonAlgorithm{CJ, AD}
@concrete struct LevenbergMarquardt{CJ, AD, T, TC <: NLSolveTerminationCondition} <:
AbstractNewtonAlgorithm{CJ, AD, TC}
ad::AD
linsolve
precs
Expand All @@ -84,23 +85,29 @@ numerically-difficult nonlinear systems.
α_geodesic::T
b_uphill::T
min_damping_D::T
termination_condition::TC
end

function LevenbergMarquardt(; concrete_jac = nothing, linsolve = nothing,
precs = DEFAULT_PRECS, damping_initial::Real = 1.0, damping_increase_factor::Real = 2.0,
damping_decrease_factor::Real = 3.0, finite_diff_step_geodesic::Real = 0.1,
α_geodesic::Real = 0.75, b_uphill::Real = 1.0, min_damping_D::AbstractFloat = 1e-8,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing),
adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
return LevenbergMarquardt{_unwrap_val(concrete_jac)}(ad, linsolve, precs,
damping_initial, damping_increase_factor, damping_decrease_factor,
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D)
finite_diff_step_geodesic, α_geodesic, b_uphill, min_damping_D,
termination_condition)
end

@concrete mutable struct LevenbergMarquardtCache{iip, uType, jType, λType, lossType}
f
alg
u::uType
u_prev::uType
fu1
fu2
du
Expand All @@ -114,6 +121,7 @@ end
internalnorm
retcode::ReturnCode.T
abstol
reltol
prob
DᵀD
JᵀJ::jType
Expand All @@ -136,12 +144,14 @@ end
fu_tmp
mat_tmp::jType
stats::NLStats
tc_storage
end

isinplace(::LevenbergMarquardtCache{iip}) where {iip} = iip

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarquardt,
args...; alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
Expand Down Expand Up @@ -177,15 +187,29 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::LevenbergMarq
fu_tmp = zero(fu1)
mat_tmp = zero(J)

return LevenbergMarquardtCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, DᵀD,
tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)

atol = _get_tolerance(abstol, tc.abstol, eltype(u))
rtol = _get_tolerance(reltol, tc.reltol, eltype(u))

storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() :
nothing

return LevenbergMarquardtCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve,
J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol, prob, DᵀD,
JᵀJ, λ, λ_factor, damping_increase_factor, damping_decrease_factor, h, α_geodesic,
b_uphill, min_damping_D, v, a, tmp_vec, v_old, loss, δ, loss, make_new_J, fu_tmp,
mat_tmp, NLStats(1, 0, 0, 0, 0))
mat_tmp, NLStats(1, 0, 0, 0, 0), storage)
end

function perform_step!(cache::LevenbergMarquardtCache{true})
@unpack fu1, f, make_new_J = cache

tc_storage = cache.tc_storage
termination_condition = cache.alg.termination_condition(tc_storage)

if _iszero(fu1)
cache.force_stop = true
return nothing
Expand All @@ -198,7 +222,7 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
cache.make_new_J = false
cache.stats.njacs += 1
end
@unpack u, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache
@unpack u, u_prev, p, λ, JᵀJ, DᵀD, J, alg, linsolve = cache

# Usual Levenberg-Marquardt step ("velocity").
# The following lines do: cache.v = -cache.mat_tmp \ cache.fu_tmp
Expand Down Expand Up @@ -237,7 +261,11 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
if (1 - β)^b_uphill * loss loss_old
# Accept step.
cache.u .+= δ
if loss < cache.abstol
if termination_condition(cache.fu_tmp,
cache.u,
u_prev,
cache.abstol,
cache.reltol)
cache.force_stop = true
return nothing
end
Expand All @@ -249,13 +277,18 @@ function perform_step!(cache::LevenbergMarquardtCache{true})
cache.make_new_J = true
end
end
@. u_prev = u
cache.λ *= cache.λ_factor
cache.λ_factor = cache.damping_increase_factor
return nothing
end

function perform_step!(cache::LevenbergMarquardtCache{false})
@unpack fu1, f, make_new_J = cache

tc_storage = cache.tc_storage
termination_condition = cache.alg.termination_condition(tc_storage)

if _iszero(fu1)
cache.force_stop = true
return nothing
Expand All @@ -272,7 +305,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
cache.make_new_J = false
cache.stats.njacs += 1
end
@unpack u, p, λ, JᵀJ, DᵀD, J = cache
@unpack u, u_prev, p, λ, JᵀJ, DᵀD, J = cache

cache.mat_tmp = JᵀJ + λ * DᵀD
# Usual Levenberg-Marquardt step ("velocity").
Expand All @@ -298,7 +331,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
if (1 - β)^b_uphill * loss loss_old
# Accept step.
cache.u += δ
if loss < cache.abstol
if termination_condition(fu_new, cache.u, u_prev, cache.abstol, cache.reltol)
cache.force_stop = true
return nothing
end
Expand All @@ -310,6 +343,7 @@ function perform_step!(cache::LevenbergMarquardtCache{false})
cache.make_new_J = true
end
end
cache.u_prev = @. cache.u
cache.λ *= cache.λ_factor
cache.λ_factor = cache.damping_increase_factor
return nothing
Expand Down
40 changes: 24 additions & 16 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ for large-scale and numerically-difficult nonlinear systems.
which means that no line search is performed. Algorithms from `LineSearches.jl` can be
used here directly, and they will be converted to the correct `LineSearch`.
"""
@concrete struct NewtonRaphson{CJ, AD, TC <: NLSolveTerminationCondition} <: AbstractNewtonAlgorithm{CJ, AD}
@concrete struct NewtonRaphson{CJ, AD, TC <: NLSolveTerminationCondition} <:
AbstractNewtonAlgorithm{CJ, AD, TC}
ad::AD
linsolve
precs
Expand All @@ -40,19 +41,24 @@ end
concrete_jac(::NewtonRaphson{CJ}) where {CJ} = CJ

function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing,
linesearch = LineSearch(), precs = DEFAULT_PRECS, termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing), adkwargs...)
linesearch = LineSearch(), precs = DEFAULT_PRECS,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing), adkwargs...)
ad = default_adargs_to_adtype(; adkwargs...)
linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch)
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch, termination_condition)
return NewtonRaphson{_unwrap_val(concrete_jac)}(ad,
linsolve,
precs,
linesearch,
termination_condition)
end

@concrete mutable struct NewtonRaphsonCache{iip}
f
alg
u
uprev
u_prev
fu1
fu2
du
Expand All @@ -76,15 +82,15 @@ end
isinplace(::NewtonRaphsonCache{iip}) where {iip} = iip

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson, args...;
alias_u0 = false, maxiters = 1000, abstol = 1e-6, internalnorm = DEFAULT_NORM,
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
internalnorm = DEFAULT_NORM,
linsolve_kwargs = (;), kwargs...) where {uType, iip}
@unpack f, u0, p = prob
u = alias_u0 ? u0 : deepcopy(u0)
fu1 = evaluate_f(prob, u)
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_kwargs)


tc = alg.termination_condition
mode = DiffEqBase.get_termination_mode(tc)

Expand All @@ -96,11 +102,12 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson

return NewtonRaphsonCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve, J,
jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol, prob,
NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip)), storage)
NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip)),
storage)
end

function perform_step!(cache::NewtonRaphsonCache{true})
@unpack u, uprev, fu1, f, p, alg, J, linsolve, du = cache
@unpack u, u_prev, fu1, f, p, alg, J, linsolve, du = cache
jacobian!!(J, cache)

tc_storage = cache.tc_storage
Expand All @@ -116,9 +123,10 @@ function perform_step!(cache::NewtonRaphsonCache{true})
@. u = u - α * du
f(cache.fu1, u, p)

termination_condition(cache.fu1, u, uprev, cache.abstol, cache.reltol) && (cache.force_stop = true)
termination_condition(cache.fu1, u, u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)

@. uprev = u
@. u_prev = u
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
Expand All @@ -127,12 +135,11 @@ function perform_step!(cache::NewtonRaphsonCache{true})
end

function perform_step!(cache::NewtonRaphsonCache{false})
@unpack u, uprev, fu1, f, p, alg, linsolve = cache
@unpack u, u_prev, fu1, f, p, alg, linsolve = cache

tc_storage = cache.tc_storage
termination_condition = cache.alg.termination_condition(tc_storage)


cache.J = jacobian!!(cache.J, cache)
# u = u - J \ fu
if linsolve === nothing
Expand All @@ -148,9 +155,10 @@ function perform_step!(cache::NewtonRaphsonCache{false})
cache.u = @. u - α * cache.du # `u` might not support mutation
cache.fu1 = f(cache.u, p)

termination_condition(cache.fu1, cache.u, uprev, cache.abstol, cache.reltol) && (cache.force_stop = true)
termination_condition(cache.fu1, cache.u, u_prev, cache.abstol, cache.reltol) &&
(cache.force_stop = true)

cache.uprev = cache.u
cache.u_prev = @. cache.u
cache.stats.nf += 1
cache.stats.njacs += 1
cache.stats.nsolve += 1
Expand Down
Loading

0 comments on commit d3e99fd

Please sign in to comment.