Skip to content

Commit

Permalink
[python-package] preserve params when copying Booster (fixes #5539)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Sep 15, 2023
1 parent ab1eaa8 commit 5b917e2
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 2 deletions.
5 changes: 4 additions & 1 deletion python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3250,6 +3250,9 @@ def __init__(
'to create Booster instance')
self.params = params

# ensure params are updated on the C++ side
self.reset_parameter(params)

def __del__(self) -> None:
try:
if self._network:
Expand All @@ -3267,7 +3270,7 @@ def __copy__(self) -> "Booster":

def __deepcopy__(self, _) -> "Booster":
model_str = self.model_to_string(num_iteration=-1)
return Booster(model_str=model_str)
return Booster(model_str=model_str, params=self.params)

def __getstate__(self) -> Dict[str, Any]:
this = self.__dict__.copy()
Expand Down
5 changes: 4 additions & 1 deletion src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,10 @@ void GBDT::ResetConfig(const Config* config) {

boosting_on_gpu_ = objective_function_ != nullptr && objective_function_->IsCUDAObjective() &&
!data_sample_strategy_->IsHessianChange(); // for sample strategy with Hessian change, fall back to boosting on CPU
tree_learner_->ResetBoostingOnGPU(boosting_on_gpu_);

if (tree_learner_ != nullptr) {
tree_learner_->ResetBoostingOnGPU(boosting_on_gpu_);
}

if (train_data_ != nullptr) {
data_sample_strategy_->ResetSampleConfig(new_config.get(), false);
Expand Down
3 changes: 3 additions & 0 deletions tests/c_api_test/test_.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ def test_booster():
c_str('model.txt'),
ctypes.byref(num_total_model),
ctypes.byref(booster2))
LIB.LGBM_BoosterResetParameter(
booster2,
c_str("app=binary metric=auc num_leaves=29 verbose=0"))
data = np.loadtxt(str(binary_example_dir / 'binary.test'), dtype=np.float64)
mat = data[:, 1:]
preb = np.empty(mat.shape[0], dtype=np.float64)
Expand Down
26 changes: 26 additions & 0 deletions tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,3 +823,29 @@ def test_feature_names_are_set_correctly_when_no_feature_names_passed_into_Datas
data=np.random.randn(100, 3),
)
assert ds.construct().feature_name == ["Column_0", "Column_1", "Column_2"]


def test_booster_deepcopy_preserves_parameters():
bst = lgb.train(
params={'num_leaves': 5, 'verbosity': -1},
num_boost_round=2,
train_set=lgb.Dataset(np.random.rand(100, 2))
)
bst2 = copy.deepcopy(bst)
assert bst2.params == bst.params
assert bst.params["num_leaves"] == 5
assert bst.params["verbosity"] == -1


def test_booster_params_kwarg_overrides_params_from_model_string():
bst = lgb.train(
params={'num_leaves': 5, 'learning_rate': 0.708, 'verbosity': -1},
num_boost_round=2,
train_set=lgb.Dataset(np.random.rand(100, 2))
)
bst2 = lgb.Booster(
params={'num_leaves': 7},
model_str=bst.model_to_string()
)
assert bst2.params["num_leaves"] == 7
assert bst2.params["learning_rate"] == 0.708

0 comments on commit 5b917e2

Please sign in to comment.