diff --git a/Project.toml b/Project.toml index 90e7278af..7aac416a5 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" RecursiveFactorization = "f2c3362d-daeb-58d1-803e-2bc74f2840b4" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -43,7 +44,8 @@ KrylovKit = "0.5, 0.6" Preferences = "1" RecursiveFactorization = "0.2.8" Reexport = "1" -SciMLBase = "1.68" +SciMLBase = "1.82" +SciMLOperators = "0.1.19" Setfield = "0.7, 0.8, 1" SnoopPrecompile = "1" Sparspak = "0.3.6" diff --git a/ext/LinearSolveHYPRE.jl b/ext/LinearSolveHYPRE.jl index f8bea7d7d..53ffe2d1f 100644 --- a/ext/LinearSolveHYPRE.jl +++ b/ext/LinearSolveHYPRE.jl @@ -4,7 +4,7 @@ using HYPRE.LibHYPRE: HYPRE_Complex using HYPRE: HYPRE, HYPREMatrix, HYPRESolver, HYPREVector using IterativeSolvers: Identity using LinearSolve: HYPREAlgorithm, LinearCache, LinearProblem, LinearSolve, - OperatorAssumptions, default_tol, init_cacheval, issquare, set_cacheval + OperatorAssumptions, default_tol, init_cacheval, __issquare, set_cacheval using SciMLBase: LinearProblem, SciMLBase using UnPack: @unpack using Setfield: @set! @@ -82,7 +82,7 @@ function SciMLBase.init(prob::LinearProblem, alg::HYPREAlgorithm, cache = LinearCache{ typeof(A), typeof(b), typeof(u0), typeof(p), typeof(alg), Tc, - typeof(Pl), typeof(Pr), typeof(reltol), issquare(assumptions) + typeof(Pl), typeof(Pr), typeof(reltol), __issquare(assumptions) }(A, b, u0, p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions) diff --git a/lib/LinearSolveCUDA/src/LinearSolveCUDA.jl b/lib/LinearSolveCUDA/src/LinearSolveCUDA.jl index b228b22f9..9c8ba3da6 100644 --- a/lib/LinearSolveCUDA/src/LinearSolveCUDA.jl +++ b/lib/LinearSolveCUDA/src/LinearSolveCUDA.jl @@ -1,6 +1,7 @@ module LinearSolveCUDA using CUDA, LinearAlgebra, LinearSolve, SciMLBase +using SciMLBase: AbstractSciMLOperator struct CudaOffloadFactorization <: LinearSolve.AbstractFactorization end @@ -17,12 +18,13 @@ function SciMLBase.solve(cache::LinearSolve.LinearCache, alg::CudaOffloadFactori end function LinearSolve.do_factorization(alg::CudaOffloadFactorization, A, b, u) - A isa Union{AbstractMatrix, SciMLBase.AbstractDiffEqOperator} || + A isa Union{AbstractMatrix, AbstractSciMLOperator} || error("LU is not defined for $(typeof(A))") - if A isa SciMLBase.AbstractDiffEqOperator + if A isa Union{MatrixOperator, DiffEqArrayOperator} A = A.A end + fact = qr(CUDA.CuArray(A)) return fact end diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 719ee5c7d..badd2a1ed 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -6,11 +6,12 @@ end using ArrayInterfaceCore using RecursiveFactorization using Base: cache_dependencies, Bool -import Base: eltype, adjoint, inv using LinearAlgebra using IterativeSolvers: Identity using SparseArrays -using SciMLBase: AbstractDiffEqOperator, AbstractLinearAlgorithm +using SciMLBase: AbstractLinearAlgorithm +using SciMLOperators +using SciMLOperators: AbstractSciMLOperator, IdentityOperator using Setfield using UnPack using SuiteSparse @@ -41,6 +42,15 @@ needs_concrete_A(alg::AbstractFactorization) = true needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false needs_concrete_A(alg::AbstractSolveFunction) = false +# Util + +_isidentity_struct(A) = false +_isidentity_struct(λ::Number) = isone(λ) +_isidentity_struct(A::UniformScaling) = isone(A.λ) +_isidentity_struct(::IterativeSolvers.Identity) = true +_isidentity_struct(::SciMLOperators.IdentityOperator) = true +_isidentity_struct(::SciMLBase.DiffEqIdentity) = true + # Code const INCLUDE_SPARSE = Preferences.@load_preference("include_sparse", Base.USE_GPL_LIBS) @@ -97,7 +107,7 @@ export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization, UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization, SparspakFactorization, DiagonalFactorization -export LinearSolveFunction +export LinearSolveFunction, DirectLdiv! export KrylovJL, KrylovJL_CG, KrylovJL_MINRES, KrylovJL_GMRES, KrylovJL_BICGSTAB, KrylovJL_LSMR, KrylovJL_CRAIGMR, diff --git a/src/common.jl b/src/common.jl index 9114895a7..9d8e5a886 100644 --- a/src/common.jl +++ b/src/common.jl @@ -1,11 +1,11 @@ -struct OperatorAssumptions{issquare} end +struct OperatorAssumptions{issq} end function OperatorAssumptions(issquare = nothing) - issquare = something(_unwrap_val(issquare), Nothing) - OperatorAssumptions{issquare}() + issq = something(_unwrap_val(issquare), Nothing) + OperatorAssumptions{issq}() end -issquare(::OperatorAssumptions{issq}) where {issq} = issq +__issquare(::OperatorAssumptions{issq}) where {issq} = issq -struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issquare} +struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issq} A::TA b::Tb u::Tu @@ -19,7 +19,7 @@ struct LinearCache{TA, Tb, Tu, Tp, Talg, Tc, Tl, Tr, Ttol, issquare} reltol::Ttol maxiters::Int verbose::Bool - assumptions::OperatorAssumptions{issquare} + assumptions::OperatorAssumptions{issq} end """ @@ -92,9 +92,9 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith reltol = default_tol(eltype(prob.A)), maxiters::Int = length(prob.b), verbose::Bool = false, - Pl = Identity(), - Pr = Identity(), - assumptions = OperatorAssumptions(), + Pl = IdentityOperator{size(prob.A, 1)}(), + Pr = IdentityOperator{size(prob.A, 2)}(), + assumptions = OperatorAssumptions(Val(issquare(prob.A))), kwargs...) @unpack A, b, u0, p = prob @@ -129,7 +129,7 @@ function SciMLBase.init(prob::LinearProblem, alg::Union{SciMLLinearSolveAlgorith typeof(Pl), typeof(Pr), typeof(reltol), - issquare(assumptions) + __issquare(assumptions) }(A, b, u0, diff --git a/src/default.jl b/src/default.jl index b98bf177c..ca529cd89 100644 --- a/src/default.jl +++ b/src/default.jl @@ -2,26 +2,30 @@ # For SciML algorithms already using `defaultalg`, all assume square matrix. defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(Val(true))) -function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions) +function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b, + assumptions::OperatorAssumptions) defaultalg(A.A, b, assumptions) end # Ambiguity handling -function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{nothing}) +function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b, + assumptions::OperatorAssumptions{nothing}) defaultalg(A.A, b, assumptions) end -function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{false}) +function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b, + assumptions::OperatorAssumptions{false}) defaultalg(A.A, b, assumptions) end -function defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{true}) +function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b, + assumptions::OperatorAssumptions{true}) defaultalg(A.A, b, assumptions) end function defaultalg(A, b, ::OperatorAssumptions{Nothing}) - issquare = size(A, 1) == size(A, 2) - defaultalg(A, b, OperatorAssumptions(Val(issquare))) + issq = issquare(A) + defaultalg(A, b, OperatorAssumptions(Val(issq))) end function defaultalg(A::Tridiagonal, b, ::OperatorAssumptions{true}) @@ -33,10 +37,13 @@ end function defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions{true}) GenericFactorization(; fact_alg = ldlt!) end -function defaultalg(A::Diagonal, b, ::OperatorAssumptions{true}) - DiagonalFactorization() +function defaultalg(A::Bidiagonal, b, ::OperatorAssumptions{true}) + DirectLdiv!() +end +function defaultalg(A::Factorization, b, ::OperatorAssumptions{true}) + DirectLdiv!() end -function defaultalg(A::Diagonal, b, ::OperatorAssumptions{false}) +function defaultalg(A::Diagonal, b, ::OperatorAssumptions{true}) DiagonalFactorization() end function defaultalg(A::Diagonal, b, ::OperatorAssumptions{Nothing}) @@ -75,18 +82,26 @@ function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{ end end -function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b, - assumptions::OperatorAssumptions) +function defaultalg(A::SciMLBase.AbstractSciMLOperator, b, + assumptions::OperatorAssumptions{true}) + if has_ldiv!(A) + return DirectLdiv!() + end + KrylovJL_GMRES() end # Ambiguity handling -function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b, +function defaultalg(A::SciMLBase.AbstractSciMLOperator, b, assumptions::OperatorAssumptions{Nothing}) + if has_ldiv!(A) + return DirectLdiv!() + end + KrylovJL_GMRES() end -function defaultalg(A::SciMLBase.AbstractDiffEqOperator, b, +function defaultalg(A::SciMLBase.AbstractSciMLOperator, b, assumptions::OperatorAssumptions{false}) m, n = size(A) if m < n diff --git a/src/iterative_wrappers.jl b/src/iterative_wrappers.jl index 7386e37b7..a2fa642c3 100644 --- a/src/iterative_wrappers.jl +++ b/src/iterative_wrappers.jl @@ -150,8 +150,9 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...) M = cache.Pl N = cache.Pr - M = (M === Identity()) ? I : InvPreconditioner(M) - N = (N === Identity()) ? I : InvPreconditioner(N) + # use no-op preconditioner for Krylov.jl (LinearAlgebra.I) when M/N is identity + M = _isidentity_struct(M) ? I : M + N = _isidentity_struct(M) ? I : N atol = float(cache.abstol) rtol = float(cache.reltol) @@ -160,7 +161,7 @@ function SciMLBase.solve(cache::LinearCache, alg::KrylovJL; kwargs...) args = (cache.cacheval, cache.A, cache.b) kwargs = (atol = atol, rtol = rtol, itmax = itmax, verbose = verbose, - history = true, alg.kwargs...) + ldiv = true, history = true, alg.kwargs...) if cache.cacheval isa Krylov.CgSolver N !== I && @@ -234,7 +235,7 @@ function init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, maxiters::Int, alg.kwargs...) iterable = if alg.generate_iterator === IterativeSolvers.cg_iterator! - Pr !== Identity() && + !_isidentity_struct(Pr) && @warn "$(alg.generate_iterator) doesn't support right preconditioning" alg.generate_iterator(u, A, b, Pl; kwargs...) @@ -242,7 +243,7 @@ function init_cacheval(alg::IterativeSolversJL, A, b, u, Pl, Pr, maxiters::Int, alg.generate_iterator(u, A, b; Pl = Pl, Pr = Pr, restart = restart, kwargs...) elseif alg.generate_iterator === IterativeSolvers.bicgstabl_iterator! - Pr !== Identity() && + !_isidentity_struct(Pr) && @warn "$(alg.generate_iterator) doesn't support right preconditioning" alg.generate_iterator(u, A, b, alg.args...; Pl = Pl, abstol = abstol, reltol = reltol, diff --git a/src/solve_function.jl b/src/solve_function.jl index 3e3402fa1..7db502acf 100644 --- a/src/solve_function.jl +++ b/src/solve_function.jl @@ -13,3 +13,12 @@ function SciMLBase.solve(cache::LinearCache, alg::LinearSolveFunction, return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) end + +struct DirectLdiv! <: AbstractSolveFunction end + +function SciMLBase.solve(cache::LinearCache, alg::DirectLdiv!, args...; kwargs...) + @unpack A, b, u = cache + ldiv!(u, A, b) + + return SciMLBase.build_linear_solution(alg, cache.u, nothing, cache) +end diff --git a/test/basictests.jl b/test/basictests.jl index dfe6cb707..1eef66306 100644 --- a/test/basictests.jl +++ b/test/basictests.jl @@ -1,4 +1,5 @@ using LinearSolve, LinearAlgebra, SparseArrays, MultiFloats, ForwardDiff +using SciMLOperators using Test import Random @@ -27,20 +28,20 @@ function test_interface(alg, prob1, prob2) b2 = prob2.b x2 = prob2.u0 - y = solve(prob1, alg; cache_kwargs...) - @test A1 * y ≈ b1 + sol = solve(prob1, alg; cache_kwargs...) + @test A1 * sol.u ≈ b1 cache = SciMLBase.init(prob1, alg; cache_kwargs...) # initialize cache - y = solve(cache) - @test A1 * y ≈ b1 + sol = solve(cache) + @test A1 * sol.u ≈ b1 - cache = LinearSolve.set_A(cache, copy(A2)) - y = solve(cache; cache_kwargs...) - @test A2 * y ≈ b1 + cache = LinearSolve.set_A(cache, deepcopy(A2)) + sol = solve(cache; cache_kwargs...) + @test A2 * sol.u ≈ b1 cache = LinearSolve.set_b(cache, b2) - y = solve(cache; cache_kwargs...) - @test A2 * y ≈ b2 + sol = solve(cache; cache_kwargs...) + @test A2 * sol.u ≈ b2 return end @@ -271,12 +272,14 @@ end @testset "Preconditioners" begin @testset "Vector Diagonal Preconditioner" begin - s = rand(n) - Pl, Pr = Diagonal(s), LinearSolve.InvPreconditioner(Diagonal(s)) - x = rand(n, n) y = rand(n, n) + s = rand(n) + Pl = Diagonal(s) |> MatrixOperator + Pr = Diagonal(s) |> MatrixOperator |> inv + Pr = cache_operator(Pr, x) + mul!(y, Pl, x) @test y ≈ s .* x mul!(y, Pr, x) @@ -353,26 +356,58 @@ end b2 = rand(n) x2 = zero(b1) - function sol_func(A, b, u, p, newA, Pl, Pr, solverdata; verbose = true, kwargs...) - if verbose == true - println("out-of-place solve") + @testset "LinearSolveFunction" begin + function sol_func(A, b, u, p, newA, Pl, Pr, solverdata; verbose = true, + kwargs...) + if verbose == true + println("out-of-place solve") + end + u = A \ b end - u = A \ b - end - function sol_func!(A, b, u, p, newA, Pl, Pr, solverdata; verbose = true, kwargs...) - if verbose == true - println("in-place solve") + function sol_func!(A, b, u, p, newA, Pl, Pr, solverdata; verbose = true, + kwargs...) + if verbose == true + println("in-place solve") + end + ldiv!(u, A, b) + end + + prob1 = LinearProblem(A1, b1; u0 = x1) + prob2 = LinearProblem(A1, b1; u0 = x1) + + for alg in (LinearSolveFunction(sol_func), + LinearSolveFunction(sol_func!)) + test_interface(alg, prob1, prob2) end - ldiv!(u, A, b) end - prob1 = LinearProblem(A1, b1; u0 = x1) - prob2 = LinearProblem(A1, b1; u0 = x1) + @testset "DirectLdiv!" begin + function get_operator(A, u) + function f(du, u, p, t) + println("using FunctionOperator mul!") + mul!(du, A, u) + end - for alg in (LinearSolveFunction(sol_func), - LinearSolveFunction(sol_func!)) - test_interface(alg, prob1, prob2) + function fi(du, u, p, t) + println("using FunctionOperator ldiv!") + ldiv!(du, A, u) + end + + FunctionOperator(f, u, u; isinplace = true, op_inverse = fi) + end + + op1 = get_operator(A1, x1 * 0) + op2 = get_operator(A2, x2 * 0) + + prob1 = LinearProblem(op1, b1; u0 = x1) + prob2 = LinearProblem(op2, b2; u0 = x2) + + @test LinearSolve.defaultalg(op1, x1) isa DirectLdiv! + @test LinearSolve.defaultalg(op2, x2) isa DirectLdiv! + + test_interface(DirectLdiv!(), prob1, prob2) + test_interface(nothing, prob1, prob2) end end end # testset diff --git a/test/runtests.jl b/test/runtests.jl index b6d4a1d67..3a47522c9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,6 +26,7 @@ if GROUP == "All" || GROUP == "Core" @time @safetestset "Non-Square Tests" begin include("nonsquare.jl") end @time @safetestset "SparseVector b Tests" begin include("sparse_vector.jl") end @time @safetestset "Default Alg Tests" begin include("default_algs.jl") end + @time @safetestset "Traits" begin include("traits.jl") end end if GROUP == "LinearSolveCUDA" diff --git a/test/traits.jl b/test/traits.jl new file mode 100644 index 000000000..fb9c8b63a --- /dev/null +++ b/test/traits.jl @@ -0,0 +1,15 @@ +# +using LinearSolve, LinearAlgebra, Test +using LinearSolve: _isidentity_struct + +N = 4 + +@testset "Traits" begin + @test _isidentity_struct(I) + @test _isidentity_struct(1.0 * I) + @test _isidentity_struct(SciMLBase.IdentityOperator{N}()) + @test _isidentity_struct(SciMLBase.DiffEqIdentity(rand(4))) + @test !_isidentity_struct(2.0 * I) + @test !_isidentity_struct(rand(N, N)) + @test !_isidentity_struct(Matrix(I, N, N)) +end