Skip to content

Commit

Permalink
add TerminationCondition support to NewtonRaphson (and eventually all…
Browse files Browse the repository at this point in the history
… the other solvers)
  • Loading branch information
oscardssmith committed Apr 17, 2023
1 parent c1267c8 commit 4f668e8
Showing 1 changed file with 46 additions and 25 deletions.
71 changes: 46 additions & 25 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
```julia
NewtonRaphson(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS)
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;)
```
An advanced NewtonRaphson implementation with support for efficient handling of sparse
Expand Down Expand Up @@ -48,26 +49,30 @@ for large-scale and numerically-difficult nonlinear systems.
Currently, the linear solver and chunk size choice only applies to in-place defined
`NonlinearProblem`s. That is expected to change in the future.
"""
struct NewtonRaphson{CS, AD, FDT, L, P, ST, CJ} <:
struct NewtonRaphson{CS, AD, FDT, L, P, ST, CJ, TC<:NLSolveTerminationCondition} <:
AbstractNewtonAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::L
precs::P
termination_condition::TC
end

function NewtonRaphson(; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS)
diff_type = Val{:forward}, linsolve = nothing, precs = DEFAULT_PRECS,
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol=nothing, reltol=nothing))
NewtonRaphson{_unwrap_val(chunk_size), _unwrap_val(autodiff), diff_type,
typeof(linsolve), typeof(precs), _unwrap_val(standardtag),
_unwrap_val(concrete_jac)}(linsolve, precs)
_unwrap_val(concrete_jac), typeof(termination_condition)}(linsolve, precs, termination_condition)
end

mutable struct NewtonRaphsonCache{iip, fType, algType, uType, duType, resType, pType,
INType, tolType,
probType, ufType, L, jType, JC}
probType, ufType, L, jType, JC, TC}
f::fType
alg::algType
u::uType
uprev::uType
fu::resType
p::pType
uf::ufType
Expand All @@ -76,29 +81,30 @@ mutable struct NewtonRaphsonCache{iip, fType, algType, uType, duType, resType, p
du1::duType
jac_config::JC
iter::Int
force_stop::Bool
maxiters::Int
internalnorm::INType
retcode::SciMLBase.ReturnCode.T
abstol::tolType
reltol::tolType
termination_condition::TC
prob::probType

function NewtonRaphsonCache{iip}(f::fType, alg::algType, u::uType, fu::resType,
function NewtonRaphsonCache{iip}(f::fType, alg::algType, u::uType, uprev::uType, fu::resType,
p::pType,
uf::ufType, linsolve::L, J::jType, du1::duType,
jac_config::JC, iter::Int,
force_stop::Bool, maxiters::Int, internalnorm::INType,
maxiters::Int, internalnorm::INType,
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
prob::probType) where {
reltol::tolType, termination_condition::TC, prob::probType) where {
iip, fType, algType, uType,
duType, resType, pType, INType,
tolType,
probType, ufType, L, jType, JC}
probType, ufType, L, jType, TC, JC}
new{iip, fType, algType, uType, duType, resType, pType, INType, tolType,
probType, ufType, L, jType, JC}(f, alg, u, fu, p,
probType, ufType, L, jType, JC, TC}(f, alg, u, uprev, fu, p,
uf, linsolve, J, du1, jac_config, iter,
force_stop, maxiters, internalnorm,
retcode, abstol, prob)
maxiters, internalnorm,
retcode, abstol, reltol, termination_condition, prob)
end
end

Expand Down Expand Up @@ -131,9 +137,23 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson
args...;
alias_u0 = false,
maxiters = 1000,
abstol = 1e-6,
abstol = nothing,
reltol = nothing,
internalnorm = DEFAULT_NORM,
kwargs...) where {uType, iip}
tc = alg.termination_condition
ueltype = eltype(uType)
abstol = !isnothing(abstol) ? abstol :
(!isnothing(tc.abstol) ? tc.abstol :
real(oneunit(ueltype)) * (eps(real(one(ueltype))))^(4 // 5))

reltol = !isnothing(reltol) ? reltol :
(!isnothing(tc.reltol) ? tc.reltol : eps(real(one(ueltype)))^(4 // 5))

mode = DiffEqBase.get_termination_mode(tc)
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? Dict() : nothing
termination_condition = tc(storage)

if alias_u0
u = prob.u0
else
Expand All @@ -144,31 +164,31 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson
if iip
fu = zero(u)
f(fu, u, p)
uprev = u
else
fu = f(u, p)
uprev = deepcopy(u)
end
uf, linsolve, J, du1, jac_config = jacobian_caches(alg, f, u, p, Val(iip))

return NewtonRaphsonCache{iip}(f, alg, u, fu, p, uf, linsolve, J, du1, jac_config,
1, false, maxiters, internalnorm,
ReturnCode.Default, abstol, prob)
return NewtonRaphsonCache{iip}(f, alg, u, uprev, fu, p, uf, linsolve, J, du1, jac_config,
1, maxiters, internalnorm,
ReturnCode.Default, abstol, reltol,
termination_condition, prob)
end

function perform_step!(cache::NewtonRaphsonCache{true})
@unpack u, fu, f, p, alg = cache
@unpack J, linsolve, du1 = cache
jacobian!(J, cache)
cache.uprev .= u

# u = u - J \ fu
linres = dolinsolve(alg.precs, linsolve, A = J, b = _vec(fu), linu = _vec(du1),
p = p, reltol = cache.abstol)
cache.linsolve = linres.cache
@. u = u - du1
f(fu, u, p)

if cache.internalnorm(cache.fu) < cache.abstol
cache.force_stop = true
end
return nothing
end

Expand All @@ -177,16 +197,17 @@ function perform_step!(cache::NewtonRaphsonCache{false})
J = jacobian(cache, f)
cache.u = u - J \ fu
cache.fu = f(cache.u, p)
if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol
cache.force_stop = true
end
cache.uprev = u
return nothing
end

function SciMLBase.solve!(cache::NewtonRaphsonCache)
while !cache.force_stop && cache.iter < cache.maxiters
while cache.iter < cache.maxiters
perform_step!(cache)
cache.iter += 1
if cache.termination_condition(cache.fu, cache.u, cache.uprev, cache.abstol, cache.reltol)
break
end
end

if cache.iter == cache.maxiters
Expand Down

0 comments on commit 4f668e8

Please sign in to comment.