Skip to content

Commit

Permalink
Use Pivoted QR for Underdetermined Systems
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 24, 2023
1 parent 86beacf commit 988e3f7
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false
needs_concrete_A(alg::AbstractSolveFunction) = false

# Util
is_underdetermined(x) = false
is_underdetermined(A::AbstractMatrix) = size(A, 1) < size(A, 2)
is_underdetermined(A::AbstractSciMLOperator) = size(A, 1) < size(A, 2)

Check warning on line 67 in src/LinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/LinearSolve.jl#L67

Added line #L67 was not covered by tests

_isidentity_struct(A) = false
_isidentity_struct::Number) = isone(λ)
Expand Down Expand Up @@ -96,6 +99,7 @@ EnumX.@enumx DefaultAlgorithmChoice begin
NormalCholeskyFactorization
AppleAccelerateLUFactorization
MKLLUFactorization
QRFactorizationPivoted
end

struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm
Expand Down
28 changes: 23 additions & 5 deletions src/default.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
needs_concrete_A(alg::DefaultLinearSolver) = true
mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12,
T13, T14, T15, T16, T17, T18}
T13, T14, T15, T16, T17, T18, T19}
LUFactorization::T1
QRFactorization::T2
DiagonalFactorization::T3
Expand All @@ -19,6 +19,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
NormalCholeskyFactorization::T16
AppleAccelerateLUFactorization::T17
MKLLUFactorization::T18
QRFactorizationPivoted::T19
end

# Legacy fallback
Expand Down Expand Up @@ -168,8 +169,8 @@ function defaultalg(A, b, assump::OperatorAssumptions)
(A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
DefaultAlgorithmChoice.RFLUFactorization
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
#elseif A === nothing || A isa Matrix
# alg = FastLUFactorization()
elseif usemkl && (A === nothing ? eltype(b) <: Union{Float32, Float64} :
eltype(A) <: Union{Float32, Float64})
DefaultAlgorithmChoice.MKLLUFactorization
Expand Down Expand Up @@ -199,9 +200,19 @@ function defaultalg(A, b, assump::OperatorAssumptions)
elseif assump.condition === OperatorCondition.WellConditioned
DefaultAlgorithmChoice.NormalCholeskyFactorization
elseif assump.condition === OperatorCondition.IllConditioned
DefaultAlgorithmChoice.QRFactorization
if is_underdetermined(A)
# Underdetermined
DefaultAlgorithmChoice.QRFactorizationPivoted
else
DefaultAlgorithmChoice.QRFactorization
end
elseif assump.condition === OperatorCondition.VeryIllConditioned
DefaultAlgorithmChoice.QRFactorization
if is_underdetermined(A)

Check warning on line 210 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L210

Added line #L210 was not covered by tests
# Underdetermined
DefaultAlgorithmChoice.QRFactorizationPivoted

Check warning on line 212 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L212

Added line #L212 was not covered by tests
else
DefaultAlgorithmChoice.QRFactorization

Check warning on line 214 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L214

Added line #L214 was not covered by tests
end
elseif assump.condition === OperatorCondition.SuperIllConditioned
DefaultAlgorithmChoice.SVDFactorization
else
Expand Down Expand Up @@ -247,6 +258,12 @@ function algchoice_to_alg(alg::Symbol)
NormalCholeskyFactorization()
elseif alg === :AppleAccelerateLUFactorization
AppleAccelerateLUFactorization()
elseif alg === :QRFactorizationPivoted
@static if VERSION v"1.7beta"

Check warning on line 262 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L262

Added line #L262 was not covered by tests
QRFactorization(ColumnNorm())
else
QRFactorization(Val(true))

Check warning on line 265 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L265

Added line #L265 was not covered by tests
end
else
error("Algorithm choice symbol $alg not allowed in the default")
end
Expand Down Expand Up @@ -310,6 +327,7 @@ function defaultalg_symbol(::Type{T}) where {T}
Symbol(split(string(SciMLBase.parameterless_type(T)), ".")[end])
end
defaultalg_symbol(::Type{<:GenericFactorization{typeof(ldlt!)}}) = :LDLtFactorization
defaultalg_symbol(::Type{<:QRFactorization{ColumnNorm}}) = :QRFactorizationPivoted

"""
if alg.alg === DefaultAlgorithmChoice.LUFactorization
Expand Down
10 changes: 10 additions & 0 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ function QRFactorization(inplace = true)
QRFactorization(pivot, 16, inplace)
end

@static if VERSION v"1.7beta"
function QRFactorization(pivot::LinearAlgebra.PivotingStrategy, inplace::Bool = true)
QRFactorization(pivot, 16, inplace)
end
else
function QRFactorization(pivot::Val, inplace::Bool = true)
QRFactorization(pivot, 16, inplace)

Check warning on line 167 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L166-L167

Added lines #L166 - L167 were not covered by tests
end
end

function do_factorization(alg::QRFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
Expand Down
9 changes: 9 additions & 0 deletions test/nonsquare.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,12 @@ solve(LinearProblem(A, b), (LinearSolve.NormalCholeskyFactorization())).u;
solve(LinearProblem(A, b),
assumptions = (OperatorAssumptions(false;
condition = OperatorCondition.WellConditioned))).u;

# Underdetermined
m, n = 2, 3

A = rand(m, n)
b = rand(m)
prob = LinearProblem(A, b)
res = A \ b
@test solve(prob).u res

0 comments on commit 988e3f7

Please sign in to comment.