Skip to content

Commit

Permalink
Merge pull request #444 from avik-pal/ap/static_arrays
Browse files Browse the repository at this point in the history
Proper handling of static arrays
  • Loading branch information
ChrisRackauckas authored Dec 13, 2023
2 parents 9e22a72 + a31f99f commit caaedab
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 60 deletions.
14 changes: 8 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,13 @@ LinearSolvePardisoExt = "Pardiso"
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"

[compat]
AllocCheck = "0.1"
Aqua = "0.8"
ArrayInterface = "7.4.11"
BandedMatrices = "1"
BlockDiagonals = "0.1"
ConcreteStructs = "0.2"
CUDA = "5"
ConcreteStructs = "0.2"
DocStringExtensions = "0.9"
EnumX = "1"
Enzyme = "0.11"
Expand All @@ -77,15 +78,15 @@ GPUArraysCore = "0.1"
HYPRE = "1.4.0"
InteractiveUtils = "1.6"
IterativeSolvers = "0.9.3"
Libdl = "1.6"
LinearAlgebra = "1.9"
JET = "0.8"
KLU = "0.3.0, 0.4"
KernelAbstractions = "0.9"
Krylov = "0.9"
KrylovKit = "0.6"
Metal = "0.5"
Libdl = "1.6"
LinearAlgebra = "1.9"
MPI = "0.20"
Metal = "0.5"
MultiFloats = "1"
Pardiso = "0.5"
Pkg = "1"
Expand All @@ -102,13 +103,14 @@ SciMLOperators = "0.3"
Setfield = "1"
SparseArrays = "1.9"
Sparspak = "0.3.6"
StaticArraysCore = "1"
StaticArrays = "1"
StaticArraysCore = "1"
Test = "1"
UnPack = "1"
julia = "1.9"

[extras]
AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
Expand All @@ -133,4 +135,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays"]
test = ["Aqua", "Test", "IterativeSolvers", "InteractiveUtils", "JET", "KrylovKit", "Pkg", "Random", "SafeTestsets", "MultiFloats", "ForwardDiff", "HYPRE", "MPI", "BlockDiagonals", "Enzyme", "FiniteDiff", "BandedMatrices", "FastAlmostBandedMatrices", "StaticArrays", "AllocCheck"]
78 changes: 44 additions & 34 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ default_alias_b(::Any, ::Any, ::Any) = false
default_alias_A(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true
default_alias_b(::AbstractKrylovSubspaceMethod, ::Any, ::Any) = true

function __init_u0_from_Ab(A, b)
u0 = similar(b, size(A, 2))
fill!(u0, false)
return u0
end
__init_u0_from_Ab(::SMatrix{S1, S2}, b) where {S1, S2} = zeros(SVector{S2, eltype(b)})

function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
args...;
alias_A = default_alias_A(alg, prob.A, prob.b),
Expand All @@ -133,7 +140,7 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
kwargs...)
@unpack A, b, u0, p = prob

A = if alias_A
A = if alias_A || A isa SMatrix
A
elseif A isa Array || A isa SparseMatrixCSC
copy(A)
Expand All @@ -143,55 +150,28 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,

b = if b isa SparseArrays.AbstractSparseArray && !(A isa Diagonal)
Array(b) # the solution to a linear solve will always be dense!
elseif alias_b
elseif alias_b || b isa SVector
b
elseif b isa Array || b isa SparseMatrixCSC
copy(b)
else
deepcopy(b)
end

u0 = if u0 !== nothing
u0
else
u0 = similar(b, size(A, 2))
fill!(u0, false)
end
u0_ = u0 !== nothing ? u0 : __init_u0_from_Ab(A, b)

# Guard against type mismatch for user-specified reltol/abstol
reltol = real(eltype(prob.b))(reltol)
abstol = real(eltype(prob.b))(abstol)

cacheval = init_cacheval(alg, A, b, u0, Pl, Pr, maxiters, abstol, reltol, verbose,
cacheval = init_cacheval(alg, A, b, u0_, Pl, Pr, maxiters, abstol, reltol, verbose,
assumptions)
isfresh = true
Tc = typeof(cacheval)

cache = LinearCache{
typeof(A),
typeof(b),
typeof(u0),
typeof(p),
typeof(alg),
Tc,
typeof(Pl),
typeof(Pr),
typeof(reltol),
typeof(assumptions.issq),
}(A,
b,
u0,
p,
alg,
cacheval,
isfresh,
Pl,
Pr,
abstol,
reltol,
maxiters,
verbose,
assumptions)
cache = LinearCache{typeof(A), typeof(b), typeof(u0_), typeof(p), typeof(alg), Tc,
typeof(Pl), typeof(Pr), typeof(reltol), typeof(assumptions.issq)}(A, b, u0_,
p, alg, cacheval, isfresh, Pl, Pr, abstol, reltol, maxiters, verbose, assumptions)
return cache
end

Expand All @@ -208,3 +188,33 @@ end
function SciMLBase.solve!(cache::LinearCache, args...; kwargs...)
solve!(cache, cache.alg, args...; kwargs...)
end

# Special Case for StaticArrays
const StaticLinearProblem = LinearProblem{uType, iip, <:SMatrix,
<:Union{<:SMatrix, <:SVector}} where {uType, iip}

function SciMLBase.solve(prob::StaticLinearProblem, args...; kwargs...)
return SciMLBase.solve(prob, nothing, args...; kwargs...)
end

function SciMLBase.solve(prob::StaticLinearProblem,
alg::Union{Nothing, SciMLLinearSolveAlgorithm}, args...; kwargs...)
if alg === nothing || alg isa DirectLdiv!
u = prob.A \ prob.b
elseif alg isa LUFactorization
u = lu(prob.A) \ prob.b
elseif alg isa QRFactorization
u = qr(prob.A) \ prob.b
elseif alg isa CholeskyFactorization
u = cholesky(prob.A) \ prob.b
elseif alg isa NormalCholeskyFactorization
u = cholesky(Symmetric(prob.A' * prob.A)) \ (prob.A' * prob.b)
elseif alg isa SVDFactorization
u = svd(prob.A) \ prob.b
else
# Slower Path but handles all cases
cache = init(prob, alg, args...; kwargs...)
return solve!(cache)
end
return SciMLBase.build_linear_solution(alg, u, nothing, prob)
end
18 changes: 9 additions & 9 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ function defaultalg(A, b, assump::OperatorAssumptions{Nothing})
defaultalg(A, b, OperatorAssumptions(issq, assump.condition))
end

function defaultalg(A::SMatrix{S1, S2}, b, assump::OperatorAssumptions{Bool}) where {S1, S2}
if S1 == S2
return LUFactorization()
else
return SVDFactorization() # QR(...) \ b is not defined currently
end
end

function defaultalg(A::Tridiagonal, b, assump::OperatorAssumptions{Bool})
if assump.issq
DefaultLinearSolver(DefaultAlgorithmChoice.LUFactorization)
Expand Down Expand Up @@ -175,10 +183,6 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool})
DefaultAlgorithmChoice.LUFactorization
end

# For static arrays GMRES allocates a lot. Use factorization
elseif A isa StaticArray
DefaultAlgorithmChoice.LUFactorization

# This catches the cases where a factorization overload could exist
# For example, BlockBandedMatrix
elseif A !== nothing && ArrayInterface.isstructured(A)
Expand All @@ -190,9 +194,6 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool})
end
elseif assump.condition === OperatorCondition.WellConditioned
DefaultAlgorithmChoice.NormalCholeskyFactorization
elseif A isa StaticArray
# Static Array doesn't have QR() \ b defined
DefaultAlgorithmChoice.SVDFactorization
elseif assump.condition === OperatorCondition.IllConditioned
if is_underdetermined(A)
# Underdetermined
Expand Down Expand Up @@ -269,8 +270,7 @@ function SciMLBase.init(prob::LinearProblem, alg::Nothing,
args...;
assumptions = OperatorAssumptions(issquare(prob.A)),
kwargs...)
alg = defaultalg(prob.A, prob.b, assumptions)
SciMLBase.init(prob, alg, args...; assumptions, kwargs...)
SciMLBase.init(prob, defaultalg(prob.A, prob.b, assumptions), args...; assumptions, kwargs...)
end

function SciMLBase.solve!(cache::LinearCache, alg::Nothing,
Expand Down
21 changes: 13 additions & 8 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,6 @@ end
function init_cacheval(alg::CholeskyFactorization, A::SMatrix{S1, S2}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions) where {S1, S2}
# StaticArrays doesn't have the pivot argument. Prevent generic fallback.
# CholeskyFactorization is part of DefaultLinearSolver, so it is possible that `A` is
# not Hermitian.
(!issquare(A) || !ishermitian(A)) && return nothing
cholesky(A)
end

Expand Down Expand Up @@ -979,11 +975,17 @@ function init_cacheval(alg::NormalCholeskyFactorization,
ArrayInterface.cholesky_instance(convert(AbstractMatrix, A))
end

function init_cacheval(alg::NormalCholeskyFactorization, A::SMatrix, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
return cholesky(Symmetric((A)' * A))
end

function init_cacheval(alg::NormalCholeskyFactorization, A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
A_ = convert(AbstractMatrix, A)
ArrayInterface.cholesky_instance(Symmetric((A)' * A, :L), alg.pivot)
return ArrayInterface.cholesky_instance(Symmetric((A)' * A), alg.pivot)
end

function init_cacheval(alg::NormalCholeskyFactorization,
Expand All @@ -997,17 +999,20 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization;
A = cache.A
A = convert(AbstractMatrix, A)
if cache.isfresh
if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray
fact = cholesky(Symmetric((A)' * A, :L); check = false)
if A isa SparseMatrixCSC || A isa GPUArraysCore.AbstractGPUArray || A isa SMatrix
fact = cholesky(Symmetric((A)' * A); check = false)
else
fact = cholesky(Symmetric((A)' * A, :L), alg.pivot; check = false)
fact = cholesky(Symmetric((A)' * A), alg.pivot; check = false)
end
cache.cacheval = fact
cache.isfresh = false
end
if A isa SparseMatrixCSC
cache.u .= @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b)
y = cache.u
elseif A isa StaticArray
cache.u = @get_cacheval(cache, :NormalCholeskyFactorization) \ (A' * cache.b)
y = cache.u
else
y = ldiv!(cache.u, @get_cacheval(cache, :NormalCholeskyFactorization), A' * cache.b)
end
Expand Down
4 changes: 3 additions & 1 deletion src/iterative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,10 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)

# Copy the solution to the allocated output vector
cacheval = @get_cacheval(cache, :KrylovJL_GMRES)
if cache.u !== cacheval.x
if cache.u !== cacheval.x && ArrayInterface.can_setindex(cache.u)
cache.u .= cacheval.x
else
cache.u = convert(typeof(cache.u), cacheval.x)
end

return SciMLBase.build_linear_solution(alg, cache.u, resid, cache;
Expand Down
24 changes: 22 additions & 2 deletions test/static_arrays.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,44 @@
using LinearSolve, StaticArrays, LinearAlgebra
using LinearSolve, StaticArrays, LinearAlgebra, Test
using AllocCheck

A = SMatrix{5, 5}(Hermitian(rand(5, 5) + I))
b = SVector{5}(rand(5))

@check_allocs __solve_no_alloc(A, b, alg) = solve(LinearProblem(A, b), alg)

function __non_native_static_array_alg(alg)
return alg isa SVDFactorization || alg isa KrylovJL
end

for alg in (nothing, LUFactorization(), SVDFactorization(), CholeskyFactorization(),
KrylovJL_GMRES())
NormalCholeskyFactorization(), KrylovJL_GMRES())
sol = solve(LinearProblem(A, b), alg)
@inferred solve(LinearProblem(A, b), alg)
@test norm(A * sol .- b) < 1e-10

if __non_native_static_array_alg(alg)
@test_broken __solve_no_alloc(A, b, alg)
else
@test_nowarn __solve_no_alloc(A, b, alg)
end

cache = init(LinearProblem(A, b), alg)
sol = solve!(cache)
@test norm(A * sol .- b) < 1e-10
end

A = SMatrix{7, 5}(rand(7, 5))
b = SVector{7}(rand(7))

for alg in (nothing, SVDFactorization(), KrylovJL_LSMR())
@inferred solve(LinearProblem(A, b), alg)
@test_nowarn solve(LinearProblem(A, b), alg)
end

A = SMatrix{5, 7}(rand(5, 7))
b = SVector{5}(rand(5))

for alg in (nothing, SVDFactorization(), KrylovJL_LSMR())
@inferred solve(LinearProblem(A, b), alg)
@test_nowarn solve(LinearProblem(A, b), alg)
end

0 comments on commit caaedab

Please sign in to comment.