From 96b614ae9b8beb1c9b03365d01f3f3802998989c Mon Sep 17 00:00:00 2001 From: Jonathan Ellis Date: Tue, 24 Dec 2024 10:01:13 -0600 Subject: [PATCH] MutablePQVectors grows dynamically, this is a better fit for CompactionGraph --- .../jvector/pq/ImmutablePQVectors.java | 5 ++ .../jbellis/jvector/pq/MutablePQVectors.java | 63 ++++++++++++------- .../github/jbellis/jvector/pq/PQVectors.java | 18 +++--- .../github/jbellis/jvector/example/Bench.java | 3 +- .../jbellis/jvector/example/SiftSmall.java | 2 +- .../jvector/pq/TestProductQuantization.java | 5 +- 6 files changed, 60 insertions(+), 36 deletions(-) diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ImmutablePQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ImmutablePQVectors.java index e0902cd7f..fef169e12 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ImmutablePQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ImmutablePQVectors.java @@ -32,4 +32,9 @@ public ImmutablePQVectors(ProductQuantization pq, ByteSequence[] compressedDa this.vectorCount = vectorCount; this.vectorsPerChunk = vectorsPerChunk; } + + @Override + protected int validChunkCount() { + return compressedDataChunks.length; + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutablePQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutablePQVectors.java index b9e5f165d..6820155c6 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutablePQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutablePQVectors.java @@ -26,46 +26,61 @@ public class MutablePQVectors extends PQVectors implements MutableCompressedVectors> { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); + private static final int VECTORS_PER_CHUNK = 1024; + private static final int INITIAL_CHUNKS = 10; + private static final float GROWTH_FACTOR = 1.5f; + /** - * Construct a mutable PQVectors instance with the given ProductQuantization and maximum number of vectors that will be - * stored in this instance. The vectors are split into chunks to avoid exceeding the maximum array size. + * Construct a mutable PQVectors instance with the given ProductQuantization. + * The vectors storage will grow dynamically as needed. * @param pq the ProductQuantization to use - * @param maximumVectorCount the maximum number of vectors that will be stored in this instance */ - public MutablePQVectors(ProductQuantization pq, int maximumVectorCount) { + public MutablePQVectors(ProductQuantization pq) { super(pq); this.vectorCount = 0; - - // Calculate if we need to split into multiple chunks - int compressedDimension = pq.compressedVectorSize(); - long totalSize = (long) maximumVectorCount * compressedDimension; - this.vectorsPerChunk = totalSize <= MAX_CHUNK_SIZE ? maximumVectorCount : MAX_CHUNK_SIZE / compressedDimension; - - int fullSizeChunks = maximumVectorCount / vectorsPerChunk; - int totalChunks = maximumVectorCount % vectorsPerChunk == 0 ? fullSizeChunks : fullSizeChunks + 1; - ByteSequence[] chunks = new ByteSequence[totalChunks]; - int chunkBytes = vectorsPerChunk * compressedDimension; - for (int i = 0; i < fullSizeChunks; i++) - chunks[i] = vectorTypeSupport.createByteSequence(chunkBytes); - - // Last chunk might be smaller - if (totalChunks > fullSizeChunks) { - int remainingVectors = maximumVectorCount % vectorsPerChunk; - chunks[fullSizeChunks] = vectorTypeSupport.createByteSequence(remainingVectors * compressedDimension); - } - - this.compressedDataChunks = chunks; + this.vectorsPerChunk = VECTORS_PER_CHUNK; + this.compressedDataChunks = new ByteSequence[INITIAL_CHUNKS]; } @Override public void encodeAndSet(int ordinal, VectorFloat vector) { + ensureChunkCapacity(ordinal); vectorCount = max(vectorCount, ordinal + 1); pq.encodeTo(vector, get(ordinal)); } @Override public void setZero(int ordinal) { + ensureChunkCapacity(ordinal); vectorCount = max(vectorCount, ordinal + 1); get(ordinal).zero(); } + + private void ensureChunkCapacity(int ordinal) { + int chunkOrdinal = ordinal / vectorsPerChunk; + + // Grow backing array if needed + if (chunkOrdinal >= compressedDataChunks.length) { + int newLength = max(chunkOrdinal + 1, (int)(compressedDataChunks.length * GROWTH_FACTOR)); + ByteSequence[] newChunks = new ByteSequence[newLength]; + System.arraycopy(compressedDataChunks, 0, newChunks, 0, compressedDataChunks.length); + compressedDataChunks = newChunks; + } + + // Allocate all chunks up to and including the required one + int chunkBytes = VECTORS_PER_CHUNK * pq.compressedVectorSize(); + for (int i = validChunkCount(); i <= chunkOrdinal; i++) { + if (compressedDataChunks[i] == null) { + compressedDataChunks[i] = vectorTypeSupport.createByteSequence(chunkBytes); + } + } + } + + @Override + protected int validChunkCount() { + if (vectorCount == 0) + return 0; + int chunkOrdinal = (vectorCount - 1) / vectorsPerChunk; + return chunkOrdinal + 1; + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java index e1d8310a5..14029c9d8 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java @@ -36,7 +36,7 @@ public abstract class PQVectors implements CompressedVectors { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); - static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 16; // standard Java array size limit with some headroom + private static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 16; // standard Java array size limit with some headroom final ProductQuantization pq; protected ByteSequence[] compressedDataChunks; @@ -148,11 +148,16 @@ public void write(DataOutput out, int version) throws IOException // compressed vectors out.writeInt(vectorCount); out.writeInt(pq.getSubspaceCount()); - for (ByteSequence chunk : compressedDataChunks) { - vectorTypeSupport.writeByteSequence(out, chunk); + for (int i = 0; i < validChunkCount(); i++) { + vectorTypeSupport.writeByteSequence(out, compressedDataChunks[i]); } } + /** + * @return the number of chunks that have actually been allocated (<= compressedDataChunks.length) + */ + protected abstract int validChunkCount(); + /** * We consider two PQVectors equal when their PQs are equal and their compressed data is equal. We ignore the * chunking strategy in the comparison since this is an implementation detail. @@ -303,10 +308,10 @@ public long ramBytesUsed() { int AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; long codebooksSize = pq.ramBytesUsed(); - long chunksArraySize = OH_BYTES + AH_BYTES + (long) compressedDataChunks.length * REF_BYTES; + long chunksArraySize = OH_BYTES + AH_BYTES + (long) validChunkCount() * REF_BYTES; long dataSize = 0; - for (ByteSequence chunk : compressedDataChunks) { - dataSize += chunk.ramBytesUsed(); + for (int i = 0; i < validChunkCount(); i++) { + dataSize += compressedDataChunks[i].ramBytesUsed(); } return codebooksSize + chunksArraySize + dataSize; } @@ -318,5 +323,4 @@ public String toString() { ", count=" + vectorCount + '}'; } - } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java index ae7c24c36..d0aedae88 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java @@ -55,8 +55,7 @@ public static void main(String[] args) throws IOException { ds -> new PQParameters(ds.getDimension() / 8, 256, ds.similarityFunction == VectorSimilarityFunction.EUCLIDEAN, UNWEIGHTED) ); List> featureSets = Arrays.asList( - EnumSet.of(FeatureId.INLINE_VECTORS), - EnumSet.of(FeatureId.INLINE_VECTORS, FeatureId.FUSED_ADC) + EnumSet.of(FeatureId.INLINE_VECTORS) ); // args is list of regexes, possibly needing to be split by whitespace. diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java index 4401be5e5..a216a4803 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java @@ -214,7 +214,7 @@ public static void siftDiskAnnLTM(List> baseVectors, List[] vectors = generate(2 * DEFAULT_CLUSTERS, 2, 1_000); var ravv = new ListRandomAccessVectorValues(List.of(vectors), vectors[0].length()); var pq = ProductQuantization.compute(ravv, 1, DEFAULT_CLUSTERS, false); - var pqv = new MutablePQVectors(pq, Integer.MAX_VALUE); + var pqv = new MutablePQVectors(pq); + // force allocation of a lot of backing storage + pqv.setZero(Integer.MAX_VALUE - 1); // write it out and load it, it's okay that it's zeros - pqv.setZero(Integer.MAX_VALUE - 1); // sets internal count var fileOut = File.createTempFile("pqtest", ".pq"); try (var out = new DataOutputStream(new FileOutputStream(fileOut))) { pqv.write(out);