From f2330a6ada7b8e7682c39be00821cadb618b86dd Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 16 Oct 2023 11:52:54 -0400 Subject: [PATCH] Defaults for Banded Matrices --- Project.toml | 5 +++- ext/LinearSolveBandedMatricesExt.jl | 38 +++++++++++++++++++++++++++++ src/common.jl | 2 +- src/default.jl | 2 +- 4 files changed, 44 insertions(+), 3 deletions(-) create mode 100644 ext/LinearSolveBandedMatricesExt.jl diff --git a/Project.toml b/Project.toml index ff2c96017..b9b30c1ef 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LinearSolve" uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" authors = ["SciML"] -version = "2.10.0" +version = "2.11.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -31,6 +31,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [weakdeps] +BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" @@ -42,6 +43,7 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962" Pardiso = "46dd5b70-b6fb-5a00-ae2d-e8fea33afaf2" [extensions] +LinearSolveBandedMatricesExt = "BandedMatrices" LinearSolveBlockDiagonalsExt = "BlockDiagonals" LinearSolveCUDAExt = "CUDA" LinearSolveEnzymeExt = "Enzyme" @@ -54,6 +56,7 @@ LinearSolvePardisoExt = "Pardiso" [compat] ArrayInterface = "7.4.11" +BandedMatrices = "1" BlockDiagonals = "0.1" ConcreteStructs = "0.2" DocStringExtensions = "0.8, 0.9" diff --git a/ext/LinearSolveBandedMatricesExt.jl b/ext/LinearSolveBandedMatricesExt.jl new file mode 100644 index 000000000..fc441fcf7 --- /dev/null +++ b/ext/LinearSolveBandedMatricesExt.jl @@ -0,0 +1,38 @@ +module LinearSolveBandedMatricesExt + +using BandedMatrices, LinearAlgebra, LinearSolve +import LinearSolve: defaultalg, + do_factorization, init_cacheval, DefaultLinearSolver, DefaultAlgorithmChoice + +# Defaults for BandedMatrices +function defaultalg(A::BandedMatrix, b, ::OperatorAssumptions) + return DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization) +end + +function defaultalg(A::Symmetric{<:Number, <:BandedMatrix}, b, ::OperatorAssumptions) + return DefaultLinearSolver(DefaultAlgorithmChoice.CholeskyFactorization) +end + +# BandedMatrices `qr` doesn't allow other args without causing an ambiguity +do_factorization(alg::QRFactorization, A::BandedMatrix, b, u) = alg.inplace ? qr!(A) : qr(A) + +function do_factorization(alg::LUFactorization, A::BandedMatrix, b, u) + _pivot = alg.pivot isa NoPivot ? Val(false) : Val(true) + return lu!(A, _pivot; check = false) +end + +# Only init for `qr`, `ldlt` and `cholesky` +for alg in (:SVDFactorization, :LUFactorization, :MKLLUFactorization, + :DiagonalFactorization, :SparspakFactorization, :KLUFactorization, + :UMFPACKFactorization, :GenericLUFactorization, :RFLUFactorization, + :BunchKaufmanFactorization, :CHOLMODFactorization, :NormalCholeskyFactorization, + :AppleAccelerateLUFactorization) + @eval begin + function init_cacheval(::$(alg), ::BandedMatrix, b, u, Pl, Pr, maxiters::Int, + abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) + return nothing + end + end +end + +end diff --git a/src/common.jl b/src/common.jl index 953a82d3c..34160647c 100644 --- a/src/common.jl +++ b/src/common.jl @@ -14,7 +14,7 @@ However, in practice this computation is very expensive and thus not possible fo Therefore, OperatorCondition lets one share to LinearSolve the expected conditioning. The higher the expected condition number, the safer the algorithm needs to be and thus there is a trade-off between numerical performance and stability. By default the method assumes the operator may be ill-conditioned -for the standard linear solvers to converge (such as LU-factorization), though more extreme +for the standard linear solvers to converge (such as LU-factorization), though more extreme ill-conditioning or well-conditioning could be the case and specified through this assumption. """ EnumX.@enumx OperatorCondition begin diff --git a/src/default.jl b/src/default.jl index dfab2242d..04d9e0229 100644 --- a/src/default.jl +++ b/src/default.jl @@ -163,7 +163,7 @@ function defaultalg(A, b, assump::OperatorAssumptions) DefaultAlgorithmChoice.GenericLUFactorization elseif VERSION >= v"1.8" && appleaccelerate_isavailable() DefaultAlgorithmChoice.AppleAccelerateLUFactorization - elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) || + elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) || (usemkl && length(b) <= 200)) && (A === nothing ? eltype(b) <: Union{Float32, Float64} : eltype(A) <: Union{Float32, Float64})