Skip to content

Commit

Permalink
Add specializations for istriu/istril to speed up isdiag. (#502)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Nov 24, 2023
1 parent 15bf446 commit 6becb4f
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
32 changes: 32 additions & 0 deletions src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,38 @@ function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
return A
end

# check if upper triangular starting from the kth superdiagonal.
function LinearAlgebra.istriu(A::AbstractGPUMatrix, k::Integer = 0)
function mapper(a, I)
row, col = Tuple(I)
if col < row + k
return iszero(a)
else
true
end
end
function reducer(a, b)
a && b
end
mapreduce(mapper, reducer, A, eachindex(IndexCartesian(), A); init=true)
end

# check if lower triangular starting from the kth subdiagonal.
function LinearAlgebra.istril(A::AbstractGPUMatrix, k::Integer = 0)
function mapper(a, I)
row, col = Tuple(I)
if col > row + k
return iszero(a)
else
true
end
end
function reducer(a, b)
a && b
end
mapreduce(mapper, reducer, A, eachindex(IndexCartesian(), A); init=true)
end


## diagonal

Expand Down
16 changes: 16 additions & 0 deletions test/testsuite/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,22 @@
@test gpu_c isa TR
end
end

@testset "istril" begin
for rows in 3:4, cols in 3:4, diag in -4:4
A = tril(rand(Float32, rows,cols), diag)
B = AT(A)
@test istril(A) == istril(B)
end
end

@testset "istriu" begin
for rows in 3:4, cols in 3:4, diag in -4:4
A = triu(rand(Float32, rows,cols), diag)
B = AT(A)
@test istriu(A) == istriu(B)
end
end
end

@testset "diagonal" begin
Expand Down

0 comments on commit 6becb4f

Please sign in to comment.