Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Oct 5, 2023
1 parent bd3366a commit 7a98d82
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
4 changes: 2 additions & 2 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ def __init__(
min_delta: Union[float, List[float]] = 0.0
) -> None:

if stopping_rounds <= 0:
raise ValueError(f"stopping_rounds should be greater than zero. got: {stopping_rounds}")
if not isinstance(stopping_rounds, int) or stopping_rounds <= 0:
raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}")

self.order = 30
self.before_iteration = False
Expand Down
11 changes: 11 additions & 0 deletions tests/python_package_test/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ def test_early_stopping_callback_is_picklable(serializer):
assert callback.stopping_rounds == rounds


def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors():
with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: 0"):
lgb.early_stopping(stopping_rounds=0)

with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: -1"):
lgb.early_stopping(stopping_rounds=-1)

with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: neverrrr"):
lgb.early_stopping(stopping_rounds="neverrrr")


@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_log_evaluation_callback_is_picklable(serializer):
periods = 42
Expand Down
4 changes: 2 additions & 2 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4501,9 +4501,9 @@ def test_train_raises_informative_error_if_any_valid_sets_are_not_dataset_object

def test_train_raises_informative_error_for_params_of_wrong_type():
X, y = make_synthetic_regression()
params = {"early_stopping_round": "too-many"}
params = {"num_leaves": "too-many"}
dtrain = lgb.Dataset(X, label=y)
with pytest.raises(lgb.basic.LightGBMError, match="Parameter early_stopping_round should be of type int, got \"too-many\""):
with pytest.raises(lgb.basic.LightGBMError, match="Parameter num_leaves should be of type int, got \"too-many\""):
lgb.train(params, dtrain)


Expand Down

0 comments on commit 7a98d82

Please sign in to comment.