Skip to content

Commit

Permalink
more changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Sep 16, 2023
1 parent aea7f0f commit c0e00c0
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
5 changes: 5 additions & 0 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3246,8 +3246,13 @@ def __init__(
elif model_str is not None:
self.model_from_string(model_str)
# ensure params are updated on the C++ side
# NOTE: models loaded from file are initially set to "boosting: GBDT", so "boosting"
# shouldn't be passed through here
self.params = params
boosting_type = params.pop("boosting", None)
self.reset_parameter(params)
if boosting_type is not None:
params["boosting"] = boosting_type
else:
raise TypeError('Need at least one training dataset or model file or model string '
'to create Booster instance')
Expand Down
30 changes: 25 additions & 5 deletions tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import lightgbm as lgb
from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame, pd_Series

from .utils import dummy_obj, load_breast_cancer, mse_obj
from .utils import BOOSTING_TYPES, dummy_obj, load_breast_cancer, mse_obj


def test_basic(tmp_path):
Expand Down Expand Up @@ -825,8 +825,11 @@ def test_feature_names_are_set_correctly_when_no_feature_names_passed_into_Datas
assert ds.construct().feature_name == ["Column_0", "Column_1", "Column_2"]


def test_booster_deepcopy_preserves_parameters():
@pytest.mark.parametrize('boosting_type', BOOSTING_TYPES)
def test_booster_deepcopy_preserves_parameters(boosting_type):
orig_params = {
'boosting': boosting_type,
'feature_fraction': 0.708,
'num_leaves': 5,
'verbosity': -1
}
Expand All @@ -841,11 +844,19 @@ def test_booster_deepcopy_preserves_parameters():
assert bst.params["verbosity"] == -1

# passed-in params shouldn't have been modified outside of lightgbm
assert orig_params == {'num_leaves': 5, 'verbosity': -1}
assert orig_params == {
'boosting': boosting_type,
'feature_fraction': 0.708,
'num_leaves': 5,
'verbosity': -1
}


def test_booster_params_kwarg_overrides_params_from_model_string():
@pytest.mark.parametrize('boosting_type', BOOSTING_TYPES)
def test_booster_params_kwarg_overrides_params_from_model_string(boosting_type):
orig_params = {
'boosting': boosting_type,
'feature_fraction': 0.708,
'num_leaves': 5,
'verbosity': -1
}
Expand All @@ -863,5 +874,14 @@ def test_booster_params_kwarg_overrides_params_from_model_string():
assert bst2.params["num_leaves"] == 7
assert "[num_leaves: 7]" in bst2.model_to_string()

# boosting type should have been preserved in the new model
if boosting_type != "gbdt":
raise RuntimeError

# passed-in params shouldn't have been modified outside of lightgbm
assert orig_params == {'num_leaves': 5, 'verbosity': -1}
assert orig_params == {
'boosting': boosting_type,
'feature_fraction': 0.708,
'num_leaves': 5,
'verbosity': -1
}
1 change: 1 addition & 0 deletions tests/python_package_test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import lightgbm as lgb

BOOSTING_TYPES = ['gbdt', 'dart', 'goss', 'rf']
SERIALIZERS = ["pickle", "joblib", "cloudpickle"]


Expand Down

0 comments on commit c0e00c0

Please sign in to comment.