Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[oneMKL] Interface variants of trsm! and trmm! #479

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
54 changes: 46 additions & 8 deletions lib/mkl/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# interfacing with LinearAlgebra standard library

import LinearAlgebra
using LinearAlgebra: Transpose, Adjoint,
using LinearAlgebra: Transpose, Adjoint, AdjOrTrans,
Hermitian, Symmetric,
LowerTriangular, UnitLowerTriangular,
UpperTriangular, UnitUpperTriangular,
MulAddMul, wrap
UpperOrLowerTriangular, MulAddMul, wrap

#
# BLAS 1
Expand Down Expand Up @@ -163,12 +163,50 @@ function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStr
GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
end

const AdjOrTransOroneMatrix{T} = Union{oneStridedMatrix{T}, AdjOrTrans{<:T,<:oneStridedMatrix}}

function LinearAlgebra.generic_trimatmul!(
C::oneStridedMatrix{T}, uplocA, isunitcA,
tfunA::Function, A::oneStridedMatrix{T},
triB::UpperOrLowerTriangular{T, <: AdjOrTransOroneMatrix{T}},
) where {T<:onemklFloat}
uplocB = LinearAlgebra.uplo_char(triB)
isunitcB = LinearAlgebra.isunit_char(triB)
B = parent(triB)
tfunB = LinearAlgebra.wrapperop(B)
transa = tfunA === identity ? 'N' : tfunA === transpose ? 'T' : 'C'
transb = tfunB === identity ? 'N' : tfunB === transpose ? 'T' : 'C'
if uplocA == 'L' && tfunA === identity && tfunB === identity && uplocB == 'U' && isunitcB == 'N' # lower * upper
triu!(B)
trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
elseif uplocA == 'U' && tfunA === identity && tfunB === identity && uplocB == 'L' && isunitcB == 'N' # upper * lower
tril!(B)
trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
elseif uplocA == 'U' && tfunA === identity && tfunB !== identity && uplocB == 'U' && isunitcA == 'N'
# operation is reversed to avoid executing the tranpose
triu!(A)
trmm!('R', uplocB, transb, isunitcB, one(T), parent(B), A, C)
elseif uplocA == 'L' && tfunA !== identity && tfunB === identity && uplocB == 'L' && isunitcB == 'N'
tril!(B)
trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
elseif uplocA == 'U' && tfunA !== identity && tfunB === identity && uplocB == 'U' && isunitcB == 'N'
triu!(B)
trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C)
elseif uplocA == 'L' && tfunA === identity && tfunB !== identity && uplocB == 'L' && isunitcA == 'N'
tril!(A)
trmm!('R', uplocB, transb, isunitcB, one(T), parent(B), A, C)
else
throw("mixed triangular-triangular multiplication") # TODO: rethink
end
return C
end

# triangular
LinearAlgebra.generic_trimatmul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
LinearAlgebra.generic_mattrimul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C)
LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::AbstractMatrix{T}) where {T<:onemklFloat} =
trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C)
70 changes: 70 additions & 0 deletions lib/mkl/wrappers_blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,76 @@ function trsm(side::Char,
trsm!(side, uplo, transa, diag, alpha, A, copy(B))
end

for (mmname_variant, smname_variant, elty) in
((:onemklDtrmm_variant, :onemklDtrsm_variant, :Float64),
(:onemklStrmm_variant, :onemklStrsm_variant, :Float32),
(:onemklZtrmm_variant, :onemklZtrsm_variant, :ComplexF64),
(:onemklCtrmm_variant, :onemklCtrsm_variant, :ComplexF32))
@eval begin
function trmm!(side::Char,
uplo::Char,
transa::Char,
diag::Char,
alpha::Number,
beta::Number,
A::oneStridedMatrix{$elty},
B::oneStridedMatrix{$elty},
C::oneStridedMatrix{$elty})
m, n = size(B)
mA, nA = size(A)
if mA != nA throw(DimensionMismatch("A must be square")) end
if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trmm!")) end
lda = max(1,stride(A,2))
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))
queue = global_queue(context(A), device())
$mmname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, C, ldc)
B
end

function trsm!(side::Char,
uplo::Char,
transa::Char,
diag::Char,
alpha::Number,
beta::Number,
A::oneStridedMatrix{$elty},
B::oneStridedMatrix{$elty},
C::oneStridedMatrix{$elty})
m, n = size(B)
mA, nA = size(A)
if mA != nA throw(DimensionMismatch("A must be square")) end
if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trsm!")) end
lda = max(1,stride(A,2))
ldb = max(1,stride(B,2))
ldc = max(1,stride(C,2))
queue = global_queue(context(A), device())
$smname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, C, ldc)
B
end
end
end
function trmm!(side::Char,
uplo::Char,
transa::Char,
diag::Char,
alpha::Number,
A::oneStridedMatrix{T},
B::oneStridedMatrix{T},
C::oneStridedMatrix{T}) where T
trmm!(side, uplo, transa, diag, alpha, zero(T), A, B, C)
end
function trsm!(side::Char,
uplo::Char,
transa::Char,
diag::Char,
alpha::Number,
A::oneStridedMatrix{T},
B::oneStridedMatrix{T},
C::oneStridedMatrix{T}) where T
trsm!(side, uplo, transa, diag, alpha, zero(T), A, B, C)
end

## hemm
for (fname, elty) in ((:onemklZhemm,:ComplexF64),
(:onemklChemm,:ComplexF32))
Expand Down
24 changes: 24 additions & 0 deletions test/onemkl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,14 @@ end
# move to host and compare
h_C = Array(dB)
@test C ≈ h_C

C = rand(T,m,n)
dC = oneArray(C)
beta = zero(T) # rand(T)
oneMKL.trmm!('L','U','N','N',alpha,beta,dA,dB,dC)
h_C = Array(dC)
D = alpha*A*B + beta*C
@test D ≈ h_C
end

@testset "trmm" begin
Expand All @@ -684,6 +692,14 @@ end
dC = copy(dB)
oneMKL.trsm!('L','U','N','N',alpha,dA,dC)
@test C ≈ Array(dC)

C = rand(T,m,n)
dC = oneArray(C)
beta = rand(T)
oneMKL.trsm!('L','U','N','N',alpha,beta,dA,dB,dC)
h_C = Array(dC)
D = alpha*(A\B) + beta*C
@test D ≈ h_C
end

@testset "left trsm" begin
Expand Down Expand Up @@ -725,6 +741,14 @@ end
dC = copy(dA)
oneMKL.trsm!('R','U','N','N',alpha,dB,dC)
@test C ≈ Array(dC)

C = rand(T,m,m)
dC = oneArray(C)
beta = rand(T)
oneMKL.trsm!('R','U','N','N',alpha,beta,dA,dB,dC)
h_C = Array(dC)
D = alpha*(A/B) + beta*C
@test D ≈ h_C
end

@testset "right trsm" begin
Expand Down