From 113fea845dd20a100ade570f679b45b0c6a9e139 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Fri, 19 Apr 2024 13:06:53 +0200 Subject: [PATCH] Add generic_trimatmul and generic_mattrimul --- src/host/linalg.jl | 14 ++++++++++++++ test/testsuite/linalg.jl | 31 +++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index a2a99019..570162d7 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -380,6 +380,20 @@ function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::Abs generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add.alpha, _add.beta) end +if VERSION >= v"1.10-" + +function LinearAlgebra.generic_trimatmul!(C::AbstractGPUVecOrMat, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix, B::AbstractGPUVecOrMat) + tA = tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C' + generic_matmatmul!(C, wrap(A, tA), B, 1, 1) +end + +function LinearAlgebra.generic_mattrimul!(C::AbstractGPUMatrix, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix, B::AbstractGPUMatrix) + tA = tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C' + generic_matmatmul!(C, wrap(A, tA), B, 1, 1) +end + +end + if VERSION < v"1.10.0-DEV.1365" # catch other functions that are called by LinearAlgebra's mul! function LinearAlgebra.gemv!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, a::Number, b::Number) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index da858608..57361008 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -132,6 +132,37 @@ @test istriu(A) == istriu(B) end end + + @testset "trimatmul" begin + n = 128 + b = AT(rand(Float32, n)) + B = AT(rand(Float32, n, n)) + + At = UpperTriangular(AT(rand(Float32, n,n))) + A = AT(At) + + @test At * b ≈ A * b + @test At * B ≈ A * B + + At = UnitUpperTriangular(AT(rand(Float32, n,n))) + A = AT(At) + + @test At * b ≈ A * b + @test At * B ≈ A * B + + + At = LowerTriangular(AT(rand(Float32, n,n))) + A = AT(At) + + @test At * b ≈ A * b + @test At * B ≈ A * B + + At = UnitLowerTriangular(AT(rand(Float32, n,n))) + A = AT(At) + + @test At * b ≈ A * b + @test At * B ≈ A * B + end end @testset "diagonal" begin