From c0bdcbd6a263906523e602549cfe7b2ae1f89cba Mon Sep 17 00:00:00 2001 From: kklein Date: Thu, 15 Aug 2024 21:02:36 +0200 Subject: [PATCH] Filter properly. --- metalearners/xlearner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index fa95edd..fd76719 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -566,7 +566,7 @@ def predict_conditional_average_outcomes( ): fold_model = cfe._estimators[fold_index] predict_method = getattr(fold_model, predict_method_name) - fold_estimates = predict_method(X[test_indices]) + fold_estimates = predict_method(index_matrix(X, test_indices)) conditional_average_outcome_estimates[test_indices] = fold_estimates conditional_average_outcomes_list.append(