From c004af31c18b215afeb45b95cd1f2d45c088ef63 Mon Sep 17 00:00:00 2001 From: KlausC Date: Thu, 2 Mar 2023 18:12:26 +0100 Subject: [PATCH] Use LinearAlgebra traited matrix operations --- .gitignore | 9 +++++ src/SparseArrays.jl | 6 ++++ src/sparseconvert.jl | 78 ++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 86 insertions(+), 7 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..cf887d59 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +# ignore org +*.org +.vscode + +# ignore newly added matrices +Manifest.toml + +*.mem +*.cov diff --git a/src/SparseArrays.jl b/src/SparseArrays.jl index 2d2a3dce..0667adba 100644 --- a/src/SparseArrays.jl +++ b/src/SparseArrays.jl @@ -64,6 +64,11 @@ TransposeFact = isdefined(LinearAlgebra, :TransposeFactorization) ? LinearAlgebra.TransposeFactorization : Transpose +struct SparseStorage{T} <: LinearAlgebra.AbstractStorageTrait{T} + data::T + SparseStorage(x::T) where T = new{T}(x) +end + include("readonly.jl") include("abstractsparse.jl") include("sparsematrix.jl") @@ -73,6 +78,7 @@ include("higherorderfns.jl") include("linalg.jl") include("deprecated.jl") +LinearAlgebra.storage_trait(::Type{<:AbstractSparseArray}) = SparseStorage # Convert from 0-based to 1-based indices diff --git a/src/sparseconvert.jl b/src/sparseconvert.jl index e235d332..eac7dc8a 100644 --- a/src/sparseconvert.jl +++ b/src/sparseconvert.jl @@ -1,6 +1,64 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license -import LinearAlgebra: AbstractTriangular +import LinearAlgebra: AbstractTriangular, DenseStorage + +# direct to appropriate sparse versions +function (*)(ta::SparseStorage, tb::DenseStorage) + A = unwrap1a(ta.data) + B = tb.data + TS = promote_op(matprod, eltype(A), eltype(B)) + mul!(similar(B, TS, (size(A, 1), size(B, 2))), A, B, true, false) +end +function (*)(ta::DenseStorage, tb::SparseStorage) + A = ta.data + B = unwrap1a(tb.data) + TS = promote_op(matprod, eltype(A), eltype(B)) + mul!(similar(A, TS, (size(A, 1), size(B, 2))), A, B, true, false) +end +function (*)(ta::SparseStorage, tb::SparseStorage) + A = unwrap1(ta.data) + B = unwrap1(tb.data) + if A === ta.data && B === tb.data + DenseStorage(A) * DenseStorage(B) + else + A * B + end +end + +function (\)(ta::SparseStorage, B::AbstractVecOrMat) + A = unwrap1a(ta.data) + if A === ta.data + DenseStorage(A) \ B + else + A \ B + end +end + +unwrap1(A::AbstractArray) = _sparsewrap(A) +unwrap1a(A::Union{Transpose,Adjoint}) = _sparsewrap(A) +unwrap1a(A::AbstractArray) = unwrap(A) +unwrap2a(A::AbstractTriangular{<:Any,<:Union{Transpose,Adjoint}}) = _sparsewrap(A, 2) +unwrap2a(A::AbstractArray) = _sparsewrap(A) + +# as long as these are not defined in "linalg.jl" fall back to generic algo +import LinearAlgebra: _lmul!, _ldiv!, _rmul!, _rdiv! + +function lmul!(ta::SparseStorage, B::AbstractVecOrMat) + A = unwrap2a(ta.data) + _lmul!(A, B) +end +function ldiv!(ta::SparseStorage, B::AbstractVecOrMat) +A = unwrap1(ta.data) +_ldiv!(A, B) +end +function rmul!(A::AbstractMatrix, tb::SparseStorage) + B = unwrap2a(tb.data) + _rmul!(A, B) +end +function rdiv!(A::AbstractMatrix, tb::SparseStorage) +B = unwrap2a(tb.data) +_rdiv!(A, B) +end """ SparseMatrixCSCSymmHerm @@ -53,19 +111,26 @@ for wr in (Symmetric, Hermitian, end # convert parent and re-wrap in same wrapper -_sparsewrap(A::Symmetric) = Symmetric(_sparsem(parent(A)), A.uplo == 'U' ? :U : :L) -_sparsewrap(A::Hermitian) = Hermitian(_sparsem(parent(A)), A.uplo == 'U' ? :U : :L) -_sparsewrap(A::SubArray) = SubArray(_sparsem(parent(A)), A.indices) +_sparsewrap(A) = _sparsewrap(A, 1) +_sparsewrap(A, i::Int) = i <= 0 ? _sparsem(A) : _sparsewr(A, i) +_sparsewr(A::AbstractArray, _::Int) = _sparsem(A) +for ty in ( Symmetric, Hermitian) + @eval function _sparsewr(A::$ty, i::Int) + $ty(_sparsewrap(parent(A), i - 1), sym_uplo(A.uplo)) + end +end +_sparsewr(A::SubArray, i::Int) = SubArray(_sparsewrap(parent(A), i - 1), A.indices) for ty in ( LowerTriangular, UnitLowerTriangular, UpperTriangular, UnitUpperTriangular, Transpose, Adjoint) - @eval _sparsewrap(A::$ty) = $ty(_sparsem(parent(A))) + @eval _sparsewr(A::$ty, i::Int) = $ty(_sparsewrap(parent(A), i - 1)) end -function _sparsewrap(A::Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}) +function _sparsewr(A::Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}, _::Int) dropzeros!(sparse(A)) end + """ unwrap(A::AbstractMatrix) @@ -283,4 +348,3 @@ function _sparse_gen(m, n, newcolptr, newrowval, newnzval) newcolptr[1] = 1 SparseMatrixCSC(m, n, newcolptr, newrowval, newnzval) end -