Skip to content

Commit

Permalink
rename missed variables
Browse files Browse the repository at this point in the history
  • Loading branch information
kyracho committed Sep 4, 2024
1 parent 00fb48c commit b248794
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ def predict(
non_control_indices = ~control_indices

for treatment_variant in range(1, self.n_variants):
treatment_variant_indices = self._treatment_variants_mask[treatment_variant]
non_treatment_variant_mask = ~treatment_variant_indices
treatment_variant_mask = self._treatment_variants_mask[treatment_variant]
non_treatment_variant_mask = ~treatment_variant_mask
if is_oos:
tau_hat_treatment = self.predict_treatment(
X=X,
Expand Down Expand Up @@ -272,8 +272,8 @@ def predict(
oos_method=oos_method,
)

tau_hat_treatment[treatment_variant_indices] = self.predict_treatment(
X=index_matrix(X, treatment_variant_indices),
tau_hat_treatment[treatment_variant_mask] = self.predict_treatment(
X=index_matrix(X, treatment_variant_mask),
model_kind=TREATMENT_EFFECT_MODEL,
model_ord=treatment_variant - 1,
is_oos=False,
Expand Down

0 comments on commit b248794

Please sign in to comment.