Skip to content

Commit

Permalink
Change the definition of the methods
Browse files Browse the repository at this point in the history
  • Loading branch information
albertomercurio committed Dec 16, 2024
1 parent 4d59b39 commit b204d84
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
13 changes: 13 additions & 0 deletions lib/JLArrays/src/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ SparseVector(V::JLSparseVector) = SparseVector(V.n, Vector(V.nzind), Vector(V.nz

Base.copy(V::JLSparseVector) = JLSparseVector(V.n, copy(V.nzind), copy(V.nzval))

Base.length(V::JLSparseVector) = V.n
Base.size(V::JLSparseVector) = (V.n,)

SparseArrays.nonzeros(V::JLSparseVector) = V.nzval
SparseArrays.nonzeroinds(V::JLSparseVector) = V.nzind

## SparseMatrixCSC

struct JLSparseMatrixCSC{Tv,Ti<:Integer} <: AbstractGPUSparseMatrixCSC{Tv,Ti}
Expand Down Expand Up @@ -57,3 +63,10 @@ JLSparseMatrixCSC(A::SparseMatrixCSC) = JLSparseMatrixCSC(A.m, A.n, JLVector(A.c
SparseMatrixCSC(A::JLSparseMatrixCSC) = SparseMatrixCSC(A.m, A.n, Vector(A.colptr), Vector(A.rowval), Vector(A.nzval))

Base.copy(A::JLSparseMatrixCSC) = JLSparseMatrixCSC(A.m, A.n, copy(A.colptr), copy(A.rowval), copy(A.nzval))

Base.size(A::JLSparseMatrixCSC) = (A.m, A.n)
Base.length(A::JLSparseMatrixCSC) = A.m * A.n

SparseArrays.nonzeros(A::JLSparseMatrixCSC) = A.nzval
SparseArrays.getcolptr(A::JLSparseMatrixCSC) = A.colptr
SparseArrays.rowvals(A::JLSparseMatrixCSC) = A.rowval
27 changes: 8 additions & 19 deletions src/host/sparse.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
## Sparse Vector

Base.length(V::AbstractGPUSparseVector) = V.n
Base.size(V::AbstractGPUSparseVector) = (V.n,)

SparseArrays.nonzeros(V::AbstractGPUSparseVector) = V.nzval
SparseArrays.getnzval(V::AbstractGPUSparseVector) = nonzeros(V)
SparseArrays.nnz(V::AbstractGPUSparseVector) = length(nzval(V))
SparseArrays.nonzeroinds(V::AbstractGPUSparseVector) = V.nzind

function Base.sizehint!(V::AbstractGPUSparseVector, newlen::Integer)
sizehint!(nonzeroinds(V), newlen)
Expand All @@ -29,26 +24,23 @@ LinearAlgebra.dot(x::AbstractGPUVector{T}, y::AbstractGPUSparseVector{T}) where

## General Sparse Matrix

Base.size(A::AbstractGPUSparseMatrix) = (A.m, A.n)
KernelAbstractions.get_backend(A::AbstractGPUSparseMatrix) = KernelAbstractions.get_backend(getnzval(A))

SparseArrays.nonzeros(A::AbstractGPUSparseMatrix) = A.nzval
SparseArrays.getnzval(A::AbstractGPUSparseMatrix) = nonzeros(A)
SparseArrays.nnz(A::AbstractGPUSparseMatrix) = length(nzval(A))
SparseArrays.nnz(A::AbstractGPUSparseMatrix) = length(getnzval(A))

function LinearAlgebra.rmul!(A::AbstractGPUSparseMatrix, x::Number)
rmul!(SparseArrays.getnzval(A), x)
rmul!(getnzval(A), x)
return A
end

function LinearAlgebra.lmul!(x::Number, A::AbstractGPUSparseMatrix)
lmul!(x, SparseArrays.getnzval(A))
lmul!(x, getnzval(A))
return A
end

## CSC Matrix

SparseArrays.getcolptr(A::AbstractGPUSparseMatrixCSC) = A.colptr
SparseArrays.rowvals(A::AbstractGPUSparseMatrixCSC) = A.rowval
SparseArrays.getrowval(A::AbstractGPUSparseMatrixCSC) = rowvals(A)
# SparseArrays.nzrange(A::AbstractGPUSparseMatrixCSC, col::Integer) = getcolptr(A)[col]:(getcolptr(A)[col+1]-1) # TODO: this uses scalar indexing

Expand Down Expand Up @@ -81,22 +73,19 @@ function _spmatmul!(C::AbstractGPUVecOrMat, A::AbstractGPUSparseMatrixCSC, B::Ab
size(B, 2) == size(C, 2) ||
throw(DimensionMismatch("second dimension of B, $(size(B,2)), does not match the second dimension of C, $(size(C,2))"))

A_colptr = getcolptr(A)
A_rowval = rowvals(A)
A_nzval = getnzval(A)
β != one(β) && LinearAlgebra._rmul_or_fill!(C, β)

@kernel function kernel_spmatmul!(C, @Const(A_colptr), @Const(A_rowval), @Const(A_nzval), @Const(B))
@kernel function kernel_spmatmul!(C, @Const(A), @Const(B))
k, col = @index(Global, NTuple)

@inbounds axj = B[col, k] * α
@inbounds for j in A_colptr[col]:(A_colptr[col+1]-1) # nzrange(A, col)
KernelAbstractions.@atomic C[A_rowval[j], k] += A_nzval[j] * axj
@inbounds for j in getcolptr(A)[col]:(getcolptr(A)[col+1]-1) # nzrange(A, col)
KernelAbstractions.@atomic C[rowvals(A)[j], k] += getnzval(A)[j] * axj
end
end

backend_C = KernelAbstractions.get_backend(C)
backend_A = KernelAbstractions.get_backend(A_nzval)
backend_A = KernelAbstractions.get_backend(A)
backend_B = KernelAbstractions.get_backend(B)

backend_A == backend_B == backend_C || throw(ArgumentError("All arrays must be on the same backend"))
Expand Down

0 comments on commit b204d84

Please sign in to comment.