diff --git a/metalearners/xlearner.py b/metalearners/xlearner.py index fd76719..e15c842 100644 --- a/metalearners/xlearner.py +++ b/metalearners/xlearner.py @@ -374,8 +374,8 @@ def evaluate( tv_imputed_te_control, tv_imputed_te_treatment = self._pseudo_outcome( y, w, treatment_variant, conditional_average_outcome_estimates ) - imputed_te_control.append(tv_imputed_te_control) - imputed_te_treatment.append(tv_imputed_te_treatment) + imputed_te_control.append(tv_imputed_te_control[w == 0]) + imputed_te_treatment.append(tv_imputed_te_treatment[w == treatment_variant]) te_treatment_evaluation = _evaluate_model_kind( self._treatment_models[TREATMENT_EFFECT_MODEL],