Skip to content

Commit

Permalink
refactor: move LinearSolve wrapper into NonlinearSolveBase
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 26, 2024
1 parent ecded14 commit 0cd904f
Show file tree
Hide file tree
Showing 17 changed files with 357 additions and 290 deletions.
11 changes: 10 additions & 1 deletion lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "NonlinearSolveBase"
uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.0.0"
version = "1.1.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand All @@ -15,22 +16,27 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
FunctionProperties = "f62d2435-5019-4c03-9749-2d4c77af0cbc"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[extensions]
NonlinearSolveBaseDiffEqBaseExt = "DiffEqBase"
NonlinearSolveBaseForwardDiffExt = "ForwardDiff"
NonlinearSolveBaseLinearSolveExt = "LinearSolve"
NonlinearSolveBaseSparseArraysExt = "SparseArrays"

[compat]
ADTypes = "1.9"
Adapt = "4.1.0"
Aqua = "0.8.7"
ArrayInterface = "7.9"
CommonSolve = "0.2.4"
Expand All @@ -45,9 +51,12 @@ ForwardDiff = "0.10.36"
FunctionProperties = "0.1.2"
InteractiveUtils = "<0.0.1, 1"
LinearAlgebra = "1.10"
LinearSolve = "2.36.1"
Markdown = "1.10"
MaybeInplace = "0.1.4"
RecursiveArrayTools = "3"
SciMLBase = "2.50"
SciMLOperators = "0.3.10"
SparseArrays = "1.10"
StaticArraysCore = "1.4"
Test = "1.10"
Expand Down
124 changes: 124 additions & 0 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
module NonlinearSolveBaseLinearSolveExt

using ArrayInterface: ArrayInterface
using CommonSolve: CommonSolve, init, solve!
using LinearAlgebra: ColumnNorm
using LinearSolve: LinearSolve
using NonlinearSolveBase: NonlinearSolveBase, LinearSolveJLCache, LinearSolveResult, Utils
using SciMLBase: ReturnCode

function (cache::LinearSolveJLCache)(;
A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing,
cachedata = nothing, reuse_A_if_factorization = false, verbose = true, kwargs...)
cache.stats.nsolve += 1

update_A!(cache, A, reuse_A_if_factorization)
b !== nothing && setproperty!(cache.lincache, :b, b)
linu !== nothing && NonlinearSolveBase.set_lincache_u!(cache, linu)

Plprev = cache.lincache.Pl
Prprev = cache.lincache.Pr

if cache.precs === nothing
Pl, Pr = nothing, nothing
else
Pl, Pr = cache.precs(cache.lincache.A, du, linu, p, nothing,
A !== nothing, Plprev, Prprev, cachedata)
end

if Pl !== nothing || Pr !== nothing
Pl, Pr = NonlinearSolveBase.wrap_preconditioners(Pl, Pr, linu)
cache.lincache.Pl = Pl
cache.lincache.Pr = Pr
end

linres = solve!(cache.lincache)
cache.lincache = linres.cache
# Unfortunately LinearSolve.jl doesn't have the most uniform ReturnCode handling
if linres.retcode === ReturnCode.Failure
structured_mat = ArrayInterface.isstructured(cache.lincache.A)
is_gpuarray = ArrayInterface.device(cache.lincache.A) isa ArrayInterface.GPU

if !(cache.linsolve isa LinearSolve.QRFactorization{ColumnNorm}) && !is_gpuarray &&
!structured_mat
if verbose
@warn "Potential Rank Deficient Matrix Detected. Attempting to solve using \
Pivoted QR Factorization."
end
@assert (A !== nothing)&&(b !== nothing) "This case is not yet supported. \
Please open an issue at \
https://github.com/SciML/NonlinearSolve.jl"
if cache.additional_lincache === nothing # First time
linprob = LinearProblem(A, b; u0 = linres.u)
cache.additional_lincache = init(
linprob, LinearSolve.QRFactorization(ColumnNorm()); alias_u0 = false,
alias_A = false, alias_b = false, cache.lincache.Pl, cache.lincache.Pr)
else
cache.additional_lincache.A = A
cache.additional_lincache.b = b
cache.additional_lincache.Pl = cache.lincache.Pl
cache.additional_lincache.Pr = cache.lincache.Pr
end
linres = solve!(cache.additional_lincache)
cache.additional_lincache = linres.cache
linres.retcode === ReturnCode.Failure &&
return LinearSolveResult(; linres.u, success = false)
return LinearSolveResult(; linres.u)
elseif !(cache.linsolve isa LinearSolve.QRFactorization{ColumnNorm})
if verbose
if structured_mat || is_gpuarray
mat_desc = structured_mat ? "Structured" : "GPU"
@warn "Potential Rank Deficient Matrix Detected. But Matrix is \
$(mat_desc). Currently, we don't attempt to solve Rank Deficient \
$(mat_desc) Matrices. Please open an issue at \
https://github.com/SciML/NonlinearSolve.jl"
end
end
end
return LinearSolveResult(; linres.u, success = false)
end

return LinearSolveResult(; linres.u)
end

NonlinearSolveBase.needs_square_A(linsolve, ::Any) = LinearSolve.needs_square_A(linsolve)

update_A!(cache::LinearSolveJLCache, ::Nothing, reuse) = cache
function update_A!(cache::LinearSolveJLCache, A, reuse)
return update_A!(cache, Utils.safe_getproperty(cache.linsolve, Val(:alg)), A, reuse)
end

function update_A!(cache::LinearSolveJLCache, alg, A, reuse)
# Not a Factorization Algorithm so don't update `nfactors`
set_lincache_A!(cache.lincache, A)
return cache
end
function update_A!(cache::LinearSolveJLCache, ::LinearSolve.AbstractFactorization, A, reuse)
reuse && return cache
set_lincache_A!(cache.lincache, A)
cache.stats.nfactors += 1
return cache
end
function update_A!(
cache::LinearSolveJLCache, alg::LinearSolve.DefaultLinearSolver, A, reuse)
if alg ==
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)
# Force a reset of the cache. This is not properly handled in LinearSolve.jl
set_lincache_A!(cache.lincache, A)
return cache
end
reuse && return cache
set_lincache_A!(cache.lincache, A)
cache.stats.nfactors += 1
return cache
end

function set_lincache_A!(lincache, new_A)
if !LinearSolve.default_alias_A(lincache.alg, new_A, lincache.b) &&
ArrayInterface.can_setindex(lincache.A)
copyto!(lincache.A, new_A)
end
lincache.A = new_A # important!! triggers special code in `setproperty!`
end

end
24 changes: 18 additions & 6 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,41 @@
module NonlinearSolveBase

using ADTypes: ADTypes, AbstractADType
using Adapt: WrappedArray
using ArrayInterface: ArrayInterface
using CommonSolve: CommonSolve
using CommonSolve: CommonSolve, init
using Compat: @compat
using ConcreteStructs: @concrete
using DifferentiationInterface: DifferentiationInterface
using EnzymeCore: EnzymeCore
using FastClosures: @closure
using FunctionProperties: hasbranching
using LinearAlgebra: norm
using LinearAlgebra: Diagonal, norm, ldiv!
using Markdown: @doc_str
using MaybeInplace: @bb
using RecursiveArrayTools: AbstractVectorOfArray, ArrayPartition
using SciMLBase: SciMLBase, ReturnCode, AbstractODEIntegrator, AbstractNonlinearProblem,
NonlinearProblem, NonlinearLeastSquaresProblem, AbstractNonlinearFunction,
@add_kwonly, StandardNonlinearProblem, NullParameters, isinplace,
warn_paramtype
using StaticArraysCore: StaticArray
AbstractNonlinearAlgorithm, AbstractNonlinearFunction,
NonlinearProblem, NonlinearLeastSquaresProblem, StandardNonlinearProblem,
NullParameters, NLStats, LinearProblem, isinplace, warn_paramtype,
@add_kwonly
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
using StaticArraysCore: StaticArray, SMatrix, SArray, MArray

const DI = DifferentiationInterface

include("public.jl")
include("utils.jl")

include("abstract_types.jl")

include("immutable_problem.jl")
include("common_defaults.jl")
include("termination_conditions.jl")

include("autodiff.jl")
include("jacobian.jl")
include("linear_solve.jl")

# Unexported Public API
@compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))
Expand All @@ -36,6 +44,10 @@ include("autodiff.jl")
(select_forward_mode_autodiff, select_reverse_mode_autodiff,
select_jacobian_autodiff))

# public for NonlinearSolve.jl to use
@compat(public, (InternalAPI, supports_line_search, supports_trust_region, set_du!))
@compat(public, (construct_linear_solver, needs_square_A))

export RelTerminationMode, AbsTerminationMode,
NormTerminationMode, RelNormTerminationMode, AbsNormTerminationMode,
RelNormSafeTerminationMode, AbsNormSafeTerminationMode,
Expand Down
45 changes: 45 additions & 0 deletions lib/NonlinearSolveBase/src/abstract_types.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
module InternalAPI

function init end
function solve! end

end

abstract type AbstractDescentDirection end

supports_line_search(::AbstractDescentDirection) = false
supports_trust_region(::AbstractDescentDirection) = false

function get_linear_solver(alg::AbstractDescentDirection)
return Utils.safe_getproperty(alg, Val(:linsolve))
end

abstract type AbstractDescentCache end

SciMLBase.get_du(cache::AbstractDescentCache) = cache.δu
SciMLBase.get_du(cache::AbstractDescentCache, ::Val{1}) = SciMLBase.get_du(cache)
SciMLBase.get_du(cache::AbstractDescentCache, ::Val{N}) where {N} = cache.δus[N - 1]
set_du!(cache::AbstractDescentCache, δu) = (cache.δu = δu)
set_du!(cache::AbstractDescentCache, δu, ::Val{1}) = set_du!(cache, δu)
set_du!(cache::AbstractDescentCache, δu, ::Val{N}) where {N} = (cache.δus[N - 1] = δu)

function last_step_accepted(cache::AbstractDescentCache)
hasfield(typeof(cache), :last_step_accepted) && return cache.last_step_accepted
return true
end

abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end

get_name(alg::AbstractNonlinearSolveAlgorithm) = Utils.safe_getproperty(alg, Val(:name))

function concrete_jac(alg::AbstractNonlinearSolveAlgorithm)
return concrete_jac(Utils.safe_getproperty(alg, Val(:concrete_jac)))
end
concrete_jac(::Missing) = missing
concrete_jac(v::Bool) = v
concrete_jac(::Val{false}) = false
concrete_jac(::Val{true}) = true

abstract type AbstractNonlinearSolveCache end

abstract type AbstractLinearSolverCache end
Empty file.
Loading

0 comments on commit 0cd904f

Please sign in to comment.