Skip to content

Commit

Permalink
Merge pull request #100 from jkni/faster-pq
Browse files Browse the repository at this point in the history
KMeansPlusPlusClusterer optimizations
  • Loading branch information
jkni authored Sep 29, 2023
2 parents f89e4c4 + 76fa03b commit 03b2850
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import io.github.jbellis.jvector.vector.VectorUtil;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
Expand All @@ -31,12 +30,11 @@ public class KMeansPlusPlusClusterer {
private final int k;
private final BiFunction<float[], float[], Float> distanceFunction;
private final Random random;
private final List<float[]>[] clusterPoints;
private final float[][] centroidDistances;
private final float[][] points;
private final int[] assignments;
private final float[][] centroids;

private final int[] centroidDenoms;
private final float[][] centroidNums;

/**
* Constructs a KMeansPlusPlusFloatClusterer with the specified number of clusters,
Expand All @@ -56,22 +54,18 @@ public KMeansPlusPlusClusterer(float[][] points, int k, BiFunction<float[], floa
this.points = points;
this.k = k;
this.distanceFunction = distanceFunction;
this.random = new Random();
this.clusterPoints = new List[k];
for (int i = 0; i < k; i++) {
this.clusterPoints[i] = new ArrayList<>();
}
centroidDistances = new float[k][k];
random = new Random();
centroidDenoms = new int[k];
centroidNums = new float[k][points[0].length];
centroids = chooseInitialCentroids(points);
updateCentroidDistances();
assignments = new int[points.length];
assignPointsToClusters();
}

/**
* Performs clustering on the provided set of points.
*
* @return a list of cluster centroids.
* @return an array of cluster centroids.
*/
public float[][] cluster(int maxIterations) {
for (int i = 0; i < maxIterations; i++) {
Expand All @@ -85,42 +79,20 @@ public float[][] cluster(int maxIterations) {

// This is broken out as a separate public method to allow implementing OPQ efficiently
public int clusterOnce() {
for (int j = 0; j < k; j++) {
if (clusterPoints[j].isEmpty()) {
// Handle empty cluster by choosing a random point
// (Choosing the highest-variance point is much slower and no better after a couple iterations)
centroids[j] = points[random.nextInt(points.length)];
} else {
centroids[j] = centroidOf(clusterPoints[j]);
}
}
int changedCount = assignPointsToClusters();
updateCentroidDistances();

return changedCount;
}

private void updateCentroidDistances() {
for (int m = 0; m < k; m++) {
for (int n = m + 1; n < k; n++) {
float distance = distanceFunction.apply(centroids[m], centroids[n]);
centroidDistances[m][n] = distance;
centroidDistances[n][m] = distance; // Distance matrix is symmetric
}
}
updateCentroids();
return assignPointsToClusters();
}

/**
* Chooses the initial centroids for clustering.
*
* The first centroid is chosen randomly from the data points. Subsequent centroids
* are selected with a probability proportional to the square of their distance
* to the nearest existing centroid. This ensures that the centroids are spread out
* across the data and not initialized too closely to each other, leading to better
* convergence and potentially improved final clusterings.
* *
*
* @param points a list of points from which centroids are chosen.
* @return a list of initial centroids.
* @return an array of initial centroids.
*/
private float[][] chooseInitialCentroids(float[][] points) {
float[][] centroids = new float[k][];
Expand Down Expand Up @@ -175,21 +147,18 @@ private float[][] chooseInitialCentroids(float[][] points) {
private int assignPointsToClusters() {
int changedCount = 0;

for (List<float[]> cluster : clusterPoints) {
cluster.clear();
}

for (int i = 0; i < points.length; i++) {
float[] point = points[i];
int clusterIndex = getNearestCluster(point, centroids);

// Check if assignment has changed
if (assignments[i] != clusterIndex) {
var oldAssignment = assignments[i];
var newAssignment = getNearestCluster(point, centroids);
if (newAssignment != oldAssignment) {
centroidDenoms[oldAssignment] = centroidDenoms[oldAssignment] - 1;
centroidDenoms[newAssignment] = centroidDenoms[newAssignment] + 1;
VectorUtil.subInPlace(centroidNums[oldAssignment], point);
VectorUtil.addInPlace(centroidNums[newAssignment], point);
assignments[i] = newAssignment;
changedCount++;
}

clusterPoints[clusterIndex].add(point);
assignments[i] = clusterIndex;
}

return changedCount;
Expand All @@ -213,6 +182,21 @@ private int getNearestCluster(float[] point, float[][] centroids) {
return nearestCluster;
}

/**
* Calculates centroids from centroidNums/centroidDenoms updated during point assignment
*/
private void updateCentroids() {
for (int i = 0; i < centroids.length; i++) {
var denom = centroidDenoms[i];
if (denom == 0) {
centroids[i] = points[random.nextInt(points.length)];
} else {
centroids[i] = Arrays.copyOf(centroidNums[i], centroidNums[i].length);
VectorUtil.divInPlace(centroids[i], centroidDenoms[i]);
}
}
}

/**
* Computes the centroid of a list of points.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,13 @@ public void addInPlace(float[] v1, float[] v2) {
}
}

@Override
public void subInPlace(float[] v1, float[] v2) {
for (int i = 0; i < v1.length; i++) {
v1[i] -= v2[i];
}
}

@Override
public float[] sub(float[] lhs, float[] rhs) {
float[] result = new float[lhs.length];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,11 @@ public static void addInPlace(float[] v1, float[] v2) {
impl.addInPlace(v1, v2);
}

public static void subInPlace(float[] v1, float[] v2) {
impl.subInPlace(v1, v2);
}


public static float[] sub(float[] lhs, float[] rhs) {
return impl.sub(lhs, rhs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ public interface VectorUtilSupport {
/** Adds v2 into v1, in place (v1 will be modified) */
public void addInPlace(float[] v1, float[] v2);

/** Subtracts v2 from v1, in place (v1 will be modified) */
public void subInPlace(float[] v1, float[] v2);

/** @return lhs - rhs, element-wise */
public float[] sub(float[] lhs, float[] rhs);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ public void addInPlace(float[] v1, float[] v2) {
SimdOps.addInPlace(v1, v2);
}

@Override
public void subInPlace(float[] v1, float[] v2) {
SimdOps.subInPlace(v1, v2);
}

@Override
public float[] sub(float[] lhs, float[] rhs) {
return SimdOps.sub(lhs, rhs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,26 @@ static void addInPlace(float[] v1, float[] v2) {
}
}

static void subInPlace(float[] v1, float[] v2) {
if (v1.length != v2.length) {
throw new IllegalArgumentException("Vectors must have the same length");
}

int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(v1.length);

// Process the vectorized part
for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) {
var a = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v1, i);
var b = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v2, i);
a.sub(b).intoArray(v1, i);
}

// Process the tail
for (int i = vectorizedLength; i < v1.length; i++) {
v1[i] = v1[i] - v2[i];
}
}

static float[] sub(float[] lhs, float[] rhs) {
if (lhs.length != rhs.length) {
throw new IllegalArgumentException("Vectors must have the same length");
Expand Down

0 comments on commit 03b2850

Please sign in to comment.