Skip to content

Commit

Permalink
feat: return id (#372)
Browse files Browse the repository at this point in the history
  • Loading branch information
friendlymatthew authored Jul 18, 2024
1 parent f9ea4ef commit 0273088
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 20 deletions.
16 changes: 8 additions & 8 deletions pkg/hnsw/hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ func (h *Hnsw) selectNeighbors(nearestNeighbors *DistHeap) ([]*Item, error) {
return nearestItems, nil
}

func (h *Hnsw) InsertVector(q Point) error {
func (h *Hnsw) InsertVector(q Point) (Id, error) {
if !h.isValidPoint(q) {
return fmt.Errorf("invalid vector dimensionality")
return 0, fmt.Errorf("invalid vector dimensionality")
}

topLevel := h.friends[h.entryPointId].TopLevel()
Expand All @@ -199,12 +199,12 @@ func (h *Hnsw) InsertVector(q Point) error {
nnToQAtLevel, err := h.searchLevel(&q, entryItem, h.efConstruction, level)

if err != nil {
return fmt.Errorf("failed to search for nearest neighbors to Q at level %v: %w", level, err)
return 0, fmt.Errorf("failed to search for nearest neighbors to Q at level %v: %w", level, err)
}

neighbors, err := h.selectNeighbors(nnToQAtLevel)
if err != nil {
return fmt.Errorf("failed to select for nearest neighbors to Q at level %v: %w", level, err)
return 0, fmt.Errorf("failed to select for nearest neighbors to Q at level %v: %w", level, err)
}

// add bidirectional connections from neighbors to q at layer c
Expand All @@ -218,13 +218,13 @@ func (h *Hnsw) InsertVector(q Point) error {
for _, neighbor := range neighbors {
neighborFriendsAtLevel, err := h.friends[neighbor.id].GetFriendsAtLevel(level)
if err != nil {
return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
return 0, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}

for neighborFriendsAtLevel.Len() > h.M {
_, err := neighborFriendsAtLevel.PopMaxItem()
if err != nil {
return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
return 0, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}
}

Expand All @@ -233,7 +233,7 @@ func (h *Hnsw) InsertVector(q Point) error {

newEntryItem, err := nnToQAtLevel.PopMinItem()
if err != nil {
return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
return 0, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}

entryItem = newEntryItem
Expand All @@ -243,7 +243,7 @@ func (h *Hnsw) InsertVector(q Point) error {
h.entryPointId = qId
}

return nil
return qId, nil
}

func (h *Hnsw) isValidPoint(point Point) bool {
Expand Down
20 changes: 10 additions & 10 deletions pkg/hnsw/hnsw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ func TestHnsw_InsertVector(t *testing.T) {
t.Fatal("insert vector should have 2 elements")
}

err := h.InsertVector(q)
_, err := h.InsertVector(q)

if err != nil {
t.Fatal(err)
Expand All @@ -446,7 +446,7 @@ func TestHnsw_InsertVector(t *testing.T) {
t.Fatalf("expected friends and points map to have same length throughout insertion")
}

err := h.InsertVector(q)
_, err := h.InsertVector(q)
if err != nil {
return
}
Expand Down Expand Up @@ -495,7 +495,7 @@ func TestHnsw_InsertVector(t *testing.T) {
h := NewHnsw(2, 4, 4, Point{0, 0})

for _, cluster := range clusterA {
if err := h.InsertVector(cluster); err != nil {
if _, err := h.InsertVector(cluster); err != nil {
t.Fatalf("failed to insert vector: %v", err)
}
}
Expand All @@ -508,17 +508,17 @@ func TestHnsw_KnnSearch(t *testing.T) {
h := NewHnsw(2, 4, 4, Point{0, 0})

// id: 1
if err := h.InsertVector(Point{3, 3}); err != nil {
if _, err := h.InsertVector(Point{3, 3}); err != nil {
t.Fatalf("failed to insert point: %v, err: %v", Point{3, 3}, err)
}

// id: 2
if err := h.InsertVector(Point{4, 4}); err != nil {
if _, err := h.InsertVector(Point{4, 4}); err != nil {
t.Fatalf("failed to insert point %v, err: %v", Point{4, 4}, err)
}

// id: 3
if err := h.InsertVector(Point{5, 5}); err != nil {
if _, err := h.InsertVector(Point{5, 5}); err != nil {
t.Fatalf("failed to insert point %v, err: %v", Point{5, 5}, err)
}

Expand Down Expand Up @@ -551,7 +551,7 @@ func TestHnsw_KnnSearch(t *testing.T) {
clusterAGraph := NewHnsw(2, 4, 4, Point{0, 0})

for _, cluster := range clusterA {
if err := clusterAGraph.InsertVector(cluster); err != nil {
if _, err := clusterAGraph.InsertVector(cluster); err != nil {
t.Fatalf("failed to insert point: %v, err: %v", cluster, err)
}
}
Expand Down Expand Up @@ -591,7 +591,7 @@ func TestHnsw_KnnSearch(t *testing.T) {
h := NewHnsw(2, clusterCLen+1, clusterCLen+1, Point{0, 0})

for _, cluster := range clusterC {
if err := h.InsertVector(cluster); err != nil {
if _, err := h.InsertVector(cluster); err != nil {
t.Fatalf("failed to insert point: %v, err: %v", cluster, err)
}
}
Expand Down Expand Up @@ -627,7 +627,7 @@ func TestHnsw_KnnSearch(t *testing.T) {
t.Run("sequential search with upper bound params", func(t *testing.T) {
h := NewHnsw(2, 12, 12, Point{0, 0})
for i := 1; i <= 8; i++ {
if err := h.InsertVector(Point{float32(i), float32(i + 1)}); err != nil {
if _, err := h.InsertVector(Point{float32(i), float32(i + 1)}); err != nil {
t.Fatalf("failed to insert point: %v, err: %v", Point{float32(i), float32(i + 1)}, err)
}
}
Expand Down Expand Up @@ -682,7 +682,7 @@ func BenchmarkHnsw_KnnSearch(b *testing.B) {
b.Fatalf("expected point of dim 3, got dim: %v", len(point))
}

if err := h.InsertVector(point); err != nil {
if _, err := h.InsertVector(point); err != nil {
b.Fatalf("failed to insert point: %v, err: %v", point, err)
}
}
Expand Down
8 changes: 6 additions & 2 deletions pkg/vectorpage/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ func TestNewVectorPageManager(t *testing.T) {
h := hnsw.NewHnsw(2, 10, 8, p0)

for i := 0; i < 100; i++ {
if err := h.InsertVector(hnsw.Point{float32(i), float32(i)}); err != nil {
id, err := h.InsertVector(hnsw.Point{float32(i), float32(i)})
if err != nil {
t.Fatal(err)
}
}

if id != hnsw.Id(i+1) {
t.Fatalf("expected id %d, got %d", id, i+1)
}
}
})
}

0 comments on commit 0273088

Please sign in to comment.