Skip to content

Commit

Permalink
Fixes bug in KMeansClusterer.
Browse files Browse the repository at this point in the history
  • Loading branch information
ppanopticon committed May 9, 2024
1 parent 7e8611d commit 54f45cc
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@ import org.vitrivr.cottontail.core.types.VectorValue
* @version 1.0.0
*/
interface Clusterer {

/** The [VectorDistance] used by this [Cluster]. */
val distance: VectorDistance<*>

/**
* Clusters a [List] of points [VectorValue] with this [Clusterer].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,52 @@ package org.vitrivr.cottontail.utilities.math.clustering
import org.apache.commons.math3.exception.ConvergenceException
import org.apache.commons.math3.exception.util.LocalizedFormats
import org.apache.commons.math3.stat.descriptive.moment.Variance
import org.vitrivr.cottontail.core.queries.functions.math.distance.binary.VectorDistance
import org.vitrivr.cottontail.core.queries.functions.VectorisableFunction
import org.vitrivr.cottontail.core.queries.functions.math.distance.SIMD
import org.vitrivr.cottontail.core.queries.functions.math.distance.binary.squaredeuclidean.SquaredEuclideanDistance
import org.vitrivr.cottontail.core.types.Types
import org.vitrivr.cottontail.core.types.VectorValue
import org.vitrivr.cottontail.core.values.DoubleValue
import org.vitrivr.cottontail.core.values.IntValue
import java.util.*
import java.util.random.RandomGenerator
import kotlin.math.pow

/**
* A [Clusterer] that uses the k-means++ algorithm.
*
* This implementation is an adaption of the Apache Commons Math [org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer].
*
* @author Ralph Gasser
* @version 1.1.0
* @version 1.2.0
*/
class KMeansClusterer(val k: Int, override val distance: VectorDistance<*>, private val random: RandomGenerator, private val iterations: Int = MAX_ITERATIONS): Clusterer {
class KMeansClusterer<V: VectorValue<*>>(val k: Int, type: Types<V>, private val random: RandomGenerator, private val iterations: Int = MAX_ITERATIONS): Clusterer {

companion object {
/** Maximum number of iterations to use for k means clustering. */
private const val MAX_ITERATIONS = 250
}

/** The [SquaredEuclideanDistance] implementation used by this [KMeansCluster]. */
private val distance: SquaredEuclideanDistance<V>

init {
require(this.k > 0) { "The number of cluster centers must be greater than zero." }
require(this.iterations > 0) { "The number of iterations must be greater than zero." }

val d = when(type) {
is Types.DoubleVector -> SquaredEuclideanDistance.DoubleVector(type)
is Types.FloatVector -> SquaredEuclideanDistance.FloatVector(type)
is Types.IntVector -> SquaredEuclideanDistance.IntVector(type)
is Types.LongVector -> SquaredEuclideanDistance.LongVector(type)
else -> throw IllegalArgumentException("Unsupported vector type $type.")
}
if (SIMD.SIMD_ENABLED && d is VectorisableFunction<*>) {
@Suppress("UNCHECKED_CAST")
this.distance = d.vectorized() as SquaredEuclideanDistance<V>
} else {
@Suppress("UNCHECKED_CAST")
this.distance = d as SquaredEuclideanDistance<V>
}
}

/**
Expand Down Expand Up @@ -81,7 +101,7 @@ class KMeansClusterer(val k: Int, override val distance: VectorDistance<*>, priv
private fun assignPointsToClusters(clusters: List<KMeansCluster>, points: List<VectorValue<*>>, assignments: IntArray): Int {
var assignedDifferently = 0
for ((pointIndex, p) in points.withIndex()) {
val clusterIndex: Int = getNearestCluster(clusters, p)
val clusterIndex: Int = getNearestCluster(clusters, p, assignments[pointIndex])
if (clusterIndex != assignments[pointIndex]) {
assignedDifferently++
}
Expand All @@ -99,14 +119,16 @@ class KMeansClusterer(val k: Int, override val distance: VectorDistance<*>, priv
* @param point The
* @return The index of the nearest [KMeansCluster] to the given point
*/
private fun getNearestCluster(clusters: List<KMeansCluster>, point: VectorValue<*>): Int {
var minDistance = Double.MAX_VALUE
var minCluster = 0
private fun getNearestCluster(clusters: List<KMeansCluster>, point: VectorValue<*>, assigned: Int): Int {
var minDistance = this.distance(clusters[assigned].center, point)!!
var minCluster = assigned
for ((clusterIndex, c) in clusters.withIndex()) {
val distance: Double = this.distance(c.center, point)!!.value
if (distance < minDistance) {
minDistance = distance
minCluster = clusterIndex
if (clusterIndex != assigned) {
val distance: DoubleValue = this.distance.invokeOrMaximum(c.center, point, minDistance)
if (distance < minDistance) {
minDistance = distance
minCluster = clusterIndex
}
}
}
return minCluster
Expand All @@ -132,7 +154,7 @@ class KMeansClusterer(val k: Int, override val distance: VectorDistance<*>, priv

/* To keep track of the minimum distance squared of elements of pointList to elements of resultSet. */
val minDistSquared = DoubleArray(points.size) {
this.distance(points[firstPointIndex], points[it])!!.value.pow(2.0)
this.distance(points[firstPointIndex], points[it])!!.value
}

/* Initialize the elements. Since the only point in resultSet is firstPoint, this is very easy. */
Expand All @@ -143,7 +165,7 @@ class KMeansClusterer(val k: Int, override val distance: VectorDistance<*>, priv
/* Add one new data point as a center. Each point x is chosen with probability proportional to D(x)2. */
val r = this.random.nextDouble() * distSqSum

/* The index of the next point to be added to the resultSet.. */
/* The index of the next point to be added to the resultSet. */
var nextPointIndex = -1

/* Sum through the squared min distances again, stopping when sum >= r. */
Expand Down Expand Up @@ -182,7 +204,7 @@ class KMeansClusterer(val k: Int, override val distance: VectorDistance<*>, priv
/* Now update elements of minDistSquared. We only have to compute the distance to the new center to do this. */
for (j in points.indices) {
if (!taken[j]) {
val d2: Double = distance(p, points[j])!!.value.pow(2.0)
val d2: Double = this.distance(p, points[j])!!.value
if (d2 < minDistSquared[j]) {
minDistSquared[j] = d2
}
Expand All @@ -208,10 +230,9 @@ class KMeansClusterer(val k: Int, override val distance: VectorDistance<*>, priv
for (cluster in clusters) {
if (cluster.points.isNotEmpty()) {
/* Compute the distance variance of the current cluster. */
val center = cluster.center
val stat = Variance()
for (point in cluster.points) {
stat.increment(this@KMeansClusterer.distance(point, center)!!.value)
stat.increment(this.distance(cluster.center, point)!!.value)
}
val variance = stat.result

Expand Down Expand Up @@ -250,7 +271,7 @@ class KMeansClusterer(val k: Int, override val distance: VectorDistance<*>, priv
/**
* A [Cluster] that is generated as a result of this [KMeansCluster].
*/
data class KMeansCluster(override val center: VectorValue<*>): Cluster {
inner class KMeansCluster(override val center: VectorValue<*>): Cluster {

/** The [List] of [VectorValue] held by this [KMeansCluster]. */
override val points: List<VectorValue<*>> = LinkedList()
Expand Down

0 comments on commit 54f45cc

Please sign in to comment.