diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index 1e1e87352d76ce..60696e6674f816 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -19,6 +19,25 @@ namespace at { namespace cuda { namespace blas { +// RAII guard that sets the CuBLAS pointer mode and restores it to +// its previous value when the guard is destroyed +class PointerModeGuard { +public: + PointerModeGuard(cublasHandle_t handle, cublasPointerMode_t mode) : + handle(handle) { + TORCH_CUDABLAS_CHECK(cublasGetPointerMode(handle, &previous_mode)); + TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, mode)); + } + + ~PointerModeGuard() { + cublasSetPointerMode(handle, previous_mode); + } + +private: + cublasHandle_t handle; + cublasPointerMode_t previous_mode; +}; + /* LEVEL 3 BLAS FUNCTIONS */ #define CUDABLAS_GEMM_ARGTYPES(Dtype) \ diff --git a/aten/src/ATen/native/cuda/LinearAlgebra.cu b/aten/src/ATen/native/cuda/LinearAlgebra.cu index 7d208af40b1286..439699f322ddf1 100644 --- a/aten/src/ATen/native/cuda/LinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/LinearAlgebra.cu @@ -396,25 +396,15 @@ Tensor dot_cuda(const Tensor& self, const Tensor& other) { Tensor result = at::empty({}, self.options()); auto handle = at::cuda::getCurrentCUDABlasHandle(); - cublasPointerMode_t previous_mode = CUBLAS_POINTER_MODE_DEVICE; - TORCH_CUDABLAS_CHECK(cublasGetPointerMode(handle, &previous_mode)); - TORCH_CUDABLAS_CHECK( - cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE)); - - try { - at::cuda::blas::dot( - handle, - n, - self.data_ptr(), - incx, - other.data_ptr(), - incy, - result.data_ptr()); - } catch (...) { - TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, previous_mode)); - throw; - } - TORCH_CUDABLAS_CHECK(cublasSetPointerMode(handle, previous_mode)); + at::cuda::blas::PointerModeGuard pointerModeGuard(handle, CUBLAS_POINTER_MODE_DEVICE); + at::cuda::blas::dot( + handle, + n, + self.data_ptr(), + incx, + other.data_ptr(), + incy, + result.data_ptr()); return result; });