Skip to content

Commit

Permalink
Fix pchs
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 25, 2024
1 parent 8022aa1 commit 53bcd03
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sklearn.base import BaseEstimator
from sklearn.linear_model import LinearRegression, LogisticRegression

from metalearners._typing import _ScikitModel
from metalearners.cross_fit_estimator import CrossFitEstimator
from metalearners.data_generation import insert_missing
from metalearners.drlearner import DRLearner
Expand Down Expand Up @@ -115,8 +116,8 @@ def predict_conditional_average_outcomes(self, X, is_oos, oos_method=None):
def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): ...

@classmethod
def _necessary_onnx_models(cls) -> set[str]:
return set()
def _necessary_onnx_models(cls) -> dict[str, list[_ScikitModel]]:
return {}


@pytest.mark.parametrize("is_classification", [True, False])
Expand Down

0 comments on commit 53bcd03

Please sign in to comment.