diff --git a/src/highlevel.jl b/src/highlevel.jl index 5966b19..e764839 100644 --- a/src/highlevel.jl +++ b/src/highlevel.jl @@ -2,8 +2,8 @@ import Base.Operators.(*) import Base: scale!, norm, vecdot -import Base: A_mul_B!, At_mul_B, Ac_mul_B, A_mul_Bc, At_mul_Bt, Ac_mul_Bc, At_mul_Bt, - At_mul_B!, Ac_mul_B!, A_mul_Bc!, At_mul_Bt!, Ac_mul_Bc!, At_mul_Bt! +import Base: A_mul_B!, At_mul_B, A_mul_Bt, Ac_mul_B, A_mul_Bc, At_mul_Bt, Ac_mul_Bc, At_mul_Bt, + At_mul_B!, A_mul_Bt!, Ac_mul_B!, A_mul_Bc!, At_mul_Bt!, Ac_mul_Bc!, At_mul_Bt! cublas_size(t::Char, M::CuVecOrMat) = (size(M, t=='N' ? 1:2), size(M, t=='N' ? 2:1)) @@ -131,6 +131,7 @@ end # Mutating A_mul_B!(C::CuMatrix{T}, A::CuMatrix{T}, B::CuMatrix{T}) where {T <: CublasFloat} = gemm_wrapper!(C, 'N', 'N', A, B) At_mul_B!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'T', 'N', A, B) +A_mul_Bt!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'N', 'T', A, B) At_mul_Bt!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'T', 'T', A, B) Ac_mul_B!(C::CuMatrix{T}, A::CuMatrix{T}, B::CuMatrix{T}) where {T<:CublasReal} = At_mul_B!(C, A, B) Ac_mul_B!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'C', 'N', A, B)