diff --git a/metalearners/tlearner.py b/metalearners/tlearner.py index 875a986..11dacaa 100644 --- a/metalearners/tlearner.py +++ b/metalearners/tlearner.py @@ -118,7 +118,6 @@ def evaluate( ) -> dict[str, float]: safe_scoring = self._scoring(scoring) - return _evaluate_model_kind( cfes=self._nuisance_models[VARIANT_OUTCOME_MODEL], Xs=[X[w == tv] for tv in range(self.n_variants)], @@ -127,5 +126,5 @@ def evaluate( model_kind=VARIANT_OUTCOME_MODEL, is_oos=is_oos, oos_method=oos_method, - is_treatment=False, + is_treatment_model=False, )