From 6becb4fdfe213d327acab47225595ad9df785dfd Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 24 Nov 2023 08:24:43 +0100 Subject: [PATCH] Add specializations for istriu/istril to speed up isdiag. (#502) --- src/host/linalg.jl | 32 ++++++++++++++++++++++++++++++++ test/testsuite/linalg.jl | 16 ++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 6856f587..36bb8582 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -165,6 +165,38 @@ function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T return A end +# check if upper triangular starting from the kth superdiagonal. +function LinearAlgebra.istriu(A::AbstractGPUMatrix, k::Integer = 0) + function mapper(a, I) + row, col = Tuple(I) + if col < row + k + return iszero(a) + else + true + end + end + function reducer(a, b) + a && b + end + mapreduce(mapper, reducer, A, eachindex(IndexCartesian(), A); init=true) +end + +# check if lower triangular starting from the kth subdiagonal. +function LinearAlgebra.istril(A::AbstractGPUMatrix, k::Integer = 0) + function mapper(a, I) + row, col = Tuple(I) + if col > row + k + return iszero(a) + else + true + end + end + function reducer(a, b) + a && b + end + mapreduce(mapper, reducer, A, eachindex(IndexCartesian(), A); init=true) +end + ## diagonal diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 13d92525..24d6b8a6 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -105,6 +105,22 @@ @test gpu_c isa TR end end + + @testset "istril" begin + for rows in 3:4, cols in 3:4, diag in -4:4 + A = tril(rand(Float32, rows,cols), diag) + B = AT(A) + @test istril(A) == istril(B) + end + end + + @testset "istriu" begin + for rows in 3:4, cols in 3:4, diag in -4:4 + A = triu(rand(Float32, rows,cols), diag) + B = AT(A) + @test istriu(A) == istriu(B) + end + end end @testset "diagonal" begin