diff --git a/opusfilter/autogen_cluster.py b/opusfilter/autogen_cluster.py index fe604e6..d25e0b5 100644 --- a/opusfilter/autogen_cluster.py +++ b/opusfilter/autogen_cluster.py @@ -131,8 +131,13 @@ def plot(self, plt): """Plot clustering and histograms""" plt.figure(figsize=(10, 10)) data_t = PCA(n_components=2).fit_transform(self.standard_data) - colors = ['orange' if lbl == self.noisy_label else 'blue' for lbl in self.labels] - plt.scatter(data_t[:, 0], data_t[:, 1], c=colors, marker=',', s=1) + for label_id in [self.noisy_label, self.clean_label]: + points = np.where(self.labels == label_id) + plt.scatter(data_t[points, 0], data_t[points, 1], + c='orange' if label_id == self.noisy_label else 'blue', + label='noisy' if label_id == self.noisy_label else 'clean', + marker=',', s=1) + plt.legend() plt.title('Clusters') noisy_samples = self.df.iloc[np.where(self.labels == self.noisy_label)] clean_samples = self.df.iloc[np.where(self.labels == self.clean_label)]