From 4fd3aad5735e3b80eefe7b068f3407d7dd0c0924 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Thu, 5 Dec 2024 13:39:20 +0530 Subject: [PATCH] Generalize `istriu`/`istril` to accept a band index (#590) Currently, only `istriu(S)` and `istril(S)` are specialized for sparse matrices, and this PR generalizes these to accept the band index `k`. This improves performance. --- src/sparsematrix.jl | 10 ++++++---- test/sparsematrix_ops.jl | 13 +++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/sparsematrix.jl b/src/sparsematrix.jl index 4fa2adf9..6d6d880d 100644 --- a/src/sparsematrix.jl +++ b/src/sparsematrix.jl @@ -4141,7 +4141,7 @@ function is_hermsym(A::AbstractSparseMatrixCSC, check::Function) return true end -function istriu(A::AbstractSparseMatrixCSC) +function istriu(A::AbstractSparseMatrixCSC, k::Integer=0) m, n = size(A) colptr = getcolptr(A) rowval = rowvals(A) @@ -4150,7 +4150,8 @@ function istriu(A::AbstractSparseMatrixCSC) for col = 1:min(n, m-1) l1 = colptr[col+1]-1 for i = 0 : (l1 - colptr[col]) - if rowval[l1-i] <= col + if rowval[l1-i] <= col - k + # rows preceeding the index would also lie above the band break end if _isnotzero(nzval[l1-i]) @@ -4161,7 +4162,7 @@ function istriu(A::AbstractSparseMatrixCSC) return true end -function istril(A::AbstractSparseMatrixCSC) +function istril(A::AbstractSparseMatrixCSC, k::Integer=0) m, n = size(A) colptr = getcolptr(A) rowval = rowvals(A) @@ -4169,7 +4170,8 @@ function istril(A::AbstractSparseMatrixCSC) for col = 2:n for i = colptr[col] : (colptr[col+1]-1) - if rowval[i] >= col + if rowval[i] >= col - k + # subsequent rows would also lie below the band break end if _isnotzero(nzval[i]) diff --git a/test/sparsematrix_ops.jl b/test/sparsematrix_ops.jl index 94e268e0..d99418a8 100644 --- a/test/sparsematrix_ops.jl +++ b/test/sparsematrix_ops.jl @@ -626,4 +626,17 @@ end @test_throws ArgumentError copytrito!(M, S, 'M') end +@testset "istriu/istril" begin + for T in Any[Tridiagonal(1:3, 1:4, 1:3), + Bidiagonal(1:4, 1:3, :U), Bidiagonal(1:4, 1:3, :L), + Diagonal(1:4), + diagm(-2=>1:2, 2=>1:2)] + S = sparse(T) + for k in -5:5 + @test istriu(S, k) == istriu(T, k) + @test istril(S, k) == istril(T, k) + end + end +end + end # module