From 72bc0143e6e45bd48cae04ea937b1eb674295308 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 29 Nov 2023 18:07:53 -0500 Subject: [PATCH 1/4] Support StaticArrays Properly --- Project.toml | 6 ++++-- src/LinearSolve.jl | 2 ++ src/default.jl | 7 +++++++ src/factorization.jl | 28 +++++++++++++++++++++++----- 4 files changed, 36 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 12d3aab9d..fba07d804 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LinearSolve" uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" authors = ["SciML"] -version = "2.20.0" +version = "2.20.1" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" @@ -26,6 +26,7 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac" +StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [weakdeps] @@ -48,6 +49,7 @@ LinearSolveBandedMatricesExt = "BandedMatrices" LinearSolveBlockDiagonalsExt = "BlockDiagonals" LinearSolveCUDAExt = "CUDA" LinearSolveEnzymeExt = ["Enzyme", "EnzymeCore"] +LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"] LinearSolveHYPREExt = "HYPRE" LinearSolveIterativeSolversExt = "IterativeSolvers" LinearSolveKernelAbstractionsExt = "KernelAbstractions" @@ -55,7 +57,6 @@ LinearSolveKrylovKitExt = "KrylovKit" LinearSolveMetalExt = "Metal" LinearSolvePardisoExt = "Pardiso" LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools" -LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"] [compat] Aqua = "0.8" @@ -102,6 +103,7 @@ Setfield = "1" SparseArrays = "1.9" Sparspak = "0.3.6" Test = "1" +StaticArraysCore = "1" UnPack = "1" julia = "1.9" diff --git a/src/LinearSolve.jl b/src/LinearSolve.jl index 0d70f0123..33a1dc4f5 100644 --- a/src/LinearSolve.jl +++ b/src/LinearSolve.jl @@ -26,6 +26,8 @@ PrecompileTools.@recompile_invalidations begin using Requires import InteractiveUtils + import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix + using LinearAlgebra: BlasInt, LU using LinearAlgebra.LAPACK: require_one_based_indexing, chkfinite, chkstride1, diff --git a/src/default.jl b/src/default.jl index 574d81885..8aa3f3713 100644 --- a/src/default.jl +++ b/src/default.jl @@ -175,6 +175,10 @@ function defaultalg(A, b, assump::OperatorAssumptions) DefaultAlgorithmChoice.LUFactorization end + # For static arrays GMRES allocates a lot. Use factorization + elseif A isa StaticArray + DefaultAlgorithmChoice.LUFactorization + # This catches the cases where a factorization overload could exist # For example, BlockBandedMatrix elseif A !== nothing && ArrayInterface.isstructured(A) @@ -186,6 +190,9 @@ function defaultalg(A, b, assump::OperatorAssumptions) end elseif assump.condition === OperatorCondition.WellConditioned DefaultAlgorithmChoice.NormalCholeskyFactorization + elseif A isa StaticArray + # Static Array doesn't have QR() \ b defined + return DefaultAlgorithmChoice.SVDFactorization elseif assump.condition === OperatorCondition.IllConditioned if is_underdetermined(A) # Underdetermined diff --git a/src/factorization.jl b/src/factorization.jl index 786b202c3..5eba6ffb2 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -10,6 +10,9 @@ end _ldiv!(x, A, b) = ldiv!(x, A, b) +_ldiv!(x::MVector, A, b::SVector) = (x .= A \ b) +_ldiv!(::SVector, A, b::SVector) = (A \ b) + function _ldiv!(x::Vector, A::Factorization, b::Vector) # workaround https://github.com/JuliaLang/julia/issues/43507 # Fallback if working with non-square matrices @@ -88,6 +91,8 @@ function do_factorization(alg::LUFactorization, A, b, u) if A isa AbstractSparseMatrixCSC return lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)), check = false) + elseif !ArrayInterface.can_setindex(typeof(A)) + fact = lu(A, alg.pivot, check = false) else fact = lu!(A, alg.pivot, check = false) end @@ -172,10 +177,14 @@ end function do_factorization(alg::QRFactorization, A, b, u) A = convert(AbstractMatrix, A) - if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray) - fact = qr!(A, alg.pivot) + if ArrayInterface.can_setindex(typeof(A)) + if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray) + fact = qr!(A, alg.pivot) + else + fact = qr(A) # CUDA.jl does not allow other args! + end else - fact = qr(A) # CUDA.jl does not allow other args! + fact = qr(A, alg.pivot) end return fact end @@ -372,11 +381,15 @@ SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer()) function do_factorization(alg::SVDFactorization, A, b, u) A = convert(AbstractMatrix, A) - fact = svd!(A; full = alg.full, alg = alg.alg) + if ArrayInterface.can_setindex(typeof(A)) + fact = svd!(A; alg.full, alg.alg) + else + fact = svd(A; alg.full) + end return fact end -function init_cacheval(alg::SVDFactorization, A::Matrix, b, u, Pl, Pr, +function init_cacheval(alg::SVDFactorization, A::Union{Matrix, SMatrix}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) ArrayInterface.svd_instance(convert(AbstractMatrix, A)) @@ -1354,6 +1367,11 @@ function init_cacheval(::SparspakFactorization, A, b, u, Pl, Pr, maxiters::Int, end end +function init_cacheval(::SparspakFactorization, ::StaticArray, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) + nothing +end + function SciMLBase.solve!(cache::LinearCache, alg::SparspakFactorization; kwargs...) A = cache.A if cache.isfresh From 12fbcefb334025e0ec93d0e69bacec1658cdce96 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 4 Dec 2023 16:57:08 -0500 Subject: [PATCH 2/4] Handle cholesky --- Project.toml | 2 +- src/factorization.jl | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index fba07d804..b8aec3744 100644 --- a/Project.toml +++ b/Project.toml @@ -102,8 +102,8 @@ SciMLOperators = "0.3" Setfield = "1" SparseArrays = "1.9" Sparspak = "0.3.6" -Test = "1" StaticArraysCore = "1" +Test = "1" UnPack = "1" julia = "1.9" diff --git a/src/factorization.jl b/src/factorization.jl index 5eba6ffb2..c66e6ed94 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -10,7 +10,7 @@ end _ldiv!(x, A, b) = ldiv!(x, A, b) -_ldiv!(x::MVector, A, b::SVector) = (x .= A \ b) +_ldiv!(x, A, b::SVector) = (x .= A \ b) _ldiv!(::SVector, A, b::SVector) = (A \ b) function _ldiv!(x::Vector, A::Factorization, b::Vector) @@ -285,6 +285,12 @@ else end end +function init_cacheval(alg::CholeskyFactorization, A::SMatrix, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) + cholesky(A) # StaticArrays doesn't have the pivot argument. Prevent generic fallback. +end + function init_cacheval(alg::CholeskyFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) From 326059994b28451b954844b43706f7f3e1804215 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 5 Dec 2023 12:35:16 -0500 Subject: [PATCH 3/4] Add tests for StaticArrays --- src/default.jl | 2 +- src/factorization.jl | 6 ++++-- test/runtests.jl | 1 + test/static_arrays.jl | 25 +++++++++++++++++++++++++ 4 files changed, 31 insertions(+), 3 deletions(-) create mode 100644 test/static_arrays.jl diff --git a/src/default.jl b/src/default.jl index 8aa3f3713..231f81361 100644 --- a/src/default.jl +++ b/src/default.jl @@ -192,7 +192,7 @@ function defaultalg(A, b, assump::OperatorAssumptions) DefaultAlgorithmChoice.NormalCholeskyFactorization elseif A isa StaticArray # Static Array doesn't have QR() \ b defined - return DefaultAlgorithmChoice.SVDFactorization + DefaultAlgorithmChoice.SVDFactorization elseif assump.condition === OperatorCondition.IllConditioned if is_underdetermined(A) # Underdetermined diff --git a/src/factorization.jl b/src/factorization.jl index c66e6ed94..2c9f4aad6 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -12,6 +12,7 @@ _ldiv!(x, A, b) = ldiv!(x, A, b) _ldiv!(x, A, b::SVector) = (x .= A \ b) _ldiv!(::SVector, A, b::SVector) = (A \ b) +_ldiv!(::SVector, A, b) = (A \ b) function _ldiv!(x::Vector, A::Factorization, b::Vector) # workaround https://github.com/JuliaLang/julia/issues/43507 @@ -288,7 +289,7 @@ end function init_cacheval(alg::CholeskyFactorization, A::SMatrix, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - cholesky(A) # StaticArrays doesn't have the pivot argument. Prevent generic fallback. + cholesky(SMatrix{1, 1}(one(eltype(A)))) # StaticArrays doesn't have the pivot argument. Prevent generic fallback. end function init_cacheval(alg::CholeskyFactorization, A, b, u, Pl, Pr, @@ -1060,7 +1061,8 @@ end function init_cacheval(alg::NormalCholeskyFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) - ArrayInterface.cholesky_instance(convert(AbstractMatrix, A), alg.pivot) + A_ = convert(AbstractMatrix, A) + ArrayInterface.cholesky_instance(Symmetric((A)' * A, :L), alg.pivot) end function init_cacheval(alg::NormalCholeskyFactorization, diff --git a/test/runtests.jl b/test/runtests.jl index 4ca877a76..2cfa51650 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,6 +17,7 @@ if GROUP == "All" || GROUP == "Core" VERSION >= v"1.9" && @time @safetestset "Enzyme Derivative Rules" include("enzyme.jl") @time @safetestset "Traits" include("traits.jl") VERSION >= v"1.9" && @time @safetestset "BandedMatrices" include("banded.jl") + @time @safetestset "Static Arrays" include("static_arrays.jl") end if GROUP == "LinearSolveCUDA" diff --git a/test/static_arrays.jl b/test/static_arrays.jl new file mode 100644 index 000000000..c77147992 --- /dev/null +++ b/test/static_arrays.jl @@ -0,0 +1,25 @@ +using LinearSolve, StaticArrays, LinearAlgebra + +A = SMatrix{5, 5}(Hermitian(rand(5, 5) + I)) +b = SVector{5}(rand(5)) + +for alg in (nothing, LUFactorization(), SVDFactorization(), CholeskyFactorization(), + KrylovJL_GMRES()) + sol = solve(LinearProblem(A, b), alg) + @show norm(A * sol .- b) + @test norm(A * sol .- b) < 1e-10 +end + +A = SMatrix{7, 5}(rand(7, 5)) +b = SVector{7}(rand(7)) + +for alg in (nothing, SVDFactorization(), KrylovJL_LSMR()) + @test_nowarn solve(LinearProblem(A, b), alg) +end + +A = SMatrix{5, 7}(rand(5, 7)) +b = SVector{5}(rand(5)) + +for alg in (nothing, SVDFactorization(), KrylovJL_LSMR()) + @test_nowarn solve(LinearProblem(A, b), alg) +end From b68ebc30abc613536e5d4cb0d65b42d8bcf87598 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 5 Dec 2023 13:17:46 -0500 Subject: [PATCH 4/4] Fix tests --- Project.toml | 4 +++- src/factorization.jl | 10 +++++++--- test/default_algs.jl | 2 +- test/static_arrays.jl | 1 - 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index b8aec3744..fce114b23 100644 --- a/Project.toml +++ b/Project.toml @@ -103,6 +103,7 @@ Setfield = "1" SparseArrays = "1.9" Sparspak = "0.3.6" StaticArraysCore = "1" +StaticArrays = "1" Test = "1" UnPack = "1" julia = "1.9" @@ -128,7 +129,8 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices"] +test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays"] diff --git a/src/factorization.jl b/src/factorization.jl index 2c9f4aad6..831065690 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -286,10 +286,14 @@ else end end -function init_cacheval(alg::CholeskyFactorization, A::SMatrix, b, u, Pl, Pr, +function init_cacheval(alg::CholeskyFactorization, A::SMatrix{S1, S2}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, - assumptions::OperatorAssumptions) - cholesky(SMatrix{1, 1}(one(eltype(A)))) # StaticArrays doesn't have the pivot argument. Prevent generic fallback. + assumptions::OperatorAssumptions) where {S1, S2} + # StaticArrays doesn't have the pivot argument. Prevent generic fallback. + # CholeskyFactorization is part of DefaultLinearSolver, so it is possible that `A` is + # not Hermitian. + (!issquare(A) || !ishermitian(A)) && return nothing + cholesky(A) end function init_cacheval(alg::CholeskyFactorization, A, b, u, Pl, Pr, diff --git a/test/default_algs.jl b/test/default_algs.jl index 3435b2f19..431e2d155 100644 --- a/test/default_algs.jl +++ b/test/default_algs.jl @@ -50,7 +50,7 @@ solve(prob) A = rand(4, 4) b = rand(4) prob = LinearProblem(A, b) - JET.@test_opt init(prob, nothing) + VERSION ≥ v"1.10-" && JET.@test_opt init(prob, nothing) JET.@test_opt solve(prob, LUFactorization()) JET.@test_opt solve(prob, GenericLUFactorization()) @test_skip JET.@test_opt solve(prob, QRFactorization()) diff --git a/test/static_arrays.jl b/test/static_arrays.jl index c77147992..5c6ccf252 100644 --- a/test/static_arrays.jl +++ b/test/static_arrays.jl @@ -6,7 +6,6 @@ b = SVector{5}(rand(5)) for alg in (nothing, LUFactorization(), SVDFactorization(), CholeskyFactorization(), KrylovJL_GMRES()) sol = solve(LinearProblem(A, b), alg) - @show norm(A * sol .- b) @test norm(A * sol .- b) < 1e-10 end