From 1b1259e08f6cf00b0ed47429184dc3cbaaf6e75c Mon Sep 17 00:00:00 2001 From: janko Date: Wed, 5 Jun 2024 13:51:49 +0200 Subject: [PATCH] Further accelerate fast stack method --- src/preprocessing.jl | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/preprocessing.jl b/src/preprocessing.jl index 9eb8d12..5cb9c7b 100644 --- a/src/preprocessing.jl +++ b/src/preprocessing.jl @@ -1,22 +1,25 @@ function _fast_stack_sparse(vecs::Vector{SparseVector{T1, T2}}) where {T1 <: Real, T2 <: Integer} """Fast method for stacking sparse columns""" - n_rows = length(vecs[1]) - @assert all(length(x) == n_rows for x in vecs) + n_rows_total = length(vecs[1]) + @assert all(length(x) == n_rows_total for x in vecs) + nnz_total = sum(nnz(x) for x in vecs) + rids, cids, nzvals = zeros(T2, nnz_total), zeros(T2, nnz_total), zeros(T1, nnz_total) - rids, cids, nzvals = Int[], Int[], T1[] - - for (col_i, v) in enumerate(vecs) + nnz_i = 1 + @inbounds for (col_i, v) in enumerate(vecs) n_val = nnz(v) if n_val > 0 - append!(rids, rowvals(v)) - append!(cids, repeat([col_i], n_val)) - append!(nzvals, nonzeros(v)) + nnz_range = nnz_i:nnz_i+n_val-1 + rids[nnz_range] .= rowvals(v) + cids[nnz_range] .= col_i + nzvals[nnz_range] .= nonzeros(v) + nnz_i += n_val end end - + n_cols = length(vecs) - return sparse(rids, cids, nzvals, n_rows, n_cols) + return sparse(rids, cids, nzvals, n_rows_total, n_cols) end function stack_or_hcat(vecs::AbstractVector{<:AbstractArray})