Skip to content

Commit

Permalink
Merge pull request #83 from SciML/abstractmatrix
Browse files Browse the repository at this point in the history
do AbstractMatrix conversion before factorization
  • Loading branch information
ChrisRackauckas authored Dec 31, 2021
2 parents 8723422 + 0c2a193 commit dacd618
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 47 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "1.2.4"
version = "1.2.5"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
64 changes: 18 additions & 46 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end

# Bad fallback: will fail if `A` is just a stand-in
# This should instead just create the factorization type.
init_cacheval(alg::AbstractFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = do_factorization(alg, A, b, u)
init_cacheval(alg::AbstractFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = do_factorization(alg, convert(AbstractMatrix,A), b, u)

## LU Factorizations

Expand All @@ -35,28 +35,24 @@ function LUFactorization()
end

function do_factorization(alg::LUFactorization, A, b, u)
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
error("LU is not defined for $(typeof(A))")

if A isa DiffEqArrayOperator
A = A.A
end
A = convert(AbstractMatrix,A)
if A isa SparseMatrixCSC
fact = lu(A, alg.pivot)
return lu(A)
else
fact = lu!(A, alg.pivot)
end
return fact
end

init_cacheval(alg::LUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
init_cacheval(alg::LUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))

# This could be a GenericFactorization perhaps?
Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization
reuse_symbolic::Bool = true
end

function init_cacheval(alg::UMFPACKFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
A = convert(AbstractMatrix,A)
zerobased = SparseArrays.getcolptr(A)[1] == 0
res = SuiteSparse.UMFPACK.UmfpackLU(C_NULL, C_NULL, size(A, 1), size(A, 2),
zerobased ? copy(SparseArrays.getcolptr(A)) : SuiteSparse.decrement(SparseArrays.getcolptr(A)),
Expand All @@ -67,9 +63,7 @@ function init_cacheval(alg::UMFPACKFactorization, A, b, u, Pl, Pr, maxiters, abs
end

function do_factorization(::UMFPACKFactorization, A, b, u)
if A isa DiffEqArrayOperator
A = A.A
end
A = convert(AbstractMatrix,A)
if A isa SparseMatrixCSC
return lu(A)
else
Expand All @@ -79,9 +73,7 @@ end

function SciMLBase.solve(cache::LinearCache, alg::UMFPACKFactorization)
A = cache.A
if A isa DiffEqArrayOperator
A = A.A
end
A = convert(AbstractMatrix,A)
if cache.isfresh
if cache.cacheval !== nothing && alg.reuse_symbolic
# If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists
Expand All @@ -103,13 +95,11 @@ Base.@kwdef struct KLUFactorization <: AbstractFactorization
end

function init_cacheval(alg::KLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
return KLU.KLUFactorization(A) # this takes care of the copy internally.
return KLU.KLUFactorization(convert(AbstractMatrix,A)) # this takes care of the copy internally.
end

function do_factorization(::KLUFactorization, A, b, u)
if A isa DiffEqArrayOperator
A = A.A
end
A = convert(AbstractMatrix,A)
if A isa SparseMatrixCSC
return klu(A)
else
Expand All @@ -119,9 +109,7 @@ end

function SciMLBase.solve(cache::LinearCache, alg::KLUFactorization)
A = cache.A
if A isa DiffEqArrayOperator
A = A.A
end
A = convert(AbstractMatrix,A)
if cache.isfresh
if cache.cacheval !== nothing && alg.reuse_symbolic
# If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists
Expand Down Expand Up @@ -159,12 +147,7 @@ function QRFactorization(inplace = true)
end

function do_factorization(alg::QRFactorization, A, b, u)
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
error("QR is not defined for $(typeof(A))")

if A isa DiffEqArrayOperator
A = A.A
end
A = convert(AbstractMatrix,A)
if alg.inplace
fact = qr!(A, alg.pivot)
else
Expand All @@ -183,13 +166,7 @@ end
SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer())

function do_factorization(alg::SVDFactorization, A, b, u)
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
error("SVD is not defined for $(typeof(A))")

if A isa DiffEqArrayOperator
A = A.A
end

A = convert(AbstractMatrix,A)
fact = svd!(A; full = alg.full, alg = alg.alg)
return fact
end
Expand All @@ -204,18 +181,13 @@ GenericFactorization(;fact_alg = LinearAlgebra.factorize) =
GenericFactorization(fact_alg)

function do_factorization(alg::GenericFactorization, A, b, u)
A isa Union{AbstractMatrix,AbstractDiffEqOperator} ||
error("GenericFactorization is not defined for $(typeof(A))")

if A isa DiffEqArrayOperator
A = A.A
end
A = convert(AbstractMatrix,A)
fact = alg.fact_alg(A)
return fact
end

init_cacheval(alg::GenericFactorization{typeof(lu)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
init_cacheval(alg::GenericFactorization{typeof(lu!)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
init_cacheval(alg::GenericFactorization{typeof(lu)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))
init_cacheval(alg::GenericFactorization{typeof(lu!)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))

init_cacheval(alg::GenericFactorization{typeof(lu)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
Expand Down Expand Up @@ -245,13 +217,13 @@ end
# Fallback, tries to make nonsingular and just factorizes
# Try to never use it.
function init_cacheval(alg::Union{QRFactorization,SVDFactorization,GenericFactorization}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
newA = copy(A)
newA = copy(convert(AbstractMatrix,A))
fill!(newA,true)
do_factorization(alg, newA, b, u)
end

## RFLUFactorization

RFLUFactorization() = GenericFactorization(;fact_alg=RecursiveFactorization.lu!)
init_cacheval(alg::GenericFactorization{typeof(RecursiveFactorization.lu!)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
init_cacheval(alg::GenericFactorization{typeof(RecursiveFactorization.lu!)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(A)
init_cacheval(alg::GenericFactorization{typeof(RecursiveFactorization.lu!)}, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))
init_cacheval(alg::GenericFactorization{typeof(RecursiveFactorization.lu!)}, A::StridedMatrix{<:LinearAlgebra.BlasFloat}, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) = ArrayInterface.lu_instance(convert(AbstractMatrix,A))

0 comments on commit dacd618

Please sign in to comment.