Skip to content

Commit

Permalink
Fix cao estimation only taking place for seen variant.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Aug 15, 2024
1 parent 6a43c9c commit bbfff15
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def fit_all_treatment(
imputed_te_control, imputed_te_treatment = self._pseudo_outcome(
y, w, treatment_variant, conditional_average_outcome_estimates
)

treatment_jobs.append(
self._treatment_joblib_specifications(
X=X,
Expand Down Expand Up @@ -221,6 +222,7 @@ def fit_all_treatment(
delayed(_fit_cross_fit_estimator_joblib)(job) for job in treatment_jobs
)
self._assign_joblib_treatment_results(results)

return self

def predict(
Expand Down Expand Up @@ -564,13 +566,15 @@ def predict_conditional_average_outcomes(
cfe = self._nuisance_models[VARIANT_OUTCOME_MODEL][tv]
conditional_average_outcome_estimates = cao_tensor.copy()

for fold_index, test_indices in zip(
range(cfe.n_folds), cfe._test_indices # type: ignore[arg-type]
for fold_index, (train_indices, prediction_indices) in enumerate(
self._cv_split_indices
):
fold_model = cfe._estimators[fold_index]
predict_method = getattr(fold_model, predict_method_name)
fold_estimates = predict_method(index_matrix(X, test_indices))
conditional_average_outcome_estimates[test_indices] = fold_estimates
fold_estimates = predict_method(index_matrix(X, prediction_indices))
conditional_average_outcome_estimates[prediction_indices] = (
fold_estimates
)

conditional_average_outcomes_list.append(
conditional_average_outcome_estimates
Expand Down

0 comments on commit bbfff15

Please sign in to comment.