From 0c2a193dfa5df0e71e3c09d0ca17968bcf91a21e Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Fri, 31 Dec 2021 11:11:41 -0500 Subject: [PATCH] do AbstractMatrix conversion before factorization --- Project.toml | 2 +- src/factorization.jl | 64 +++++++++++++------------------------------- 2 files changed, 19 insertions(+), 47 deletions(-) diff --git a/Project.toml b/Project.toml index 224814760..c0c2111f0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LinearSolve" uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" authors = ["SciML"] -version = "1.2.4" +version = "1.2.5" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/factorization.jl b/src/factorization.jl index 9d5d48008..339dde9f3 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -17,7 +17,7 @@ end # Bad fallback: will fail if `A` is just a stand-in # This should instead just create the factorization type. -init_cacheval(alg::AbstractFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = do_factorization(alg, A, b, u) +init_cacheval(alg::AbstractFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = do_factorization(alg, convert(AbstractMatrix,A), b, u) ## LU Factorizations @@ -35,21 +35,16 @@ function LUFactorization() end function do_factorization(alg::LUFactorization, A, b, u) - A isa Union{AbstractMatrix,AbstractDiffEqOperator} || - error("LU is not defined for $(typeof(A))") - - if A isa DiffEqArrayOperator - A = A.A - end + A = convert(AbstractMatrix,A) if A isa SparseMatrixCSC - fact = lu(A, alg.pivot) + return lu(A) else fact = lu!(A, alg.pivot) end return fact end -init_cacheval(alg::LUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A) +init_cacheval(alg::LUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A)) # This could be a GenericFactorization perhaps? Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization @@ -57,6 +52,7 @@ Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization end function init_cacheval(alg::UMFPACKFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) + A = convert(AbstractMatrix,A) zerobased = SparseArrays.getcolptr(A)[1] == 0 res = SuiteSparse.UMFPACK.UmfpackLU(C_NULL, C_NULL, size(A, 1), size(A, 2), zerobased ? copy(SparseArrays.getcolptr(A)) : SuiteSparse.decrement(SparseArrays.getcolptr(A)), @@ -67,9 +63,7 @@ function init_cacheval(alg::UMFPACKFactorization, A, b, u, Pl, Pr, maxiters, abs end function do_factorization(::UMFPACKFactorization, A, b, u) - if A isa DiffEqArrayOperator - A = A.A - end + A = convert(AbstractMatrix,A) if A isa SparseMatrixCSC return lu(A) else @@ -79,9 +73,7 @@ end function SciMLBase.solve(cache::LinearCache, alg::UMFPACKFactorization) A = cache.A - if A isa DiffEqArrayOperator - A = A.A - end + A = convert(AbstractMatrix,A) if cache.isfresh if cache.cacheval !== nothing && alg.reuse_symbolic # If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists @@ -103,13 +95,11 @@ Base.@kwdef struct KLUFactorization <: AbstractFactorization end function init_cacheval(alg::KLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) - return KLU.KLUFactorization(A) # this takes care of the copy internally. + return KLU.KLUFactorization(convert(AbstractMatrix,A)) # this takes care of the copy internally. end function do_factorization(::KLUFactorization, A, b, u) - if A isa DiffEqArrayOperator - A = A.A - end + A = convert(AbstractMatrix,A) if A isa SparseMatrixCSC return klu(A) else @@ -119,9 +109,7 @@ end function SciMLBase.solve(cache::LinearCache, alg::KLUFactorization) A = cache.A - if A isa DiffEqArrayOperator - A = A.A - end + A = convert(AbstractMatrix,A) if cache.isfresh if cache.cacheval !== nothing && alg.reuse_symbolic # If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists @@ -159,12 +147,7 @@ function QRFactorization(inplace = true) end function do_factorization(alg::QRFactorization, A, b, u) - A isa Union{AbstractMatrix,AbstractDiffEqOperator} || - error("QR is not defined for $(typeof(A))") - - if A isa DiffEqArrayOperator - A = A.A - end + A = convert(AbstractMatrix,A) if alg.inplace fact = qr!(A, alg.pivot) else @@ -183,13 +166,7 @@ end SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer()) function do_factorization(alg::SVDFactorization, A, b, u) - A isa Union{AbstractMatrix,AbstractDiffEqOperator} || - error("SVD is not defined for $(typeof(A))") - - if A isa DiffEqArrayOperator - A = A.A - end - + A = convert(AbstractMatrix,A) fact = svd!(A; full = alg.full, alg = alg.alg) return fact end @@ -204,18 +181,13 @@ GenericFactorization(;fact_alg = LinearAlgebra.factorize) = GenericFactorization(fact_alg) function do_factorization(alg::GenericFactorization, A, b, u) - A isa Union{AbstractMatrix,AbstractDiffEqOperator} || - error("GenericFactorization is not defined for $(typeof(A))") - - if A isa DiffEqArrayOperator - A = A.A - end + A = convert(AbstractMatrix,A) fact = alg.fact_alg(A) return fact end -init_cacheval(alg::GenericFactorization{typeof(lu)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A) -init_cacheval(alg::GenericFactorization{typeof(lu!)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A) +init_cacheval(alg::GenericFactorization{typeof(lu)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A)) +init_cacheval(alg::GenericFactorization{typeof(lu!)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A)) init_cacheval(alg::GenericFactorization{typeof(lu)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A) init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A) @@ -245,7 +217,7 @@ end # Fallback, tries to make nonsingular and just factorizes # Try to never use it. function init_cacheval(alg::Union{QRFactorization,SVDFactorization,GenericFactorization}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) - newA = copy(A) + newA = copy(convert(AbstractMatrix,A)) fill!(newA,true) do_factorization(alg, newA, b, u) end @@ -253,5 +225,5 @@ end ## RFLUFactorization RFLUFactorization() = GenericFactorization(;fact_alg=RecursiveFactorization.lu!) -init_cacheval(alg::GenericFactorization{typeof(RecursiveFactorization.lu!)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A) -init_cacheval(alg::GenericFactorization{typeof(RecursiveFactorization.lu!)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A) +init_cacheval(alg::GenericFactorization{typeof(RecursiveFactorization.lu!)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A)) +init_cacheval(alg::GenericFactorization{typeof(RecursiveFactorization.lu!)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))