From 2ba639f2617a476d88c7f27b5d033e2757a9853a Mon Sep 17 00:00:00 2001 From: Darcie Delzell Date: Wed, 3 Apr 2024 14:06:57 -0500 Subject: [PATCH 1/7] Remove early stopping callback addition when stopping_rounds <=0 --- python-package/lightgbm/callback.py | 14 +++++--- python-package/lightgbm/engine.py | 17 +++++++--- tests/python_package_test/test_callback.py | 9 ++---- tests/python_package_test/test_engine.py | 37 ++++++++++++++++++++++ 4 files changed, 62 insertions(+), 15 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 5947796dcb3f..9253717fe170 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -280,18 +280,24 @@ def __init__( verbose: bool = True, min_delta: Union[float, List[float]] = 0.0, ) -> None: - if not isinstance(stopping_rounds, int) or stopping_rounds <= 0: - raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}") + if not isinstance(stopping_rounds, int): + raise TypeError( + f"stopping_rounds should be an integer. Got {type(stopping_rounds)}") + + self.stopping_rounds = stopping_rounds + + if stopping_rounds > 0: + self.enabled = True + else: + self.enabled = False self.order = 30 self.before_iteration = False - self.stopping_rounds = stopping_rounds self.first_metric_only = first_metric_only self.verbose = verbose self.min_delta = min_delta - self.enabled = True self._reset_storages() def _reset_storages(self) -> None: diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 561349a44146..76644f93eed2 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -188,8 +188,17 @@ def train( params=params, default_value=None, ) - if params["early_stopping_round"] is None: - params.pop("early_stopping_round") + if "early_stopping_round" in params: + if params["early_stopping_round"] is None: + params.pop("early_stopping_round") + # the check below happens if the callback is instantiated, but if the user + # passes a non-numeric value `params.get("early_stopping_round", 0) > 0` below + # will fail prior to the callback instantiation, so a TypeError should be raised + # here as well + elif not isinstance(params["early_stopping_round"], int): + raise TypeError( + f"stopping_rounds should be an integer. Got {type(params['early_stopping_round'])}") + first_metric_only = params.get("first_metric_only", False) predictor: Optional[_InnerPredictor] = None @@ -236,7 +245,7 @@ def train( cb.__dict__.setdefault("order", i - len(callbacks)) callbacks_set = set(callbacks) - if "early_stopping_round" in params: + if params.get("early_stopping_round", 0) > 0: callbacks_set.add( callback.early_stopping( stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type] @@ -760,7 +769,7 @@ def cv( cb.__dict__.setdefault("order", i - len(callbacks)) callbacks_set = set(callbacks) - if "early_stopping_round" in params: + if params.get("early_stopping_round", 0) > 0: callbacks_set.add( callback.early_stopping( stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type] diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index a13ee9c0e6e9..1f3fc626cdd0 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -22,13 +22,8 @@ def test_early_stopping_callback_is_picklable(serializer): def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors(): - with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: 0"): - lgb.early_stopping(stopping_rounds=0) - - with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: -1"): - lgb.early_stopping(stopping_rounds=-1) - - with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: neverrrr"): + with pytest.raises(TypeError, match="stopping_rounds should be an integer. Got " + ""): lgb.early_stopping(stopping_rounds="neverrrr") diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 3fad36b34407..3cf54d0b06c3 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -938,6 +938,43 @@ def test_early_stopping_via_global_params(first_metric_only): assert "error" in gbm.best_score[valid_set_name] +@pytest.mark.parametrize("early_stopping_round", [-10, -1, 0, None, "None"]) +def test_early_stopping_non_positive_values(early_stopping_round): + X, y = load_breast_cancer(return_X_y=True) + num_trees = 5 + params = { + "num_trees": num_trees, + "objective": "binary", + "metric": "None", + "verbose": -1, + "early_stopping_round": early_stopping_round, + "first_metric_only": True, + } + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) + lgb_train = lgb.Dataset(X_train, y_train) + lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train) + valid_set_name = "valid_set" + + if early_stopping_round is None: + gbm = lgb.train( + params, lgb_train, feval=[decreasing_metric, constant_metric], + valid_sets=lgb_eval, valid_names=valid_set_name + ) + assert not "early_stopping_round" in gbm.params + elif early_stopping_round == "None": + with pytest.raises(TypeError): + gbm = lgb.train( + params, lgb_train, feval=[decreasing_metric, constant_metric], + valid_sets=lgb_eval, valid_names=valid_set_name + ) + elif early_stopping_round <=0: + gbm = lgb.train( + params, lgb_train, feval=[decreasing_metric, constant_metric], + valid_sets=lgb_eval, valid_names=valid_set_name + ) + assert gbm.params["early_stopping_round"] == early_stopping_round + + @pytest.mark.parametrize("first_only", [True, False]) @pytest.mark.parametrize("single_metric", [True, False]) @pytest.mark.parametrize("greater_is_better", [True, False]) From 1dd9e9a84fa729fe5384b0e559fd4a573c33667f Mon Sep 17 00:00:00 2001 From: Darcie Delzell Date: Thu, 4 Apr 2024 11:09:12 -0500 Subject: [PATCH 2/7] Add should_enable_early_stopping function --- python-package/lightgbm/callback.py | 21 +++++++++++------- python-package/lightgbm/engine.py | 17 ++++----------- tests/python_package_test/test_callback.py | 3 +-- tests/python_package_test/test_engine.py | 25 +++++++++++++++------- 4 files changed, 35 insertions(+), 31 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 9253717fe170..bc44840e48e8 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -280,16 +280,9 @@ def __init__( verbose: bool = True, min_delta: Union[float, List[float]] = 0.0, ) -> None: - if not isinstance(stopping_rounds, int): - raise TypeError( - f"stopping_rounds should be an integer. Got {type(stopping_rounds)}") - self.stopping_rounds = stopping_rounds - if stopping_rounds > 0: - self.enabled = True - else: - self.enabled = False + self.enabled = _should_enable_early_stopping(stopping_rounds) self.order = 30 self.before_iteration = False @@ -444,6 +437,18 @@ def __call__(self, env: CallbackEnv) -> None: self._final_iteration_check(env, eval_name_splitted, i) +def _should_enable_early_stopping(stopping_rounds: Any) -> bool: + """Check if early stopping should be activated. + + This function will evaluate to True if the early stopping callback should be + activated (i.e. stopping_rounds > 0). It also provides an informative error if the + type is not int. + """ + if not isinstance(stopping_rounds, int): + raise TypeError(f"early_stopping_round should be an integer. Got {type(stopping_rounds)}") + return stopping_rounds > 0 + + def early_stopping( stopping_rounds: int, first_metric_only: bool = False, diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 76644f93eed2..a19b29e7b584 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -188,17 +188,8 @@ def train( params=params, default_value=None, ) - if "early_stopping_round" in params: - if params["early_stopping_round"] is None: - params.pop("early_stopping_round") - # the check below happens if the callback is instantiated, but if the user - # passes a non-numeric value `params.get("early_stopping_round", 0) > 0` below - # will fail prior to the callback instantiation, so a TypeError should be raised - # here as well - elif not isinstance(params["early_stopping_round"], int): - raise TypeError( - f"stopping_rounds should be an integer. Got {type(params['early_stopping_round'])}") - + if params["early_stopping_round"] is None: + params.pop("early_stopping_round") first_metric_only = params.get("first_metric_only", False) predictor: Optional[_InnerPredictor] = None @@ -245,7 +236,7 @@ def train( cb.__dict__.setdefault("order", i - len(callbacks)) callbacks_set = set(callbacks) - if params.get("early_stopping_round", 0) > 0: + if callback._should_enable_early_stopping(params.get("early_stopping_round", 0)): callbacks_set.add( callback.early_stopping( stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type] @@ -769,7 +760,7 @@ def cv( cb.__dict__.setdefault("order", i - len(callbacks)) callbacks_set = set(callbacks) - if params.get("early_stopping_round", 0) > 0: + if callback._should_enable_early_stopping(params.get("early_stopping_round", 0)): callbacks_set.add( callback.early_stopping( stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type] diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index 1f3fc626cdd0..ea3e400e1046 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -22,8 +22,7 @@ def test_early_stopping_callback_is_picklable(serializer): def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors(): - with pytest.raises(TypeError, match="stopping_rounds should be an integer. Got " - ""): + with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got " ""): lgb.early_stopping(stopping_rounds="neverrrr") diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 3cf54d0b06c3..c70f5564e038 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -957,20 +957,29 @@ def test_early_stopping_non_positive_values(early_stopping_round): if early_stopping_round is None: gbm = lgb.train( - params, lgb_train, feval=[decreasing_metric, constant_metric], - valid_sets=lgb_eval, valid_names=valid_set_name + params, + lgb_train, + feval=[decreasing_metric, constant_metric], + valid_sets=lgb_eval, + valid_names=valid_set_name, ) - assert not "early_stopping_round" in gbm.params + assert "early_stopping_round" not in gbm.params elif early_stopping_round == "None": with pytest.raises(TypeError): gbm = lgb.train( - params, lgb_train, feval=[decreasing_metric, constant_metric], - valid_sets=lgb_eval, valid_names=valid_set_name + params, + lgb_train, + feval=[decreasing_metric, constant_metric], + valid_sets=lgb_eval, + valid_names=valid_set_name, ) - elif early_stopping_round <=0: + elif early_stopping_round <= 0: gbm = lgb.train( - params, lgb_train, feval=[decreasing_metric, constant_metric], - valid_sets=lgb_eval, valid_names=valid_set_name + params, + lgb_train, + feval=[decreasing_metric, constant_metric], + valid_sets=lgb_eval, + valid_names=valid_set_name, ) assert gbm.params["early_stopping_round"] == early_stopping_round From 260fb9483910f7e52d15365b39448d5adef5ba8c Mon Sep 17 00:00:00 2001 From: Darcie Delzell Date: Fri, 5 Apr 2024 12:44:50 -0500 Subject: [PATCH 3/7] WIP --- python-package/lightgbm/callback.py | 5 ++--- tests/python_package_test/test_callback.py | 8 +++++++- tests/python_package_test/test_engine.py | 12 +++++++----- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index bc44840e48e8..8548de3b0afd 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -280,13 +280,12 @@ def __init__( verbose: bool = True, min_delta: Union[float, List[float]] = 0.0, ) -> None: - self.stopping_rounds = stopping_rounds - self.enabled = _should_enable_early_stopping(stopping_rounds) self.order = 30 self.before_iteration = False + self.stopping_rounds = stopping_rounds self.first_metric_only = first_metric_only self.verbose = verbose self.min_delta = min_delta @@ -445,7 +444,7 @@ def _should_enable_early_stopping(stopping_rounds: Any) -> bool: type is not int. """ if not isinstance(stopping_rounds, int): - raise TypeError(f"early_stopping_round should be an integer. Got {type(stopping_rounds)}") + raise TypeError(f"early_stopping_round should be an integer. Got " f"{type(stopping_rounds).__name__}") return stopping_rounds > 0 diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index ea3e400e1046..e7d29947d7ef 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -22,10 +22,16 @@ def test_early_stopping_callback_is_picklable(serializer): def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors(): - with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got " ""): + with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got str"): lgb.early_stopping(stopping_rounds="neverrrr") +@pytest.mark.parametrize("stopping_rounds", [-10, -1, 0]) +def test_early_stopping_callback_accepts_non_positive_stopping_rounds(stopping_rounds): + cb = lgb.early_stopping(stopping_rounds=stopping_rounds) + assert cb.enabled is False + + @pytest.mark.parametrize("serializer", SERIALIZERS) def test_log_evaluation_callback_is_picklable(serializer): periods = 42 diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index c70f5564e038..ccf6fb37f42f 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -939,7 +939,7 @@ def test_early_stopping_via_global_params(first_metric_only): @pytest.mark.parametrize("early_stopping_round", [-10, -1, 0, None, "None"]) -def test_early_stopping_non_positive_values(early_stopping_round): +def test_early_stopping_is_not_enabled_for_non_positive_stopping_rounds(early_stopping_round): X, y = load_breast_cancer(return_X_y=True) num_trees = 5 params = { @@ -959,17 +959,18 @@ def test_early_stopping_non_positive_values(early_stopping_round): gbm = lgb.train( params, lgb_train, - feval=[decreasing_metric, constant_metric], + feval=[constant_metric], valid_sets=lgb_eval, valid_names=valid_set_name, ) assert "early_stopping_round" not in gbm.params + assert gbm.num_trees() == num_trees elif early_stopping_round == "None": - with pytest.raises(TypeError): + with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got str"): gbm = lgb.train( params, lgb_train, - feval=[decreasing_metric, constant_metric], + feval=[constant_metric], valid_sets=lgb_eval, valid_names=valid_set_name, ) @@ -977,11 +978,12 @@ def test_early_stopping_non_positive_values(early_stopping_round): gbm = lgb.train( params, lgb_train, - feval=[decreasing_metric, constant_metric], + feval=[constant_metric], valid_sets=lgb_eval, valid_names=valid_set_name, ) assert gbm.params["early_stopping_round"] == early_stopping_round + assert gbm.num_trees() == num_trees @pytest.mark.parametrize("first_only", [True, False]) From 26990fd6ac62f5b2e9364e5f3441082d389451f9 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 19 Apr 2024 21:00:26 -0500 Subject: [PATCH 4/7] Update python-package/lightgbm/callback.py --- python-package/lightgbm/callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 8548de3b0afd..6d2519c248c4 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -444,7 +444,7 @@ def _should_enable_early_stopping(stopping_rounds: Any) -> bool: type is not int. """ if not isinstance(stopping_rounds, int): - raise TypeError(f"early_stopping_round should be an integer. Got " f"{type(stopping_rounds).__name__}") + raise TypeError(f"early_stopping_round should be an integer. Got {type(stopping_rounds).__name__}") return stopping_rounds > 0 From a8bf7e17df80ed108749d017b7329ca7bc20faea Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 19 Apr 2024 21:04:29 -0500 Subject: [PATCH 5/7] Update python-package/lightgbm/callback.py --- python-package/lightgbm/callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 6d2519c248c4..e776ea953bd1 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -444,7 +444,7 @@ def _should_enable_early_stopping(stopping_rounds: Any) -> bool: type is not int. """ if not isinstance(stopping_rounds, int): - raise TypeError(f"early_stopping_round should be an integer. Got {type(stopping_rounds).__name__}") + raise TypeError(f"early_stopping_round should be an integer. Got '{type(stopping_rounds).__name__}'") return stopping_rounds > 0 From b6835ebb5919175a3f575f10d3c36aceca8d3b06 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 19 Apr 2024 21:04:47 -0500 Subject: [PATCH 6/7] Update tests/python_package_test/test_callback.py --- tests/python_package_test/test_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index e7d29947d7ef..48c7a29e8705 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -22,7 +22,7 @@ def test_early_stopping_callback_is_picklable(serializer): def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors(): - with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got str"): + with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got 'str'"): lgb.early_stopping(stopping_rounds="neverrrr") From 178ea94b8c83bf447035316e1ab9083eb827a690 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 19 Apr 2024 21:41:55 -0500 Subject: [PATCH 7/7] Update tests/python_package_test/test_engine.py --- tests/python_package_test/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index ccf6fb37f42f..05c5792b1836 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -966,7 +966,7 @@ def test_early_stopping_is_not_enabled_for_non_positive_stopping_rounds(early_st assert "early_stopping_round" not in gbm.params assert gbm.num_trees() == num_trees elif early_stopping_round == "None": - with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got str"): + with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got 'str'"): gbm = lgb.train( params, lgb_train,