Skip to content

Commit

Permalink
Clean up.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Aug 15, 2024
1 parent 5afa7af commit 413e5b0
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,10 @@ def predict_conditional_average_outcomes(
)
# TODO: Consider multiprocessing
n_obs = len(X)
nuisance_tensors = self._nuisance_tensors(n_obs)
cao_tensor = self._nuisance_tensors(n_obs)[VARIANT_OUTCOME_MODEL][0]
predict_method_name = self.nuisance_model_specifications()[
VARIANT_OUTCOME_MODEL
]["predict_method"](self)
conditional_average_outcomes_list = []

for tv in range(self.n_variants):
Expand All @@ -554,19 +557,22 @@ def predict_conditional_average_outcomes(
)
)
else:
# TODO: Consider moving this logic to CrossFitEstimator.predict.
cfe = self._nuisance_models[VARIANT_OUTCOME_MODEL][tv]
conditional_average_outcome_estimates = cao_tensor.copy()

for fold_index, test_indices in zip(
range(cfe.n_folds), cfe._test_indices # type: ignore[arg-type]
):
fold_model = cfe._estimators[fold_index]
predict_method = getattr(fold_model, predict_method_name)
fold_estimates = predict_method(X[test_indices])
conditional_average_outcome_estimates[test_indices] = fold_estimates

conditional_average_outcomes_list.append(
nuisance_tensors[VARIANT_OUTCOME_MODEL][0].copy()
conditional_average_outcome_estimates
)
for split_index, test_indices in enumerate(cfe._test_indices): # type: ignore[arg-type]
model = cfe._estimators[split_index]
predict_method_name = self.nuisance_model_specifications()[
VARIANT_OUTCOME_MODEL
]["predict_method"](self)
predict_method = getattr(model, predict_method_name)
conditional_average_outcomes_list[tv][test_indices] = (
predict_method(X[test_indices])
)

return np.stack(conditional_average_outcomes_list, axis=1).reshape(
n_obs, self.n_variants, -1
)

0 comments on commit 413e5b0

Please sign in to comment.