Skip to content

Commit

Permalink
Fix in-sample evaluate.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Aug 15, 2024
1 parent fe16b75 commit 410e9e7
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,8 @@ def evaluate(

variant_outcome_evaluation = _evaluate_model_kind(
cfes=self._nuisance_models[VARIANT_OUTCOME_MODEL],
Xs=[X[w == tv] for tv in range(self.n_variants)],
ys=[y[w == tv] for tv in range(self.n_variants)],
Xs=[X] * self.n_variants,
ys=[y] * self.n_variants,
scorers=safe_scoring[VARIANT_OUTCOME_MODEL],
model_kind=VARIANT_OUTCOME_MODEL,
is_oos=is_oos,
Expand Down Expand Up @@ -374,12 +374,12 @@ def evaluate(
tv_imputed_te_control, tv_imputed_te_treatment = self._pseudo_outcome(
y, w, treatment_variant, conditional_average_outcome_estimates
)
imputed_te_control.append(tv_imputed_te_control[w == 0])
imputed_te_treatment.append(tv_imputed_te_treatment[w == treatment_variant])
imputed_te_control.append(tv_imputed_te_control)
imputed_te_treatment.append(tv_imputed_te_treatment)

te_treatment_evaluation = _evaluate_model_kind(
self._treatment_models[TREATMENT_EFFECT_MODEL],
Xs=[X[w == tv] for tv in range(1, self.n_variants)],
Xs=[X] * self.n_variants,
ys=imputed_te_treatment,
scorers=safe_scoring[TREATMENT_EFFECT_MODEL],
model_kind=TREATMENT_EFFECT_MODEL,
Expand All @@ -391,7 +391,7 @@ def evaluate(

te_control_evaluation = _evaluate_model_kind(
self._treatment_models[CONTROL_EFFECT_MODEL],
Xs=[X[w == 0] for _ in range(1, self.n_variants)],
Xs=[X] * self.n_variants,
ys=imputed_te_control,
scorers=safe_scoring[CONTROL_EFFECT_MODEL],
model_kind=CONTROL_EFFECT_MODEL,
Expand Down

0 comments on commit 410e9e7

Please sign in to comment.