From cc65daaf8ef14b67a6be133affd7617eb69dd7fe Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 11 Dec 2023 00:01:22 -0500 Subject: [PATCH 1/3] Proper handling of static arrays --- Project.toml | 10 +++++----- src/common.jl | 26 +++++++++++++++----------- src/default.jl | 18 +++++++++--------- test/static_arrays.jl | 2 +- 4 files changed, 30 insertions(+), 26 deletions(-) diff --git a/Project.toml b/Project.toml index 36749c026..2fc2bae47 100644 --- a/Project.toml +++ b/Project.toml @@ -63,8 +63,8 @@ Aqua = "0.8" ArrayInterface = "7.4.11" BandedMatrices = "1" BlockDiagonals = "0.1" -ConcreteStructs = "0.2" CUDA = "5" +ConcreteStructs = "0.2" DocStringExtensions = "0.9" EnumX = "1" Enzyme = "0.11" @@ -77,15 +77,15 @@ GPUArraysCore = "0.1" HYPRE = "1.4.0" InteractiveUtils = "1.6" IterativeSolvers = "0.9.3" -Libdl = "1.6" -LinearAlgebra = "1.9" JET = "0.8" KLU = "0.3.0, 0.4" KernelAbstractions = "0.9" Krylov = "0.9" KrylovKit = "0.6" -Metal = "0.5" +Libdl = "1.6" +LinearAlgebra = "1.9" MPI = "0.20" +Metal = "0.5" MultiFloats = "1" Pardiso = "0.5" Pkg = "1" @@ -102,8 +102,8 @@ SciMLOperators = "0.3" Setfield = "1" SparseArrays = "1.9" Sparspak = "0.3.6" -StaticArraysCore = "1" StaticArrays = "1" +StaticArraysCore = "1" Test = "1" UnPack = "1" julia = "1.9" diff --git a/src/common.jl b/src/common.jl index 791ab91c8..765382f81 100644 --- a/src/common.jl +++ b/src/common.jl @@ -119,6 +119,15 @@ default_alias_b(::Any, ::Any, ::Any) = false default_alias_A(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true default_alias_b(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true +function __init_u0_from_Ab(A, b) + u0 = similar(b, size(A, 2)) + fill!(u0, false) + return u0 +end +function __init_u0_from_Ab(A::SMatrix{S1, S2}, b) where {S1, S2} + return zeros(SVector{S2, eltype(b)}) +end + function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, args...; alias_A = default_alias_A(alg, prob.A, prob.b), @@ -133,7 +142,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, kwargs...) @unpack A, b, u0, p = prob - A = if alias_A + A = if alias_A || A isa SMatrix A elseif A isa Array || A isa SparseMatrixCSC copy(A) @@ -143,7 +152,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, b = if b isa SparseArrays.AbstractSparseArray && !(A isa Diagonal) Array(b) # the solution to a linear solve will always be dense! - elseif alias_b + elseif alias_b || b isa SVector b elseif b isa Array || b isa SparseMatrixCSC copy(b) @@ -151,18 +160,13 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, deepcopy(b) end - u0 = if u0 !== nothing - u0 - else - u0 = similar(b, size(A, 2)) - fill!(u0, false) - end + u0_ = u0 !== nothing ? u0 : __init_u0_from_Ab(A, b) # Guard against type mismatch for user-specified reltol/abstol reltol = real(eltype(prob.b))(reltol) abstol = real(eltype(prob.b))(abstol) - cacheval = init_cacheval(alg, A, b, u0, Pl, Pr, maxiters, abstol, reltol, verbose, + cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose, assumptions) isfresh = true Tc = typeof(cacheval) @@ -170,7 +174,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, cache = LinearCache{ typeof(A), typeof(b), - typeof(u0), + typeof(u0_), typeof(p), typeof(alg), Tc, @@ -180,7 +184,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, typeof(assumptions.issq), }(A, b, - u0, + u0_, p, alg, cacheval, diff --git a/src/default.jl b/src/default.jl index 2fd2f5ddf..4c4b43a1f 100644 --- a/src/default.jl +++ b/src/default.jl @@ -36,6 +36,14 @@ function defaultalg(A, b, assump::OperatorAssumptions{Nothing}) defaultalg(A, b, OperatorAssumptions(issq, assump.condition)) end +function defaultalg(A::SMatrix{S1, S2}, b, assump::OperatorAssumptions{Bool}) where {S1, S2} + if S1 == S2 + return LUFactorization() + else + return SVDFactorization() # QR(...) \ b is not defined currently + end +end + function defaultalg(A::Tridiagonal, b, assump::OperatorAssumptions{Bool}) if assump.issq DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization) @@ -175,10 +183,6 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool}) DefaultAlgorithmChoice.LUFactorization end - # For static arrays GMRES allocates a lot. Use factorization - elseif A isa StaticArray - DefaultAlgorithmChoice.LUFactorization - # This catches the cases where a factorization overload could exist # For example, BlockBandedMatrix elseif A !== nothing && ArrayInterface.isstructured(A) @@ -190,9 +194,6 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool}) end elseif assump.condition === OperatorCondition.WellConditioned DefaultAlgorithmChoice.NormalCholeskyFactorization - elseif A isa StaticArray - # Static Array doesn't have QR() \ b defined - DefaultAlgorithmChoice.SVDFactorization elseif assump.condition === OperatorCondition.IllConditioned if is_underdetermined(A) # Underdetermined @@ -269,8 +270,7 @@ function SciMLBase.init(prob::LinearProblem, alg::Nothing, args...; assumptions = OperatorAssumptions(issquare(prob.A)), kwargs...) - alg = defaultalg(prob.A, prob.b, assumptions) - SciMLBase.init(prob, alg, args...; assumptions, kwargs...) + SciMLBase.init(prob, defaultalg(prob.A, prob.b, assumptions), args...; assumptions, kwargs...) end function SciMLBase.solve!(cache::LinearCache, alg::Nothing, diff --git a/test/static_arrays.jl b/test/static_arrays.jl index 5c6ccf252..6fc614e67 100644 --- a/test/static_arrays.jl +++ b/test/static_arrays.jl @@ -1,4 +1,4 @@ -using LinearSolve, StaticArrays, LinearAlgebra +using LinearSolve, StaticArrays, LinearAlgebra, Test A = SMatrix{5, 5}(Hermitian(rand(5, 5) + I)) b = SVector{5}(rand(5)) From 3c56ab422a0268d47ac9f9dc58de2a6372bdd99e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 11 Dec 2023 10:04:19 -0500 Subject: [PATCH 2/3] Add inference check --- src/common.jl | 4 +--- src/factorization.jl | 4 ---- src/iterative_wrappers.jl | 4 +++- test/static_arrays.jl | 3 +++ 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/common.jl b/src/common.jl index 765382f81..3f59b3da2 100644 --- a/src/common.jl +++ b/src/common.jl @@ -124,9 +124,7 @@ function __init_u0_from_Ab(A, b) fill!(u0, false) return u0 end -function __init_u0_from_Ab(A::SMatrix{S1, S2}, b) where {S1, S2} - return zeros(SVector{S2, eltype(b)}) -end +__init_u0_from_Ab(::SMatrix{S1, S2}, b) where {S1, S2} = zeros(SVector{S2, eltype(b)}) function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, args...; diff --git a/src/factorization.jl b/src/factorization.jl index 5ff403155..2077d0639 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -215,10 +215,6 @@ end function init_cacheval(alg::CholeskyFactorization, A::SMatrix{S1, S2}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) where {S1, S2} - # StaticArrays doesn't have the pivot argument. Prevent generic fallback. - # CholeskyFactorization is part of DefaultLinearSolver, so it is possible that `A` is - # not Hermitian. - (!issquare(A) || !ishermitian(A)) && return nothing cholesky(A) end diff --git a/src/iterative_wrappers.jl b/src/iterative_wrappers.jl index b37571cb5..294cfe7f1 100644 --- a/src/iterative_wrappers.jl +++ b/src/iterative_wrappers.jl @@ -284,8 +284,10 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...) # Copy the solution to the allocated output vector cacheval = @get_cacheval(cache, :KrylovJL_GMRES) - if cache.u !== cacheval.x + if cache.u !== cacheval.x && ArrayInterface.can_setindex(cache.u) cache.u .= cacheval.x + else + cache.u = convert(typeof(cache.u), cacheval.x) end return SciMLBase.build_linear_solution(alg, cache.u, resid, cache; diff --git a/test/static_arrays.jl b/test/static_arrays.jl index 6fc614e67..55158947e 100644 --- a/test/static_arrays.jl +++ b/test/static_arrays.jl @@ -6,6 +6,7 @@ b = SVector{5}(rand(5)) for alg in (nothing, LUFactorization(), SVDFactorization(), CholeskyFactorization(), KrylovJL_GMRES()) sol = solve(LinearProblem(A, b), alg) + @inferred solve(LinearProblem(A, b), alg) @test norm(A * sol .- b) < 1e-10 end @@ -13,6 +14,7 @@ A = SMatrix{7, 5}(rand(7, 5)) b = SVector{7}(rand(7)) for alg in (nothing, SVDFactorization(), KrylovJL_LSMR()) + @inferred solve(LinearProblem(A, b), alg) @test_nowarn solve(LinearProblem(A, b), alg) end @@ -20,5 +22,6 @@ A = SMatrix{5, 7}(rand(5, 7)) b = SVector{5}(rand(5)) for alg in (nothing, SVDFactorization(), KrylovJL_LSMR()) + @inferred solve(LinearProblem(A, b), alg) @test_nowarn solve(LinearProblem(A, b), alg) end From a31f99fb0714d35d982aa5144d9a5b7b72100431 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 12 Dec 2023 21:40:20 -0500 Subject: [PATCH 3/3] Special handling for staticarrays --- Project.toml | 4 ++- src/common.jl | 58 ++++++++++++++++++++++++------------------- src/factorization.jl | 17 ++++++++++--- test/static_arrays.jl | 19 +++++++++++++- 4 files changed, 67 insertions(+), 31 deletions(-) diff --git a/Project.toml b/Project.toml index 2fc2bae47..9d524be7c 100644 --- a/Project.toml +++ b/Project.toml @@ -59,6 +59,7 @@ LinearSolvePardisoExt = "Pardiso" LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools" [compat] +AllocCheck = "0.1" Aqua = "0.8" ArrayInterface = "7.4.11" BandedMatrices = "1" @@ -109,6 +110,7 @@ UnPack = "1" julia = "1.9" [extras] +AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" @@ -133,4 +135,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays"] +test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck"] diff --git a/src/common.jl b/src/common.jl index 3f59b3da2..b206598d5 100644 --- a/src/common.jl +++ b/src/common.jl @@ -169,31 +169,9 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, isfresh = true Tc = typeof(cacheval) - cache = LinearCache{ - typeof(A), - typeof(b), - typeof(u0_), - typeof(p), - typeof(alg), - Tc, - typeof(Pl), - typeof(Pr), - typeof(reltol), - typeof(assumptions.issq), - }(A, - b, - u0_, - p, - alg, - cacheval, - isfresh, - Pl, - Pr, - abstol, - reltol, - maxiters, - verbose, - assumptions) + cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc, + typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq)}(A, b, u0_, + p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions) return cache end @@ -210,3 +188,33 @@ end function SciMLBase.solve!(cache::LinearCache, args...; kwargs...) solve!(cache, cache.alg, args...; kwargs...) end + +# Special Case for StaticArrays +const StaticLinearProblem = LinearProblem{uType, iip, <:SMatrix, + <:Union{<:SMatrix, <:SVector}} where {uType, iip} + +function SciMLBase.solve(prob::StaticLinearProblem, args...; kwargs...) + return SciMLBase.solve(prob, nothing, args...; kwargs...) +end + +function SciMLBase.solve(prob::StaticLinearProblem, + alg::Union{Nothing, SciMLLinearSolveAlgorithm}, args...; kwargs...) + if alg === nothing || alg isa DirectLdiv! + u = prob.A \ prob.b + elseif alg isa LUFactorization + u = lu(prob.A) \ prob.b + elseif alg isa QRFactorization + u = qr(prob.A) \ prob.b + elseif alg isa CholeskyFactorization + u = cholesky(prob.A) \ prob.b + elseif alg isa NormalCholeskyFactorization + u = cholesky(Symmetric(prob.A' * prob.A)) \ (prob.A' * prob.b) + elseif alg isa SVDFactorization + u = svd(prob.A) \ prob.b + else + # Slower Path but handles all cases + cache = init(prob, alg, args...; kwargs...) + return solve!(cache) + end + return SciMLBase.build_linear_solution(alg, u, nothing, prob) +end diff --git a/src/factorization.jl b/src/factorization.jl index 2077d0639..59ade8266 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -975,11 +975,17 @@ function init_cacheval(alg::NormalCholeskyFactorization, ArrayInterface.cholesky_instance(convert(AbstractMatrix, A)) end +function init_cacheval(alg::NormalCholeskyFactorization, A::SMatrix, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) + return cholesky(Symmetric((A)' * A)) +end + function init_cacheval(alg::NormalCholeskyFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) A_ = convert(AbstractMatrix, A) - ArrayInterface.cholesky_instance(Symmetric((A)' * A, :L), alg.pivot) + return ArrayInterface.cholesky_instance(Symmetric((A)' * A), alg.pivot) end function init_cacheval(alg::NormalCholeskyFactorization, @@ -993,10 +999,10 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization; A = cache.A A = convert(AbstractMatrix, A) if cache.isfresh - if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray - fact = cholesky(Symmetric((A)' * A, :L); check = false) + if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray || A isa SMatrix + fact = cholesky(Symmetric((A)' * A); check = false) else - fact = cholesky(Symmetric((A)' * A, :L), alg.pivot; check = false) + fact = cholesky(Symmetric((A)' * A), alg.pivot; check = false) end cache.cacheval = fact cache.isfresh = false @@ -1004,6 +1010,9 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization; if A isa SparseMatrixCSC cache.u .= @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b) y = cache.u + elseif A isa StaticArray + cache.u = @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b) + y = cache.u else y = ldiv!(cache.u, @get_cacheval(cache, :NormalCholeskyFactorization), A' * cache.b) end diff --git a/test/static_arrays.jl b/test/static_arrays.jl index 55158947e..72676d482 100644 --- a/test/static_arrays.jl +++ b/test/static_arrays.jl @@ -1,13 +1,30 @@ using LinearSolve, StaticArrays, LinearAlgebra, Test +using AllocCheck A = SMatrix{5, 5}(Hermitian(rand(5, 5) + I)) b = SVector{5}(rand(5)) +@check_allocs __solve_no_alloc(A, b, alg) = solve(LinearProblem(A, b), alg) + +function __non_native_static_array_alg(alg) + return alg isa SVDFactorization || alg isa KrylovJL +end + for alg in (nothing, LUFactorization(), SVDFactorization(), CholeskyFactorization(), - KrylovJL_GMRES()) + NormalCholeskyFactorization(), KrylovJL_GMRES()) sol = solve(LinearProblem(A, b), alg) @inferred solve(LinearProblem(A, b), alg) @test norm(A * sol .- b) < 1e-10 + + if __non_native_static_array_alg(alg) + @test_broken __solve_no_alloc(A, b, alg) + else + @test_nowarn __solve_no_alloc(A, b, alg) + end + + cache = init(LinearProblem(A, b), alg) + sol = solve!(cache) + @test norm(A * sol .- b) < 1e-10 end A = SMatrix{7, 5}(rand(7, 5))