Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
Merge pull request #30 from MikeInnes/cudart
Browse files Browse the repository at this point in the history
Use CUDAdrv instead of CUDArt
  • Loading branch information
MikeInnes authored Aug 10, 2017
2 parents ecc09c9 + 05b68d0 commit c053b1d
Show file tree
Hide file tree
Showing 7 changed files with 533 additions and 519 deletions.
2 changes: 1 addition & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
julia 0.5
CUDArt 0.3.0
CUDAdrv 0.5.1
8 changes: 6 additions & 2 deletions src/CUBLAS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ module CUBLAS

importall Base.LinAlg.BLAS

using CUDArt
using CUDArt.rt.cudaStream_t
using CUDAdrv: OwnedPtr, CuArray, CuVector, CuMatrix

CuVecOrMat{T} = Union{CuVector{T},CuMatrix{T}}

const BlasChar = Char #import Base.LinAlg.BlasChar
import Base.one
Expand Down Expand Up @@ -77,6 +78,9 @@ if isempty(libcublas)
error("CUBLAS library cannot be found. Please make sure that CUDA is installed")
end

# Typedef needed by libcublas
const cudaStream_t = Ptr{Void}

include("libcublas.jl")

# setup cublas handle
Expand Down
392 changes: 196 additions & 196 deletions src/blas.jl

Large diffs are not rendered by default.

73 changes: 35 additions & 38 deletions src/highlevel.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import Base.Operators.(*)

import Base: scale!, scale, norm, vecdot
import Base: scale!, norm, vecdot

import Base: A_mul_B!, At_mul_B, Ac_mul_B, A_mul_Bc, At_mul_Bt, Ac_mul_Bc, At_mul_Bt,
At_mul_B!, Ac_mul_B!, A_mul_Bc!, At_mul_Bt!, Ac_mul_Bc!, At_mul_Bt!

cublas_size(t::Char, M::CudaVecOrMat) = (size(M, t=='N' ? 1:2), size(M, t=='N' ? 2:1))
cublas_size(t::Char, M::CuVecOrMat) = (size(M, t=='N' ? 1:2), size(M, t=='N' ? 2:1))

###########
#
Expand All @@ -16,13 +16,12 @@ cublas_size(t::Char, M::CudaVecOrMat) = (size(M, t=='N' ? 1:2), size(M, t=='N' ?
#######
# SCAL
#######
scale!{T<:CublasFloat}(x::CudaArray{T}, k::Number) = CUBLAS.scal!(length(x), k, x, 1)
scale{T<:CublasFloat}(x::CudaArray{T}, k::Number) = CUBLAS.scal!(length(x), k, copy(x), 1)
scale!{T<:CublasFloat}(x::CuArray{T}, k::Number) = CUBLAS.scal!(length(x), k, x, 1)

#######
# DOT
#######
function dot{T <: CublasFloat, TI<:Integer}(x::CudaVector{T}, rx::Union{UnitRange{TI},Range{TI}}, y::CudaVector{T}, ry::Union{UnitRange{TI},Range{TI}})
function dot{T <: CublasFloat, TI<:Integer}(x::CuVector{T}, rx::Union{UnitRange{TI},Range{TI}}, y::CuVector{T}, ry::Union{UnitRange{TI},Range{TI}})
if length(rx) != length(ry)
throw(DimensionMismatch("length of rx, $(length(rx)), does not equal length of ry, $(length(ry))"))
end
Expand All @@ -35,17 +34,17 @@ function dot{T <: CublasFloat, TI<:Integer}(x::CudaVector{T}, rx::Union{UnitRang
dot(length(rx), pointer(x)+(first(rx)-1)*sizeof(T), step(rx), pointer(y)+(first(ry)-1)*sizeof(T), step(ry))
end

At_mul_B{T<:CublasReal}(x::CudaVector{T}, y::CudaVector{T}) = [CUBLAS.dot(x, y)]
At_mul_B{T<:CublasComplex}(x::CudaVector{T}, y::CudaVector{T}) = [CUBLAS.dotu(x, y)]
Ac_mul_B{T<:CublasComplex}(x::CudaVector{T}, y::CudaVector{T}) = [CUBLAS.dotc(x, y)]
At_mul_B{T<:CublasReal}(x::CuVector{T}, y::CuVector{T}) = [CUBLAS.dot(x, y)]
At_mul_B{T<:CublasComplex}(x::CuVector{T}, y::CuVector{T}) = [CUBLAS.dotu(x, y)]
Ac_mul_B{T<:CublasComplex}(x::CuVector{T}, y::CuVector{T}) = [CUBLAS.dotc(x, y)]

vecdot{T<:CublasReal}(x::CudaVector{T}, y::CudaVector{T}) = dot(x, y)
vecdot{T<:CublasComplex}(x::CudaVector{T}, y::CudaVector{T}) = dotc(x, y)
vecdot{T<:CublasReal}(x::CuVector{T}, y::CuVector{T}) = dot(x, y)
vecdot{T<:CublasComplex}(x::CuVector{T}, y::CuVector{T}) = dotc(x, y)

#######
# NRM2
#######
norm(x::CudaArray) = nrm2(x)
norm(x::CuArray) = nrm2(x)


############
Expand All @@ -58,7 +57,7 @@ norm(x::CudaArray) = nrm2(x)
#########
# GEMV
##########
function gemv_wrapper!{T<:CublasFloat}(y::CudaVector{T}, tA::Char, A::CudaMatrix{T}, x::CudaVector{T},
function gemv_wrapper!{T<:CublasFloat}(y::CuVector{T}, tA::Char, A::CuMatrix{T}, x::CuVector{T},
alpha = one(T), beta = zero(T))
mA, nA = cublas_size(tA, A)
if nA != length(x)
Expand All @@ -76,20 +75,20 @@ function gemv_wrapper!{T<:CublasFloat}(y::CudaVector{T}, tA::Char, A::CudaMatrix
gemv!(tA, alpha, A, x, beta, y)
end

A_mul_B!{T<:CublasFloat}(y::CudaVector{T}, A::CudaMatrix{T}, x::CudaVector{T}) = gemv_wrapper!(y, 'N', A, x)
At_mul_B!{T<:CublasFloat}(y::CudaVector{T}, A::CudaMatrix{T}, x::CudaVector{T}) = gemv_wrapper!(y, 'T', A, x)
Ac_mul_B!{T<:CublasFloat}(y::CudaVector{T}, A::CudaMatrix{T}, x::CudaVector{T}) = gemv_wrapper!(y, 'T', A, x)
Ac_mul_B!{T<:CublasComplex}(y::CudaVector{T}, A::CudaMatrix{T}, x::CudaVector{T}) = gemv_wrapper!(y, 'C', A, x)
A_mul_B!{T<:CublasFloat}(y::CuVector{T}, A::CuMatrix{T}, x::CuVector{T}) = gemv_wrapper!(y, 'N', A, x)
At_mul_B!{T<:CublasFloat}(y::CuVector{T}, A::CuMatrix{T}, x::CuVector{T}) = gemv_wrapper!(y, 'T', A, x)
Ac_mul_B!{T<:CublasFloat}(y::CuVector{T}, A::CuMatrix{T}, x::CuVector{T}) = gemv_wrapper!(y, 'T', A, x)
Ac_mul_B!{T<:CublasComplex}(y::CuVector{T}, A::CuMatrix{T}, x::CuVector{T}) = gemv_wrapper!(y, 'C', A, x)

function (*){T<:CublasFloat}(A::CudaMatrix{T}, x::CudaVector{T})
function (*){T<:CublasFloat}(A::CuMatrix{T}, x::CuVector{T})
A_mul_B!(similar(x, T, size(A,1)), A, x)
end

function At_mul_B{T<:CublasFloat}(A::CudaMatrix{T}, x::CudaVector{T})
function At_mul_B{T<:CublasFloat}(A::CuMatrix{T}, x::CuVector{T})
At_mul_B!(similar(x, T, size(A,2)), A, x)
end

function Ac_mul_B{T<:CublasFloat}(A::CudaMatrix{T}, x::CudaVector{T})
function Ac_mul_B{T<:CublasFloat}(A::CuMatrix{T}, x::CuVector{T})
Ac_mul_B!(similar(x, T, size(A,2)), A, x)
end

Expand All @@ -103,9 +102,9 @@ end
########
# GEMM
########
function gemm_wrapper!{T <: CublasFloat}(C::CudaVecOrMat{T}, tA::Char, tB::Char,
A::CudaVecOrMat{T},
B::CudaVecOrMat{T},
function gemm_wrapper!{T <: CublasFloat}(C::CuVecOrMat{T}, tA::Char, tB::Char,
A::CuVecOrMat{T},
B::CuVecOrMat{T},
alpha = one(T),
beta = zero(T))
mA, nA = cublas_size(tA, A)
Expand All @@ -130,51 +129,49 @@ function gemm_wrapper!{T <: CublasFloat}(C::CudaVecOrMat{T}, tA::Char, tB::Char,
end

# Mutating
A_mul_B!{T <: CublasFloat}(C::CudaMatrix{T}, A::CudaMatrix{T}, B::CudaMatrix{T}) = gemm_wrapper!(C, 'N', 'N', A, B)
At_mul_B!(C::CudaMatrix, A::CudaMatrix, B::CudaMatrix) = gemm_wrapper!(C, 'T', 'N', A, B)
At_mul_Bt!(C::CudaMatrix, A::CudaMatrix, B::CudaMatrix) = gemm_wrapper!(C, 'T', 'T', A, B)
Ac_mul_B!{T<:CublasReal}(C::CudaMatrix{T}, A::CudaMatrix{T}, B::CudaMatrix{T}) = At_mul_B!(C, A, B)
Ac_mul_B!(C::CudaMatrix, A::CudaMatrix, B::CudaMatrix) = gemm_wrapper!(C, 'C', 'N', A, B)
A_mul_B!{T <: CublasFloat}(C::CuMatrix{T}, A::CuMatrix{T}, B::CuMatrix{T}) = gemm_wrapper!(C, 'N', 'N', A, B)
At_mul_B!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'T', 'N', A, B)
At_mul_Bt!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'T', 'T', A, B)
Ac_mul_B!{T<:CublasReal}(C::CuMatrix{T}, A::CuMatrix{T}, B::CuMatrix{T}) = At_mul_B!(C, A, B)
Ac_mul_B!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'C', 'N', A, B)

function A_mul_B!{T}(C::CudaMatrix{T}, A::CudaVecOrMat{T}, B::CudaVecOrMat{T})
function A_mul_B!{T}(C::CuMatrix{T}, A::CuVecOrMat{T}, B::CuVecOrMat{T})
gemm_wrapper!(C, 'N', 'N', A, B)
end

# Non mutating

# A_mul_Bx
function (*){T <: CublasFloat}(A::CudaMatrix{T}, B::CudaMatrix{T})
function (*){T <: CublasFloat}(A::CuMatrix{T}, B::CuMatrix{T})
A_mul_B!(similar(B, T,(size(A,1), size(B,2))), A, B)
end

function A_mul_Bt{T}(A::CudaMatrix{T}, B::CudaMatrix{T})
function A_mul_Bt{T}(A::CuMatrix{T}, B::CuMatrix{T})
A_mul_Bt!(similar(B, T, (size(A,1), size(B,1))), A, B)
end

function A_mul_Bc{T}(A::CudaMatrix{T}, B::CudaMatrix{T})
function A_mul_Bc{T}(A::CuMatrix{T}, B::CuMatrix{T})
A_mul_Bc!(similar(B, T,(size(A,1),size(B,1))),A, B)
end

# At_mul_Bx
function At_mul_B{T}(A::CudaMatrix{T}, B::CudaMatrix{T})
function At_mul_B{T}(A::CuMatrix{T}, B::CuMatrix{T})
At_mul_B!(similar(B, T, (size(A,2), size(B,2))), A, B)
end

function At_mul_Bt{T}(A::CudaMatrix{T}, B::CudaMatrix{T})
function At_mul_Bt{T}(A::CuMatrix{T}, B::CuMatrix{T})
At_mul_Bt!(similar(B, T, (size(A,2), size(B,1))), A, B)
end

# Ac_mul_Bx
function Ac_mul_B{T}(A::CudaMatrix{T}, B::CudaMatrix{T})
function Ac_mul_B{T}(A::CuMatrix{T}, B::CuMatrix{T})
Ac_mul_B!(similar(B, T, (size(A,2), size(B,2))), A, B)
end

function Ac_mul_Bt{T,S}(A::CudaMatrix{T}, B::CudaMatrix{S})
function Ac_mul_Bt{T,S}(A::CuMatrix{T}, B::CuMatrix{S})
Ac_mul_Bt(similar(B, T, (size(A,2), size(B,1))), A, B)
end

function Ac_mul_Bc{T,S}(A::CudaMatrix{T}, B::CudaMatrix{S})
function Ac_mul_Bc{T,S}(A::CuMatrix{T}, B::CuMatrix{S})
Ac_mul_Bc!(similar(B, T, (size(A,2), size(B,1))), A, B)
end


2 changes: 1 addition & 1 deletion src/libcublas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -773,4 +773,4 @@ try
catch exception
Base.show_backtrace(STDOUT, backtrace());
println();
end
end
2 changes: 1 addition & 1 deletion src/libcublas_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,4 @@ try
catch exception
Base.show_backtrace(STDOUT, backtrace());
println();
end
end
Loading

0 comments on commit c053b1d

Please sign in to comment.