Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] allow use of early_stopping_round<=0 to turn off early stopping (fixes #6401) #6406

Merged
merged 10 commits into from
Apr 20, 2024
14 changes: 10 additions & 4 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,18 +280,24 @@ def __init__(
verbose: bool = True,
min_delta: Union[float, List[float]] = 0.0,
) -> None:
if not isinstance(stopping_rounds, int) or stopping_rounds <= 0:
raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}")
if not isinstance(stopping_rounds, int):
raise TypeError(
f"stopping_rounds should be an integer. Got {type(stopping_rounds)}")

self.stopping_rounds = stopping_rounds

if stopping_rounds > 0:
self.enabled = True
else:
self.enabled = False

self.order = 30
self.before_iteration = False

self.stopping_rounds = stopping_rounds
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
self.first_metric_only = first_metric_only
self.verbose = verbose
self.min_delta = min_delta

self.enabled = True
self._reset_storages()

def _reset_storages(self) -> None:
Expand Down
17 changes: 13 additions & 4 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,17 @@ def train(
params=params,
default_value=None,
)
if params["early_stopping_round"] is None:
params.pop("early_stopping_round")
if "early_stopping_round" in params:
if params["early_stopping_round"] is None:
params.pop("early_stopping_round")
# the check below happens if the callback is instantiated, but if the user
# passes a non-numeric value `params.get("early_stopping_round", 0) > 0` below
# will fail prior to the callback instantiation, so a TypeError should be raised
# here as well
elif not isinstance(params["early_stopping_round"], int):
raise TypeError(
f"stopping_rounds should be an integer. Got {type(params['early_stopping_round'])}")
jameslamb marked this conversation as resolved.
Show resolved Hide resolved

first_metric_only = params.get("first_metric_only", False)

predictor: Optional[_InnerPredictor] = None
Expand Down Expand Up @@ -236,7 +245,7 @@ def train(
cb.__dict__.setdefault("order", i - len(callbacks))
callbacks_set = set(callbacks)

if "early_stopping_round" in params:
if params.get("early_stopping_round", 0) > 0:
callbacks_set.add(
callback.early_stopping(
stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type]
Expand Down Expand Up @@ -760,7 +769,7 @@ def cv(
cb.__dict__.setdefault("order", i - len(callbacks))
callbacks_set = set(callbacks)

if "early_stopping_round" in params:
if params.get("early_stopping_round", 0) > 0:
callbacks_set.add(
callback.early_stopping(
stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type]
Expand Down
9 changes: 2 additions & 7 deletions tests/python_package_test/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,8 @@ def test_early_stopping_callback_is_picklable(serializer):


def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors():
with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: 0"):
lgb.early_stopping(stopping_rounds=0)

with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: -1"):
lgb.early_stopping(stopping_rounds=-1)

with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: neverrrr"):
with pytest.raises(TypeError, match="stopping_rounds should be an integer. Got "
"<class 'str'>"):
lgb.early_stopping(stopping_rounds="neverrrr")
jameslamb marked this conversation as resolved.
Show resolved Hide resolved


Expand Down
37 changes: 37 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,43 @@ def test_early_stopping_via_global_params(first_metric_only):
assert "error" in gbm.best_score[valid_set_name]


@pytest.mark.parametrize("early_stopping_round", [-10, -1, 0, None, "None"])
def test_early_stopping_non_positive_values(early_stopping_round):
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
X, y = load_breast_cancer(return_X_y=True)
num_trees = 5
params = {
"num_trees": num_trees,
"objective": "binary",
"metric": "None",
"verbose": -1,
"early_stopping_round": early_stopping_round,
"first_metric_only": True,
}
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
valid_set_name = "valid_set"

if early_stopping_round is None:
gbm = lgb.train(
params, lgb_train, feval=[decreasing_metric, constant_metric],
valid_sets=lgb_eval, valid_names=valid_set_name
)
assert not "early_stopping_round" in gbm.params
elif early_stopping_round == "None":
with pytest.raises(TypeError):
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
gbm = lgb.train(
params, lgb_train, feval=[decreasing_metric, constant_metric],
valid_sets=lgb_eval, valid_names=valid_set_name
)
elif early_stopping_round <=0:
gbm = lgb.train(
params, lgb_train, feval=[decreasing_metric, constant_metric],
valid_sets=lgb_eval, valid_names=valid_set_name
)
assert gbm.params["early_stopping_round"] == early_stopping_round
jameslamb marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize("first_only", [True, False])
@pytest.mark.parametrize("single_metric", [True, False])
@pytest.mark.parametrize("greater_is_better", [True, False])
Expand Down
Loading