Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
Merge pull request #32 from JuliaGPU/cudart
Browse files Browse the repository at this point in the history
Fix versioned codeblocks
  • Loading branch information
MikeInnes authored Aug 14, 2017
2 parents 07b2d6d + c829309 commit 2e0fa4c
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 131 deletions.
1 change: 1 addition & 0 deletions src/CUBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ module CUBLAS

importall Base.LinAlg.BLAS

using CUDAdrv
using CUDAdrv: OwnedPtr, CuArray, CuVector, CuMatrix

CuVecOrMat{T} = Union{CuVector{T},CuMatrix{T}}
Expand Down
189 changes: 92 additions & 97 deletions src/libcublas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -675,102 +675,97 @@ function cublasZtrttp(handle, uplo, n, A, lda, AP)
statuscheck(ccall( (:cublasZtrttp, libcublas), cublasStatus_t, (cublasHandle_t, cublasFillMode_t, Cint, Ptr{cuDoubleComplex}, Cint, Ptr{cuDoubleComplex}), handle, uplo, n, A, lda, AP))
end

try
if (CUDArt.runtime_version() >= 7500) #these functions were introduced with CUDA v7.5
# Wrap extensions of functions (ie. Nrm2Ex, GemmEx, etc) (CUDA 7.5+)
function cublasNrm2Ex(handle, n, x, xType, incx, result, resultType, executionType)
statuscheck(ccall((:cublasNrm2Ex, libcublas),
cublasStatus_t,
(cublasHandle_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, cudaDataType_t),
handle, n, x, xType, incx, result, resultType, executionType));
end
function cublasDotEx(handle, n, x, xType, incx, y, yType, incy, result, resultType, executionType)
statuscheck(ccall((:cublasDotEx, libcublas),
cublasStatus_t,
(cublasHandle_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, cudaDataType_t),
handle, n, x, xType, incx, y, yType, incy, result, resultType, executionType));
end
function cublasDotcEx(handle, n, x, xType, incx, y, yType, incy, result, resultType, executionType)
statuscheck(ccall((:cublasDotcEx, libcublas),
cublasStatus_t,
(cublasHandle_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, cudaDataType_t),
handle, n, x, xType, incx, y, yType, incy, result, resultType, executionType));
end
function cublasScalEx(handle, n, alpha, alphaType, x, xType, incx, executionType)
statuscheck(ccall((:cublasScalEx, libcublas),
cublasStatus_t,
(cublasHandle_t, Cint, Ptr{Void}, cudaDataType_t, Ptr{Void}, cudaDataType_t, Cint, cudaDataType_t),
handle, n, alpha, alphaType, x, xType, incx, executionType));
end
function cublasAxpyEx(handle, n, alpha, alphaType, x, xType, incx, y, yType, incy, executionType)
statuscheck(ccall((:cublasAxpyEx, libcublas),
cublasStatus_t,
(cublasHandle_t, Cint, Ptr{Void}, cudaDataType_t, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, Cint, cudaDataType_t),
handle, n, alpha, alphaType, x, xType, incx, y, yType, incy, executionType));
end
function cublasCgemm3mEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc)
statuscheck(ccall((:cublasCgemm3mEx, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasOperation_t, cublasOperation_t, Cint, Cint, Cint, Ptr{cuComplex}, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{cuComplex}, Ptr{Void}, cudaDataType_t, Cint),
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc));
end
function cublasSgemmEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc)
statuscheck(ccall((:cublasSgemmEx, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasOperation_t, cublasOperation_t, Cint, Cint, Cint, Ptr{Cfloat}, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{Cfloat}, Ptr{Void}, cudaDataType_t, Cint),
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc));
end
function cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo)
statuscheck(ccall((:cublasGemmEx, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasOperation_t, cublasOperation_t, Cint, Cint, Cint, Ptr{Void}, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{void}, Ptr{Void}, cudaDataType_t, Cint, cudaDataType_t, cublasGemmAlgo_t),
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo));
end
function cublasCgemmEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc)
statuscheck(ccall((:cublasCgemmEx, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasOperation_t, cublasOperation_t, Cint, Cint, Cint, Ptr{cuComplex}, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{cuComplex}, Ptr{Void}, cudaDataType_t, Cint),
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc));
end
function cublasCsyrkEx(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc)
statuscheck(ccall((:cublasCsyrkEx, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasFillMode_t, cublasOperation_t, Cint, Cint, Ptr{cuComplex}, Ptr{Void}, cudaDataType_t, Cint, Ptr{cuComplex}, Ptr{Void}, cudaDataType_t, Cint),
handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc));
end
function cublasCsyrk3mEx(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc)
statuscheck(ccall((:cublasCsyrk3mEx, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasFillMode_t, cublasOperation_t, Cint, Cint, Ptr{cuComplex}, Ptr{Void}, cudaDataType_t, Cint, Ptr{cuComplex}, Ptr{Void}, cudaDataType_t, Cint),
handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc));
end
function cublasCherkEx(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc)
statuscheck(ccall((:cublasCherkEx, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasFillMode_t, cublasOperation_t, Cint, Cint, Ptr{Cfloat}, Ptr{Void}, cudaDataType_t, Cint, Ptr{Cfloat}, Ptr{Void}, cudaDataType_t, Cint),
handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc));
end
function cublasCherk3mEx(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc)
statuscheck(ccall((:cublasCherk3mEx, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasFillMode_t, cublasOperation_t, Cint, Cint, Ptr{Cfloat}, Ptr{Void}, cudaDataType_t, Cint, Ptr{Cfloat}, Ptr{Void}, cudaDataType_t, Cint),
handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc));
end
# Wrap FP16 functions (CUDA 7.5+)
function cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc)
statuscheck(ccall((:cublasHgemm, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasOperation_t, cublasOperation_t, Cint, Cint, Cint, Ptr{__half}, Ptr{__half}, Cint, Ptr{__half}, Cint, Ptr{__half}, Ptr{__half}, Cint),
handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc));
end
function cublasHgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount)
statuscheck(ccall((:cublasHgemmStridedBatched, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasOperation_t, cublasOperation_t, Cint, Cint, Cint, Ptr{__half}, Ptr{__half}, Cint, Clonglong, Ptr{__half}, Cint, Clonglong, Ptr{__half}, Ptr{__half}, Cint, Clonglong, Cint),
handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount));
end
if CUDAdrv.version() v"7.5.0" #these functions were introduced with CUDA v7.5
# Wrap extensions of functions (ie. Nrm2Ex, GemmEx, etc) (CUDA 7.5+)
function cublasNrm2Ex(handle, n, x, xType, incx, result, resultType, executionType)
statuscheck(ccall((:cublasNrm2Ex, libcublas),
cublasStatus_t,
(cublasHandle_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, cudaDataType_t),
handle, n, x, xType, incx, result, resultType, executionType));
end
function cublasDotEx(handle, n, x, xType, incx, y, yType, incy, result, resultType, executionType)
statuscheck(ccall((:cublasDotEx, libcublas),
cublasStatus_t,
(cublasHandle_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, cudaDataType_t),
handle, n, x, xType, incx, y, yType, incy, result, resultType, executionType));
end
function cublasDotcEx(handle, n, x, xType, incx, y, yType, incy, result, resultType, executionType)
statuscheck(ccall((:cublasDotcEx, libcublas),
cublasStatus_t,
(cublasHandle_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, cudaDataType_t),
handle, n, x, xType, incx, y, yType, incy, result, resultType, executionType));
end
function cublasScalEx(handle, n, alpha, alphaType, x, xType, incx, executionType)
statuscheck(ccall((:cublasScalEx, libcublas),
cublasStatus_t,
(cublasHandle_t, Cint, Ptr{Void}, cudaDataType_t, Ptr{Void}, cudaDataType_t, Cint, cudaDataType_t),
handle, n, alpha, alphaType, x, xType, incx, executionType));
end
function cublasAxpyEx(handle, n, alpha, alphaType, x, xType, incx, y, yType, incy, executionType)
statuscheck(ccall((:cublasAxpyEx, libcublas),
cublasStatus_t,
(cublasHandle_t, Cint, Ptr{Void}, cudaDataType_t, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, Cint, cudaDataType_t),
handle, n, alpha, alphaType, x, xType, incx, y, yType, incy, executionType));
end
function cublasCgemm3mEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc)
statuscheck(ccall((:cublasCgemm3mEx, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasOperation_t, cublasOperation_t, Cint, Cint, Cint, Ptr{cuComplex}, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{cuComplex}, Ptr{Void}, cudaDataType_t, Cint),
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc));
end
function cublasSgemmEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc)
statuscheck(ccall((:cublasSgemmEx, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasOperation_t, cublasOperation_t, Cint, Cint, Cint, Ptr{Cfloat}, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{Cfloat}, Ptr{Void}, cudaDataType_t, Cint),
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc));
end
function cublasGemmEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo)
statuscheck(ccall((:cublasGemmEx, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasOperation_t, cublasOperation_t, Cint, Cint, Cint, Ptr{Void}, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, Ptr{Void}, cudaDataType_t, Cint, cudaDataType_t, cublasGemmAlgo_t),
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc, computeType, algo));
end
function cublasCgemmEx(handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc)
statuscheck(ccall((:cublasCgemmEx, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasOperation_t, cublasOperation_t, Cint, Cint, Cint, Ptr{cuComplex}, Ptr{Void}, cudaDataType_t, Cint, Ptr{Void}, cudaDataType_t, Cint, Ptr{cuComplex}, Ptr{Void}, cudaDataType_t, Cint),
handle, transa, transb, m, n, k, alpha, A, Atype, lda, B, Btype, ldb, beta, C, Ctype, ldc));
end
function cublasCsyrkEx(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc)
statuscheck(ccall((:cublasCsyrkEx, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasFillMode_t, cublasOperation_t, Cint, Cint, Ptr{cuComplex}, Ptr{Void}, cudaDataType_t, Cint, Ptr{cuComplex}, Ptr{Void}, cudaDataType_t, Cint),
handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc));
end
function cublasCsyrk3mEx(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc)
statuscheck(ccall((:cublasCsyrk3mEx, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasFillMode_t, cublasOperation_t, Cint, Cint, Ptr{cuComplex}, Ptr{Void}, cudaDataType_t, Cint, Ptr{cuComplex}, Ptr{Void}, cudaDataType_t, Cint),
handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc));
end
function cublasCherkEx(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc)
statuscheck(ccall((:cublasCherkEx, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasFillMode_t, cublasOperation_t, Cint, Cint, Ptr{Cfloat}, Ptr{Void}, cudaDataType_t, Cint, Ptr{Cfloat}, Ptr{Void}, cudaDataType_t, Cint),
handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc));
end
function cublasCherk3mEx(handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc)
statuscheck(ccall((:cublasCherk3mEx, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasFillMode_t, cublasOperation_t, Cint, Cint, Ptr{Cfloat}, Ptr{Void}, cudaDataType_t, Cint, Ptr{Cfloat}, Ptr{Void}, cudaDataType_t, Cint),
handle, uplo, trans, n, k, alpha, A, Atype, lda, beta, C, Ctype, ldc));
end
# Wrap FP16 functions (CUDA 7.5+)
function cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc)
statuscheck(ccall((:cublasHgemm, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasOperation_t, cublasOperation_t, Cint, Cint, Cint, Ptr{__half}, Ptr{__half}, Cint, Ptr{__half}, Cint, Ptr{__half}, Ptr{__half}, Cint),
handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc));
end
function cublasHgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount)
statuscheck(ccall((:cublasHgemmStridedBatched, libcublas),
cublasStatus_t,
(cublasHandle_t, cublasOperation_t, cublasOperation_t, Cint, Cint, Cint, Ptr{__half}, Ptr{__half}, Cint, Clonglong, Ptr{__half}, Cint, Clonglong, Ptr{__half}, Ptr{__half}, Cint, Clonglong, Cint),
handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount));
end
catch exception
Base.show_backtrace(STDOUT, backtrace());
println();
end
64 changes: 30 additions & 34 deletions src/libcublas_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,42 +62,38 @@ const CublasFloat = Union{Float64,Float32,Complex128,Complex64}
const CublasReal = Union{Float64,Float32}
const CublasComplex = Union{Complex128,Complex64}
# FP16 (cuda_fp16.h) in cuda
const __half = Float16;
const __half = Float16
immutable __half2
x1::__half
x2::__half
end
try
if (CUDArt.runtime_version() >= 7500)
# specify which GEMM algorithm to use in cublasGemmEx() (CUDA 7.5+)
const cublasGemmAlgo_t = Int32;
const CUBLAS_GEMM_DFALT = -1;
const CUBLAS_GEMM_ALGO0 = 0;
const CUBLAS_GEMM_ALGO1 = 1;
const CUBLAS_GEMM_ALGO2 = 2;
const CUBLAS_GEMM_ALGO3 = 3;
const CUBLAS_GEMM_ALGO4 = 4;
const CUBLAS_GEMM_ALGO5 = 5;
const CUBLAS_GEMM_ALGO6 = 6;
const CUBLAS_GEMM_ALGO7 = 7;
# specify which DataType to use with cublas<t>gemmEx() and cublasGemmEx() (CUDA 7.5+) functions
const cudaDataType_t = UInt32;
const CUDA_R_16F = UInt32(2);
const CUDA_C_16F = UInt32(6);
const CUDA_R_32F = UInt32(0);
const CUDA_C_32F = UInt32(4);
const CUDA_R_64F = UInt32(1);
const CUDA_C_64F = UInt32(5);
const CUDA_R_8I = UInt32(3);
const CUDA_C_8I = UInt32(7);
const CUDA_R_8U = UInt32(8);
const CUDA_C_8U = UInt32(9);
const CUDA_R_32I = UInt32(10);
const CUDA_C_32I = UInt32(11);
const CUDA_R_32U = UInt32(12);
const CUDA_C_32U = UInt32(13);
end
catch exception
Base.show_backtrace(STDOUT, backtrace());
println();

if CUDAdrv.version() >= v"0.7.5"
# specify which GEMM algorithm to use in cublasGemmEx() (CUDA 7.5+)
const cublasGemmAlgo_t = Int32
const CUBLAS_GEMM_DFALT = -1
const CUBLAS_GEMM_ALGO0 = 0
const CUBLAS_GEMM_ALGO1 = 1
const CUBLAS_GEMM_ALGO2 = 2
const CUBLAS_GEMM_ALGO3 = 3
const CUBLAS_GEMM_ALGO4 = 4
const CUBLAS_GEMM_ALGO5 = 5
const CUBLAS_GEMM_ALGO6 = 6
const CUBLAS_GEMM_ALGO7 = 7
# specify which DataType to use with cublas<t>gemmEx() and cublasGemmEx() (CUDA 7.5+) functions
const cudaDataType_t = UInt32
const CUDA_R_16F = UInt32(2)
const CUDA_C_16F = UInt32(6)
const CUDA_R_32F = UInt32(0)
const CUDA_C_32F = UInt32(4)
const CUDA_R_64F = UInt32(1)
const CUDA_C_64F = UInt32(5)
const CUDA_R_8I = UInt32(3)
const CUDA_C_8I = UInt32(7)
const CUDA_R_8U = UInt32(8)
const CUDA_C_8U = UInt32(9)
const CUDA_R_32I = UInt32(10)
const CUDA_C_32I = UInt32(11)
const CUDA_R_32U = UInt32(12)
const CUDA_C_32U = UInt32(13)
end

0 comments on commit 2e0fa4c

Please sign in to comment.