diff --git a/optuna/integration/lightgbm.py b/optuna/integration/lightgbm.py index f2da5b52af7..6e7b8b38756 100644 --- a/optuna/integration/lightgbm.py +++ b/optuna/integration/lightgbm.py @@ -1,15 +1,20 @@ +from __future__ import annotations + import sys -from typing import List -from typing import Optional +from typing import TYPE_CHECKING import optuna from optuna._imports import try_import from optuna.integration import _lightgbm_tuner as tuner +if TYPE_CHECKING: + from lightgbm.basic import _LGBM_BoosterEvalMethodResultType + from lightgbm.callback import CallbackEnv + + with try_import() as _imports: import lightgbm as lgb - from lightgbm.callback import CallbackEnv # Attach lightgbm API. if _imports.is_successful(): @@ -88,27 +93,35 @@ def __init__( self._report_interval = report_interval def _find_evaluation_result( - self, target_valid_name: str, env: "CallbackEnv" - ) -> Optional[List]: - for evaluation_result in env.evaluation_result_list: + self, target_valid_name: str, env: CallbackEnv + ) -> _LGBM_BoosterEvalMethodResultType | None: + evaluation_result_list = env.evaluation_result_list + if evaluation_result_list is None: + return None + + for evaluation_result in evaluation_result_list: valid_name, metric, current_score, is_higher_better = evaluation_result[:4] # The prefix "valid " is added to metric name since LightGBM v4.0.0. if valid_name != target_valid_name or ( metric != "valid " + self._metric and metric != self._metric ): continue - return evaluation_result return None - def __call__(self, env: "CallbackEnv") -> None: + def __call__(self, env: CallbackEnv) -> None: if (env.iteration + 1) % self._report_interval == 0: # If this callback has been passed to `lightgbm.cv` function, # the value of `is_cv` becomes `True`. See also: - # https://github.com/Microsoft/LightGBM/blob/v2.2.2/python-package/lightgbm/engine.py#L329 + # https://github.com/microsoft/LightGBM/blob/v4.1.0/python-package/lightgbm/engine.py#L533 # Note that `5` is not the number of folds but the length of sequence. - is_cv = len(env.evaluation_result_list) > 0 and len(env.evaluation_result_list[0]) == 5 + evaluation_result_list = env.evaluation_result_list + is_cv = ( + evaluation_result_list is not None + and len(evaluation_result_list) > 0 + and len(evaluation_result_list[0]) == 5 + ) if is_cv: target_valid_name = "cv_agg" else: