Skip to content

Commit

Permalink
Merge pull request #437 from ArnoStrouwen/amb
Browse files Browse the repository at this point in the history
start fixing ambiguities
  • Loading branch information
ChrisRackauckas authored Dec 11, 2023
2 parents f87583e + 3099e72 commit bda160c
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 38 deletions.
4 changes: 2 additions & 2 deletions ext/LinearSolveBandedMatricesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import LinearSolve: defaultalg,
do_factorization, init_cacheval, DefaultLinearSolver, DefaultAlgorithmChoice

# Defaults for BandedMatrices
function defaultalg(A::BandedMatrix, b, oa::OperatorAssumptions)
function defaultalg(A::BandedMatrix, b, oa::OperatorAssumptions{Bool})
if oa.issq
return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
elseif LinearSolve.is_underdetermined(A)
Expand All @@ -15,7 +15,7 @@ function defaultalg(A::BandedMatrix, b, oa::OperatorAssumptions)
end
end

function defaultalg(A::Symmetric{<:Number, <:BandedMatrix}, b, ::OperatorAssumptions)
function defaultalg(A::Symmetric{<:Number, <:BandedMatrix}, b, ::OperatorAssumptions{Bool})
return DefaultLinearSolver(DefaultAlgorithmChoice.CholeskyFactorization)
end

Expand Down
2 changes: 1 addition & 1 deletion ext/LinearSolveFastAlmostBandedMatricesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using FastAlmostBandedMatrices, LinearAlgebra, LinearSolve
import LinearSolve: defaultalg,
do_factorization, init_cacheval, DefaultLinearSolver, DefaultAlgorithmChoice

function defaultalg(A::AlmostBandedMatrix, b, oa::OperatorAssumptions)
function defaultalg(A::AlmostBandedMatrix, b, oa::OperatorAssumptions{Bool})
if oa.issq
return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
else
Expand Down
32 changes: 16 additions & 16 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ end
defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(true))

function defaultalg(A::Union{DiffEqArrayOperator, MatrixOperator}, b,
assump::OperatorAssumptions)
assump::OperatorAssumptions{Bool})
defaultalg(A.A, b, assump)
end

Expand All @@ -36,41 +36,41 @@ function defaultalg(A, b, assump::OperatorAssumptions{Nothing})
defaultalg(A, b, OperatorAssumptions(issq, assump.condition))
end

function defaultalg(A::Tridiagonal, b, assump::OperatorAssumptions)
function defaultalg(A::Tridiagonal, b, assump::OperatorAssumptions{Bool})
if assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization)
else
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
end
end

function defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions)
function defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions{Bool})
DefaultLinearSolver(DefaultAlgorithmChoice.LDLtFactorization)
end
function defaultalg(A::Bidiagonal, b, ::OperatorAssumptions)
function defaultalg(A::Bidiagonal, b, ::OperatorAssumptions{Bool})
DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
end
function defaultalg(A::Factorization, b, ::OperatorAssumptions)
function defaultalg(A::Factorization, b, ::OperatorAssumptions{Bool})
DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
end
function defaultalg(A::Diagonal, b, ::OperatorAssumptions)
function defaultalg(A::Diagonal, b, ::OperatorAssumptions{Bool})
DefaultLinearSolver(DefaultAlgorithmChoice.DiagonalFactorization)
end

function defaultalg(A::Hermitian, b, ::OperatorAssumptions)
function defaultalg(A::Hermitian, b, ::OperatorAssumptions{Bool})
DefaultLinearSolver(DefaultAlgorithmChoice.CholeskyFactorization)
end

function defaultalg(A::Symmetric{<:Number, <:Array}, b, ::OperatorAssumptions)
function defaultalg(A::Symmetric{<:Number, <:Array}, b, ::OperatorAssumptions{Bool})
DefaultLinearSolver(DefaultAlgorithmChoice.BunchKaufmanFactorization)
end

function defaultalg(A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions)
function defaultalg(A::Symmetric{<:Number, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
DefaultLinearSolver(DefaultAlgorithmChoice.CHOLMODFactorization)
end

function defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b,
assump::OperatorAssumptions) where {Tv, Ti}
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
if assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.SparspakFactorization)
else
Expand All @@ -80,7 +80,7 @@ end

@static if INCLUDE_SPARSE
function defaultalg(A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
assump::OperatorAssumptions) where {Ti}
assump::OperatorAssumptions{Bool}) where {Ti}
if assump.issq
if length(b) <= 10_000 && length(nonzeros(A)) / length(A) < 2e-4
DefaultLinearSolver(DefaultAlgorithmChoice.KLUFactorization)
Expand All @@ -93,7 +93,7 @@ end
end
end

function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssumptions)
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssumptions{Bool})
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
else
Expand All @@ -102,7 +102,7 @@ function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, assump::OperatorAssump
end

# A === nothing case
function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, assump::OperatorAssumptions)
function defaultalg(A::Nothing, b::GPUArraysCore.AbstractGPUArray, assump::OperatorAssumptions{Bool})
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
else
Expand All @@ -112,7 +112,7 @@ end

# Ambiguity handling
function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray,
assump::OperatorAssumptions)
assump::OperatorAssumptions{Bool})
if assump.condition === OperatorCondition.IllConditioned || !assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.QRFactorization)
else
Expand All @@ -121,7 +121,7 @@ function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.Abstract
end

function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
assump::OperatorAssumptions)
assump::OperatorAssumptions{Bool})
if has_ldiv!(A)
return DefaultLinearSolver(DefaultAlgorithmChoice.DirectLdiv!)
elseif !assump.issq
Expand All @@ -137,7 +137,7 @@ function defaultalg(A::SciMLBase.AbstractSciMLOperator, b,
end

# Allows A === nothing as a stand-in for dense matrix
function defaultalg(A, b, assump::OperatorAssumptions)
function defaultalg(A, b, assump::OperatorAssumptions{Bool})
alg = if assump.issq
# Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
# it makes sense according to the benchmarks, which is dependent on
Expand Down
122 changes: 103 additions & 19 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,15 +402,15 @@ function do_factorization(alg::GenericFactorization, A, b, u)
return fact
end

function init_cacheval(alg::GenericFactorization{typeof(lu)}, A, b, u, Pl, Pr,
function init_cacheval(alg::GenericFactorization{typeof(lu)}, A::AbstractMatrix, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
ArrayInterface.lu_instance(A)
end
function init_cacheval(alg::GenericFactorization{typeof(lu!)}, A, b, u, Pl, Pr,
function init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::AbstractMatrix, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ArrayInterface.lu_instance(convert(AbstractMatrix, A))
ArrayInterface.lu_instance(A)
end

function init_cacheval(alg::GenericFactorization{typeof(lu)},
Expand Down Expand Up @@ -445,16 +445,36 @@ function init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::Tridiagonal, b
assumptions::OperatorAssumptions)
ArrayInterface.lu_instance(A)
end
function init_cacheval(alg::GenericFactorization{typeof(lu!)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions) where {T, V}
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
end
function init_cacheval(alg::GenericFactorization{typeof(lu)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions) where {T, V}
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
end

function init_cacheval(alg::GenericFactorization{typeof(qr)}, A, b, u, Pl, Pr,
function init_cacheval(alg::GenericFactorization{typeof(qr)}, A::AbstractMatrix, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ArrayInterface.qr_instance(convert(AbstractMatrix, A))
ArrayInterface.qr_instance(A)
end
function init_cacheval(alg::GenericFactorization{typeof(qr!)}, A, b, u, Pl, Pr,
function init_cacheval(alg::GenericFactorization{typeof(qr!)}, A::AbstractMatrix, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ArrayInterface.qr_instance(convert(AbstractMatrix, A))
ArrayInterface.qr_instance(A)
end
function init_cacheval(alg::GenericFactorization{typeof(qr)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions) where {T, V}
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
end
function init_cacheval(alg::GenericFactorization{typeof(qr!)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions) where {T, V}
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
end

function init_cacheval(alg::GenericFactorization{typeof(qr)},
Expand Down Expand Up @@ -490,15 +510,15 @@ function init_cacheval(alg::GenericFactorization{typeof(qr!)}, A::Tridiagonal, b
ArrayInterface.qr_instance(A)
end

function init_cacheval(alg::GenericFactorization{typeof(svd)}, A, b, u, Pl, Pr,
function init_cacheval(alg::GenericFactorization{typeof(svd)}, A::AbstractMatrix, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ArrayInterface.svd_instance(convert(AbstractMatrix, A))
ArrayInterface.svd_instance(A)
end
function init_cacheval(alg::GenericFactorization{typeof(svd!)}, A, b, u, Pl, Pr,
function init_cacheval(alg::GenericFactorization{typeof(svd!)}, A::AbstractMatrix, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ArrayInterface.svd_instance(convert(AbstractMatrix, A))
ArrayInterface.svd_instance(A)
end

function init_cacheval(alg::GenericFactorization{typeof(svd)},
Expand Down Expand Up @@ -534,6 +554,16 @@ function init_cacheval(alg::GenericFactorization{typeof(svd!)}, A::Tridiagonal,
assumptions::OperatorAssumptions)
ArrayInterface.svd_instance(A)
end
function init_cacheval(alg::GenericFactorization{typeof(svd!)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions) where {T, V}
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
end
function init_cacheval(alg::GenericFactorization{typeof(svd)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions) where {T, V}
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
end

function init_cacheval(alg::GenericFactorization, A::Diagonal, b, u, Pl, Pr, maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
Expand All @@ -549,6 +579,18 @@ function init_cacheval(alg::GenericFactorization, A::SymTridiagonal{T, V}, b, u,
assumptions::OperatorAssumptions) where {T, V}
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
end
function init_cacheval(alg::GenericFactorization, A, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
init_cacheval(alg, convert(AbstractMatrix, A), b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
end
function init_cacheval(alg::GenericFactorization, A::AbstractMatrix, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
do_factorization(alg, A, b, u)
end

function init_cacheval(alg::Union{GenericFactorization{typeof(bunchkaufman!)},
GenericFactorization{typeof(bunchkaufman)}},
Expand All @@ -573,15 +615,49 @@ end
# Try to never use it.

# Cholesky needs the posdef matrix, for GenericFactorization assume structure is needed
function init_cacheval(alg::Union{GenericFactorization{typeof(cholesky)},
GenericFactorization{typeof(cholesky!)}}, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
function init_cacheval(alg::GenericFactorization{typeof(cholesky)}, A::AbstractMatrix, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
newA = copy(convert(AbstractMatrix, A))
do_factorization(alg, newA, b, u)
end
function init_cacheval(alg::GenericFactorization{typeof(cholesky!)}, A::AbstractMatrix, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
newA = copy(convert(AbstractMatrix, A))
do_factorization(alg, newA, b, u)
end
function init_cacheval(alg::GenericFactorization{typeof(cholesky!)}, A::Diagonal, b, u, Pl, Pr, maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
Diagonal(inv.(A.diag))
end
function init_cacheval(alg::GenericFactorization{typeof(cholesky!)}, A::Tridiagonal, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ArrayInterface.lu_instance(A)
end
function init_cacheval(alg::GenericFactorization{typeof(cholesky!)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions) where {T, V}
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
end
function init_cacheval(alg::GenericFactorization{typeof(cholesky)}, A::Diagonal, b, u, Pl, Pr, maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
Diagonal(inv.(A.diag))
end
function init_cacheval(alg::GenericFactorization{typeof(cholesky)}, A::Tridiagonal, b, u, Pl, Pr,
maxiters::Int,
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
ArrayInterface.lu_instance(A)
end
function init_cacheval(alg::GenericFactorization{typeof(cholesky)}, A::SymTridiagonal{T, V}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions) where {T, V}
LinearAlgebra.LDLt{T, SymTridiagonal{T, V}}(A)
end


function init_cacheval(alg::Union{GenericFactorization},
function init_cacheval(alg::GenericFactorization,
A::Union{Hermitian{T, <:SparseMatrixCSC},
Symmetric{T, <:SparseMatrixCSC}}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
Expand Down Expand Up @@ -1063,21 +1139,29 @@ end
# but QRFactorization uses 16.
FastQRFactorization() = FastQRFactorization(NoPivot(), 36)

function init_cacheval(alg::FastQRFactorization{NoPivot}, A, b, u, Pl, Pr,
function init_cacheval(alg::FastQRFactorization{NoPivot}, A::AbstractMatrix, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ws = QRWYWs(A; blocksize = alg.blocksize)
return WorkspaceAndFactors(ws,
ArrayInterface.qr_instance(convert(AbstractMatrix, A)))
end
function init_cacheval(::FastQRFactorization{ColumnNorm}, A, b, u, Pl, Pr,
function init_cacheval(::FastQRFactorization{ColumnNorm}, A::AbstractMatrix, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ws = QRpWs(A)
return WorkspaceAndFactors(ws,
ArrayInterface.qr_instance(convert(AbstractMatrix, A)))
end

function init_cacheval(alg::FastQRFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
return init_cacheval(alg, convert(AbstractMatrix, A), b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
end

function SciMLBase.solve!(cache::LinearCache, alg::FastQRFactorization{P};
kwargs...) where {P}
A = cache.A
Expand Down Expand Up @@ -1184,4 +1268,4 @@ for alg in InteractiveUtils.subtypes(AbstractFactorization)
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
end
end
end

0 comments on commit bda160c

Please sign in to comment.