diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java index 14029c9d..929cbde5 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java @@ -16,6 +16,7 @@ package io.github.jbellis.jvector.pq; +import io.github.jbellis.jvector.annotations.VisibleForTesting; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; @@ -36,7 +37,7 @@ public abstract class PQVectors implements CompressedVectors { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); - private static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 16; // standard Java array size limit with some headroom + static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 16; // standard Java array size limit with some headroom final ProductQuantization pq; protected ByteSequence[] compressedDataChunks; @@ -52,22 +53,15 @@ public static ImmutablePQVectors load(RandomAccessReader in) throws IOException var pq = ProductQuantization.load(in); // read the vectors - int vectorCount = in.readInt(); - if (vectorCount < 0) { - throw new IOException("Invalid compressed vector count " + vectorCount); - } - + int vectorCount = in.readInt(); int compressedDimension = in.readInt(); - if (compressedDimension < 0) { - throw new IOException("Invalid compressed vector dimension " + compressedDimension); - } - - // Calculate if we need to split into multiple chunks - long totalSize = (long) vectorCount * compressedDimension; - int vectorsPerChunk = totalSize <= PQVectors.MAX_CHUNK_SIZE ? vectorCount : PQVectors.MAX_CHUNK_SIZE / compressedDimension; - - int fullSizeChunks = vectorCount / vectorsPerChunk; - int totalChunks = vectorCount % vectorsPerChunk == 0 ? fullSizeChunks : fullSizeChunks + 1; + + int[] params = calculateChunkParameters(vectorCount, compressedDimension); + int vectorsPerChunk = params[0]; + int totalChunks = params[1]; + int fullSizeChunks = params[2]; + int remainingVectors = params[3]; + ByteSequence[] chunks = new ByteSequence[totalChunks]; int chunkBytes = vectorsPerChunk * compressedDimension; @@ -77,13 +71,38 @@ public static ImmutablePQVectors load(RandomAccessReader in) throws IOException // Last chunk might be smaller if (totalChunks > fullSizeChunks) { - int remainingVectors = vectorCount % vectorsPerChunk; chunks[fullSizeChunks] = vectorTypeSupport.readByteSequence(in, remainingVectors * compressedDimension); } return new ImmutablePQVectors(pq, chunks, vectorCount, vectorsPerChunk); } + /** + * Calculate chunking parameters for the given vector count and compressed dimension + * @return array of [vectorsPerChunk, totalChunks, fullSizeChunks, remainingVectors] + */ + @VisibleForTesting + static int[] calculateChunkParameters(int vectorCount, int compressedDimension) { + if (vectorCount < 0) { + throw new IllegalArgumentException("Invalid vector count " + vectorCount); + } + if (compressedDimension < 0) { + throw new IllegalArgumentException("Invalid compressed dimension " + compressedDimension); + } + + long totalSize = (long) vectorCount * compressedDimension; + int vectorsPerChunk = totalSize <= MAX_CHUNK_SIZE ? vectorCount : MAX_CHUNK_SIZE / compressedDimension; + if (vectorsPerChunk == 0) { + throw new IllegalArgumentException("Compressed dimension " + compressedDimension + " too large for chunking"); + } + + int fullSizeChunks = vectorCount / vectorsPerChunk; + int totalChunks = vectorCount % vectorsPerChunk == 0 ? fullSizeChunks : fullSizeChunks + 1; + + int remainingVectors = vectorCount % vectorsPerChunk; + return new int[] {vectorsPerChunk, totalChunks, fullSizeChunks, remainingVectors}; + } + public static PQVectors load(RandomAccessReader in, long offset) throws IOException { in.seek(offset); return load(in); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java index 26ce9581..0c9a0dde 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java @@ -19,7 +19,6 @@ import com.carrotsearch.randomizedtesting.RandomizedTest; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import io.github.jbellis.jvector.disk.SimpleMappedReader; -import io.github.jbellis.jvector.disk.SimpleReader; import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.VectorUtil; @@ -32,12 +31,9 @@ import java.io.DataOutputStream; import java.io.File; import java.io.FileOutputStream; -import java.io.IOException; import java.nio.file.Files; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.ThreadLocalRandom; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -49,8 +45,8 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; -import static org.junit.jupiter.api.Assertions.assertTrue; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class TestProductQuantization extends RandomizedTest { @@ -157,7 +153,7 @@ public void testConvergenceAnisotropic() { } // System.out.println("iterations=" + iterations); - assertTrue(improvedLoss < initialLoss, "improvedLoss=" + improvedLoss + " initialLoss=" + initialLoss); + assertTrue(improvedLoss < initialLoss); } /** @@ -253,25 +249,61 @@ public void testSaveVersion0() throws Exception { assertArrayEquals(contents1, contents2); } - @Test - public void testPQVectorsAllocation() throws IOException { - // test that MPVQ gets the math right in an allocation edge case - var R = getRandom(); - VectorFloat[] vectors = generate(2 * DEFAULT_CLUSTERS, 2, 1_000); - var ravv = new ListRandomAccessVectorValues(List.of(vectors), vectors[0].length()); - var pq = ProductQuantization.compute(ravv, 1, DEFAULT_CLUSTERS, false); - var pqv = new MutablePQVectors(pq); - // force allocation of a lot of backing storage - pqv.setZero(Integer.MAX_VALUE - 1); + private void validateChunkMath(int[] params, int expectedTotalVectors, int dimension) { + int vectorsPerChunk = params[0]; + int totalChunks = params[1]; + int fullSizeChunks = params[2]; + int remainingVectors = params[3]; + + // Basic parameter validation + assertTrue("vectorsPerChunk must be positive", vectorsPerChunk > 0); + assertTrue("totalChunks must be positive", totalChunks > 0); + assertTrue("fullSizeChunks must be non-negative", fullSizeChunks >= 0); + assertTrue("remainingVectors must be non-negative", remainingVectors >= 0); + assertTrue("fullSizeChunks must not exceed totalChunks", fullSizeChunks <= totalChunks); + assertTrue("remainingVectors must be less than vectorsPerChunk", remainingVectors < vectorsPerChunk); + + // Chunk size validation + assertTrue("Chunk size must not exceed MAX_CHUNK_SIZE", + (long) vectorsPerChunk * dimension <= PQVectors.MAX_CHUNK_SIZE); + + // Total vectors validation + long calculatedTotal = (long) fullSizeChunks * vectorsPerChunk + remainingVectors; + assertEquals("Total vectors must match expected count", + expectedTotalVectors, calculatedTotal); + + // Chunk count validation + assertEquals("Total chunks must match full + partial chunks", + totalChunks, fullSizeChunks + (remainingVectors > 0 ? 1 : 0)); + } - // write it out and load it, it's okay that it's zeros - var fileOut = File.createTempFile("pqtest", ".pq"); - try (var out = new DataOutputStream(new FileOutputStream(fileOut))) { - pqv.write(out); - } - // exercise the load() allocation path - try (var in = new SimpleReader(fileOut.toPath())) { - var pqv2 = PQVectors.load(in); - } + @Test + public void testPQVectorsChunkCalculation() { + // Test normal case + int[] params = PQVectors.calculateChunkParameters(1000, 8); + validateChunkMath(params, 1000, 8); + assertEquals(1000, params[0]); // vectorsPerChunk + assertEquals(1, params[1]); // numChunks + assertEquals(1, params[2]); // fullSizeChunks + assertEquals(0, params[3]); // remainingVectors + + // Test case requiring multiple chunks + int bigVectorCount = Integer.MAX_VALUE - 1; + int smallDim = 8; + params = PQVectors.calculateChunkParameters(bigVectorCount, smallDim); + validateChunkMath(params, bigVectorCount, smallDim); + assertTrue(params[0] > 0); + assertTrue(params[1] > 1); + + // Test edge case with large dimension + int smallVectorCount = 1000; + int bigDim = Integer.MAX_VALUE / 2; + params = PQVectors.calculateChunkParameters(smallVectorCount, bigDim); + validateChunkMath(params, smallVectorCount, bigDim); + assertTrue(params[0] > 0); + + // Test invalid inputs + assertThrows(IllegalArgumentException.class, () -> PQVectors.calculateChunkParameters(-1, 8)); + assertThrows(IllegalArgumentException.class, () -> PQVectors.calculateChunkParameters(100, -1)); } }