diff --git a/metalearners/slearner.py b/metalearners/slearner.py index aacc93b..4e3e2b7 100644 --- a/metalearners/slearner.py +++ b/metalearners/slearner.py @@ -172,7 +172,7 @@ def evaluate( model_kind=_BASE_MODEL, is_oos=is_oos, oos_method=oos_method, - is_treatment=False, + is_treatment_model=False, ) def predict_conditional_average_outcomes(