Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PAM algorithm to K-Medoids #73

Merged
merged 33 commits into from
Nov 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
91891ae
Add pam algorithm
Nov 6, 2020
40e42e7
Merge remote-tracking branch 'upstream/master' into kmedoid_pam
Nov 7, 2020
4362f85
pam algorithm, not naive.
Nov 8, 2020
5a86b95
black reformat
Nov 8, 2020
6e6c90d
Fix mistake in code
Nov 8, 2020
2ed1803
optimization of the algorithm for speed, review from @kno10
Nov 10, 2020
8422c17
remove generator for couples
Nov 10, 2020
6eae84b
fix mistake
Nov 10, 2020
2a92978
Update pam review 2
Nov 11, 2020
ecce8c8
fix mistake
Nov 12, 2020
1cba61c
cython implementation
Nov 13, 2020
258c262
add test
Nov 13, 2020
9078628
disable openmp for windows and mac
Nov 13, 2020
482bc37
fix black
Nov 13, 2020
50d0eb3
fix setup.py for windows
Nov 13, 2020
7637891
remove test
Nov 13, 2020
eeaa2a3
change review
Nov 15, 2020
093f8b0
Merge branch 'master' into kmedoid_pam
TimotheeMathieu Nov 18, 2020
bd04827
fix black
Nov 18, 2020
e979579
Add build, remove parallel computing
Nov 21, 2020
b51d23b
Apply suggestions from code review
TimotheeMathieu Nov 21, 2020
e675bdb
apply suggested change & rename alternating to alternate.
Nov 21, 2020
4250681
fix test
Nov 21, 2020
8f2ada3
Merge remote-tracking branch 'upstream/master' into kmedoid_pam
Nov 21, 2020
552294b
make build default. Allow max_iter = 0 for build-only algo
Nov 21, 2020
b024b8e
Test for method and init
Nov 21, 2020
018c9c7
test on blobs example
Nov 21, 2020
2f6368f
fix typo
Nov 21, 2020
f1a33ad
fix difference long/long long windows vs linux
Nov 21, 2020
498d9b6
try another fix for windows/linux long difference
Nov 21, 2020
daa9879
test another fix cython long/int on different platforms
Nov 21, 2020
213bb2e
test all in int, cython kmedoid
Nov 22, 2020
9c15afa
explain test_kmedoid_results
Nov 26, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@
["sklearn_extra/utils/_cyfht.pyx"],
include_dirs=[np.get_include()],
),
Extension(
"sklearn_extra.cluster._k_medoids_helper",
["sklearn_extra/cluster/_k_medoids_helper.pyx"],
include_dirs=[np.get_include()],
),
Extension(
"sklearn_extra.robust._robust_weighted_estimator_helper",
["sklearn_extra/robust/_robust_weighted_estimator_helper.pyx"],
Expand Down
77 changes: 63 additions & 14 deletions sklearn_extra/cluster/_k_medoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from sklearn.utils.validation import check_is_fitted
from sklearn.exceptions import ConvergenceWarning

# cython implementation of swap step in PAM algorithm.
from ._k_medoids_helper import _compute_optimal_swap, _build


class KMedoids(BaseEstimator, ClusterMixin, TransformerMixin):
"""k-medoids clustering.
Expand All @@ -35,17 +38,27 @@ class KMedoids(BaseEstimator, ClusterMixin, TransformerMixin):
metric : string, or callable, optional, default: 'euclidean'
What distance metric to use. See :func:metrics.pairwise_distances

init : {'random', 'heuristic', 'k-medoids++'}, optional, default: 'heuristic'
method : {'alternate', 'pam'}, default: 'alternate'
Which algorithm to use.

init : {'random', 'heuristic', 'k-medoids++', 'build'}, optional, default: 'build'
Specify medoid initialization method. 'random' selects n_clusters
elements from the dataset. 'heuristic' picks the n_clusters points
with the smallest sum distance to every other point. 'k-medoids++'
follows an approach based on k-means++_, and in general, gives initial
medoids which are more separated than those generated by the other methods.
'build' is a greedy initialization of the medoids used in the original PAM
algorithm. Often 'build' is more efficient but slower than other
initializations on big datasets and it is also very non-robust,
if there are outliers in the dataset, use another initialization.

.. _k-means++: https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf

max_iter : int, optional, default : 300
Specify the maximum number of iterations when fitting.
Specify the maximum number of iterations when fitting. It can be zero in
which case only the initialization is computed which may be suitable for
large datasets when the initialization is sufficiently efficient
(i.e. for 'build' init).

random_state : int, RandomState instance or None, optional
Specify random state for the random number generator. Used to
Expand Down Expand Up @@ -112,24 +125,25 @@ def __init__(
self,
n_clusters=8,
metric="euclidean",
method="alternate",
init="heuristic",
max_iter=300,
random_state=None,
):
self.n_clusters = n_clusters
self.metric = metric
self.method = method
self.init = init
self.max_iter = max_iter
self.random_state = random_state

def _check_nonnegative_int(self, value, desc):
def _check_nonnegative_int(self, value, desc, strict=True):
"""Validates if value is a valid integer > 0"""

if (
value is None
or value <= 0
or not isinstance(value, (int, np.integer))
):
if strict:
negative = (value is None) or (value <= 0)
else:
negative = (value is None) or (value < 0)
if negative or not isinstance(value, (int, np.integer)):
raise ValueError(
"%s should be a nonnegative integer. "
"%s was given" % (desc, value)
Expand All @@ -140,10 +154,10 @@ def _check_init_args(self):

# Check n_clusters and max_iter
self._check_nonnegative_int(self.n_clusters, "n_clusters")
self._check_nonnegative_int(self.max_iter, "max_iter")
self._check_nonnegative_int(self.max_iter, "max_iter", False)

# Check init
init_methods = ["random", "heuristic", "k-medoids++"]
init_methods = ["random", "heuristic", "k-medoids++", "build"]
if self.init not in init_methods:
raise ValueError(
"init needs to be one of "
Expand Down Expand Up @@ -183,15 +197,44 @@ def fit(self, X, y=None):
)
labels = None

if self.method == "pam":
# Compute the distance to the first and second closest points
# among medoids.
Djs, Ejs = np.sort(D[medoid_idxs], axis=0)[[0, 1]]

# Continue the algorithm as long as
# the medoids keep changing and the maximum number
# of iterations is not exceeded

for self.n_iter_ in range(0, self.max_iter):
old_medoid_idxs = np.copy(medoid_idxs)
labels = np.argmin(D[medoid_idxs, :], axis=0)

# Update medoids with the new cluster indices
self._update_medoid_idxs_in_place(D, labels, medoid_idxs)
if self.method == "alternate":
# Update medoids with the new cluster indices
self._update_medoid_idxs_in_place(D, labels, medoid_idxs)
elif self.method == "pam":
not_medoid_idxs = np.delete(np.arange(len(D)), medoid_idxs)
optimal_swap = _compute_optimal_swap(
D,
medoid_idxs.astype(np.intc),
not_medoid_idxs.astype(np.intc),
Djs,
Ejs,
self.n_clusters,
)
if optimal_swap is not None:
i, j, _ = optimal_swap
medoid_idxs[medoid_idxs == i] = j

# update Djs and Ejs with new medoids
Djs, Ejs = np.sort(D[medoid_idxs], axis=0)[[0, 1]]
else:
raise ValueError(
f"method={self.method} is not supported. Supported methods "
f"are 'pam' and 'alternate'."
)

if np.all(old_medoid_idxs == medoid_idxs):
break
elif self.n_iter_ == self.max_iter - 1:
Expand All @@ -210,7 +253,7 @@ def fit(self, X, y=None):

# Expose labels_ which are the assignments of
# the training data to clusters
self.labels_ = labels
self.labels_ = np.argmin(D[medoid_idxs, :], axis=0)
self.medoid_indices_ = medoid_idxs
self.inertia_ = self._compute_inertia(self.transform(X))

Expand Down Expand Up @@ -252,6 +295,10 @@ def _update_medoid_idxs_in_place(self, D, labels, medoid_idxs):
if min_cost < curr_cost:
medoid_idxs[k] = cluster_k_idxs[min_cost_idx]

def _compute_cost(self, D, medoid_idxs):
""" Compute the cose for a given configuration of the medoids"""
return self._compute_inertia(D[:, medoid_idxs])

def transform(self, X):
"""Transforms X to cluster-distance space.

Expand Down Expand Up @@ -339,6 +386,8 @@ def _initialize_medoids(self, D, n_clusters, random_state_):
medoids = np.argpartition(np.sum(D, axis=1), n_clusters - 1)[
:n_clusters
]
elif self.init == "build": # Build initialization
medoids = _build(D, n_clusters).astype(np.int64)
else:
raise ValueError(f"init value '{self.init}' not recognized")

Expand Down
110 changes: 110 additions & 0 deletions sklearn_extra/cluster/_k_medoids_helper.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# cython: infer_types=True
# Fast swap step and build step in PAM algorithm for k_medoid.
# Author: Timothée Mathieu
# License: 3-clause BSD

cimport cython

import numpy as np
cimport numpy as np
from cython cimport floating, integral

@cython.boundscheck(False) # Deactivate bounds checking
def _compute_optimal_swap( floating[:,:] D,
int[:] medoid_idxs,
int[:] not_medoid_idxs,
floating[:] Djs,
floating[:] Ejs,
int n_clusters):
"""Compute best cost change for all the possible swaps."""

# Initialize best cost change and the associated swap couple.
cdef (int, int, floating) best_cost_change = (1, 1, 0.0)
cdef int sample_size = len(D)
cdef int i, j, h, id_i, id_h, id_j
cdef floating cost_change
cdef int not_medoid_shape = sample_size - n_clusters
cdef bint cluster_i_bool, not_cluster_i_bool, second_best_medoid
cdef bint not_second_best_medoid

# Compute the change in cost for each swap.
for h in range(not_medoid_shape):
# id of the potential new medoid.
id_h = not_medoid_idxs[h]
for i in range(n_clusters):
# id of the medoid we want to replace.
id_i = medoid_idxs[i]
cost_change = 0.0
# compute for all not-selected points the change in cost
for j in range(not_medoid_shape):
id_j = not_medoid_idxs[j]
cluster_i_bool = D[id_i, id_j] == Djs[id_j]
not_cluster_i_bool = D[id_i, id_j] != Djs[id_j]
second_best_medoid = D[id_h, id_j] < Ejs[id_j]
not_second_best_medoid = D[id_h, id_j] >= Ejs[id_j]

if cluster_i_bool & second_best_medoid:
cost_change += D[id_j, id_h] - Djs[id_j]
elif cluster_i_bool & not_second_best_medoid:
cost_change += Ejs[id_j] - Djs[id_j]
elif not_cluster_i_bool & (D[id_j, id_h] < Djs[id_j]):
cost_change += D[id_j, id_h] - Djs[id_j]

# same for i
second_best_medoid = D[id_h, id_i] < Ejs[id_i]
if second_best_medoid:
cost_change += D[id_i, id_h]
else:
cost_change += Ejs[id_i]

if cost_change < best_cost_change[2]:
best_cost_change = (id_i, id_h, cost_change)

# If one of the swap decrease the objective, return that swap.
if best_cost_change[2] < 0:
return best_cost_change
else:
return None




def _build( floating[:, :] D, int n_clusters):
"""Compute BUILD initialization, a greedy medoid initialization."""

cdef int[:] medoid_idxs = np.zeros(n_clusters, dtype = np.intc)
cdef int sample_size = len(D)
cdef int[:] not_medoid_idxs = np.zeros(sample_size, dtype = np.intc)
cdef int i, j, id_i, id_j

medoid_idxs[0] = np.argmin(np.sum(D,axis=0))
not_medoid_idxs = np.delete(not_medoid_idxs, medoid_idxs[0])

cdef int n_medoids_current = 1

cdef floating[:] Dj = D[medoid_idxs[0]].copy()
cdef floating cost_change
cdef (int, int) new_medoid = (medoid_idxs[0], 0)
cdef floating cost_change_max

for _ in range(n_clusters -1):
cost_change_max = 0
for i in range(sample_size - n_medoids_current):
id_i = not_medoid_idxs[i]
cost_change = 0
for j in range(sample_size - n_medoids_current):
id_j = not_medoid_idxs[j]
cost_change += max(0, Dj[id_j] - D[id_i, id_j])
if cost_change >= cost_change_max:
cost_change_max = cost_change
new_medoid = (id_i, i)


medoid_idxs[n_medoids_current] = new_medoid[0]
n_medoids_current += 1
not_medoid_idxs = np.delete(not_medoid_idxs, new_medoid[1])


for id_j in range(sample_size):
Dj[id_j] = min(Dj[id_j], D[id_j, new_medoid[0]])
return np.array(medoid_idxs)
40 changes: 38 additions & 2 deletions sklearn_extra/cluster/tests/test_k_medoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,46 @@

from sklearn_extra.cluster import KMedoids
from sklearn.cluster import KMeans
from sklearn.datasets import make_blobs


seed = 0
X = np.random.RandomState(seed).rand(100, 5)

# test kmedoid's results
rng = np.random.RandomState(seed)
X_cc, y_cc = make_blobs(
n_samples=100,
centers=np.array([[-1, -1], [1, 1]]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can reduce the cluster_std to 0.2 some such value? Currently it's at 1.0 and that's why there is an overlap between clusters and we detect some fraction of points in the wrong cluster. Or alternatively comment that the 0.8 value in test is due to the data and not the the kmedoids itself.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I commented to explain. I prefer to keep the clusters not so separable in order to test the stability of the method at the same time.

random_state=rng,
shuffle=False,
)


@pytest.mark.parametrize("method", ["alternate", "pam"])
@pytest.mark.parametrize(
"init", ["random", "heuristic", "build", "k-medoids++"]
)
def test_kmedoid_results(method, init):
expected = np.hstack([np.zeros(50), np.ones(50)])
km = KMedoids(n_clusters=2, init=init, method=method)
km.fit(X_cc)
# This test use data that are not perfectly separable so the
# accuracy is not 1. Accuracy around 0.85
assert (np.mean(km.labels_ == expected) > 0.8) or (
1 - np.mean(km.labels_ == expected) > 0.8
)


def test_medoids_invalid_method():
with pytest.raises(ValueError, match="invalid is not supported"):
KMedoids(n_clusters=1, method="invalid").fit([[0, 1], [1, 1]])


def test_medoids_invalid_init():
with pytest.raises(ValueError, match="init needs to be one of"):
KMedoids(n_clusters=1, init="invalid").fit([[0, 1], [1, 1]])


def test_kmedoids_input_validation_and_fit_check():
rng = np.random.RandomState(seed)
Expand All @@ -28,9 +64,9 @@ def test_kmedoids_input_validation_and_fit_check():
with pytest.raises(ValueError, match=msg):
KMedoids(n_clusters=None).fit(X)

msg = "max_iter should be a nonnegative integer. 0 was given"
msg = "max_iter should be a nonnegative integer. -1 was given"
with pytest.raises(ValueError, match=msg):
KMedoids(n_clusters=1, max_iter=0).fit(X)
KMedoids(n_clusters=1, max_iter=-1).fit(X)

msg = "max_iter should be a nonnegative integer. None was given"
with pytest.raises(ValueError, match=msg):
Expand Down