From c08dd6ae69dd38cdfb96ec2eea78bc40dad631e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francesc=20Mart=C3=AD=20Escofet?= Date: Thu, 27 Jun 2024 09:27:51 +0200 Subject: [PATCH] Reuse typing --- metalearners/grid_search.py | 8 ++++---- metalearners/metalearner.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/metalearners/grid_search.py b/metalearners/grid_search.py index 78ba23a..2352a89 100644 --- a/metalearners/grid_search.py +++ b/metalearners/grid_search.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import time -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Mapping, Sequence from dataclasses import dataclass from functools import reduce from operator import add @@ -11,7 +11,7 @@ from joblib import Parallel, delayed from sklearn.model_selection import KFold, ParameterGrid -from metalearners._typing import Matrix, OosMethod, Vector, _ScikitModel +from metalearners._typing import Matrix, OosMethod, Scoring, Vector, _ScikitModel from metalearners._utils import index_matrix, index_vector from metalearners.cross_fit_estimator import OVERALL from metalearners.metalearner import PROPENSITY_MODEL, MetaLearner @@ -27,7 +27,7 @@ class _FitAndScoreJob: y_test: Vector w_test: Vector oos_method: OosMethod - scoring: Mapping[str, list[str | Callable]] | None + scoring: Scoring | None kwargs: dict cv_index: int @@ -163,7 +163,7 @@ def __init__( metalearner_params: Mapping, base_learner_grid: Mapping[str, Sequence[type[_ScikitModel]]], param_grid: Mapping[str, Mapping[str, Sequence]], - scoring: Mapping[str, list[str | Callable]] | None = None, + scoring: Scoring | None = None, cv: int = 5, n_jobs: int | None = None, random_state: int | None = None, diff --git a/metalearners/metalearner.py b/metalearners/metalearner.py index 4bce635..1efe82b 100644 --- a/metalearners/metalearner.py +++ b/metalearners/metalearner.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from abc import ABC, abstractmethod -from collections.abc import Callable, Collection, Mapping, Sequence +from collections.abc import Callable, Collection, Sequence from copy import deepcopy from dataclasses import dataclass from typing import TypedDict @@ -856,7 +856,7 @@ def evaluate( w: Vector, is_oos: bool, oos_method: OosMethod = OVERALL, - scoring: Mapping[str, list[str | Callable]] | None = None, + scoring: Scoring | None = None, ) -> dict[str, float]: r"""Evaluate the MetaLearner.