diff --git a/tests/baseline_images/test_cluster/test_silhouette/test_colormap_as_colors_silhouette.png b/tests/baseline_images/test_cluster/test_silhouette/test_colormap_as_colors_silhouette.png index 1133e8455..4b63f15d1 100644 Binary files a/tests/baseline_images/test_cluster/test_silhouette/test_colormap_as_colors_silhouette.png and b/tests/baseline_images/test_cluster/test_silhouette/test_colormap_as_colors_silhouette.png differ diff --git a/tests/baseline_images/test_cluster/test_silhouette/test_colormap_silhouette.png b/tests/baseline_images/test_cluster/test_silhouette/test_colormap_silhouette.png index bf4a0c090..33e1b7c95 100644 Binary files a/tests/baseline_images/test_cluster/test_silhouette/test_colormap_silhouette.png and b/tests/baseline_images/test_cluster/test_silhouette/test_colormap_silhouette.png differ diff --git a/tests/baseline_images/test_cluster/test_silhouette/test_colors_silhouette.png b/tests/baseline_images/test_cluster/test_silhouette/test_colors_silhouette.png index c292d91c4..2661d601d 100644 Binary files a/tests/baseline_images/test_cluster/test_silhouette/test_colors_silhouette.png and b/tests/baseline_images/test_cluster/test_silhouette/test_colors_silhouette.png differ diff --git a/tests/baseline_images/test_cluster/test_silhouette/test_integrated_kmeans_silhouette.png b/tests/baseline_images/test_cluster/test_silhouette/test_integrated_kmeans_silhouette.png index d8dd1083b..fa3c8e234 100644 Binary files a/tests/baseline_images/test_cluster/test_silhouette/test_integrated_kmeans_silhouette.png and b/tests/baseline_images/test_cluster/test_silhouette/test_integrated_kmeans_silhouette.png differ diff --git a/tests/baseline_images/test_cluster/test_silhouette/test_integrated_mini_batch_kmeans_silhouette.png b/tests/baseline_images/test_cluster/test_silhouette/test_integrated_mini_batch_kmeans_silhouette.png index 90066eaeb..8668abb6c 100644 Binary files a/tests/baseline_images/test_cluster/test_silhouette/test_integrated_mini_batch_kmeans_silhouette.png and b/tests/baseline_images/test_cluster/test_silhouette/test_integrated_mini_batch_kmeans_silhouette.png differ diff --git a/tests/baseline_images/test_cluster/test_silhouette/test_quick_method.png b/tests/baseline_images/test_cluster/test_silhouette/test_quick_method.png index 0c08a484d..e916d9f9f 100644 Binary files a/tests/baseline_images/test_cluster/test_silhouette/test_quick_method.png and b/tests/baseline_images/test_cluster/test_silhouette/test_quick_method.png differ diff --git a/tests/test_cluster/test_silhouette.py b/tests/test_cluster/test_silhouette.py index 6f6615857..0c68a4868 100644 --- a/tests/test_cluster/test_silhouette.py +++ b/tests/test_cluster/test_silhouette.py @@ -20,14 +20,15 @@ import sys import pytest import matplotlib.pyplot as plt +import numpy as np from sklearn.datasets import make_blobs from sklearn.cluster import KMeans, MiniBatchKMeans +from sklearn.cluster import SpectralClustering, AgglomerativeClustering from unittest import mock from tests.base import VisualTestCase -from yellowbrick.datasets import load_nfl from yellowbrick.cluster.silhouette import SilhouetteVisualizer, silhouette_visualizer @@ -53,7 +54,6 @@ def test_integrated_kmeans_silhouette(self): n_samples=1000, n_features=12, centers=8, shuffle=False, random_state=0 ) - fig = plt.figure() ax = fig.add_subplot() @@ -62,7 +62,6 @@ def test_integrated_kmeans_silhouette(self): visualizer.finalize() self.assert_images_similar(visualizer, remove_legend=True) - @pytest.mark.xfail(sys.platform == "win32", reason="images not close on windows") def test_integrated_mini_batch_kmeans_silhouette(self): @@ -84,7 +83,6 @@ def test_integrated_mini_batch_kmeans_silhouette(self): visualizer.finalize() self.assert_images_similar(visualizer, remove_legend=True) - @pytest.mark.skip(reason="no negative silhouette example available yet") def test_negative_silhouette_score(self): @@ -103,7 +101,6 @@ def test_colormap_silhouette(self): n_samples=1000, n_features=12, centers=8, shuffle=False, random_state=0 ) - fig = plt.figure() ax = fig.add_subplot() @@ -138,7 +135,7 @@ def test_colors_silhouette(self): visualizer.finalize() self.assert_images_similar(visualizer, remove_legend=True) - + def test_colormap_as_colors_silhouette(self): """ Test no exceptions for modifying the colors in a silhouette visualizer @@ -162,7 +159,7 @@ def test_colormap_as_colors_silhouette(self): 3.2 if sys.platform == "win32" else 0.01 ) # Fails on AppVeyor with RMS 3.143 self.assert_images_similar(visualizer, remove_legend=True, tol=tol) - + def test_quick_method(self): """ Test the quick method producing a valid visualization @@ -177,29 +174,44 @@ def test_quick_method(self): self.assert_images_similar(oz) - @pytest.mark.xfail( - reason="""third test fails with AssertionError: Expected fit - to be called once. Called 0 times.""" - ) def test_with_fitted(self): """ Test that visualizer properly handles an already-fitted model """ - X, y = load_nfl(return_dataset=True).to_numpy() - - model = MiniBatchKMeans().fit(X, y) + X, y = make_blobs( + n_samples=100, n_features=5, centers=3, shuffle=False, random_state=112 + ) + model = MiniBatchKMeans().fit(X) + labels = model.predict(X) with mock.patch.object(model, "fit") as mockfit: oz = SilhouetteVisualizer(model) - oz.fit(X, y) + oz.fit(X) mockfit.assert_not_called() with mock.patch.object(model, "fit") as mockfit: oz = SilhouetteVisualizer(model, is_fitted=True) - oz.fit(X, y) + oz.fit(X) mockfit.assert_not_called() - with mock.patch.object(model, "fit") as mockfit: + with mock.patch.object(model, "fit_predict", return_value=labels) as mockfit: oz = SilhouetteVisualizer(model, is_fitted=False) - oz.fit(X, y) - mockfit.assert_called_once_with(X, y) + oz.fit(X) + mockfit.assert_called_once_with(X, None) + + @pytest.mark.parametrize( + "model", + [SpectralClustering, AgglomerativeClustering], + ) + def test_clusterer_without_predict(self, model): + """ + Assert that clustering estimators that don't implement + a predict() method utilize fit_predict() + """ + X = np.array([[1, 2], [1, 4], [1, 0], [4, 2], [4, 4], [4, 0]]) + try: + visualizer = SilhouetteVisualizer(model(n_clusters=2)) + visualizer.fit(X) + visualizer.finalize() + except AttributeError: + self.fail("could not use fit or fit_predict methods") diff --git a/yellowbrick/cluster/silhouette.py b/yellowbrick/cluster/silhouette.py index b847496d9..0e5572fa2 100644 --- a/yellowbrick/cluster/silhouette.py +++ b/yellowbrick/cluster/silhouette.py @@ -23,6 +23,35 @@ from sklearn.metrics import silhouette_score, silhouette_samples +try: + from sklearn.metrics.pairwise import _VALID_METRICS +except ImportError: + _VALID_METRICS = [ + "cityblock", + "cosine", + "euclidean", + "l1", + "l2", + "manhattan", + "braycurtis", + "canberra", + "chebyshev", + "correlation", + "dice", + "hamming", + "jaccard", + "kulsinski", + "mahalanobis", + "minkowski", + "rogerstanimoto", + "russellrao", + "seuclidean", + "sokalmichener", + "sokalsneath", + "sqeuclidean", + "yule", + ] + from yellowbrick.utils import check_fitted from yellowbrick.style import resolve_colors from yellowbrick.cluster.base import ClusteringScoreVisualizer @@ -113,7 +142,6 @@ class SilhouetteVisualizer(ClusteringScoreVisualizer): """ def __init__(self, estimator, ax=None, colors=None, is_fitted="auto", **kwargs): - # Initialize the visualizer bases super(SilhouetteVisualizer, self).__init__( estimator, ax=ax, is_fitted=is_fitted, **kwargs @@ -130,23 +158,47 @@ def __init__(self, estimator, ax=None, colors=None, is_fitted="auto", **kwargs): def fit(self, X, y=None, **kwargs): """ Fits the model and generates the silhouette visualization. + + Unlike other visualizers that use the score() method to draw the results, this + visualizer errs on visualizing on fit since this is when the clusters are + computed. This means that a predict call is required in fit (or a fit_predict) + in order to produce the visualization. """ - # TODO: decide to use this method or the score method to draw. - # NOTE: Probably this would be better in score, but the standard score - # is a little different and I'm not sure how it's used. + # If the estimator is not fitted, fit it; then call predict to get the labels + # for computing the silhoutte score on. If the estimator is already fitted, then + # attempt to predict the labels, but if the estimator is stateless, fit and + # predict on the data specified. At the end of this block, no matter the fitted + # state of the estimator and the method, we should have cluster labels for X. if not check_fitted(self.estimator, is_fitted_by=self.is_fitted): - # Fit the wrapped estimator - self.estimator.fit(X, y, **kwargs) + if hasattr(self.estimator, "fit_predict"): + labels = self.estimator.fit_predict(X, y, **kwargs) + else: + self.estimator.fit(X, y, **kwargs) + labels = self.estimator.predict(X) + else: + if hasattr(self.estimator, "predict"): + labels = self.estimator.predict(X) + else: + labels = self.estimator.fit_predict(X, y, **kwargs) # Get the properties of the dataset self.n_samples_ = X.shape[0] - self.n_clusters_ = self.estimator.n_clusters + + # Compute the number of available clusters from the estimator + if hasattr(self.estimator, "n_clusters"): + self.n_clusters_ = self.estimator.n_clusters + else: + unique_labels = set(labels) + n_noise_clusters = 1 if -1 in unique_labels else 0 + self.n_clusters_ = len(unique_labels) - n_noise_clusters + + # Identify the distance metric to use for silhouette scoring + metric = self._identify_silhouette_metric() # Compute the scores of the cluster - labels = self.estimator.predict(X) - self.silhouette_score_ = silhouette_score(X, labels) - self.silhouette_samples_ = silhouette_samples(X, labels) + self.silhouette_score_ = silhouette_score(X, labels, metric=metric) + self.silhouette_samples_ = silhouette_samples(X, labels, metric=metric) # Draw the silhouette figure self.draw(labels) @@ -185,7 +237,6 @@ def draw(self, labels): # For each cluster, plot the silhouette scores self.y_tick_pos_ = [] for idx in range(self.n_clusters_): - # Collect silhouette scores for samples in the current cluster . values = self.silhouette_samples_[labels == idx] values.sort() @@ -260,6 +311,26 @@ def finalize(self): # Show legend (Average Silhouette Score axis) self.ax.legend(loc="best") + def _identify_silhouette_metric(self): + """ + The Silhouette metric must be one of the distance options allowed by + metrics.pairwise.pairwise_distances or a callable. This method attempts to + discover a valid distance metric from the underlying estimator or returns + "euclidean" by default. + """ + if hasattr(self.estimator, "metric"): + if callable(self.estimator.metric): + return self.estimator.metric + + if self.estimator.metric in _VALID_METRICS: + return self.estimator.metric + + if hasattr(self.estimator, "affinity"): + if self.estimator.affinity in _VALID_METRICS: + return self.estimator.affinity + + return "euclidean" + ########################################################################## ## Quick Method