From 72bc0143e6e45bd48cae04ea937b1eb674295308 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 Nov 2023 18:07:53 -0500 Subject: [PATCH] Support StaticArrays Properly --- Project.toml | 6 ++++-- src/LinearSolve.jl | 2 ++ src/default.jl | 7 +++++++ src/factorization.jl | 28 +++++++++++++++++++++++----- 4 files changed, 36 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 12d3aab9d..fba07d804 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LinearSolve" uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" authors = ["SciML"] -version = "2.20.0" +version = "2.20.1" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -26,6 +26,7 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac" +StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [weakdeps] @@ -48,6 +49,7 @@ LinearSolveBandedMatricesExt = "BandedMatrices" LinearSolveBlockDiagonalsExt = "BlockDiagonals" LinearSolveCUDAExt = "CUDA" LinearSolveEnzymeExt = ["Enzyme", "EnzymeCore"] +LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"] LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" LinearSolveKernelAbstractionsExt = "KernelAbstractions" @@ -55,7 +57,6 @@ LinearSolveKrylovKitExt = "KrylovKit" LinearSolveMetalExt = "Metal" LinearSolvePardisoExt = "Pardiso" LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools" -LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"] [compat] Aqua = "0.8" @@ -102,6 +103,7 @@ Setfield = "1" SparseArrays = "1.9" Sparspak = "0.3.6" Test = "1" +StaticArraysCore = "1" UnPack = "1" julia = "1.9" diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 0d70f0123..33a1dc4f5 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -26,6 +26,8 @@ PrecompileTools.@recompile_invalidations begin using Requires import InteractiveUtils + import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix + using LinearAlgebra: BlasInt, LU using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, diff --git a/src/default.jl b/src/default.jl index 574d81885..8aa3f3713 100644 --- a/src/default.jl +++ b/src/default.jl @@ -175,6 +175,10 @@ function defaultalg(A, b, assump::OperatorAssumptions) DefaultAlgorithmChoice.LUFactorization end + # For static arrays GMRES allocates a lot. Use factorization + elseif A isa StaticArray + DefaultAlgorithmChoice.LUFactorization + # This catches the cases where a factorization overload could exist # For example, BlockBandedMatrix elseif A !== nothing && ArrayInterface.isstructured(A) @@ -186,6 +190,9 @@ function defaultalg(A, b, assump::OperatorAssumptions) end elseif assump.condition === OperatorCondition.WellConditioned DefaultAlgorithmChoice.NormalCholeskyFactorization + elseif A isa StaticArray + # Static Array doesn't have QR() \ b defined + return DefaultAlgorithmChoice.SVDFactorization elseif assump.condition === OperatorCondition.IllConditioned if is_underdetermined(A) # Underdetermined diff --git a/src/factorization.jl b/src/factorization.jl index 786b202c3..5eba6ffb2 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -10,6 +10,9 @@ end _ldiv!(x, A, b) = ldiv!(x, A, b) +_ldiv!(x::MVector, A, b::SVector) = (x .= A \ b) +_ldiv!(::SVector, A, b::SVector) = (A \ b) + function _ldiv!(x::Vector, A::Factorization, b::Vector) # workaround https://github.com/JuliaLang/julia/issues/43507 # Fallback if working with non-square matrices @@ -88,6 +91,8 @@ function do_factorization(alg::LUFactorization, A, b, u) if A isa AbstractSparseMatrixCSC return lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), check = false) + elseif !ArrayInterface.can_setindex(typeof(A)) + fact = lu(A, alg.pivot, check = false) else fact = lu!(A, alg.pivot, check = false) end @@ -172,10 +177,14 @@ end function do_factorization(alg::QRFactorization, A, b, u) A = convert(AbstractMatrix, A) - if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray) - fact = qr!(A, alg.pivot) + if ArrayInterface.can_setindex(typeof(A)) + if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray) + fact = qr!(A, alg.pivot) + else + fact = qr(A) # CUDA.jl does not allow other args! + end else - fact = qr(A) # CUDA.jl does not allow other args! + fact = qr(A, alg.pivot) end return fact end @@ -372,11 +381,15 @@ SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer()) function do_factorization(alg::SVDFactorization, A, b, u) A = convert(AbstractMatrix, A) - fact = svd!(A; full = alg.full, alg = alg.alg) + if ArrayInterface.can_setindex(typeof(A)) + fact = svd!(A; alg.full, alg.alg) + else + fact = svd(A; alg.full) + end return fact end -function init_cacheval(alg::SVDFactorization, A::Matrix, b, u, Pl, Pr, +function init_cacheval(alg::SVDFactorization, A::Union{Matrix, SMatrix}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) ArrayInterface.svd_instance(convert(AbstractMatrix, A)) @@ -1354,6 +1367,11 @@ function init_cacheval(::SparspakFactorization, A, b, u, Pl, Pr, maxiters::Int, end end +function init_cacheval(::SparspakFactorization, ::StaticArray, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) + nothing +end + function SciMLBase.solve!(cache::LinearCache, alg::SparspakFactorization; kwargs...) A = cache.A if cache.isfresh