Skip to content

Commit

Permalink
Merge pull request optuna#4923 from optuna/fix-lightgbm-checks-integr…
Browse files Browse the repository at this point in the history
…ation

Fix the `checks-integration` errors on LightGBMTuner
  • Loading branch information
eukaryo authored Sep 15, 2023
2 parents 0ff4164 + c957c4a commit 68ed591
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions optuna/integration/lightgbm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from __future__ import annotations

import sys
from typing import List
from typing import Optional
from typing import TYPE_CHECKING

import optuna
from optuna._imports import try_import
from optuna.integration import _lightgbm_tuner as tuner


if TYPE_CHECKING:
from lightgbm.basic import _LGBM_BoosterEvalMethodResultType
from lightgbm.callback import CallbackEnv


with try_import() as _imports:
import lightgbm as lgb
from lightgbm.callback import CallbackEnv

# Attach lightgbm API.
if _imports.is_successful():
Expand Down Expand Up @@ -88,27 +93,35 @@ def __init__(
self._report_interval = report_interval

def _find_evaluation_result(
self, target_valid_name: str, env: "CallbackEnv"
) -> Optional[List]:
for evaluation_result in env.evaluation_result_list:
self, target_valid_name: str, env: CallbackEnv
) -> _LGBM_BoosterEvalMethodResultType | None:
evaluation_result_list = env.evaluation_result_list
if evaluation_result_list is None:
return None

for evaluation_result in evaluation_result_list:
valid_name, metric, current_score, is_higher_better = evaluation_result[:4]
# The prefix "valid " is added to metric name since LightGBM v4.0.0.
if valid_name != target_valid_name or (
metric != "valid " + self._metric and metric != self._metric
):
continue

return evaluation_result

return None

def __call__(self, env: "CallbackEnv") -> None:
def __call__(self, env: CallbackEnv) -> None:
if (env.iteration + 1) % self._report_interval == 0:
# If this callback has been passed to `lightgbm.cv` function,
# the value of `is_cv` becomes `True`. See also:
# https://github.com/Microsoft/LightGBM/blob/v2.2.2/python-package/lightgbm/engine.py#L329
# https://github.com/microsoft/LightGBM/blob/v4.1.0/python-package/lightgbm/engine.py#L533
# Note that `5` is not the number of folds but the length of sequence.
is_cv = len(env.evaluation_result_list) > 0 and len(env.evaluation_result_list[0]) == 5
evaluation_result_list = env.evaluation_result_list
is_cv = (
evaluation_result_list is not None
and len(evaluation_result_list) > 0
and len(evaluation_result_list[0]) == 5
)
if is_cv:
target_valid_name = "cv_agg"
else:
Expand Down

0 comments on commit 68ed591

Please sign in to comment.