From 7941aefb004626af0e41c03a9a67603d35a9ec2d Mon Sep 17 00:00:00 2001 From: kklein Date: Thu, 15 Aug 2024 00:45:52 +0200 Subject: [PATCH] Reuse conditional average outcome estimates for X-Learner pseudo outcome. --- metalearners/xlearner.py | 65 +++++++++++++++++++++++----------------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index 5965b96c..28bee892 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -173,9 +173,17 @@ def fit_all_treatment( qualified_fit_params = self._qualified_fit_params(fit_params) treatment_jobs: list[_ParallelJoblibSpecification] = [] + + conditional_average_outcome_estimates = ( + self.predict_conditional_average_outcomes( + X=X, + is_oos=False, + ) + ) + for treatment_variant in range(1, self.n_variants): imputed_te_control, imputed_te_treatment = self._pseudo_outcome( - X, y, w, treatment_variant + y, w, treatment_variant, conditional_average_outcome_estimates ) treatment_jobs.append( self._treatment_joblib_specifications( @@ -270,6 +278,7 @@ def predict( oos_method=oos_method, ) ) + tau_hat_treatment[treatment_variant_indices] = self.predict_treatment( X=index_matrix(X, treatment_variant_indices), model_kind=TREATMENT_EFFECT_MODEL, @@ -350,11 +359,19 @@ def evaluate( feature_set=self.feature_set[PROPENSITY_MODEL], ) + conditional_average_outcome_estimates = ( + self.predict_conditional_average_outcomes( + X=X, + is_oos=is_oos, + oos_method=oos_method, + ) + ) + imputed_te_control: list[np.ndarray] = [] imputed_te_treatment: list[np.ndarray] = [] for treatment_variant in range(1, self.n_variants): tv_imputed_te_control, tv_imputed_te_treatment = self._pseudo_outcome( - X, y, w, treatment_variant + y, w, treatment_variant, conditional_average_outcome_estimates ) imputed_te_control.append(tv_imputed_te_control) imputed_te_treatment.append(tv_imputed_te_treatment) @@ -391,7 +408,11 @@ def evaluate( ) def _pseudo_outcome( - self, X: Matrix, y: Vector, w: Vector, treatment_variant: int + self, + y: Vector, + w: Vector, + treatment_variant: int, + conditional_average_outcome_estimates: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: """Compute the X-Learner pseudo outcome. @@ -404,38 +425,26 @@ def _pseudo_outcome( """ validate_valid_treatment_variant_not_control(treatment_variant, self.n_variants) - mask = (w == treatment_variant) | (w == 0) + treatment_indices = w == treatment_variant + control_indices = w == 0 - X_filt = X[mask] - y_filt = y[mask] - w_filt = (w[mask] == treatment_variant).astype(int) + treatment_outcome = index_matrix( + conditional_average_outcome_estimates, control_indices + )[:, treatment_variant] + control_outcome = index_matrix( + conditional_average_outcome_estimates, treatment_indices + )[:, 0] - treatment_indices = w_filt == 1 - control_indices = w_filt == 0 - - # This is always oos because the VARIANT_OUTCOME_MODEL[0] is used to predict the - # control outcomes of the treated observations and vice versa. - control_outcome = self.predict_nuisance( - X=index_matrix(X_filt, treatment_indices), - model_kind=VARIANT_OUTCOME_MODEL, - model_ord=0, - is_oos=True, - oos_method=OVERALL, - ) - treatment_outcome = self.predict_nuisance( - X=index_matrix(X_filt, control_indices), - model_kind=VARIANT_OUTCOME_MODEL, - model_ord=treatment_variant, - is_oos=True, - oos_method=OVERALL, - ) if self.is_classification: # Get the probability of positive class, multiclass is currently not supported. control_outcome = control_outcome[:, 1] treatment_outcome = treatment_outcome[:, 1] + else: + control_outcome = control_outcome[:, 0] + treatment_outcome = treatment_outcome[:, 0] - imputed_te_treatment = y_filt[treatment_indices] - control_outcome - imputed_te_control = treatment_outcome - y_filt[control_indices] + imputed_te_treatment = y[treatment_indices] - control_outcome + imputed_te_control = treatment_outcome - y[control_indices] return imputed_te_control, imputed_te_treatment