Skip to content

Commit

Permalink
More fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Feb 6, 2025
1 parent 6bb7f86 commit 252cfe6
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 15 deletions.
4 changes: 1 addition & 3 deletions lib/cublas/CUBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,9 @@ function handle()
end

cublasSetStream_v2(new_handle, cuda.stream)
cublasSetPointerMode_v2(new_handle, CUBLAS_POINTER_MODE_DEVICE)
math_mode!(new_handle, cuda.math_mode)

# default to device pointers everywhere
cublasSetPointerMode_v2(state.handle, CUBLAS_POINTER_MODE_DEVICE)

(; handle=new_handle, cuda.stream, cuda.math_mode)
end
state = get!(states, cuda.context) do
Expand Down
25 changes: 14 additions & 11 deletions lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1445,13 +1445,9 @@ for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched, :cublasSgemmGrouped
Cptrs = unsafe_batch(reduce(vcat, C))

try
## XXX: does not seem to support device pointers
## XXX: cublasXgemmGroupedBatched does not seem to support device pointers
cublasSetPointerMode_v2(handle(), CUBLAS_POINTER_MODE_HOST)

mode = Ref{cublasPointerMode_t}()
cublasGetPointerMode_v2(handle(), mode)
@show mode[]

if CUBLAS.version() >= v"12.0"
$fname_64(handle(), transa, transb, m, n, k, alpha, Aptrs, lda,
Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
Expand Down Expand Up @@ -1507,12 +1503,19 @@ for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched, :cublasSgemmGrouped
Bptrs = unsafe_batch(B)
Cptrs = unsafe_batch(C)

if CUBLAS.version() >= v"12.0"
$fname_64(handle(), transa, transb, m, n, k, alpha, Aptrs, lda,
Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
else
$fname(handle(), transa, transb, m, n, k, alpha, Aptrs, lda,
Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
try
## XXX: cublasXgemmGroupedBatched does not seem to support device pointers
cublasSetPointerMode_v2(handle(), CUBLAS_POINTER_MODE_HOST)

if CUBLAS.version() >= v"12.0"
$fname_64(handle(), transa, transb, m, n, k, alpha, Aptrs, lda,
Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
else
$fname(handle(), transa, transb, m, n, k, alpha, Aptrs, lda,
Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
end
finally
cublasSetPointerMode_v2(handle(), CUBLAS_POINTER_MODE_DEVICE)
end
unsafe_free!(Cptrs)
unsafe_free!(Bptrs)
Expand Down
3 changes: 2 additions & 1 deletion test/libraries/cublas/level1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ k = 13
@test testf(*, rand(T, m)', rand(T, m))
@test testf(norm, rand(T, m))
@test testf(BLAS.asum, rand(T, m))

@test testf(axpy!, rand(), rand(T, m), rand(T, m))
#@test testf(LinearAlgebra.axpby!, rand(), rand(T, m), rand(), rand(T, m))
@test testf(LinearAlgebra.axpby!, rand(), rand(T, m), rand(), rand(T, m))
if T <: Complex
@test testf(dot, rand(T, m), rand(T, m))
x = rand(T, m)
Expand Down
5 changes: 5 additions & 0 deletions test/libraries/cublas/level3/gemm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ k = 13
end
@test_throws DimensionMismatch CUBLAS.gemm_batched('N','N',alpha,bd_A,bd_bad)
end

@testset "gemmBatchedEx!" begin
# C = (alpha*A)*B + beta*C
CUBLAS.gemmBatchedEx!('N','N',alpha,bd_A,bd_B,beta,bd_C)
Expand All @@ -236,6 +237,7 @@ k = 13
end
@test_throws DimensionMismatch CUBLAS.gemmBatchedEx!('N','N',alpha,bd_A,bd_bad,beta,bd_C)
end

nbatch = 10
bA = rand(elty, m, k, nbatch)
bB = rand(elty, k, n, nbatch)
Expand All @@ -256,6 +258,7 @@ k = 13
@test bC h_C
@test_throws DimensionMismatch CUBLAS.gemm_strided_batched!('N', 'N', alpha, bd_A, bd_B, beta, bd_bad)
end

@testset "gemmStridedBatchedEx!" begin
CUBLAS.gemmStridedBatchedEx!('N', 'N', alpha, bd_A, bd_B, beta, bd_C)
for i in 1:nbatch
Expand All @@ -265,6 +268,7 @@ k = 13
@test bC h_C
@test_throws DimensionMismatch CUBLAS.gemmStridedBatchedEx!('N', 'N', alpha, bd_A, bd_B, beta, bd_bad)
end

@testset "gemm_strided_batched" begin
bd_C = CUBLAS.gemm_strided_batched('N', 'N', bd_A, bd_B)

Expand Down Expand Up @@ -402,6 +406,7 @@ k = 13
@test C Array(dC) rtol=rtol
end
end

# also test an unsupported combination (falling back to GPUArrays)
if VERSION < v"1.11-" # JuliaGPU/CUDA.jl#2441
AT=BFloat16
Expand Down

0 comments on commit 252cfe6

Please sign in to comment.