Skip to content

Commit

Permalink
[python-package] early stopping min_delta (fixes #2526) (#4580)
Browse files Browse the repository at this point in the history
* initial changes

* initial version

* better handling of cases

* warn only with positive threshold

* remove early_stopping_threshold from high-level functions

* remove remaining early_stopping_threshold

* update test to use callback

* better handling of cases

* rename threshold to min_delta

enhance parameter description

update tests

* Apply suggestions from code review

Co-authored-by: Nikita Titov <[email protected]>

* reduce num_boost_round in tests

* Apply suggestions from code review

Co-authored-by: Nikita Titov <[email protected]>

* trigger ci

Co-authored-by: Nikita Titov <[email protected]>
Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
3 people authored Nov 10, 2021
1 parent 0a4d190 commit 99e0a4b
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 7 deletions.
52 changes: 45 additions & 7 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
# coding: utf-8
"""Callbacks library."""
import collections
from operator import gt, lt
from functools import partial
from typing import Any, Callable, Dict, List, Union

from .basic import _ConfigAliases, _log_info, _log_warning


def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool:
return curr_score > best_score + delta


def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool:
return curr_score < best_score - delta


class EarlyStopException(Exception):
"""Exception of early stopping."""

Expand Down Expand Up @@ -181,11 +189,11 @@ def _callback(env: CallbackEnv) -> None:
return _callback


def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True) -> Callable:
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> Callable:
"""Create a callback that activates early stopping.
Activates early stopping.
The model will train until the validation score stops improving.
The model will train until the validation score doesn't improve by at least ``min_delta``.
Validation score needs to improve at least every ``stopping_rounds`` round(s)
to continue training.
Requires at least one validation data and one metric.
Expand All @@ -203,6 +211,10 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
Whether to log message with early stopping information.
By default, standard output resource is used.
Use ``register_logger()`` function to register a custom logger.
min_delta : float or list of float, optional (default=0.0)
Minimum improvement in score to keep training.
If float, this single value is used for all metrics.
If list, its length should match the total number of metrics.
Returns
-------
Expand All @@ -229,17 +241,43 @@ def _init(env: CallbackEnv) -> None:
if verbose:
_log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds")

n_metrics = len(set(m[1] for m in env.evaluation_result_list))
n_datasets = len(env.evaluation_result_list) // n_metrics
if isinstance(min_delta, list):
if not all(t >= 0 for t in min_delta):
raise ValueError('Values for early stopping min_delta must be non-negative.')
if len(min_delta) == 0:
if verbose:
_log_info('Disabling min_delta for early stopping.')
deltas = [0.0] * n_datasets * n_metrics
elif len(min_delta) == 1:
if verbose:
_log_info(f'Using {min_delta[0]} as min_delta for all metrics.')
deltas = min_delta * n_datasets * n_metrics
else:
if len(min_delta) != n_metrics:
raise ValueError('Must provide a single value for min_delta or as many as metrics.')
if first_metric_only and verbose:
_log_info(f'Using only {min_delta[0]} as early stopping min_delta.')
deltas = min_delta * n_datasets
else:
if min_delta < 0:
raise ValueError('Early stopping min_delta must be non-negative.')
if min_delta > 0 and n_metrics > 1 and not first_metric_only and verbose:
_log_info(f'Using {min_delta} as min_delta for all metrics.')
deltas = [min_delta] * n_datasets * n_metrics

# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1]
for eval_ret in env.evaluation_result_list:
for eval_ret, delta in zip(env.evaluation_result_list, deltas):
best_iter.append(0)
best_score_list.append(None)
if eval_ret[3]:
if eval_ret[3]: # greater is better
best_score.append(float('-inf'))
cmp_op.append(gt)
cmp_op.append(partial(_gt_delta, delta=delta))
else:
best_score.append(float('inf'))
cmp_op.append(lt)
cmp_op.append(partial(_lt_delta, delta=delta))

def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
if env.iteration == env.end_iteration - 1:
Expand Down
75 changes: 75 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,81 @@ def test_early_stopping():
assert 'binary_logloss' in gbm.best_score[valid_set_name]


@pytest.mark.parametrize('first_only', [True, False])
@pytest.mark.parametrize('single_metric', [True, False])
@pytest.mark.parametrize('greater_is_better', [True, False])
def test_early_stopping_min_delta(first_only, single_metric, greater_is_better):
if single_metric and not first_only:
pytest.skip("first_metric_only doesn't affect single metric.")
metric2min_delta = {
'auc': 0.001,
'binary_logloss': 0.01,
'average_precision': 0.001,
'mape': 0.01,
}
if single_metric:
if greater_is_better:
metric = 'auc'
else:
metric = 'binary_logloss'
else:
if first_only:
if greater_is_better:
metric = ['auc', 'binary_logloss']
else:
metric = ['binary_logloss', 'auc']
else:
if greater_is_better:
metric = ['auc', 'average_precision']
else:
metric = ['binary_logloss', 'mape']

X, y = load_breast_cancer(return_X_y=True)
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=0)
train_ds = lgb.Dataset(X_train, y_train)
valid_ds = lgb.Dataset(X_valid, y_valid, reference=train_ds)

params = {'objective': 'binary', 'metric': metric, 'verbose': -1}
if isinstance(metric, str):
min_delta = metric2min_delta[metric]
elif first_only:
min_delta = metric2min_delta[metric[0]]
else:
min_delta = [metric2min_delta[m] for m in metric]
train_kwargs = dict(
params=params,
train_set=train_ds,
num_boost_round=50,
valid_sets=[train_ds, valid_ds],
valid_names=['training', 'valid'],
)

# regular early stopping
train_kwargs['callbacks'] = [lgb.callback.early_stopping(10, first_only, verbose=0)]
evals_result = {}
bst = lgb.train(evals_result=evals_result, **train_kwargs)
scores = np.vstack(list(evals_result['valid'].values())).T

# positive min_delta
train_kwargs['callbacks'] = [lgb.callback.early_stopping(10, first_only, verbose=0, min_delta=min_delta)]
delta_result = {}
delta_bst = lgb.train(evals_result=delta_result, **train_kwargs)
delta_scores = np.vstack(list(delta_result['valid'].values())).T

if first_only:
scores = scores[:, 0]
delta_scores = delta_scores[:, 0]

assert delta_bst.num_trees() < bst.num_trees()
np.testing.assert_allclose(scores[:len(delta_scores)], delta_scores)
last_score = delta_scores[-1]
best_score = delta_scores[delta_bst.num_trees() - 1]
if greater_is_better:
assert np.less_equal(last_score, best_score + min_delta).any()
else:
assert np.greater_equal(last_score, best_score - min_delta).any()


def test_continue_train():
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
Expand Down

0 comments on commit 99e0a4b

Please sign in to comment.