Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
A_mul_Bt
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Aug 23, 2017
1 parent 746505a commit 31453b8
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ import Base.Operators.(*)

import Base: scale!, norm, vecdot

import Base: A_mul_B!, At_mul_B, Ac_mul_B, A_mul_Bc, At_mul_Bt, Ac_mul_Bc, At_mul_Bt,
At_mul_B!, Ac_mul_B!, A_mul_Bc!, At_mul_Bt!, Ac_mul_Bc!, At_mul_Bt!
import Base: A_mul_B!, At_mul_B, A_mul_Bt, Ac_mul_B, A_mul_Bc, At_mul_Bt, Ac_mul_Bc, At_mul_Bt,
At_mul_B!, A_mul_Bt!, Ac_mul_B!, A_mul_Bc!, At_mul_Bt!, Ac_mul_Bc!, At_mul_Bt!

cublas_size(t::Char, M::CuVecOrMat) = (size(M, t=='N' ? 1:2), size(M, t=='N' ? 2:1))

Expand Down Expand Up @@ -131,6 +131,7 @@ end
# Mutating
A_mul_B!(C::CuMatrix{T}, A::CuMatrix{T}, B::CuMatrix{T}) where {T <: CublasFloat} = gemm_wrapper!(C, 'N', 'N', A, B)
At_mul_B!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'T', 'N', A, B)
A_mul_Bt!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'N', 'T', A, B)
At_mul_Bt!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'T', 'T', A, B)
Ac_mul_B!(C::CuMatrix{T}, A::CuMatrix{T}, B::CuMatrix{T}) where {T<:CublasReal} = At_mul_B!(C, A, B)
Ac_mul_B!(C::CuMatrix, A::CuMatrix, B::CuMatrix) = gemm_wrapper!(C, 'C', 'N', A, B)
Expand Down

0 comments on commit 31453b8

Please sign in to comment.