Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] allow use of early_stopping_round<=0 to turn off early stopping (fixes #6401) #6406

Merged
merged 10 commits into from
Apr 20, 2024
16 changes: 13 additions & 3 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,7 @@ def __init__(
verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0,
) -> None:
if not isinstance(stopping_rounds, int) or stopping_rounds <= 0:
raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}")
self.enabled = _should_enable_early_stopping(stopping_rounds)

self.order = 30
self.before_iteration = False
Expand All @@ -291,7 +290,6 @@ def __init__(
self.verbose = verbose
self.min_delta = min_delta

self.enabled = True
self._reset_storages()

def _reset_storages(self) -> None:
Expand Down Expand Up @@ -438,6 +436,18 @@ def __call__(self, env: CallbackEnv) -> None:
self._final_iteration_check(env, eval_name_splitted, i)


def _should_enable_early_stopping(stopping_rounds: Any) -> bool:
"""Check if early stopping should be activated.

This function will evaluate to True if the early stopping callback should be
activated (i.e. stopping_rounds > 0). It also provides an informative error if the
type is not int.
"""
if not isinstance(stopping_rounds, int):
raise TypeError(f"early_stopping_round should be an integer. Got " f"{type(stopping_rounds).__name__}")
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
return stopping_rounds > 0


def early_stopping(
stopping_rounds: int,
first_metric_only: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def train(
cb.__dict__.setdefault("order", i - len(callbacks))
callbacks_set = set(callbacks)

if "early_stopping_round" in params:
if callback._should_enable_early_stopping(params.get("early_stopping_round", 0)):
callbacks_set.add(
callback.early_stopping(
stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type]
Expand Down Expand Up @@ -760,7 +760,7 @@ def cv(
cb.__dict__.setdefault("order", i - len(callbacks))
callbacks_set = set(callbacks)

if "early_stopping_round" in params:
if callback._should_enable_early_stopping(params.get("early_stopping_round", 0)):
callbacks_set.add(
callback.early_stopping(
stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type]
Expand Down
12 changes: 6 additions & 6 deletions tests/python_package_test/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ def test_early_stopping_callback_is_picklable(serializer):


def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors():
with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: 0"):
lgb.early_stopping(stopping_rounds=0)
with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got str"):
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
lgb.early_stopping(stopping_rounds="neverrrr")
jameslamb marked this conversation as resolved.
Show resolved Hide resolved

with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: -1"):
lgb.early_stopping(stopping_rounds=-1)

with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: neverrrr"):
lgb.early_stopping(stopping_rounds="neverrrr")
@pytest.mark.parametrize("stopping_rounds", [-10, -1, 0])
def test_early_stopping_callback_accepts_non_positive_stopping_rounds(stopping_rounds):
cb = lgb.early_stopping(stopping_rounds=stopping_rounds)
assert cb.enabled is False


@pytest.mark.parametrize("serializer", SERIALIZERS)
Expand Down
48 changes: 48 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,54 @@ def test_early_stopping_via_global_params(first_metric_only):
assert "error" in gbm.best_score[valid_set_name]


@pytest.mark.parametrize("early_stopping_round", [-10, -1, 0, None, "None"])
def test_early_stopping_is_not_enabled_for_non_positive_stopping_rounds(early_stopping_round):
X, y = load_breast_cancer(return_X_y=True)
num_trees = 5
params = {
"num_trees": num_trees,
"objective": "binary",
"metric": "None",
"verbose": -1,
"early_stopping_round": early_stopping_round,
"first_metric_only": True,
}
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
valid_set_name = "valid_set"

if early_stopping_round is None:
gbm = lgb.train(
params,
lgb_train,
feval=[constant_metric],
valid_sets=lgb_eval,
valid_names=valid_set_name,
)
assert "early_stopping_round" not in gbm.params
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
assert gbm.num_trees() == num_trees
elif early_stopping_round == "None":
with pytest.raises(TypeError, match="early_stopping_round should be an integer. Got str"):
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
gbm = lgb.train(
params,
lgb_train,
feval=[constant_metric],
valid_sets=lgb_eval,
valid_names=valid_set_name,
)
elif early_stopping_round <= 0:
gbm = lgb.train(
params,
lgb_train,
feval=[constant_metric],
valid_sets=lgb_eval,
valid_names=valid_set_name,
)
assert gbm.params["early_stopping_round"] == early_stopping_round
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
assert gbm.num_trees() == num_trees


@pytest.mark.parametrize("first_only", [True, False])
@pytest.mark.parametrize("single_metric", [True, False])
@pytest.mark.parametrize("greater_is_better", [True, False])
Expand Down
Loading