diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index ffd4028de..49269b00d 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -33,6 +33,7 @@ import java.util.Random; import java.util.Set; import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.IntStream; @@ -69,6 +70,9 @@ public class GraphIndexBuilder<T> { private final int dimension; // for convenience so we don't have to go to the pool for this private final NodeSimilarity similarity; + private final ForkJoinPool simdExecutor; + private final ForkJoinPool parallelExecutor; + private final AtomicInteger updateEntryNodeIn = new AtomicInteger(10_000); /** @@ -93,6 +97,37 @@ public GraphIndexBuilder( int beamWidth, float neighborOverflow, float alpha) { + this(vectorValues, vectorEncoding, similarityFunction, M, beamWidth, neighborOverflow, alpha, + PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool()); + } + + /** + * Reads all the vectors from vector values, builds a graph connecting them by their dense + * ordinals, using the given hyperparameter settings, and returns the resulting graph. + * + * @param vectorValues the vectors whose relations are represented by the graph - must provide a + * different view over those vectors than the one used to add via addGraphNode. + * @param M – the maximum number of connections a node can have + * @param beamWidth the size of the beam search to use when finding nearest neighbors. + * @param neighborOverflow the ratio of extra neighbors to allow temporarily when inserting a + * node. larger values will build more efficiently, but use more memory. + * @param alpha how aggressive pruning diverse neighbors should be. Set alpha > 1.0 to + * allow longer edges. If alpha = 1.0 then the equivalent of the lowest level of + * an HNSW graph will be created, which is usually not what you want. + * @param simdExecutor ForkJoinPool instance for SIMD operations, best is to use a pool with the size of + * the number of physical cores. + * @param parallelExecutor ForkJoinPool instance for parallel stream operations + */ + public GraphIndexBuilder( + RandomAccessVectorValues<T> vectorValues, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction, + int M, + int beamWidth, + float neighborOverflow, + float alpha, + ForkJoinPool simdExecutor, + ForkJoinPool parallelExecutor) { vectors = vectorValues.isValueShared() ? PoolingSupport.newThreadBased(vectorValues::copy) : PoolingSupport.newNoPooling(vectorValues); vectorsCopy = vectorValues.isValueShared() ? PoolingSupport.newThreadBased(vectorValues::copy) : PoolingSupport.newNoPooling(vectorValues); dimension = vectorValues.dimension(); @@ -107,6 +142,8 @@ public GraphIndexBuilder( throw new IllegalArgumentException("beamWidth must be positive"); } this.beamWidth = beamWidth; + this.simdExecutor = simdExecutor; + this.parallelExecutor = parallelExecutor; similarity = node1 -> { try (var v = vectors.get(); var vc = vectorsCopy.get()) { @@ -130,13 +167,13 @@ public OnHeapGraphIndex<T> build() { size = v.get().size(); } - PhysicalCoreExecutor.instance.execute(() -> { + simdExecutor.submit(() -> { IntStream.range(0, size).parallel().forEach(i -> { try (var v1 = vectors.get()) { addGraphNode(i, v1.get()); } }); - }); + }).join(); cleanup(); return graph; @@ -163,12 +200,12 @@ public void cleanup() { removeDeletedNodes(); // clean up overflowed neighbor lists - IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(i -> { + parallelExecutor.submit(() -> IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(i -> { var neighbors = graph.getNeighbors(i); if (neighbors != null) { neighbors.cleanup(); } - }); + })).join(); // reconnect any orphaned nodes. this will maintain neighbors size reconnectOrphanedNodes(); @@ -187,9 +224,9 @@ private void reconnectOrphanedNodes() { var connectedNodes = new AtomicFixedBitSet(graph.getIdUpperBound()); connectedNodes.set(graph.entry()); var entryNeighbors = graph.getNeighbors(graph.entry()).getCurrent(); - IntStream.range(0, entryNeighbors.size).parallel().forEach(node -> { + parallelExecutor.submit(() -> IntStream.range(0, entryNeighbors.size).parallel().forEach(node -> { findConnected(connectedNodes, entryNeighbors.node[node]); - }); + })).join(); // reconnect unreachable nodes var nReconnected = new AtomicInteger(); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/BinaryQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/BinaryQuantization.java index f295c4e48..ea1c777d4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/BinaryQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/BinaryQuantization.java @@ -19,7 +19,6 @@ import io.github.jbellis.jvector.disk.Io; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; -import io.github.jbellis.jvector.util.PhysicalCoreExecutor; import io.github.jbellis.jvector.util.PoolingSupport; import io.github.jbellis.jvector.vector.VectorUtil; @@ -27,6 +26,7 @@ import java.io.IOException; import java.util.Arrays; import java.util.List; +import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ThreadLocalRandom; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -45,10 +45,14 @@ public BinaryQuantization(float[] globalCentroid) { } public static BinaryQuantization compute(RandomAccessVectorValues<float[]> ravv) { + return compute(ravv, ForkJoinPool.commonPool()); + } + + public static BinaryQuantization compute(RandomAccessVectorValues<float[]> ravv, ForkJoinPool parallelExecutor) { // limit the number of vectors we train on var P = min(1.0f, ProductQuantization.MAX_PQ_TRAINING_SET_SIZE / (float) ravv.size()); var ravvCopy = ravv.isValueShared() ? PoolingSupport.newThreadBased(ravv::copy) : PoolingSupport.newNoPooling(ravv); - var vectors = IntStream.range(0, ravv.size()).parallel() + var vectors = parallelExecutor.submit(() -> IntStream.range(0, ravv.size()).parallel() .filter(i -> ThreadLocalRandom.current().nextFloat() < P) .mapToObj(targetOrd -> { try (var pooledRavv = ravvCopy.get()) { @@ -57,7 +61,8 @@ public static BinaryQuantization compute(RandomAccessVectorValues<float[]> ravv) return localRavv.isValueShared() ? Arrays.copyOf(v, v.length) : v; } }) - .collect(Collectors.toList()); + .collect(Collectors.toList())) + .join(); // compute the centroid of the training set float[] globalCentroid = KMeansPlusPlusClusterer.centroidOf(vectors); @@ -70,8 +75,8 @@ public CompressedVectors createCompressedVectors(Object[] compressedVectors) { } @Override - public long[][] encodeAll(List<float[]> vectors) { - return PhysicalCoreExecutor.instance.submit(() -> vectors.stream().parallel().map(this::encode).toArray(long[][]::new)); + public long[][] encodeAll(List<float[]> vectors, ForkJoinPool simdExecutor) { + return simdExecutor.submit(() -> vectors.stream().parallel().map(this::encode).toArray(long[][]::new)).join(); } /** diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java index cf5c32c57..550706543 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java @@ -19,9 +19,9 @@ import io.github.jbellis.jvector.disk.Io; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.util.PhysicalCoreExecutor; import io.github.jbellis.jvector.util.PoolingSupport; import io.github.jbellis.jvector.util.RamUsageEstimator; -import io.github.jbellis.jvector.util.PhysicalCoreExecutor; import io.github.jbellis.jvector.vector.VectorUtil; import java.io.DataOutput; @@ -30,6 +30,7 @@ import java.util.Arrays; import java.util.List; import java.util.Objects; +import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ThreadLocalRandom; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -62,11 +63,31 @@ public class ProductQuantization implements VectorCompressor<byte[]> { * (not recommended when using the quantization for dot product) */ public static ProductQuantization compute(RandomAccessVectorValues<float[]> ravv, int M, boolean globallyCenter) { + return compute(ravv, M, globallyCenter, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool()); + } + + /** + * Initializes the codebooks by clustering the input data using Product Quantization. + * + * @param ravv the vectors to quantize + * @param M number of subspaces + * @param globallyCenter whether to center the vectors globally before quantization + * (not recommended when using the quantization for dot product) + * @param simdExecutor ForkJoinPool instance for SIMD operations, best is to use a pool with the size of + * the number of physical cores. + * @param parallelExecutor ForkJoinPool instance for parallel stream operations + */ + public static ProductQuantization compute( + RandomAccessVectorValues<float[]> ravv, + int M, + boolean globallyCenter, + ForkJoinPool simdExecutor, + ForkJoinPool parallelExecutor) { // limit the number of vectors we train on var P = min(1.0f, MAX_PQ_TRAINING_SET_SIZE / (float) ravv.size()); var ravvCopy = ravv.isValueShared() ? PoolingSupport.newThreadBased(ravv::copy) : PoolingSupport.newNoPooling(ravv); var subvectorSizesAndOffsets = getSubvectorSizesAndOffsets(ravv.dimension(), M); - var vectors = IntStream.range(0, ravv.size()).parallel() + var vectors = parallelExecutor.submit(() -> IntStream.range(0, ravv.size()).parallel() .filter(i -> ThreadLocalRandom.current().nextFloat() < P) .mapToObj(targetOrd -> { try (var pooledRavv = ravvCopy.get()) { @@ -75,7 +96,8 @@ public static ProductQuantization compute(RandomAccessVectorValues<float[]> ravv return localRavv.isValueShared() ? Arrays.copyOf(v, v.length) : v; } }) - .collect(Collectors.toList()); + .collect(Collectors.toList())) + .join(); // subtract the centroid from each training vector float[] globalCentroid; @@ -83,13 +105,13 @@ public static ProductQuantization compute(RandomAccessVectorValues<float[]> ravv globalCentroid = KMeansPlusPlusClusterer.centroidOf(vectors); // subtract the centroid from each vector List<float[]> finalVectors = vectors; - vectors = PhysicalCoreExecutor.instance.submit(() -> finalVectors.stream().parallel().map(v -> VectorUtil.sub(v, globalCentroid)).collect(Collectors.toList())); + vectors = simdExecutor.submit(() -> finalVectors.stream().parallel().map(v -> VectorUtil.sub(v, globalCentroid)).collect(Collectors.toList())).join(); } else { globalCentroid = null; } // derive the codebooks - var codebooks = createCodebooks(vectors, M, subvectorSizesAndOffsets); + var codebooks = createCodebooks(vectors, M, subvectorSizesAndOffsets, simdExecutor); return new ProductQuantization(codebooks, globalCentroid); } @@ -116,8 +138,9 @@ public CompressedVectors createCompressedVectors(Object[] compressedVectors) { /** * Encodes the given vectors in parallel using the PQ codebooks. */ - public byte[][] encodeAll(List<float[]> vectors) { - return PhysicalCoreExecutor.instance.submit(() ->vectors.stream().parallel().map(this::encode).toArray(byte[][]::new)); + @Override + public byte[][] encodeAll(List<float[]> vectors, ForkJoinPool simdExecutor) { + return simdExecutor.submit(() ->vectors.stream().parallel().map(this::encode).toArray(byte[][]::new)).join(); } /** @@ -125,6 +148,7 @@ public byte[][] encodeAll(List<float[]> vectors) { * * @return one byte per subspace */ + @Override public byte[] encode(float[] vector) { if (globalCentroid != null) { vector = VectorUtil.sub(vector, globalCentroid); @@ -190,8 +214,8 @@ private static String arraySummary(float[] a) { return "[" + String.join(", ", b) + "]"; } - static float[][][] createCodebooks(List<float[]> vectors, int M, int[][] subvectorSizeAndOffset) { - return PhysicalCoreExecutor.instance.submit(() -> IntStream.range(0, M).parallel() + static float[][][] createCodebooks(List<float[]> vectors, int M, int[][] subvectorSizeAndOffset, ForkJoinPool simdExecutor) { + return simdExecutor.submit(() -> IntStream.range(0, M).parallel() .mapToObj(m -> { float[][] subvectors = vectors.stream().parallel() .map(vector -> getSubVector(vector, m, subvectorSizeAndOffset)) @@ -199,7 +223,8 @@ static float[][][] createCodebooks(List<float[]> vectors, int M, int[][] subvect var clusterer = new KMeansPlusPlusClusterer(subvectors, CLUSTERS, VectorUtil::squareDistance); return clusterer.cluster(K_MEANS_ITERATIONS); }) - .toArray(float[][][]::new)); + .toArray(float[][][]::new)) + .join(); } static int closetCentroidIndex(float[] subvector, float[][] codebook) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/VectorCompressor.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/VectorCompressor.java index 83e53bf12..1594cff84 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/VectorCompressor.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/VectorCompressor.java @@ -16,16 +16,24 @@ package io.github.jbellis.jvector.pq; +import io.github.jbellis.jvector.util.PhysicalCoreExecutor; + import java.io.DataOutput; import java.io.IOException; import java.util.List; +import java.util.concurrent.ForkJoinPool; /** * Interface for vector compression. T is the encoded (compressed) vector type; * it will be an array type. */ public interface VectorCompressor<T> { - T[] encodeAll(List<float[]> vectors); + + default T[] encodeAll(List<float[]> vectors) { + return encodeAll(vectors, PhysicalCoreExecutor.pool()); + } + + T[] encodeAll(List<float[]> vectors, ForkJoinPool simdExecutor); T encode(float[] v); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/PhysicalCoreExecutor.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/PhysicalCoreExecutor.java index 166c422e9..95ba1db9c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/PhysicalCoreExecutor.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/PhysicalCoreExecutor.java @@ -36,6 +36,10 @@ public class PhysicalCoreExecutor { public static final PhysicalCoreExecutor instance = new PhysicalCoreExecutor(physicalCoreCount); + public static ForkJoinPool pool() { + return instance.pool; + } + private final ForkJoinPool pool; private PhysicalCoreExecutor(int cores) {