Skip to content

Commit

Permalink
Add generic_trimatmul and generic_mattrimul
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich committed Apr 20, 2024
1 parent 228fb9e commit 113fea8
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,20 @@ function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::Abs
generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add.alpha, _add.beta)
end

if VERSION >= v"1.10-"

function LinearAlgebra.generic_trimatmul!(C::AbstractGPUVecOrMat, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix, B::AbstractGPUVecOrMat)
tA = tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C'
generic_matmatmul!(C, wrap(A, tA), B, 1, 1)
end

function LinearAlgebra.generic_mattrimul!(C::AbstractGPUMatrix, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix, B::AbstractGPUMatrix)
tA = tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C'
generic_matmatmul!(C, wrap(A, tA), B, 1, 1)
end

end

if VERSION < v"1.10.0-DEV.1365"
# catch other functions that are called by LinearAlgebra's mul!
function LinearAlgebra.gemv!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, a::Number, b::Number)
Expand Down
31 changes: 31 additions & 0 deletions test/testsuite/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,37 @@
@test istriu(A) == istriu(B)
end
end

@testset "trimatmul" begin
n = 128
b = AT(rand(Float32, n))
B = AT(rand(Float32, n, n))

At = UpperTriangular(AT(rand(Float32, n,n)))
A = AT(At)

@test At * b A * b
@test At * B A * B

At = UnitUpperTriangular(AT(rand(Float32, n,n)))
A = AT(At)

@test At * b A * b
@test At * B A * B


At = LowerTriangular(AT(rand(Float32, n,n)))
A = AT(At)

@test At * b A * b
@test At * B A * B

At = UnitLowerTriangular(AT(rand(Float32, n,n)))
A = AT(At)

@test At * b A * b
@test At * B A * B
end
end

@testset "diagonal" begin
Expand Down

0 comments on commit 113fea8

Please sign in to comment.