diff --git a/metalearners/metalearner.py b/metalearners/metalearner.py index b3627d86..793dca59 100644 --- a/metalearners/metalearner.py +++ b/metalearners/metalearner.py @@ -31,6 +31,7 @@ ONNX_PROBABILITIES_OUTPUTS, default_metric, index_matrix, + index_vector, safe_len, validate_model_and_predict_method, validate_number_positive, @@ -347,9 +348,9 @@ def _validate_outcome(self, y: Vector, w: Vector) -> None: f" Yet we found {len(np.unique(y))} classes." ) if self.is_classification: - classes_0 = set(np.unique(y[w == 0])) + classes_0 = set(np.unique(index_vector(y, w == 0))) for tv in range(self.n_variants): - if set(np.unique(y[w == tv])) != classes_0: + if set(np.unique(index_vector(y, w == tv))) != classes_0: raise ValueError( f"Variants 0 and {tv} have seen different sets of classification outcomes. Please check your data." )