From 86beacf921c089cd1215bc23c1146419005f2590 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 24 Oct 2023 12:23:12 -0400 Subject: [PATCH 1/4] Add `needs_square_A` trait --- Project.toml | 2 +- src/LinearSolve.jl | 40 ++++++++++++++++++++++++++++++++++++++++ test/nonsquare.jl | 7 +++++++ 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2d1808607..6bbf26462 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LinearSolve" uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" authors = ["SciML"] -version = "2.11.1" +version = "2.12.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index e240a790b..661a0a859 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -143,6 +143,46 @@ end include("factorization_sparse.jl") end +# Solver Specific Traits +## Needs Square Matrix +""" + needs_square_A(alg) + +Returns `true` if the algorithm requires a square matrix. + +Note that this checks if the implementation of the algorithm needs a square matrix by +trying to solve an underdetermined system. It is recommended to add a dispatch to this +function for custom algorithms! +""" +needs_square_A(::Nothing) = false # Linear Solve automatically will use a correct alg! +function needs_square_A(alg::SciMLLinearSolveAlgorithm) + try + A = [1.0 2.0; + 3.0 4.0; + 5.0 6.0] + b = ones(Float64, 3) + solve(LinearProblem(A, b), alg) + return false + catch err + return true + end +end +for alg in (:QRFactorization, :FastQRFactorization, :NormalCholeskyFactorization, + :NormalBunchKaufmanFactorization) + @eval needs_square_A(::$(alg)) = false +end +for kralg in (Krylov.lsmr!, Krylov.craigmr!) + @eval needs_square_A(::KrylovJL{$(typeof(kralg))}) = false +end +for alg in (:LUFactorization, :FastLUFactorization, :SVDFactorization, + :GenericFactorization, :GenericLUFactorization, :SimpleLUFactorization, + :RFLUFactorization, :UMFPACKFactorization, :KLUFactorization, :SparspakFactorization, + :DiagonalFactorization, :CholeskyFactorization, :BunchKaufmanFactorization, + :CHOLMODFactorization, :LDLtFactorization, :AppleAccelerateLUFactorization, + :MKLLUFactorization, :MetalLUFactorization) + @eval needs_square_A(::$(alg)) = true +end + const IS_OPENBLAS = Ref(true) isopenblas() = IS_OPENBLAS[] diff --git a/test/nonsquare.jl b/test/nonsquare.jl index c678c17b3..e0e817b9e 100644 --- a/test/nonsquare.jl +++ b/test/nonsquare.jl @@ -8,7 +8,11 @@ b = rand(m) prob = LinearProblem(A, b) res = A \ b @test solve(prob).u ≈ res +@test !LinearSolve.needs_square_A(QRFactorization()) @test solve(prob, QRFactorization()) ≈ res +@test !LinearSolve.needs_square_A(FastQRFactorization()) +@test solve(prob, FastQRFactorization()) ≈ res +@test !LinearSolve.needs_square_A(KrylovJL_LSMR()) @test solve(prob, KrylovJL_LSMR()) ≈ res A = sprand(m, n, 0.5) @@ -23,6 +27,7 @@ A = sprand(n, m, 0.5) b = rand(n) prob = LinearProblem(A, b) res = Matrix(A) \ b +@test !LinearSolve.needs_square_A(KrylovJL_CRAIGMR()) @test solve(prob, KrylovJL_CRAIGMR()) ≈ res A = sprandn(1000, 100, 0.1) @@ -35,7 +40,9 @@ A = randn(1000, 100) b = randn(1000) @test isapprox(solve(LinearProblem(A, b)).u, Symmetric(A' * A) \ (A' * b)) solve(LinearProblem(A, b)).u; +@test !LinearSolve.needs_square_A(NormalCholeskyFactorization()) solve(LinearProblem(A, b), (LinearSolve.NormalCholeskyFactorization())).u; +@test !LinearSolve.needs_square_A(NormalBunchKaufmanFactorization()) solve(LinearProblem(A, b), (LinearSolve.NormalBunchKaufmanFactorization())).u; solve(LinearProblem(A, b), assumptions = (OperatorAssumptions(false; From 988e3f76e3130f8443ea275071823e2b6cab216d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 24 Oct 2023 15:34:17 -0400 Subject: [PATCH 2/4] Use Pivoted QR for Underdetermined Systems --- src/LinearSolve.jl | 4 ++++ src/default.jl | 28 +++++++++++++++++++++++----- src/factorization.jl | 10 ++++++++++ test/nonsquare.jl | 9 +++++++++ 4 files changed, 46 insertions(+), 5 deletions(-) diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 661a0a859..3ae907fab 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -62,6 +62,9 @@ needs_concrete_A(alg::AbstractKrylovSubspaceMethod) = false needs_concrete_A(alg::AbstractSolveFunction) = false # Util +is_underdetermined(x) = false +is_underdetermined(A::AbstractMatrix) = size(A, 1) < size(A, 2) +is_underdetermined(A::AbstractSciMLOperator) = size(A, 1) < size(A, 2) _isidentity_struct(A) = false _isidentity_struct(λ::Number) = isone(λ) @@ -96,6 +99,7 @@ EnumX.@enumx DefaultAlgorithmChoice begin NormalCholeskyFactorization AppleAccelerateLUFactorization MKLLUFactorization + QRFactorizationPivoted end struct DefaultLinearSolver <: SciMLLinearSolveAlgorithm diff --git a/src/default.jl b/src/default.jl index 509a14571..830770802 100644 --- a/src/default.jl +++ b/src/default.jl @@ -1,6 +1,6 @@ needs_concrete_A(alg::DefaultLinearSolver) = true mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, - T13, T14, T15, T16, T17, T18} + T13, T14, T15, T16, T17, T18, T19} LUFactorization::T1 QRFactorization::T2 DiagonalFactorization::T3 @@ -19,6 +19,7 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, NormalCholeskyFactorization::T16 AppleAccelerateLUFactorization::T17 MKLLUFactorization::T18 + QRFactorizationPivoted::T19 end # Legacy fallback @@ -168,8 +169,8 @@ function defaultalg(A, b, assump::OperatorAssumptions) (A === nothing ? eltype(b) <: Union{Float32, Float64} : eltype(A) <: Union{Float32, Float64}) DefaultAlgorithmChoice.RFLUFactorization - #elseif A === nothing || A isa Matrix - # alg = FastLUFactorization() + #elseif A === nothing || A isa Matrix + # alg = FastLUFactorization() elseif usemkl && (A === nothing ? eltype(b) <: Union{Float32, Float64} : eltype(A) <: Union{Float32, Float64}) DefaultAlgorithmChoice.MKLLUFactorization @@ -199,9 +200,19 @@ function defaultalg(A, b, assump::OperatorAssumptions) elseif assump.condition === OperatorCondition.WellConditioned DefaultAlgorithmChoice.NormalCholeskyFactorization elseif assump.condition === OperatorCondition.IllConditioned - DefaultAlgorithmChoice.QRFactorization + if is_underdetermined(A) + # Underdetermined + DefaultAlgorithmChoice.QRFactorizationPivoted + else + DefaultAlgorithmChoice.QRFactorization + end elseif assump.condition === OperatorCondition.VeryIllConditioned - DefaultAlgorithmChoice.QRFactorization + if is_underdetermined(A) + # Underdetermined + DefaultAlgorithmChoice.QRFactorizationPivoted + else + DefaultAlgorithmChoice.QRFactorization + end elseif assump.condition === OperatorCondition.SuperIllConditioned DefaultAlgorithmChoice.SVDFactorization else @@ -247,6 +258,12 @@ function algchoice_to_alg(alg::Symbol) NormalCholeskyFactorization() elseif alg === :AppleAccelerateLUFactorization AppleAccelerateLUFactorization() + elseif alg === :QRFactorizationPivoted + @static if VERSION ≥ v"1.7beta" + QRFactorization(ColumnNorm()) + else + QRFactorization(Val(true)) + end else error("Algorithm choice symbol $alg not allowed in the default") end @@ -310,6 +327,7 @@ function defaultalg_symbol(::Type{T}) where {T} Symbol(split(string(SciMLBase.parameterless_type(T)), ".")[end]) end defaultalg_symbol(::Type{<:GenericFactorization{typeof(ldlt!)}}) = :LDLtFactorization +defaultalg_symbol(::Type{<:QRFactorization{ColumnNorm}}) = :QRFactorizationPivoted """ if alg.alg === DefaultAlgorithmChoice.LUFactorization diff --git a/src/factorization.jl b/src/factorization.jl index 7156991fb..2917580a3 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -158,6 +158,16 @@ function QRFactorization(inplace = true) QRFactorization(pivot, 16, inplace) end +@static if VERSION ≥ v"1.7beta" + function QRFactorization(pivot::LinearAlgebra.PivotingStrategy, inplace::Bool = true) + QRFactorization(pivot, 16, inplace) + end +else + function QRFactorization(pivot::Val, inplace::Bool = true) + QRFactorization(pivot, 16, inplace) + end +end + function do_factorization(alg::QRFactorization, A, b, u) A = convert(AbstractMatrix, A) if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray) diff --git a/test/nonsquare.jl b/test/nonsquare.jl index e0e817b9e..694bf4dfd 100644 --- a/test/nonsquare.jl +++ b/test/nonsquare.jl @@ -56,3 +56,12 @@ solve(LinearProblem(A, b), (LinearSolve.NormalCholeskyFactorization())).u; solve(LinearProblem(A, b), assumptions = (OperatorAssumptions(false; condition = OperatorCondition.WellConditioned))).u; + +# Underdetermined +m, n = 2, 3 + +A = rand(m, n) +b = rand(m) +prob = LinearProblem(A, b) +res = A \ b +@test solve(prob).u ≈ res From 45adbd08537438d13e0d73323773ccddf75f50ee Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 24 Oct 2023 19:58:13 -0400 Subject: [PATCH 3/4] Remove fallback --- src/LinearSolve.jl | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 3ae907fab..b0511080a 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -153,24 +153,9 @@ end needs_square_A(alg) Returns `true` if the algorithm requires a square matrix. - -Note that this checks if the implementation of the algorithm needs a square matrix by -trying to solve an underdetermined system. It is recommended to add a dispatch to this -function for custom algorithms! """ needs_square_A(::Nothing) = false # Linear Solve automatically will use a correct alg! -function needs_square_A(alg::SciMLLinearSolveAlgorithm) - try - A = [1.0 2.0; - 3.0 4.0; - 5.0 6.0] - b = ones(Float64, 3) - solve(LinearProblem(A, b), alg) - return false - catch err - return true - end -end +needs_square_A(alg::SciMLLinearSolveAlgorithm) = true for alg in (:QRFactorization, :FastQRFactorization, :NormalCholeskyFactorization, :NormalBunchKaufmanFactorization) @eval needs_square_A(::$(alg)) = false From 12c7ed9616de1d4f63d83a6634545331ba59ec7e Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 25 Oct 2023 02:45:33 +0200 Subject: [PATCH 4/4] Update default.jl --- src/default.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/default.jl b/src/default.jl index 830770802..82166a247 100644 --- a/src/default.jl +++ b/src/default.jl @@ -327,7 +327,12 @@ function defaultalg_symbol(::Type{T}) where {T} Symbol(split(string(SciMLBase.parameterless_type(T)), ".")[end]) end defaultalg_symbol(::Type{<:GenericFactorization{typeof(ldlt!)}}) = :LDLtFactorization -defaultalg_symbol(::Type{<:QRFactorization{ColumnNorm}}) = :QRFactorizationPivoted + +@static if VERSION >= v"1.7" + defaultalg_symbol(::Type{<:QRFactorization{ColumnNorm}}) = :QRFactorizationPivoted +else + defaultalg_symbol(::Type{<:QRFactorization{Val{true}}}) = :QRFactorizationPivoted +end """ if alg.alg === DefaultAlgorithmChoice.LUFactorization