Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SciMLOperators in LinearSolve #270

Merged
merged 26 commits into from
Feb 5, 2023
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
0bb62a5
added scimloperators
vpuri3 Feb 2, 2023
2190beb
extended SciMLOperators.issquare
vpuri3 Feb 2, 2023
63396af
support abstractscimloperator, matrixoperator
vpuri3 Feb 3, 2023
d500b2c
support scimloperators in linearsolvecuda
vpuri3 Feb 3, 2023
1c4d194
rm issquare comment
vpuri3 Feb 3, 2023
9d8bb18
_isidentity_struct
vpuri3 Feb 3, 2023
49132bc
deprecate preconditioner interface and use scimloperators
vpuri3 Feb 3, 2023
5674434
use scimloperators.identityoperator in place of iterativesolves.identity
vpuri3 Feb 3, 2023
6f3ae84
direcldiv-bang
vpuri3 Feb 4, 2023
6d302c2
add methods defaultalg to support scimlops
vpuri3 Feb 4, 2023
14fe2a5
comments
vpuri3 Feb 4, 2023
a5cd365
comments
vpuri3 Feb 4, 2023
99a01fc
https://github.com/JuliaSmoothOptimizers/Krylov.jl/issues/612#issueco…
vpuri3 Feb 4, 2023
34f035d
directldiv-bang for bidiagonal, factorization
vpuri3 Feb 4, 2023
f32ddcc
rm defaultalg(A::Diagonal, b, ::OperatorAssumptions{false}) because d…
vpuri3 Feb 5, 2023
76ffd29
update compat
vpuri3 Feb 5, 2023
7828889
cleaning
vpuri3 Feb 5, 2023
adb2f08
format
vpuri3 Feb 5, 2023
a2a394b
put invpreconditioner back as it is used upstream. deprecate later
vpuri3 Feb 5, 2023
616e21c
lu(diagonal) not working in 1.6
vpuri3 Feb 5, 2023
c1c683a
fix test
vpuri3 Feb 5, 2023
4b8a773
readd IterativeSolvers.Identity. its being used downstream in ordinar…
vpuri3 Feb 5, 2023
16b76e9
__issquare
vpuri3 Feb 5, 2023
28f1781
__issquare fix in LinearSolveHYPRE
vpuri3 Feb 5, 2023
fc427b9
fix DirectLdiv<bang> test
vpuri3 Feb 5, 2023
c763a17
scimlbase compat
vpuri3 Feb 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -43,7 +44,8 @@ KrylovKit = "0.5, 0.6"
Preferences = "1"
RecursiveFactorization = "0.2.8"
Reexport = "1"
SciMLBase = "1.68"
SciMLBase = "1.81.1"
SciMLOperators = "0.1.19"
Setfield = "0.7, 0.8, 1"
SnoopPrecompile = "1"
Sparspak = "0.3.6"
Expand Down
4 changes: 2 additions & 2 deletions ext/LinearSolveHYPRE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using HYPRE.LibHYPRE: HYPRE_Complex
using HYPRE: HYPRE, HYPREMatrix, HYPRESolver, HYPREVector
using IterativeSolvers: Identity
using LinearSolve: HYPREAlgorithm, LinearCache, LinearProblem, LinearSolve,
OperatorAssumptions, default_tol, init_cacheval, issquare, set_cacheval
OperatorAssumptions, default_tol, init_cacheval, __issquare, set_cacheval
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to be careful with this release.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. good we have CI test for 1.9

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests passing with lts, release, beta locally.

using SciMLBase: LinearProblem, SciMLBase
using UnPack: @unpack
using Setfield: @set!
Expand Down Expand Up @@ -82,7 +82,7 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm,

cache = LinearCache{
typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol), issquare(assumptions)
typeof(Pl), typeof(Pr), typeof(reltol), __issquare(assumptions)
}(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol,
maxiters,
verbose, assumptions)
Expand Down
6 changes: 4 additions & 2 deletions lib/LinearSolveCUDA/src/LinearSolveCUDA.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module LinearSolveCUDA

using CUDA, LinearAlgebra, LinearSolve, SciMLBase
using SciMLBase: AbstractSciMLOperator

struct CudaOffloadFactorization <: LinearSolve.AbstractFactorization end

Expand All @@ -17,12 +18,13 @@ function SciMLBase.solve(cache::LinearSolve.LinearCache, alg::CudaOffloadFactori
end

function LinearSolve.do_factorization(alg::CudaOffloadFactorization, A, b, u)
A isa Union{AbstractMatrix, SciMLBase.AbstractDiffEqOperator} ||
A isa Union{AbstractMatrix, AbstractSciMLOperator} ||
error("LU is not defined for $(typeof(A))")

if A isa SciMLBase.AbstractDiffEqOperator
if A isa Union{MatrixOperator, DiffEqArrayOperator}
A = A.A
end

fact = qr(CUDA.CuArray(A))
return fact
end
Expand Down
16 changes: 13 additions & 3 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ end
using ArrayInterfaceCore
using RecursiveFactorization
using Base: cache_dependencies, Bool
import Base: eltype, adjoint, inv
using LinearAlgebra
using IterativeSolvers: Identity
using SparseArrays
using SciMLBase: AbstractDiffEqOperator, AbstractLinearAlgorithm
using SciMLBase: AbstractLinearAlgorithm
using SciMLOperators
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
using Setfield
using UnPack
using SuiteSparse
Expand Down Expand Up @@ -41,6 +42,15 @@ needs_concrete_A(alg::AbstractFactorization) = true
needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
needs_concrete_A(alg::AbstractSolveFunction) = false

# Util

_isidentity_struct(A) = false
_isidentity_struct(λ::Number) = isone(λ)
_isidentity_struct(A::UniformScaling) = isone(A.λ)
_isidentity_struct(::IterativeSolvers.Identity) = true
_isidentity_struct(::SciMLOperators.IdentityOperator) = true
_isidentity_struct(::SciMLBase.DiffEqIdentity) = true

# Code

const INCLUDE_SPARSE = Preferences.@load_preference("include_sparse", Base.USE_GPL_LIBS)
Expand Down Expand Up @@ -97,7 +107,7 @@ export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization,
SparspakFactorization, DiagonalFactorization

export LinearSolveFunction
export LinearSolveFunction, DirectLdiv!

export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,
KrylovJL_BICGSTAB, KrylovJL_LSMR, KrylovJL_CRAIGMR,
Expand Down
20 changes: 10 additions & 10 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
struct OperatorAssumptions{issquare} end
struct OperatorAssumptions{issq} end
function OperatorAssumptions(issquare = nothing)
issquare = something(_unwrap_val(issquare), Nothing)
OperatorAssumptions{issquare}()
issq = something(_unwrap_val(issquare), Nothing)
OperatorAssumptions{issq}()
end
issquare(::OperatorAssumptions{issq}) where {issq} = issq
__issquare(::OperatorAssumptions{issq}) where {issq} = issq

struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issquare}
struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
A::TA
b::Tb
u::Tu
Expand All @@ -19,7 +19,7 @@ struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issquare}
reltol::Ttol
maxiters::Int
verbose::Bool
assumptions::OperatorAssumptions{issquare}
assumptions::OperatorAssumptions{issq}
end

"""
Expand Down Expand Up @@ -92,9 +92,9 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
reltol = default_tol(eltype(prob.A)),
maxiters::Int = length(prob.b),
verbose::Bool = false,
Pl = Identity(),
Pr = Identity(),
assumptions = OperatorAssumptions(),
Pl = IdentityOperator{size(prob.A, 1)}(),
Pr = IdentityOperator{size(prob.A, 2)}(),
assumptions = OperatorAssumptions(Val(issquare(prob.A))),
kwargs...)
@unpack A, b, u0, p = prob

Expand Down Expand Up @@ -129,7 +129,7 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
typeof(Pl),
typeof(Pr),
typeof(reltol),
issquare(assumptions)
__issquare(assumptions)
}(A,
b,
u0,
Expand Down
41 changes: 28 additions & 13 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,30 @@
# For SciML algorithms already using `defaultalg`, all assume square matrix.
defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(Val(true)))

function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions)
function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b,
assumptions::OperatorAssumptions)
defaultalg(A.A, b, assumptions)
end

# Ambiguity handling
function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{nothing})
function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b,
assumptions::OperatorAssumptions{nothing})
defaultalg(A.A, b, assumptions)
end

function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{false})
function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b,
assumptions::OperatorAssumptions{false})
defaultalg(A.A, b, assumptions)
end

function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{true})
function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b,
assumptions::OperatorAssumptions{true})
defaultalg(A.A, b, assumptions)
end

function defaultalg(A, b, ::OperatorAssumptions{Nothing})
issquare = size(A, 1) == size(A, 2)
defaultalg(A, b, OperatorAssumptions(Val(issquare)))
issq = issquare(A)
defaultalg(A, b, OperatorAssumptions(Val(issq)))
end

function defaultalg(A::Tridiagonal, b, ::OperatorAssumptions{true})
Expand All @@ -33,10 +37,13 @@ end
function defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions{true})
GenericFactorization(; fact_alg = ldlt!)
end
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{true})
DiagonalFactorization()
function defaultalg(A::Bidiagonal, b, ::OperatorAssumptions{true})
DirectLdiv!()
end
function defaultalg(A::Factorization, b, ::OperatorAssumptions{true})
DirectLdiv!()
end
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{false})
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{true})
DiagonalFactorization()
end
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{Nothing})
Expand Down Expand Up @@ -75,18 +82,26 @@ function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{
end
end

function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b,
assumptions::OperatorAssumptions)
function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
assumptions::OperatorAssumptions{true})
if has_ldiv!(A)
return DirectLdiv!()
end

KrylovJL_GMRES()
end

# Ambiguity handling
function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b,
function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
assumptions::OperatorAssumptions{Nothing})
if has_ldiv!(A)
return DirectLdiv!()
end

KrylovJL_GMRES()
end

function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b,
function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
assumptions::OperatorAssumptions{false})
m, n = size(A)
if m < n
Expand Down
11 changes: 6 additions & 5 deletions src/iterative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
M = cache.Pl
N = cache.Pr

M = (M === Identity()) ? I : InvPreconditioner(M)
N = (N === Identity()) ? I : InvPreconditioner(N)
# use no-op preconditioner for Krylov.jl (LinearAlgebra.I) when M/N is identity
M = _isidentity_struct(M) ? I : M
N = _isidentity_struct(M) ? I : N

atol = float(cache.abstol)
rtol = float(cache.reltol)
Expand All @@ -160,7 +161,7 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)

args = (cache.cacheval, cache.A, cache.b)
kwargs = (atol = atol, rtol = rtol, itmax = itmax, verbose = verbose,
history = true, alg.kwargs...)
ldiv = true, history = true, alg.kwargs...)

if cache.cacheval isa Krylov.CgSolver
N !== I &&
Expand Down Expand Up @@ -234,15 +235,15 @@ function init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, maxiters::Int,
alg.kwargs...)

iterable = if alg.generate_iterator === IterativeSolvers.cg_iterator!
Pr !== Identity() &&
!_isidentity_struct(Pr) &&
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
alg.generate_iterator(u, A, b, Pl;
kwargs...)
elseif alg.generate_iterator === IterativeSolvers.gmres_iterable!
alg.generate_iterator(u, A, b; Pl = Pl, Pr = Pr, restart = restart,
kwargs...)
elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator!
Pr !== Identity() &&
!_isidentity_struct(Pr) &&
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
alg.generate_iterator(u, A, b, alg.args...; Pl = Pl,
abstol = abstol, reltol = reltol,
Expand Down
9 changes: 9 additions & 0 deletions src/solve_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,12 @@ function SciMLBase.solve(cache::LinearCache, alg::LinearSolveFunction,

return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end

struct DirectLdiv! <: AbstractSolveFunction end

function SciMLBase.solve(cache::LinearCache, alg::DirectLdiv!, args...; kwargs...)
@unpack A, b, u = cache
ldiv!(u, A, b)

return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end
Loading