-
-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: move LinearSolve wrapper into NonlinearSolveBase
- Loading branch information
Showing
19 changed files
with
363 additions
and
292 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
name = "NonlinearSolveBase" | ||
uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0" | ||
authors = ["Avik Pal <[email protected]> and contributors"] | ||
version = "1.0.0" | ||
version = "1.1.0" | ||
|
||
[deps] | ||
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" | ||
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" | ||
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" | ||
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" | ||
|
@@ -15,22 +16,27 @@ FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" | |
FunctionProperties = "f62d2435-5019-4c03-9749-2d4c77af0cbc" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" | ||
MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" | ||
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" | ||
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" | ||
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" | ||
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" | ||
|
||
[weakdeps] | ||
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" | ||
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" | ||
|
||
[extensions] | ||
NonlinearSolveBaseDiffEqBaseExt = "DiffEqBase" | ||
NonlinearSolveBaseForwardDiffExt = "ForwardDiff" | ||
NonlinearSolveBaseLinearSolveExt = "LinearSolve" | ||
NonlinearSolveBaseSparseArraysExt = "SparseArrays" | ||
|
||
[compat] | ||
ADTypes = "1.9" | ||
Adapt = "4.1.0" | ||
Aqua = "0.8.7" | ||
ArrayInterface = "7.9" | ||
CommonSolve = "0.2.4" | ||
|
@@ -45,9 +51,12 @@ ForwardDiff = "0.10.36" | |
FunctionProperties = "0.1.2" | ||
InteractiveUtils = "<0.0.1, 1" | ||
LinearAlgebra = "1.10" | ||
LinearSolve = "2.36.1" | ||
Markdown = "1.10" | ||
MaybeInplace = "0.1.4" | ||
RecursiveArrayTools = "3" | ||
SciMLBase = "2.50" | ||
SciMLOperators = "0.3.10" | ||
SparseArrays = "1.10" | ||
StaticArraysCore = "1.4" | ||
Test = "1.10" | ||
|
127 changes: 127 additions & 0 deletions
127
lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
module NonlinearSolveBaseLinearSolveExt | ||
|
||
using ArrayInterface: ArrayInterface | ||
using CommonSolve: CommonSolve, init, solve! | ||
using LinearAlgebra: ColumnNorm | ||
using LinearSolve: LinearSolve, QRFactorization | ||
using NonlinearSolveBase: NonlinearSolveBase, LinearSolveJLCache, LinearSolveResult, Utils | ||
using SciMLBase: ReturnCode, LinearProblem | ||
|
||
function (cache::LinearSolveJLCache)(; | ||
A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing, | ||
cachedata = nothing, reuse_A_if_factorization = false, verbose = true, kwargs...) | ||
cache.stats.nsolve += 1 | ||
|
||
update_A!(cache, A, reuse_A_if_factorization) | ||
b !== nothing && setproperty!(cache.lincache, :b, b) | ||
linu !== nothing && NonlinearSolveBase.set_lincache_u!(cache, linu) | ||
|
||
Plprev = cache.lincache.Pl | ||
Prprev = cache.lincache.Pr | ||
|
||
if cache.precs === nothing | ||
Pl, Pr = nothing, nothing | ||
else | ||
Pl, Pr = cache.precs(cache.lincache.A, du, linu, p, nothing, | ||
A !== nothing, Plprev, Prprev, cachedata) | ||
end | ||
|
||
if Pl !== nothing || Pr !== nothing | ||
Pl, Pr = NonlinearSolveBase.wrap_preconditioners(Pl, Pr, linu) | ||
cache.lincache.Pl = Pl | ||
cache.lincache.Pr = Pr | ||
end | ||
|
||
linres = solve!(cache.lincache) | ||
cache.lincache = linres.cache | ||
# Unfortunately LinearSolve.jl doesn't have the most uniform ReturnCode handling | ||
if linres.retcode === ReturnCode.Failure | ||
structured_mat = ArrayInterface.isstructured(cache.lincache.A) | ||
is_gpuarray = ArrayInterface.device(cache.lincache.A) isa ArrayInterface.GPU | ||
|
||
if !(cache.linsolve isa QRFactorization{ColumnNorm}) && !is_gpuarray && | ||
!structured_mat | ||
if verbose | ||
@warn "Potential Rank Deficient Matrix Detected. Attempting to solve using \ | ||
Pivoted QR Factorization." | ||
end | ||
@assert (A !== nothing)&&(b !== nothing) "This case is not yet supported. \ | ||
Please open an issue at \ | ||
https://github.com/SciML/NonlinearSolve.jl" | ||
if cache.additional_lincache === nothing # First time | ||
linprob = LinearProblem(A, b; u0 = linres.u) | ||
cache.additional_lincache = init( | ||
linprob, QRFactorization(ColumnNorm()); alias_u0 = false, | ||
alias_A = false, alias_b = false, cache.lincache.Pl, cache.lincache.Pr) | ||
else | ||
cache.additional_lincache.A = A | ||
cache.additional_lincache.b = b | ||
cache.additional_lincache.Pl = cache.lincache.Pl | ||
cache.additional_lincache.Pr = cache.lincache.Pr | ||
end | ||
linres = solve!(cache.additional_lincache) | ||
cache.additional_lincache = linres.cache | ||
linres.retcode === ReturnCode.Failure && | ||
return LinearSolveResult(; linres.u, success = false) | ||
return LinearSolveResult(; linres.u) | ||
elseif !(cache.linsolve isa QRFactorization{ColumnNorm}) | ||
if verbose | ||
if structured_mat || is_gpuarray | ||
mat_desc = structured_mat ? "Structured" : "GPU" | ||
@warn "Potential Rank Deficient Matrix Detected. But Matrix is \ | ||
$(mat_desc). Currently, we don't attempt to solve Rank Deficient \ | ||
$(mat_desc) Matrices. Please open an issue at \ | ||
https://github.com/SciML/NonlinearSolve.jl" | ||
end | ||
end | ||
end | ||
return LinearSolveResult(; linres.u, success = false) | ||
end | ||
|
||
return LinearSolveResult(; linres.u) | ||
end | ||
|
||
NonlinearSolveBase.needs_square_A(linsolve, ::Any) = LinearSolve.needs_square_A(linsolve) | ||
|
||
update_A!(cache::LinearSolveJLCache, ::Nothing, reuse) = cache | ||
function update_A!(cache::LinearSolveJLCache, A, reuse) | ||
return update_A!(cache, Utils.safe_getproperty(cache.linsolve, Val(:alg)), A, reuse) | ||
end | ||
|
||
function update_A!(cache::LinearSolveJLCache, alg, A, reuse) | ||
# Not a Factorization Algorithm so don't update `nfactors` | ||
set_lincache_A!(cache.lincache, A) | ||
return cache | ||
end | ||
function update_A!(cache::LinearSolveJLCache, ::LinearSolve.AbstractFactorization, A, reuse) | ||
reuse && return cache | ||
set_lincache_A!(cache.lincache, A) | ||
cache.stats.nfactors += 1 | ||
return cache | ||
end | ||
function update_A!( | ||
cache::LinearSolveJLCache, alg::LinearSolve.DefaultLinearSolver, A, reuse) | ||
if alg == | ||
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES) | ||
# Force a reset of the cache. This is not properly handled in LinearSolve.jl | ||
set_lincache_A!(cache.lincache, A) | ||
return cache | ||
end | ||
reuse && return cache | ||
set_lincache_A!(cache.lincache, A) | ||
cache.stats.nfactors += 1 | ||
return cache | ||
end | ||
|
||
function set_lincache_A!(lincache, new_A) | ||
if !LinearSolve.default_alias_A(lincache.alg, new_A, lincache.b) && | ||
ArrayInterface.can_setindex(lincache.A) | ||
copyto!(lincache.A, new_A) | ||
lincache.A = lincache.A # important!! triggers special code in `setproperty!` | ||
return | ||
end | ||
lincache.A = new_A | ||
return | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
module InternalAPI | ||
|
||
function init end | ||
function solve! end | ||
|
||
end | ||
|
||
abstract type AbstractDescentDirection end | ||
|
||
supports_line_search(::AbstractDescentDirection) = false | ||
supports_trust_region(::AbstractDescentDirection) = false | ||
|
||
function get_linear_solver(alg::AbstractDescentDirection) | ||
return Utils.safe_getproperty(alg, Val(:linsolve)) | ||
end | ||
|
||
abstract type AbstractDescentCache end | ||
|
||
SciMLBase.get_du(cache::AbstractDescentCache) = cache.δu | ||
SciMLBase.get_du(cache::AbstractDescentCache, ::Val{1}) = SciMLBase.get_du(cache) | ||
SciMLBase.get_du(cache::AbstractDescentCache, ::Val{N}) where {N} = cache.δus[N - 1] | ||
set_du!(cache::AbstractDescentCache, δu) = (cache.δu = δu) | ||
set_du!(cache::AbstractDescentCache, δu, ::Val{1}) = set_du!(cache, δu) | ||
set_du!(cache::AbstractDescentCache, δu, ::Val{N}) where {N} = (cache.δus[N - 1] = δu) | ||
|
||
function last_step_accepted(cache::AbstractDescentCache) | ||
hasfield(typeof(cache), :last_step_accepted) && return cache.last_step_accepted | ||
return true | ||
end | ||
|
||
abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end | ||
|
||
get_name(alg::AbstractNonlinearSolveAlgorithm) = Utils.safe_getproperty(alg, Val(:name)) | ||
|
||
function concrete_jac(alg::AbstractNonlinearSolveAlgorithm) | ||
return concrete_jac(Utils.safe_getproperty(alg, Val(:concrete_jac))) | ||
end | ||
concrete_jac(::Missing) = missing | ||
concrete_jac(v::Bool) = v | ||
concrete_jac(::Val{false}) = false | ||
concrete_jac(::Val{true}) = true | ||
|
||
abstract type AbstractNonlinearSolveCache end | ||
|
||
abstract type AbstractLinearSolverCache end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
Oops, something went wrong.