Skip to content

Commit

Permalink
Add optional FJP args for indexing and quantization
Browse files Browse the repository at this point in the history
By default, SIMD operations are using `PhysicalCoreExecutor`
and non-SIMD operations are using `FJP.commonPool()`.

With this addition, one can define custom fork-join pools for these tasks
in `GraphIndexBuilder`, `ProductQuantization` and `BinaryQuantization`.
  • Loading branch information
mdogan committed Dec 1, 2023
1 parent b1c6986 commit 6e1ae9a
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -69,6 +70,9 @@ public class GraphIndexBuilder<T> {
private final int dimension; // for convenience so we don't have to go to the pool for this
private final NodeSimilarity similarity;

private final ForkJoinPool simdExecutor;
private final ForkJoinPool parallelExecutor;

private final AtomicInteger updateEntryNodeIn = new AtomicInteger(10_000);

/**
Expand All @@ -93,6 +97,37 @@ public GraphIndexBuilder(
int beamWidth,
float neighborOverflow,
float alpha) {
this(vectorValues, vectorEncoding, similarityFunction, M, beamWidth, neighborOverflow, alpha,
PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool());
}

/**
* Reads all the vectors from vector values, builds a graph connecting them by their dense
* ordinals, using the given hyperparameter settings, and returns the resulting graph.
*
* @param vectorValues the vectors whose relations are represented by the graph - must provide a
* different view over those vectors than the one used to add via addGraphNode.
* @param M – the maximum number of connections a node can have
* @param beamWidth the size of the beam search to use when finding nearest neighbors.
* @param neighborOverflow the ratio of extra neighbors to allow temporarily when inserting a
* node. larger values will build more efficiently, but use more memory.
* @param alpha how aggressive pruning diverse neighbors should be. Set alpha &gt; 1.0 to
* allow longer edges. If alpha = 1.0 then the equivalent of the lowest level of
* an HNSW graph will be created, which is usually not what you want.
* @param simdExecutor ForkJoinPool instance for SIMD operations, best is to use a pool with the size of
* the number of physical cores.
* @param parallelExecutor ForkJoinPool instance for parallel stream operations
*/
public GraphIndexBuilder(
RandomAccessVectorValues<T> vectorValues,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
int M,
int beamWidth,
float neighborOverflow,
float alpha,
ForkJoinPool simdExecutor,
ForkJoinPool parallelExecutor) {
vectors = vectorValues.isValueShared() ? PoolingSupport.newThreadBased(vectorValues::copy) : PoolingSupport.newNoPooling(vectorValues);
vectorsCopy = vectorValues.isValueShared() ? PoolingSupport.newThreadBased(vectorValues::copy) : PoolingSupport.newNoPooling(vectorValues);
dimension = vectorValues.dimension();
Expand All @@ -107,6 +142,8 @@ public GraphIndexBuilder(
throw new IllegalArgumentException("beamWidth must be positive");
}
this.beamWidth = beamWidth;
this.simdExecutor = simdExecutor;
this.parallelExecutor = parallelExecutor;

similarity = node1 -> {
try (var v = vectors.get(); var vc = vectorsCopy.get()) {
Expand All @@ -130,13 +167,13 @@ public OnHeapGraphIndex<T> build() {
size = v.get().size();
}

PhysicalCoreExecutor.instance.execute(() -> {
simdExecutor.submit(() -> {
IntStream.range(0, size).parallel().forEach(i -> {
try (var v1 = vectors.get()) {
addGraphNode(i, v1.get());
}
});
});
}).join();

cleanup();
return graph;
Expand All @@ -163,12 +200,12 @@ public void cleanup() {
removeDeletedNodes();

// clean up overflowed neighbor lists
IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(i -> {
parallelExecutor.submit(() -> IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(i -> {
var neighbors = graph.getNeighbors(i);
if (neighbors != null) {
neighbors.cleanup();
}
});
})).join();

// reconnect any orphaned nodes. this will maintain neighbors size
reconnectOrphanedNodes();
Expand All @@ -187,9 +224,9 @@ private void reconnectOrphanedNodes() {
var connectedNodes = new AtomicFixedBitSet(graph.getIdUpperBound());
connectedNodes.set(graph.entry());
var entryNeighbors = graph.getNeighbors(graph.entry()).getCurrent();
IntStream.range(0, entryNeighbors.size).parallel().forEach(node -> {
parallelExecutor.submit(() -> IntStream.range(0, entryNeighbors.size).parallel().forEach(node -> {
findConnected(connectedNodes, entryNeighbors.node[node]);
});
})).join();

// reconnect unreachable nodes
var nReconnected = new AtomicInteger();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
import io.github.jbellis.jvector.disk.Io;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
import io.github.jbellis.jvector.util.PoolingSupport;
import io.github.jbellis.jvector.vector.VectorUtil;

import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand All @@ -45,10 +45,14 @@ public BinaryQuantization(float[] globalCentroid) {
}

public static BinaryQuantization compute(RandomAccessVectorValues<float[]> ravv) {
return compute(ravv, ForkJoinPool.commonPool());
}

public static BinaryQuantization compute(RandomAccessVectorValues<float[]> ravv, ForkJoinPool parallelExecutor) {
// limit the number of vectors we train on
var P = min(1.0f, ProductQuantization.MAX_PQ_TRAINING_SET_SIZE / (float) ravv.size());
var ravvCopy = ravv.isValueShared() ? PoolingSupport.newThreadBased(ravv::copy) : PoolingSupport.newNoPooling(ravv);
var vectors = IntStream.range(0, ravv.size()).parallel()
var vectors = parallelExecutor.submit(() -> IntStream.range(0, ravv.size()).parallel()
.filter(i -> ThreadLocalRandom.current().nextFloat() < P)
.mapToObj(targetOrd -> {
try (var pooledRavv = ravvCopy.get()) {
Expand All @@ -57,7 +61,8 @@ public static BinaryQuantization compute(RandomAccessVectorValues<float[]> ravv)
return localRavv.isValueShared() ? Arrays.copyOf(v, v.length) : v;
}
})
.collect(Collectors.toList());
.collect(Collectors.toList()))
.join();

// compute the centroid of the training set
float[] globalCentroid = KMeansPlusPlusClusterer.centroidOf(vectors);
Expand All @@ -70,8 +75,8 @@ public CompressedVectors createCompressedVectors(Object[] compressedVectors) {
}

@Override
public long[][] encodeAll(List<float[]> vectors) {
return PhysicalCoreExecutor.instance.submit(() -> vectors.stream().parallel().map(this::encode).toArray(long[][]::new));
public long[][] encodeAll(List<float[]> vectors, ForkJoinPool simdExecutor) {
return simdExecutor.submit(() -> vectors.stream().parallel().map(this::encode).toArray(long[][]::new)).join();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import io.github.jbellis.jvector.disk.Io;
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
import io.github.jbellis.jvector.util.PoolingSupport;
import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.util.PhysicalCoreExecutor;
import io.github.jbellis.jvector.vector.VectorUtil;

import java.io.DataOutput;
Expand All @@ -30,6 +30,7 @@
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
Expand Down Expand Up @@ -62,11 +63,31 @@ public class ProductQuantization implements VectorCompressor<byte[]> {
* (not recommended when using the quantization for dot product)
*/
public static ProductQuantization compute(RandomAccessVectorValues<float[]> ravv, int M, boolean globallyCenter) {
return compute(ravv, M, globallyCenter, PhysicalCoreExecutor.pool(), ForkJoinPool.commonPool());
}

/**
* Initializes the codebooks by clustering the input data using Product Quantization.
*
* @param ravv the vectors to quantize
* @param M number of subspaces
* @param globallyCenter whether to center the vectors globally before quantization
* (not recommended when using the quantization for dot product)
* @param simdExecutor ForkJoinPool instance for SIMD operations, best is to use a pool with the size of
* the number of physical cores.
* @param parallelExecutor ForkJoinPool instance for parallel stream operations
*/
public static ProductQuantization compute(
RandomAccessVectorValues<float[]> ravv,
int M,
boolean globallyCenter,
ForkJoinPool simdExecutor,
ForkJoinPool parallelExecutor) {
// limit the number of vectors we train on
var P = min(1.0f, MAX_PQ_TRAINING_SET_SIZE / (float) ravv.size());
var ravvCopy = ravv.isValueShared() ? PoolingSupport.newThreadBased(ravv::copy) : PoolingSupport.newNoPooling(ravv);
var subvectorSizesAndOffsets = getSubvectorSizesAndOffsets(ravv.dimension(), M);
var vectors = IntStream.range(0, ravv.size()).parallel()
var vectors = parallelExecutor.submit(() -> IntStream.range(0, ravv.size()).parallel()
.filter(i -> ThreadLocalRandom.current().nextFloat() < P)
.mapToObj(targetOrd -> {
try (var pooledRavv = ravvCopy.get()) {
Expand All @@ -75,21 +96,22 @@ public static ProductQuantization compute(RandomAccessVectorValues<float[]> ravv
return localRavv.isValueShared() ? Arrays.copyOf(v, v.length) : v;
}
})
.collect(Collectors.toList());
.collect(Collectors.toList()))
.join();

// subtract the centroid from each training vector
float[] globalCentroid;
if (globallyCenter) {
globalCentroid = KMeansPlusPlusClusterer.centroidOf(vectors);
// subtract the centroid from each vector
List<float[]> finalVectors = vectors;
vectors = PhysicalCoreExecutor.instance.submit(() -> finalVectors.stream().parallel().map(v -> VectorUtil.sub(v, globalCentroid)).collect(Collectors.toList()));
vectors = simdExecutor.submit(() -> finalVectors.stream().parallel().map(v -> VectorUtil.sub(v, globalCentroid)).collect(Collectors.toList())).join();
} else {
globalCentroid = null;
}

// derive the codebooks
var codebooks = createCodebooks(vectors, M, subvectorSizesAndOffsets);
var codebooks = createCodebooks(vectors, M, subvectorSizesAndOffsets, simdExecutor);
return new ProductQuantization(codebooks, globalCentroid);
}

Expand All @@ -116,15 +138,17 @@ public CompressedVectors createCompressedVectors(Object[] compressedVectors) {
/**
* Encodes the given vectors in parallel using the PQ codebooks.
*/
public byte[][] encodeAll(List<float[]> vectors) {
return PhysicalCoreExecutor.instance.submit(() ->vectors.stream().parallel().map(this::encode).toArray(byte[][]::new));
@Override
public byte[][] encodeAll(List<float[]> vectors, ForkJoinPool simdExecutor) {
return simdExecutor.submit(() ->vectors.stream().parallel().map(this::encode).toArray(byte[][]::new)).join();
}

/**
* Encodes the input vector using the PQ codebooks.
*
* @return one byte per subspace
*/
@Override
public byte[] encode(float[] vector) {
if (globalCentroid != null) {
vector = VectorUtil.sub(vector, globalCentroid);
Expand Down Expand Up @@ -190,16 +214,17 @@ private static String arraySummary(float[] a) {
return "[" + String.join(", ", b) + "]";
}

static float[][][] createCodebooks(List<float[]> vectors, int M, int[][] subvectorSizeAndOffset) {
return PhysicalCoreExecutor.instance.submit(() -> IntStream.range(0, M).parallel()
static float[][][] createCodebooks(List<float[]> vectors, int M, int[][] subvectorSizeAndOffset, ForkJoinPool simdExecutor) {
return simdExecutor.submit(() -> IntStream.range(0, M).parallel()
.mapToObj(m -> {
float[][] subvectors = vectors.stream().parallel()
.map(vector -> getSubVector(vector, m, subvectorSizeAndOffset))
.toArray(float[][]::new);
var clusterer = new KMeansPlusPlusClusterer(subvectors, CLUSTERS, VectorUtil::squareDistance);
return clusterer.cluster(K_MEANS_ITERATIONS);
})
.toArray(float[][][]::new));
.toArray(float[][][]::new))
.join();
}

static int closetCentroidIndex(float[] subvector, float[][] codebook) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,24 @@

package io.github.jbellis.jvector.pq;

import io.github.jbellis.jvector.util.PhysicalCoreExecutor;

import java.io.DataOutput;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.ForkJoinPool;

/**
* Interface for vector compression. T is the encoded (compressed) vector type;
* it will be an array type.
*/
public interface VectorCompressor<T> {
T[] encodeAll(List<float[]> vectors);

default T[] encodeAll(List<float[]> vectors) {
return encodeAll(vectors, PhysicalCoreExecutor.pool());
}

T[] encodeAll(List<float[]> vectors, ForkJoinPool simdExecutor);

T encode(float[] v);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ public class PhysicalCoreExecutor {

public static final PhysicalCoreExecutor instance = new PhysicalCoreExecutor(physicalCoreCount);

public static ForkJoinPool pool() {
return instance.pool;
}

private final ForkJoinPool pool;

private PhysicalCoreExecutor(int cores) {
Expand Down

0 comments on commit 6e1ae9a

Please sign in to comment.