From 1ce2357193fe4955c7415386d1823e95ec06893b Mon Sep 17 00:00:00 2001 From: Darcie Delzell <108882028+ddelzell@users.noreply.github.com> Date: Fri, 19 Apr 2024 23:31:30 -0500 Subject: [PATCH] [python-package] allow use of early_stopping_round<=0 to turn off early stopping (fixes #6401) (#6406) --- python-package/lightgbm/callback.py | 16 ++++++-- python-package/lightgbm/engine.py | 4 +- tests/python_package_test/test_callback.py | 12 +++--- tests/python_package_test/test_engine.py | 48 ++++++++++++++++++++++ 4 files changed, 69 insertions(+), 11 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 5947796dcb3f..e776ea953bd1 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -280,8 +280,7 @@ def __init__( verbose: bool = True, min_delta: Union[float, List[float]] = 0.0, ) -> None: - if not isinstance(stopping_rounds, int) or stopping_rounds <= 0: - raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}") + self.enabled = _should_enable_early_stopping(stopping_rounds) self.order = 30 self.before_iteration = False @@ -291,7 +290,6 @@ def __init__( self.verbose = verbose self.min_delta = min_delta - self.enabled = True self._reset_storages() def _reset_storages(self) -> None: @@ -438,6 +436,18 @@ def __call__(self, env: CallbackEnv) -> None: self._final_iteration_check(env, eval_name_splitted, i) +def _should_enable_early_stopping(stopping_rounds: Any) -> bool: + """Check if early stopping should be activated. + + This function will evaluate to True if the early stopping callback should be + activated (i.e. stopping_rounds > 0). It also provides an informative error if the + type is not int. + """ + if not isinstance(stopping_rounds, int): + raise TypeError(f"early_stopping_round should be an integer. Got '{type(stopping_rounds).__name__}'") + return stopping_rounds > 0 + + def early_stopping( stopping_rounds: int, first_metric_only: bool = False, diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 561349a44146..a19b29e7b584 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -236,7 +236,7 @@ def train( cb.__dict__.setdefault("order", i - len(callbacks)) callbacks_set = set(callbacks) - if "early_stopping_round" in params: + if callback._should_enable_early_stopping(params.get("early_stopping_round", 0)): callbacks_set.add( callback.early_stopping( stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type] @@ -760,7 +760,7 @@ def cv( cb.__dict__.setdefault("order", i - len(callbacks)) callbacks_set = set(callbacks) - if "early_stopping_round" in params: + if callback._should_enable_early_stopping(params.get("early_stopping_round", 0)): callbacks_set.add( callback.early_stopping( stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type] diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index a13ee9c0e6e9..48c7a29e8705 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -22,14 +22,14 @@ def test_early_stopping_callback_is_picklable(serializer): def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors(): - with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: 0"): - lgb.early_stopping(stopping_rounds=0) + with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got 'str'"): + lgb.early_stopping(stopping_rounds="neverrrr") - with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: -1"): - lgb.early_stopping(stopping_rounds=-1) - with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: neverrrr"): - lgb.early_stopping(stopping_rounds="neverrrr") +@pytest.mark.parametrize("stopping_rounds", [-10, -1, 0]) +def test_early_stopping_callback_accepts_non_positive_stopping_rounds(stopping_rounds): + cb = lgb.early_stopping(stopping_rounds=stopping_rounds) + assert cb.enabled is False @pytest.mark.parametrize("serializer", SERIALIZERS) diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 3fad36b34407..05c5792b1836 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -938,6 +938,54 @@ def test_early_stopping_via_global_params(first_metric_only): assert "error" in gbm.best_score[valid_set_name] +@pytest.mark.parametrize("early_stopping_round", [-10, -1, 0, None, "None"]) +def test_early_stopping_is_not_enabled_for_non_positive_stopping_rounds(early_stopping_round): + X, y = load_breast_cancer(return_X_y=True) + num_trees = 5 + params = { + "num_trees": num_trees, + "objective": "binary", + "metric": "None", + "verbose": -1, + "early_stopping_round": early_stopping_round, + "first_metric_only": True, + } + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) + lgb_train = lgb.Dataset(X_train, y_train) + lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train) + valid_set_name = "valid_set" + + if early_stopping_round is None: + gbm = lgb.train( + params, + lgb_train, + feval=[constant_metric], + valid_sets=lgb_eval, + valid_names=valid_set_name, + ) + assert "early_stopping_round" not in gbm.params + assert gbm.num_trees() == num_trees + elif early_stopping_round == "None": + with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got 'str'"): + gbm = lgb.train( + params, + lgb_train, + feval=[constant_metric], + valid_sets=lgb_eval, + valid_names=valid_set_name, + ) + elif early_stopping_round <= 0: + gbm = lgb.train( + params, + lgb_train, + feval=[constant_metric], + valid_sets=lgb_eval, + valid_names=valid_set_name, + ) + assert gbm.params["early_stopping_round"] == early_stopping_round + assert gbm.num_trees() == num_trees + + @pytest.mark.parametrize("first_only", [True, False]) @pytest.mark.parametrize("single_metric", [True, False]) @pytest.mark.parametrize("greater_is_better", [True, False])