Skip to content

Commit

Permalink
Rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
tansongchen committed Nov 1, 2024
1 parent 7364d90 commit 689b13c
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions lib/NonlinearSolveBase/src/descent/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ end

@internal_caches HalleyDescentCache :lincache

function __internal_init(
prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; shared::Val{N} = Val(1),
pre_inverted::Val{INV} = False, linsolve_kwargs = (;), abstol = nothing,
reltol = nothing, timer = get_timer_output(), kwargs...) where {INV, N}
function __internal_init(prob::NonlinearProblem, alg::HalleyDescent, J, fu, u; stats,
shared::Val{N} = Val(1), pre_inverted::Val{INV} = False,
linsolve_kwargs = (;), abstol = nothing, reltol = nothing,
timer = get_timer_output(), kwargs...) where {INV, N}
@bb δu = similar(u)
@bb b = similar(u)
@bb fu = similar(fu)
Expand All @@ -48,23 +48,27 @@ function __internal_init(
end
INV && return HalleyDescentCache{true}(prob.f, prob.p, δu, δus, b, nothing, timer)
lincache = LinearSolverCache(
alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol, linsolve_kwargs...)
alg, alg.linsolve, J, _vec(fu), _vec(u); stats, abstol, reltol, linsolve_kwargs...)
return HalleyDescentCache{false}(prob.f, prob.p, δu, δus, b, fu, lincache, timer)
end

function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val = Val(1);
skip_solve::Bool = false, new_jacobian::Bool = true, kwargs...) where {INV}
δu = get_du(cache, idx)
skip_solve && return δu, true, (;)
skip_solve && return DescentResult(; δu)
if INV
@assert J!==nothing "`J` must be provided when `pre_inverted = Val(true)`."
@bb δu = J × vec(fu)
else
@static_timeit cache.timer "linear solve 1" begin
δu = cache.lincache(;
linres = cache.lincache(;
A = J, b = _vec(fu), kwargs..., linu = _vec(δu), du = _vec(δu),
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1)))
δu = _restructure(get_du(cache, idx), δu)
δu = _restructure(get_du(cache, idx), linres.u)
if !linres.success
set_du!(cache, δu, idx)
return DescentResult(; δu, success = false, linsolve_success = false)
end
end
end
b = cache.b
Expand All @@ -75,15 +79,19 @@ function __internal_solve!(cache::HalleyDescentCache{INV}, J, fu, u, idx::Val =
@bb b = J × vec(hvvp)
else
@static_timeit cache.timer "linear solve 2" begin
b = cache.lincache(; A = J, b = _vec(hvvp), kwargs..., linu = _vec(b),
linres = cache.lincache(; A = J, b = _vec(hvvp), kwargs..., linu = _vec(b),
du = _vec(b), reuse_A_if_factorization = true)
b = _restructure(cache.b, b)
b = _restructure(cache.b, linres.u)
if !linres.success
set_du!(cache, δu, idx)
return DescentResult(; δu, success = false, linsolve_success = false)
end
end
end
@bb @. δu = δu * δu / (b / 2 - δu)
set_du!(cache, δu, idx)
cache.b = b
return δu, true, (;)
return DescentResult(; δu)
end

function evaluate_hvvp(
Expand Down

0 comments on commit 689b13c

Please sign in to comment.