diff --git a/lib/NonlinearSolveBase/src/descent/halley.jl b/lib/NonlinearSolveBase/src/descent/halley.jl index 0d98efbe4..596d78f19 100644 --- a/lib/NonlinearSolveBase/src/descent/halley.jl +++ b/lib/NonlinearSolveBase/src/descent/halley.jl @@ -36,10 +36,10 @@ end @internal_caches HalleyDescentCache :lincache -function __internal_init( - prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; shared::Val{N} = Val(1), - pre_inverted::Val{INV} = False, linsolve_kwargs = (;), abstol = nothing, - reltol = nothing, timer = get_timer_output(), kwargs...) where {INV, N} +function __internal_init(prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; stats, + shared::Val{N} = Val(1), pre_inverted::Val{INV} = False, + linsolve_kwargs = (;), abstol = nothing, reltol = nothing, + timer = get_timer_output(), kwargs...) where {INV, N} @bb δu = similar(u) @bb b = similar(u) @bb fu = similar(fu) @@ -48,23 +48,27 @@ function __internal_init( end INV && return HalleyDescentCache{true}(prob.f, prob.p, δu, δus, b, nothing, timer) lincache = LinearSolverCache( - alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol, linsolve_kwargs...) + alg, alg.linsolve, J, _vec(fu), _vec(u); stats, abstol, reltol, linsolve_kwargs...) return HalleyDescentCache{false}(prob.f, prob.p, δu, δus, b, fu, lincache, timer) end function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val = Val(1); skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...) where {INV} δu = get_du(cache, idx) - skip_solve && return δu, true, (;) + skip_solve && return DescentResult(; δu) if INV @assert J!==nothing "`J` must be provided when `pre_inverted = Val(true)`." @bb δu = J × vec(fu) else @static_timeit cache.timer "linear solve 1" begin - δu = cache.lincache(; + linres = cache.lincache(; A = J, b = _vec(fu), kwargs..., linu = _vec(δu), du = _vec(δu), reuse_A_if_factorization = !new_jacobian || (idx !== Val(1))) - δu = _restructure(get_du(cache, idx), δu) + δu = _restructure(get_du(cache, idx), linres.u) + if !linres.success + set_du!(cache, δu, idx) + return DescentResult(; δu, success = false, linsolve_success = false) + end end end b = cache.b @@ -75,15 +79,19 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val = @bb b = J × vec(hvvp) else @static_timeit cache.timer "linear solve 2" begin - b = cache.lincache(; A = J, b = _vec(hvvp), kwargs..., linu = _vec(b), + linres = cache.lincache(; A = J, b = _vec(hvvp), kwargs..., linu = _vec(b), du = _vec(b), reuse_A_if_factorization = true) - b = _restructure(cache.b, b) + b = _restructure(cache.b, linres.u) + if !linres.success + set_du!(cache, δu, idx) + return DescentResult(; δu, success = false, linsolve_success = false) + end end end @bb @. δu = δu * δu / (b / 2 - δu) set_du!(cache, δu, idx) cache.b = b - return δu, true, (;) + return DescentResult(; δu) end function evaluate_hvvp(