diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 9a852b093cee..38fe74cc016f 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -3246,8 +3246,13 @@ def __init__( elif model_str is not None: self.model_from_string(model_str) # ensure params are updated on the C++ side + # NOTE: models loaded from file are initially set to "boosting: GBDT", so "boosting" + # shouldn't be passed through here self.params = params + boosting_type = params.pop("boosting", None) self.reset_parameter(params) + if boosting_type is not None: + params["boosting"] = boosting_type else: raise TypeError('Need at least one training dataset or model file or model string ' 'to create Booster instance') diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index b11c0d2f2548..f3330b616016 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -15,7 +15,7 @@ import lightgbm as lgb from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame, pd_Series -from .utils import dummy_obj, load_breast_cancer, mse_obj +from .utils import BOOSTING_TYPES, dummy_obj, load_breast_cancer, mse_obj def test_basic(tmp_path): @@ -825,8 +825,11 @@ def test_feature_names_are_set_correctly_when_no_feature_names_passed_into_Datas assert ds.construct().feature_name == ["Column_0", "Column_1", "Column_2"] -def test_booster_deepcopy_preserves_parameters(): +@pytest.mark.parametrize('boosting_type', BOOSTING_TYPES) +def test_booster_deepcopy_preserves_parameters(boosting_type): orig_params = { + 'boosting': boosting_type, + 'feature_fraction': 0.708, 'num_leaves': 5, 'verbosity': -1 } @@ -841,11 +844,19 @@ def test_booster_deepcopy_preserves_parameters(): assert bst.params["verbosity"] == -1 # passed-in params shouldn't have been modified outside of lightgbm - assert orig_params == {'num_leaves': 5, 'verbosity': -1} + assert orig_params == { + 'boosting': boosting_type, + 'feature_fraction': 0.708, + 'num_leaves': 5, + 'verbosity': -1 + } -def test_booster_params_kwarg_overrides_params_from_model_string(): +@pytest.mark.parametrize('boosting_type', BOOSTING_TYPES) +def test_booster_params_kwarg_overrides_params_from_model_string(boosting_type): orig_params = { + 'boosting': boosting_type, + 'feature_fraction': 0.708, 'num_leaves': 5, 'verbosity': -1 } @@ -863,5 +874,14 @@ def test_booster_params_kwarg_overrides_params_from_model_string(): assert bst2.params["num_leaves"] == 7 assert "[num_leaves: 7]" in bst2.model_to_string() + # boosting type should have been preserved in the new model + if boosting_type != "gbdt": + raise RuntimeError + # passed-in params shouldn't have been modified outside of lightgbm - assert orig_params == {'num_leaves': 5, 'verbosity': -1} + assert orig_params == { + 'boosting': boosting_type, + 'feature_fraction': 0.708, + 'num_leaves': 5, + 'verbosity': -1 + } diff --git a/tests/python_package_test/utils.py b/tests/python_package_test/utils.py index df01e29852e7..4d7cae3c0f62 100644 --- a/tests/python_package_test/utils.py +++ b/tests/python_package_test/utils.py @@ -10,6 +10,7 @@ import lightgbm as lgb +BOOSTING_TYPES = ['gbdt', 'dart', 'goss', 'rf'] SERIALIZERS = ["pickle", "joblib", "cloudpickle"]