Skip to content

Commit

Permalink
Support StaticArrays Properly
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 4, 2023
1 parent e712202 commit 72bc014
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 7 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LinearSolve"
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
authors = ["SciML"]
version = "2.20.0"
version = "2.20.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand All @@ -26,6 +26,7 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Sparspak = "e56a9233-b9d6-4f03-8d0f-1825330902ac"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
Expand All @@ -48,14 +49,14 @@ LinearSolveBandedMatricesExt = "BandedMatrices"
LinearSolveBlockDiagonalsExt = "BlockDiagonals"
LinearSolveCUDAExt = "CUDA"
LinearSolveEnzymeExt = ["Enzyme", "EnzymeCore"]
LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"]
LinearSolveHYPREExt = "HYPRE"
LinearSolveIterativeSolversExt = "IterativeSolvers"
LinearSolveKernelAbstractionsExt = "KernelAbstractions"
LinearSolveKrylovKitExt = "KrylovKit"
LinearSolveMetalExt = "Metal"
LinearSolvePardisoExt = "Pardiso"
LinearSolveRecursiveArrayToolsExt = "RecursiveArrayTools"
LinearSolveFastAlmostBandedMatricesExt = ["FastAlmostBandedMatrices"]

[compat]
Aqua = "0.8"
Expand Down Expand Up @@ -102,6 +103,7 @@ Setfield = "1"
SparseArrays = "1.9"
Sparspak = "0.3.6"
Test = "1"
StaticArraysCore = "1"
UnPack = "1"
julia = "1.9"

Expand Down
2 changes: 2 additions & 0 deletions src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ PrecompileTools.@recompile_invalidations begin
using Requires
import InteractiveUtils

import StaticArraysCore: StaticArray, SVector, MVector, SMatrix, MMatrix

using LinearAlgebra: BlasInt, LU
using LinearAlgebra.LAPACK: require_one_based_indexing,
chkfinite, chkstride1,
Expand Down
7 changes: 7 additions & 0 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ function defaultalg(A, b, assump::OperatorAssumptions)
DefaultAlgorithmChoice.LUFactorization
end

# For static arrays GMRES allocates a lot. Use factorization
elseif A isa StaticArray
DefaultAlgorithmChoice.LUFactorization

Check warning on line 180 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L179-L180

Added lines #L179 - L180 were not covered by tests

# This catches the cases where a factorization overload could exist
# For example, BlockBandedMatrix
elseif A !== nothing && ArrayInterface.isstructured(A)
Expand All @@ -186,6 +190,9 @@ function defaultalg(A, b, assump::OperatorAssumptions)
end
elseif assump.condition === OperatorCondition.WellConditioned
DefaultAlgorithmChoice.NormalCholeskyFactorization
elseif A isa StaticArray
# Static Array doesn't have QR() \ b defined
return DefaultAlgorithmChoice.SVDFactorization

Check warning on line 195 in src/default.jl

View check run for this annotation

Codecov / codecov/patch

src/default.jl#L195

Added line #L195 was not covered by tests
elseif assump.condition === OperatorCondition.IllConditioned
if is_underdetermined(A)
# Underdetermined
Expand Down
28 changes: 23 additions & 5 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ end

_ldiv!(x, A, b) = ldiv!(x, A, b)

_ldiv!(x::MVector, A, b::SVector) = (x .= A \ b)
_ldiv!(::SVector, A, b::SVector) = (A \ b)

Check warning on line 14 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L13-L14

Added lines #L13 - L14 were not covered by tests

function _ldiv!(x::Vector, A::Factorization, b::Vector)
# workaround https://github.com/JuliaLang/julia/issues/43507
# Fallback if working with non-square matrices
Expand Down Expand Up @@ -88,6 +91,8 @@ function do_factorization(alg::LUFactorization, A, b, u)
if A isa AbstractSparseMatrixCSC
return lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
check = false)
elseif !ArrayInterface.can_setindex(typeof(A))
fact = lu(A, alg.pivot, check = false)

Check warning on line 95 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L95

Added line #L95 was not covered by tests
else
fact = lu!(A, alg.pivot, check = false)
end
Expand Down Expand Up @@ -172,10 +177,14 @@ end

function do_factorization(alg::QRFactorization, A, b, u)
A = convert(AbstractMatrix, A)
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
fact = qr!(A, alg.pivot)
if ArrayInterface.can_setindex(typeof(A))
if alg.inplace && !(A isa SparseMatrixCSC) && !(A isa GPUArraysCore.AbstractGPUArray)
fact = qr!(A, alg.pivot)
else
fact = qr(A) # CUDA.jl does not allow other args!
end
else
fact = qr(A) # CUDA.jl does not allow other args!
fact = qr(A, alg.pivot)

Check warning on line 187 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L187

Added line #L187 was not covered by tests
end
return fact
end
Expand Down Expand Up @@ -372,11 +381,15 @@ SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer())

function do_factorization(alg::SVDFactorization, A, b, u)
A = convert(AbstractMatrix, A)
fact = svd!(A; full = alg.full, alg = alg.alg)
if ArrayInterface.can_setindex(typeof(A))
fact = svd!(A; alg.full, alg.alg)
else
fact = svd(A; alg.full)

Check warning on line 387 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L387

Added line #L387 was not covered by tests
end
return fact
end

function init_cacheval(alg::SVDFactorization, A::Matrix, b, u, Pl, Pr,
function init_cacheval(alg::SVDFactorization, A::Union{Matrix, SMatrix}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
ArrayInterface.svd_instance(convert(AbstractMatrix, A))
Expand Down Expand Up @@ -1354,6 +1367,11 @@ function init_cacheval(::SparspakFactorization, A, b, u, Pl, Pr, maxiters::Int,
end
end

function init_cacheval(::SparspakFactorization, ::StaticArray, b, u, Pl, Pr,

Check warning on line 1370 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L1370

Added line #L1370 was not covered by tests
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
nothing

Check warning on line 1372 in src/factorization.jl

View check run for this annotation

Codecov / codecov/patch

src/factorization.jl#L1372

Added line #L1372 was not covered by tests
end

function SciMLBase.solve!(cache::LinearCache, alg::SparspakFactorization; kwargs...)
A = cache.A
if cache.isfresh
Expand Down

0 comments on commit 72bc014

Please sign in to comment.