-
Notifications
You must be signed in to change notification settings - Fork 81
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce Abstract types for sparse arrays
- Loading branch information
1 parent
61193e1
commit 285b64c
Showing
9 changed files
with
156 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
export JLSparseMatrixCSC | ||
|
||
struct JLSparseMatrixCSC{Tv,Ti<:Integer} <: AbstractGPUSparseMatrixCSC{Tv,Ti} | ||
m::Int # Number of rows | ||
n::Int # Number of columns | ||
colptr::JLVector{Ti} # Column i is in colptr[i]:(colptr[i+1]-1) | ||
rowval::JLVector{Ti} # Row indices of stored values | ||
nzval::JLVector{Tv} # Stored values, typically nonzeros | ||
|
||
function JLSparseMatrixCSC{Tv,Ti}(m::Integer, n::Integer, colptr::JLVector{Ti}, | ||
rowval::JLVector{Ti}, nzval::JLVector{Tv}) where {Tv,Ti<:Integer} | ||
SparseArrays.sparse_check_Ti(m, n, Ti) | ||
GPUArrays._goodbuffers_csc(m, n, colptr, rowval, nzval) || | ||
throw(ArgumentError("Invalid buffers for JLSparseMatrixCSC construction n=$n, colptr=$(summary(colptr)), rowval=$(summary(rowval)), nzval=$(summary(nzval))")) | ||
new(Int(m), Int(n), colptr, rowval, nzval) | ||
end | ||
end | ||
function JLSparseMatrixCSC(m::Integer, n::Integer, colptr::JLVector, rowval::JLVector, nzval::JLVector) | ||
Tv = eltype(nzval) | ||
Ti = promote_type(eltype(colptr), eltype(rowval)) | ||
SparseArrays.sparse_check_Ti(m, n, Ti) | ||
# SparseArrays.sparse_check(n, colptr, rowval, nzval) # TODO: this uses scalar indexing | ||
# silently shorten rowval and nzval to usable index positions. | ||
maxlen = abs(widemul(m, n)) | ||
isbitstype(Ti) && (maxlen = min(maxlen, typemax(Ti) - 1)) | ||
length(rowval) > maxlen && resize!(rowval, maxlen) | ||
length(nzval) > maxlen && resize!(nzval, maxlen) | ||
JLSparseMatrixCSC{Tv,Ti}(m, n, colptr, rowval, nzval) | ||
end | ||
|
||
JLSparseMatrixCSC(A::SparseMatrixCSC) = JLSparseMatrixCSC(A.m, A.n, JLVector(A.colptr), JLVector(A.rowval), JLVector(A.nzval)) | ||
|
||
Base.copy(A::JLSparseMatrixCSC) = JLSparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), copy(A.nzval)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
## General Sparse Matrix | ||
|
||
Base.size(A::AbstractGPUSparseMatrix) = (A.m, A.n) | ||
|
||
SparseArrays.nonzeros(A::AbstractGPUSparseMatrix) = A.nzval | ||
SparseArrays.getnzval(A::AbstractGPUSparseMatrix) = nonzeros(A) | ||
SparseArrays.nnz(A::AbstractGPUSparseMatrix) = length(nzval(A)) | ||
|
||
function LinearAlgebra.rmul!(A::AbstractGPUSparseMatrix, x::Number) | ||
rmul!(SparseArrays.getnzval(A), x) | ||
return A | ||
end | ||
|
||
function LinearAlgebra.lmul!(x::Number, A::AbstractGPUSparseMatrix) | ||
lmul!(x, SparseArrays.getnzval(A)) | ||
return A | ||
end | ||
|
||
## CSC Matrix | ||
|
||
SparseArrays.getcolptr(A::AbstractGPUSparseMatrixCSC) = A.colptr | ||
SparseArrays.rowvals(A::AbstractGPUSparseMatrixCSC) = A.rowval | ||
SparseArrays.getrowval(A::AbstractGPUSparseMatrixCSC) = rowvals(A) | ||
# SparseArrays.nzrange(A::AbstractGPUSparseMatrixCSC, col::Integer) = getcolptr(A)[col]:(getcolptr(A)[col+1]-1) # TODO: this uses scalar indexing | ||
|
||
function _goodbuffers_csc(m, n, colptr, rowval, nzval) | ||
return (length(colptr) == n + 1 && length(rowval) == length(nzval)) | ||
# TODO: also add the condition that colptr[end] - 1 == length(nzval) (allowscalar?) | ||
end | ||
|
||
@inline function LinearAlgebra.mul!(C::AbstractGPUVector, A::AnyGPUSparseMatrixCSC, B::AbstractGPUVector, α::Number, β::Number) | ||
return LinearAlgebra.generic_matvecmul!(C, LinearAlgebra.wrapper_char(A), LinearAlgebra._unwrap(A), B, LinearAlgebra.MulAddMul(α, β)) | ||
end | ||
|
||
@inline function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA, A::AbstractGPUSparseMatrixCSC, B::AbstractGPUVector, _add::LinearAlgebra.MulAddMul) | ||
return SparseArrays.spdensemul!(C, tA, 'N', A, B, _add) | ||
end | ||
|
||
Base.@constprop :aggressive function SparseArrays.spdensemul!(C::AbstractGPUVecOrMat, tA, tB, A::AbstractGPUSparseMatrixCSC, B::AbstractGPUVecOrMat, _add::LinearAlgebra.MulAddMul) | ||
if tA == 'N' | ||
return _spmatmul!(C, A, wrap(B, tB), _add.alpha, _add.beta) | ||
else | ||
throw(ArgumentError("tA different from 'N' not yet supported")) | ||
end | ||
end | ||
|
||
function _spmatmul!(C::AbstractGPUVecOrMat, A::AbstractGPUSparseMatrixCSC, B::AbstractGPUVecOrMat, α::Number, β::Number) | ||
size(A, 2) == size(B, 1) || | ||
throw(DimensionMismatch("second dimension of A, $(size(A,2)), does not match the first dimension of B, $(size(B,1))")) | ||
size(A, 1) == size(C, 1) || | ||
throw(DimensionMismatch("first dimension of A, $(size(A,1)), does not match the first dimension of C, $(size(C,1))")) | ||
size(B, 2) == size(C, 2) || | ||
throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))")) | ||
|
||
A_colptr = getcolptr(A) | ||
A_rowval = rowvals(A) | ||
A_nzval = getnzval(A) | ||
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β) | ||
|
||
@kernel function kernel_spmatmul!(C, @Const(A_colptr), @Const(A_rowval), @Const(A_nzval), @Const(B)) | ||
k, col = @index(Global, NTuple) | ||
|
||
@inbounds axj = B[col, k] * α | ||
@inbounds for j in A_colptr[col]:(A_colptr[col+1]-1) # nzrange(A, col) | ||
KernelAbstractions.@atomic C[A_rowval[j], k] += A_nzval[j] * axj | ||
end | ||
end | ||
|
||
backend_C = KernelAbstractions.get_backend(C) | ||
backend_A = KernelAbstractions.get_backend(A_nzval) | ||
backend_B = KernelAbstractions.get_backend(B) | ||
|
||
backend_A == backend_B == backend_C || throw(ArgumentError("All arrays must be on the same backend")) | ||
|
||
kernel! = kernel_spmatmul!(backend_A) | ||
kernel!(C, A_colptr, A_rowval, A_nzval, B; ndrange=(size(C, 2), size(A, 2))) | ||
|
||
return C | ||
end |