Skip to content

Commit

Permalink
refactor: simplify and optimize dynamic collision handling
Browse files Browse the repository at this point in the history
  • Loading branch information
guo-yong-zhi committed Nov 7, 2024
1 parent 7136051 commit 7a2dd48
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 72 deletions.
44 changes: 11 additions & 33 deletions examples/dynamiccollisions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,23 @@ using Stuffing
qts = Stuffing.randqtrees(300);
println("visualization:")
println(repr("text/plain", overlap(qts)))

println("""To get all collisions via `partialcollisions`, the set `updated` should contains all collided labels in last round.
And the labels of other moved objects, if any, should be contained too.
`partialcollisions` is faster than `totalcollisions` when the set `updated` is small.
""")
updated = Set(1:length(qts));
spqtree = linked_spacial_qtree(qts);
for i in 1:10
C1 = partialcollisions(qts, spqtree, updated);
union!(updated, first.(C1) |> Iterators.flatten) #all collided labels
begin #test
C2 = QTrees.totalcollisions_native(qts);
@assert length(C1) == length(C2)
@assert Set(Set.(first.(C1))) == Set(Set.(first.(C2)))
end
Stuffing.Trainer.batchsteps!(qts, C1) #move objects in C1
begin #other moved labels
QTrees.shift!(qts[3], 1, 1, 1)
QTrees.shift!(qts[7], 1, -1, -1)
union!(updated, [3, 7])
end
function test_collision_results(C1, qts)
C2 = QTrees.totalcollisions_native(qts);
C3 = QTrees.totalcollisions(qts);
# sometimes C2!=C3. When objects are out of the scope, the `totalcollisions_native` will miss them.
# But `totalcollisions` may not (not promise).
@assert length(C1) == length(C2) || length(C1) == length(C3)
@assert Set(Set.(first.(C1))) == Set(Set.(first.(C2))) || Set(Set.(first.(C1))) == Set(Set.(first.(C3)))
end
println("Things are similar but much simpler for `dynamiccollisions`.")
println("And `dynamiccollisions` is faster than `partialcollisions` when the `updated` is not that small.")

colliders = DynamicColliders(qts);
for i in 1:10
C1 = dynamiccollisions(colliders);
begin #test
C2 = QTrees.totalcollisions_native(qts);
C3 = QTrees.totalcollisions(qts);
#sometimes C2!=C3. When objects are out of the scope, the `totalcollisions_native` will miss them.
#But `totalcollisions` may not (not promise).
@assert length(C1) == length(C2) || length(C1) == length(C3)
@assert Set(Set.(first.(C1))) == Set(Set.(first.(C2))) || Set(Set.(first.(C1))) == Set(Set.(first.(C3)))
end
test_collision_results(C1, qts)
Stuffing.Trainer.batchsteps!(qts, C1) #move objects in C1
begin #other moved labels
QTrees.shift!(qts[3], 1, -1, -1)
QTrees.shift!(qts[7], 1, 1, 1)
union!(colliders.updated, [3, 7]) #other updated labels
union!(colliders, [3, 7]) #other updated labels
end
end
end
12 changes: 5 additions & 7 deletions src/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,20 +216,18 @@ trainepoch_D!(;inputs) = Dict(:colist => Vector{QTrees.CoItem}(),
:queue => QTrees.thread_queue(),
:itemlist => Vector{QTrees.CoItem}(),
:pairlist => Vector{Tuple{Int, Int}}(),
:updated => QTrees.UpdatedSet(1:length(inputs)),
:updated => DynamicColliders(inputs),
:loops => 10,
:spqtree => QTrees.linked_spacial_qtree(inputs), #fllowing 3 tiems: pre-allocating for dynamiccollisions
:sptqree2 => QTrees.hash_spacial_qtree(inputs),
:sptqree2 => QTrees.hash_spacial_qtree(inputs), #fllowing 2 tiems: pre-allocating for dynamiccollisions
:treenodestack => Vector{QTrees.SpacialQTreeNode}())
trainepoch_D!(s::Symbol) = get(Dict(:patience => 10, :epochs => 2000), s, nothing)
function trainepoch_D!(qtrees::AbstractVector{<:ShiftedQTree}; spqtree, optimiser=(t, Δ) -> Δ ./ 6,
colist=Vector{QTrees.CoItem}(), updated=Set{Int}(), loops=1, kargs...)
function trainepoch_D!(qtrees::AbstractVector{<:ShiftedQTree}; optimiser=(t, Δ) -> Δ ./ 6,
colist=Vector{QTrees.CoItem}(), updated=DynamicColliders(qtrees), loops=1, kargs...)
for ni in 1:loops
dynamiccollisions(qtrees, spqtree, updated; colist=empty!(colist), kargs...)
dynamiccollisions(updated; colist=empty!(colist), kargs...)
nc = length(colist)
if nc == 0 return nc end
batchsteps!(qtrees, colist, optimiser)
union!(updated, first.(colist) |> Iterators.flatten)
end
length(colist)
end
Expand Down
65 changes: 37 additions & 28 deletions src/qtree_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,13 @@ function locate!(qt::AbstractStackedQTree, spqtree::Union{HashSpacialQTree, Link
push!(spqtree, inds, label)
nothing
end
function locate!(qts::AbstractVector, spqtree=hash_spacial_qtree(qts))
function locate!(qts::AbstractVector, spqtree::Union{HashSpacialQTree, LinkedSpacialQTree}=hash_spacial_qtree(qts))
for (i, qt) in enumerate(qts)
locate!(qt, spqtree, i)
end
spqtree
end
function locate!(qts::AbstractVector, labels::Union{AbstractVector{Int},AbstractSet{Int}}, spqtree=hash_spacial_qtree(qts))
function locate!(qts::AbstractVector, labels, spqtree=hash_spacial_qtree(qts))
for l in labels
locate!(qts[l], spqtree, l)
end
Expand Down Expand Up @@ -313,51 +313,60 @@ function partialcollisions(qtrees::AbstractVector,
end
end
end
empty!(labels)
# @show length(itemlist), length(colist)
r = _totalcollisions_native(qtrees, itemlist; colist=colist, kargs...)
unique ? unique!(first, sort!(r)) : r
end
mutable struct UpdatedSet{T} <: AbstractSet{T}
updatelen::Int
maxlen::Int
set::Set{T}
end
UpdatedSet(maxlen::Int) = UpdatedSet(maxlen, maxlen, Set{Int}(1:maxlen))
UpdatedSet(labels, maxlen::Int=length(labels)) = UpdatedSet(length(labels), maxlen, Set{Int}(labels))
function Base.union!(s::UpdatedSet, c)
s.updatelen = length(c)
length(s.set) == s.maxlen ? s : union!(s.set, c)
end
Base.empty!(s::UpdatedSet) = empty!(s.set)
Base.length(s::UpdatedSet) = length(s.set)
Base.iterate(s::UpdatedSet, args...) = iterate(s.set, args...)
Base.in(item, s::UpdatedSet) = in(item, s.set)
Base.in(s::UpdatedSet) = in(s.set)
# mutable struct UpdatedSet{T} <: AbstractSet{T}
# updatelen::Int
# maxlen::Int
# set::Set{T}
# end
# UpdatedSet(maxlen::Int) = UpdatedSet(maxlen, maxlen, Set{Int}(1:maxlen))
# UpdatedSet(labels, maxlen::Int=length(labels)) = UpdatedSet(length(labels), maxlen, Set{Int}(labels))
# function Base.union!(s::UpdatedSet, c)
# s.updatelen = length(c)
# length(s.set) == s.maxlen ? s : union!(s.set, c)
# end
# Base.empty!(s::UpdatedSet) = empty!(s.set)
# Base.length(s::UpdatedSet) = length(s.set)
# Base.iterate(s::UpdatedSet, args...) = iterate(s.set, args...)
# Base.in(item, s::UpdatedSet) = in(item, s.set)
# Base.in(s::UpdatedSet) = in(s.set)
function totalcollisions_kw(qtrees; sptqree2=hash_spacial_qtree(qtrees),
spqtree=nothing, treenodestack=nothing, kargs...)
totalcollisions(qtrees; spqtree=sptqree2, kargs...)
end
partialcollisions_kw(qtrees, spqtree, updated; sptqree2=nothing, pairlist=nothing, kargs...) = partialcollisions(qtrees, spqtree, updated; kargs...)
partialcollisions_kw(qtrees, spqtree, labels; sptqree2=nothing, pairlist=nothing, kargs...) = partialcollisions(qtrees, spqtree, labels; kargs...)
function dynamiccollisions(qtrees::AbstractVector,
spqtree::LinkedSpacialQTree=linked_spacial_qtree(qtrees),
updated::UpdatedSet{Int}=UpdatedSet(1:length(qtrees));
labels::AbstractSet{Int}=Set(1:length(qtrees));
updated::AbstractSet{Int}=Set(1:length(qtrees)),
kargs...)
if updated.updatelen / updated.maxlen > 0.6
return totalcollisions_kw(qtrees; kargs...)
if length(labels) / length(qtrees) > 0.6
r = totalcollisions_kw(qtrees; kargs...)
else
return partialcollisions_kw(qtrees, spqtree, updated; kargs...)
locate!(qtrees, (i for i in updated if i labels), spqtree)
empty!(updated)
r = partialcollisions_kw(qtrees, spqtree, labels; kargs...)
end
empty!(labels)
union!(labels, first.(r) |> Iterators.flatten)
r
end
struct DynamicColliders
qtrees::Vector{U8SQTree}
spqtree::LinkedSpacialQTree
updated::UpdatedSet
colabels::Set{Int}
updated::Set{Int}
end
function Base.union!(dc::DynamicColliders, c)
union!(dc.colabels, c)
union!(dc.updated, c)
end
DynamicColliders(qtrees::AbstractVector{U8SQTree}) = DynamicColliders(qtrees, linked_spacial_qtree(qtrees), UpdatedSet(length(qtrees)))
DynamicColliders(qtrees::AbstractVector{U8SQTree}) = DynamicColliders(qtrees, linked_spacial_qtree(qtrees), Set(1:length(qtrees)), Set(1:length(qtrees)))
function dynamiccollisions(colliders::DynamicColliders; kargs...)
r = dynamiccollisions(colliders.qtrees, colliders.spqtree, colliders.updated; kargs...)
union!(colliders.updated, first.(r) |> Iterators.flatten)
r = dynamiccollisions(colliders.qtrees, colliders.spqtree, colliders.colabels; updated=colliders.updated, kargs...)
r
end
########## place!
Expand Down
9 changes: 5 additions & 4 deletions test/test_qtrees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ testqtree = Stuffing.testqtree
QTrees.shift!(qts[7], 1, -1, -1)
union!(updated, [3, 7]) #other updated labels
end
updated = QTrees.UpdatedSet(length(qts));
colabels = Set(1:length(qts));
updated = Set(1:length(qts));
spqtree = linked_spacial_qtree(qts);
for i in 1:10
C1 = dynamiccollisions(qts, spqtree, updated);
union!(updated, first.(C1) |> Iterators.flatten) #all collided labels
C1 = dynamiccollisions(qts, spqtree, colabels; updated=updated);
C2 = QTrees.totalcollisions_native(qts);
C3 = QTrees.totalcollisions(qts);
@test length(C1) == length(C2) || length(C1) == length(C3)
Expand All @@ -94,6 +94,7 @@ testqtree = Stuffing.testqtree
QTrees.shift!(qts[3], 1, -1, -1)
QTrees.shift!(qts[7], 1, 1, 1)
union!(updated, [3, 7]) #other updated labels
union!(colabels, [3, 7]) #other updated labels
end
colliders = DynamicColliders(qts);
for i in 1:10
Expand All @@ -108,7 +109,7 @@ testqtree = Stuffing.testqtree
begin #other moved labels
QTrees.shift!(qts[3], 1, -1, -1)
QTrees.shift!(qts[7], 1, 1, 1)
union!(colliders.updated, [3, 7]) #other updated labels
union!(colliders, [3, 7]) #other updated labels
end
end
#edge cases
Expand Down

0 comments on commit 7a2dd48

Please sign in to comment.