diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index 25fe50023..1ba7a0cc3 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -8,16 +8,17 @@ using LinearAlgebra: norm using Markdown: @doc_str using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem, - AbstractNonlinearFunction, @add_kwonly, StandardNonlinearProblem, - NullParameters, NonlinearProblem, isinplace + NonlinearProblem, NonlinearLeastSquaresProblem, AbstractNonlinearFunction, + @add_kwonly, StandardNonlinearProblem, NullParameters, NonlinearProblem, + isinplace using StaticArraysCore: StaticArray include("public.jl") include("utils.jl") +include("immutable_problem.jl") include("common_defaults.jl") include("termination_conditions.jl") -include("immutable_problem.jl") # Unexported Public API @compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance)) diff --git a/lib/NonlinearSolveBase/src/termination_conditions.jl b/lib/NonlinearSolveBase/src/termination_conditions.jl index 50af54a57..a278861a8 100644 --- a/lib/NonlinearSolveBase/src/termination_conditions.jl +++ b/lib/NonlinearSolveBase/src/termination_conditions.jl @@ -245,3 +245,37 @@ end function check_convergence(mode::AbsNormModes, duₙ, _, __, abstol, ___) return Utils.apply_norm(mode.internalnorm, duₙ) ≤ abstol end + +# High-Level API with defaults. +## This is mostly for internal usage in NonlinearSolve and SimpleNonlinearSolve +function default_termination_mode( + ::Union{ImmutableNonlinearProblem, NonlinearProblem}, ::Val{:simple}) + return AbsNormTerminationMode(Base.Fix1(maximum, abs)) +end +function default_termination_mode(::NonlinearLeastSquaresProblem, ::Val{:simple}) + return AbsNormTerminationMode(Base.Fix2(norm, 2)) +end + +function default_termination_mode( + ::Union{ImmutableNonlinearProblem, NonlinearProblem}, ::Val{:regular}) + return AbsNormSafeBestTerminationMode(Base.Fix1(maximum, abs); max_stalled_steps = 32) +end + +function default_termination_mode(::NonlinearLeastSquaresProblem, ::Val{:regular}) + return AbsNormSafeBestTerminationMode(Base.Fix2(norm, 2); max_stalled_steps = 32) +end + +function init_termination_cache( + prob::AbstractNonlinearProblem, abstol, reltol, du, u, ::Nothing, callee::Val) + return init_termination_cache( + prob, abstol, reltol, du, u, default_termination_mode(prob, callee), callee) +end + +function init_termination_cache(::AbstractNonlinearProblem, abstol, reltol, du, + u, tc::AbstractNonlinearTerminationMode, ::Val) + T = promote_type(eltype(du), eltype(u)) + abstol = get_tolerance(abstol, T) + reltol = get_tolerance(reltol, T) + cache = init(du, u, tc; abstol, reltol) + return abstol, reltol, cache +end