diff --git a/Project.toml b/Project.toml index 84a79b45f..ed8af6b17 100644 --- a/Project.toml +++ b/Project.toml @@ -59,12 +59,13 @@ LinearSolvePardisoExt = "Pardiso" LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools" [compat] +AllocCheck = "0.1" Aqua = "0.8" ArrayInterface = "7.4.11" BandedMatrices = "1" BlockDiagonals = "0.1" -ConcreteStructs = "0.2" CUDA = "5" +ConcreteStructs = "0.2" DocStringExtensions = "0.9" EnumX = "1" Enzyme = "0.11" @@ -77,15 +78,15 @@ GPUArraysCore = "0.1" HYPRE = "1.4.0" InteractiveUtils = "1.6" IterativeSolvers = "0.9.3" -Libdl = "1.6" -LinearAlgebra = "1.9" JET = "0.8" KLU = "0.3.0, 0.4" KernelAbstractions = "0.9" Krylov = "0.9" KrylovKit = "0.6" -Metal = "0.5" +Libdl = "1.6" +LinearAlgebra = "1.9" MPI = "0.20" +Metal = "0.5" MultiFloats = "1" Pardiso = "0.5" Pkg = "1" @@ -102,13 +103,14 @@ SciMLOperators = "0.3" Setfield = "1" SparseArrays = "1.9" Sparspak = "0.3.6" -StaticArraysCore = "1" StaticArrays = "1" +StaticArraysCore = "1" Test = "1" UnPack = "1" julia = "1.9" [extras] +AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" @@ -133,4 +135,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays"] +test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck"] diff --git a/src/common.jl b/src/common.jl index 791ab91c8..b206598d5 100644 --- a/src/common.jl +++ b/src/common.jl @@ -119,6 +119,13 @@ default_alias_b(::Any, ::Any, ::Any) = false default_alias_A(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true default_alias_b(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true +function __init_u0_from_Ab(A, b) + u0 = similar(b, size(A, 2)) + fill!(u0, false) + return u0 +end +__init_u0_from_Ab(::SMatrix{S1, S2}, b) where {S1, S2} = zeros(SVector{S2, eltype(b)}) + function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, args...; alias_A = default_alias_A(alg, prob.A, prob.b), @@ -133,7 +140,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, kwargs...) @unpack A, b, u0, p = prob - A = if alias_A + A = if alias_A || A isa SMatrix A elseif A isa Array || A isa SparseMatrixCSC copy(A) @@ -143,7 +150,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, b = if b isa SparseArrays.AbstractSparseArray && !(A isa Diagonal) Array(b) # the solution to a linear solve will always be dense! - elseif alias_b + elseif alias_b || b isa SVector b elseif b isa Array || b isa SparseMatrixCSC copy(b) @@ -151,47 +158,20 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm, deepcopy(b) end - u0 = if u0 !== nothing - u0 - else - u0 = similar(b, size(A, 2)) - fill!(u0, false) - end + u0_ = u0 !== nothing ? u0 : __init_u0_from_Ab(A, b) # Guard against type mismatch for user-specified reltol/abstol reltol = real(eltype(prob.b))(reltol) abstol = real(eltype(prob.b))(abstol) - cacheval = init_cacheval(alg, A, b, u0, Pl, Pr, maxiters, abstol, reltol, verbose, + cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose, assumptions) isfresh = true Tc = typeof(cacheval) - cache = LinearCache{ - typeof(A), - typeof(b), - typeof(u0), - typeof(p), - typeof(alg), - Tc, - typeof(Pl), - typeof(Pr), - typeof(reltol), - typeof(assumptions.issq), - }(A, - b, - u0, - p, - alg, - cacheval, - isfresh, - Pl, - Pr, - abstol, - reltol, - maxiters, - verbose, - assumptions) + cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc, + typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq)}(A, b, u0_, + p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions) return cache end @@ -208,3 +188,33 @@ end function SciMLBase.solve!(cache::LinearCache, args...; kwargs...) solve!(cache, cache.alg, args...; kwargs...) end + +# Special Case for StaticArrays +const StaticLinearProblem = LinearProblem{uType, iip, <:SMatrix, + <:Union{<:SMatrix, <:SVector}} where {uType, iip} + +function SciMLBase.solve(prob::StaticLinearProblem, args...; kwargs...) + return SciMLBase.solve(prob, nothing, args...; kwargs...) +end + +function SciMLBase.solve(prob::StaticLinearProblem, + alg::Union{Nothing, SciMLLinearSolveAlgorithm}, args...; kwargs...) + if alg === nothing || alg isa DirectLdiv! + u = prob.A \ prob.b + elseif alg isa LUFactorization + u = lu(prob.A) \ prob.b + elseif alg isa QRFactorization + u = qr(prob.A) \ prob.b + elseif alg isa CholeskyFactorization + u = cholesky(prob.A) \ prob.b + elseif alg isa NormalCholeskyFactorization + u = cholesky(Symmetric(prob.A' * prob.A)) \ (prob.A' * prob.b) + elseif alg isa SVDFactorization + u = svd(prob.A) \ prob.b + else + # Slower Path but handles all cases + cache = init(prob, alg, args...; kwargs...) + return solve!(cache) + end + return SciMLBase.build_linear_solution(alg, u, nothing, prob) +end diff --git a/src/default.jl b/src/default.jl index 2fd2f5ddf..4c4b43a1f 100644 --- a/src/default.jl +++ b/src/default.jl @@ -36,6 +36,14 @@ function defaultalg(A, b, assump::OperatorAssumptions{Nothing}) defaultalg(A, b, OperatorAssumptions(issq, assump.condition)) end +function defaultalg(A::SMatrix{S1, S2}, b, assump::OperatorAssumptions{Bool}) where {S1, S2} + if S1 == S2 + return LUFactorization() + else + return SVDFactorization() # QR(...) \ b is not defined currently + end +end + function defaultalg(A::Tridiagonal, b, assump::OperatorAssumptions{Bool}) if assump.issq DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization) @@ -175,10 +183,6 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool}) DefaultAlgorithmChoice.LUFactorization end - # For static arrays GMRES allocates a lot. Use factorization - elseif A isa StaticArray - DefaultAlgorithmChoice.LUFactorization - # This catches the cases where a factorization overload could exist # For example, BlockBandedMatrix elseif A !== nothing && ArrayInterface.isstructured(A) @@ -190,9 +194,6 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool}) end elseif assump.condition === OperatorCondition.WellConditioned DefaultAlgorithmChoice.NormalCholeskyFactorization - elseif A isa StaticArray - # Static Array doesn't have QR() \ b defined - DefaultAlgorithmChoice.SVDFactorization elseif assump.condition === OperatorCondition.IllConditioned if is_underdetermined(A) # Underdetermined @@ -269,8 +270,7 @@ function SciMLBase.init(prob::LinearProblem, alg::Nothing, args...; assumptions = OperatorAssumptions(issquare(prob.A)), kwargs...) - alg = defaultalg(prob.A, prob.b, assumptions) - SciMLBase.init(prob, alg, args...; assumptions, kwargs...) + SciMLBase.init(prob, defaultalg(prob.A, prob.b, assumptions), args...; assumptions, kwargs...) end function SciMLBase.solve!(cache::LinearCache, alg::Nothing, diff --git a/src/factorization.jl b/src/factorization.jl index 5ff403155..59ade8266 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -215,10 +215,6 @@ end function init_cacheval(alg::CholeskyFactorization, A::SMatrix{S1, S2}, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) where {S1, S2} - # StaticArrays doesn't have the pivot argument. Prevent generic fallback. - # CholeskyFactorization is part of DefaultLinearSolver, so it is possible that `A` is - # not Hermitian. - (!issquare(A) || !ishermitian(A)) && return nothing cholesky(A) end @@ -979,11 +975,17 @@ function init_cacheval(alg::NormalCholeskyFactorization, ArrayInterface.cholesky_instance(convert(AbstractMatrix, A)) end +function init_cacheval(alg::NormalCholeskyFactorization, A::SMatrix, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, + assumptions::OperatorAssumptions) + return cholesky(Symmetric((A)' * A)) +end + function init_cacheval(alg::NormalCholeskyFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) A_ = convert(AbstractMatrix, A) - ArrayInterface.cholesky_instance(Symmetric((A)' * A, :L), alg.pivot) + return ArrayInterface.cholesky_instance(Symmetric((A)' * A), alg.pivot) end function init_cacheval(alg::NormalCholeskyFactorization, @@ -997,10 +999,10 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization; A = cache.A A = convert(AbstractMatrix, A) if cache.isfresh - if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray - fact = cholesky(Symmetric((A)' * A, :L); check = false) + if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray || A isa SMatrix + fact = cholesky(Symmetric((A)' * A); check = false) else - fact = cholesky(Symmetric((A)' * A, :L), alg.pivot; check = false) + fact = cholesky(Symmetric((A)' * A), alg.pivot; check = false) end cache.cacheval = fact cache.isfresh = false @@ -1008,6 +1010,9 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization; if A isa SparseMatrixCSC cache.u .= @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b) y = cache.u + elseif A isa StaticArray + cache.u = @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b) + y = cache.u else y = ldiv!(cache.u, @get_cacheval(cache, :NormalCholeskyFactorization), A' * cache.b) end diff --git a/src/iterative_wrappers.jl b/src/iterative_wrappers.jl index b37571cb5..294cfe7f1 100644 --- a/src/iterative_wrappers.jl +++ b/src/iterative_wrappers.jl @@ -284,8 +284,10 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...) # Copy the solution to the allocated output vector cacheval = @get_cacheval(cache, :KrylovJL_GMRES) - if cache.u !== cacheval.x + if cache.u !== cacheval.x && ArrayInterface.can_setindex(cache.u) cache.u .= cacheval.x + else + cache.u = convert(typeof(cache.u), cacheval.x) end return SciMLBase.build_linear_solution(alg, cache.u, resid, cache; diff --git a/test/static_arrays.jl b/test/static_arrays.jl index 5c6ccf252..72676d482 100644 --- a/test/static_arrays.jl +++ b/test/static_arrays.jl @@ -1,11 +1,29 @@ -using LinearSolve, StaticArrays, LinearAlgebra +using LinearSolve, StaticArrays, LinearAlgebra, Test +using AllocCheck A = SMatrix{5, 5}(Hermitian(rand(5, 5) + I)) b = SVector{5}(rand(5)) +@check_allocs __solve_no_alloc(A, b, alg) = solve(LinearProblem(A, b), alg) + +function __non_native_static_array_alg(alg) + return alg isa SVDFactorization || alg isa KrylovJL +end + for alg in (nothing, LUFactorization(), SVDFactorization(), CholeskyFactorization(), - KrylovJL_GMRES()) + NormalCholeskyFactorization(), KrylovJL_GMRES()) sol = solve(LinearProblem(A, b), alg) + @inferred solve(LinearProblem(A, b), alg) + @test norm(A * sol .- b) < 1e-10 + + if __non_native_static_array_alg(alg) + @test_broken __solve_no_alloc(A, b, alg) + else + @test_nowarn __solve_no_alloc(A, b, alg) + end + + cache = init(LinearProblem(A, b), alg) + sol = solve!(cache) @test norm(A * sol .- b) < 1e-10 end @@ -13,6 +31,7 @@ A = SMatrix{7, 5}(rand(7, 5)) b = SVector{7}(rand(7)) for alg in (nothing, SVDFactorization(), KrylovJL_LSMR()) + @inferred solve(LinearProblem(A, b), alg) @test_nowarn solve(LinearProblem(A, b), alg) end @@ -20,5 +39,6 @@ A = SMatrix{5, 7}(rand(5, 7)) b = SVector{5}(rand(5)) for alg in (nothing, SVDFactorization(), KrylovJL_LSMR()) + @inferred solve(LinearProblem(A, b), alg) @test_nowarn solve(LinearProblem(A, b), alg) end