Skip to content

Commit

Permalink
Restore grouped batched gemm functionality.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Feb 6, 2025
1 parent f0be27f commit 6bb7f86
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 15 deletions.
6 changes: 3 additions & 3 deletions lib/cublas/CUBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ function handle()
cublasSetStream_v2(new_handle, cuda.stream)
math_mode!(new_handle, cuda.math_mode)

# default to device pointers everywhere
cublasSetPointerMode_v2(state.handle, CUBLAS_POINTER_MODE_DEVICE)

(; handle=new_handle, cuda.stream, cuda.math_mode)
end
state = get!(states, cuda.context) do
Expand All @@ -130,9 +133,6 @@ function handle()
states[cuda.context] = state = update_math_mode(cuda, state)
end

# set pointer mode to device
cublasSetPointerMode_v2(state.handle, CUBLAS_POINTER_MODE_DEVICE)

return state.handle
end

Expand Down
29 changes: 20 additions & 9 deletions lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1401,8 +1401,7 @@ end
end

## (GE) general matrix-matrix multiplication grouped batched
# does NOT work with device side scalar pointers
#= for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched, :cublasSgemmGroupedBatched_64, :Float32),
for (fname, fname_64, elty) in ((:cublasSgemmGroupedBatched, :cublasSgemmGroupedBatched_64, :Float32),
(:cublasDgemmGroupedBatched, :cublasDgemmGroupedBatched_64, :Float64))
@eval begin
function gemm_grouped_batched!(transA::Vector{Char},
Expand Down Expand Up @@ -1445,12 +1444,23 @@ end
Bptrs = unsafe_batch(reduce(vcat, B))
Cptrs = unsafe_batch(reduce(vcat, C))

if CUBLAS.version() >= v"12.0"
$fname_64(handle(), transa, transb, m, n, k, alpha, Aptrs, lda,
Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
else
$fname(handle(), transa, transb, m, n, k, alpha, Aptrs, lda,
Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
try
## XXX: does not seem to support device pointers
cublasSetPointerMode_v2(handle(), CUBLAS_POINTER_MODE_HOST)

mode = Ref{cublasPointerMode_t}()
cublasGetPointerMode_v2(handle(), mode)
@show mode[]

if CUBLAS.version() >= v"12.0"
$fname_64(handle(), transa, transb, m, n, k, alpha, Aptrs, lda,
Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
else
$fname(handle(), transa, transb, m, n, k, alpha, Aptrs, lda,
Bptrs, ldb, beta, Cptrs, ldc, group_count, group_size)
end
finally
cublasSetPointerMode_v2(handle(), CUBLAS_POINTER_MODE_DEVICE)
end
unsafe_free!(Cptrs)
unsafe_free!(Bptrs)
Expand Down Expand Up @@ -1540,7 +1550,8 @@ function gemm_grouped_batched(transA::Vector{Char}, transB::Vector{Char},
alpha = [one(T) for i = 1:length(transA)]
gemm_grouped_batched(transA, transB, alpha, A, B)
end
=#


## (GE) general matrix-matrix multiplication batched
for (fname, fname_64, elty) in ((:cublasDgemmBatched, :cublasDgemmBatched_64, :Float64),
(:cublasSgemmBatched, :cublasSgemmBatched_64, :Float32),
Expand Down
3 changes: 0 additions & 3 deletions test/libraries/cublas/level3/gemm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,6 @@ k = 13
end
end

# TODO does not work with device side pointers
#=
if CUDA.CUBLAS.version() >= v"12.4.2"
@testset "elty = $elty" for elty in [Float32, Float64]
num_groups = 10
Expand Down Expand Up @@ -372,7 +370,6 @@ k = 13
end
end
end
=#

@testset "mixed-precision matmul" begin
m,k,n = 4,4,4
Expand Down

0 comments on commit 6bb7f86

Please sign in to comment.