From c6b16b072b6fce822e1cae71918e7d4d6b9b090e Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Mon, 17 Jun 2024 12:54:52 -0400 Subject: [PATCH] faster update --- pkg/hnsw/heap.go | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/pkg/hnsw/heap.go b/pkg/hnsw/heap.go index 7006d28..80eba38 100644 --- a/pkg/hnsw/heap.go +++ b/pkg/hnsw/heap.go @@ -13,13 +13,13 @@ var EmptyHeapError = fmt.Errorf("Empty Heap") type DistHeap struct { items []*Item - visited map[Id]bool + visited map[Id]int } func NewDistHeap() *DistHeap { d := &DistHeap{ items: make([]*Item, 0), - visited: make(map[Id]bool), + visited: make(map[Id]int), } return d } @@ -86,20 +86,19 @@ func (d *DistHeap) PopMaxItem() (*Item, error) { return d.Pop(), nil } func (d *DistHeap) Insert(id Id, dist float32) { - if d.visited[id] { - for idx, item := range d.items { - if item.id == id { - item.dist = dist - d.Fix(idx) - return - } - } - } else { + index, ok := d.visited[id] + + if !ok { d.Push(&Item{id: id, dist: dist}) + d.visited[id] = d.Len() - 1 d.up(d.Len() - 1) - d.visited[id] = true + return } + + d.items[index].dist = dist + d.Fix(index) } + func (d *DistHeap) Fix(i int) { if !d.down(i, d.Len()) { d.up(i) @@ -109,7 +108,10 @@ func (d *DistHeap) Fix(i int) { func (d DistHeap) IsEmpty() bool { return len(d.items) == 0 } func (d DistHeap) Len() int { return len(d.items) } func (d DistHeap) Less(i, j int) bool { return d.items[i].dist < d.items[j].dist } -func (d DistHeap) Swap(i, j int) { d.items[i], d.items[j] = d.items[j], d.items[i] } +func (d DistHeap) Swap(i, j int) { + d.visited[d.items[i].id], d.visited[d.items[j].id] = j, i + d.items[i], d.items[j] = d.items[j], d.items[i] +} func (d *DistHeap) Push(x *Item) { (*d).items = append((*d).items, x) } @@ -118,5 +120,6 @@ func (d *DistHeap) Pop() *Item { n := len(old) x := old[n-1] (*d).items = old[0 : n-1] + delete(d.visited, x.id) return x }