Skip to content

Commit

Permalink
[python-package] reorganize early stopping callback
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Sep 27, 2023
1 parent 60a4a13 commit 177bd50
Showing 1 changed file with 45 additions and 19 deletions.
64 changes: 45 additions & 19 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,12 @@ def __call__(self, env: CallbackEnv) -> None:
if new_param != env.params.get(key, None):
new_parameters[key] = new_param
if new_parameters:
env.model.reset_parameter(new_parameters)
if isinstance(env.model, Booster):
env.model.reset_parameter(new_parameters)
else:
# CVBooster holds a list of Booster objects, each needs to be updated
for i in range(len(env.model.boosters)):
env.model.boosters[i].reset_parameter(new_parameters)
env.params.update(new_parameters)


Expand Down Expand Up @@ -291,32 +296,49 @@ def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
return curr_score < best_score - delta

def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool:
return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name
def _is_train_set(self, ds_name: str, eval_name: str, env: CallbackEnv) -> bool:
"""Check, by name, if a given Dataset is the training data"""
# for lgb.cv() with eval_train_metric=True, evaluation is also done on the training set
# and those metrics are considered for early stopping
if ds_name == "cv_agg" and eval_name == "train":
return True

# for lgb.train(), it's possible to pass the training data via valid_sets with any eval_name
if isinstance(env.model, Booster):
if ds_name == env.model._train_data_name:
return True

return False

def _init(self, env: CallbackEnv) -> None:
if env.evaluation_result_list is None or env.evaluation_result_list == []:
raise ValueError(
"For early stopping, at least one dataset and eval metric is required for evaluation"
)

if self.stopping_rounds <= 0:
raise ValueError(f"stopping_rounds should be greater than zero. got: {self.stopping_rounds}")

is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting"))
only_train_set = (
len(env.evaluation_result_list) == 1
and self._is_train_set(
ds_name=env.evaluation_result_list[0][0],
eval_name=env.evaluation_result_list[0][1].split(" ")[0],
train_name=env.model._train_data_name)
)
self.enabled = not is_dart and not only_train_set
if not self.enabled:
if is_dart:
_log_warning('Early stopping is not available in dart mode')
elif only_train_set:
_log_warning('Only training set found, disabling early stopping.')
if is_dart:
self.enabled = False
_log_warning('Early stopping is not available in dart mode')
return

if self.stopping_rounds <= 0:
raise ValueError("stopping_rounds should be greater than zero.")
# validation sets are guaranteed to not be identical to the training data in cv()
if isinstance(env.model, Booster):
only_train_set = (
len(env.evaluation_result_list) == 1
and self._is_train_set(
ds_name=env.evaluation_result_list[0][0],
eval_name=env.evaluation_result_list[0][1].split(" ")[0],
env=env
)
)
if only_train_set:
self.enabled = False
_log_warning('Only training set found, disabling early stopping.')
return

if self.verbose:
_log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
Expand Down Expand Up @@ -395,7 +417,11 @@ def __call__(self, env: CallbackEnv) -> None:
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping
if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name):
if self._is_train_set(
ds_name=env.evaluation_result_list[i][0],
eval_name=eval_name_splitted[0],
env=env
):
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
if self.verbose:
Expand Down

0 comments on commit 177bd50

Please sign in to comment.