Skip to content

Commit

Permalink
add the iteratenz API
Browse files Browse the repository at this point in the history
  • Loading branch information
SobhanMP committed Jun 26, 2022
1 parent 16b28ce commit fab4d98
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/abstractsparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,12 @@ julia> findnz(A)
"""
function findnz end

"""
iternz(A::AbstractSparseArray)
Equivalent to `zip(findnz(A)...)` but does not allocated
```
"""
function iternz end

widelength(x::AbstractSparseArray) = prod(Int64.(size(x)))
31 changes: 31 additions & 0 deletions src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1700,6 +1700,37 @@ function sparse_sortedlinearindices!(I::Vector{Ti}, V::Vector, m::Int, n::Int) w
return SparseMatrixCSC(m, n, colptr, I, V)
end


abstract type SparseIndexIterate end
@inline getcolptr(x::SparseIndexIterate) = getcolptr(x.m)
@inline getrowval(x::SparseIndexIterate) = getrowval(x.m)
@inline getnzval(x::SparseIndexIterate) = getnzval(x.m)
@inline nonzeroinds(x::SparseIndexIterate) = nonzeroinds(x.m)
@inline nonzeros(x::SparseIndexIterate) = nonzeros(x.m)
@inline nnz(x::SparseIndexIterate) = nnz(x.m)

Base.length(x::SparseIndexIterate) = nnz(x.m)
Base.size(x::SparseIndexIterate) = size(x.m)
Base.size(x::SparseIndexIterate, i) = size(x.m)[i]
struct IterateNZCSC{T<: AbstractSparseMatrixCSC} <: SparseIndexIterate
m::T
end

Base.eltype(::IterateNZCSC{T}) where {Ti, Tv, T <: AbstractSparseMatrixCSC{Tv, Ti}} = Tuple{Ti, Ti, Tv}
Base.iterate(x::IterateNZCSC, state=(1, 0)) = @inbounds let (j, ind) = state
ind += 1
while (j < size(x, 2)) && (ind > getcolptr(x)[j + 1] - 1)
j += 1
end
(j > size(x, 2) || ind > getcolptr(x)[end] - 1) && return nothing

(getrowval(x)[ind], j, getnzval(x)[ind]), (j, ind)
end

iternz(S::AbstractSparseMatrixCSC) = IterateNZCSC(S)



"""
sprand([rng],[type],m,[n],p::AbstractFloat,[rfn])
Expand Down
20 changes: 20 additions & 0 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,26 @@ function _sparse_findprevnz(v::SparseVector, i::Integer)
end
end



struct IterateSparseVec{T<: AbstractSparseVector} <: SparseIndexIterate
m::T
end

Base.eltype(::IterateSparseVec{T}) where {Ti, Tv, T <: AbstractSparseVector{Tv, Ti}} = Tuple{Ti, Tv}

Base.iterate(x::IterateSparseVec, state=0) = @inbounds begin
state += 1
if state > nnz(x)
nothing
else
(nonzeroinds(x)[state], nonzeros(x)[state]), state
end
end

iternz(S::AbstractSparseVector) = IterateSparseVec(S)


### Generic functions operating on AbstractSparseVector

### getindex
Expand Down
6 changes: 6 additions & 0 deletions test/sparsematrix_constructors_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1578,4 +1578,10 @@ end
@test_throws ArgumentError SparseArrays.expandptr([2; 3])
end

@testset "iteratenz" begin
for i in 1:20
A = sprandn(100, 100, 1 / i)
@test collect(SparseArrays.iternz(A)) == collect(zip(findnz(A)...))
end
end
end
7 changes: 7 additions & 0 deletions test/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,13 @@ end
@test findnz(xc) == ([2, 3, 5], [1.25, 0, -0.75])
end
end

@testset "iteratenz (vector)" begin
for i in 1:10
A = sprandn(100, i / 100)
@test collect(SparseArrays.iternz(A)) == collect(zip(findnz(A)...))
end
end
### Array manipulation

@testset "copy[!]" begin
Expand Down

0 comments on commit fab4d98

Please sign in to comment.