diff --git a/metalearners/_typing.py b/metalearners/_typing.py index 03f6192..a7f39d4 100644 --- a/metalearners/_typing.py +++ b/metalearners/_typing.py @@ -1,7 +1,7 @@ # # Copyright (c) QuantCo 2024-2024 # # SPDX-License-Identifier: BSD-3-Clause -from collections.abc import Collection +from collections.abc import Collection, Mapping from typing import Literal, Protocol, Union import numpy as np @@ -16,7 +16,7 @@ # https://mypy.readthedocs.io/en/stable/literal_types.html#limitations OosMethod = Literal["overall", "median", "mean"] -Params = dict[str, int | float | str] +Params = Mapping[str, int | float | str] Features = Collection[str] | Collection[int] # ruff is not happy about the usage of Union. diff --git a/tests/test_learner.py b/tests/test_learner.py index df7c6da..e76018e 100644 --- a/tests/test_learner.py +++ b/tests/test_learner.py @@ -588,13 +588,13 @@ def test_conditional_average_outcomes_smoke( factory = metalearner_factory(metalearner_prefix) learner = factory( nuisance_model_factory=_tree_base_learner(is_classification), - nuisance_model_params={"n_estimators": 1}, # type: ignore + nuisance_model_params={"n_estimators": 1}, is_classification=is_classification, n_variants=len(np.unique(df[treatment_column])), treatment_model_factory=LGBMRegressor, - treatment_model_params={"n_estimators": 1}, # type: ignore + treatment_model_params={"n_estimators": 1}, propensity_model_factory=LGBMClassifier, - propensity_model_params={"n_estimators": 1}, # type: ignore + propensity_model_params={"n_estimators": 1}, n_folds=2, ) learner.fit(df[feature_columns], df[outcome_column], df[treatment_column]) @@ -626,7 +626,7 @@ def test_conditional_average_outcomes_smoke_multi_class( y = rng.integers(0, n_classes, size=sample_size) learner = factory( nuisance_model_factory=_tree_base_learner(True), - nuisance_model_params={"n_estimators": 1}, # type: ignore + nuisance_model_params={"n_estimators": 1}, n_variants=n_variants, is_classification=True, n_folds=2, @@ -665,13 +665,13 @@ def test_predict_smoke( y = rng.standard_normal(sample_size) learner = factory( nuisance_model_factory=_tree_base_learner(is_classification), - nuisance_model_params={"n_estimators": 1}, # type: ignore + nuisance_model_params={"n_estimators": 1}, n_variants=n_variants, is_classification=is_classification, treatment_model_factory=LGBMRegressor, - treatment_model_params={"n_estimators": 1}, # type: ignore + treatment_model_params={"n_estimators": 1}, propensity_model_factory=LGBMClassifier, - propensity_model_params={"n_estimators": 1}, # type: ignore + propensity_model_params={"n_estimators": 1}, n_folds=2, ) learner.fit(X, y, w) diff --git a/tests/test_metalearner.py b/tests/test_metalearner.py index 3fa4015..1e41423 100644 --- a/tests/test_metalearner.py +++ b/tests/test_metalearner.py @@ -737,7 +737,7 @@ def test_feature_importances_smoke( nuisance_model_factory=LinearRegression, treatment_model_factory=LGBMRegressor, propensity_model_factory=LogisticRegression, - treatment_model_params={"n_estimators": 1}, # type: ignore + treatment_model_params={"n_estimators": 1}, ) ml.fit(X=X, y=y, w=w) @@ -895,7 +895,7 @@ def test_shap_values_smoke( nuisance_model_factory=LinearRegression, treatment_model_factory=LGBMRegressor, propensity_model_factory=LogisticRegression, - treatment_model_params={"n_estimators": 1}, # type: ignore + treatment_model_params={"n_estimators": 1}, ) ml.fit(X=X, y=y, w=w)