Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add needs_square_A trait #400

Merged
merged 4 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.11.1"
version = "2.12.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
29 changes: 29 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
needs_concrete_A(alg::AbstractSolveFunction) = false

# Util
is_underdetermined(x) = false
is_underdetermined(A::AbstractMatrix) = size(A, 1) < size(A, 2)
is_underdetermined(A::AbstractSciMLOperator) = size(A, 1) < size(A, 2)

_isidentity_struct(A) = false
_isidentity_struct(λ::Number) = isone(λ)
Expand Down Expand Up @@ -96,6 +99,7 @@ EnumX.@enumx DefaultAlgorithmChoice begin
NormalCholeskyFactorization
AppleAccelerateLUFactorization
MKLLUFactorization
QRFactorizationPivoted
end

struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
Expand Down Expand Up @@ -143,6 +147,31 @@ end
include("factorization_sparse.jl")
end

# Solver Specific Traits
## Needs Square Matrix
"""
needs_square_A(alg)

Returns `true` if the algorithm requires a square matrix.
"""
needs_square_A(::Nothing) = false # Linear Solve automatically will use a correct alg!
needs_square_A(alg::SciMLLinearSolveAlgorithm) = true
for alg in (:QRFactorization, :FastQRFactorization, :NormalCholeskyFactorization,
:NormalBunchKaufmanFactorization)
@eval needs_square_A(::$(alg)) = false
end
for kralg in (Krylov.lsmr!, Krylov.craigmr!)
@eval needs_square_A(::KrylovJL{$(typeof(kralg))}) = false
end
for alg in (:LUFactorization, :FastLUFactorization, :SVDFactorization,
:GenericFactorization, :GenericLUFactorization, :SimpleLUFactorization,
:RFLUFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization,
:DiagonalFactorization, :CholeskyFactorization, :BunchKaufmanFactorization,
:CHOLMODFactorization, :LDLtFactorization, :AppleAccelerateLUFactorization,
:MKLLUFactorization, :MetalLUFactorization)
@eval needs_square_A(::$(alg)) = true
end

const IS_OPENBLAS = Ref(true)
isopenblas() = IS_OPENBLAS[]

Expand Down
33 changes: 28 additions & 5 deletions src/default.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
needs_concrete_A(alg::DefaultLinearSolver) = true
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
T13, T14, T15, T16, T17, T18}
T13, T14, T15, T16, T17, T18, T19}
LUFactorization::T1
QRFactorization::T2
DiagonalFactorization::T3
Expand All @@ -19,6 +19,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
NormalCholeskyFactorization::T16
AppleAccelerateLUFactorization::T17
MKLLUFactorization::T18
QRFactorizationPivoted::T19
end

# Legacy fallback
Expand Down Expand Up @@ -168,8 +169,8 @@ function defaultalg(A, b, assump::OperatorAssumptions)
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
DefaultAlgorithmChoice.RFLUFactorization
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
elseif usemkl && (A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
DefaultAlgorithmChoice.MKLLUFactorization
Expand Down Expand Up @@ -199,9 +200,19 @@ function defaultalg(A, b, assump::OperatorAssumptions)
elseif assump.condition === OperatorCondition.WellConditioned
DefaultAlgorithmChoice.NormalCholeskyFactorization
elseif assump.condition === OperatorCondition.IllConditioned
DefaultAlgorithmChoice.QRFactorization
if is_underdetermined(A)
# Underdetermined
DefaultAlgorithmChoice.QRFactorizationPivoted
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QRFactorization is already pivoted?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I see. @YingboMa what do you think of just always defaulting to pivoted QR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, pivoted QR handles degenerate systems too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it's quite a bit slower though.

else
DefaultAlgorithmChoice.QRFactorization
end
elseif assump.condition === OperatorCondition.VeryIllConditioned
DefaultAlgorithmChoice.QRFactorization
if is_underdetermined(A)
# Underdetermined
DefaultAlgorithmChoice.QRFactorizationPivoted
else
DefaultAlgorithmChoice.QRFactorization
end
elseif assump.condition === OperatorCondition.SuperIllConditioned
DefaultAlgorithmChoice.SVDFactorization
else
Expand Down Expand Up @@ -247,6 +258,12 @@ function algchoice_to_alg(alg::Symbol)
NormalCholeskyFactorization()
elseif alg === :AppleAccelerateLUFactorization
AppleAccelerateLUFactorization()
elseif alg === :QRFactorizationPivoted
@static if VERSION ≥ v"1.7beta"
QRFactorization(ColumnNorm())
else
QRFactorization(Val(true))
end
else
error("Algorithm choice symbol $alg not allowed in the default")
end
Expand Down Expand Up @@ -311,6 +328,12 @@ function defaultalg_symbol(::Type{T}) where {T}
end
defaultalg_symbol(::Type{<:GenericFactorization{typeof(ldlt!)}}) = :LDLtFactorization

@static if VERSION >= v"1.7"
defaultalg_symbol(::Type{<:QRFactorization{ColumnNorm}}) = :QRFactorizationPivoted
else
defaultalg_symbol(::Type{<:QRFactorization{Val{true}}}) = :QRFactorizationPivoted
end

"""
if alg.alg === DefaultAlgorithmChoice.LUFactorization
SciMLBase.solve!(cache, LUFactorization(), args...; kwargs...))
Expand Down
10 changes: 10 additions & 0 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ function QRFactorization(inplace = true)
QRFactorization(pivot, 16, inplace)
end

@static if VERSION ≥ v"1.7beta"
function QRFactorization(pivot::LinearAlgebra.PivotingStrategy, inplace::Bool = true)
QRFactorization(pivot, 16, inplace)
end
else
function QRFactorization(pivot::Val, inplace::Bool = true)
QRFactorization(pivot, 16, inplace)
end
end

function do_factorization(alg::QRFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
Expand Down
16 changes: 16 additions & 0 deletions test/nonsquare.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ b = rand(m)
prob = LinearProblem(A, b)
res = A \ b
@test solve(prob).u ≈ res
@test !LinearSolve.needs_square_A(QRFactorization())
@test solve(prob, QRFactorization()) ≈ res
@test !LinearSolve.needs_square_A(FastQRFactorization())
@test solve(prob, FastQRFactorization()) ≈ res
@test !LinearSolve.needs_square_A(KrylovJL_LSMR())
@test solve(prob, KrylovJL_LSMR()) ≈ res

A = sprand(m, n, 0.5)
Expand All @@ -23,6 +27,7 @@ A = sprand(n, m, 0.5)
b = rand(n)
prob = LinearProblem(A, b)
res = Matrix(A) \ b
@test !LinearSolve.needs_square_A(KrylovJL_CRAIGMR())
@test solve(prob, KrylovJL_CRAIGMR()) ≈ res

A = sprandn(1000, 100, 0.1)
Expand All @@ -35,7 +40,9 @@ A = randn(1000, 100)
b = randn(1000)
@test isapprox(solve(LinearProblem(A, b)).u, Symmetric(A' * A) \ (A' * b))
solve(LinearProblem(A, b)).u;
@test !LinearSolve.needs_square_A(NormalCholeskyFactorization())
solve(LinearProblem(A, b), (LinearSolve.NormalCholeskyFactorization())).u;
@test !LinearSolve.needs_square_A(NormalBunchKaufmanFactorization())
solve(LinearProblem(A, b), (LinearSolve.NormalBunchKaufmanFactorization())).u;
solve(LinearProblem(A, b),
assumptions = (OperatorAssumptions(false;
Expand All @@ -49,3 +56,12 @@ solve(LinearProblem(A, b), (LinearSolve.NormalCholeskyFactorization())).u;
solve(LinearProblem(A, b),
assumptions = (OperatorAssumptions(false;
condition = OperatorCondition.WellConditioned))).u;

# Underdetermined
m, n = 2, 3

A = rand(m, n)
b = rand(m)
prob = LinearProblem(A, b)
res = A \ b
@test solve(prob).u ≈ res