Skip to content

Commit

Permalink
improve chunk partitioning
Browse files Browse the repository at this point in the history
  • Loading branch information
guo-yong-zhi committed Oct 16, 2024
1 parent 41e414a commit 2ee4963
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ function filttrain!(qtrees, inpool, outpool, nearlevel2; optimiser,
nchunks = min(length(queue), max(1, length(inpool)÷4))
Threads.@threads for ichunk in 1:nchunks
que = @inbounds queue[ichunk]
for ind in ichunk : nchunks : length(inpool)
for ind in QTrees.index_chunk(length(inpool), nchunks, ichunk)
i1, i2 = inpool[ind]
cp = QTrees._collision_randbfs(qtrees[i1], qtrees[i2], empty!(que))
if cp[1] >= nearlevel2
Expand Down
10 changes: 8 additions & 2 deletions src/qtree_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ end
const CoItem = Pair{Tuple{Int,Int}, Index}
const AbstractThreadQueue = AbstractVector{<:AbstractVector{Index}}
thread_queue() = [Vector{Tuple{Int,Int,Int}}() for i = 1:2Threads.nthreads()]
function index_chunk(l, n, ichunk) # ChunkSplitters.jl
n_per_chunk, n_remaining = divrem(l, n)
first = 1 + (ichunk - 1) * n_per_chunk + ifelse(ichunk <= n_remaining, ichunk - 1, n_remaining)
last = (first - 1) + n_per_chunk + ifelse(ichunk <= n_remaining, 1, 0)
return first:last
end
# assume inkernelbounds(qtree, at) is true
function _totalcollisions_native(qtrees::AbstractVector, copairs;
colist=Vector{CoItem}(),
Expand All @@ -77,7 +83,7 @@ function _totalcollisions_native(qtrees::AbstractVector, copairs;
nchunks = min(length(queue), max(1, length(copairs)÷4))
Threads.@threads for ichunk in 1:nchunks
que = @inbounds queue[ichunk]
for ind in ichunk : nchunks : length(copairs)
for ind in index_chunk(length(copairs), nchunks, ichunk)
i1, i2 = copairs[ind]
empty!(que)
push!(que, at)
Expand All @@ -95,7 +101,7 @@ function _totalcollisions_native(qtrees::AbstractVector, coitems::Vector{CoItem}
nchunks = min(length(queue), max(1, length(coitems)÷4))
Threads.@threads for ichunk in 1:nchunks
que = @inbounds queue[ichunk]
for ind in ichunk : nchunks : length(coitems)
for ind in index_chunk(length(coitems), nchunks, ichunk)
(i1, i2), at = coitems[ind]
empty!(que)
push!(que, at)
Expand Down

0 comments on commit 2ee4963

Please sign in to comment.