From 413e5b0007c28df885760635865415c425d0ad4e Mon Sep 17 00:00:00 2001 From: kklein Date: Thu, 15 Aug 2024 20:40:33 +0200 Subject: [PATCH] Clean up. --- metalearners/xlearner.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index 082917d..bec07b5 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -539,7 +539,10 @@ def predict_conditional_average_outcomes( ) # TODO: Consider multiprocessing n_obs = len(X) - nuisance_tensors = self._nuisance_tensors(n_obs) + cao_tensor = self._nuisance_tensors(n_obs)[VARIANT_OUTCOME_MODEL][0] + predict_method_name = self.nuisance_model_specifications()[ + VARIANT_OUTCOME_MODEL + ]["predict_method"](self) conditional_average_outcomes_list = [] for tv in range(self.n_variants): @@ -554,19 +557,22 @@ def predict_conditional_average_outcomes( ) ) else: + # TODO: Consider moving this logic to CrossFitEstimator.predict. cfe = self._nuisance_models[VARIANT_OUTCOME_MODEL][tv] + conditional_average_outcome_estimates = cao_tensor.copy() + + for fold_index, test_indices in zip( + range(cfe.n_folds), cfe._test_indices # type: ignore[arg-type] + ): + fold_model = cfe._estimators[fold_index] + predict_method = getattr(fold_model, predict_method_name) + fold_estimates = predict_method(X[test_indices]) + conditional_average_outcome_estimates[test_indices] = fold_estimates + conditional_average_outcomes_list.append( - nuisance_tensors[VARIANT_OUTCOME_MODEL][0].copy() + conditional_average_outcome_estimates ) - for split_index, test_indices in enumerate(cfe._test_indices): # type: ignore[arg-type] - model = cfe._estimators[split_index] - predict_method_name = self.nuisance_model_specifications()[ - VARIANT_OUTCOME_MODEL - ]["predict_method"](self) - predict_method = getattr(model, predict_method_name) - conditional_average_outcomes_list[tv][test_indices] = ( - predict_method(X[test_indices]) - ) + return np.stack(conditional_average_outcomes_list, axis=1).reshape( n_obs, self.n_variants, -1 )