diff --git a/metalearners/drlearner.py b/metalearners/drlearner.py index ca5dd1a..40e8a09 100644 --- a/metalearners/drlearner.py +++ b/metalearners/drlearner.py @@ -223,7 +223,7 @@ def evaluate( model_kind=VARIANT_OUTCOME_MODEL, is_oos=is_oos, oos_method=oos_method, - is_treatment=False, + is_treatment_model=False, ) propensity_evaluation = _evaluate_model_kind( @@ -234,7 +234,7 @@ def evaluate( model_kind=PROPENSITY_MODEL, is_oos=is_oos, oos_method=oos_method, - is_treatment=False, + is_treatment_model=False, ) pseudo_outcome: list[np.ndarray] = [] @@ -257,7 +257,7 @@ def evaluate( model_kind=TREATMENT_MODEL, is_oos=is_oos, oos_method=oos_method, - is_treatment=True, + is_treatment_model=True, ) return variant_outcome_evaluation | propensity_evaluation | treatment_evaluation