diff --git a/src/raphson.jl b/src/raphson.jl index a1fa1a1c5..ad64cab41 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -29,24 +29,28 @@ for large-scale and numerically-difficult nonlinear systems. which means that no line search is performed. Algorithms from `LineSearches.jl` can be used here directly, and they will be converted to the correct `LineSearch`. """ -@concrete struct NewtonRaphson{CJ, AD} <: AbstractNewtonAlgorithm{CJ, AD} +@concrete struct NewtonRaphson{CJ, AD, TC <: NLSolveTerminationCondition} <: AbstractNewtonAlgorithm{CJ, AD} ad::AD linsolve precs linesearch + termination_condition::TC end function NewtonRaphson(; concrete_jac = nothing, linsolve = nothing, - linesearch = LineSearch(), precs = DEFAULT_PRECS, adkwargs...) + linesearch = LineSearch(), precs = DEFAULT_PRECS, termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault; + abstol = nothing, + reltol = nothing), adkwargs...) ad = default_adargs_to_adtype(; adkwargs...) linesearch = linesearch isa LineSearch ? linesearch : LineSearch(; method = linesearch) - return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch) + return NewtonRaphson{_unwrap_val(concrete_jac)}(ad, linsolve, precs, linesearch, termination_condition) end @concrete mutable struct NewtonRaphsonCache{iip} <: AbstractNonlinearSolveCache{iip} f alg u + uprev fu1 fu2 du @@ -60,9 +64,11 @@ end internalnorm retcode::ReturnCode.T abstol + reltol prob stats::NLStats lscache + tc_storage end function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson, args...; @@ -74,15 +80,28 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::NewtonRaphson uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip); linsolve_kwargs) - return NewtonRaphsonCache{iip}(f, alg, u, fu1, fu2, du, p, uf, linsolve, J, - jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, prob, - NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip))) + + tc = alg.termination_condition + mode = DiffEqBase.get_termination_mode(tc) + + atol = _get_tolerance(abstol, tc.abstol, eltype(u)) + rtol = _get_tolerance(reltol, tc.reltol, eltype(u)) + + storage = mode ∈ DiffEqBase.SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult() : + nothing + + return NewtonRaphsonCache{iip}(f, alg, u, copy(u), fu1, fu2, du, p, uf, linsolve, J, + jac_cache, false, maxiters, internalnorm, ReturnCode.Default, atol, rtol, prob, + NLStats(1, 0, 0, 0, 0), LineSearchCache(alg.linesearch, f, u, p, fu1, Val(iip)), storage) end function perform_step!(cache::NewtonRaphsonCache{true}) - @unpack u, fu1, f, p, alg, J, linsolve, du = cache + @unpack u, uprev, fu1, f, p, alg, J, linsolve, du = cache jacobian!!(J, cache) + tc_storage = cache.tc_storage + termination_condition = cache.alg.termination_condition(tc_storage) + # u = u - J \ fu linres = dolinsolve(alg.precs, linsolve; A = J, b = _vec(fu1), linu = _vec(du), p, reltol = cache.abstol) @@ -93,7 +112,9 @@ function perform_step!(cache::NewtonRaphsonCache{true}) @. u = u - α * du f(cache.fu1, u, p) - cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true) + termination_condition(cache.fu1, u, uprev, cache.abstol, cache.reltol) && (cache.force_stop = true) + + @. uprev = u cache.stats.nf += 1 cache.stats.njacs += 1 cache.stats.nsolve += 1 @@ -102,7 +123,11 @@ function perform_step!(cache::NewtonRaphsonCache{true}) end function perform_step!(cache::NewtonRaphsonCache{false}) - @unpack u, fu1, f, p, alg, linsolve = cache + @unpack u, uprev, fu1, f, p, alg, linsolve = cache + + tc_storage = cache.tc_storage + termination_condition = cache.alg.termination_condition(tc_storage) + cache.J = jacobian!!(cache.J, cache) # u = u - J \ fu @@ -119,7 +144,9 @@ function perform_step!(cache::NewtonRaphsonCache{false}) cache.u = @. u - α * cache.du # `u` might not support mutation cache.fu1 = f(cache.u, p) - cache.internalnorm(fu1) < cache.abstol && (cache.force_stop = true) + termination_condition(cache.fu1, cache.u, uprev, cache.abstol, cache.reltol) && (cache.force_stop = true) + + cache.uprev = cache.u cache.stats.nf += 1 cache.stats.njacs += 1 cache.stats.nsolve += 1 diff --git a/src/utils.jl b/src/utils.jl index bd6a5a036..e7fe631e0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -163,3 +163,8 @@ function evaluate_f(f, u, p, ::Val{iip}; fu = nothing) where {iip} return f(u, p) end end + +function _get_tolerance(η, tc_η, ::Type{T}) where {T} + @show fallback_η + return ifelse(η !== nothing, η, ifelse(tc_η !== nothing, tc_η, fallback_η)) +end