diff --git a/tests/test_metalearner.py b/tests/test_metalearner.py index 5e18153..aa360b0 100644 --- a/tests/test_metalearner.py +++ b/tests/test_metalearner.py @@ -13,6 +13,7 @@ from sklearn.base import BaseEstimator from sklearn.linear_model import LinearRegression, LogisticRegression +from metalearners._typing import _ScikitModel from metalearners.cross_fit_estimator import CrossFitEstimator from metalearners.data_generation import insert_missing from metalearners.drlearner import DRLearner @@ -115,8 +116,8 @@ def predict_conditional_average_outcomes(self, X, is_oos, oos_method=None): def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): ... @classmethod - def _necessary_onnx_models(cls) -> set[str]: - return set() + def _necessary_onnx_models(cls) -> dict[str, list[_ScikitModel]]: + return {} @pytest.mark.parametrize("is_classification", [True, False])