From aa7ea19ccd0c162c60ad8dcc503719af87fad764 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francesc=20Mart=C3=AD=20Escofet?= <154450563+FrancescMartiEscofetQC@users.noreply.github.com> Date: Fri, 7 Jun 2024 09:00:38 +0200 Subject: [PATCH] Make params `Mapping` instead of `dict` (#163) --- metalearners/_typing.py | 4 ++-- tests/test_learner.py | 14 +++++++------- tests/test_metalearner.py | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/metalearners/_typing.py b/metalearners/_typing.py index 03f61926..a7f39d46 100644 --- a/metalearners/_typing.py +++ b/metalearners/_typing.py @@ -1,7 +1,7 @@ # # Copyright (c) QuantCo 2024-2024 # # SPDX-License-Identifier: BSD-3-Clause -from collections.abc import Collection +from collections.abc import Collection, Mapping from typing import Literal, Protocol, Union import numpy as np @@ -16,7 +16,7 @@ # https://mypy.readthedocs.io/en/stable/literal_types.html#limitations OosMethod = Literal["overall", "median", "mean"] -Params = dict[str, int | float | str] +Params = Mapping[str, int | float | str] Features = Collection[str] | Collection[int] # ruff is not happy about the usage of Union. diff --git a/tests/test_learner.py b/tests/test_learner.py index df7c6da2..e76018e7 100644 --- a/tests/test_learner.py +++ b/tests/test_learner.py @@ -588,13 +588,13 @@ def test_conditional_average_outcomes_smoke( factory = metalearner_factory(metalearner_prefix) learner = factory( nuisance_model_factory=_tree_base_learner(is_classification), - nuisance_model_params={"n_estimators": 1}, # type: ignore + nuisance_model_params={"n_estimators": 1}, is_classification=is_classification, n_variants=len(np.unique(df[treatment_column])), treatment_model_factory=LGBMRegressor, - treatment_model_params={"n_estimators": 1}, # type: ignore + treatment_model_params={"n_estimators": 1}, propensity_model_factory=LGBMClassifier, - propensity_model_params={"n_estimators": 1}, # type: ignore + propensity_model_params={"n_estimators": 1}, n_folds=2, ) learner.fit(df[feature_columns], df[outcome_column], df[treatment_column]) @@ -626,7 +626,7 @@ def test_conditional_average_outcomes_smoke_multi_class( y = rng.integers(0, n_classes, size=sample_size) learner = factory( nuisance_model_factory=_tree_base_learner(True), - nuisance_model_params={"n_estimators": 1}, # type: ignore + nuisance_model_params={"n_estimators": 1}, n_variants=n_variants, is_classification=True, n_folds=2, @@ -665,13 +665,13 @@ def test_predict_smoke( y = rng.standard_normal(sample_size) learner = factory( nuisance_model_factory=_tree_base_learner(is_classification), - nuisance_model_params={"n_estimators": 1}, # type: ignore + nuisance_model_params={"n_estimators": 1}, n_variants=n_variants, is_classification=is_classification, treatment_model_factory=LGBMRegressor, - treatment_model_params={"n_estimators": 1}, # type: ignore + treatment_model_params={"n_estimators": 1}, propensity_model_factory=LGBMClassifier, - propensity_model_params={"n_estimators": 1}, # type: ignore + propensity_model_params={"n_estimators": 1}, n_folds=2, ) learner.fit(X, y, w) diff --git a/tests/test_metalearner.py b/tests/test_metalearner.py index 3fa4015b..1e414239 100644 --- a/tests/test_metalearner.py +++ b/tests/test_metalearner.py @@ -737,7 +737,7 @@ def test_feature_importances_smoke( nuisance_model_factory=LinearRegression, treatment_model_factory=LGBMRegressor, propensity_model_factory=LogisticRegression, - treatment_model_params={"n_estimators": 1}, # type: ignore + treatment_model_params={"n_estimators": 1}, ) ml.fit(X=X, y=y, w=w) @@ -895,7 +895,7 @@ def test_shap_values_smoke( nuisance_model_factory=LinearRegression, treatment_model_factory=LGBMRegressor, propensity_model_factory=LogisticRegression, - treatment_model_params={"n_estimators": 1}, # type: ignore + treatment_model_params={"n_estimators": 1}, ) ml.fit(X=X, y=y, w=w)