diff --git a/enso/sample/random_sampler.py b/enso/sample/random_sampler.py index eb6f66e..bbfd114 100644 --- a/enso/sample/random_sampler.py +++ b/enso/sample/random_sampler.py @@ -100,7 +100,9 @@ def _choose_starting_points(self, n_points=3): for cls in self.classes: indices = [i for i, val in enumerate(self.train_labels) if cls in val] index = random.choice(indices) - points.append(self.train_indices[index]) + train_index = self.train_indices[index] + if train_index not in points: + points.append(train_index) return points @property