Skip to content

Commit

Permalink
- fix removeDeletedNeighbors
Browse files Browse the repository at this point in the history
- NN path can include self, so exclude it from candidates array
  • Loading branch information
jbellis committed Oct 9, 2023
1 parent 35bf868 commit 557d9d6
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,24 @@ public void cleanup() {
public boolean removeDeletedNeighbors(Bits deletedNodes) {
AtomicBoolean found = new AtomicBoolean();
neighborsRef.getAndUpdate(current -> {
// build a set of the entries we want to retain
var toRetain = new FixedBitSet(current.size);
for (int i = 0; i < current.size; i++) {
if (deletedNodes.get(current.node[i])) {
found.set(true);
} else {
toRetain.set(i);
}
}

// if we're retaining everything, no need to make a copy
if (!found.get()) {
return current;
}

// copy and purge the deleted ones
var next = current.copy();
next.retain(Bits.inverseOf(deletedNodes));
found.set(next.size < current.size);
next.retain(toRetain);
return next;
});
return found.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ public long addGraphNode(int node, RandomAccessVectorValues<T> vectors) {

// Update neighbors with these candidates.
// TODO if we made NeighborArray an interface we could wrap the NodeScore[] directly instead of copying
var natural = toScratchCandidates(result.getNodes(), naturalScratchPooled.get());
var natural = toScratchCandidates(result.getNodes(), result.getNodes().length, naturalScratchPooled.get());
var concurrent = getConcurrentCandidates(node, inProgressBefore, concurrentScratchPooled.get(), vectors, vc.get());
updateNeighbors(node, natural, concurrent);
graph.markComplete(node);
Expand Down Expand Up @@ -332,7 +332,7 @@ public int length() {
var value = v1.get().vectorValue(node);
NeighborSimilarity.ExactScoreFunction scoreFunction = i -> scoreBetween(v2.get().vectorValue(i), value);
var result = gs.get().searchInternal(scoreFunction, null, beamWidth, graph.entry(), notSelfBits);
var candidates = getPathCandidates(result.getVisited(), scoreFunction, scratch.get());
var candidates = getPathCandidates(result.getVisited(), node, scoreFunction, scratch.get());
updateNeighbors(node, candidates, NeighborArray.EMPTY);
}
}
Expand Down Expand Up @@ -381,19 +381,23 @@ private void updateNeighbors(int node, NeighborArray natural, NeighborArray conc
/**
* compute the scores for the nodes set in `visited` and return them in a NeighborArray
*/
private NeighborArray getPathCandidates(BitSet visited, NeighborSimilarity.ExactScoreFunction scoreFunction, NeighborArray scratch) {
private NeighborArray getPathCandidates(BitSet visited, int node, NeighborSimilarity.ExactScoreFunction scoreFunction, NeighborArray scratch) {
// doing a single sort is faster than repeatedly calling insertSorted
SearchResult.NodeScore[] candidates = new SearchResult.NodeScore[visited.cardinality()];
int j = 0;
for (int i = visited.nextSetBit(0); i != NO_MORE_DOCS; i = visited.nextSetBit(i + 1)) {
candidates[j++] = new SearchResult.NodeScore(i, scoreFunction.similarityTo(i));
if (i != node) {
candidates[j++] = new SearchResult.NodeScore(i, scoreFunction.similarityTo(i));
}
}
Arrays.sort(candidates, Comparator.comparingDouble(ns -> ns.score));
return toScratchCandidates(candidates, scratch);
Arrays.sort(candidates, 0, j, Comparator.comparingDouble(ns -> -ns.score));
return toScratchCandidates(candidates, j, scratch);
}

private NeighborArray toScratchCandidates(SearchResult.NodeScore[] candidates, NeighborArray scratch) {
private NeighborArray toScratchCandidates(SearchResult.NodeScore[] candidates, int count, NeighborArray scratch) {
scratch.clear();
for (SearchResult.NodeScore candidate : candidates) {
for (int i = 0; i < count; i++) {
var candidate = candidates[i];
scratch.addInOrder(candidate.node, candidate.score);
}
return scratch;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import org.junit.Test;

import java.util.Arrays;

import static io.github.jbellis.jvector.graph.GraphIndexTestCase.createRandomFloatVectors;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotEquals;
Expand Down Expand Up @@ -70,8 +72,8 @@ public void testCleanup() {
assertEquals(ravv.size() - nDeleted, graph.size());

// cleanup should have added new connections to the node that would otherwise have been disconnected
var v = ravv.vectorValue(nodeToIsolate);
var results = GraphSearcher.search(v, 1, ravv, VectorEncoding.FLOAT32, VectorSimilarityFunction.COSINE, graph, Bits.ALL);
var v = Arrays.copyOf(ravv.vectorValue(nodeToIsolate), ravv.dimension);
var results = GraphSearcher.search(v, 10, ravv, VectorEncoding.FLOAT32, VectorSimilarityFunction.COSINE, graph, Bits.ALL);
assertEquals(nodeToIsolate, results.getNodes()[0].node);
}
}

0 comments on commit 557d9d6

Please sign in to comment.