Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proper handling of static arrays #444

Merged
merged 3 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,13 @@ LinearSolvePardisoExt = "Pardiso"
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"

[compat]
AllocCheck = "0.1"
Aqua = "0.8"
ArrayInterface = "7.4.11"
BandedMatrices = "1"
BlockDiagonals = "0.1"
ConcreteStructs = "0.2"
CUDA = "5"
ConcreteStructs = "0.2"
DocStringExtensions = "0.9"
EnumX = "1"
Enzyme = "0.11"
Expand All @@ -77,15 +78,15 @@ GPUArraysCore = "0.1"
HYPRE = "1.4.0"
InteractiveUtils = "1.6"
IterativeSolvers = "0.9.3"
Libdl = "1.6"
LinearAlgebra = "1.9"
JET = "0.8"
KLU = "0.3.0, 0.4"
KernelAbstractions = "0.9"
Krylov = "0.9"
KrylovKit = "0.6"
Metal = "0.5"
Libdl = "1.6"
LinearAlgebra = "1.9"
MPI = "0.20"
Metal = "0.5"
MultiFloats = "1"
Pardiso = "0.5"
Pkg = "1"
Expand All @@ -102,13 +103,14 @@ SciMLOperators = "0.3"
Setfield = "1"
SparseArrays = "1.9"
Sparspak = "0.3.6"
StaticArraysCore = "1"
StaticArrays = "1"
StaticArraysCore = "1"
Test = "1"
UnPack = "1"
julia = "1.9"

[extras]
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
Expand All @@ -133,4 +135,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays"]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck"]
78 changes: 44 additions & 34 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@
default_alias_A(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true
default_alias_b(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true

function __init_u0_from_Ab(A, b)
u0 = similar(b, size(A, 2))
fill!(u0, false)
return u0
end
__init_u0_from_Ab(::SMatrix{S1, S2}, b) where {S1, S2} = zeros(SVector{S2, eltype(b)})

function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
args...;
alias_A = default_alias_A(alg, prob.A, prob.b),
Expand All @@ -133,7 +140,7 @@
kwargs...)
@unpack A, b, u0, p = prob

A = if alias_A
A = if alias_A || A isa SMatrix
A
elseif A isa Array || A isa SparseMatrixCSC
copy(A)
Expand All @@ -143,55 +150,28 @@

b = if b isa SparseArrays.AbstractSparseArray && !(A isa Diagonal)
Array(b) # the solution to a linear solve will always be dense!
elseif alias_b
elseif alias_b || b isa SVector
b
elseif b isa Array || b isa SparseMatrixCSC
copy(b)
else
deepcopy(b)
end

u0 = if u0 !== nothing
u0
else
u0 = similar(b, size(A, 2))
fill!(u0, false)
end
u0_ = u0 !== nothing ? u0 : __init_u0_from_Ab(A, b)

# Guard against type mismatch for user-specified reltol/abstol
reltol = real(eltype(prob.b))(reltol)
abstol = real(eltype(prob.b))(abstol)

cacheval = init_cacheval(alg, A, b, u0, Pl, Pr, maxiters, abstol, reltol, verbose,
cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose,
assumptions)
isfresh = true
Tc = typeof(cacheval)

cache = LinearCache{
typeof(A),
typeof(b),
typeof(u0),
typeof(p),
typeof(alg),
Tc,
typeof(Pl),
typeof(Pr),
typeof(reltol),
typeof(assumptions.issq),
}(A,
b,
u0,
p,
alg,
cacheval,
isfresh,
Pl,
Pr,
abstol,
reltol,
maxiters,
verbose,
assumptions)
cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq)}(A, b, u0_,
p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions)
return cache
end

Expand All @@ -208,3 +188,33 @@
function SciMLBase.solve!(cache::LinearCache, args...; kwargs...)
solve!(cache, cache.alg, args...; kwargs...)
end

# Special Case for StaticArrays
const StaticLinearProblem = LinearProblem{uType, iip, <:SMatrix,
<:Union{<:SMatrix, <:SVector}} where {uType, iip}

function SciMLBase.solve(prob::StaticLinearProblem, args...; kwargs...)
return SciMLBase.solve(prob, nothing, args...; kwargs...)

Check warning on line 197 in src/common.jl

View check run for this annotation

Codecov / codecov/patch

src/common.jl#L196-L197

Added lines #L196 - L197 were not covered by tests
end

function SciMLBase.solve(prob::StaticLinearProblem,
alg::Union{Nothing, SciMLLinearSolveAlgorithm}, args...; kwargs...)
if alg === nothing || alg isa DirectLdiv!
u = prob.A \ prob.b
elseif alg isa LUFactorization
u = lu(prob.A) \ prob.b
elseif alg isa QRFactorization
u = qr(prob.A) \ prob.b

Check warning on line 207 in src/common.jl

View check run for this annotation

Codecov / codecov/patch

src/common.jl#L207

Added line #L207 was not covered by tests
elseif alg isa CholeskyFactorization
u = cholesky(prob.A) \ prob.b
elseif alg isa NormalCholeskyFactorization
u = cholesky(Symmetric(prob.A' * prob.A)) \ (prob.A' * prob.b)
elseif alg isa SVDFactorization
u = svd(prob.A) \ prob.b
else
# Slower Path but handles all cases
cache = init(prob, alg, args...; kwargs...)
return solve!(cache)
end
return SciMLBase.build_linear_solution(alg, u, nothing, prob)
end
18 changes: 9 additions & 9 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
defaultalg(A, b, OperatorAssumptions(issq, assump.condition))
end

function defaultalg(A::SMatrix{S1, S2}, b, assump::OperatorAssumptions{Bool}) where {S1, S2}
if S1 == S2
return LUFactorization()
else
return SVDFactorization() # QR(...) \ b is not defined currently

Check warning on line 43 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L43

Added line #L43 was not covered by tests
end
end

function defaultalg(A::Tridiagonal, b, assump::OperatorAssumptions{Bool})
if assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization)
Expand Down Expand Up @@ -175,10 +183,6 @@
DefaultAlgorithmChoice.LUFactorization
end

# For static arrays GMRES allocates a lot. Use factorization
elseif A isa StaticArray
DefaultAlgorithmChoice.LUFactorization

# This catches the cases where a factorization overload could exist
# For example, BlockBandedMatrix
elseif A !== nothing && ArrayInterface.isstructured(A)
Expand All @@ -190,9 +194,6 @@
end
elseif assump.condition === OperatorCondition.WellConditioned
DefaultAlgorithmChoice.NormalCholeskyFactorization
elseif A isa StaticArray
# Static Array doesn't have QR() \ b defined
DefaultAlgorithmChoice.SVDFactorization
elseif assump.condition === OperatorCondition.IllConditioned
if is_underdetermined(A)
# Underdetermined
Expand Down Expand Up @@ -269,8 +270,7 @@
args...;
assumptions = OperatorAssumptions(issquare(prob.A)),
kwargs...)
alg = defaultalg(prob.A, prob.b, assumptions)
SciMLBase.init(prob, alg, args...; assumptions, kwargs...)
SciMLBase.init(prob, defaultalg(prob.A, prob.b, assumptions), args...; assumptions, kwargs...)
end

function SciMLBase.solve!(cache::LinearCache, alg::Nothing,
Expand Down
21 changes: 13 additions & 8 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,6 @@
function init_cacheval(alg::CholeskyFactorization, A::SMatrix{S1, S2}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions) where {S1, S2}
# StaticArrays doesn't have the pivot argument. Prevent generic fallback.
# CholeskyFactorization is part of DefaultLinearSolver, so it is possible that `A` is
# not Hermitian.
(!issquare(A) || !ishermitian(A)) && return nothing
cholesky(A)
end

Expand Down Expand Up @@ -979,11 +975,17 @@
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
end

function init_cacheval(alg::NormalCholeskyFactorization, A::SMatrix, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
return cholesky(Symmetric((A)' * A))
end

function init_cacheval(alg::NormalCholeskyFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
A_ = convert(AbstractMatrix, A)
ArrayInterface.cholesky_instance(Symmetric((A)' * A, :L), alg.pivot)
return ArrayInterface.cholesky_instance(Symmetric((A)' * A), alg.pivot)
end

function init_cacheval(alg::NormalCholeskyFactorization,
Expand All @@ -997,17 +999,20 @@
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray
fact = cholesky(Symmetric((A)' * A, :L); check = false)
if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray || A isa SMatrix
fact = cholesky(Symmetric((A)' * A); check = false)
else
fact = cholesky(Symmetric((A)' * A, :L), alg.pivot; check = false)
fact = cholesky(Symmetric((A)' * A), alg.pivot; check = false)
end
cache.cacheval = fact
cache.isfresh = false
end
if A isa SparseMatrixCSC
cache.u .= @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b)
y = cache.u
elseif A isa StaticArray
cache.u = @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b)

Check warning on line 1014 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L1014

Added line #L1014 was not covered by tests
y = cache.u
else
y = ldiv!(cache.u, @get_cacheval(cache, :NormalCholeskyFactorization), A' * cache.b)
end
Expand Down
4 changes: 3 additions & 1 deletion src/iterative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,10 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)

# Copy the solution to the allocated output vector
cacheval = @get_cacheval(cache, :KrylovJL_GMRES)
if cache.u !== cacheval.x
if cache.u !== cacheval.x && ArrayInterface.can_setindex(cache.u)
cache.u .= cacheval.x
else
cache.u = convert(typeof(cache.u), cacheval.x)
end

return SciMLBase.build_linear_solution(alg, cache.u, resid, cache;
Expand Down
24 changes: 22 additions & 2 deletions test/static_arrays.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,44 @@
using LinearSolve, StaticArrays, LinearAlgebra
using LinearSolve, StaticArrays, LinearAlgebra, Test
using AllocCheck

A = SMatrix{5, 5}(Hermitian(rand(5, 5) + I))
b = SVector{5}(rand(5))

@check_allocs __solve_no_alloc(A, b, alg) = solve(LinearProblem(A, b), alg)

function __non_native_static_array_alg(alg)
return alg isa SVDFactorization || alg isa KrylovJL
end

for alg in (nothing, LUFactorization(), SVDFactorization(), CholeskyFactorization(),
KrylovJL_GMRES())
NormalCholeskyFactorization(), KrylovJL_GMRES())
sol = solve(LinearProblem(A, b), alg)
@inferred solve(LinearProblem(A, b), alg)
@test norm(A * sol .- b) < 1e-10

if __non_native_static_array_alg(alg)
@test_broken __solve_no_alloc(A, b, alg)
else
@test_nowarn __solve_no_alloc(A, b, alg)
end

cache = init(LinearProblem(A, b), alg)
sol = solve!(cache)
@test norm(A * sol .- b) < 1e-10
end

A = SMatrix{7, 5}(rand(7, 5))
b = SVector{7}(rand(7))

for alg in (nothing, SVDFactorization(), KrylovJL_LSMR())
@inferred solve(LinearProblem(A, b), alg)
@test_nowarn solve(LinearProblem(A, b), alg)
end

A = SMatrix{5, 7}(rand(5, 7))
b = SVector{5}(rand(5))

for alg in (nothing, SVDFactorization(), KrylovJL_LSMR())
@inferred solve(LinearProblem(A, b), alg)
@test_nowarn solve(LinearProblem(A, b), alg)
end
Loading