From f74875ed60e696ee7d223ddb409e66f51bddbb47 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 18 Apr 2023 22:20:40 -0500 Subject: [PATCH] [python-package] move validation up earlier in cv() and train() (#5836) --- python-package/lightgbm/engine.py | 29 +++++++++++++-------- tests/python_package_test/test_engine.py | 32 ++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 10 deletions(-) diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 5e6e615fb4c9..859f513a19a5 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -141,6 +141,20 @@ def train( booster : Booster The trained Booster model. """ + if not isinstance(train_set, Dataset): + raise TypeError(f"train() only accepts Dataset object, train_set has type '{type(train_set).__name__}'.") + + if num_boost_round <= 0: + raise ValueError(f"num_boost_round must be greater than 0. Got {num_boost_round}.") + + if isinstance(valid_sets, list): + for i, valid_item in enumerate(valid_sets): + if not isinstance(valid_item, Dataset): + raise TypeError( + "Every item in valid_sets must be a Dataset object. " + f"Item {i} has type '{type(valid_item).__name__}'." + ) + # create predictor first params = copy.deepcopy(params) params = _choose_param_value( @@ -167,17 +181,12 @@ def train( params.pop("early_stopping_round") first_metric_only = params.get('first_metric_only', False) - if num_boost_round <= 0: - raise ValueError("num_boost_round should be greater than zero.") predictor: Optional[_InnerPredictor] = None if isinstance(init_model, (str, Path)): predictor = _InnerPredictor(model_file=init_model, pred_parameter=params) elif isinstance(init_model, Booster): predictor = init_model._to_predictor(pred_parameter=dict(init_model.params, **params)) init_iteration = predictor.num_total_iteration if predictor is not None else 0 - # check dataset - if not isinstance(train_set, Dataset): - raise TypeError("Training only accepts Dataset object") train_set._update_params(params) \ ._set_predictor(predictor) \ @@ -200,8 +209,6 @@ def train( if valid_names is not None: train_data_name = valid_names[i] continue - if not isinstance(valid_data, Dataset): - raise TypeError("Training only accepts Dataset object") reduced_valid_sets.append(valid_data._update_params(params).set_reference(train_set)) if valid_names is not None and len(valid_names) > i: name_valid_sets.append(valid_names[i]) @@ -647,7 +654,11 @@ def cv( If ``return_cvbooster=True``, also returns trained boosters via ``cvbooster`` key. """ if not isinstance(train_set, Dataset): - raise TypeError("Training only accepts Dataset object") + raise TypeError(f"cv() only accepts Dataset object, train_set has type '{type(train_set).__name__}'.") + + if num_boost_round <= 0: + raise ValueError(f"num_boost_round must be greater than 0. Got {num_boost_round}.") + params = copy.deepcopy(params) params = _choose_param_value( main_param_name='objective', @@ -673,8 +684,6 @@ def cv( params.pop("early_stopping_round") first_metric_only = params.get('first_metric_only', False) - if num_boost_round <= 0: - raise ValueError("num_boost_round should be greater than zero.") if isinstance(init_model, (str, Path)): predictor = _InnerPredictor(model_file=init_model, pred_parameter=params) elif isinstance(init_model, Booster): diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 9e7f8e8a8d5a..d152c2c359d3 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -4017,6 +4017,38 @@ def test_validate_features(): bst.refit(df2, y, validate_features=False) +def test_train_and_cv_raise_informative_error_for_train_set_of_wrong_type(): + with pytest.raises(TypeError, match=r"train\(\) only accepts Dataset object, train_set has type 'list'\."): + lgb.train({}, train_set=[]) + with pytest.raises(TypeError, match=r"cv\(\) only accepts Dataset object, train_set has type 'list'\."): + lgb.cv({}, train_set=[]) + + +@pytest.mark.parametrize('num_boost_round', [-7, -1, 0]) +def test_train_and_cv_raise_informative_error_for_impossible_num_boost_round(num_boost_round): + X, y = make_synthetic_regression(n_samples=100) + error_msg = rf"num_boost_round must be greater than 0\. Got {num_boost_round}\." + with pytest.raises(ValueError, match=error_msg): + lgb.train({}, train_set=lgb.Dataset(X, y), num_boost_round=num_boost_round) + with pytest.raises(ValueError, match=error_msg): + lgb.cv({}, train_set=lgb.Dataset(X, y), num_boost_round=num_boost_round) + + +def test_train_raises_informative_error_if_any_valid_sets_are_not_dataset_objects(): + X, y = make_synthetic_regression(n_samples=100) + X_valid = X * 2.0 + with pytest.raises(TypeError, match=r"Every item in valid_sets must be a Dataset object\. Item 1 has type 'tuple'\."): + lgb.train( + params={}, + train_set=lgb.Dataset(X, y), + valid_sets=[ + lgb.Dataset(X_valid, y), + ([1.0], [2.0]), + [5.6, 5.7, 5.8] + ] + ) + + def test_train_raises_informative_error_for_params_of_wrong_type(): X, y = make_synthetic_regression() params = {"early_stopping_round": "too-many"}