diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index fa95edd9..fd767194 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -566,7 +566,7 @@ def predict_conditional_average_outcomes( ): fold_model = cfe._estimators[fold_index] predict_method = getattr(fold_model, predict_method_name) - fold_estimates = predict_method(X[test_indices]) + fold_estimates = predict_method(index_matrix(X, test_indices)) conditional_average_outcome_estimates[test_indices] = fold_estimates conditional_average_outcomes_list.append(