From b248794f7ddff11847e8a1c4c86fc3ca9c9ebba6 Mon Sep 17 00:00:00 2001 From: kyracho Date: Tue, 3 Sep 2024 18:31:32 -0700 Subject: [PATCH] rename missed variables --- metalearners/xlearner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index 02a94ed8..f9ff5497 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -243,8 +243,8 @@ def predict( non_control_indices = ~control_indices for treatment_variant in range(1, self.n_variants): - treatment_variant_indices = self._treatment_variants_mask[treatment_variant] - non_treatment_variant_mask = ~treatment_variant_indices + treatment_variant_mask = self._treatment_variants_mask[treatment_variant] + non_treatment_variant_mask = ~treatment_variant_mask if is_oos: tau_hat_treatment = self.predict_treatment( X=X, @@ -272,8 +272,8 @@ def predict( oos_method=oos_method, ) - tau_hat_treatment[treatment_variant_indices] = self.predict_treatment( - X=index_matrix(X, treatment_variant_indices), + tau_hat_treatment[treatment_variant_mask] = self.predict_treatment( + X=index_matrix(X, treatment_variant_mask), model_kind=TREATMENT_EFFECT_MODEL, model_ord=treatment_variant - 1, is_oos=False,