diff --git a/python-package/lightgbm/__init__.py b/python-package/lightgbm/__init__.py index 5815bc602bde..0dc5b75cfdf2 100644 --- a/python-package/lightgbm/__init__.py +++ b/python-package/lightgbm/__init__.py @@ -6,7 +6,7 @@ from pathlib import Path from .basic import Booster, Dataset, Sequence, register_logger -from .callback import early_stopping, log_evaluation, record_evaluation, reset_parameter +from .callback import EarlyStopException, early_stopping, log_evaluation, record_evaluation, reset_parameter from .engine import CVBooster, cv, train try: @@ -32,5 +32,5 @@ 'train', 'cv', 'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker', 'DaskLGBMRegressor', 'DaskLGBMClassifier', 'DaskLGBMRanker', - 'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', + 'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'EarlyStopException', 'plot_importance', 'plot_split_value_histogram', 'plot_metric', 'plot_tree', 'create_tree_digraph'] diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 84733e6505e1..9302eedd824f 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -12,6 +12,7 @@ from .engine import CVBooster __all__ = [ + 'EarlyStopException', 'early_stopping', 'log_evaluation', 'record_evaluation', @@ -30,7 +31,11 @@ class EarlyStopException(Exception): - """Exception of early stopping.""" + """Exception of early stopping. + + Raise this from a callback passed in via keyword argument ``callbacks`` + in ``cv()`` or ``train()`` to trigger early stopping. + """ def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None: """Create early stopping exception. @@ -39,6 +44,7 @@ def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> ---------- best_iteration : int The best iteration stopped. + 0-based... pass ``best_iteration=2`` to indicate that the third iteration was the best one. best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple Scores for each metric, on each validation set, as of the best iteration. """ diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 25413d7ea072..e55df17607b4 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -1092,6 +1092,33 @@ def test_early_stopping_min_delta(first_only, single_metric, greater_is_better): assert np.greater_equal(last_score, best_score - min_delta).any() +def test_early_stopping_can_be_triggered_via_custom_callback(): + X, y = make_synthetic_regression() + + def _early_stop_after_seventh_iteration(env): + if env.iteration == 6: + exc = lgb.EarlyStopException( + best_iteration=6, + best_score=[("some_validation_set", "some_metric", 0.708, True)] + ) + raise exc + + bst = lgb.train( + params={ + "objective": "regression", + "verbose": -1, + "num_leaves": 2 + }, + train_set=lgb.Dataset(X, label=y), + num_boost_round=23, + callbacks=[_early_stop_after_seventh_iteration] + ) + assert bst.num_trees() == 7 + assert bst.best_score["some_validation_set"]["some_metric"] == 0.708 + assert bst.best_iteration == 7 + assert bst.current_iteration() == 7 + + def test_continue_train(): X, y = make_synthetic_regression() X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)