From 557d9d6943aee60ed9dea00fe247c0b6d377bc50 Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Sun, 8 Oct 2023 16:32:31 -0500 Subject: [PATCH] - fix removeDeletedNeighbors - NN path can include self, so exclude it from candidates array --- .../jvector/graph/ConcurrentNeighborSet.java | 19 ++++++++++++++++-- .../jvector/graph/GraphIndexBuilder.java | 20 +++++++++++-------- .../jbellis/jvector/graph/TestDeletions.java | 6 ++++-- 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborSet.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborSet.java index 1ad454ca5..fcc55260a 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborSet.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/ConcurrentNeighborSet.java @@ -108,9 +108,24 @@ public void cleanup() { public boolean removeDeletedNeighbors(Bits deletedNodes) { AtomicBoolean found = new AtomicBoolean(); neighborsRef.getAndUpdate(current -> { + // build a set of the entries we want to retain + var toRetain = new FixedBitSet(current.size); + for (int i = 0; i < current.size; i++) { + if (deletedNodes.get(current.node[i])) { + found.set(true); + } else { + toRetain.set(i); + } + } + + // if we're retaining everything, no need to make a copy + if (!found.get()) { + return current; + } + + // copy and purge the deleted ones var next = current.copy(); - next.retain(Bits.inverseOf(deletedNodes)); - found.set(next.size < current.size); + next.retain(toRetain); return next; }); return found.get(); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index ff58ae87b..9ffbcbe09 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -217,7 +217,7 @@ public long addGraphNode(int node, RandomAccessVectorValues vectors) { // Update neighbors with these candidates. // TODO if we made NeighborArray an interface we could wrap the NodeScore[] directly instead of copying - var natural = toScratchCandidates(result.getNodes(), naturalScratchPooled.get()); + var natural = toScratchCandidates(result.getNodes(), result.getNodes().length, naturalScratchPooled.get()); var concurrent = getConcurrentCandidates(node, inProgressBefore, concurrentScratchPooled.get(), vectors, vc.get()); updateNeighbors(node, natural, concurrent); graph.markComplete(node); @@ -332,7 +332,7 @@ public int length() { var value = v1.get().vectorValue(node); NeighborSimilarity.ExactScoreFunction scoreFunction = i -> scoreBetween(v2.get().vectorValue(i), value); var result = gs.get().searchInternal(scoreFunction, null, beamWidth, graph.entry(), notSelfBits); - var candidates = getPathCandidates(result.getVisited(), scoreFunction, scratch.get()); + var candidates = getPathCandidates(result.getVisited(), node, scoreFunction, scratch.get()); updateNeighbors(node, candidates, NeighborArray.EMPTY); } } @@ -381,19 +381,23 @@ private void updateNeighbors(int node, NeighborArray natural, NeighborArray conc /** * compute the scores for the nodes set in `visited` and return them in a NeighborArray */ - private NeighborArray getPathCandidates(BitSet visited, NeighborSimilarity.ExactScoreFunction scoreFunction, NeighborArray scratch) { + private NeighborArray getPathCandidates(BitSet visited, int node, NeighborSimilarity.ExactScoreFunction scoreFunction, NeighborArray scratch) { + // doing a single sort is faster than repeatedly calling insertSorted SearchResult.NodeScore[] candidates = new SearchResult.NodeScore[visited.cardinality()]; int j = 0; for (int i = visited.nextSetBit(0); i != NO_MORE_DOCS; i = visited.nextSetBit(i + 1)) { - candidates[j++] = new SearchResult.NodeScore(i, scoreFunction.similarityTo(i)); + if (i != node) { + candidates[j++] = new SearchResult.NodeScore(i, scoreFunction.similarityTo(i)); + } } - Arrays.sort(candidates, Comparator.comparingDouble(ns -> ns.score)); - return toScratchCandidates(candidates, scratch); + Arrays.sort(candidates, 0, j, Comparator.comparingDouble(ns -> -ns.score)); + return toScratchCandidates(candidates, j, scratch); } - private NeighborArray toScratchCandidates(SearchResult.NodeScore[] candidates, NeighborArray scratch) { + private NeighborArray toScratchCandidates(SearchResult.NodeScore[] candidates, int count, NeighborArray scratch) { scratch.clear(); - for (SearchResult.NodeScore candidate : candidates) { + for (int i = 0; i < count; i++) { + var candidate = candidates[i]; scratch.addInOrder(candidate.node, candidate.score); } return scratch; diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java index bba7839c4..9bc8380df 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestDeletions.java @@ -8,6 +8,8 @@ import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import org.junit.Test; +import java.util.Arrays; + import static io.github.jbellis.jvector.graph.GraphIndexTestCase.createRandomFloatVectors; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; @@ -70,8 +72,8 @@ public void testCleanup() { assertEquals(ravv.size() - nDeleted, graph.size()); // cleanup should have added new connections to the node that would otherwise have been disconnected - var v = ravv.vectorValue(nodeToIsolate); - var results = GraphSearcher.search(v, 1, ravv, VectorEncoding.FLOAT32, VectorSimilarityFunction.COSINE, graph, Bits.ALL); + var v = Arrays.copyOf(ravv.vectorValue(nodeToIsolate), ravv.dimension); + var results = GraphSearcher.search(v, 10, ravv, VectorEncoding.FLOAT32, VectorSimilarityFunction.COSINE, graph, Bits.ALL); assertEquals(nodeToIsolate, results.getNodes()[0].node); } }