From 53bcd036fb47980ae35b2547122b39522a01d9a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francesc=20Mart=C3=AD=20Escofet?= Date: Thu, 25 Jul 2024 14:34:54 +0200 Subject: [PATCH] Fix pchs --- tests/test_metalearner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_metalearner.py b/tests/test_metalearner.py index 5e18153..aa360b0 100644 --- a/tests/test_metalearner.py +++ b/tests/test_metalearner.py @@ -13,6 +13,7 @@ from sklearn.base import BaseEstimator from sklearn.linear_model import LinearRegression, LogisticRegression +from metalearners._typing import _ScikitModel from metalearners.cross_fit_estimator import CrossFitEstimator from metalearners.data_generation import insert_missing from metalearners.drlearner import DRLearner @@ -115,8 +116,8 @@ def predict_conditional_average_outcomes(self, X, is_oos, oos_method=None): def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): ... @classmethod - def _necessary_onnx_models(cls) -> set[str]: - return set() + def _necessary_onnx_models(cls) -> dict[str, list[_ScikitModel]]: + return {} @pytest.mark.parametrize("is_classification", [True, False])