diff --git a/src/abstractsparse.jl b/src/abstractsparse.jl index 5eb81f7b..40360b12 100644 --- a/src/abstractsparse.jl +++ b/src/abstractsparse.jl @@ -124,6 +124,14 @@ julia> findnz(A) """ function findnz end +""" + iternz(A::AbstractSparseArray) + +Equivalent to `zip(findnz(A)...)` but does not allocated +``` +""" +function iternz end + widelength(x::AbstractSparseArray) = prod(Int64.(size(x))) diff --git a/src/sparsematrix.jl b/src/sparsematrix.jl index cc42959e..d2acbac5 100644 --- a/src/sparsematrix.jl +++ b/src/sparsematrix.jl @@ -1772,6 +1772,37 @@ function sparse_sortedlinearindices!(I::Vector{Ti}, V::Vector, m::Int, n::Int) w return SparseMatrixCSC(m, n, colptr, I, V) end + +abstract type SparseIndexIterate end +@inline getcolptr(x::SparseIndexIterate) = getcolptr(x.m) +@inline getrowval(x::SparseIndexIterate) = getrowval(x.m) +@inline getnzval(x::SparseIndexIterate) = getnzval(x.m) +@inline nonzeroinds(x::SparseIndexIterate) = nonzeroinds(x.m) +@inline nonzeros(x::SparseIndexIterate) = nonzeros(x.m) +@inline nnz(x::SparseIndexIterate) = nnz(x.m) + +Base.length(x::SparseIndexIterate) = nnz(x.m) +Base.size(x::SparseIndexIterate) = size(x.m) +Base.size(x::SparseIndexIterate, i) = size(x.m)[i] +struct IterateNZCSC{T<: AbstractSparseMatrixCSC} <: SparseIndexIterate + m::T +end + +Base.eltype(::IterateNZCSC{T}) where {Ti, Tv, T <: AbstractSparseMatrixCSC{Tv, Ti}} = Tuple{Ti, Ti, Tv} +Base.iterate(x::IterateNZCSC, state=(1, 0)) = @inbounds let (j, ind) = state + ind += 1 + while (j < size(x, 2)) && (ind > getcolptr(x)[j + 1] - 1) + j += 1 + end + (j > size(x, 2) || ind > getcolptr(x)[end] - 1) && return nothing + + (getrowval(x)[ind], j, getnzval(x)[ind]), (j, ind) +end + +iternz(S::AbstractSparseMatrixCSC) = IterateNZCSC(S) + + + """ sprand([rng],[type],m,[n],p::AbstractFloat,[rfn]) diff --git a/src/sparsevector.jl b/src/sparsevector.jl index efb1b4e6..d8424ae2 100644 --- a/src/sparsevector.jl +++ b/src/sparsevector.jl @@ -835,6 +835,26 @@ function _sparse_findprevnz(v::SparseVector, i::Integer) end end + + +struct IterateSparseVec{T<: AbstractSparseVector} <: SparseIndexIterate + m::T +end + +Base.eltype(::IterateSparseVec{T}) where {Ti, Tv, T <: AbstractSparseVector{Tv, Ti}} = Tuple{Ti, Tv} + +Base.iterate(x::IterateSparseVec, state=0) = @inbounds begin + state += 1 + if state > nnz(x) + nothing + else + (nonzeroinds(x)[state], nonzeros(x)[state]), state + end +end + +iternz(S::AbstractSparseVector) = IterateSparseVec(S) + + ### Generic functions operating on AbstractSparseVector ### getindex diff --git a/test/sparsematrix_constructors_indexing.jl b/test/sparsematrix_constructors_indexing.jl index f6875fde..1ed01f74 100644 --- a/test/sparsematrix_constructors_indexing.jl +++ b/test/sparsematrix_constructors_indexing.jl @@ -1578,4 +1578,10 @@ end @test_throws ArgumentError SparseArrays.expandptr([2; 3]) end +@testset "iteratenz" begin + for i in 1:20 + A = sprandn(100, 100, 1 / i) + @test collect(SparseArrays.iternz(A)) == collect(zip(findnz(A)...)) + end +end end \ No newline at end of file diff --git a/test/sparsevector.jl b/test/sparsevector.jl index 790d4aa3..63f14e71 100644 --- a/test/sparsevector.jl +++ b/test/sparsevector.jl @@ -345,6 +345,13 @@ end @test findnz(@view Xc[:,2]) == ([2], [1.25]) end end + +@testset "iteratenz (vector)" begin + for i in 1:10 + A = sprandn(100, i / 100) + @test collect(SparseArrays.iternz(A)) == collect(zip(findnz(A)...)) + end +end ### Array manipulation @testset "copy[!]" begin