Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] early stopping min_delta (fixes #2526) #4580

Merged
merged 17 commits into from
Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
# coding: utf-8
"""Callbacks library."""
import collections
from operator import gt, lt
from functools import partial
from typing import Any, Callable, Dict, List, Union

from .basic import _ConfigAliases, _log_info, _log_warning


def _gt_threshold(curr_score, best_score, threshold):
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
return curr_score > best_score + threshold


def _lt_threshold(curr_score, best_score, threshold):
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
return curr_score < best_score - threshold


class EarlyStopException(Exception):
"""Exception of early stopping."""

Expand Down Expand Up @@ -143,7 +151,7 @@ def _callback(env: CallbackEnv) -> None:
return _callback


def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True) -> Callable:
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, threshold: Union[float, List[float]] = 0.0) -> Callable:
"""Create a callback that activates early stopping.

Activates early stopping.
Expand All @@ -162,6 +170,8 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos
Whether to use only the first metric for early stopping.
verbose : bool, optional (default=True)
Whether to print message with early stopping information.
threshold: float or list of float (default=0.0)
Minimum improvement in score to keep training.

Returns
-------
Expand All @@ -188,17 +198,35 @@ def _init(env: CallbackEnv) -> None:
if verbose:
_log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds")

n_metrics = len(set(m[1] for m in env.evaluation_result_list))
n_datasets = len(env.evaluation_result_list) // n_metrics
if isinstance(threshold, list):
if not all(t >= 0 for t in threshold):
raise ValueError('Early stopping thresholds must be non-negative.')
if len(threshold) > 1:
if len(threshold) != n_metrics:
raise ValueError('Must provide a single early stopping threshold or as many as metrics.')
if first_metric_only:
_log_warning(f'Using only {threshold[0]} as early stopping threshold.')
tholds = threshold * n_datasets
else:
if threshold < 0:
raise ValueError('Early stopping threshold must be non-negative.')
if threshold > 0 and n_metrics > 1 and not first_metric_only:
_log_warning(f'Using {threshold} as the early stopping threshold for all metrics.')
tholds = [threshold] * len(env.evaluation_result_list)

# split is needed for "<dataset type> <metric>" case (e.g. "train l1")
first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1]
for eval_ret in env.evaluation_result_list:
for i, eval_ret in enumerate(env.evaluation_result_list):
best_iter.append(0)
best_score_list.append(None)
if eval_ret[3]:
if eval_ret[3]: # greater is better
best_score.append(float('-inf'))
cmp_op.append(gt)
cmp_op.append(partial(_gt_threshold, threshold=tholds[i]))
else:
best_score.append(float('inf'))
cmp_op.append(lt)
cmp_op.append(partial(_lt_threshold, threshold=tholds[i]))

def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
if env.iteration == env.end_iteration - 1:
Expand Down
13 changes: 9 additions & 4 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def train(
feature_name: Union[List[str], str] = 'auto',
categorical_feature: Union[List[str], List[int], str] = 'auto',
early_stopping_rounds: Optional[int] = None,
early_stopping_threshold: Union[float, List[float]] = 0.0,
evals_result: Optional[Dict[str, Any]] = None,
verbose_eval: Union[bool, int] = True,
learning_rates: Optional[Union[List[float], Callable[[int], float]]] = None,
Expand Down Expand Up @@ -121,6 +122,8 @@ def train(
To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in ``params``.
The index of iteration that has the best performance will be saved in the ``best_iteration`` field
if early stopping logic is enabled by setting ``early_stopping_rounds``.
early_stopping_threshold : float or list of float (default=0.0)
Minimum improvement in score to keep training.
evals_result: dict or None, optional (default=None)
Dictionary used to store all evaluation results of all the items in ``valid_sets``.
This should be initialized outside of your call to ``train()`` and should be empty.
Expand Down Expand Up @@ -239,7 +242,7 @@ def train(
callbacks.add(callback.print_evaluation(verbose_eval))

if early_stopping_rounds is not None and early_stopping_rounds > 0:
callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval)))
callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval), threshold=early_stopping_threshold))

if learning_rates is not None:
callbacks.add(callback.reset_parameter(learning_rate=learning_rates))
Expand Down Expand Up @@ -421,8 +424,8 @@ def cv(params, train_set, num_boost_round=100,
folds=None, nfold=5, stratified=True, shuffle=True,
metrics=None, fobj=None, feval=None, init_model=None,
feature_name='auto', categorical_feature='auto',
early_stopping_rounds=None, fpreproc=None,
verbose_eval=None, show_stdv=True, seed=0,
early_stopping_rounds=None, early_stopping_threshold=0.0,
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
fpreproc=None, verbose_eval=None, show_stdv=True, seed=0,
callbacks=None, eval_train_metric=False,
return_cvbooster=False):
"""Perform the cross-validation with given parameters.
Expand Down Expand Up @@ -515,6 +518,8 @@ def cv(params, train_set, num_boost_round=100,
Requires at least one metric. If there's more than one, will check all of them.
To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in ``params``.
Last entry in evaluation history is the one from the best iteration.
early_stopping_threshold : float or list of float (default=0.0)
Minimum improvement in score to keep training.
fpreproc : callable or None, optional (default=None)
Preprocessing function that takes (dtrain, dtest, params)
and returns transformed versions of those.
Expand Down Expand Up @@ -600,7 +605,7 @@ def cv(params, train_set, num_boost_round=100,
cb.__dict__.setdefault('order', i - len(callbacks))
callbacks = set(callbacks)
if early_stopping_rounds is not None and early_stopping_rounds > 0:
callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=False))
callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=False, threshold=early_stopping_threshold))
if verbose_eval is True:
callbacks.add(callback.print_evaluation(show_stdv=show_stdv))
elif isinstance(verbose_eval, int):
Expand Down
30 changes: 18 additions & 12 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ def __call__(self, preds, dataset):
If there's more than one, will check all of them. But the training data is ignored anyway.
To check only the first metric, set the ``first_metric_only`` parameter to ``True``
in additional parameters ``**kwargs`` of the model constructor.
early_stopping_threshold : float or list of float (default=0.0)
Minimum improvement in score to keep training.
verbose : bool or int, optional (default=True)
Requires at least one evaluation data.
If True, the eval metric on the eval set is printed at each boosting stage.
Expand Down Expand Up @@ -570,8 +572,8 @@ def fit(self, X, y,
sample_weight=None, init_score=None, group=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_class_weight=None, eval_init_score=None, eval_group=None,
eval_metric=None, early_stopping_rounds=None, verbose=True,
feature_name='auto', categorical_feature='auto',
eval_metric=None, early_stopping_rounds=None, early_stopping_threshold=0.0,
verbose=True, feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is set after definition, using a template."""
if self._objective is None:
Expand Down Expand Up @@ -711,7 +713,7 @@ def _get_meta_data(collection, name, i):

self._Booster = train(params, train_set,
self.n_estimators, valid_sets=valid_sets, valid_names=eval_names,
early_stopping_rounds=early_stopping_rounds,
early_stopping_rounds=early_stopping_rounds, early_stopping_threshold=early_stopping_threshold,
evals_result=evals_result, fobj=self._fobj, feval=eval_metrics_callable,
verbose_eval=verbose, feature_name=feature_name,
callbacks=callbacks, init_model=init_model)
Expand Down Expand Up @@ -843,13 +845,14 @@ def fit(self, X, y,
sample_weight=None, init_score=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_metric=None, early_stopping_rounds=None,
verbose=True, feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
early_stopping_threshold=0.0, verbose=True, feature_name='auto',
categorical_feature='auto', callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel."""
super().fit(X, y, sample_weight=sample_weight, init_score=init_score,
eval_set=eval_set, eval_names=eval_names, eval_sample_weight=eval_sample_weight,
eval_init_score=eval_init_score, eval_metric=eval_metric,
early_stopping_rounds=early_stopping_rounds, verbose=verbose, feature_name=feature_name,
early_stopping_rounds=early_stopping_rounds, early_stopping_threshold=early_stopping_threshold,
verbose=verbose, feature_name=feature_name,
categorical_feature=categorical_feature, callbacks=callbacks, init_model=init_model)
return self

Expand All @@ -869,8 +872,8 @@ def fit(self, X, y,
sample_weight=None, init_score=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_class_weight=None, eval_init_score=None, eval_metric=None,
early_stopping_rounds=None, verbose=True,
feature_name='auto', categorical_feature='auto',
early_stopping_rounds=None, early_stopping_threshold=0.0,
verbose=True, feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel."""
_LGBMAssertAllFinite(y)
Expand Down Expand Up @@ -922,7 +925,8 @@ def fit(self, X, y,
eval_names=eval_names, eval_sample_weight=eval_sample_weight,
eval_class_weight=eval_class_weight, eval_init_score=eval_init_score,
eval_metric=eval_metric, early_stopping_rounds=early_stopping_rounds,
verbose=verbose, feature_name=feature_name, categorical_feature=categorical_feature,
early_stopping_threshold=early_stopping_threshold, verbose=verbose,
feature_name=feature_name, categorical_feature=categorical_feature,
callbacks=callbacks, init_model=init_model)
return self

Expand Down Expand Up @@ -997,7 +1001,8 @@ def fit(self, X, y,
sample_weight=None, init_score=None, group=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_group=None, eval_metric=None,
eval_at=(1, 2, 3, 4, 5), early_stopping_rounds=None, verbose=True,
eval_at=(1, 2, 3, 4, 5), early_stopping_rounds=None,
early_stopping_threshold=0.0, verbose=True,
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel."""
Expand All @@ -1021,8 +1026,9 @@ def fit(self, X, y,
super().fit(X, y, sample_weight=sample_weight, init_score=init_score, group=group,
eval_set=eval_set, eval_names=eval_names, eval_sample_weight=eval_sample_weight,
eval_init_score=eval_init_score, eval_group=eval_group, eval_metric=eval_metric,
early_stopping_rounds=early_stopping_rounds, verbose=verbose, feature_name=feature_name,
categorical_feature=categorical_feature, callbacks=callbacks, init_model=init_model)
early_stopping_rounds=early_stopping_rounds, early_stopping_threshold=early_stopping_threshold,
verbose=verbose, feature_name=feature_name, categorical_feature=categorical_feature,
callbacks=callbacks, init_model=init_model)
return self

_base_doc = LGBMModel.fit.__doc__
Expand Down
77 changes: 76 additions & 1 deletion tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import psutil
import pytest
from scipy.sparse import csr_matrix, isspmatrix_csc, isspmatrix_csr
from sklearn.datasets import load_svmlight_file, make_multilabel_classification
from sklearn.datasets import load_svmlight_file, make_classification, make_multilabel_classification
from sklearn.metrics import average_precision_score, log_loss, mean_absolute_error, mean_squared_error, roc_auc_score
from sklearn.model_selection import GroupKFold, TimeSeriesSplit, train_test_split

Expand Down Expand Up @@ -642,6 +642,81 @@ def test_early_stopping():
assert 'binary_logloss' in gbm.best_score[valid_set_name]


@pytest.mark.parametrize('first_only', [True, False])
@pytest.mark.parametrize('single_metric', [True, False])
@pytest.mark.parametrize('greater_is_better', [True, False])
def test_early_stopping_threshold(first_only, single_metric, greater_is_better):
if single_metric and not first_only:
pytest.skip("first_metric_only doesn't affect single metric.")
metric2threshold = {
'auc': 0.001,
'binary_logloss': 0.01,
'average_precision': 0.001,
'mape': 0.001,
}
if single_metric:
if greater_is_better:
metric = 'auc'
else:
metric = 'binary_logloss'
else:
if first_only:
if greater_is_better:
metric = ['auc', 'binary_logloss']
else:
metric = ['binary_logloss', 'auc']
else:
if greater_is_better:
metric = ['auc', 'average_precision']
else:
metric = ['binary_logloss', 'mape']

X, y = make_classification(n_samples=1_000, n_features=2, n_redundant=0, n_classes=2, random_state=0)
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=0)
train_ds = lgb.Dataset(X_train, y_train)
valid_ds = lgb.Dataset(X_valid, y_valid, reference=train_ds)

params = {'objective': 'binary', 'metric': metric, 'first_metric_only': first_only, 'verbose': -1}
if isinstance(metric, str):
threshold = metric2threshold[metric]
elif first_only:
threshold = metric2threshold[metric[0]]
else:
threshold = [metric2threshold[m] for m in metric]
train_kwargs = dict(
params=params,
train_set=train_ds,
num_boost_round=100,
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
valid_sets=[train_ds, valid_ds],
valid_names=['training', 'valid'],
early_stopping_rounds=10,
verbose_eval=0,
)

# regular early stopping
evals_result = {}
bst = lgb.train(evals_result=evals_result, **train_kwargs)
scores = np.vstack([res for res in evals_result['valid'].values()]).T

# positive threshold
threshold_result = {}
threshold_bst = lgb.train(early_stopping_threshold=threshold, evals_result=threshold_result, **train_kwargs)
threshold_scores = np.vstack([res for res in threshold_result['valid'].values()]).T

if first_only:
scores = scores[:, 0]
threshold_scores = threshold_scores[:, 0]

assert threshold_bst.num_trees() < bst.num_trees()
np.testing.assert_allclose(scores[:len(threshold_scores)], threshold_scores)
last_score = threshold_scores[-1]
best_score = threshold_scores[threshold_bst.num_trees() - 1]
if greater_is_better:
assert np.less_equal(last_score, best_score + threshold).any()
else:
assert np.greater_equal(last_score, best_score - threshold).any()


def test_continue_train():
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
Expand Down