From b865815e63b1480df9499f6668da5b4370845dba Mon Sep 17 00:00:00 2001 From: Joel Knighton Date: Mon, 2 Oct 2023 14:02:46 -0500 Subject: [PATCH] Add improved test coverage for on-disk graph caching --- .../io/github/jbellis/jvector/TestUtil.java | 12 ++- .../jbellis/jvector/disk/TestGraphCache.java | 77 +++++++++++++++++++ .../jvector/graph/TestByteVectorGraph.java | 4 +- .../jvector/graph/TestFloatVectorGraph.java | 4 +- .../jvector/graph/TestOnDiskGraphIndex.java | 19 ++--- 5 files changed, 99 insertions(+), 17 deletions(-) create mode 100644 jvector-tests/src/test/java/io/github/jbellis/jvector/disk/TestGraphCache.java diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java index 9f95c9bbe..2c1c0eaa0 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java @@ -16,8 +16,10 @@ package io.github.jbellis.jvector; +import io.github.jbellis.jvector.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.GraphIndex; import io.github.jbellis.jvector.graph.NodesIterator; +import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.vector.VectorUtil; import java.io.BufferedOutputStream; @@ -106,6 +108,14 @@ public static byte[] randomVector8(Random random, int dim) { return bvec; } + public static void writeGraph(GraphIndex graph, RandomAccessVectorValues vectors, Path outputPath) throws IOException { + try (var indexOutputWriter = openFileForWriting(outputPath)) + { + OnDiskGraphIndex.write(graph, vectors, indexOutputWriter); + indexOutputWriter.flush(); + } + } + public static class FullyConnectedGraphIndex implements GraphIndex { private final int entryNode; private final int size; @@ -133,7 +143,7 @@ public View getView() { @Override public int maxDegree() { - return size; + return size - 1; } @Override diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/disk/TestGraphCache.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/disk/TestGraphCache.java new file mode 100644 index 000000000..13fe5655e --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/disk/TestGraphCache.java @@ -0,0 +1,77 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.disk; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import io.github.jbellis.jvector.TestUtil; +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static io.github.jbellis.jvector.TestUtil.writeGraph; +import static org.junit.Assert.*; + +public class TestGraphCache extends RandomizedTest { + private Path testDirectory; + private Path onDiskGraphIndexPath; + private RandomAccessVectorValues vectors; + + + @Before + public void setup() throws IOException { + var fullyConnectedGraph = new TestUtil.FullyConnectedGraphIndex(0, 6); + vectors = new ListRandomAccessVectorValues(IntStream.range(0, 6).mapToObj(i -> TestUtil.randomVector(getRandom(), 2)).collect(Collectors.toList()), 2); + testDirectory = Files.createTempDirectory(this.getClass().getSimpleName()); + onDiskGraphIndexPath = testDirectory.resolve("fullyConnectedGraph"); + writeGraph(fullyConnectedGraph, vectors, onDiskGraphIndexPath); + } + + @After + public void tearDown() { + TestUtil.deleteQuietly(testDirectory); + } + + @Test + public void testGraphCacheLoading() throws Exception { + try (var marr = new SimpleMappedReader(onDiskGraphIndexPath.toAbsolutePath().toString()); + var onDiskGraph = new OnDiskGraphIndex(marr::duplicate, 0)) + { + var none = GraphCache.load(onDiskGraph, -1); + assertEquals(0, none.ramBytesUsed()); + assertNull(none.getNode(0)); + var zero = GraphCache.load(onDiskGraph, 0); + assertNotNull(zero.getNode(0)); + assertNull(zero.getNode(1)); + var one = GraphCache.load(onDiskGraph, 1); + // move from caching entry node to entry node + all its neighbors (5) + assertEquals(one.ramBytesUsed(), zero.ramBytesUsed() * (onDiskGraph.size())); + for (int i = 0; i < 6; i++) { + assertArrayEquals(one.getNode(i).vector, vectors.vectorValue(i), 0); + // fully connected, + assertEquals(one.getNode(i).neighbors.length, onDiskGraph.maxDegree()); + } + } + } +} \ No newline at end of file diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestByteVectorGraph.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestByteVectorGraph.java index a5ce29b4e..f3f561cdb 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestByteVectorGraph.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestByteVectorGraph.java @@ -51,7 +51,7 @@ byte[] randomVector(int dim) { @Override AbstractMockVectorValues vectorValues(int size, int dimension) { - return MockByteVectorValues.fromValues(GraphIndexTestCase.createRandomByteVectors(size, dimension, getRandom())); + return MockByteVectorValues.fromValues(createRandomByteVectors(size, dimension, getRandom())); } static boolean fitsInByte(float v) { @@ -87,7 +87,7 @@ AbstractMockVectorValues vectorValues( int pregeneratedOffset) { byte[][] vectors = new byte[size][]; byte[][] randomVectors = - GraphIndexTestCase.createRandomByteVectors(size - pregeneratedVectorValues.values.length, dimension, getRandom()); + createRandomByteVectors(size - pregeneratedVectorValues.values.length, dimension, getRandom()); for (int i = 0; i < pregeneratedOffset; i++) { vectors[i] = randomVectors[i]; diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestFloatVectorGraph.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestFloatVectorGraph.java index 8b878f721..e1978eafc 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestFloatVectorGraph.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestFloatVectorGraph.java @@ -58,7 +58,7 @@ float[] randomVector(int dim) { @Override AbstractMockVectorValues vectorValues(int size, int dimension) { - return MockVectorValues.fromValues(GraphIndexTestCase.createRandomFloatVectors(size, dimension, getRandom())); + return MockVectorValues.fromValues(createRandomFloatVectors(size, dimension, getRandom())); } @Override @@ -74,7 +74,7 @@ AbstractMockVectorValues vectorValues( int pregeneratedOffset) { float[][] vectors = new float[size][]; float[][] randomVectors = - GraphIndexTestCase.createRandomFloatVectors( + createRandomFloatVectors( size - pregeneratedVectorValues.values.length, dimension, getRandom()); for (int i = 0; i < pregeneratedOffset; i++) { diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestOnDiskGraphIndex.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestOnDiskGraphIndex.java index 5c7e37bfc..16272dc7c 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestOnDiskGraphIndex.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestOnDiskGraphIndex.java @@ -70,21 +70,13 @@ private static void validateGraph(GraphIndex.View expectedView, GraphInde } } - private static void writeGraph(GraphIndex graph, RandomAccessVectorValues vectors, Path outputPath) throws IOException { - try (var indexOutputWriter = TestUtil.openFileForWriting(outputPath)) - { - OnDiskGraphIndex.write(graph, vectors, indexOutputWriter); - indexOutputWriter.flush(); - } - } - @Test public void testSimpleGraphs() throws Exception { for (var graph : List.of(fullyConnectedGraph, randomlyConnectedGraph)) { var outputPath = testDirectory.resolve("test_graph_" + graph.getClass().getSimpleName()); var ravv = new GraphIndexTestCase.CircularFloatVectorValues(graph.size()); - writeGraph(graph, ravv, outputPath); + TestUtil.writeGraph(graph, ravv, outputPath); try (var marr = new SimpleMappedReader(outputPath.toAbsolutePath().toString()); var onDiskGraph = new OnDiskGraphIndex(marr::duplicate, 0); var onDiskView = onDiskGraph.getView()) @@ -107,15 +99,18 @@ public void testLargeGraph() throws Exception var graph = new TestUtil.RandomlyConnectedGraphIndex(100_000, 16, getRandom()); var outputPath = testDirectory.resolve("large_graph"); var ravv = new GraphIndexTestCase.CircularFloatVectorValues(graph.size()); - writeGraph(graph, ravv, outputPath); + TestUtil.writeGraph(graph, ravv, outputPath); try (var marr = new SimpleMappedReader(outputPath.toAbsolutePath().toString()); var onDiskGraph = new OnDiskGraphIndex(marr::duplicate, 0); - var onDiskView = onDiskGraph.getView()) + var onDiskView = onDiskGraph.getView(); + var cachedOnDiskGraph = new CachingGraphIndex(onDiskGraph); + var cachedOnDiskView = cachedOnDiskGraph.getView()) { validateGraph(graph.getView(), onDiskView); - validateGraph(graph.getView(), new CachingGraphIndex(onDiskGraph).getView()); + validateGraph(graph.getView(), cachedOnDiskView); validateVectors(onDiskView, ravv); + validateVectors(cachedOnDiskView, ravv); } } }