From ebc6aee9fb0305afa822d206cd1b560999a9585f Mon Sep 17 00:00:00 2001 From: Kevin Klein <7267523+kklein@users.noreply.github.com> Date: Thu, 27 Jun 2024 13:21:40 +0200 Subject: [PATCH] Expect explicit dimensions in reshaping. (#35) --- metalearners/cross_fit_estimator.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/metalearners/cross_fit_estimator.py b/metalearners/cross_fit_estimator.py index 6760413..5b141ac 100644 --- a/metalearners/cross_fit_estimator.py +++ b/metalearners/cross_fit_estimator.py @@ -220,7 +220,7 @@ def _predict_all(self, X: Matrix, method: PredictMethod) -> np.ndarray: ) for i, estimator in enumerate(self._estimators): predictions[:, :, i] = np.reshape( - getattr(estimator, method)(X), (-1, n_outputs) + getattr(estimator, method)(X), (len(X), n_outputs) ) if n_outputs == 1: return predictions[:, 0, :] @@ -255,7 +255,8 @@ def _predict_in_sample( ) for estimator, indices in zip(self._estimators, self._test_indices): fold_predictions = np.reshape( - getattr(estimator, method)(index_matrix(X, indices)), (-1, n_outputs, 1) + getattr(estimator, method)(index_matrix(X, indices)), + (len(indices), n_outputs, 1), ) predictions[indices] = fold_predictions if n_outputs == 1: