Skip to content

Commit

Permalink
Merge branch 'rename-t-to-w' of github.com:Quantco/metalearners into …
Browse files Browse the repository at this point in the history
…rename-t-to-w
  • Loading branch information
kklein committed Jun 27, 2024
2 parents 070cb00 + ad06f94 commit 24fd9aa
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions metalearners/cross_fit_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _predict_all(self, X: Matrix, method: PredictMethod) -> np.ndarray:
)
for i, estimator in enumerate(self._estimators):
predictions[:, :, i] = np.reshape(
getattr(estimator, method)(X), (-1, n_outputs)
getattr(estimator, method)(X), (len(X), n_outputs)
)
if n_outputs == 1:
return predictions[:, 0, :]
Expand Down Expand Up @@ -255,7 +255,8 @@ def _predict_in_sample(
)
for estimator, indices in zip(self._estimators, self._test_indices):
fold_predictions = np.reshape(
getattr(estimator, method)(index_matrix(X, indices)), (-1, n_outputs, 1)
getattr(estimator, method)(index_matrix(X, indices)),
(len(indices), n_outputs, 1),
)
predictions[indices] = fold_predictions
if n_outputs == 1:
Expand Down

0 comments on commit 24fd9aa

Please sign in to comment.