From d80fadedba79a9c29dd1021681cb16f26ac5f9af Mon Sep 17 00:00:00 2001 From: Arno Strouwen Date: Sun, 3 Dec 2023 18:09:36 +0100 Subject: [PATCH] start fixing ambiguities --- ext/LinearSolveBandedMatricesExt.jl | 4 +- ext/LinearSolveFastAlmostBandedMatricesExt.jl | 2 +- src/default.jl | 32 ++--- src/factorization.jl | 133 ++++++++++++++---- 4 files changed, 127 insertions(+), 44 deletions(-) diff --git a/ext/LinearSolveBandedMatricesExt.jl b/ext/LinearSolveBandedMatricesExt.jl index 1eaf0b712..c6d851894 100644 --- a/ext/LinearSolveBandedMatricesExt.jl +++ b/ext/LinearSolveBandedMatricesExt.jl @@ -5,7 +5,7 @@ import LinearSolve: defaultalg, do_factorization, init_cacheval, DefaultLinearSolver, DefaultAlgorithmChoice # Defaults for BandedMatrices -function defaultalg(A::BandedMatrix, b, oa::OperatorAssumptions) +function defaultalg(A::BandedMatrix, b, oa::OperatorAssumptions{Bool}) if oa.issq return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!) elseif LinearSolve.is_underdetermined(A) @@ -15,7 +15,7 @@ function defaultalg(A::BandedMatrix, b, oa::OperatorAssumptions) end end -function defaultalg(A::Symmetric{<:Number, <:BandedMatrix}, b, ::OperatorAssumptions) +function defaultalg(A::Symmetric{<:Number, <:BandedMatrix}, b, ::OperatorAssumptions{Bool}) return DefaultLinearSolver(DefaultAlgorithmChoice.CholeskyFactorization) end diff --git a/ext/LinearSolveFastAlmostBandedMatricesExt.jl b/ext/LinearSolveFastAlmostBandedMatricesExt.jl index 187f2b454..080ec30f5 100644 --- a/ext/LinearSolveFastAlmostBandedMatricesExt.jl +++ b/ext/LinearSolveFastAlmostBandedMatricesExt.jl @@ -4,7 +4,7 @@ using FastAlmostBandedMatrices, LinearAlgebra, LinearSolve import LinearSolve: defaultalg, do_factorization, init_cacheval, DefaultLinearSolver, DefaultAlgorithmChoice -function defaultalg(A::AlmostBandedMatrix, b, oa::OperatorAssumptions) +function defaultalg(A::AlmostBandedMatrix, b, oa::OperatorAssumptions{Bool}) if oa.issq return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!) else diff --git a/src/default.jl b/src/default.jl index 574d81885..caf0ed701 100644 --- a/src/default.jl +++ b/src/default.jl @@ -27,7 +27,7 @@ end defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(true)) function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b, - assump::OperatorAssumptions) + assump::OperatorAssumptions{Bool}) defaultalg(A.A, b, assump) end @@ -36,7 +36,7 @@ function defaultalg(A, b, assump::OperatorAssumptions{Nothing}) defaultalg(A, b, OperatorAssumptions(issq, assump.condition)) end -function defaultalg(A::Tridiagonal, b, assump::OperatorAssumptions) +function defaultalg(A::Tridiagonal, b, assump::OperatorAssumptions{Bool}) if assump.issq DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization) else @@ -44,33 +44,33 @@ function defaultalg(A::Tridiagonal, b, assump::OperatorAssumptions) end end -function defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions) +function defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions{Bool}) DefaultLinearSolver(DefaultAlgorithmChoice.LDLtFactorization) end -function defaultalg(A::Bidiagonal, b, ::OperatorAssumptions) +function defaultalg(A::Bidiagonal, b, ::OperatorAssumptions{Bool}) DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!) end -function defaultalg(A::Factorization, b, ::OperatorAssumptions) +function defaultalg(A::Factorization, b, ::OperatorAssumptions{Bool}) DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!) end -function defaultalg(A::Diagonal, b, ::OperatorAssumptions) +function defaultalg(A::Diagonal, b, ::OperatorAssumptions{Bool}) DefaultLinearSolver(DefaultAlgorithmChoice.DiagonalFactorization) end -function defaultalg(A::Hermitian, b, ::OperatorAssumptions) +function defaultalg(A::Hermitian, b, ::OperatorAssumptions{Bool}) DefaultLinearSolver(DefaultAlgorithmChoice.CholeskyFactorization) end -function defaultalg(A::Symmetric{<:Number, <:Array}, b, ::OperatorAssumptions) +function defaultalg(A::Symmetric{<:Number, <:Array}, b, ::OperatorAssumptions{Bool}) DefaultLinearSolver(DefaultAlgorithmChoice.BunchKaufmanFactorization) end -function defaultalg(A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions) +function defaultalg(A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool}) DefaultLinearSolver(DefaultAlgorithmChoice.CHOLMODFactorization) end function defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b, - assump::OperatorAssumptions) where {Tv, Ti} + assump::OperatorAssumptions{Bool}) where {Tv, Ti} if assump.issq DefaultLinearSolver(DefaultAlgorithmChoice.SparspakFactorization) else @@ -80,7 +80,7 @@ end @static if INCLUDE_SPARSE function defaultalg(A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b, - assump::OperatorAssumptions) where {Ti} + assump::OperatorAssumptions{Bool}) where {Ti} if assump.issq if length(b) <= 10_000 && length(nonzeros(A)) / length(A) < 2e-4 DefaultLinearSolver(DefaultAlgorithmChoice.KLUFactorization) @@ -93,7 +93,7 @@ end end end -function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssumptions) +function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssumptions{Bool}) if assump.condition === OperatorCondition.IllConditioned || !assump.issq DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization) else @@ -102,7 +102,7 @@ function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssump end # A === nothing case -function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, assump::OperatorAssumptions) +function defaultalg(A::Nothing, b::GPUArraysCore.AbstractGPUArray, assump::OperatorAssumptions{Bool}) if assump.condition === OperatorCondition.IllConditioned || !assump.issq DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization) else @@ -112,7 +112,7 @@ end # Ambiguity handling function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray, - assump::OperatorAssumptions) + assump::OperatorAssumptions{Bool}) if assump.condition === OperatorCondition.IllConditioned || !assump.issq DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization) else @@ -121,7 +121,7 @@ function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.Abstract end function defaultalg(A::SciMLBase.AbstractSciMLOperator, b, - assump::OperatorAssumptions) + assump::OperatorAssumptions{Bool}) if has_ldiv!(A) return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!) elseif !assump.issq @@ -137,7 +137,7 @@ function defaultalg(A::SciMLBase.AbstractSciMLOperator, b, end # Allows A === nothing as a stand-in for dense matrix -function defaultalg(A, b, assump::OperatorAssumptions) +function defaultalg(A, b, assump::OperatorAssumptions{Bool}) alg = if assump.issq # Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when # it makes sense according to the benchmarks, which is dependent on diff --git a/src/factorization.jl b/src/factorization.jl index 786b202c3..a65e1a062 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -483,15 +483,15 @@ function do_factorization(alg::GenericFactorization, A, b, u) return fact end -function init_cacheval(alg::GenericFactorization{typeof(lu)}, A, b, u, Pl, Pr, +function init_cacheval(alg::GenericFactorization{typeof(lu)}, A::AbstractMatrix, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - ArrayInterface.lu_instance(convert(AbstractMatrix, A)) + ArrayInterface.lu_instance(A) end -function init_cacheval(alg::GenericFactorization{typeof(lu!)}, A, b, u, Pl, Pr, +function init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::AbstractMatrix, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - ArrayInterface.lu_instance(convert(AbstractMatrix, A)) + ArrayInterface.lu_instance(A) end function init_cacheval(alg::GenericFactorization{typeof(lu)}, @@ -526,16 +526,36 @@ function init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::Tridiagonal, b assumptions::OperatorAssumptions) ArrayInterface.lu_instance(A) end +function init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) where {T, V} + LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A) +end +function init_cacheval(alg::GenericFactorization{typeof(lu)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) where {T, V} + LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A) +end -function init_cacheval(alg::GenericFactorization{typeof(qr)}, A, b, u, Pl, Pr, +function init_cacheval(alg::GenericFactorization{typeof(qr)}, A::AbstractMatrix, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - ArrayInterface.qr_instance(convert(AbstractMatrix, A)) + ArrayInterface.qr_instance(A) end -function init_cacheval(alg::GenericFactorization{typeof(qr!)}, A, b, u, Pl, Pr, +function init_cacheval(alg::GenericFactorization{typeof(qr!)}, A::AbstractMatrix, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - ArrayInterface.qr_instance(convert(AbstractMatrix, A)) + ArrayInterface.qr_instance(A) +end +function init_cacheval(alg::GenericFactorization{typeof(qr)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) where {T, V} + LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A) +end +function init_cacheval(alg::GenericFactorization{typeof(qr!)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) where {T, V} + LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A) end function init_cacheval(alg::GenericFactorization{typeof(qr)}, @@ -571,15 +591,15 @@ function init_cacheval(alg::GenericFactorization{typeof(qr!)}, A::Tridiagonal, b ArrayInterface.qr_instance(A) end -function init_cacheval(alg::GenericFactorization{typeof(svd)}, A, b, u, Pl, Pr, +function init_cacheval(alg::GenericFactorization{typeof(svd)}, A::AbstractMatrix, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - ArrayInterface.svd_instance(convert(AbstractMatrix, A)) + ArrayInterface.svd_instance(A) end -function init_cacheval(alg::GenericFactorization{typeof(svd!)}, A, b, u, Pl, Pr, +function init_cacheval(alg::GenericFactorization{typeof(svd!)}, A::AbstractMatrix, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - ArrayInterface.svd_instance(convert(AbstractMatrix, A)) + ArrayInterface.svd_instance(A) end function init_cacheval(alg::GenericFactorization{typeof(svd)}, @@ -615,6 +635,16 @@ function init_cacheval(alg::GenericFactorization{typeof(svd!)}, A::Tridiagonal, assumptions::OperatorAssumptions) ArrayInterface.svd_instance(A) end +function init_cacheval(alg::GenericFactorization{typeof(svd!)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) where {T, V} + LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A) +end +function init_cacheval(alg::GenericFactorization{typeof(svd)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) where {T, V} + LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A) +end function init_cacheval(alg::GenericFactorization, A::Diagonal, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) @@ -630,6 +660,18 @@ function init_cacheval(alg::GenericFactorization, A::SymTridiagonal{T, V}, b, u, assumptions::OperatorAssumptions) where {T, V} LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A) end +function init_cacheval(alg::GenericFactorization, A, b, u, Pl, Pr, + maxiters::Int, + abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) + init_cacheval(alg, convert(AbstractMatrix, A), b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) +end +function init_cacheval(alg::GenericFactorization, A::AbstractMatrix, b, u, Pl, Pr, + maxiters::Int, + abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) + do_factorization(alg, A, b, u) +end function init_cacheval(alg::Union{GenericFactorization{typeof(bunchkaufman!)}, GenericFactorization{typeof(bunchkaufman)}}, @@ -654,15 +696,49 @@ end # Try to never use it. # Cholesky needs the posdef matrix, for GenericFactorization assume structure is needed -function init_cacheval(alg::Union{GenericFactorization{typeof(cholesky)}, - GenericFactorization{typeof(cholesky!)}}, A, b, u, Pl, Pr, - maxiters::Int, abstol, reltol, verbose::Bool, +function init_cacheval(alg::GenericFactorization{typeof(cholesky)}, A::AbstractMatrix, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) + newA = copy(convert(AbstractMatrix, A)) + do_factorization(alg, newA, b, u) +end +function init_cacheval(alg::GenericFactorization{typeof(cholesky!)}, A::AbstractMatrix, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) newA = copy(convert(AbstractMatrix, A)) do_factorization(alg, newA, b, u) end +function init_cacheval(alg::GenericFactorization{typeof(cholesky!)}, A::Diagonal, b, u, Pl, Pr, maxiters::Int, + abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) + Diagonal(inv.(A.diag)) +end +function init_cacheval(alg::GenericFactorization{typeof(cholesky!)}, A::Tridiagonal, b, u, Pl, Pr, + maxiters::Int, + abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) + ArrayInterface.lu_instance(A) +end +function init_cacheval(alg::GenericFactorization{typeof(cholesky!)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) where {T, V} + LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A) +end +function init_cacheval(alg::GenericFactorization{typeof(cholesky)}, A::Diagonal, b, u, Pl, Pr, maxiters::Int, + abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) + Diagonal(inv.(A.diag)) +end +function init_cacheval(alg::GenericFactorization{typeof(cholesky)}, A::Tridiagonal, b, u, Pl, Pr, + maxiters::Int, + abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) + ArrayInterface.lu_instance(A) +end +function init_cacheval(alg::GenericFactorization{typeof(cholesky)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) where {T, V} + LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A) +end + -function init_cacheval(alg::Union{GenericFactorization}, +function init_cacheval(alg::GenericFactorization, A::Union{Hermitian{T, <:SparseMatrixCSC}, Symmetric{T, <:SparseMatrixCSC}}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, @@ -1242,37 +1318,44 @@ function FastQRFactorization() end @static if VERSION < v"1.7beta" - function init_cacheval(alg::FastQRFactorization{Val{false}}, A, b, u, Pl, Pr, + function init_cacheval(alg::FastQRFactorization{Val{false}}, A::AbstractMatrix, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) ws = QRWYWs(A; blocksize = alg.blocksize) return WorkspaceAndFactors(ws, - ArrayInterface.qr_instance(convert(AbstractMatrix, A))) + ArrayInterface.qr_instance(A)) end - function init_cacheval(::FastQRFactorization{Val{true}}, A, b, u, Pl, Pr, + function init_cacheval(::FastQRFactorization{Val{true}}, A::AbstractMatrix, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) ws = QRpWs(A) return WorkspaceAndFactors(ws, - ArrayInterface.qr_instance(convert(AbstractMatrix, A))) + ArrayInterface.qr_instance(A)) end else - function init_cacheval(alg::FastQRFactorization{NoPivot}, A, b, u, Pl, Pr, + function init_cacheval(alg::FastQRFactorization{NoPivot}, A::AbstractMatrix, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) ws = QRWYWs(A; blocksize = alg.blocksize) return WorkspaceAndFactors(ws, - ArrayInterface.qr_instance(convert(AbstractMatrix, A))) + ArrayInterface.qr_instance(A)) end - function init_cacheval(::FastQRFactorization{ColumnNorm}, A, b, u, Pl, Pr, + function init_cacheval(::FastQRFactorization{ColumnNorm}, A::AbstractMatrix, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) ws = QRpWs(A) return WorkspaceAndFactors(ws, - ArrayInterface.qr_instance(convert(AbstractMatrix, A))) + ArrayInterface.qr_instance(A)) end end +function init_cacheval(alg::FastQRFactorization, A, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) + return init_cacheval(alg, convert(AbstractMatrix, A), b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) +end function SciMLBase.solve!(cache::LinearCache, alg::FastQRFactorization{P}; kwargs...) where {P} @@ -1380,4 +1463,4 @@ for alg in InteractiveUtils.subtypes(AbstractFactorization) maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) end -end +end