From 252cfe61fef04df462ca0a6b66fbcdc3299c8260 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 6 Feb 2025 14:45:44 +0100 Subject: [PATCH] More fixes. --- lib/cublas/CUBLAS.jl | 4 +--- lib/cublas/wrappers.jl | 25 ++++++++++++++----------- test/libraries/cublas/level1.jl | 3 ++- test/libraries/cublas/level3/gemm.jl | 5 +++++ 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/lib/cublas/CUBLAS.jl b/lib/cublas/CUBLAS.jl index b25369a830..30ef9e4df5 100644 --- a/lib/cublas/CUBLAS.jl +++ b/lib/cublas/CUBLAS.jl @@ -104,11 +104,9 @@ function handle() end cublasSetStream_v2(new_handle, cuda.stream) + cublasSetPointerMode_v2(new_handle, CUBLAS_POINTER_MODE_DEVICE) math_mode!(new_handle, cuda.math_mode) - # default to device pointers everywhere - cublasSetPointerMode_v2(state.handle, CUBLAS_POINTER_MODE_DEVICE) - (; handle=new_handle, cuda.stream, cuda.math_mode) end state = get!(states, cuda.context) do diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index 6c056049db..85299a0fd5 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -1445,13 +1445,9 @@ for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched, :cublasSgemmGrouped Cptrs = unsafe_batch(reduce(vcat, C)) try - ## XXX: does not seem to support device pointers + ## XXX: cublasXgemmGroupedBatched does not seem to support device pointers cublasSetPointerMode_v2(handle(), CUBLAS_POINTER_MODE_HOST) - mode = Ref{cublasPointerMode_t}() - cublasGetPointerMode_v2(handle(), mode) - @show mode[] - if CUBLAS.version() >= v"12.0" $fname_64(handle(), transa, transb, m, n, k, alpha, Aptrs, lda, Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size) @@ -1507,12 +1503,19 @@ for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched, :cublasSgemmGrouped Bptrs = unsafe_batch(B) Cptrs = unsafe_batch(C) - if CUBLAS.version() >= v"12.0" - $fname_64(handle(), transa, transb, m, n, k, alpha, Aptrs, lda, - Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size) - else - $fname(handle(), transa, transb, m, n, k, alpha, Aptrs, lda, - Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size) + try + ## XXX: cublasXgemmGroupedBatched does not seem to support device pointers + cublasSetPointerMode_v2(handle(), CUBLAS_POINTER_MODE_HOST) + + if CUBLAS.version() >= v"12.0" + $fname_64(handle(), transa, transb, m, n, k, alpha, Aptrs, lda, + Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size) + else + $fname(handle(), transa, transb, m, n, k, alpha, Aptrs, lda, + Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size) + end + finally + cublasSetPointerMode_v2(handle(), CUBLAS_POINTER_MODE_DEVICE) end unsafe_free!(Cptrs) unsafe_free!(Bptrs) diff --git a/test/libraries/cublas/level1.jl b/test/libraries/cublas/level1.jl index 0258db3b0c..1b1b978b4f 100644 --- a/test/libraries/cublas/level1.jl +++ b/test/libraries/cublas/level1.jl @@ -26,8 +26,9 @@ k = 13 @test testf(*, rand(T, m)', rand(T, m)) @test testf(norm, rand(T, m)) @test testf(BLAS.asum, rand(T, m)) + @test testf(axpy!, rand(), rand(T, m), rand(T, m)) - #@test testf(LinearAlgebra.axpby!, rand(), rand(T, m), rand(), rand(T, m)) + @test testf(LinearAlgebra.axpby!, rand(), rand(T, m), rand(), rand(T, m)) if T <: Complex @test testf(dot, rand(T, m), rand(T, m)) x = rand(T, m) diff --git a/test/libraries/cublas/level3/gemm.jl b/test/libraries/cublas/level3/gemm.jl index 9742c355b0..ab5e1c02e6 100644 --- a/test/libraries/cublas/level3/gemm.jl +++ b/test/libraries/cublas/level3/gemm.jl @@ -225,6 +225,7 @@ k = 13 end @test_throws DimensionMismatch CUBLAS.gemm_batched('N','N',alpha,bd_A,bd_bad) end + @testset "gemmBatchedEx!" begin # C = (alpha*A)*B + beta*C CUBLAS.gemmBatchedEx!('N','N',alpha,bd_A,bd_B,beta,bd_C) @@ -236,6 +237,7 @@ k = 13 end @test_throws DimensionMismatch CUBLAS.gemmBatchedEx!('N','N',alpha,bd_A,bd_bad,beta,bd_C) end + nbatch = 10 bA = rand(elty, m, k, nbatch) bB = rand(elty, k, n, nbatch) @@ -256,6 +258,7 @@ k = 13 @test bC ≈ h_C @test_throws DimensionMismatch CUBLAS.gemm_strided_batched!('N', 'N', alpha, bd_A, bd_B, beta, bd_bad) end + @testset "gemmStridedBatchedEx!" begin CUBLAS.gemmStridedBatchedEx!('N', 'N', alpha, bd_A, bd_B, beta, bd_C) for i in 1:nbatch @@ -265,6 +268,7 @@ k = 13 @test bC ≈ h_C @test_throws DimensionMismatch CUBLAS.gemmStridedBatchedEx!('N', 'N', alpha, bd_A, bd_B, beta, bd_bad) end + @testset "gemm_strided_batched" begin bd_C = CUBLAS.gemm_strided_batched('N', 'N', bd_A, bd_B) @@ -402,6 +406,7 @@ k = 13 @test C ≈ Array(dC) rtol=rtol end end + # also test an unsupported combination (falling back to GPUArrays) if VERSION < v"1.11-" # JuliaGPU/CUDA.jl#2441 AT=BFloat16