Skip to content

Commit

Permalink
replace test that allocated multiple GB of PQVectors with calculateCh…
Browse files Browse the repository at this point in the history
…unkParameters, this makes JUnit's small VMs happy
  • Loading branch information
jbellis committed Dec 24, 2024
1 parent 96b614a commit ac6caa1
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.annotations.VisibleForTesting;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
Expand All @@ -36,7 +37,7 @@

public abstract class PQVectors implements CompressedVectors {
private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
private static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 16; // standard Java array size limit with some headroom
static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 16; // standard Java array size limit with some headroom

final ProductQuantization pq;
protected ByteSequence<?>[] compressedDataChunks;
Expand All @@ -52,22 +53,15 @@ public static ImmutablePQVectors load(RandomAccessReader in) throws IOException
var pq = ProductQuantization.load(in);

// read the vectors
int vectorCount = in.readInt();
if (vectorCount < 0) {
throw new IOException("Invalid compressed vector count " + vectorCount);
}

int vectorCount = in.readInt();
int compressedDimension = in.readInt();
if (compressedDimension < 0) {
throw new IOException("Invalid compressed vector dimension " + compressedDimension);
}

// Calculate if we need to split into multiple chunks
long totalSize = (long) vectorCount * compressedDimension;
int vectorsPerChunk = totalSize <= PQVectors.MAX_CHUNK_SIZE ? vectorCount : PQVectors.MAX_CHUNK_SIZE / compressedDimension;

int fullSizeChunks = vectorCount / vectorsPerChunk;
int totalChunks = vectorCount % vectorsPerChunk == 0 ? fullSizeChunks : fullSizeChunks + 1;

int[] params = calculateChunkParameters(vectorCount, compressedDimension);
int vectorsPerChunk = params[0];
int totalChunks = params[1];
int fullSizeChunks = params[2];
int remainingVectors = params[3];

ByteSequence<?>[] chunks = new ByteSequence<?>[totalChunks];
int chunkBytes = vectorsPerChunk * compressedDimension;

Expand All @@ -77,13 +71,38 @@ public static ImmutablePQVectors load(RandomAccessReader in) throws IOException

// Last chunk might be smaller
if (totalChunks > fullSizeChunks) {
int remainingVectors = vectorCount % vectorsPerChunk;
chunks[fullSizeChunks] = vectorTypeSupport.readByteSequence(in, remainingVectors * compressedDimension);
}

return new ImmutablePQVectors(pq, chunks, vectorCount, vectorsPerChunk);
}

/**
* Calculate chunking parameters for the given vector count and compressed dimension
* @return array of [vectorsPerChunk, totalChunks, fullSizeChunks, remainingVectors]
*/
@VisibleForTesting
static int[] calculateChunkParameters(int vectorCount, int compressedDimension) {
if (vectorCount < 0) {
throw new IllegalArgumentException("Invalid vector count " + vectorCount);
}
if (compressedDimension < 0) {
throw new IllegalArgumentException("Invalid compressed dimension " + compressedDimension);
}

long totalSize = (long) vectorCount * compressedDimension;
int vectorsPerChunk = totalSize <= MAX_CHUNK_SIZE ? vectorCount : MAX_CHUNK_SIZE / compressedDimension;
if (vectorsPerChunk == 0) {
throw new IllegalArgumentException("Compressed dimension " + compressedDimension + " too large for chunking");
}

int fullSizeChunks = vectorCount / vectorsPerChunk;
int totalChunks = vectorCount % vectorsPerChunk == 0 ? fullSizeChunks : fullSizeChunks + 1;

int remainingVectors = vectorCount % vectorsPerChunk;
return new int[] {vectorsPerChunk, totalChunks, fullSizeChunks, remainingVectors};
}

public static PQVectors load(RandomAccessReader in, long offset) throws IOException {
in.seek(offset);
return load(in);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.carrotsearch.randomizedtesting.RandomizedTest;
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
import io.github.jbellis.jvector.disk.SimpleMappedReader;
import io.github.jbellis.jvector.disk.SimpleReader;
import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
Expand All @@ -32,12 +31,9 @@
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand All @@ -49,8 +45,8 @@
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertTrue;

@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class TestProductQuantization extends RandomizedTest {
Expand Down Expand Up @@ -157,7 +153,7 @@ public void testConvergenceAnisotropic() {
}
// System.out.println("iterations=" + iterations);

assertTrue(improvedLoss < initialLoss, "improvedLoss=" + improvedLoss + " initialLoss=" + initialLoss);
assertTrue(improvedLoss < initialLoss);
}

/**
Expand Down Expand Up @@ -253,25 +249,61 @@ public void testSaveVersion0() throws Exception {
assertArrayEquals(contents1, contents2);
}

@Test
public void testPQVectorsAllocation() throws IOException {
// test that MPVQ gets the math right in an allocation edge case
var R = getRandom();
VectorFloat<?>[] vectors = generate(2 * DEFAULT_CLUSTERS, 2, 1_000);
var ravv = new ListRandomAccessVectorValues(List.of(vectors), vectors[0].length());
var pq = ProductQuantization.compute(ravv, 1, DEFAULT_CLUSTERS, false);
var pqv = new MutablePQVectors(pq);
// force allocation of a lot of backing storage
pqv.setZero(Integer.MAX_VALUE - 1);
private void validateChunkMath(int[] params, int expectedTotalVectors, int dimension) {
int vectorsPerChunk = params[0];
int totalChunks = params[1];
int fullSizeChunks = params[2];
int remainingVectors = params[3];

// Basic parameter validation
assertTrue("vectorsPerChunk must be positive", vectorsPerChunk > 0);
assertTrue("totalChunks must be positive", totalChunks > 0);
assertTrue("fullSizeChunks must be non-negative", fullSizeChunks >= 0);
assertTrue("remainingVectors must be non-negative", remainingVectors >= 0);
assertTrue("fullSizeChunks must not exceed totalChunks", fullSizeChunks <= totalChunks);
assertTrue("remainingVectors must be less than vectorsPerChunk", remainingVectors < vectorsPerChunk);

// Chunk size validation
assertTrue("Chunk size must not exceed MAX_CHUNK_SIZE",
(long) vectorsPerChunk * dimension <= PQVectors.MAX_CHUNK_SIZE);

// Total vectors validation
long calculatedTotal = (long) fullSizeChunks * vectorsPerChunk + remainingVectors;
assertEquals("Total vectors must match expected count",
expectedTotalVectors, calculatedTotal);

// Chunk count validation
assertEquals("Total chunks must match full + partial chunks",
totalChunks, fullSizeChunks + (remainingVectors > 0 ? 1 : 0));
}

// write it out and load it, it's okay that it's zeros
var fileOut = File.createTempFile("pqtest", ".pq");
try (var out = new DataOutputStream(new FileOutputStream(fileOut))) {
pqv.write(out);
}
// exercise the load() allocation path
try (var in = new SimpleReader(fileOut.toPath())) {
var pqv2 = PQVectors.load(in);
}
@Test
public void testPQVectorsChunkCalculation() {
// Test normal case
int[] params = PQVectors.calculateChunkParameters(1000, 8);
validateChunkMath(params, 1000, 8);
assertEquals(1000, params[0]); // vectorsPerChunk
assertEquals(1, params[1]); // numChunks
assertEquals(1, params[2]); // fullSizeChunks
assertEquals(0, params[3]); // remainingVectors

// Test case requiring multiple chunks
int bigVectorCount = Integer.MAX_VALUE - 1;
int smallDim = 8;
params = PQVectors.calculateChunkParameters(bigVectorCount, smallDim);
validateChunkMath(params, bigVectorCount, smallDim);
assertTrue(params[0] > 0);
assertTrue(params[1] > 1);

// Test edge case with large dimension
int smallVectorCount = 1000;
int bigDim = Integer.MAX_VALUE / 2;
params = PQVectors.calculateChunkParameters(smallVectorCount, bigDim);
validateChunkMath(params, smallVectorCount, bigDim);
assertTrue(params[0] > 0);

// Test invalid inputs
assertThrows(IllegalArgumentException.class, () -> PQVectors.calculateChunkParameters(-1, 8));
assertThrows(IllegalArgumentException.class, () -> PQVectors.calculateChunkParameters(100, -1));
}
}

0 comments on commit ac6caa1

Please sign in to comment.