Skip to content

Commit

Permalink
Merge pull request #109 from jkni/jvector-43
Browse files Browse the repository at this point in the history
Add improved test coverage for on-disk graph caching
  • Loading branch information
jkni authored Oct 2, 2023
2 parents 96edf4e + b865815 commit 11e2954
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

package io.github.jbellis.jvector;

import io.github.jbellis.jvector.disk.OnDiskGraphIndex;
import io.github.jbellis.jvector.graph.GraphIndex;
import io.github.jbellis.jvector.graph.NodesIterator;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.vector.VectorUtil;

import java.io.BufferedOutputStream;
Expand Down Expand Up @@ -106,6 +108,14 @@ public static byte[] randomVector8(Random random, int dim) {
return bvec;
}

public static <T> void writeGraph(GraphIndex<T> graph, RandomAccessVectorValues<T> vectors, Path outputPath) throws IOException {
try (var indexOutputWriter = openFileForWriting(outputPath))
{
OnDiskGraphIndex.write(graph, vectors, indexOutputWriter);
indexOutputWriter.flush();
}
}

public static class FullyConnectedGraphIndex<T> implements GraphIndex<T> {
private final int entryNode;
private final int size;
Expand Down Expand Up @@ -133,7 +143,7 @@ public View<T> getView() {

@Override
public int maxDegree() {
return size;
return size - 1;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.github.jbellis.jvector.disk;

import com.carrotsearch.randomizedtesting.RandomizedTest;
import io.github.jbellis.jvector.TestUtil;
import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import static io.github.jbellis.jvector.TestUtil.writeGraph;
import static org.junit.Assert.*;

public class TestGraphCache extends RandomizedTest {
private Path testDirectory;
private Path onDiskGraphIndexPath;
private RandomAccessVectorValues<float[]> vectors;


@Before
public void setup() throws IOException {
var fullyConnectedGraph = new TestUtil.FullyConnectedGraphIndex<float[]>(0, 6);
vectors = new ListRandomAccessVectorValues(IntStream.range(0, 6).mapToObj(i -> TestUtil.randomVector(getRandom(), 2)).collect(Collectors.toList()), 2);
testDirectory = Files.createTempDirectory(this.getClass().getSimpleName());
onDiskGraphIndexPath = testDirectory.resolve("fullyConnectedGraph");
writeGraph(fullyConnectedGraph, vectors, onDiskGraphIndexPath);
}

@After
public void tearDown() {
TestUtil.deleteQuietly(testDirectory);
}

@Test
public void testGraphCacheLoading() throws Exception {
try (var marr = new SimpleMappedReader(onDiskGraphIndexPath.toAbsolutePath().toString());
var onDiskGraph = new OnDiskGraphIndex<float[]>(marr::duplicate, 0))
{
var none = GraphCache.load(onDiskGraph, -1);
assertEquals(0, none.ramBytesUsed());
assertNull(none.getNode(0));
var zero = GraphCache.load(onDiskGraph, 0);
assertNotNull(zero.getNode(0));
assertNull(zero.getNode(1));
var one = GraphCache.load(onDiskGraph, 1);
// move from caching entry node to entry node + all its neighbors (5)
assertEquals(one.ramBytesUsed(), zero.ramBytesUsed() * (onDiskGraph.size()));
for (int i = 0; i < 6; i++) {
assertArrayEquals(one.getNode(i).vector, vectors.vectorValue(i), 0);
// fully connected,
assertEquals(one.getNode(i).neighbors.length, onDiskGraph.maxDegree());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ byte[] randomVector(int dim) {

@Override
AbstractMockVectorValues<byte[]> vectorValues(int size, int dimension) {
return MockByteVectorValues.fromValues(GraphIndexTestCase.createRandomByteVectors(size, dimension, getRandom()));
return MockByteVectorValues.fromValues(createRandomByteVectors(size, dimension, getRandom()));
}

static boolean fitsInByte(float v) {
Expand Down Expand Up @@ -87,7 +87,7 @@ AbstractMockVectorValues<byte[]> vectorValues(
int pregeneratedOffset) {
byte[][] vectors = new byte[size][];
byte[][] randomVectors =
GraphIndexTestCase.createRandomByteVectors(size - pregeneratedVectorValues.values.length, dimension, getRandom());
createRandomByteVectors(size - pregeneratedVectorValues.values.length, dimension, getRandom());

for (int i = 0; i < pregeneratedOffset; i++) {
vectors[i] = randomVectors[i];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ float[] randomVector(int dim) {

@Override
AbstractMockVectorValues<float[]> vectorValues(int size, int dimension) {
return MockVectorValues.fromValues(GraphIndexTestCase.createRandomFloatVectors(size, dimension, getRandom()));
return MockVectorValues.fromValues(createRandomFloatVectors(size, dimension, getRandom()));
}

@Override
Expand All @@ -74,7 +74,7 @@ AbstractMockVectorValues<float[]> vectorValues(
int pregeneratedOffset) {
float[][] vectors = new float[size][];
float[][] randomVectors =
GraphIndexTestCase.createRandomFloatVectors(
createRandomFloatVectors(
size - pregeneratedVectorValues.values.length, dimension, getRandom());

for (int i = 0; i < pregeneratedOffset; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,13 @@ private static <T> void validateGraph(GraphIndex.View<T> expectedView, GraphInde
}
}

private static <T> void writeGraph(GraphIndex<T> graph, RandomAccessVectorValues<T> vectors, Path outputPath) throws IOException {
try (var indexOutputWriter = TestUtil.openFileForWriting(outputPath))
{
OnDiskGraphIndex.write(graph, vectors, indexOutputWriter);
indexOutputWriter.flush();
}
}

@Test
public void testSimpleGraphs() throws Exception {
for (var graph : List.of(fullyConnectedGraph, randomlyConnectedGraph))
{
var outputPath = testDirectory.resolve("test_graph_" + graph.getClass().getSimpleName());
var ravv = new GraphIndexTestCase.CircularFloatVectorValues(graph.size());
writeGraph(graph, ravv, outputPath);
TestUtil.writeGraph(graph, ravv, outputPath);
try (var marr = new SimpleMappedReader(outputPath.toAbsolutePath().toString());
var onDiskGraph = new OnDiskGraphIndex<float[]>(marr::duplicate, 0);
var onDiskView = onDiskGraph.getView())
Expand All @@ -107,15 +99,18 @@ public void testLargeGraph() throws Exception
var graph = new TestUtil.RandomlyConnectedGraphIndex<float[]>(100_000, 16, getRandom());
var outputPath = testDirectory.resolve("large_graph");
var ravv = new GraphIndexTestCase.CircularFloatVectorValues(graph.size());
writeGraph(graph, ravv, outputPath);
TestUtil.writeGraph(graph, ravv, outputPath);

try (var marr = new SimpleMappedReader(outputPath.toAbsolutePath().toString());
var onDiskGraph = new OnDiskGraphIndex<float[]>(marr::duplicate, 0);
var onDiskView = onDiskGraph.getView())
var onDiskView = onDiskGraph.getView();
var cachedOnDiskGraph = new CachingGraphIndex(onDiskGraph);
var cachedOnDiskView = cachedOnDiskGraph.getView())
{
validateGraph(graph.getView(), onDiskView);
validateGraph(graph.getView(), new CachingGraphIndex(onDiskGraph).getView());
validateGraph(graph.getView(), cachedOnDiskView);
validateVectors(onDiskView, ravv);
validateVectors(cachedOnDiskView, ravv);
}
}
}

0 comments on commit 11e2954

Please sign in to comment.