diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index f7e574ece020..5b4e3f194c31 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -1,12 +1,20 @@ # coding: utf-8 """Callbacks library.""" import collections -from operator import gt, lt +from functools import partial from typing import Any, Callable, Dict, List, Union from .basic import _ConfigAliases, _log_info, _log_warning +def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool: + return curr_score > best_score + delta + + +def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool: + return curr_score < best_score - delta + + class EarlyStopException(Exception): """Exception of early stopping.""" @@ -181,11 +189,11 @@ def _callback(env: CallbackEnv) -> None: return _callback -def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True) -> Callable: +def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> Callable: """Create a callback that activates early stopping. Activates early stopping. - The model will train until the validation score stops improving. + The model will train until the validation score doesn't improve by at least ``min_delta``. Validation score needs to improve at least every ``stopping_rounds`` round(s) to continue training. Requires at least one validation data and one metric. @@ -203,6 +211,10 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos Whether to log message with early stopping information. By default, standard output resource is used. Use ``register_logger()`` function to register a custom logger. + min_delta : float or list of float, optional (default=0.0) + Minimum improvement in score to keep training. + If float, this single value is used for all metrics. + If list, its length should match the total number of metrics. Returns ------- @@ -229,17 +241,43 @@ def _init(env: CallbackEnv) -> None: if verbose: _log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds") + n_metrics = len(set(m[1] for m in env.evaluation_result_list)) + n_datasets = len(env.evaluation_result_list) // n_metrics + if isinstance(min_delta, list): + if not all(t >= 0 for t in min_delta): + raise ValueError('Values for early stopping min_delta must be non-negative.') + if len(min_delta) == 0: + if verbose: + _log_info('Disabling min_delta for early stopping.') + deltas = [0.0] * n_datasets * n_metrics + elif len(min_delta) == 1: + if verbose: + _log_info(f'Using {min_delta[0]} as min_delta for all metrics.') + deltas = min_delta * n_datasets * n_metrics + else: + if len(min_delta) != n_metrics: + raise ValueError('Must provide a single value for min_delta or as many as metrics.') + if first_metric_only and verbose: + _log_info(f'Using only {min_delta[0]} as early stopping min_delta.') + deltas = min_delta * n_datasets + else: + if min_delta < 0: + raise ValueError('Early stopping min_delta must be non-negative.') + if min_delta > 0 and n_metrics > 1 and not first_metric_only and verbose: + _log_info(f'Using {min_delta} as min_delta for all metrics.') + deltas = [min_delta] * n_datasets * n_metrics + # split is needed for " " case (e.g. "train l1") first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1] - for eval_ret in env.evaluation_result_list: + for eval_ret, delta in zip(env.evaluation_result_list, deltas): best_iter.append(0) best_score_list.append(None) - if eval_ret[3]: + if eval_ret[3]: # greater is better best_score.append(float('-inf')) - cmp_op.append(gt) + cmp_op.append(partial(_gt_delta, delta=delta)) else: best_score.append(float('inf')) - cmp_op.append(lt) + cmp_op.append(partial(_lt_delta, delta=delta)) def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None: if env.iteration == env.end_iteration - 1: diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index b44cee469a22..925b3f81c531 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -642,6 +642,81 @@ def test_early_stopping(): assert 'binary_logloss' in gbm.best_score[valid_set_name] +@pytest.mark.parametrize('first_only', [True, False]) +@pytest.mark.parametrize('single_metric', [True, False]) +@pytest.mark.parametrize('greater_is_better', [True, False]) +def test_early_stopping_min_delta(first_only, single_metric, greater_is_better): + if single_metric and not first_only: + pytest.skip("first_metric_only doesn't affect single metric.") + metric2min_delta = { + 'auc': 0.001, + 'binary_logloss': 0.01, + 'average_precision': 0.001, + 'mape': 0.01, + } + if single_metric: + if greater_is_better: + metric = 'auc' + else: + metric = 'binary_logloss' + else: + if first_only: + if greater_is_better: + metric = ['auc', 'binary_logloss'] + else: + metric = ['binary_logloss', 'auc'] + else: + if greater_is_better: + metric = ['auc', 'average_precision'] + else: + metric = ['binary_logloss', 'mape'] + + X, y = load_breast_cancer(return_X_y=True) + X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=0) + train_ds = lgb.Dataset(X_train, y_train) + valid_ds = lgb.Dataset(X_valid, y_valid, reference=train_ds) + + params = {'objective': 'binary', 'metric': metric, 'verbose': -1} + if isinstance(metric, str): + min_delta = metric2min_delta[metric] + elif first_only: + min_delta = metric2min_delta[metric[0]] + else: + min_delta = [metric2min_delta[m] for m in metric] + train_kwargs = dict( + params=params, + train_set=train_ds, + num_boost_round=50, + valid_sets=[train_ds, valid_ds], + valid_names=['training', 'valid'], + ) + + # regular early stopping + train_kwargs['callbacks'] = [lgb.callback.early_stopping(10, first_only, verbose=0)] + evals_result = {} + bst = lgb.train(evals_result=evals_result, **train_kwargs) + scores = np.vstack(list(evals_result['valid'].values())).T + + # positive min_delta + train_kwargs['callbacks'] = [lgb.callback.early_stopping(10, first_only, verbose=0, min_delta=min_delta)] + delta_result = {} + delta_bst = lgb.train(evals_result=delta_result, **train_kwargs) + delta_scores = np.vstack(list(delta_result['valid'].values())).T + + if first_only: + scores = scores[:, 0] + delta_scores = delta_scores[:, 0] + + assert delta_bst.num_trees() < bst.num_trees() + np.testing.assert_allclose(scores[:len(delta_scores)], delta_scores) + last_score = delta_scores[-1] + best_score = delta_scores[delta_bst.num_trees() - 1] + if greater_is_better: + assert np.less_equal(last_score, best_score + min_delta).any() + else: + assert np.greater_equal(last_score, best_score - min_delta).any() + + def test_continue_train(): X, y = load_boston(return_X_y=True) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)