Skip to content

Commit

Permalink
vectorsEncountered not always in sync with resultQueue, causing NPE w…
Browse files Browse the repository at this point in the history
…hen breaking out of loop due to threshold probability (#150)
  • Loading branch information
jbellis authored Nov 10, 2023
1 parent 0b93065 commit 3b6bfe6
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -189,20 +189,15 @@ SearchResult searchInternal(NodeSimilarity.ScoreFunction scoreFunction,
visited.set(ep);
numVisited++;
candidates.push(ep, score);
if (acceptOrds.get(ep) && score >= threshold) {
resultsQueue.push(ep, score);
}

// A bound that holds the minimum similarity to the query vector that a candidate vector must
// have to be considered.
float minAcceptedSimilarity = Float.NEGATIVE_INFINITY;
if (resultsQueue.size() >= topK) {
minAcceptedSimilarity = resultsQueue.topScore();
}

while (candidates.size() > 0 && !resultsQueue.incomplete()) {
// get the best candidate (closest or best scoring)
if (candidates.topScore() < minAcceptedSimilarity) {
// done when best candidate is worse than the worst result so far
float topCandidateScore = candidates.topScore();
if (topCandidateScore < minAcceptedSimilarity) {
break;
}

Expand All @@ -211,10 +206,21 @@ SearchResult searchInternal(NodeSimilarity.ScoreFunction scoreFunction,
break;
}

// add the top candidate to the resultset
int topCandidateNode = candidates.pop();
if (!scoreFunction.isExact()) {
vectorsEncountered.put(topCandidateNode, view.getVector(topCandidateNode));
if (acceptOrds.get(topCandidateNode)
&& topCandidateScore >= threshold
&& resultsQueue.push(topCandidateNode, topCandidateScore))
{
if (resultsQueue.size() >= topK) {
minAcceptedSimilarity = resultsQueue.topScore();
}
if (!scoreFunction.isExact()) {
vectorsEncountered.put(topCandidateNode, view.getVector(topCandidateNode));
}
}

// add its neighbors to the candidates queue
for (var it = view.getNeighborsIterator(topCandidateNode); it.hasNext(); ) {
int friendOrd = it.nextInt();
if (visited.getAndSet(friendOrd)) {
Expand All @@ -224,14 +230,8 @@ SearchResult searchInternal(NodeSimilarity.ScoreFunction scoreFunction,

float friendSimilarity = scoreFunction.similarityTo(friendOrd);
scoreTracker.track(friendSimilarity);

if (friendSimilarity >= minAcceptedSimilarity) {
candidates.push(friendOrd, friendSimilarity);
if (acceptOrds.get(friendOrd) && friendSimilarity >= threshold) {
if (resultsQueue.push(friendOrd, friendSimilarity) && resultsQueue.size() >= topK) {
minAcceptedSimilarity = resultsQueue.topScore();
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,26 @@

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
import io.github.jbellis.jvector.LuceneTestCase;
import io.github.jbellis.jvector.TestUtil;
import io.github.jbellis.jvector.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.disk.SimpleMappedReader;
import io.github.jbellis.jvector.pq.PQVectors;
import io.github.jbellis.jvector.pq.ProductQuantization;
import io.github.jbellis.jvector.util.Bits;
import io.github.jbellis.jvector.vector.VectorEncoding;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import org.junit.Test;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.List;

@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class Test2DThreshold extends LuceneTestCase {
@Test
public void testThreshold() {
public void testThreshold() throws IOException {
var R = getRandom();
// generate 2D vectors
float[][] vectors = new float[10000][2];
Expand All @@ -40,9 +48,10 @@ public void testThreshold() {

var ravv = new ListRandomAccessVectorValues(List.of(vectors), 2);
var builder = new GraphIndexBuilder<>(ravv, VectorEncoding.FLOAT32, VectorSimilarityFunction.EUCLIDEAN, 6, 32, 1.2f, 1.4f);
var graph = builder.build();
var searcher = new GraphSearcher.Builder<>(graph.getView()).build();
var onHeapGraph = builder.build();

// test raw vectors
var searcher = new GraphSearcher.Builder<>(onHeapGraph.getView()).build();
for (int i = 0; i < 10; i++) {
TestParams tp = createTestParams(vectors);

Expand All @@ -52,6 +61,27 @@ public void testThreshold() {
assert result.getVisitedCount() < vectors.length : "visited all vectors for threshold " + tp.th;
assert result.getNodes().length >= 0.9 * tp.exactCount : "returned " + result.getNodes().length + " nodes for threshold " + tp.th + " but should have returned at least " + tp.exactCount;
}

// test compressed
Path outputPath = Files.createTempFile("graph", ".jvector");
TestUtil.writeGraph(onHeapGraph, ravv, outputPath);
var pq = ProductQuantization.compute(ravv, ravv.dimension(), false);
var cv = new PQVectors(pq, pq.encodeAll(List.of(vectors)));

try (var marr = new SimpleMappedReader(outputPath.toAbsolutePath().toString());
var onDiskGraph = new OnDiskGraphIndex<float[]>(marr::duplicate, 0))
{
for (int i = 0; i < 10; i++) {
TestParams tp = createTestParams(vectors);
searcher = new GraphSearcher.Builder<>(onDiskGraph.getView()).build();
NodeSimilarity.ReRanker<float[]> reranker = (j, map) -> VectorSimilarityFunction.EUCLIDEAN.compare(tp.q, map.get(j));
var asf = cv.approximateScoreFunctionFor(tp.q, VectorSimilarityFunction.EUCLIDEAN);
var result = searcher.search(asf, reranker, vectors.length, tp.th, Bits.ALL);

assert result.getVisitedCount() < vectors.length : "visited all vectors for threshold " + tp.th;
}
}

}

// it's not an interesting test if all the vectors are within the threshold
Expand Down

0 comments on commit 3b6bfe6

Please sign in to comment.