Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SciMLOperators in LinearSolve #270

Merged
merged 26 commits into from
Feb 5, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
0bb62a5
added scimloperators
vpuri3 Feb 2, 2023
2190beb
extended SciMLOperators.issquare
vpuri3 Feb 2, 2023
63396af
support abstractscimloperator, matrixoperator
vpuri3 Feb 3, 2023
d500b2c
support scimloperators in linearsolvecuda
vpuri3 Feb 3, 2023
1c4d194
rm issquare comment
vpuri3 Feb 3, 2023
9d8bb18
_isidentity_struct
vpuri3 Feb 3, 2023
49132bc
deprecate preconditioner interface and use scimloperators
vpuri3 Feb 3, 2023
5674434
use scimloperators.identityoperator in place of iterativesolves.identity
vpuri3 Feb 3, 2023
6f3ae84
direcldiv-bang
vpuri3 Feb 4, 2023
6d302c2
add methods defaultalg to support scimlops
vpuri3 Feb 4, 2023
14fe2a5
comments
vpuri3 Feb 4, 2023
a5cd365
comments
vpuri3 Feb 4, 2023
99a01fc
https://github.com/JuliaSmoothOptimizers/Krylov.jl/issues/612#issueco…
vpuri3 Feb 4, 2023
34f035d
directldiv-bang for bidiagonal, factorization
vpuri3 Feb 4, 2023
f32ddcc
rm defaultalg(A::Diagonal, b, ::OperatorAssumptions{false}) because d…
vpuri3 Feb 5, 2023
76ffd29
update compat
vpuri3 Feb 5, 2023
7828889
cleaning
vpuri3 Feb 5, 2023
adb2f08
format
vpuri3 Feb 5, 2023
a2a394b
put invpreconditioner back as it is used upstream. deprecate later
vpuri3 Feb 5, 2023
616e21c
lu(diagonal) not working in 1.6
vpuri3 Feb 5, 2023
c1c683a
fix test
vpuri3 Feb 5, 2023
4b8a773
readd IterativeSolvers.Identity. its being used downstream in ordinar…
vpuri3 Feb 5, 2023
16b76e9
__issquare
vpuri3 Feb 5, 2023
28f1781
__issquare fix in LinearSolveHYPRE
vpuri3 Feb 5, 2023
fc427b9
fix DirectLdiv<bang> test
vpuri3 Feb 5, 2023
c763a17
scimlbase compat
vpuri3 Feb 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -43,7 +44,8 @@ KrylovKit = "0.5, 0.6"
Preferences = "1"
RecursiveFactorization = "0.2.8"
Reexport = "1"
SciMLBase = "1.68"
SciMLBase = "1.81.1"
SciMLOperators = "0.1.19"
Setfield = "0.7, 0.8, 1"
SnoopPrecompile = "1"
Sparspak = "0.3.6"
Expand Down
6 changes: 4 additions & 2 deletions lib/LinearSolveCUDA/src/LinearSolveCUDA.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module LinearSolveCUDA

using CUDA, LinearAlgebra, LinearSolve, SciMLBase
using SciMLBase: AbstractSciMLOperator

struct CudaOffloadFactorization <: LinearSolve.AbstractFactorization end

Expand All @@ -17,12 +18,13 @@ function SciMLBase.solve(cache::LinearSolve.LinearCache, alg::CudaOffloadFactori
end

function LinearSolve.do_factorization(alg::CudaOffloadFactorization, A, b, u)
A isa Union{AbstractMatrix, SciMLBase.AbstractDiffEqOperator} ||
A isa Union{AbstractMatrix, AbstractSciMLOperator} ||
error("LU is not defined for $(typeof(A))")

if A isa SciMLBase.AbstractDiffEqOperator
if A isa Union{MatrixOperator, DiffEqArrayOperator}
A = A.A
end

fact = qr(CUDA.CuArray(A))
return fact
end
Expand Down
17 changes: 13 additions & 4 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ end
using ArrayInterfaceCore
using RecursiveFactorization
using Base: cache_dependencies, Bool
import Base: eltype, adjoint, inv
using LinearAlgebra
using IterativeSolvers: Identity
using SparseArrays
using SciMLBase: AbstractDiffEqOperator, AbstractLinearAlgorithm
using SciMLBase: AbstractLinearAlgorithm
using SciMLOperators
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
using Setfield
using UnPack
using SuiteSparse
Expand Down Expand Up @@ -41,6 +41,15 @@ needs_concrete_A(alg::AbstractFactorization) = true
needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
needs_concrete_A(alg::AbstractSolveFunction) = false

# Util

_isidentity_struct(A) = false
_isidentity_struct(λ::Number) = isone(λ)
_isidentity_struct(A::UniformScaling) = isone(A.λ)
_isidentity_struct(::IterativeSolvers.Identity) = true
_isidentity_struct(::SciMLOperators.IdentityOperator) = true
_isidentity_struct(::SciMLBase.DiffEqIdentity) = true

# Code

const INCLUDE_SPARSE = Preferences.@load_preference("include_sparse", Base.USE_GPL_LIBS)
Expand Down Expand Up @@ -97,7 +106,7 @@ export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization,
SparspakFactorization, DiagonalFactorization

export LinearSolveFunction
export LinearSolveFunction, DirectLdiv!

export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES,
KrylovJL_BICGSTAB, KrylovJL_LSMR, KrylovJL_CRAIGMR,
Expand Down
18 changes: 9 additions & 9 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
struct OperatorAssumptions{issquare} end
struct OperatorAssumptions{issq} end
function OperatorAssumptions(issquare = nothing)
issquare = something(_unwrap_val(issquare), Nothing)
OperatorAssumptions{issquare}()
issq = something(_unwrap_val(issquare), Nothing)
OperatorAssumptions{issq}()
end
issquare(::OperatorAssumptions{issq}) where {issq} = issq
SciMLOperators.issquare(::OperatorAssumptions{issq}) where {issq} = issq
vpuri3 marked this conversation as resolved.
Show resolved Hide resolved

struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issquare}
struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq}
A::TA
b::Tb
u::Tu
Expand All @@ -19,7 +19,7 @@ struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issquare}
reltol::Ttol
maxiters::Int
verbose::Bool
assumptions::OperatorAssumptions{issquare}
assumptions::OperatorAssumptions{issq}
end

"""
Expand Down Expand Up @@ -92,9 +92,9 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith
reltol = default_tol(eltype(prob.A)),
maxiters::Int = length(prob.b),
verbose::Bool = false,
Pl = Identity(),
Pr = Identity(),
assumptions = OperatorAssumptions(),
Pl = IdentityOperator{size(prob.A, 1)}(),
Pr = IdentityOperator{size(prob.A, 2)}(),
assumptions = OperatorAssumptions(issquare(prob.A)),
kwargs...)
@unpack A, b, u0, p = prob

Expand Down
41 changes: 28 additions & 13 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,30 @@
# For SciML algorithms already using `defaultalg`, all assume square matrix.
defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(Val(true)))

function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions)
function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b,
assumptions::OperatorAssumptions)
defaultalg(A.A, b, assumptions)
end

# Ambiguity handling
function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{nothing})
function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b,
assumptions::OperatorAssumptions{nothing})
defaultalg(A.A, b, assumptions)
end

function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{false})
function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b,
assumptions::OperatorAssumptions{false})
defaultalg(A.A, b, assumptions)
end

function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{true})
function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b,
assumptions::OperatorAssumptions{true})
defaultalg(A.A, b, assumptions)
end

function defaultalg(A, b, ::OperatorAssumptions{Nothing})
issquare = size(A, 1) == size(A, 2)
defaultalg(A, b, OperatorAssumptions(Val(issquare)))
issq = issquare(A)
defaultalg(A, b, OperatorAssumptions(Val(issq)))
end

function defaultalg(A::Tridiagonal, b, ::OperatorAssumptions{true})
Expand All @@ -33,10 +37,13 @@ end
function defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions{true})
GenericFactorization(; fact_alg = ldlt!)
end
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{true})
DiagonalFactorization()
function defaultalg(A::Bidiagonal, b, ::OperatorAssumptions{true})
DirectLdiv!()
end
function defaultalg(A::Factorization, b, ::OperatorAssumptions{true})
DirectLdiv!()
end
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{false})
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{true})
DiagonalFactorization()
end
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{Nothing})
Expand Down Expand Up @@ -75,18 +82,26 @@ function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{
end
end

function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b,
assumptions::OperatorAssumptions)
function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
assumptions::OperatorAssumptions{true})
if has_ldiv!(A)
return DirectLdiv!()
end

KrylovJL_GMRES()
end

# Ambiguity handling
function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b,
function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
assumptions::OperatorAssumptions{Nothing})
if has_ldiv!(A)
return DirectLdiv!()
end

KrylovJL_GMRES()
end

function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b,
function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
assumptions::OperatorAssumptions{false})
m, n = size(A)
if m < n
Expand Down
11 changes: 6 additions & 5 deletions src/iterative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)
M = cache.Pl
N = cache.Pr

M = (M === Identity()) ? I : InvPreconditioner(M)
N = (N === Identity()) ? I : InvPreconditioner(N)
# use no-op preconditioner for Krylov.jl (LinearAlgebra.I) when M/N is identity
M = _isidentity_struct(M) ? I : M
N = _isidentity_struct(M) ? I : N

atol = float(cache.abstol)
rtol = float(cache.reltol)
Expand All @@ -160,7 +161,7 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...)

args = (cache.cacheval, cache.A, cache.b)
kwargs = (atol = atol, rtol = rtol, itmax = itmax, verbose = verbose,
history = true, alg.kwargs...)
ldiv = true, history = true, alg.kwargs...)

if cache.cacheval isa Krylov.CgSolver
N !== I &&
Expand Down Expand Up @@ -234,15 +235,15 @@ function init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, maxiters::Int,
alg.kwargs...)

iterable = if alg.generate_iterator === IterativeSolvers.cg_iterator!
Pr !== Identity() &&
!_isidentity_struct(Pr) &&
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
alg.generate_iterator(u, A, b, Pl;
kwargs...)
elseif alg.generate_iterator === IterativeSolvers.gmres_iterable!
alg.generate_iterator(u, A, b; Pl = Pl, Pr = Pr, restart = restart,
kwargs...)
elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator!
Pr !== Identity() &&
!_isidentity_struct(Pr) &&
@warn "$(alg.generate_iterator) doesn't support right preconditioning"
alg.generate_iterator(u, A, b, alg.args...; Pl = Pl,
abstol = abstol, reltol = reltol,
Expand Down
9 changes: 9 additions & 0 deletions src/solve_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,12 @@ function SciMLBase.solve(cache::LinearCache, alg::LinearSolveFunction,

return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end

struct DirectLdiv! <: AbstractSolveFunction end

function SciMLBase.solve(cache::LinearCache, alg::DirectLdiv!, args...; kwargs...)
@unpack A, b, u = cache
ldiv!(u, A, b)

return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache)
end
93 changes: 65 additions & 28 deletions test/basictests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LinearSolve, LinearAlgebra, SparseArrays, MultiFloats, ForwardDiff
using SciMLOperators
using Test
import Random

Expand Down Expand Up @@ -27,20 +28,20 @@ function test_interface(alg, prob1, prob2)
b2 = prob2.b
x2 = prob2.u0

y = solve(prob1, alg; cache_kwargs...)
@test A1 * y ≈ b1
sol = solve(prob1, alg; cache_kwargs...)
@test A1 * sol.u ≈ b1

cache = SciMLBase.init(prob1, alg; cache_kwargs...) # initialize cache
y = solve(cache)
@test A1 * y ≈ b1
sol = solve(cache)
@test A1 * sol.u ≈ b1

cache = LinearSolve.set_A(cache, copy(A2))
y = solve(cache; cache_kwargs...)
@test A2 * y ≈ b1
cache = LinearSolve.set_A(cache, deepcopy(A2))
vpuri3 marked this conversation as resolved.
Show resolved Hide resolved
sol = solve(cache; cache_kwargs...)
@test A2 * sol.u ≈ b1

cache = LinearSolve.set_b(cache, b2)
y = solve(cache; cache_kwargs...)
@test A2 * y ≈ b2
sol = solve(cache; cache_kwargs...)
@test A2 * sol.u ≈ b2

return
end
Expand Down Expand Up @@ -271,12 +272,14 @@ end

@testset "Preconditioners" begin
@testset "Vector Diagonal Preconditioner" begin
s = rand(n)
Pl, Pr = Diagonal(s), LinearSolve.InvPreconditioner(Diagonal(s))

x = rand(n, n)
y = rand(n, n)

s = rand(n)
Pl = Diagonal(s) |> MatrixOperator
Pr = Diagonal(s) |> MatrixOperator |> inv
Pr = cache_operator(Pr, x)

mul!(y, Pl, x)
@test y ≈ s .* x
mul!(y, Pr, x)
Expand Down Expand Up @@ -346,33 +349,67 @@ end
end

@testset "Solve Function" begin
A1 = rand(n) |> Diagonal
A1 = rand(n) |> Diagonal |> Array
b1 = rand(n)
x1 = zero(b1)
A2 = rand(n) |> Diagonal
A2 = rand(n) |> Diagonal |> Array
vpuri3 marked this conversation as resolved.
Show resolved Hide resolved
b2 = rand(n)
x2 = zero(b1)

function sol_func(A, b, u, p, newA, Pl, Pr, solverdata; verbose = true, kwargs...)
if verbose == true
println("out-of-place solve")
@testset "LinearSolveFunction" begin
function sol_func(A, b, u, p, newA, Pl, Pr, solverdata; verbose = true,
kwargs...)
if verbose == true
println("out-of-place solve")
end
u = A \ b
end
u = A \ b
end

function sol_func!(A, b, u, p, newA, Pl, Pr, solverdata; verbose = true, kwargs...)
if verbose == true
println("in-place solve")
function sol_func!(A, b, u, p, newA, Pl, Pr, solverdata; verbose = true,
kwargs...)
if verbose == true
println("in-place solve")
end
ldiv!(u, A, b)
end

prob1 = LinearProblem(A1, b1; u0 = x1)
prob2 = LinearProblem(A1, b1; u0 = x1)

for alg in (LinearSolveFunction(sol_func),
LinearSolveFunction(sol_func!))
test_interface(alg, prob1, prob2)
end
ldiv!(u, A, b)
end

prob1 = LinearProblem(A1, b1; u0 = x1)
prob2 = LinearProblem(A1, b1; u0 = x1)
@testset "DirectLdiv!" begin
function get_operator(A, u)
F = lu(A)

for alg in (LinearSolveFunction(sol_func),
LinearSolveFunction(sol_func!))
test_interface(alg, prob1, prob2)
function f(du, u, p, t)
println("using FunctionOperator mul!")
mul!(du, A, u)
end

function fi(du, u, p, t)
println("using FunctionOperator ldiv!")
ldiv!(du, F, u)
end

FunctionOperator(f, u, u; isinplace = true, op_inverse = fi)
end

op1 = get_operator(A1, x1 * 0)
op2 = get_operator(A2, x2 * 0)

prob1 = LinearProblem(op1, b1; u0 = x1)
prob2 = LinearProblem(op2, b2; u0 = x2)

@test LinearSolve.defaultalg(op1, x1) isa DirectLdiv!
@test LinearSolve.defaultalg(op2, x2) isa DirectLdiv!

test_interface(DirectLdiv!(), prob1, prob2)
test_interface(nothing, prob1, prob2)
end
end
end # testset
Loading