diff --git a/src/host/linalg.jl b/src/host/linalg.jl index a2a99019..570162d7 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -380,6 +380,20 @@ function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::Abs generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add.alpha, _add.beta) end +if VERSION >= v"1.10-" + +function LinearAlgebra.generic_trimatmul!(C::AbstractGPUVecOrMat, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix, B::AbstractGPUVecOrMat) + tA = tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C' + generic_matmatmul!(C, wrap(A, tA), B, 1, 1) +end + +function LinearAlgebra.generic_mattrimul!(C::AbstractGPUMatrix, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix, B::AbstractGPUMatrix) + tA = tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C' + generic_matmatmul!(C, wrap(A, tA), B, 1, 1) +end + +end + if VERSION < v"1.10.0-DEV.1365" # catch other functions that are called by LinearAlgebra's mul! function LinearAlgebra.gemv!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, a::Number, b::Number) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index da858608..57361008 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -132,6 +132,37 @@ @test istriu(A) == istriu(B) end end + + @testset "trimatmul" begin + n = 128 + b = AT(rand(Float32, n)) + B = AT(rand(Float32, n, n)) + + At = UpperTriangular(AT(rand(Float32, n,n))) + A = AT(At) + + @test At * b ≈ A * b + @test At * B ≈ A * B + + At = UnitUpperTriangular(AT(rand(Float32, n,n))) + A = AT(At) + + @test At * b ≈ A * b + @test At * B ≈ A * B + + + At = LowerTriangular(AT(rand(Float32, n,n))) + A = AT(At) + + @test At * b ≈ A * b + @test At * B ≈ A * B + + At = UnitLowerTriangular(AT(rand(Float32, n,n))) + A = AT(At) + + @test At * b ≈ A * b + @test At * B ≈ A * B + end end @testset "diagonal" begin