From e38d0ef2f874fdd1eb6064b826cfc70f79a4c4dc Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 4 Oct 2023 00:46:09 +0530 Subject: [PATCH] replace ind2sub/sub2ind by CartesianIndices/LinearIndices --- src/sparsematrix.jl | 11 ++++++++--- src/sparsevector.jl | 13 +++++++++---- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/src/sparsematrix.jl b/src/sparsematrix.jl index e2dfacce..be671e5b 100644 --- a/src/sparsematrix.jl +++ b/src/sparsematrix.jl @@ -3082,12 +3082,15 @@ function getindex(A::AbstractSparseMatrixCSC{Tv,Ti}, I::AbstractArray) where {Tv colptrB[colB] = 1 idxB = 1 + CartIndsA = CartesianIndices(szA) + CartIndsB = CartesianIndices(szB) + for i in 1:n @boundscheck checkbounds(A, I[i]) - row,col = Base._ind2sub(szA, I[i]) + row,col = Tuple(CartIndsA[I[i]]) for r in colptrA[col]:(colptrA[col+1]-1) @inbounds if rowvalA[r] == row - rowB,colB = Base._ind2sub(szB, i) + rowB,colB = Tuple(CartIndsB[i]) colptrB[colB+1] += 1 rowvalB[idxB] = rowB nzvalB[idxB] = nzvalA[r] @@ -3591,13 +3594,15 @@ function setindex!(A::AbstractSparseMatrixCSC, x::AbstractArray, Ix::AbstractVec isa(x, AbstractArray) && setindex_shape_check(x, length(I)) + CartIndsA = CartesianIndices(szA) + lastcol = 0 (nrowA, ncolA) = szA @inbounds for xidx in 1:n sxidx = S[xidx] (sxidx < n) && (I[sxidx] == I[sxidx+1]) && continue - row,col = Base._ind2sub(szA, I[sxidx]) + row,col = Tuple(CartIndsA[I[sxidx]]) v = x[sxidx] if col > lastcol diff --git a/src/sparsevector.jl b/src/sparsevector.jl index fc555e69..73bd525b 100644 --- a/src/sparsevector.jl +++ b/src/sparsevector.jl @@ -779,9 +779,12 @@ function getindex(A::AbstractSparseMatrixCSC{Tv}, I::AbstractUnitRange) where Tv rowvalB = Vector{Int}(undef, nnzB) nzvalB = Vector{Tv}(undef, nnzB) + CartIndsA = CartesianIndices(szA) + LinIndsA = LinearIndices(szA) + if nnzB > 0 - rowstart,colstart = Base._ind2sub(szA, first(I)) - rowend,colend = Base._ind2sub(szA, last(I)) + rowstart,colstart = Tuple(CartIndsA[first(I)]) + rowend,colend = Tuple(CartIndsA[last(I)]) idxB = 1 @inbounds for col in colstart:colend @@ -790,7 +793,7 @@ function getindex(A::AbstractSparseMatrixCSC{Tv}, I::AbstractUnitRange) where Tv for r in colptrA[col]:(colptrA[col+1]-1) rowA = rowvalA[r] if minrow <= rowA <= maxrow - rowvalB[idxB] = Base._sub2ind(szA, rowA, col) - first(I) + 1 + rowvalB[idxB] = LinIndsA[rowA, col] - first(I) + 1 nzvalB[idxB] = nzvalA[r] idxB += 1 end @@ -818,9 +821,11 @@ function getindex(A::AbstractSparseMatrixCSC{Tv,Ti}, I::AbstractVector) where {T rowvalB = Vector{Ti}(undef, nnzB) nzvalB = Vector{Tv}(undef, nnzB) + CartIndsA = CartesianIndices(szA) + idxB = 1 for i in 1:n - row,col = Base._ind2sub(szA, I[i]) + row,col = Tuple(CartIndsA[I[i]]) for r in colptrA[col]:(colptrA[col+1]-1) @inbounds if rowvalA[r] == row if idxB <= nnzB