Skip to content

Commit

Permalink
[CUSPARSE] Support CuSparseMatrixBSR in the generic mm! (#2639)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison authored Feb 6, 2025
1 parent 031d7b9 commit 4d85f27
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 111 deletions.
45 changes: 33 additions & 12 deletions lib/cusparse/generic.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,27 @@
# generic APIs

export gather!, scatter!, axpby!, rot!
export vv!, sv!, sm!, gemv, gemm, gemm!, sddmm!
export vv!, sv!, sm!, mv!, mm!, gemv, gemm, gemm!, sddmm!
export bmm!

"""
mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrix, X::CuVector, beta::Number, Y::CuVector, index::SparseChar)
Performs `Y = alpha * op(A) * X + beta * Y`, where `op` can be nothing (`transa = N`),
tranpose (`transa = T`) or conjugate transpose (`transa = C`).
`X` and `Y` are dense vectors.
"""
function mv! end

"""
mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrix, B::CuMatrix, beta::Number, C::CuMatrix, index::SparseChar)
mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuMatrix, B::Union{CuSparseMatrixCSC,CuSparseMatrixCSR,CuSparseMatrixCOO}, beta::Number, C::CuMatrix, index::SparseChar)
Performs `C = alpha * op(A) * op(B) + beta * C`, where `op` can be nothing (`transa = N`),
tranpose (`transa = T`) or conjugate transpose (`transa = C`).
"""
function mm! end

## API functions

function sparsetodense(A::Union{CuSparseMatrixCSC{T},CuSparseMatrixCSR{T},CuSparseMatrixCOO{T}}, index::SparseChar, algo::cusparseSparseToDenseAlg_t=CUSPARSE_SPARSETODENSE_ALG_DEFAULT) where {T}
Expand Down Expand Up @@ -191,9 +209,11 @@ function mv!(transa::SparseChar, alpha::Number, A::Union{CuSparseMatrixCSC{TA},C
return Y
end

function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::Union{CuSparseMatrixCSC{T},CuSparseMatrixCSR{T},CuSparseMatrixCOO{T}},
function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrix{T},
B::DenseCuMatrix{T}, beta::Number, C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_SPMM_ALG_DEFAULT) where {T}

(A isa CuSparseMatrixBSR) && (CUSPARSE.version() < v"12.5.1") && throw(ErrorException("This operation is not supported by the current CUDA version."))

# Support transa = 'C' and `transb = 'C' for real matrices
transa = T <: Real && transa == 'C' ? 'T' : transa
transb = T <: Real && transb == 'C' ? 'T' : transb
Expand Down Expand Up @@ -235,10 +255,10 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::Union{CuS
# cusparseCsrSetStridedBatch(obj, batchsize, 0, nnz(A))
# end

# Set default buffer for small matrices (10000 chosen arbitrarly)
# Set default buffer for small matrices (1000 chosen arbitrarly)
# Otherwise tries to allocate 120TB of memory (see #2296)
function bufferSize()
out = Ref{Csize_t}(10000)
out = Ref{Csize_t}(1000)
cusparseSpMM_bufferSize(
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
descC, T, algo, out)
Expand Down Expand Up @@ -274,7 +294,6 @@ function bmm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse
throw(ErrorException("Batched dense-matrix times batched sparse-matrix (bmm!) requires a CUSPARSE version ≥ 11.7.2 (yours: $(CUSPARSE.version()))."))
end


# Support transa = 'C' and `transb = 'C' for real matrices
transa = T <: Real && transa == 'C' ? 'T' : transa
transb = T <: Real && transb == 'C' ? 'T' : transb
Expand Down Expand Up @@ -313,10 +332,10 @@ function bmm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse
strideC = stride(C, 3)
cusparseDnMatSetStridedBatch(descC, b, strideC)

# Set default buffer for small matrices (10000 chosen arbitrarly)
# Set default buffer for small matrices (1000 chosen arbitrarly)
# Otherwise tries to allocate 120TB of memory (see #2296)
function bufferSize()
out = Ref{Csize_t}(10000)
out = Ref{Csize_t}(1000)
cusparseSpMM_bufferSize(
handle(), transa, transb, Ref{T}(alpha), descA, descB, Ref{T}(beta),
descC, T, algo, out)
Expand All @@ -337,10 +356,11 @@ function bmm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparse
end

function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::DenseCuMatrix{T},
B::Union{CuSparseMatrixCSC{T},CuSparseMatrixCSR{T},CuSparseMatrixCOO{T}},
beta::Number, C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_SPMM_ALG_DEFAULT) where {T}
B::Union{CuSparseMatrixCSC{T},CuSparseMatrixCSR{T},CuSparseMatrixCOO{T}}, beta::Number,
C::DenseCuMatrix{T}, index::SparseChar, algo::cusparseSpMMAlg_t=CUSPARSE_SPMM_ALG_DEFAULT) where {T}

CUSPARSE.version() < v"11.7.4" && throw(ErrorException("This operation is not supported by the current CUDA version."))

# Support transa = 'C' and `transb = 'C' for real matrices
transa = T <: Real && transa == 'C' ? 'T' : transa
transb = T <: Real && transb == 'C' ? 'T' : transb
Expand Down Expand Up @@ -373,10 +393,10 @@ function mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::DenseCuMa
descB = CuSparseMatrixDescriptor(B, index, transposed=true)
descC = CuDenseMatrixDescriptor(C, transposed=true)

# Set default buffer for small matrices (10000 chosen arbitrarly)
# Set default buffer for small matrices (1000 chosen arbitrarly)
# Otherwise tries to allocate 120TB of memory (see #2296)
function bufferSize()
out = Ref{Csize_t}(10000)
out = Ref{Csize_t}(1000)
cusparseSpMM_bufferSize(
handle(), transb, transa, Ref{T}(alpha), descB, descA, Ref{T}(beta),
descC, T, algo, out)
Expand Down Expand Up @@ -736,9 +756,10 @@ function sm!(transa::SparseChar, transb::SparseChar, uplo::SparseChar, diag::Spa
end

function sddmm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::DenseCuMatrix{T}, B::DenseCuMatrix{T},
beta::Number, C::CuSparseMatrixCSR{T}, index::SparseChar, algo::cusparseSDDMMAlg_t=CUSPARSE_SDDMM_ALG_DEFAULT) where {T}
beta::Number, C::Union{CuSparseMatrixCSR{T},CuSparseMatrixBSR{T}}, index::SparseChar, algo::cusparseSDDMMAlg_t=CUSPARSE_SDDMM_ALG_DEFAULT) where {T}

CUSPARSE.version() < v"11.4.1" && throw(ErrorException("This operation is not supported by the current CUDA version."))
(C isa CuSparseMatrixBSR) && (CUSPARSE.version() < v"12.1.0") && throw(ErrorException("This operation is not supported by the current CUDA version."))

# Support transa = 'C' and `transb = 'C' for real matrices
transa = T <: Real && transa == 'C' ? 'T' : transa
Expand Down
11 changes: 1 addition & 10 deletions lib/cusparse/level2.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
# sparse linear algebra functions that perform operations between sparse matrices and dense
# vectors

export mv!, sv2!, sv2, gemvi!

"""
mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrix, X::CuVector, beta::Number, Y::CuVector, index::SparseChar)
Performs `Y = alpha * op(A) * X + beta * Y`, where `op` can be nothing (`transa = N`),
tranpose (`transa = T`) or conjugate transpose (`transa = C`).
`X` and `Y` are dense vectors.
"""
mv!(transa::SparseChar, alpha::Number, A::CuSparseMatrix, X::CuVector, beta::Number, Y::CuVector, index::SparseChar)
export sv2!, sv2, gemvi!

for (fname,elty) in ((:cusparseSbsrmv, :Float32),
(:cusparseDbsrmv, :Float64),
Expand Down
55 changes: 1 addition & 54 deletions lib/cusparse/level3.jl
Original file line number Diff line number Diff line change
@@ -1,60 +1,7 @@
# sparse linear algebra functions that perform operations between sparse and (usually tall)
# dense matrices

export mm!, sm2!, sm2

"""
mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrix, B::CuMatrix, beta::Number, C::CuMatrix, index::SparseChar)
Performs `C = alpha * op(A) * op(B) + beta * C`, where `op` can be nothing (`transa = N`),
tranpose (`transa = T`) or conjugate transpose (`transa = C`).
`B` and `C` are dense matrices.
"""
mm!(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrix, B::CuMatrix, beta::Number, C::CuMatrix, index::SparseChar)

# bsrmm
for (fname,elty) in ((:cusparseSbsrmm, :Float32),
(:cusparseDbsrmm, :Float64),
(:cusparseCbsrmm, :ComplexF32),
(:cusparseZbsrmm, :ComplexF64))
@eval begin
function mm!(transa::SparseChar,
transb::SparseChar,
alpha::Number,
A::CuSparseMatrixBSR{$elty},
B::StridedCuMatrix{$elty},
beta::Number,
C::StridedCuMatrix{$elty},
index::SparseChar)

# Support transa = 'C' and `transb = 'C' for real matrices
transa = $elty <: Real && transa == 'C' ? 'T' : transa
transb = $elty <: Real && transb == 'C' ? 'T' : transb

desc = CuMatrixDescriptor('G', 'L', 'N', index)
m,k = size(A)
mb = cld(m, A.blockDim)
kb = cld(k, A.blockDim)
n = size(C)[2]
if transa == 'N' && transb == 'N'
chkmmdims(B,C,k,n,m,n)
elseif transa == 'N' && transb != 'N'
chkmmdims(B,C,n,k,m,n)
elseif transa != 'N' && transb == 'N'
chkmmdims(B,C,m,n,k,n)
elseif transa != 'N' && transb != 'N'
chkmmdims(B,C,n,m,k,n)
end
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))
$fname(handle(), A.dir,
transa, transb, mb, n, kb, A.nnzb,
alpha, desc, nonzeros(A),A.rowPtr, A.colVal,
A.blockDim, B, ldb, beta, C, ldc)
C
end
end
end
export sm2!, sm2

"""
sm2!(transa::SparseChar, transxy::SparseChar, uplo::SparseChar, diag::SparseChar, alpha::BlasFloat, A::CuSparseMatrix, X::CuMatrix, index::SparseChar)
Expand Down
12 changes: 6 additions & 6 deletions lib/cusparse/libcusparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5415,9 +5415,9 @@ end
@gcsafe_ccall libcusparse.cusparseCreateBsr(spMatDescr::Ptr{cusparseSpMatDescr_t},
brows::Int64, bcols::Int64, bnnz::Int64,
rowBlockSize::Int64, colBlockSize::Int64,
bsrRowOffsets::Ptr{Cvoid},
bsrColInd::Ptr{Cvoid},
bsrValues::Ptr{Cvoid},
bsrRowOffsets::CuPtr{Cvoid},
bsrColInd::CuPtr{Cvoid},
bsrValues::CuPtr{Cvoid},
bsrRowOffsetsType::cusparseIndexType_t,
bsrColIndType::cusparseIndexType_t,
idxBase::cusparseIndexBase_t,
Expand All @@ -5434,9 +5434,9 @@ end
brows::Int64, bcols::Int64,
bnnz::Int64, rowBlockDim::Int64,
colBlockDim::Int64,
bsrRowOffsets::Ptr{Cvoid},
bsrColInd::Ptr{Cvoid},
bsrValues::Ptr{Cvoid},
bsrRowOffsets::CuPtr{Cvoid},
bsrColInd::CuPtr{Cvoid},
bsrValues::CuPtr{Cvoid},
bsrRowOffsetsType::cusparseIndexType_t,
bsrColIndType::cusparseIndexType_t,
idxBase::cusparseIndexBase_t,
Expand Down
10 changes: 10 additions & 0 deletions res/wrap/cusparse.toml
Original file line number Diff line number Diff line change
Expand Up @@ -990,3 +990,13 @@ needs_context = false

[api.cusparseSpMMOp.argtypes]
2 = "CuPtr{Cvoid}"

[api.cusparseCreateBsr.argtypes]
7 = "CuPtr{Cvoid}"
8 = "CuPtr{Cvoid}"
9 = "CuPtr{Cvoid}"

[api.cusparseCreateConstBsr.argtypes]
7 = "CuPtr{Cvoid}"
8 = "CuPtr{Cvoid}"
9 = "CuPtr{Cvoid}"
3 changes: 1 addition & 2 deletions test/libraries/cusparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -908,8 +908,7 @@ end
alpha = rand(elty)
beta = rand(elty)
@testset "$(typeof(d_A))" for d_A in [CuSparseMatrixCSR(A),
CuSparseMatrixCSC(A),
CuSparseMatrixBSR(A, blockdim)]
CuSparseMatrixCSC(A)]
d_B = CuArray(B)
d_C = CuArray(C)
@test_throws DimensionMismatch CUSPARSE.mm!('N','T',alpha,d_A,d_B,beta,d_C,'O')
Expand Down
Loading

0 comments on commit 4d85f27

Please sign in to comment.