From 9c50e5fe30dad9a4e833e2aa1998156bc36501c6 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 27 Oct 2023 12:23:30 -0400 Subject: [PATCH] Proper handling of complex numbers and failures --- src/broyden.jl | 4 ++-- src/default.jl | 10 +++++----- src/klement.jl | 4 ++-- src/lbroyden.jl | 4 ++-- src/raphson.jl | 7 ++----- src/trustRegion.jl | 35 ++++++++++++++++++++--------------- src/utils.jl | 14 +++++++++++++- 7 files changed, 46 insertions(+), 32 deletions(-) diff --git a/src/broyden.jl b/src/broyden.jl index 5fcbd3d51..ce0d10930 100644 --- a/src/broyden.jl +++ b/src/broyden.jl @@ -111,7 +111,7 @@ function perform_step!(cache::GeneralBroydenCache{true}) if all(cache.reset_check, du) || all(cache.reset_check, dfu) if cache.resets ≥ cache.max_resets - cache.retcode = ReturnCode.Unstable + cache.retcode = ReturnCode.ConvergenceFailure cache.force_stop = true return nothing end @@ -153,7 +153,7 @@ function perform_step!(cache::GeneralBroydenCache{false}) cache.dfu = cache.fu2 .- cache.fu if all(cache.reset_check, cache.du) || all(cache.reset_check, cache.dfu) if cache.resets ≥ cache.max_resets - cache.retcode = ReturnCode.Unstable + cache.retcode = ReturnCode.ConvergenceFailure cache.force_stop = true return nothing end diff --git a/src/default.jl b/src/default.jl index 3b4f8ef23..a163903bb 100644 --- a/src/default.jl +++ b/src/default.jl @@ -128,8 +128,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, @unpack adkwargs, linsolve, precs = alg algs = ( - # Klement(), - # Broyden(), + GeneralKlement(; linsolve, precs), + GeneralBroyden(), NewtonRaphson(; linsolve, precs, adkwargs...), NewtonRaphson(; linsolve, precs, linesearch = BackTracking(), adkwargs...), TrustRegion(; linsolve, precs, adkwargs...), @@ -159,7 +159,7 @@ end ] else [ - :(GeneralKlement()), + :(GeneralKlement(; linsolve, precs)), :(GeneralBroyden()), :(NewtonRaphson(; linsolve, precs, adkwargs...)), :(NewtonRaphson(; linsolve, precs, linesearch = BackTracking(), adkwargs...)), @@ -191,7 +191,7 @@ end push!(calls, quote resids = tuple($(Tuple(resids)...)) - minfu, idx = findmin(DEFAULT_NORM, resids) + minfu, idx = __findmin(DEFAULT_NORM, resids) end) for i in 1:length(algs) @@ -249,7 +249,7 @@ end retcode = ReturnCode.MaxIters fus = tuple($(Tuple(resids)...)) - minfu, idx = findmin(cache.caches[1].internalnorm, fus) + minfu, idx = __findmin(cache.caches[1].internalnorm, fus) stats = cache.caches[idx].stats u = cache.caches[idx].u diff --git a/src/klement.jl b/src/klement.jl index e60aeee9b..435bdf52c 100644 --- a/src/klement.jl +++ b/src/klement.jl @@ -118,7 +118,7 @@ function perform_step!(cache::GeneralKlementCache{true}) if singular if cache.resets == alg.max_resets cache.force_stop = true - cache.retcode = ReturnCode.Unstable + cache.retcode = ReturnCode.ConvergenceFailure return nothing end fact_done = false @@ -176,7 +176,7 @@ function perform_step!(cache::GeneralKlementCache{false}) if singular if cache.resets == alg.max_resets cache.force_stop = true - cache.retcode = ReturnCode.Unstable + cache.retcode = ReturnCode.ConvergenceFailure return nothing end fact_done = false diff --git a/src/lbroyden.jl b/src/lbroyden.jl index db4353b41..eecc4b712 100644 --- a/src/lbroyden.jl +++ b/src/lbroyden.jl @@ -128,7 +128,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{true}) if cache.iterations_since_reset > size(cache.U, 1) && (all(cache.reset_check, du) || all(cache.reset_check, cache.dfu)) if cache.resets ≥ cache.max_resets - cache.retcode = ReturnCode.Unstable + cache.retcode = ReturnCode.ConvergenceFailure cache.force_stop = true return nothing end @@ -188,7 +188,7 @@ function perform_step!(cache::LimitedMemoryBroydenCache{false}) if cache.iterations_since_reset > size(cache.U, 1) && (all(cache.reset_check, cache.du) || all(cache.reset_check, cache.dfu)) if cache.resets ≥ cache.max_resets - cache.retcode = ReturnCode.Unstable + cache.retcode = ReturnCode.ConvergenceFailure cache.force_stop = true return nothing end diff --git a/src/raphson.jl b/src/raphson.jl index 6e2a502bb..4c5dcfb99 100644 --- a/src/raphson.jl +++ b/src/raphson.jl @@ -80,8 +80,7 @@ end function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphson, args...; alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing, - termination_condition = nothing, - internalnorm = DEFAULT_NORM, + termination_condition = nothing, internalnorm = DEFAULT_NORM, linsolve_kwargs = (;), kwargs...) where {uType, iip} alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob @@ -91,9 +90,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::NewtonRaphso linsolve_kwargs) abstol, reltol, termination_condition = _init_termination_elements(abstol, - reltol, - termination_condition, - eltype(u)) + reltol, termination_condition, eltype(u)) mode = DiffEqBase.get_termination_mode(termination_condition) diff --git a/src/trustRegion.jl b/src/trustRegion.jl index cf9f41af0..4339b4739 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -141,11 +141,6 @@ for large-scale and numerically-difficult nonlinear systems. `expand_threshold < r` (with `r` defined in `shrink_threshold`). Defaults to `2.0`. - `max_shrink_times`: the maximum number of times to shrink the trust region radius in a row, `max_shrink_times` is exceeded, the algorithm returns. Defaults to `32`. - -!!! warning - - `linsolve` and `precs` are used exclusively for the inplace version of the algorithm. - Support for the OOP version is planned! """ @concrete struct TrustRegion{CJ, AD, MTR} <: AbstractNewtonAlgorithm{CJ, AD} @@ -250,7 +245,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, linsolve_kwargs) u_tmp = zero(u) u_cauchy = zero(u) - u_gauss_newton = zero(u) + u_gauss_newton = _mutable_zero(u) loss_new = loss H = zero(J' * J) @@ -338,10 +333,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, initial_trust_radius = convert(trustType, 1.0) end - abstol, reltol, termination_condition = _init_termination_elements(abstol, - reltol, - termination_condition, - eltype(u)) + abstol, reltol, termination_condition = _init_termination_elements(abstol, reltol, + termination_condition, eltype(u)) mode = DiffEqBase.get_termination_mode(termination_condition) @@ -368,8 +361,7 @@ function perform_step!(cache::TrustRegionCache{true}) # do not use A = cache.H, b = _vec(cache.g) since it is equivalent # to A = cache.J, b = _vec(fu) as long as the Jacobian is non-singular linres = dolinsolve(alg.precs, linsolve, A = J, b = _vec(fu), - linu = _vec(u_gauss_newton), - p = p, reltol = cache.abstol) + linu = _vec(u_gauss_newton), p = p, reltol = cache.abstol) cache.linsolve = linres.cache @. cache.u_gauss_newton = -1 * u_gauss_newton end @@ -395,7 +387,12 @@ function perform_step!(cache::TrustRegionCache{false}) cache.H = J' * J cache.g = _restructure(fu, J' * _vec(fu)) cache.stats.njacs += 1 - cache.u_gauss_newton = -1 .* _restructure(cache.g, cache.H \ _vec(cache.g)) + + # do not use A = cache.H, b = _vec(cache.g) since it is equivalent + # to A = cache.J, b = _vec(fu) as long as the Jacobian is non-singular + linres = dolinsolve(cache.alg.precs, cache.linsolve, A = cache.J, b = -_vec(fu), + linu = _vec(cache.u_gauss_newton), p = p, reltol = cache.abstol) + cache.linsolve = linres.cache end # Compute the Newton step. @@ -718,8 +715,16 @@ function jvp!(cache::TrustRegionCache{true}) end function not_terminated(cache::TrustRegionCache) - return !cache.force_stop && cache.stats.nsteps < cache.maxiters && - cache.shrink_counter < cache.alg.max_shrink_times + non_shrink_terminated = cache.force_stop || cache.stats.nsteps ≥ cache.maxiters + # Terminated due to convergence or maxiters + non_shrink_terminated && return false + # Terminated due to too many shrink + shrink_terminated = cache.shrink_counter ≥ cache.alg.max_shrink_times + if shrink_terminated + cache.retcode = ReturnCode.ConvergenceFailure + return false + end + return true end get_fu(cache::TrustRegionCache) = cache.fu diff --git a/src/utils.jl b/src/utils.jl index 87a80e4ed..8da6bc748 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,6 +13,14 @@ end @inline DEFAULT_NORM(u::AbstractArray) = sqrt(real(sum(UNITLESS_ABS2, u)) / length(u)) @inline DEFAULT_NORM(u) = norm(u) +# Ignores NaN +function __findmin(f, x) + return findmin(x) do xᵢ + fx = f(xᵢ) + return isnan(fx) ? Inf : fx + end +end + """ default_adargs_to_adtype(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(), diff_type = Val{:forward}) @@ -210,9 +218,13 @@ function __get_concrete_algorithm(alg, prob) return set_ad(alg, ad) end +__cvt_real(::Type{T}, ::Nothing) where {T} = nothing +__cvt_real(::Type{T}, x) where {T} = real(T(x)) + function _get_tolerance(η, tc_η, ::Type{T}) where {T} fallback_η = real(oneunit(T)) * (eps(real(one(T))))^(4 // 5) - return T(ifelse(η !== nothing, η, ifelse(tc_η !== nothing, tc_η, fallback_η))) + return ifelse(η !== nothing, __cvt_real(T, η), + ifelse(tc_η !== nothing, __cvt_real(T, tc_η), fallback_η)) end function _init_termination_elements(abstol, reltol, termination_condition,