Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SciMLOperators in LinearSolve #270

Merged
merged 26 commits into from
Feb 5, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
0bb62a5
added scimloperators
vpuri3 Feb 2, 2023
2190beb
extended SciMLOperators.issquare
vpuri3 Feb 2, 2023
63396af
support abstractscimloperator, matrixoperator
vpuri3 Feb 3, 2023
d500b2c
support scimloperators in linearsolvecuda
vpuri3 Feb 3, 2023
1c4d194
rm issquare comment
vpuri3 Feb 3, 2023
9d8bb18
_isidentity_struct
vpuri3 Feb 3, 2023
49132bc
deprecate preconditioner interface and use scimloperators
vpuri3 Feb 3, 2023
5674434
use scimloperators.identityoperator in place of iterativesolves.identity
vpuri3 Feb 3, 2023
6f3ae84
direcldiv-bang
vpuri3 Feb 4, 2023
6d302c2
add methods defaultalg to support scimlops
vpuri3 Feb 4, 2023
14fe2a5
comments
vpuri3 Feb 4, 2023
a5cd365
comments
vpuri3 Feb 4, 2023
99a01fc
https://github.com/JuliaSmoothOptimizers/Krylov.jl/issues/612#issueco…
vpuri3 Feb 4, 2023
34f035d
directldiv-bang for bidiagonal, factorization
vpuri3 Feb 4, 2023
f32ddcc
rm defaultalg(A::Diagonal, b, ::OperatorAssumptions{false}) because d…
vpuri3 Feb 5, 2023
76ffd29
update compat
vpuri3 Feb 5, 2023
7828889
cleaning
vpuri3 Feb 5, 2023
adb2f08
format
vpuri3 Feb 5, 2023
a2a394b
put invpreconditioner back as it is used upstream. deprecate later
vpuri3 Feb 5, 2023
616e21c
lu(diagonal) not working in 1.6
vpuri3 Feb 5, 2023
c1c683a
fix test
vpuri3 Feb 5, 2023
4b8a773
readd IterativeSolvers.Identity. its being used downstream in ordinar…
vpuri3 Feb 5, 2023
16b76e9
__issquare
vpuri3 Feb 5, 2023
28f1781
__issquare fix in LinearSolveHYPRE
vpuri3 Feb 5, 2023
fc427b9
fix DirectLdiv<bang> test
vpuri3 Feb 5, 2023
c763a17
scimlbase compat
vpuri3 Feb 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
6 changes: 4 additions & 2 deletions lib/LinearSolveCUDA/src/LinearSolveCUDA.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module LinearSolveCUDA

using CUDA, LinearAlgebra, LinearSolve, SciMLBase
using SciMLBase: AbstractSciMLOperator

struct CudaOffloadFactorization <: LinearSolve.AbstractFactorization end

Expand All @@ -17,12 +18,13 @@ function SciMLBase.solve(cache::LinearSolve.LinearCache, alg::CudaOffloadFactori
end

function LinearSolve.do_factorization(alg::CudaOffloadFactorization, A, b, u)
A isa Union{AbstractMatrix, SciMLBase.AbstractDiffEqOperator} ||
A isa Union{AbstractMatrix, AbstractSciMLOperator} ||
error("LU is not defined for $(typeof(A))")

if A isa SciMLBase.AbstractDiffEqOperator
if A isa Union{MatrixOperator, DiffEqArrayOperator}
A = A.A
end

fact = qr(CUDA.CuArray(A))
return fact
end
Expand Down
15 changes: 13 additions & 2 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import Base: eltype, adjoint, inv
using LinearAlgebra
using IterativeSolvers: Identity
using SparseArrays
using SciMLBase: AbstractDiffEqOperator, AbstractLinearAlgorithm
using SciMLBase: AbstractLinearAlgorithm, DiffEqIdentity
using SciMLOperators: AbstractSciMLOperator, IdentityOperator,
InvertedOperator, ComposedOperator
using Setfield
using UnPack
using SuiteSparse
Expand All @@ -20,6 +22,7 @@ using FastLapackInterface
using DocStringExtensions
import GPUArraysCore
import Preferences
import SciMLOperators: issquare

# wrap
import Krylov
Expand All @@ -41,6 +44,14 @@ needs_concrete_A(alg::AbstractFactorization) = true
needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
needs_concrete_A(alg::AbstractSolveFunction) = false

# Util
_isidentity_struct(A) = false
_isidentity_struct(λ::Number) = isone(λ)
_isidentity_struct(A::UniformScaling) = isone(A.λ)
_isidentity_struct(::IterativeSolvers.Identity) = true
_isidentity_struct(::SciMLBase.IdentityOperator) = true
_isidentity_struct(::SciMLBase.DiffEqIdentity) = true

# Code

const INCLUDE_SPARSE = Preferences.@load_preference("include_sparse", Base.USE_GPL_LIBS)
Expand Down Expand Up @@ -97,7 +108,7 @@ export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization,
SparspakFactorization, DiagonalFactorization

export LinearSolveFunction
export LinearSolveFunction, DirectLdiv!

export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,
KrylovJL_BICGSTAB, KrylovJL_LSMR, KrylovJL_CRAIGMR,
Expand Down
20 changes: 12 additions & 8 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
struct OperatorAssumptions{issquare} end
struct OperatorAssumptions{issq} end
# TODO - in defaultalg selection, OperatorAssumptions{nothing} behaves
# exactly like OperatorAssumptions{true}. So let's remove the option to
# put in nothing.
function OperatorAssumptions(issquare = nothing)
issquare = something(_unwrap_val(issquare), Nothing)
OperatorAssumptions{issquare}()
issq = something(_unwrap_val(issquare), Nothing)
OperatorAssumptions{issq}()
end
issquare(::OperatorAssumptions{issq}) where {issq} = issq

struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issquare}
struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
A::TA
b::Tb
u::Tu
Expand All @@ -19,12 +22,13 @@ struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issquare}
reltol::Ttol
maxiters::Int
verbose::Bool
assumptions::OperatorAssumptions{issquare}
assumptions::OperatorAssumptions{issq}
end

"""
$(SIGNATURES)
"""
# TODO - this should modify OperatorAssumption??
function set_A(cache::LinearCache, A)
@set! cache.A = A
@set! cache.isfresh = true
Expand Down Expand Up @@ -92,9 +96,9 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
reltol = default_tol(eltype(prob.A)),
maxiters::Int = length(prob.b),
verbose::Bool = false,
Pl = Identity(),
Pr = Identity(),
assumptions = OperatorAssumptions(),
Pl = IdentityOperator{size(prob.A, 1)}(),
Pr = IdentityOperator{size(prob.A, 2)}(),
assumptions = OperatorAssumptions(issquare(prob.A)),
kwargs...)
@unpack A, b, u0, p = prob

Expand Down
35 changes: 25 additions & 10 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
# For SciML algorithms already using `defaultalg`, all assume square matrix.
defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(Val(true)))

function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions)
function defaultalg(A::Union{DiffEqArrayOperator,MatrixOperator}, b, assumptions::OperatorAssumptions)
defaultalg(A.A, b, assumptions)
end

# Ambiguity handling
function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{nothing})
function defaultalg(A::Union{DiffEqArrayOperator,MatrixOperator}, b, assumptions::OperatorAssumptions{nothing})
defaultalg(A.A, b, assumptions)
end

function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{false})
function defaultalg(A::Union{DiffEqArrayOperator,MatrixOperator}, b, assumptions::OperatorAssumptions{false})
defaultalg(A.A, b, assumptions)
end

function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{true})
function defaultalg(A::Union{DiffEqArrayOperator,MatrixOperator}, b, assumptions::OperatorAssumptions{true})
defaultalg(A.A, b, assumptions)
end

function defaultalg(A, b, ::OperatorAssumptions{Nothing})
issquare = size(A, 1) == size(A, 2)
defaultalg(A, b, OperatorAssumptions(Val(issquare)))
issq = issquare(A)
defaultalg(A, b, OperatorAssumptions(Val(issq)))
end

function defaultalg(A::Tridiagonal, b, ::OperatorAssumptions{true})
Expand All @@ -33,9 +33,16 @@ end
function defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions{true})
GenericFactorization(; fact_alg = ldlt!)
end
function defaultalg(A::Bidiagonal, b, ::OperatorAssumptions{true})
DirectLdiv!()
end
function defaultalg(A::Factorization, b, ::OperatorAssumptions{true})
DirectLdiv!()
end
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{true})
DiagonalFactorization()
end
# TODO - Diagonal matrices are always square
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{false})
DiagonalFactorization()
end
Expand Down Expand Up @@ -75,18 +82,26 @@ function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{
end
end

function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b,
assumptions::OperatorAssumptions)
function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
assumptions::OperatorAssumptions{true})
if has_ldiv!(A)
return DirectLdiv!()
end

KrylovJL_GMRES()
end

# Ambiguity handling
function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b,
function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
assumptions::OperatorAssumptions{Nothing})
if has_ldiv!(A)
return DirectLdiv!()
end

KrylovJL_GMRES()
end

function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b,
function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
assumptions::OperatorAssumptions{false})
m, n = size(A)
if m < n
Expand Down
2 changes: 2 additions & 0 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,8 @@ struct DiagonalFactorization <: AbstractFactorization end

function init_cacheval(alg::DiagonalFactorization, A, b, u, Pl, Pr, maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
# TODO - DiagonalFactorization should preinvert, and mul!
#Diagonal(inv.(A.diag))
nothing
end

Expand Down
11 changes: 6 additions & 5 deletions src/iterative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
M = cache.Pl
N = cache.Pr

M = (M === Identity()) ? I : InvPreconditioner(M)
N = (N === Identity()) ? I : InvPreconditioner(N)
# use no-op preconditioner for Krylov.jl (LinearAlgebra.I) when M/N is identity
M = _isidentity_struct(M) ? I : M
N = _isidentity_struct(M) ? I : N

atol = float(cache.abstol)
rtol = float(cache.reltol)
Expand All @@ -160,7 +161,7 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)

args = (cache.cacheval, cache.A, cache.b)
kwargs = (atol = atol, rtol = rtol, itmax = itmax, verbose = verbose,
history = true, alg.kwargs...)
ldiv = true, history = true, alg.kwargs...)

if cache.cacheval isa Krylov.CgSolver
N !== I &&
Expand Down Expand Up @@ -234,15 +235,15 @@ function init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, maxiters::Int,
alg.kwargs...)

iterable = if alg.generate_iterator === IterativeSolvers.cg_iterator!
Pr !== Identity() &&
! _isidentity_struct(Pr) &&
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
alg.generate_iterator(u, A, b, Pl;
kwargs...)
elseif alg.generate_iterator === IterativeSolvers.gmres_iterable!
alg.generate_iterator(u, A, b; Pl = Pl, Pr = Pr, restart = restart,
kwargs...)
elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator!
Pr !== Identity() &&
! _isidentity_struct(Pr) &&
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
alg.generate_iterator(u, A, b, alg.args...; Pl = Pl,
abstol = abstol, reltol = reltol,
Expand Down
17 changes: 9 additions & 8 deletions src/preconditioners.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
# Tooling Preconditioners
#
# TODO - update Preconditioner docs
# TODO - replace ComposePreconditoner with ComposedOperator after
# ComposePreconditioner is deprecated in OrdinaryDiffEq

#const ComposePreconditioner = ComposedOperator
#@deprecate ComposePreconditioner ComposedOperator
vpuri3 marked this conversation as resolved.
Show resolved Hide resolved

struct ComposePreconditioner{Ti, To}
inner::Ti
Expand All @@ -21,11 +28,5 @@ function LinearAlgebra.ldiv!(y, A::ComposePreconditioner, x)
ldiv!(outer, y)
end

struct InvPreconditioner{T}
P::T
end

Base.eltype(A::InvPreconditioner) = Base.eltype(A.P)
LinearAlgebra.ldiv!(A::InvPreconditioner, x) = mul!(x, A.P, x)
LinearAlgebra.ldiv!(y, A::InvPreconditioner, x) = mul!(y, A.P, x)
LinearAlgebra.mul!(y, A::InvPreconditioner, x) = ldiv!(y, A.P, x)
#const InvPreconditioner = InvertedOperator
#@deprecate InvPreconditioner InvertedOperator
vpuri3 marked this conversation as resolved.
Show resolved Hide resolved
9 changes: 9 additions & 0 deletions src/solve_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,12 @@ function SciMLBase.solve(cache::LinearCache, alg::LinearSolveFunction,

return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end

struct DirectLdiv! <: AbstractSolveFunction end

function SciMLBase.solve(cache::LinearCache, alg::DirectLdiv!, args...; kwargs...)
@unpack A, b, u = cache
ldiv!(u, A, b)

return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end
Loading