Skip to content

Commit

Permalink
[python-package] fix mypy errors related to eval result tuples (#6097)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Sep 13, 2023
1 parent 921479b commit 1a6e6ff
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
1 change: 1 addition & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
_LGBM_EvalFunctionResultType = Tuple[str, float, bool]
_LGBM_BoosterBestScoreType = Dict[str, Dict[str, float]]
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]
_LGBM_BoosterEvalMethodResultWithStandardDeviationType = Tuple[str, str, float, bool, float]
_LGBM_CategoricalFeatureConfiguration = Union[List[str], List[int], "Literal['auto']"]
_LGBM_FeatureNameConfiguration = Union[List[str], "Literal['auto']"]
_LGBM_GroupType = Union[
Expand Down
11 changes: 6 additions & 5 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from collections import OrderedDict
from dataclasses import dataclass
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

from .basic import Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
from .basic import (Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType,
_LGBM_BoosterEvalMethodResultWithStandardDeviationType, _log_info, _log_warning)

if TYPE_CHECKING:
from .engine import CVBooster
Expand All @@ -20,11 +21,11 @@
_EvalResultDict = Dict[str, Dict[str, List[Any]]]
_EvalResultTuple = Union[
_LGBM_BoosterEvalMethodResultType,
Tuple[str, str, float, bool, float]
_LGBM_BoosterEvalMethodResultWithStandardDeviationType
]
_ListOfEvalResultTuples = Union[
List[_LGBM_BoosterEvalMethodResultType],
List[Tuple[str, str, float, bool, float]]
List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType]
]


Expand Down Expand Up @@ -54,7 +55,7 @@ class CallbackEnv:
iteration: int
begin_iteration: int
end_iteration: int
evaluation_result_list: Optional[List[_LGBM_BoosterEvalMethodResultType]]
evaluation_result_list: Optional[_ListOfEvalResultTuples]


def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
Expand Down
12 changes: 6 additions & 6 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

from . import callback
from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor,
_LGBM_BoosterEvalMethodResultType, _LGBM_CategoricalFeatureConfiguration,
_LGBM_CustomObjectiveFunction, _LGBM_EvalFunctionResultType, _LGBM_FeatureNameConfiguration,
_log_warning)
_LGBM_BoosterEvalMethodResultType, _LGBM_BoosterEvalMethodResultWithStandardDeviationType,
_LGBM_CategoricalFeatureConfiguration, _LGBM_CustomObjectiveFunction, _LGBM_EvalFunctionResultType,
_LGBM_FeatureNameConfiguration, _log_warning)
from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold

__all__ = [
Expand Down Expand Up @@ -519,8 +519,8 @@ def _make_n_folds(


def _agg_cv_result(
raw_results: List[List[Tuple[str, str, float, bool]]]
) -> List[Tuple[str, str, float, bool, float]]:
raw_results: List[List[_LGBM_BoosterEvalMethodResultType]]
) -> List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType]:
"""Aggregate cross-validation results."""
cvmap: Dict[str, List[float]] = OrderedDict()
metric_type: Dict[str, bool] = {}
Expand All @@ -530,7 +530,7 @@ def _agg_cv_result(
metric_type[key] = one_line[3]
cvmap.setdefault(key, [])
cvmap[key].append(one_line[2])
return [('cv_agg', k, np.mean(v), metric_type[k], np.std(v)) for k, v in cvmap.items()]
return [('cv_agg', k, float(np.mean(v)), metric_type[k], float(np.std(v))) for k, v in cvmap.items()]


def cv(
Expand Down

0 comments on commit 1a6e6ff

Please sign in to comment.