Skip to content

Commit

Permalink
Subselect with index helper.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Feb 6, 2025
1 parent 942a3de commit ccda67b
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ONNX_PROBABILITIES_OUTPUTS,
default_metric,
index_matrix,
index_vector,
safe_len,
validate_model_and_predict_method,
validate_number_positive,
Expand Down Expand Up @@ -347,9 +348,9 @@ def _validate_outcome(self, y: Vector, w: Vector) -> None:
f" Yet we found {len(np.unique(y))} classes."
)
if self.is_classification:
classes_0 = set(np.unique(y[w == 0]))
classes_0 = set(np.unique(index_vector(y, w == 0)))
for tv in range(self.n_variants):
if set(np.unique(y[w == tv])) != classes_0:
if set(np.unique(index_vector(y, w == tv))) != classes_0:
raise ValueError(
f"Variants 0 and {tv} have seen different sets of classification outcomes. Please check your data."
)
Expand Down

0 comments on commit ccda67b

Please sign in to comment.