Skip to content

Commit

Permalink
[python-package] allow use of early_stopping_round<=0 to turn off ear…
Browse files Browse the repository at this point in the history
…ly stopping (fixes #6401) (#6406)
  • Loading branch information
ddelzell authored Apr 20, 2024
1 parent f6ecd4d commit 1ce2357
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 11 deletions.
16 changes: 13 additions & 3 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,7 @@ def __init__(
verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0,
) -> None:
if not isinstance(stopping_rounds, int) or stopping_rounds <= 0:
raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}")
self.enabled = _should_enable_early_stopping(stopping_rounds)

self.order = 30
self.before_iteration = False
Expand All @@ -291,7 +290,6 @@ def __init__(
self.verbose = verbose
self.min_delta = min_delta

self.enabled = True
self._reset_storages()

def _reset_storages(self) -> None:
Expand Down Expand Up @@ -438,6 +436,18 @@ def __call__(self, env: CallbackEnv) -> None:
self._final_iteration_check(env, eval_name_splitted, i)


def _should_enable_early_stopping(stopping_rounds: Any) -> bool:
"""Check if early stopping should be activated.
This function will evaluate to True if the early stopping callback should be
activated (i.e. stopping_rounds > 0). It also provides an informative error if the
type is not int.
"""
if not isinstance(stopping_rounds, int):
raise TypeError(f"early_stopping_round should be an integer. Got '{type(stopping_rounds).__name__}'")
return stopping_rounds > 0


def early_stopping(
stopping_rounds: int,
first_metric_only: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def train(
cb.__dict__.setdefault("order", i - len(callbacks))
callbacks_set = set(callbacks)

if "early_stopping_round" in params:
if callback._should_enable_early_stopping(params.get("early_stopping_round", 0)):
callbacks_set.add(
callback.early_stopping(
stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type]
Expand Down Expand Up @@ -760,7 +760,7 @@ def cv(
cb.__dict__.setdefault("order", i - len(callbacks))
callbacks_set = set(callbacks)

if "early_stopping_round" in params:
if callback._should_enable_early_stopping(params.get("early_stopping_round", 0)):
callbacks_set.add(
callback.early_stopping(
stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type]
Expand Down
12 changes: 6 additions & 6 deletions tests/python_package_test/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def test_early_stopping_callback_is_picklable(serializer):


def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors():
with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: 0"):
lgb.early_stopping(stopping_rounds=0)
with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got 'str'"):
lgb.early_stopping(stopping_rounds="neverrrr")

with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: -1"):
lgb.early_stopping(stopping_rounds=-1)

with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: neverrrr"):
lgb.early_stopping(stopping_rounds="neverrrr")
@pytest.mark.parametrize("stopping_rounds", [-10, -1, 0])
def test_early_stopping_callback_accepts_non_positive_stopping_rounds(stopping_rounds):
cb = lgb.early_stopping(stopping_rounds=stopping_rounds)
assert cb.enabled is False


@pytest.mark.parametrize("serializer", SERIALIZERS)
Expand Down
48 changes: 48 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,54 @@ def test_early_stopping_via_global_params(first_metric_only):
assert "error" in gbm.best_score[valid_set_name]


@pytest.mark.parametrize("early_stopping_round", [-10, -1, 0, None, "None"])
def test_early_stopping_is_not_enabled_for_non_positive_stopping_rounds(early_stopping_round):
X, y = load_breast_cancer(return_X_y=True)
num_trees = 5
params = {
"num_trees": num_trees,
"objective": "binary",
"metric": "None",
"verbose": -1,
"early_stopping_round": early_stopping_round,
"first_metric_only": True,
}
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
valid_set_name = "valid_set"

if early_stopping_round is None:
gbm = lgb.train(
params,
lgb_train,
feval=[constant_metric],
valid_sets=lgb_eval,
valid_names=valid_set_name,
)
assert "early_stopping_round" not in gbm.params
assert gbm.num_trees() == num_trees
elif early_stopping_round == "None":
with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got 'str'"):
gbm = lgb.train(
params,
lgb_train,
feval=[constant_metric],
valid_sets=lgb_eval,
valid_names=valid_set_name,
)
elif early_stopping_round <= 0:
gbm = lgb.train(
params,
lgb_train,
feval=[constant_metric],
valid_sets=lgb_eval,
valid_names=valid_set_name,
)
assert gbm.params["early_stopping_round"] == early_stopping_round
assert gbm.num_trees() == num_trees


@pytest.mark.parametrize("first_only", [True, False])
@pytest.mark.parametrize("single_metric", [True, False])
@pytest.mark.parametrize("greater_is_better", [True, False])
Expand Down

0 comments on commit 1ce2357

Please sign in to comment.