From 14fcf216b10bad8ec2a74beec584b8914f19fe5b Mon Sep 17 00:00:00 2001 From: Atanas Dimitrov Date: Mon, 5 Aug 2024 18:52:26 +0300 Subject: [PATCH 1/9] Use tmp_path instead of leaving files --- tests/c_api_test/test_.py | 7 ++++--- tests/python_package_test/test_engine.py | 15 ++++++++------- tests/python_package_test/test_sklearn.py | 7 ++++--- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/tests/c_api_test/test_.py b/tests/c_api_test/test_.py index 77fb7f6e8ead..2931d8b38cc7 100644 --- a/tests/c_api_test/test_.py +++ b/tests/c_api_test/test_.py @@ -175,11 +175,12 @@ def test_dataset(): free_dataset(train) -def test_booster(): +def test_booster(tmp_path): binary_example_dir = Path(__file__).absolute().parents[2] / "examples" / "binary_classification" train = load_from_mat(binary_example_dir / "binary.train", None) test = load_from_mat(binary_example_dir / "binary.test", train) booster = ctypes.c_void_p() + model_path = tmp_path / "model.txt" LIB.LGBM_BoosterCreate(train, c_str("app=binary metric=auc num_leaves=31 verbose=0"), ctypes.byref(booster)) LIB.LGBM_BoosterAddValidData(booster, test) is_finished = ctypes.c_int(0) @@ -192,13 +193,13 @@ def test_booster(): ) if i % 10 == 0: print(f"{i} iteration test AUC {result[0]:.6f}") - LIB.LGBM_BoosterSaveModel(booster, ctypes.c_int(0), ctypes.c_int(-1), ctypes.c_int(0), c_str("model.txt")) + LIB.LGBM_BoosterSaveModel(booster, ctypes.c_int(0), ctypes.c_int(-1), ctypes.c_int(0), c_str(str(model_path))) LIB.LGBM_BoosterFree(booster) free_dataset(train) free_dataset(test) booster2 = ctypes.c_void_p() num_total_model = ctypes.c_int(0) - LIB.LGBM_BoosterCreateFromModelfile(c_str("model.txt"), ctypes.byref(num_total_model), ctypes.byref(booster2)) + LIB.LGBM_BoosterCreateFromModelfile(c_str(str(model_path)), ctypes.byref(num_total_model), ctypes.byref(booster2)) data = np.loadtxt(str(binary_example_dir / "binary.test"), dtype=np.float64) mat = data[:, 1:] preb = np.empty(mat.shape[0], dtype=np.float64) diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 9ff56206ca70..a1eab3a15aad 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -1112,15 +1112,15 @@ def _early_stop_after_seventh_iteration(env): assert bst.current_iteration() == 7 -def test_continue_train(): +def test_continue_train(tmp_path): X, y = make_synthetic_regression() X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) params = {"objective": "regression", "metric": "l1", "verbose": -1} lgb_train = lgb.Dataset(X_train, y_train, free_raw_data=False) lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train, free_raw_data=False) init_gbm = lgb.train(params, lgb_train, num_boost_round=20) - model_name = "model.txt" - init_gbm.save_model(model_name) + model_path = tmp_path / "model.txt" + init_gbm.save_model(model_path) evals_result = {} gbm = lgb.train( params, @@ -1130,7 +1130,7 @@ def test_continue_train(): # test custom eval metrics feval=(lambda p, d: ("custom_mae", mean_absolute_error(p, d.get_label()), False)), callbacks=[lgb.record_evaluation(evals_result)], - init_model="model.txt", + init_model=model_path, ) ret = mean_absolute_error(y_test, gbm.predict(X_test)) assert ret < 13.6 @@ -1713,7 +1713,7 @@ def test_all_expected_params_are_written_out_to_model_text(tmp_path): # why fixed seed? # sometimes there is no difference how cols are treated (cat or not cat) -def test_pandas_categorical(rng_fixed_seed): +def test_pandas_categorical(rng_fixed_seed, tmp_path): pd = pytest.importorskip("pandas") X = pd.DataFrame( { @@ -1756,8 +1756,9 @@ def test_pandas_categorical(rng_fixed_seed): gbm3 = lgb.train(params, lgb_train, num_boost_round=10, categorical_feature=["A", "B", "C", "D"]) pred3 = gbm3.predict(X_test) assert lgb_train.categorical_feature == ["A", "B", "C", "D"] - gbm3.save_model("categorical.model") - gbm4 = lgb.Booster(model_file="categorical.model") + categorical_model_path = tmp_path / "categorical.model" + gbm3.save_model(categorical_model_path) + gbm4 = lgb.Booster(model_file=categorical_model_path) pred4 = gbm4.predict(X_test) model_str = gbm4.model_to_string() gbm4.model_from_string(model_str) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 6f0f7cb2ff3a..0101b33ce93f 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -558,7 +558,7 @@ def test_feature_importances_type(): # why fixed seed? # sometimes there is no difference how cols are treated (cat or not cat) -def test_pandas_categorical(rng_fixed_seed): +def test_pandas_categorical(rng_fixed_seed, tmp_path): pd = pytest.importorskip("pandas") X = pd.DataFrame( { @@ -593,8 +593,9 @@ def test_pandas_categorical(rng_fixed_seed): pred2 = gbm2.predict(X_test, raw_score=True) gbm3 = lgb.sklearn.LGBMClassifier(n_estimators=10).fit(X, y, categorical_feature=["A", "B", "C", "D"]) pred3 = gbm3.predict(X_test, raw_score=True) - gbm3.booster_.save_model("categorical.model") - gbm4 = lgb.Booster(model_file="categorical.model") + categorical_model_path = tmp_path / "categorical.model" + gbm3.booster_.save_model(categorical_model_path) + gbm4 = lgb.Booster(model_file=categorical_model_path) pred4 = gbm4.predict(X_test) gbm5 = lgb.sklearn.LGBMClassifier(n_estimators=10).fit(X, y, categorical_feature=["A", "B", "C", "D", "E"]) pred5 = gbm5.predict(X_test, raw_score=True) From 954e6a949b826a65b57ddfeabfe7357f0350f61d Mon Sep 17 00:00:00 2001 From: Atanas Dimitrov Date: Thu, 8 Aug 2024 16:11:02 +0300 Subject: [PATCH 2/9] Init --- include/LightGBM/boosting.h | 5 +++++ include/LightGBM/c_api.h | 6 ++++++ python-package/lightgbm/basic.py | 35 ++++++++++++++++++++++++++++++++ src/boosting/gbdt.cpp | 17 ++++++++++++++++ src/boosting/gbdt.h | 2 ++ src/c_api.cpp | 15 ++++++++++++++ 6 files changed, 80 insertions(+) diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index 2e620f8680f6..f59f1c5b4f23 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -76,6 +76,11 @@ class LIGHTGBM_EXPORT Boosting { */ virtual void RefitTree(const int* tree_leaf_prediction, const size_t nrow, const size_t ncol) = 0; + /*! + * \brief Change the leaf values of a tree and update the scores + */ + virtual void RefitTreeManual(int tree_idx, const double *vals) = 0; + /*! * \brief Training logic * \param gradients nullptr for using default objective, otherwise use self-defined boosting diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index b43f096c31ee..fa8af3a18045 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -778,6 +778,12 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterRefit(BoosterHandle handle, int32_t nrow, int32_t ncol); +/*! + */ +LIGHTGBM_C_EXPORT int LGBM_BoosterRefitTreeManual(BoosterHandle handle, + int32_t tree_idx, + const double *values); + /*! * \brief Update the model by specifying gradient and Hessian directly * (this can be used to support customized loss functions). diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index af4d757f480b..8e1ea81647c5 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -4912,6 +4912,41 @@ def refit( new_booster._network = self._network return new_booster + def refit_tree_manual( + self, + tree_id: int, + values: np.ndarray + ) -> None: + """Set all the outputs of a tree and recalculate the dataset scores. + + .. versionadded:: 4.6.0 + + Parameters + ---------- + tree_id : int + The index of the tree. + values : numpy 1-D array + Value to set as the outputs of the tree. + The number of elements should be equal to the number of leaves in the tree. + + Returns + ------- + self : Booster + Booster with the leaf outputs set. + """ + values = _list_to_1d_numpy(values, dtype=np.float64, name="leaf_values") + + _safe_call( + _LIB.LGBM_BoosterRefitTreeManual( + self._handle, + ctypes.c_int(tree_id), + values.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) + ) + ) + self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)] + return self + + def get_leaf_output(self, tree_id: int, leaf_id: int) -> float: """Get the output of a leaf. diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 937b44fcc8aa..787d89e03e34 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -289,6 +289,23 @@ void GBDT::RefitTree(const int* tree_leaf_prediction, const size_t nrow, const s } } +void GBDT::RefitTreeManual(int tree_idx, const double *vals) { + CHECK(tree_idx >= 0 && static_cast(tree_idx) < models_.size()); + // reset score + for (int leaf_id = 0; leaf_id < models_[tree_idx]->num_leaves(); ++leaf_id) { + models_[tree_idx]->SetLeafOutput(leaf_id, vals[leaf_id] - models_[tree_idx]->LeafOutput(leaf_id)); + } + // add the delta + train_score_updater_->AddScore(models_[tree_idx].get(), tree_idx % num_tree_per_iteration_); + for (auto& score_updater : valid_score_updater_) { + score_updater->AddScore(models_[tree_idx].get(), tree_idx % num_tree_per_iteration_); + } + // update the model + for (int leaf_id = 0; leaf_id < models_[tree_idx]->num_leaves(); ++leaf_id) { + models_[tree_idx]->SetLeafOutput(leaf_id, vals[leaf_id]); + } +} + /* If the custom "average" is implemented it will be used in place of the label average (if enabled) * * An improvement to this is to have options to explicitly choose diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index fa53a664b45b..06272662012e 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -145,6 +145,8 @@ class GBDT : public GBDTBase { void RefitTree(const int* tree_leaf_prediction, const size_t nrow, const size_t ncol) override; + void RefitTreeManual(int tree_idx, const double *vals) override; + /*! * \brief Training logic * \param gradients nullptr for using default objective, otherwise use self-defined boosting diff --git a/src/c_api.cpp b/src/c_api.cpp index 98748bc9ff2f..7b1da686f207 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -412,6 +412,11 @@ class Booster { boosting_->RefitTree(leaf_preds, nrow, ncol); } + void RefitTreeManual(int tree_idx, const double *vals) { + UNIQUE_LOCK(mutex_) + boosting_->RefitTreeManual(tree_idx, vals); + } + bool TrainOneIter(const score_t* gradients, const score_t* hessians) { UNIQUE_LOCK(mutex_) return boosting_->TrainOneIter(gradients, hessians); @@ -2058,6 +2063,16 @@ int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t n API_END(); } +int LGBM_BoosterRefitTreeManual(BoosterHandle handle, + int tree_idx, + const double *val) { + API_BEGIN(); + Booster* ref_booster = reinterpret_cast(handle); + ref_booster->RefitTreeManual(tree_idx, val); + API_END(); +} + + int LGBM_BoosterUpdateOneIter(BoosterHandle handle, int* is_finished) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); From a07f1a237a0bfb3acf7aa26d3f8c1f8adf20f091 Mon Sep 17 00:00:00 2001 From: Atanas Dimitrov Date: Thu, 8 Aug 2024 16:23:05 +0300 Subject: [PATCH 3/9] Prevent faulty data --- python-package/lightgbm/basic.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 8e1ea81647c5..eae45435ff38 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -4936,6 +4936,9 @@ def refit_tree_manual( """ values = _list_to_1d_numpy(values, dtype=np.float64, name="leaf_values") + if len(values) != self.num_leaves(tree_id): + raise ValueError("Length of values should be equal to the number of leaves in the tree") + _safe_call( _LIB.LGBM_BoosterRefitTreeManual( self._handle, From 75a4ec5b9878206dc2c671a8466aadd3d6f5bac0 Mon Sep 17 00:00:00 2001 From: Atanas Dimitrov Date: Fri, 9 Aug 2024 14:59:25 +0300 Subject: [PATCH 4/9] Add validation checks --- include/LightGBM/boosting.h | 2 +- include/LightGBM/c_api.h | 3 ++- python-package/lightgbm/basic.py | 6 ++---- src/boosting/gbdt.cpp | 4 ++-- src/boosting/gbdt.h | 2 +- src/c_api.cpp | 9 +++++---- 6 files changed, 13 insertions(+), 13 deletions(-) diff --git a/include/LightGBM/boosting.h b/include/LightGBM/boosting.h index f59f1c5b4f23..52f27f44c538 100644 --- a/include/LightGBM/boosting.h +++ b/include/LightGBM/boosting.h @@ -79,7 +79,7 @@ class LIGHTGBM_EXPORT Boosting { /*! * \brief Change the leaf values of a tree and update the scores */ - virtual void RefitTreeManual(int tree_idx, const double *vals) = 0; + virtual void RefitTreeManual(int tree_idx, const double *vals, const int vals_size) = 0; /*! * \brief Training logic diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index fa8af3a18045..821d82197775 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -782,7 +782,8 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterRefit(BoosterHandle handle, */ LIGHTGBM_C_EXPORT int LGBM_BoosterRefitTreeManual(BoosterHandle handle, int32_t tree_idx, - const double *values); + const double *vals, + const int vals_size); /*! * \brief Update the model by specifying gradient and Hessian directly diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index eae45435ff38..90a264ae4a53 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -4936,14 +4936,12 @@ def refit_tree_manual( """ values = _list_to_1d_numpy(values, dtype=np.float64, name="leaf_values") - if len(values) != self.num_leaves(tree_id): - raise ValueError("Length of values should be equal to the number of leaves in the tree") - _safe_call( _LIB.LGBM_BoosterRefitTreeManual( self._handle, ctypes.c_int(tree_id), - values.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) + values.ctypes.data_as(ctypes.POINTER(ctypes.c_double)), + ctypes.c_int(len(values)), ) ) self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)] diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 787d89e03e34..6d7234a9e167 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -289,8 +289,8 @@ void GBDT::RefitTree(const int* tree_leaf_prediction, const size_t nrow, const s } } -void GBDT::RefitTreeManual(int tree_idx, const double *vals) { - CHECK(tree_idx >= 0 && static_cast(tree_idx) < models_.size()); +void GBDT::RefitTreeManual(int tree_idx, const double *vals, const int vals_size) { + CHECK(tree_idx >= 0 && static_cast(tree_idx) < models_.size() && vals_size == models_[tree_idx]->num_leaves()); // reset score for (int leaf_id = 0; leaf_id < models_[tree_idx]->num_leaves(); ++leaf_id) { models_[tree_idx]->SetLeafOutput(leaf_id, vals[leaf_id] - models_[tree_idx]->LeafOutput(leaf_id)); diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 06272662012e..0bc83bf895e6 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -145,7 +145,7 @@ class GBDT : public GBDTBase { void RefitTree(const int* tree_leaf_prediction, const size_t nrow, const size_t ncol) override; - void RefitTreeManual(int tree_idx, const double *vals) override; + void RefitTreeManual(int tree_idx, const double *vals, const int vals_size) override; /*! * \brief Training logic diff --git a/src/c_api.cpp b/src/c_api.cpp index 7b1da686f207..19bccf057eb8 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -412,9 +412,9 @@ class Booster { boosting_->RefitTree(leaf_preds, nrow, ncol); } - void RefitTreeManual(int tree_idx, const double *vals) { + void RefitTreeManual(int tree_idx, const double *vals, const int vals_size) { UNIQUE_LOCK(mutex_) - boosting_->RefitTreeManual(tree_idx, vals); + boosting_->RefitTreeManual(tree_idx, vals, vals_size); } bool TrainOneIter(const score_t* gradients, const score_t* hessians) { @@ -2065,10 +2065,11 @@ int LGBM_BoosterRefit(BoosterHandle handle, const int32_t* leaf_preds, int32_t n int LGBM_BoosterRefitTreeManual(BoosterHandle handle, int tree_idx, - const double *val) { + const double *vals, + const int vals_size) { API_BEGIN(); Booster* ref_booster = reinterpret_cast(handle); - ref_booster->RefitTreeManual(tree_idx, val); + ref_booster->RefitTreeManual(tree_idx, vals, vals_size); API_END(); } From b26a21cbc43aa2b90e68c2a9d62e1290230b7c6f Mon Sep 17 00:00:00 2001 From: Atanas Dimitrov Date: Fri, 16 Aug 2024 17:21:09 +0300 Subject: [PATCH 5/9] Add tests --- python-package/lightgbm/basic.py | 7 +---- tests/python_package_test/test_engine.py | 39 ++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 90a264ae4a53..929411206721 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -4912,11 +4912,7 @@ def refit( new_booster._network = self._network return new_booster - def refit_tree_manual( - self, - tree_id: int, - values: np.ndarray - ) -> None: + def refit_tree_manual(self, tree_id: int, values: np.ndarray) -> "Booster": """Set all the outputs of a tree and recalculate the dataset scores. .. versionadded:: 4.6.0 @@ -4947,7 +4943,6 @@ def refit_tree_manual( self.__is_predicted_cur_iter = [False for _ in range(self.__num_dataset)] return self - def get_leaf_output(self, tree_id: int, leaf_id: int) -> float: """Get the output of a leaf. diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index a1eab3a15aad..95dd04a49219 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -2332,6 +2332,45 @@ def test_refit_dataset_params(rng): np.testing.assert_allclose(stored_weights, refit_weight) +def test_refit_tree_manual(): + def retrieve_leaves_from_tree(tree): + if "leaf_index" in tree: + return {tree["leaf_index"]: tree["leaf_value"]} + + left_child = retrieve_leaves_from_tree(tree["left_child"]) + right_child = retrieve_leaves_from_tree(tree["right_child"]) + + return left_child | right_child + + def retrieve_leaves_from_booster(booster, iteration): + tree = booster.dump_model(0, iteration)["tree_info"][0]["tree_structure"] + return retrieve_leaves_from_tree(tree) + + def debias_callback(env): + booster = env.model + curr_values = retrieve_leaves_from_booster(booster, env.iteration) + eval_pred = booster.predict(df) + delta = np.log(np.mean(y) / np.mean(eval_pred)) + refitted_values = [curr_values[ix] + delta for ix in range(len(curr_values))] + booster.refit_tree_manual(env.iteration, refitted_values) + + X, y = make_synthetic_regression() + y = np.abs(y) + df = pd_DataFrame(X, columns=["x1", "x2", "x3", "x4"]) + ds = lgb.Dataset(df, y) + + params = { + "verbose": -1, + "n_estimators": 5, + "num_leaves": 5, + "objective": "gamma", + } + bst = lgb.train(params, ds, callbacks=[debias_callback]) + + # Check if debiasing worked + np.testing.assert_allclose(bst.predict(df).mean(), y.mean()) + + @pytest.mark.parametrize("boosting_type", ["rf", "dart"]) def test_mape_for_specific_boosting_types(boosting_type): X, y = make_synthetic_regression() From 9c5540258be51dd3da4c6e12755019b226497a88 Mon Sep 17 00:00:00 2001 From: Atanas Dimitrov Date: Fri, 16 Aug 2024 18:00:44 +0300 Subject: [PATCH 6/9] Add comment --- include/LightGBM/c_api.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 821d82197775..5bdae492d781 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -779,6 +779,14 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterRefit(BoosterHandle handle, int32_t ncol); /*! + * \brief Refit a single tree by specifying a new set of leaf scores for each data point + * \note + * The length of the array referenced by ``vals`` must be equal to the number of leaves. + * \param handle Handle of the Booster model + * \param tree_idx Index of the tree to refit + * \param vals The new set of leaf scores for each data point + * \param vals_size Number of data points for which leaf scores are provided + * \return 0 when successful, -1 when failure occurs */ LIGHTGBM_C_EXPORT int LGBM_BoosterRefitTreeManual(BoosterHandle handle, int32_t tree_idx, From a387aa50291c470e9dc293b027f1f0bc265643b6 Mon Sep 17 00:00:00 2001 From: Atanas Dimitrov Date: Fri, 16 Aug 2024 18:20:08 +0300 Subject: [PATCH 7/9] Fix for 3.8 --- tests/python_package_test/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index d7e17de156ff..ab74346f9312 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -2341,7 +2341,7 @@ def retrieve_leaves_from_tree(tree): left_child = retrieve_leaves_from_tree(tree["left_child"]) right_child = retrieve_leaves_from_tree(tree["right_child"]) - return left_child | right_child + return {**left_child, **right_child} def retrieve_leaves_from_booster(booster, iteration): tree = booster.dump_model(0, iteration)["tree_info"][0]["tree_structure"] From 1d73e4e6a6ba92098c35de90e9a95d01b6482476 Mon Sep 17 00:00:00 2001 From: Atanas Dimitrov Date: Mon, 2 Sep 2024 17:48:10 +0300 Subject: [PATCH 8/9] Comments after code review --- src/boosting/gbdt.cpp | 6 ++++-- tests/python_package_test/test_engine.py | 6 +++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 6d7234a9e167..a3e794a59301 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -290,8 +290,10 @@ void GBDT::RefitTree(const int* tree_leaf_prediction, const size_t nrow, const s } void GBDT::RefitTreeManual(int tree_idx, const double *vals, const int vals_size) { - CHECK(tree_idx >= 0 && static_cast(tree_idx) < models_.size() && vals_size == models_[tree_idx]->num_leaves()); - // reset score + CHECK(tree_idx >= 0); + CHECK(static_cast(tree_idx) < models_.size()); + CHECK(vals_size == models_[tree_idx]->num_leaves()); + // reset score by adding the difference for (int leaf_id = 0; leaf_id < models_[tree_idx]->num_leaves(); ++leaf_id) { models_[tree_idx]->SetLeafOutput(leaf_id, vals[leaf_id] - models_[tree_idx]->LeafOutput(leaf_id)); } diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index ab74346f9312..7432355df795 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -2366,9 +2366,13 @@ def debias_callback(env): "num_leaves": 5, "objective": "gamma", } - bst = lgb.train(params, ds, callbacks=[debias_callback]) + + # Check that the model is biased when no callback is provided + bst = lgb.train(params, ds) + np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, bst.predict(df).mean(), y.mean()) # Check if debiasing worked + bst = lgb.train(params, ds, callbacks=[debias_callback]) np.testing.assert_allclose(bst.predict(df).mean(), y.mean()) From bdc173baae9c248a9716ecf87b83107f4371535f Mon Sep 17 00:00:00 2001 From: Atanas Dimitrov Date: Tue, 3 Sep 2024 11:58:21 +0300 Subject: [PATCH 9/9] Appease linter --- src/boosting/gbdt.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index a3e794a59301..00b7a80dd494 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -290,9 +290,9 @@ void GBDT::RefitTree(const int* tree_leaf_prediction, const size_t nrow, const s } void GBDT::RefitTreeManual(int tree_idx, const double *vals, const int vals_size) { - CHECK(tree_idx >= 0); - CHECK(static_cast(tree_idx) < models_.size()); - CHECK(vals_size == models_[tree_idx]->num_leaves()); + CHECK_GE(tree_idx, 0); + CHECK_LT(static_cast(tree_idx), models_.size()); + CHECK_EQ(vals_size, models_[tree_idx]->num_leaves()); // reset score by adding the difference for (int leaf_id = 0; leaf_id < models_[tree_idx]->num_leaves(); ++leaf_id) { models_[tree_idx]->SetLeafOutput(leaf_id, vals[leaf_id] - models_[tree_idx]->LeafOutput(leaf_id));