From 5549c4238338d08f02436aef6b113aecb50aabe4 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Oct 2024 22:29:09 -0400 Subject: [PATCH] refactor: move LinearSolve wrapper into NonlinearSolveBase --- lib/NonlinearSolveBase/Project.toml | 11 +- .../ext/NonlinearSolveBaseLinearSolveExt.jl | 127 +++++++++ .../src/NonlinearSolveBase.jl | 24 +- lib/NonlinearSolveBase/src/abstract_types.jl | 45 ++++ lib/NonlinearSolveBase/src/jacobian.jl | 1 + lib/NonlinearSolveBase/src/linear_solve.jl | 140 ++++++++++ .../src/termination_conditions.jl | 2 +- lib/NonlinearSolveBase/src/utils.jl | 5 + src/NonlinearSolve.jl | 9 +- src/abstract_types.jl | 23 -- src/algorithms/extension_algs.jl | 2 +- src/core/approximate_jacobian.jl | 2 +- src/core/generalized_first_order.jl | 2 +- src/descent/damped_newton.jl | 4 +- src/descent/dogleg.jl | 2 +- src/descent/newton.jl | 6 +- src/descent/steepest.jl | 3 +- src/internal/linear_solve.jl | 245 ------------------ src/utils.jl | 2 - 19 files changed, 363 insertions(+), 292 deletions(-) create mode 100644 lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl create mode 100644 lib/NonlinearSolveBase/src/abstract_types.jl create mode 100644 lib/NonlinearSolveBase/src/jacobian.jl create mode 100644 lib/NonlinearSolveBase/src/linear_solve.jl delete mode 100644 src/internal/linear_solve.jl diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 70c0f97d4..a902e5594 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -1,10 +1,11 @@ name = "NonlinearSolveBase" uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" authors = ["Avik Pal and contributors"] -version = "1.0.0" +version = "1.1.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" @@ -15,22 +16,27 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" FunctionProperties = "f62d2435-5019-4c03-9749-2d4c77af0cbc" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [extensions] NonlinearSolveBaseDiffEqBaseExt = "DiffEqBase" NonlinearSolveBaseForwardDiffExt = "ForwardDiff" +NonlinearSolveBaseLinearSolveExt = "LinearSolve" NonlinearSolveBaseSparseArraysExt = "SparseArrays" [compat] ADTypes = "1.9" +Adapt = "4.1.0" Aqua = "0.8.7" ArrayInterface = "7.9" CommonSolve = "0.2.4" @@ -45,9 +51,12 @@ ForwardDiff = "0.10.36" FunctionProperties = "0.1.2" InteractiveUtils = "<0.0.1, 1" LinearAlgebra = "1.10" +LinearSolve = "2.36.1" Markdown = "1.10" +MaybeInplace = "0.1.4" RecursiveArrayTools = "3" SciMLBase = "2.50" +SciMLOperators = "0.3.10" SparseArrays = "1.10" StaticArraysCore = "1.4" Test = "1.10" diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl new file mode 100644 index 000000000..294a53de9 --- /dev/null +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl @@ -0,0 +1,127 @@ +module NonlinearSolveBaseLinearSolveExt + +using ArrayInterface: ArrayInterface +using CommonSolve: CommonSolve, init, solve! +using LinearAlgebra: ColumnNorm +using LinearSolve: LinearSolve, QRFactorization +using NonlinearSolveBase: NonlinearSolveBase, LinearSolveJLCache, LinearSolveResult, Utils +using SciMLBase: ReturnCode, LinearProblem + +function (cache::LinearSolveJLCache)(; + A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing, + cachedata = nothing, reuse_A_if_factorization = false, verbose = true, kwargs...) + cache.stats.nsolve += 1 + + update_A!(cache, A, reuse_A_if_factorization) + b !== nothing && setproperty!(cache.lincache, :b, b) + linu !== nothing && NonlinearSolveBase.set_lincache_u!(cache, linu) + + Plprev = cache.lincache.Pl + Prprev = cache.lincache.Pr + + if cache.precs === nothing + Pl, Pr = nothing, nothing + else + Pl, Pr = cache.precs(cache.lincache.A, du, linu, p, nothing, + A !== nothing, Plprev, Prprev, cachedata) + end + + if Pl !== nothing || Pr !== nothing + Pl, Pr = NonlinearSolveBase.wrap_preconditioners(Pl, Pr, linu) + cache.lincache.Pl = Pl + cache.lincache.Pr = Pr + end + + linres = solve!(cache.lincache) + cache.lincache = linres.cache + # Unfortunately LinearSolve.jl doesn't have the most uniform ReturnCode handling + if linres.retcode === ReturnCode.Failure + structured_mat = ArrayInterface.isstructured(cache.lincache.A) + is_gpuarray = ArrayInterface.device(cache.lincache.A) isa ArrayInterface.GPU + + if !(cache.linsolve isa QRFactorization{ColumnNorm}) && !is_gpuarray && + !structured_mat + if verbose + @warn "Potential Rank Deficient Matrix Detected. Attempting to solve using \ + Pivoted QR Factorization." + end + @assert (A !== nothing)&&(b !== nothing) "This case is not yet supported. \ + Please open an issue at \ + https://github.com/SciML/NonlinearSolve.jl" + if cache.additional_lincache === nothing # First time + linprob = LinearProblem(A, b; u0 = linres.u) + cache.additional_lincache = init( + linprob, QRFactorization(ColumnNorm()); alias_u0 = false, + alias_A = false, alias_b = false, cache.lincache.Pl, cache.lincache.Pr) + else + cache.additional_lincache.A = A + cache.additional_lincache.b = b + cache.additional_lincache.Pl = cache.lincache.Pl + cache.additional_lincache.Pr = cache.lincache.Pr + end + linres = solve!(cache.additional_lincache) + cache.additional_lincache = linres.cache + linres.retcode === ReturnCode.Failure && + return LinearSolveResult(; linres.u, success = false) + return LinearSolveResult(; linres.u) + elseif !(cache.linsolve isa QRFactorization{ColumnNorm}) + if verbose + if structured_mat || is_gpuarray + mat_desc = structured_mat ? "Structured" : "GPU" + @warn "Potential Rank Deficient Matrix Detected. But Matrix is \ + $(mat_desc). Currently, we don't attempt to solve Rank Deficient \ + $(mat_desc) Matrices. Please open an issue at \ + https://github.com/SciML/NonlinearSolve.jl" + end + end + end + return LinearSolveResult(; linres.u, success = false) + end + + return LinearSolveResult(; linres.u) +end + +NonlinearSolveBase.needs_square_A(linsolve, ::Any) = LinearSolve.needs_square_A(linsolve) + +update_A!(cache::LinearSolveJLCache, ::Nothing, reuse) = cache +function update_A!(cache::LinearSolveJLCache, A, reuse) + return update_A!(cache, Utils.safe_getproperty(cache.linsolve, Val(:alg)), A, reuse) +end + +function update_A!(cache::LinearSolveJLCache, alg, A, reuse) + # Not a Factorization Algorithm so don't update `nfactors` + set_lincache_A!(cache.lincache, A) + return cache +end +function update_A!(cache::LinearSolveJLCache, ::LinearSolve.AbstractFactorization, A, reuse) + reuse && return cache + set_lincache_A!(cache.lincache, A) + cache.stats.nfactors += 1 + return cache +end +function update_A!( + cache::LinearSolveJLCache, alg::LinearSolve.DefaultLinearSolver, A, reuse) + if alg == + LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES) + # Force a reset of the cache. This is not properly handled in LinearSolve.jl + set_lincache_A!(cache.lincache, A) + return cache + end + reuse && return cache + set_lincache_A!(cache.lincache, A) + cache.stats.nfactors += 1 + return cache +end + +function set_lincache_A!(lincache, new_A) + if !LinearSolve.default_alias_A(lincache.alg, new_A, lincache.b) && + ArrayInterface.can_setindex(lincache.A) + copyto!(lincache.A, new_A) + lincache.A = lincache.A # important!! triggers special code in `setproperty!` + return + end + lincache.A = new_A + return +end + +end diff --git a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl index b07b4b168..c035f9f46 100644 --- a/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl +++ b/lib/NonlinearSolveBase/src/NonlinearSolveBase.jl @@ -1,33 +1,41 @@ module NonlinearSolveBase using ADTypes: ADTypes, AbstractADType +using Adapt: WrappedArray using ArrayInterface: ArrayInterface -using CommonSolve: CommonSolve +using CommonSolve: CommonSolve, init using Compat: @compat using ConcreteStructs: @concrete using DifferentiationInterface: DifferentiationInterface using EnzymeCore: EnzymeCore using FastClosures: @closure using FunctionProperties: hasbranching -using LinearAlgebra: norm +using LinearAlgebra: Diagonal, norm, ldiv! using Markdown: @doc_str +using MaybeInplace: @bb using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem, - NonlinearProblem, NonlinearLeastSquaresProblem, AbstractNonlinearFunction, - @add_kwonly, StandardNonlinearProblem, NullParameters, isinplace, - warn_paramtype -using StaticArraysCore: StaticArray + AbstractNonlinearAlgorithm, AbstractNonlinearFunction, + NonlinearProblem, NonlinearLeastSquaresProblem, StandardNonlinearProblem, + NullParameters, NLStats, LinearProblem, isinplace, warn_paramtype, + @add_kwonly +using SciMLOperators: AbstractSciMLOperator, IdentityOperator +using StaticArraysCore: StaticArray, SMatrix, SArray, MArray const DI = DifferentiationInterface include("public.jl") include("utils.jl") +include("abstract_types.jl") + include("immutable_problem.jl") include("common_defaults.jl") include("termination_conditions.jl") include("autodiff.jl") +include("jacobian.jl") +include("linear_solve.jl") # Unexported Public API @compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance)) @@ -36,6 +44,10 @@ include("autodiff.jl") (select_forward_mode_autodiff, select_reverse_mode_autodiff, select_jacobian_autodiff)) +# public for NonlinearSolve.jl to use +@compat(public, (InternalAPI, supports_line_search, supports_trust_region, set_du!)) +@compat(public, (construct_linear_solver, needs_square_A)) + export RelTerminationMode, AbsTerminationMode, NormTerminationMode, RelNormTerminationMode, AbsNormTerminationMode, RelNormSafeTerminationMode, AbsNormSafeTerminationMode, diff --git a/lib/NonlinearSolveBase/src/abstract_types.jl b/lib/NonlinearSolveBase/src/abstract_types.jl new file mode 100644 index 000000000..bd7b3a86b --- /dev/null +++ b/lib/NonlinearSolveBase/src/abstract_types.jl @@ -0,0 +1,45 @@ +module InternalAPI + +function init end +function solve! end + +end + +abstract type AbstractDescentDirection end + +supports_line_search(::AbstractDescentDirection) = false +supports_trust_region(::AbstractDescentDirection) = false + +function get_linear_solver(alg::AbstractDescentDirection) + return Utils.safe_getproperty(alg, Val(:linsolve)) +end + +abstract type AbstractDescentCache end + +SciMLBase.get_du(cache::AbstractDescentCache) = cache.δu +SciMLBase.get_du(cache::AbstractDescentCache, ::Val{1}) = SciMLBase.get_du(cache) +SciMLBase.get_du(cache::AbstractDescentCache, ::Val{N}) where {N} = cache.δus[N - 1] +set_du!(cache::AbstractDescentCache, δu) = (cache.δu = δu) +set_du!(cache::AbstractDescentCache, δu, ::Val{1}) = set_du!(cache, δu) +set_du!(cache::AbstractDescentCache, δu, ::Val{N}) where {N} = (cache.δus[N - 1] = δu) + +function last_step_accepted(cache::AbstractDescentCache) + hasfield(typeof(cache), :last_step_accepted) && return cache.last_step_accepted + return true +end + +abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end + +get_name(alg::AbstractNonlinearSolveAlgorithm) = Utils.safe_getproperty(alg, Val(:name)) + +function concrete_jac(alg::AbstractNonlinearSolveAlgorithm) + return concrete_jac(Utils.safe_getproperty(alg, Val(:concrete_jac))) +end +concrete_jac(::Missing) = missing +concrete_jac(v::Bool) = v +concrete_jac(::Val{false}) = false +concrete_jac(::Val{true}) = true + +abstract type AbstractNonlinearSolveCache end + +abstract type AbstractLinearSolverCache end diff --git a/lib/NonlinearSolveBase/src/jacobian.jl b/lib/NonlinearSolveBase/src/jacobian.jl new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/lib/NonlinearSolveBase/src/jacobian.jl @@ -0,0 +1 @@ + diff --git a/lib/NonlinearSolveBase/src/linear_solve.jl b/lib/NonlinearSolveBase/src/linear_solve.jl new file mode 100644 index 000000000..6e7bceabb --- /dev/null +++ b/lib/NonlinearSolveBase/src/linear_solve.jl @@ -0,0 +1,140 @@ +@kwdef @concrete struct LinearSolveResult + u + success::Bool = true +end + +@concrete mutable struct LinearSolveJLCache <: AbstractLinearSolverCache + lincache + linsolve + additional_lincache::Any + precs + stats::NLStats +end + +@concrete mutable struct NativeJLLinearSolveCache <: AbstractLinearSolverCache + A + b + stats::NLStats +end + +""" + construct_linear_solver(alg, linsolve, A, b, u; stats, kwargs...) + +Construct a cache for solving linear systems of the form `A * u = b`. Following cases are +handled: + + 1. `A` is Number, then we solve it with `u = b / A` + 2. `A` is `SMatrix`, then we solve it with `u = A \\ b` (using the defaults from base + Julia) (unless a preconditioner is specified) + 3. If `linsolve` is `\\`, then we solve it with directly using `ldiv!(u, A, b)` + 4. In all other cases, we use `alg` to solve the linear system using + [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl) + +### Solving the System + +```julia +(cache::LinearSolverCache)(; + A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing, + weight = nothing, cachedata = nothing, reuse_A_if_factorization = false, kwargs...) +``` + +Returns the solution of the system `u` and stores the updated cache in `cache.lincache`. + +#### Special Handling for Rank-deficient Matrix `A` + +If we detect a failure in the linear solve (mostly due to using an algorithm that doesn't +support rank-deficient matrices), we emit a warning and attempt to solve the problem using +Pivoted QR factorization. This is quite efficient if there are only a few rank-deficient +that originate in the problem. However, if these are quite frequent for the main nonlinear +system, then it is recommended to use a different linear solver that supports rank-deficient +matrices. + +#### Keyword Arguments + + - `reuse_A_if_factorization`: If `true`, then the factorization of `A` is reused if + possible. This is useful when solving the same system with different `b` values. + If the algorithm is an iterative solver, then we reset the internal linear solve cache. + +One distinct feature of this compared to the cache from LinearSolve is that it respects the +aliasing arguments even after cache construction, i.e., if we passed in an `A` that `A` is +not mutated, we do this by copying over `A` to a preconstructed cache. +""" +function construct_linear_solver(alg, linsolve, A, b, u; stats, kwargs...) + no_preconditioner = !hasfield(typeof(alg), :precs) || alg.precs === nothing + + if (A isa Number && b isa Number) || (A isa Diagonal) + return NativeJLLinearSolveCache(A, b, stats) + elseif linsolve isa typeof(\) + !no_preconditioner && + error("Default Julia Backsolve Operator `\\` doesn't support Preconditioners") + return NativeJLLinearSolveCache(A, b, stats) + elseif no_preconditioner && linsolve === nothing + # Non-allocating linear solve exists in StaticArrays.jl + if (A isa SMatrix || A isa WrappedArray{<:Any, <:SMatrix}) && + Core.Compiler.return_type(\, Tuple{typeof(A), typeof(b)}) <: SArray + return NativeJLLinearSolveCache(A, b, stats) + end + end + + u_fixed = fix_incompatible_linsolve_arguments(A, b, u) + @bb u_cache = copy(u_fixed) + linprob = LinearProblem(A, b; u0 = u_cache, kwargs...) + + if no_preconditioner + precs, Pl, Pr = nothing, nothing, nothing + else + precs = alg.precs + Pl, Pr = precs(A, nothing, u, ntuple(Returns(nothing), 6)...) + end + Pl, Pr = wrap_preconditioners(Pl, Pr, u) + + # unlias here, we will later use these as caches + lincache = init(linprob, linsolve; alias_A = false, alias_b = false, Pl, Pr) + return LinearSolveJLCache(lincache, linsolve, nothing, precs, stats) +end + +function (cache::NativeJLLinearSolveCache)(; + A = nothing, b = nothing, linu = nothing, kwargs...) + cache.stats.nsolve += 1 + cache.stats.nfactors += 1 + + A === nothing || (cache.A = A) + b === nothing || (cache.b = b) + + if linu !== nothing && ArrayInterface.can_setindex(linu) && + applicable(ldiv!, linu, cache.A, cache.b) + ldiv!(linu, cache.A, cache.b) + res = linu + else + res = cache.A \ cache.b + end + return LinearSolveResult(; u = res) +end + +fix_incompatible_linsolve_arguments(A, b, u) = u +fix_incompatible_linsolve_arguments(::SArray, ::SArray, ::SArray) = u +function fix_incompatible_linsolve_arguments(A, b, u::SArray) + (Core.Compiler.return_type(\, Tuple{typeof(A), typeof(b)}) <: typeof(u)) && return u + @warn "Solving Linear System A::$(typeof(A)) x::$(typeof(u)) = b::$(typeof(u)) is not \ + properly supported. Converting `x` to a mutable array. Check the return type \ + of the nonlinear function provided for optimal performance." maxlog=1 + return MArray(u) +end + +set_lincache_u!(cache, u) = setproperty!(cache.lincache, :u, u) +function set_lincache_u!(cache, u::SArray) + cache.lincache.u isa MArray && return set_lincache_u!(cache, MArray(u)) + cache.lincache.u = u +end + +function wrap_preconditioners(Pl, Pr, u) + Pl = Pl === nothing ? IdentityOperator(length(u)) : Pl + Pr = Pr === nothing ? IdentityOperator(length(u)) : Pr + return Pl, Pr +end + +needs_square_A(::Any, ::Number) = false +needs_square_A(::Nothing, ::Number) = false +needs_square_A(::Nothing, ::Any) = false +needs_square_A(::typeof(\), ::Number) = false +needs_square_A(::typeof(\), ::Any) = false diff --git a/lib/NonlinearSolveBase/src/termination_conditions.jl b/lib/NonlinearSolveBase/src/termination_conditions.jl index 9f20e46bc..c2d768c17 100644 --- a/lib/NonlinearSolveBase/src/termination_conditions.jl +++ b/lib/NonlinearSolveBase/src/termination_conditions.jl @@ -278,6 +278,6 @@ function init_termination_cache(prob::AbstractNonlinearProblem, abstol, reltol, T = promote_type(eltype(du), eltype(u)) abstol = get_tolerance(u, abstol, T) reltol = get_tolerance(u, reltol, T) - cache = CommonSolve.init(prob, tc, du, u; abstol, reltol) + cache = init(prob, tc, du, u; abstol, reltol) return abstol, reltol, cache end diff --git a/lib/NonlinearSolveBase/src/utils.jl b/lib/NonlinearSolveBase/src/utils.jl index 825f63767..06f53aff4 100644 --- a/lib/NonlinearSolveBase/src/utils.jl +++ b/lib/NonlinearSolveBase/src/utils.jl @@ -90,4 +90,9 @@ end safe_reshape(x::Number, args...) = x safe_reshape(x, args...) = reshape(x, args...) +@generated function safe_getproperty(s::S, ::Val{X}) where {S, X} + hasfield(S, X) && return :(getproperty(s, $(X))) + return :(missing) +end + end diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 60e2f0663..cea9d1fbd 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -24,7 +24,7 @@ using NonlinearSolveBase: NonlinearSolveBase, nonlinearsolve_forwarddiff_solve, AbstractNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode, select_forward_mode_autodiff, select_reverse_mode_autodiff, - select_jacobian_autodiff + select_jacobian_autodiff, construct_linear_solver using Printf: @printf using Preferences: Preferences, @load_preference, @set_preferences! using RecursiveArrayTools: recursivecopy! @@ -71,7 +71,6 @@ include("descent/damped_newton.jl") include("descent/geodesic_acceleration.jl") include("internal/jacobian.jl") -include("internal/linear_solve.jl") include("internal/termination.jl") include("internal/tracing.jl") include("internal/approximate_initialization.jl") @@ -102,8 +101,10 @@ include("default.jl") include("internal/forward_diff.jl") # we need to define after the algorithms @setup_workload begin - nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1), - (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1])) + nlfuncs = ( + (NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1), + (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]) + ) probs_nls = NonlinearProblem[] for (fn, u0) in nlfuncs push!(probs_nls, NonlinearProblem(fn, u0, 2.0)) diff --git a/src/abstract_types.jl b/src/abstract_types.jl index 255c5e541..0876e2036 100644 --- a/src/abstract_types.jl +++ b/src/abstract_types.jl @@ -106,29 +106,6 @@ function last_step_accepted(cache::AbstractDescentCache) return true end -""" - AbstractNonlinearSolveLineSearchCache - -Abstract Type for all Line Search Caches used in NonlinearSolve.jl. - -### `__internal_solve!` specification - -```julia -__internal_solve!(cache::AbstractNonlinearSolveLineSearchCache, u, du; kwargs...) -``` - -Returns 2 values: - - - `unsuccessful`: If `true` it means that the Line Search Failed. - - `alpha`: The step size. -""" -abstract type AbstractNonlinearSolveLineSearchCache end - -function reinit_cache!( - cache::AbstractNonlinearSolveLineSearchCache, args...; p = cache.p, kwargs...) - cache.p = p -end - """ AbstractNonlinearSolveAlgorithm{name} <: AbstractNonlinearAlgorithm diff --git a/src/algorithms/extension_algs.jl b/src/algorithms/extension_algs.jl index 8f2c0c54c..9458f5f84 100644 --- a/src/algorithms/extension_algs.jl +++ b/src/algorithms/extension_algs.jl @@ -362,7 +362,7 @@ function FixedPointAccelerationJL(; end end if extrapolation_period === missing - extrapolation_period = algorithm === :SEA || algorithm === :VEA ? 6 : 7 + extrapolation_period = algorithm === :SEA || algorithm === :VEA ? 6 : 7 else if (algorithm === :SEA || algorithm === :VEA) && extrapolation_period % 2 != 0 error("`extrapolation_period` must be multiples of 2 for SEA and VEA") diff --git a/src/core/approximate_jacobian.jl b/src/core/approximate_jacobian.jl index b8f479a6d..757be1b04 100644 --- a/src/core/approximate_jacobian.jl +++ b/src/core/approximate_jacobian.jl @@ -294,7 +294,7 @@ function __step!(cache::ApproximateJacobianSolveCache{INV, GB, iip}; # Extremely pathological case. Jacobian was just reset and linear solve # failed. Should ideally never happen in practice unless true jacobian init # is used. - cache.retcode = LinearSolveFailureCode + cache.retcode = ReturnCode.InternalLinearSolveFailed cache.force_stop = true return else diff --git a/src/core/generalized_first_order.jl b/src/core/generalized_first_order.jl index d5ee3eeab..40a11756f 100644 --- a/src/core/generalized_first_order.jl +++ b/src/core/generalized_first_order.jl @@ -245,7 +245,7 @@ function __step!(cache::GeneralizedFirstOrderAlgorithmCache{iip, GB}; if !descent_result.linsolve_success if new_jacobian # Jacobian Information is current and linear solve failed terminate the solve - cache.retcode = LinearSolveFailureCode + cache.retcode = ReturnCode.InternalLinearSolveFailed cache.force_stop = true return else diff --git a/src/descent/damped_newton.jl b/src/descent/damped_newton.jl index ba3e1d028..f1d2996c9 100644 --- a/src/descent/damped_newton.jl +++ b/src/descent/damped_newton.jl @@ -63,7 +63,7 @@ function __internal_init(prob::AbstractNonlinearProblem, alg::DampedNewtonDescen end normal_form_damping = returns_norm_form_damping(alg.damping_fn) - normal_form_linsolve = __needs_square_A(alg.linsolve, u) + normal_form_linsolve = NonlinearSolveBase.needs_square_A(alg.linsolve, u) if u isa Number mode = :simple elseif prob isa NonlinearProblem @@ -124,7 +124,7 @@ function __internal_init(prob::AbstractNonlinearProblem, alg::DampedNewtonDescen rhs_cache = nothing end - lincache = LinearSolverCache( + lincache = construct_linear_solver( alg, alg.linsolve, A, b, _vec(u); stats, abstol, reltol, linsolve_kwargs...) return DampedNewtonDescentCache{INV, mode}( diff --git a/src/descent/dogleg.jl b/src/descent/dogleg.jl index 4c96c98f6..679f33c26 100644 --- a/src/descent/dogleg.jl +++ b/src/descent/dogleg.jl @@ -66,7 +66,7 @@ function __internal_init(prob::AbstractNonlinearProblem, alg::Dogleg, J, fu, u; T = promote_type(eltype(u), eltype(fu)) normal_form = prob isa NonlinearLeastSquaresProblem && - __needs_square_A(alg.newton_descent.linsolve, u) + NonlinearSolveBase.needs_square_A(alg.newton_descent.linsolve, u) JᵀJ_cache = !normal_form ? J * _vec(δu) : nothing # TODO: Rename return DoglegCache{INV, normal_form}(δu, δus, newton_cache, cauchy_cache, internalnorm, diff --git a/src/descent/newton.jl b/src/descent/newton.jl index 2fe7abf9a..f0964c7dd 100644 --- a/src/descent/newton.jl +++ b/src/descent/newton.jl @@ -41,7 +41,7 @@ function __internal_init(prob::NonlinearProblem, alg::NewtonDescent, J, fu, u; s @bb δu_ = similar(u) end INV && return NewtonDescentCache{true, false}(δu, δus, nothing, nothing, nothing, timer) - lincache = LinearSolverCache( + lincache = construct_linear_solver( alg, alg.linsolve, J, _vec(fu), _vec(u); stats, abstol, reltol, linsolve_kwargs...) return NewtonDescentCache{false, false}(δu, δus, lincache, nothing, nothing, timer) end @@ -53,7 +53,7 @@ function __internal_init(prob::NonlinearLeastSquaresProblem, alg::NewtonDescent, length(fu) != length(u) && @assert !INV "Precomputed Inverse for Non-Square Jacobian doesn't make sense." - normal_form = __needs_square_A(alg.linsolve, u) + normal_form = NonlinearSolveBase.needs_square_A(alg.linsolve, u) if normal_form JᵀJ = transpose(J) * J Jᵀfu = transpose(J) * _vec(fu) @@ -62,7 +62,7 @@ function __internal_init(prob::NonlinearLeastSquaresProblem, alg::NewtonDescent, JᵀJ, Jᵀfu = nothing, nothing A, b = J, _vec(fu) end - lincache = LinearSolverCache( + lincache = construct_linear_solver( alg, alg.linsolve, A, b, _vec(u); stats, abstol, reltol, linsolve_kwargs...) @bb δu = similar(u) δus = N ≤ 1 ? nothing : map(2:N) do i diff --git a/src/descent/steepest.jl b/src/descent/steepest.jl index cc2f4d128..82b02552d 100644 --- a/src/descent/steepest.jl +++ b/src/descent/steepest.jl @@ -40,7 +40,8 @@ end @bb δu_ = similar(u) end if INV - lincache = LinearSolverCache(alg, alg.linsolve, transpose(J), _vec(fu), _vec(u); + lincache = construct_linear_solver( + alg, alg.linsolve, transpose(J), _vec(fu), _vec(u); stats, abstol, reltol, linsolve_kwargs...) else lincache = nothing diff --git a/src/internal/linear_solve.jl b/src/internal/linear_solve.jl deleted file mode 100644 index 707790ff3..000000000 --- a/src/internal/linear_solve.jl +++ /dev/null @@ -1,245 +0,0 @@ -const LinearSolveFailureCode = isdefined(ReturnCode, :InternalLinearSolveFailure) ? - ReturnCode.InternalLinearSolveFailure : ReturnCode.Failure - -""" - LinearSolverCache(alg, linsolve, A, b, u; stats, kwargs...) - -Construct a cache for solving linear systems of the form `A * u = b`. Following cases are -handled: - - 1. `A` is Number, then we solve it with `u = b / A` - 2. `A` is `SMatrix`, then we solve it with `u = A \\ b` (using the defaults from base - Julia) - 3. `A` is `Diagonal`, then we solve it with `u = b ./ A.diag` - 4. In all other cases, we use `alg` to solve the linear system using - [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl). - -### Solving the System - -```julia -(cache::LinearSolverCache)(; - A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing, - weight = nothing, cachedata = nothing, reuse_A_if_factorization = false, kwargs...) -``` - -Returns the solution of the system `u` and stores the updated cache in `cache.lincache`. - -#### Special Handling for Rank-deficient Matrix `A` - -If we detect a failure in the linear solve (mostly due to using an algorithm that doesn't -support rank-deficient matrices), we emit a warning and attempt to solve the problem using -Pivoted QR factorization. This is quite efficient if there are only a few rank-deficient -that originate in the problem. However, if these are quite frequent for the main nonlinear -system, then it is recommended to use a different linear solver that supports rank-deficient -matrices. - -#### Keyword Arguments - - - `reuse_A_if_factorization`: If `true`, then the factorization of `A` is reused if - possible. This is useful when solving the same system with different `b` values. - If the algorithm is an iterative solver, then we reset the internal linear solve cache. - -One distinct feature of this compared to the cache from LinearSolve is that it respects the -aliasing arguments even after cache construction, i.e., if we passed in an `A` that `A` is -not mutated, we do this by copying over `A` to a preconstructed cache. -""" -@concrete mutable struct LinearSolverCache <: AbstractLinearSolverCache - lincache - linsolve - additional_lincache::Any - A - b - precs - stats::NLStats -end - -@inline __fix_strange_type_combination(A, b, u) = u -@inline function __fix_strange_type_combination(A, b, u::SArray) - A isa SArray && b isa SArray && return u - @warn "Solving Linear System A::$(typeof(A)) x::$(typeof(u)) = b::$(typeof(u)) is not \ - properly supported. Converting `x` to a mutable array. Check the return type \ - of the nonlinear function provided for optimal performance." maxlog=1 - return MArray(u) -end - -@inline __set_lincache_u!(cache, u) = (cache.lincache.u = u) -@inline function __set_lincache_u!(cache, u::SArray) - cache.lincache.u isa MArray && return __set_lincache_u!(cache, MArray(u)) - cache.lincache.u = u -end - -function LinearSolverCache(alg, linsolve, A, b, u; stats, kwargs...) - u_fixed = __fix_strange_type_combination(A, b, u) - - if (A isa Number && b isa Number) || - (linsolve === nothing && A isa SMatrix) || - (A isa Diagonal) || - (linsolve isa typeof(\)) - return LinearSolverCache(nothing, nothing, nothing, A, b, nothing, stats) - end - @bb u_ = copy(u_fixed) - linprob = LinearProblem(A, b; u0 = u_, kwargs...) - - if __hasfield(alg, Val(:precs)) - precs = alg.precs - Pl_, Pr_ = precs(A, nothing, u, ntuple(Returns(nothing), 6)...) - else - precs, Pl_, Pr_ = nothing, nothing, nothing - end - Pl, Pr = __wrapprecs(Pl_, Pr_, u) - - # Unalias here, we will later use these as caches - lincache = init(linprob, linsolve; alias_A = false, alias_b = false, Pl, Pr) - - return LinearSolverCache(lincache, linsolve, nothing, nothing, nothing, precs, stats) -end - -@kwdef @concrete struct LinearSolveResult - u - success::Bool = true -end - -# Direct Linear Solve Case without Caching -function (cache::LinearSolverCache{Nothing})(; - A = nothing, b = nothing, linu = nothing, kwargs...) - cache.stats.nsolve += 1 - cache.stats.nfactors += 1 - A === nothing || (cache.A = A) - b === nothing || (cache.b = b) - if A isa Diagonal - _diag = _restructure(cache.b, cache.A.diag) - @bb @. linu = cache.b / _diag - res = linu - else - res = cache.A \ cache.b - end - return LinearSolveResult(; u = res) -end - -# Use LinearSolve.jl -function (cache::LinearSolverCache)(; - A = nothing, b = nothing, linu = nothing, du = nothing, - p = nothing, weight = nothing, cachedata = nothing, - reuse_A_if_factorization = false, verbose = true, kwargs...) - cache.stats.nsolve += 1 - - __update_A!(cache, A, reuse_A_if_factorization) - b !== nothing && (cache.lincache.b = b) - linu !== nothing && __set_lincache_u!(cache, linu) - - Plprev = cache.lincache.Pl - Prprev = cache.lincache.Pr - - if cache.precs === nothing - _Pl, _Pr = nothing, nothing - else - _Pl, _Pr = cache.precs(cache.lincache.A, du, linu, p, nothing, - A !== nothing, Plprev, Prprev, cachedata) - end - - if (_Pl !== nothing || _Pr !== nothing) - Pl, Pr = __wrapprecs(_Pl, _Pr, linu) - cache.lincache.Pl = Pl - cache.lincache.Pr = Pr - end - - linres = solve!(cache.lincache) - cache.lincache = linres.cache - # Unfortunately LinearSolve.jl doesn't have the most uniform ReturnCode handling - if linres.retcode === ReturnCode.Failure - structured_mat = ArrayInterface.isstructured(cache.lincache.A) - is_gpuarray = ArrayInterface.device(cache.lincache.A) isa ArrayInterface.GPU - if !(cache.linsolve isa QRFactorization{ColumnNorm}) && !is_gpuarray && - !structured_mat - if verbose - @warn "Potential Rank Deficient Matrix Detected. Attempting to solve using \ - Pivoted QR Factorization." - end - @assert (A !== nothing)&&(b !== nothing) "This case is not yet supported. \ - Please open an issue at \ - https://github.com/SciML/NonlinearSolve.jl" - if cache.additional_lincache === nothing # First time - linprob = LinearProblem(A, b; u0 = linres.u) - cache.additional_lincache = init( - linprob, QRFactorization(ColumnNorm()); alias_u0 = false, - alias_A = false, alias_b = false, cache.lincache.Pl, cache.lincache.Pr) - else - cache.additional_lincache.A = A - cache.additional_lincache.b = b - cache.additional_lincache.Pl = cache.lincache.Pl - cache.additional_lincache.Pr = cache.lincache.Pr - end - linres = solve!(cache.additional_lincache) - cache.additional_lincache = linres.cache - linres.retcode === ReturnCode.Failure && - return LinearSolveResult(; u = linres.u, success = false) - return LinearSolveResult(; u = linres.u) - elseif !(cache.linsolve isa QRFactorization{ColumnNorm}) - if verbose - if structured_mat || is_gpuarray - mat_desc = structured_mat ? "Structured" : "GPU" - @warn "Potential Rank Deficient Matrix Detected. But Matrix is \ - $(mat_desc). Currently, we don't attempt to solve Rank Deficient \ - $(mat_desc) Matrices. Please open an issue at \ - https://github.com/SciML/NonlinearSolve.jl" - end - end - end - return LinearSolveResult(; u = linres.u, success = false) - end - - return LinearSolveResult(; u = linres.u) -end - -@inline __update_A!(cache::LinearSolverCache, ::Nothing, reuse) = cache -@inline function __update_A!(cache::LinearSolverCache, A, reuse) - return __update_A!(cache, __getproperty(cache.lincache, Val(:alg)), A, reuse) -end -@inline function __update_A!(cache, alg, A, reuse) - # Not a Factorization Algorithm so don't update `nfactors` - __set_lincache_A(cache.lincache, A) - return cache -end -@inline function __update_A!(cache, ::AbstractFactorization, A, reuse) - reuse && return cache - __set_lincache_A(cache.lincache, A) - cache.stats.nfactors += 1 - return cache -end -@inline function __update_A!(cache, alg::DefaultLinearSolver, A, reuse) - if alg == DefaultLinearSolver(DefaultAlgorithmChoice.KrylovJL_GMRES) - # Force a reset of the cache. This is not properly handled in LinearSolve.jl - __set_lincache_A(cache.lincache, A) - return cache - end - reuse && return cache - __set_lincache_A(cache.lincache, A) - cache.stats.nfactors += 1 - return cache -end - -function __set_lincache_A(lincache, new_A) - if LinearSolve.default_alias_A(lincache.alg, new_A, lincache.b) - lincache.A = new_A - else - if can_setindex(lincache.A) - copyto!(lincache.A, new_A) - lincache.A = lincache.A - else - lincache.A = new_A - end - end -end - -function __wrapprecs(_Pl, _Pr, u) - Pl = _Pl !== nothing ? _Pl : IdentityOperator(length(u)) - Pr = _Pr !== nothing ? _Pr : IdentityOperator(length(u)) - return Pl, Pr -end - -@inline __needs_square_A(_, ::Number) = false -@inline __needs_square_A(::Nothing, ::Number) = false -@inline __needs_square_A(::Nothing, _) = false -@inline __needs_square_A(linsolve, _) = LinearSolve.needs_square_A(linsolve) -@inline __needs_square_A(::typeof(\), _) = false -@inline __needs_square_A(::typeof(\), ::Number) = false # Ambiguity Fix diff --git a/src/utils.jl b/src/utils.jl index 069fde86e..a080bd124 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,8 +2,6 @@ @inline DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, cachedata) = nothing, nothing # Helper Functions -@inline __hasfield(::T, ::Val{field}) where {T, field} = hasfield(T, field) - @generated function __getproperty(s::S, ::Val{X}) where {S, X} hasfield(S, X) && return :(s.$X) return :(missing)