Skip to content

Commit

Permalink
Create CuBLAS PointerModeGuard (pytorch#42639)
Browse files Browse the repository at this point in the history
Summary:
Adds an RAII guard for `cublasSetPointerMode()`.
Updates `dot_cuda` to use the guard, rather than exception catching.

Addresses this comment: pytorch#41377 (comment)

Pull Request resolved: pytorch#42639

Reviewed By: malfet

Differential Revision: D22969985

Pulled By: ezyang

fbshipit-source-id: b05c35d1884bb890f8767d6a4ef8b4724a329471
  • Loading branch information
kurtamohler authored and facebook-github-bot committed Aug 6, 2020
1 parent eb9ae7c commit 2360744
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
19 changes: 19 additions & 0 deletions aten/src/ATen/cuda/CUDABlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,25 @@ namespace at {
namespace cuda {
namespace blas {

// RAII guard that sets the CuBLAS pointer mode and restores it to
// its previous value when the guard is destroyed
class PointerModeGuard {
public:
PointerModeGuard(cublasHandle_t handle, cublasPointerMode_t mode) :
handle(handle) {
TORCH_CUDABLAS_CHECK(cublasGetPointerMode(handle, &previous_mode));
TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, mode));
}

~PointerModeGuard() {
cublasSetPointerMode(handle, previous_mode);
}

private:
cublasHandle_t handle;
cublasPointerMode_t previous_mode;
};

/* LEVEL 3 BLAS FUNCTIONS */

#define CUDABLAS_GEMM_ARGTYPES(Dtype) \
Expand Down
28 changes: 9 additions & 19 deletions aten/src/ATen/native/cuda/LinearAlgebra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -396,25 +396,15 @@ Tensor dot_cuda(const Tensor& self, const Tensor& other) {
Tensor result = at::empty({}, self.options());

auto handle = at::cuda::getCurrentCUDABlasHandle();
cublasPointerMode_t previous_mode = CUBLAS_POINTER_MODE_DEVICE;
TORCH_CUDABLAS_CHECK(cublasGetPointerMode(handle, &previous_mode));
TORCH_CUDABLAS_CHECK(
cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE));

try {
at::cuda::blas::dot<scalar_t>(
handle,
n,
self.data_ptr<scalar_t>(),
incx,
other.data_ptr<scalar_t>(),
incy,
result.data_ptr<scalar_t>());
} catch (...) {
TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, previous_mode));
throw;
}
TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, previous_mode));
at::cuda::blas::PointerModeGuard pointerModeGuard(handle, CUBLAS_POINTER_MODE_DEVICE);
at::cuda::blas::dot<scalar_t>(
handle,
n,
self.data_ptr<scalar_t>(),
incx,
other.data_ptr<scalar_t>(),
incy,
result.data_ptr<scalar_t>());

return result;
});
Expand Down

0 comments on commit 2360744

Please sign in to comment.