From e8397a966450d0a64114f698dedf2f97a9cd7df8 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 1 Dec 2023 18:44:40 -0500 Subject: [PATCH] standardize the cache --- src/jacobian.jl | 7 ++++-- src/trustRegion.jl | 57 +++++++++++++++++++++++----------------------- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/src/jacobian.jl b/src/jacobian.jl index 328169165..feac0eb2e 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -139,11 +139,14 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f::F, u::Number, kwargs...) where {needsJᵀJ, F} # NOTE: Scalar `u` assumes scalar output from `f` uf = SciMLBase.JacobianWrapper{false}(f, p) - needsJᵀJ && return uf, nothing, u, nothing, nothing, u, u, u - return uf, FakeLinearSolveJLCache(u, u), u, nothing, nothing, u + return uf, FakeLinearSolveJLCache(u, u), u, nothing, nothing, u, u, u end # Linear Solve Cache +function linsolve_caches(A::Number, b, u, p, alg; linsolve_kwargs = (;)) + return FakeLinearSolveJLCache(A, b) +end + function linsolve_caches(A, b, u, p, alg; linsolve_kwargs = (;)) if alg.linsolve === nothing && A isa SMatrix && linsolve_kwargs === (;) # Default handling for SArrays in LinearSolve is not great. Some parts are patched diff --git a/src/trustRegion.jl b/src/trustRegion.jl index 1a8d0c57c..132d5dfb1 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -186,15 +186,21 @@ end AbstractNonlinearSolveCache{iip} f alg - u_prev u - fu_prev + u_cache + u_cache_2 fu - fu2 + fu_cache + fu_cache_2 + du + u_gauss_newton + u_cauchy + J + JᵀJ + Jᵀf p uf linsolve - J jac_cache force_stop::Bool maxiters::Int @@ -213,14 +219,7 @@ end expand_factor::trustType loss::floatType loss_new::floatType - H - g shrink_counter::Int - du - u_tmp - u_gauss_newton - u_cauchy - fu_new make_new_J::Bool r::floatType p1::floatType @@ -240,24 +239,22 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, linsolve_kwargs = (;), kwargs...) where {uType, iip} alg = get_concrete_algorithm(alg_, prob) @unpack f, u0, p = prob - u = alias_u0 ? u0 : deepcopy(u0) - u_prev = zero(u) - fu1 = evaluate_f(prob, u) - fu_prev = zero(fu1) + u = __mabe_unaliased(u0, alias_u0) + @bb u_cache = similar(u) + @bb u_cache_2 = similar(u) + fu = evaluate_f(prob, u) + @bb fu_cache_2 = similar(fu) - loss = __get_trust_region_loss(fu1) - uf, _, J, fu2, jac_cache, du, H, g = jacobian_caches(alg, f, u, p, Val(iip); - linsolve_kwargs, linsolve_with_JᵀJ = Val(true), lininit = Val(false)) - g = _restructure(fu1, g) - linsolve = u isa Number ? nothing : linsolve_caches(J, fu2, du, p, alg) + loss = __get_trust_region_loss(fu) + uf, _, J, fu_cache, jac_cache, du, JᵀJ, Jᵀf = jacobian_caches(alg, f, u, p, Val(iip); + linsolve_with_JᵀJ = Val(true), lininit = Val(false)) + linsolve = linsolve_caches(J, fu, du, p, alg) - u_tmp = zero(u) - u_cauchy = zero(u) - u_gauss_newton = _mutable_zero(u) + @bb u_cauchy = similar(u) + @bb u_gauss_newton = similar(u) loss_new = loss shrink_counter = 0 - fu_new = zero(fu1) make_new_J = true r = loss @@ -342,12 +339,13 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg_::TrustRegion, termination_condition) trace = init_nonlinearsolve_trace(alg, u, fu1, ApplyArray(__zero, J), du; kwargs...) - return TrustRegionCache{iip}(f, alg, u_prev, u, fu_prev, fu1, fu2, p, uf, linsolve, J, - jac_cache, false, maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob, + return TrustRegionCache{iip}(f, alg, u, u_cache, u_cache_2, fu, fu_cache, fu_cache_2, + du, u_gauss_newton, u_cauchy, J, JᵀJ, Jᵀf, p, uf, linsolve, jac_cache, false, + maxiters, internalnorm, ReturnCode.Default, abstol, reltol, prob, radius_update_scheme, initial_trust_radius, max_trust_radius, step_threshold, shrink_threshold, expand_threshold, shrink_factor, expand_factor, loss, loss_new, - H, g, shrink_counter, du, u_tmp, u_gauss_newton, u_cauchy, fu_new, make_new_J, r, - p1, p2, p3, p4, ϵ, NLStats(1, 0, 0, 0, 0), tc_cache, trace) + shrink_counter, make_new_J, r, p1, p2, p3, p4, ϵ, NLStats(1, 0, 0, 0, 0), tc_cache, + trace) end function perform_step!(cache::TrustRegionCache{iip}) where {iip} @@ -458,7 +456,8 @@ function trust_region_step!(cache::TrustRegionCache) cache.shrink_counter += 1 else cache.shrink_counter = 0 - if r ≥ cache.expand_threshold && 2 * cache.internalnorm(cache.du) > cache.trust_r + if r ≥ cache.expand_threshold && + 2 * cache.internalnorm(cache.du) > cache.trust_r cache.p1 = cache.p3 * cache.p1 end end