Skip to content

Commit

Permalink
Merge pull request #75 from paulmach/knearest-sort
Browse files Browse the repository at this point in the history
quadtree: sort KNearest results closest first
  • Loading branch information
paulmach authored Oct 16, 2021
2 parents 418be84 + e7643d0 commit 59d7b22
Show file tree
Hide file tree
Showing 5 changed files with 232 additions and 47 deletions.
34 changes: 34 additions & 0 deletions quadtree/benchmarks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,37 @@ func BenchmarkRandomInBound1000Buf(b *testing.B) {
buf = qt.InBound(buf, p.Bound().Pad(0.1))
}
}

func BenchmarkRandomKNearest10(b *testing.B) {
r := rand.New(rand.NewSource(43))

qt := New(orb.Bound{Min: orb.Point{0, 0}, Max: orb.Point{1, 1}})
for i := 0; i < 1000; i++ {
qt.Add(orb.Point{r.Float64(), r.Float64()})
}

buf := make([]orb.Pointer, 0, 10)

b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
qt.KNearest(buf[:0], orb.Point{r.Float64(), r.Float64()}, 10)
}
}

func BenchmarkRandomKNearest100(b *testing.B) {
r := rand.New(rand.NewSource(43))

qt := New(orb.Bound{Min: orb.Point{0, 0}, Max: orb.Point{1, 1}})
for i := 0; i < 1000; i++ {
qt.Add(orb.Point{r.Float64(), r.Float64()})
}

buf := make([]orb.Pointer, 0, 100)

b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
qt.KNearest(buf[:0], orb.Point{r.Float64(), r.Float64()}, 100)
}
}
102 changes: 102 additions & 0 deletions quadtree/maxheap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package quadtree

import "github.com/paulmach/orb"

// maxHeap is used for the knearest list. We need a way to maintain
// the furthest point from the query point in the list, hence maxHeap.
// When we find a point closer than the furthest away, we remove
// furthest and add the new point to the heap.
type maxHeap []*heapItem

type heapItem struct {
point orb.Pointer
distance float64
}

func (h *maxHeap) Push(point orb.Pointer, distance float64) {
// Common usage is Push followed by a Pop if we have > k points.
// We're reusing the k+1 heapItem object to reduce memory allocations.
// First we manaully lengthen the slice,
// then we see if the last item has been allocated already.

prevLen := len(*h)
*h = (*h)[:prevLen+1]
if (*h)[prevLen] == nil {
(*h)[prevLen] = &heapItem{point: point, distance: distance}
} else {
(*h)[prevLen].point = point
(*h)[prevLen].distance = distance
}

i := len(*h) - 1
for i > 0 {
up := ((i + 1) >> 1) - 1
parent := (*h)[up]

if distance < parent.distance {
// parent is further so we're done fixing up the heap.
break
}

// swap nodes
// (*h)[i] = parent
(*h)[i].point = parent.point
(*h)[i].distance = parent.distance

// (*h)[up] = item
(*h)[up].point = point
(*h)[up].distance = distance

i = up
}
}

// Pop returns the "greatest" item in the list.
// The returned item should not be saved across push/pop operations.
func (h *maxHeap) Pop() *heapItem {
removed := (*h)[0]
lastItem := (*h)[len(*h)-1]
(*h) = (*h)[:len(*h)-1]

mh := (*h)
if len(mh) == 0 {
return removed
}

// move the last item to the top and reset the heap
mh[0] = lastItem

i := 0
current := mh[i]
for {
right := (i + 1) << 1
left := right - 1

childIndex := i
child := mh[childIndex]

// swap with biggest child
if left < len(mh) && child.distance < mh[left].distance {
childIndex = left
child = mh[left]
}

if right < len(mh) && child.distance < mh[right].distance {
childIndex = right
child = mh[right]
}

// non bigger, so quit
if childIndex == i {
break
}

// swap the nodes
mh[i] = child
mh[childIndex] = current

i = childIndex
}

return removed
}
27 changes: 27 additions & 0 deletions quadtree/maxheap_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package quadtree

import (
"math/rand"
"testing"
)

func TestMaxHeap(t *testing.T) {
r := rand.New(rand.NewSource(22))

for i := 1; i < 100; i++ {
h := make(maxHeap, 0, i)
for j := 0; j < i; j++ {
h.Push(nil, r.Float64())
}

current := h.Pop().distance
for len(h) > 0 {
next := h.Pop().distance
if next > current {
t.Errorf("incorrect")
}

current = next
}
}
}
97 changes: 50 additions & 47 deletions quadtree/quadtree.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package quadtree

import (
"container/heap"
"errors"
"math"

Expand Down Expand Up @@ -213,6 +212,7 @@ func (q *Quadtree) Matching(p orb.Point, f FilterFunc) orb.Pointer {
// KNearest returns k closest Value/Pointer in the quadtree.
// This function is thread safe. Multiple goroutines can read from a pre-created tree.
// An optional buffer parameter is provided to allow for the reuse of result slice memory.
// The points are returned in a sorted order, nearest first.
// This function allows defining a maximum distance in order to reduce search iterations.
func (q *Quadtree) KNearest(buf []orb.Pointer, p orb.Point, k int, maxDistance ...float64) []orb.Pointer {
return q.KNearestMatching(buf, p, k, nil, maxDistance...)
Expand All @@ -222,6 +222,7 @@ func (q *Quadtree) KNearest(buf []orb.Pointer, p orb.Point, k int, maxDistance .
// the given filter function returns true. This function is thread safe.
// Multiple goroutines can read from a pre-created tree. An optional buffer
// parameter is provided to allow for the reuse of result slice memory.
// The points are returned in a sorted order, nearest first.
// This function allows defining a maximum distance in order to reduce search iterations.
func (q *Quadtree) KNearestMatching(buf []orb.Pointer, p orb.Point, k int, f FilterFunc, maxDistance ...float64) []orb.Pointer {
if q.root == nil {
Expand All @@ -233,13 +234,13 @@ func (q *Quadtree) KNearestMatching(buf []orb.Pointer, p orb.Point, k int, f Fil
point: p,
filter: f,
k: k,
closest: newPointsQueue(k),
maxHeap: make(maxHeap, 0, k+1),
closestBound: &b,
maxDistSquared: math.MaxFloat64,
}

if len(maxDistance) > 0 {
v.maxDistSquared = math.Pow(maxDistance[0], 2)
v.maxDistSquared = maxDistance[0] * maxDistance[0]
}

newVisit(v).Visit(q.root,
Expand All @@ -250,15 +251,16 @@ func (q *Quadtree) KNearestMatching(buf []orb.Pointer, p orb.Point, k int, f Fil
)

//repack result
if cap(buf) < len(v.closest) {
buf = make([]orb.Pointer, 0, len(v.closest))
if cap(buf) < len(v.maxHeap) {
buf = make([]orb.Pointer, len(v.maxHeap))
} else {
buf = buf[:0]
buf = buf[:len(v.maxHeap)]
}

for _, element := range v.closest {
buf = append(buf, element.point)
for i := len(v.maxHeap) - 1; i >= 0; i-- {
buf[i] = v.maxHeap.Pop().point
}

return buf
}

Expand Down Expand Up @@ -405,53 +407,53 @@ func (v *findVisitor) Visit(n *node) {
}
}

type pointsQueueItem struct {
point orb.Pointer
distance float64 // distance to point and priority inside the queue
index int // point index in queue
}
// type pointsQueueItem struct {
// point orb.Pointer
// distance float64 // distance to point and priority inside the queue
// index int // point index in queue
// }

type pointsQueue []pointsQueueItem
// type pointsQueue []pointsQueueItem

func newPointsQueue(capacity int) pointsQueue {
// We make capacity+1 because we need additional place for the greatest element
return make([]pointsQueueItem, 0, capacity+1)
}
// func newPointsQueue(capacity int) pointsQueue {
// // We make capacity+1 because we need additional place for the greatest element
// return make([]pointsQueueItem, 0, capacity+1)
// }

func (pq pointsQueue) Len() int { return len(pq) }
// func (pq pointsQueue) Len() int { return len(pq) }

func (pq pointsQueue) Less(i, j int) bool {
// We want pop longest distances so Less was inverted
return pq[i].distance > pq[j].distance
}
// func (pq pointsQueue) Less(i, j int) bool {
// // We want pop longest distances so Less was inverted
// return pq[i].distance > pq[j].distance
// }

func (pq pointsQueue) Swap(i, j int) {
pq[i], pq[j] = pq[j], pq[i]
pq[i].index = i
pq[j].index = j
}
// func (pq pointsQueue) Swap(i, j int) {
// pq[i], pq[j] = pq[j], pq[i]
// pq[i].index = i
// pq[j].index = j
// }

func (pq *pointsQueue) Push(x interface{}) {
n := len(*pq)
item := x.(pointsQueueItem)
item.index = n
*pq = append(*pq, item)
}
// func (pq *pointsQueue) Push(x interface{}) {
// n := len(*pq)
// item := x.(pointsQueueItem)
// item.index = n
// *pq = append(*pq, item)
// }

func (pq *pointsQueue) Pop() interface{} {
old := *pq
n := len(old)
item := old[n-1]
item.index = -1
*pq = old[0 : n-1]
return item
}
// func (pq *pointsQueue) Pop() interface{} {
// old := *pq
// n := len(old)
// item := old[n-1]
// item.index = -1
// *pq = old[0 : n-1]
// return item
// }

type nearestVisitor struct {
point orb.Point
filter FilterFunc
k int
closest pointsQueue
maxHeap maxHeap
closestBound *orb.Bound
maxDistSquared float64
}
Expand All @@ -472,13 +474,14 @@ func (v *nearestVisitor) Visit(n *node) {

point := n.Value.Point()
if d := planar.DistanceSquared(point, v.point); d < v.maxDistSquared {
heap.Push(&v.closest, pointsQueueItem{point: n.Value, distance: d})
if v.closest.Len() > v.k {
heap.Pop(&v.closest)
v.maxHeap.Push(n.Value, d)
if len(v.maxHeap) > v.k {

v.maxHeap.Pop()

// Actually this is a hack. We know how heap works and obtain
// top element without function call
top := v.closest[0]
top := v.maxHeap[0]

v.maxDistSquared = top.distance

Expand Down
19 changes: 19 additions & 0 deletions quadtree/quadtree_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,25 @@ func TestQuadtreeKNearest(t *testing.T) {
}
}

func TestQuadtreeKNearest_sorted(t *testing.T) {
q := New(orb.Bound{Max: orb.Point{5, 5}})
q.Add(orb.Point{0, 0})
q.Add(orb.Point{1, 1})
q.Add(orb.Point{2, 2})
q.Add(orb.Point{3, 3})
q.Add(orb.Point{4, 4})
q.Add(orb.Point{5, 5})

nearest := q.KNearest(nil, orb.Point{2.25, 2.25}, 5)

expected := []orb.Point{{2, 2}, {3, 3}, {1, 1}, {4, 4}, {0, 0}}
for i, p := range expected {
if n := nearest[i].Point(); !n.Equal(p) {
t.Errorf("incorrect point %d: %v", i, n)
}
}
}

func TestQuadtreeKNearest_DistanceLimit(t *testing.T) {
type dataPointer struct {
orb.Pointer
Expand Down

0 comments on commit 59d7b22

Please sign in to comment.