Skip to content

Commit

Permalink
Further accelerate fast stack method
Browse files Browse the repository at this point in the history
  • Loading branch information
jtackm committed Jun 5, 2024
1 parent 759e2ee commit 1b1259e
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions src/preprocessing.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
function _fast_stack_sparse(vecs::Vector{SparseVector{T1, T2}}) where {T1 <: Real, T2 <: Integer}
"""Fast method for stacking sparse columns"""
n_rows = length(vecs[1])
@assert all(length(x) == n_rows for x in vecs)
n_rows_total = length(vecs[1])
@assert all(length(x) == n_rows_total for x in vecs)
nnz_total = sum(nnz(x) for x in vecs)
rids, cids, nzvals = zeros(T2, nnz_total), zeros(T2, nnz_total), zeros(T1, nnz_total)

rids, cids, nzvals = Int[], Int[], T1[]

for (col_i, v) in enumerate(vecs)
nnz_i = 1
@inbounds for (col_i, v) in enumerate(vecs)
n_val = nnz(v)

if n_val > 0
append!(rids, rowvals(v))
append!(cids, repeat([col_i], n_val))
append!(nzvals, nonzeros(v))
nnz_range = nnz_i:nnz_i+n_val-1
rids[nnz_range] .= rowvals(v)
cids[nnz_range] .= col_i
nzvals[nnz_range] .= nonzeros(v)
nnz_i += n_val
end
end

n_cols = length(vecs)
return sparse(rids, cids, nzvals, n_rows, n_cols)
return sparse(rids, cids, nzvals, n_rows_total, n_cols)
end

function stack_or_hcat(vecs::AbstractVector{<:AbstractArray})
Expand Down

0 comments on commit 1b1259e

Please sign in to comment.