Skip to content

Commit

Permalink
rename variables
Browse files Browse the repository at this point in the history
  • Loading branch information
guo-yong-zhi committed Nov 7, 2024
1 parent d6f34a0 commit 17454d7
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 62 deletions.
4 changes: 2 additions & 2 deletions examples/dynamiccollisions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ for i in 1:10
C1 = dynamiccollisions(colliders);
test_collision_results(C1, qts)
Stuffing.Trainer.batchsteps!(qts, C1) #move objects in C1
begin #other moved labels
begin #other moved objects
QTrees.shift!(qts[3], 1, -1, -1)
QTrees.shift!(qts[7], 1, 1, 1)
union!(colliders, [3, 7]) #other updated labels
union!(colliders, [3, 7]) #other moved objects
end
end
50 changes: 25 additions & 25 deletions src/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,16 @@ trainepoch_E!(;inputs) = Dict(:colist => Vector{QTrees.CoItem}(),
:queue => QTrees.thread_queue(),
:itemlist => Vector{QTrees.CoItem}(),
:pairlist => Vector{Tuple{Int, Int}}(),
:updated => Set{Int}(),
:moved => Set{Int}(),
:spqtree => QTrees.hash_spacial_qtree(inputs))
trainepoch_E!(s::Symbol) = get(Dict(:patience => 20, :epochs => 2000), s, nothing)
function trainepoch_E!(qtrees::AbstractVector{<:ShiftedQTree}; optimiser=(t, Δ) -> Δ ./ 6,
colist=Vector{QTrees.CoItem}(), updated=Set{Int}(), kargs...)
colist=Vector{QTrees.CoItem}(), moved=Set{Int}(), kargs...)
totalcollisions(qtrees; colist=empty!(colist), kargs...)
nc = length(colist)
if nc == 0 return nc end
batchsteps!(qtrees, colist, optimiser)
inds = union!(empty!(updated), first.(colist) |> Iterators.flatten)
inds = union!(empty!(moved), first.(colist) |> Iterators.flatten)
# @show length(qtrees),length(inds)
for ni in 1:length(qtrees) ÷ length(inds)
totalcollisions(qtrees, inds; colist=empty!(colist), kargs...)
Expand All @@ -104,24 +104,24 @@ trainepoch_EM!(;inputs) = Dict(:colist => Vector{QTrees.CoItem}(),
:memory => intlru(length(inputs)),
:itemlist => Vector{QTrees.CoItem}(),
:pairlist => Vector{Tuple{Int, Int}}(),
:updated => Set{Int}(),
:moved => Set{Int}(),
:spqtree => QTrees.hash_spacial_qtree(inputs))
trainepoch_EM!(s::Symbol) = get(Dict(:patience => 10, :epochs => 1000), s, nothing)
function trainepoch_EM!(qtrees::AbstractVector{<:ShiftedQTree}; memory, optimiser=(t, Δ) -> Δ ./ 6,
colist=Vector{QTrees.CoItem}(), updated=Set{Int}(), kargs...)
colist=Vector{QTrees.CoItem}(), moved=Set{Int}(), kargs...)
totalcollisions(qtrees; colist=empty!(colist), kargs...)
nc = length(colist)
if nc == 0 return nc end
batchsteps!(qtrees, colist, optimiser)
inds = union!(empty!(updated), first.(colist) |> Iterators.flatten)
inds = union!(empty!(moved), first.(colist) |> Iterators.flatten)
# @show length(inds)
push!.(memory, inds)
inds = collect(memory, length(inds) * 2)
for ni in 1:2length(qtrees) ÷ length(inds)
totalcollisions(qtrees, inds; colist=empty!(colist), kargs...)
batchsteps!(qtrees, colist, optimiser)
if ni > 2length(colist) break end
inds2 = union!(empty!(updated), first.(colist) |> Iterators.flatten)
inds2 = union!(empty!(moved), first.(colist) |> Iterators.flatten)
# @show length(qtrees),length(inds),length(inds2)
for ni2 in 1:2length(inds) ÷ length(inds2)
totalcollisions(qtrees, inds2; colist=empty!(colist), kargs...)
Expand All @@ -135,27 +135,27 @@ end
trainepoch_EM2!(;inputs) = trainepoch_EM!(;inputs=inputs)
trainepoch_EM2!(s::Symbol) = trainepoch_EM!(s)
function trainepoch_EM2!(qtrees::AbstractVector{<:ShiftedQTree}; memory, optimiser=(t, Δ) -> Δ ./ 6,
colist=Vector{QTrees.CoItem}(), updated=Set{Int}(), kargs...)
colist=Vector{QTrees.CoItem}(), moved=Set{Int}(), kargs...)
totalcollisions(qtrees; colist=empty!(colist), kargs...)
nc = length(colist)
if nc == 0 return nc end
batchsteps!(qtrees, colist, optimiser)
inds = union!(empty!(updated), first.(colist) |> Iterators.flatten)
inds = union!(empty!(moved), first.(colist) |> Iterators.flatten)
# @show length(inds)
push!.(memory, inds)
inds = collect(memory, length(inds) * 4)
for ni in 1:2length(qtrees) ÷ length(inds)
totalcollisions(qtrees, inds; colist=empty!(colist), kargs...)
batchsteps!(qtrees, colist, optimiser)
if ni > 2length(colist) break end
inds2 = union!(empty!(updated), first.(colist) |> Iterators.flatten)
inds2 = union!(empty!(moved), first.(colist) |> Iterators.flatten)
push!.(memory, inds2)
inds2 = collect(memory, length(inds2) * 2)
for ni2 in 1:2length(inds) ÷ length(inds2)
totalcollisions(qtrees, inds2; colist=empty!(colist), kargs...)
batchsteps!(qtrees, colist, optimiser)
if ni2 > 2length(colist) break end
inds3 = union!(empty!(updated), first.(colist) |> Iterators.flatten)
inds3 = union!(empty!(moved), first.(colist) |> Iterators.flatten)
# @show length(qtrees),length(inds),length(inds2),length(inds3)
for ni3 in 1:2length(inds2) ÷ length(inds3)
totalcollisions(qtrees, inds3; colist=empty!(colist), kargs...)
Expand All @@ -171,34 +171,34 @@ end
trainepoch_EM3!(;inputs) = trainepoch_EM!(;inputs=inputs)
trainepoch_EM3!(s::Symbol) = trainepoch_EM!(s)
function trainepoch_EM3!(qtrees::AbstractVector{<:ShiftedQTree}; memory, optimiser=(t, Δ) -> Δ ./ 6,
colist=Vector{QTrees.CoItem}(), updated=Set{Int}(), kargs...)
colist=Vector{QTrees.CoItem}(), moved=Set{Int}(), kargs...)
totalcollisions(qtrees; colist=empty!(colist), kargs...)
nc = length(colist)
if nc == 0 return nc end
batchsteps!(qtrees, colist, optimiser)
inds = union!(empty!(updated), first.(colist) |> Iterators.flatten)
inds = union!(empty!(moved), first.(colist) |> Iterators.flatten)
# @show length(inds)
push!.(memory, inds)
inds = collect(memory, length(inds) * 8)
for ni in 1:2length(qtrees) ÷ length(inds)
totalcollisions(qtrees, inds; colist=empty!(colist), kargs...)
batchsteps!(qtrees, colist, optimiser)
if ni > 2length(colist) break end
inds2 = union!(empty!(updated), first.(colist) |> Iterators.flatten)
inds2 = union!(empty!(moved), first.(colist) |> Iterators.flatten)
push!.(memory, inds2)
inds2 = collect(memory, length(inds2) * 4)
for ni2 in 1:2length(inds) ÷ length(inds2)
totalcollisions(qtrees, inds2; colist=empty!(colist), kargs...)
batchsteps!(qtrees, colist, optimiser)
if ni2 > 2length(colist) break end
inds3 = union!(empty!(updated), first.(colist) |> Iterators.flatten)
inds3 = union!(empty!(moved), first.(colist) |> Iterators.flatten)
push!.(memory, inds3)
inds3 = collect(memory, length(inds3) * 2)
for ni3 in 1:2length(inds2) ÷ length(inds3)
totalcollisions(qtrees, inds3; colist=empty!(colist), kargs...)
batchsteps!(qtrees, colist, optimiser)
if ni3 > 2length(colist) break end
inds4 = union!(empty!(updated), first.(colist) |> Iterators.flatten)
inds4 = union!(empty!(moved), first.(colist) |> Iterators.flatten)
# @show length(qtrees),length(inds),length(inds2),length(inds3)
for ni4 in 1:2length(inds3) ÷ length(inds4)
totalcollisions(qtrees, inds4; colist=empty!(colist), kargs...)
Expand All @@ -216,15 +216,15 @@ trainepoch_D!(;inputs) = Dict(:colist => Vector{QTrees.CoItem}(),
:queue => QTrees.thread_queue(),
:itemlist => Vector{QTrees.CoItem}(),
:pairlist => Vector{Tuple{Int, Int}}(),
:updated => DynamicColliders(inputs),
:moved => DynamicColliders(inputs),
:loops => 10,
:sptqree2 => QTrees.hash_spacial_qtree(inputs), #fllowing 2 tiems: pre-allocating for dynamiccollisions
:spqtree => QTrees.hash_spacial_qtree(inputs), #fllowing 2 tiems: pre-allocating for dynamiccollisions
:treenodestack => Vector{QTrees.SpacialQTreeNode}())
trainepoch_D!(s::Symbol) = get(Dict(:patience => 10, :epochs => 2000), s, nothing)
function trainepoch_D!(qtrees::AbstractVector{<:ShiftedQTree}; optimiser=(t, Δ) -> Δ ./ 6,
colist=Vector{QTrees.CoItem}(), updated=DynamicColliders(qtrees), loops=1, kargs...)
colist=Vector{QTrees.CoItem}(), moved=DynamicColliders(qtrees), loops=1, kargs...)
for ni in 1:loops
dynamiccollisions(updated; colist=empty!(colist), kargs...)
dynamiccollisions(moved; colist=empty!(colist), kargs...)
nc = length(colist)
if nc == 0 return nc end
batchsteps!(qtrees, colist, optimiser)
Expand Down Expand Up @@ -482,7 +482,7 @@ function train!(ts, epochs::Number=-1, args...;
end
epochs >= 0 || (epochs = trainer(:epochs))
@debug "epochs: $epochs, " * (reposition_flag ? "patience: $patience" : "reposition off")
updated = get(resource, :updated, nothing)
moved = get(resource, :moved, nothing)
while ep < epochs
callback_pre(ep)
nc = trainer(ts, args...; resource..., optimiser=optimiser, unique=false, kargs...)
Expand All @@ -498,7 +498,7 @@ function train!(ts, epochs::Number=-1, args...;
if length(repositioned) > 0
reset!(indi_r)
reset!.(optimiser, ts[repositioned])
updated !== nothing && union!(updated, repositioned)
moved !== nothing && union!(moved, repositioned) # repositioned may be not in colist
end
repositioned_set = Set(repositioned)
if last_repositioned == repositioned_set
Expand All @@ -512,8 +512,8 @@ function train!(ts, epochs::Number=-1, args...;
@debug "The repositioning strategy failed after $ep epochs"
break
end
moved = randommove!(ts, colist)
@debug "@epoch $ep, random move $(length(moved)>0 ? moved : "nothing")"
randommoved = randommove!(ts, colist) # must be in colist
@debug "@epoch $ep, random move $(length(randommoved)>0 ? randommoved : "nothing")"
end
end
callback(ep)
Expand All @@ -525,7 +525,7 @@ function train!(ts, epochs::Number=-1, args...;
else
@debug "$outlabels out of bounds"
nc += outlen
updated !== nothing && union!(updated, outlabels)
moved !== nothing && union!(moved, outlabels)
end
end
if indi_s.age > max(1, patience, epochs / 50)
Expand Down
49 changes: 26 additions & 23 deletions src/qtree_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,16 @@ function totalcollisions(args...; kargs...)
end
end
function partialcollisions(qtrees::AbstractVector,
spqtree::LinkedSpacialQTree=linked_spacial_qtree(qtrees),
linkedspqtree::LinkedSpacialQTree=linked_spacial_qtree(qtrees),
labels::AbstractSet{Int}=Set(1:length(qtrees));
colist=Vector{CoItem}(), itemlist::AbstractVector{CoItem}=Vector{CoItem}(),
treenodestack = Vector{SpacialQTreeNode}(),
unique=true, kargs...)
empty!(itemlist)
locate!(qtrees, labels, spqtree) #需要将labels中的label移动到链表首
locate!(qtrees, labels, linkedspqtree) #需要将labels中的label移动到链表首
for label in labels
# @show label
for listnode in spacial_indexesof(spqtree, label)
for listnode in spacial_indexesof(linkedspqtree, label)
# 更prev的node都是move过的,在其向后遍历时会加入与当前node的pair,故不需要向前遍历
# 但要保证更prev的node在`labels`中
treenode = seek_treenode(listnode)
Expand All @@ -284,7 +284,7 @@ function partialcollisions(qtrees::AbstractVector,
while !isroot(tn)
tn = tn.parent #root不是哨兵,值需要遍历
if !isemptylabels(tn)
plbs = Iterators.filter(!in(labels), labelsof(tn)) #moved了的plb不加入,等候其向下遍历时加,避免重复
plbs = Iterators.filter(!in(labels), labelsof(tn)) #move了的plb不加入,等候其向下遍历时加,避免重复
collisions_boundsfilter(qtrees, spindex, label, plbs, itemlist, colist)
end
end
Expand All @@ -309,7 +309,7 @@ function partialcollisions(qtrees::AbstractVector,
# @show itemlist
end
end
emptyflag && remove_tree_node(spqtree, tn)
emptyflag && remove_tree_node(linkedspqtree, tn)
end
end
end
Expand All @@ -333,40 +333,43 @@ end
# Base.iterate(s::UpdatedSet, args...) = iterate(s.set, args...)
# Base.in(item, s::UpdatedSet) = in(item, s.set)
# Base.in(s::UpdatedSet) = in(s.set)
function totalcollisions_kw(qtrees; sptqree2=hash_spacial_qtree(qtrees),
spqtree=nothing, treenodestack=nothing, kargs...)
totalcollisions(qtrees; spqtree=sptqree2, kargs...)
function totalcollisions_kw(args...;
linkedspqtree=nothing, treenodestack=nothing, kargs...)
totalcollisions(args...; kargs...)
end
function partialcollisions_kw(args...;
spqtree=nothing, pairlist=nothing, kargs...)
partialcollisions(args...; kargs...)
end
partialcollisions_kw(qtrees, spqtree, labels; sptqree2=nothing, pairlist=nothing, kargs...) = partialcollisions(qtrees, spqtree, labels; kargs...)
function dynamiccollisions(qtrees::AbstractVector,
spqtree::LinkedSpacialQTree=linked_spacial_qtree(qtrees),
labels::AbstractSet{Int}=Set(1:length(qtrees));
updated::AbstractSet{Int},
linkedspqtree::LinkedSpacialQTree=linked_spacial_qtree(qtrees),
moved::AbstractSet{Int}=Set(1:length(qtrees));
unlocated::AbstractSet{Int},
kargs...)
if length(labels) / length(qtrees) > 0.6
if length(moved) / length(qtrees) > 0.6
r = totalcollisions_kw(qtrees; kargs...)
union!(updated, labels)
union!(unlocated, moved)
else
locate!(qtrees, (i for i in updated if i labels), spqtree)
empty!(updated)
r = partialcollisions_kw(qtrees, spqtree, labels; kargs...)
locate!(qtrees, (i for i in unlocated if i moved), linkedspqtree)
empty!(unlocated)
r = partialcollisions_kw(qtrees, linkedspqtree, moved; kargs...)
end
empty!(labels)
union!(labels, first.(r) |> Iterators.flatten)
empty!(moved)
union!(moved, first.(r) |> Iterators.flatten)
r
end
struct DynamicColliders
qtrees::Vector{U8SQTree}
spqtree::LinkedSpacialQTree
colabels::Set{Int}
updated::Set{Int}
moved::Set{Int}
unlocated::Set{Int}
end
function Base.union!(dc::DynamicColliders, c)
union!(dc.colabels, c)
union!(dc.moved, c)
end
DynamicColliders(qtrees::AbstractVector{U8SQTree}) = DynamicColliders(qtrees, linked_spacial_qtree(qtrees), Set(1:length(qtrees)), Set{Int}())
function dynamiccollisions(colliders::DynamicColliders; kargs...)
r = dynamiccollisions(colliders.qtrees, colliders.spqtree, colliders.colabels; updated=colliders.updated, kargs...)
r = dynamiccollisions(colliders.qtrees, colliders.spqtree, colliders.moved; unlocated=colliders.unlocated, kargs...)
r
end
########## place!
Expand Down
24 changes: 12 additions & 12 deletions test/test_qtrees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ testqtree = Stuffing.testqtree
cln = QTrees.totalcollisions_native(qts)
clp = QTrees.partialcollisions(qts, spt2, Set(1:length(qts)))
@test Set(Set.(first.(cls))) == Set(Set.(first.(cln))) == Set(Set.(first.(clp)))
updated = first.(clp)|>Iterators.flatten|>Set
clp = QTrees.partialcollisions(qts, spt2, updated, unique=true)
moved = first.(clp)|>Iterators.flatten|>Set
clp = QTrees.partialcollisions(qts, spt2, moved, unique=true)
@test Set(Set.(first.(cln))) == Set(Set.(first.(clp)))
clp = QTrees.partialcollisions(qts)
@test Set(Set.(first.(cln))) == Set(Set.(first.(clp)))
Expand All @@ -68,32 +68,32 @@ testqtree = Stuffing.testqtree
c2set = Set([Set(p) for p in first.(QTrees.totalcollisions(qts)) if !isdisjoint(p, labels)])
@test c2set == Set(Set.(first.(c1)))
# dynamic
updated = Set(1:length(qts));
moved = Set(1:length(qts));
spqtree = linked_spacial_qtree(qts);
for i in 1:10
C1 = partialcollisions(qts, spqtree, updated);
union!(updated, first.(C1) |> Iterators.flatten) #all collided labels
C1 = partialcollisions(qts, spqtree, moved);
union!(moved, first.(C1) |> Iterators.flatten) #all collided labels
C2 = QTrees.totalcollisions_native(qts);
@test length(C1) == length(C2)
@test Set(Set.(first.(C1))) == Set(Set.(first.(C2)))
Stuffing.Trainer.batchsteps!(qts, C1)
QTrees.shift!(qts[3], 1, 1, 1)
QTrees.shift!(qts[7], 1, -1, -1)
union!(updated, [3, 7]) #other updated labels
union!(moved, [3, 7]) #other moved objects
end
colabels = Set(1:length(qts));
updated = Set{Int}();
moved = Set(1:length(qts));
unlocated = Set{Int}();
spqtree = linked_spacial_qtree(qts);
for i in 1:10
C1 = dynamiccollisions(qts, spqtree, colabels; updated=updated);
C1 = dynamiccollisions(qts, spqtree, moved; unlocated=unlocated);
C2 = QTrees.totalcollisions_native(qts);
C3 = QTrees.totalcollisions(qts);
@test length(C1) == length(C2) || length(C1) == length(C3)
@test Set(Set.(first.(C1))) == Set(Set.(first.(C2))) || Set(Set.(first.(C1))) == Set(Set.(first.(C3)))
Stuffing.Trainer.batchsteps!(qts, C1)
QTrees.shift!(qts[3], 1, -1, -1)
QTrees.shift!(qts[7], 1, 1, 1)
union!(colabels, [3, 7]) #other updated labels
union!(moved, [3, 7]) #other moved objects
end
colliders = DynamicColliders(qts);
for i in 1:10
Expand All @@ -105,10 +105,10 @@ testqtree = Stuffing.testqtree
@test Set(Set.(first.(C1))) == Set(Set.(first.(C2))) || Set(Set.(first.(C1))) == Set(Set.(first.(C3)))
end
Stuffing.Trainer.batchsteps!(qts, C1) #move objects in C1
begin #other moved labels
begin #other moved objects
QTrees.shift!(qts[3], 1, -1, -1)
QTrees.shift!(qts[7], 1, 1, 1)
union!(colliders, [3, 7]) #other updated labels
union!(colliders, [3, 7]) #other moved objects
end
end
#edge cases
Expand Down

0 comments on commit 17454d7

Please sign in to comment.