From 703d11aa12e9a5eb8a87acf55907c911ed414c57 Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Thu, 21 Apr 2022 08:44:31 +0000 Subject: [PATCH 1/7] clear split info buffer in cegb_ before every iteration --- .../cost_effective_gradient_boosting.hpp | 13 +++++++++++++ src/treelearner/serial_tree_learner.cpp | 2 ++ 2 files changed, 15 insertions(+) diff --git a/src/treelearner/cost_effective_gradient_boosting.hpp b/src/treelearner/cost_effective_gradient_boosting.hpp index 4bc149148c79..3feebef8fe82 100644 --- a/src/treelearner/cost_effective_gradient_boosting.hpp +++ b/src/treelearner/cost_effective_gradient_boosting.hpp @@ -32,6 +32,7 @@ class CostEfficientGradientBoosting { return true; } } + void Init() { auto train_data = tree_learner_->train_data_; if (!init_) { @@ -63,6 +64,17 @@ class CostEfficientGradientBoosting { } init_ = true; } + + void BeforeTrain() { + // clear the splits in splits_per_leaf_ + const int num_total_splits = static_cast(splits_per_leaf_.size()); + const int num_threads = OMP_NUM_THREADS(); + #pragma omp parallel for schedule(static) num_threads(num_threads) + for (int i = 0; i < num_total_splits; ++i) { + splits_per_leaf_[i].Reset(); + } + } + double DetlaGain(int feature_index, int real_fidx, int leaf_index, int num_data_in_leaf, SplitInfo split_info) { auto config = tree_learner_->config_; @@ -82,6 +94,7 @@ class CostEfficientGradientBoosting { feature_index] = split_info; return delta; } + void UpdateLeafBestSplits(Tree* tree, int best_leaf, const SplitInfo* best_split_info, std::vector* best_split_per_leaf) { diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp index 304c712f0723..79dc017ae4d5 100644 --- a/src/treelearner/serial_tree_learner.cpp +++ b/src/treelearner/serial_tree_learner.cpp @@ -278,6 +278,8 @@ void SerialTreeLearner::BeforeTrain() { } larger_leaf_splits_->Init(); + + cegb_->BeforeTrain(); } bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) { From fad48d4d09aa6b772883417b6913f460f10af46d Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Thu, 5 May 2022 02:47:03 +0000 Subject: [PATCH 2/7] check nullable of cegb_ in serial_tree_learner.cpp --- src/treelearner/serial_tree_learner.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp index 79dc017ae4d5..1aa9e57b04d7 100644 --- a/src/treelearner/serial_tree_learner.cpp +++ b/src/treelearner/serial_tree_learner.cpp @@ -279,7 +279,9 @@ void SerialTreeLearner::BeforeTrain() { larger_leaf_splits_->Init(); - cegb_->BeforeTrain(); + if (cegb_ != nullptr) { + cegb_->BeforeTrain(); + } } bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) { From 0bc4333218e79e793ebb55af48eb7e897f4a02aa Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Mon, 9 May 2022 03:36:37 +0000 Subject: [PATCH 3/7] add a test case for checking the split buffer in CEGB --- tests/python_package_test/test_basic.py | 46 +++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 4d6c367d8150..f8f0dc7e3b40 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -430,6 +430,52 @@ def test_cegb_scaling_equalities(tmp_path): assert p1txt == p2txt +def test_cegb_split_buffer_clean(): + # modified from https://github.com/microsoft/LightGBM/issues/3679#issuecomment-938652811 + # and https://github.com/microsoft/LightGBM/pull/5087 + # test that the ``splits_per_leaf_`` of CEGB is cleaned before training a new tree + # which is done in the fix #5164 + # without the fix: + # Check failed: (best_split_info.left_count) > (0) + + R, C = 1000, 100 + seed = 29 + np.random.seed(seed) + data = np.random.randn(R, C) + for i in range(1, C): + data[i] += data[0] * np.random.randn() + + N = int(0.8 * len(data)) + train_data = data[:N] + test_data = data[N:] + train_y = np.sum(train_data, axis=1) + test_y = np.sum(test_data, axis=1) + + train = lgb.Dataset(train_data, train_y, free_raw_data=True) + test = lgb.Dataset(test_data, test_y, free_raw_data=True, reference=train) + + # The test is run twice, on cpu and gpu + params = { + 'device': "cpu", + 'boosting_type': 'gbdt', + 'objective': 'regression', + 'max_bin': 255, + 'num_leaves': 31, + 'seed': 0, + 'learning_rate': 0.1, + 'min_data_in_leaf': 0, + 'verbose': 2, + 'min_split_gain': 1000.0, + 'cegb_penalty_feature_coupled': 5 * np.arange(C), + 'cegb_penalty_split': 0.0002, + 'cegb_tradeoff': 10.0, + 'num_threads': 16, + 'force_col_wise': True, + } + + lgb.train(params, train, num_boost_round=20, valid_sets=test) + + def test_consistent_state_for_dataset_fields(): def check_asserts(data): From 45de0a42303e5885487cdb29b74e75a42dea6bb3 Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Tue, 10 May 2022 02:28:55 +0000 Subject: [PATCH 4/7] swith to Threading::For instead of raw OpenMP --- .../cost_effective_gradient_boosting.hpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/treelearner/cost_effective_gradient_boosting.hpp b/src/treelearner/cost_effective_gradient_boosting.hpp index e18fe70bd22a..4c29deb82de4 100644 --- a/src/treelearner/cost_effective_gradient_boosting.hpp +++ b/src/treelearner/cost_effective_gradient_boosting.hpp @@ -10,6 +10,7 @@ #include #include #include +#include #include @@ -67,12 +68,12 @@ class CostEfficientGradientBoosting { void BeforeTrain() { // clear the splits in splits_per_leaf_ - const int num_total_splits = static_cast(splits_per_leaf_.size()); - const int num_threads = OMP_NUM_THREADS(); - #pragma omp parallel for schedule(static) num_threads(num_threads) - for (int i = 0; i < num_total_splits; ++i) { - splits_per_leaf_[i].Reset(); - } + Threading::For(0, splits_per_leaf_.size(), 1024, + [this] (int /*thread_index*/, size_t start, size_t end) { + for (size_t i = start; i < end; ++i) { + splits_per_leaf_[i].Reset(); + } + }); } double DeltaGain(int feature_index, int real_fidx, int leaf_index, From a04d00470072ac1ddaea9853dd912607b309596e Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Wed, 11 May 2022 02:33:21 +0000 Subject: [PATCH 5/7] apply review suggestions --- tests/python_package_test/test_basic.py | 46 ----------------------- tests/python_package_test/test_engine.py | 47 ++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 46 deletions(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index f8f0dc7e3b40..4d6c367d8150 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -430,52 +430,6 @@ def test_cegb_scaling_equalities(tmp_path): assert p1txt == p2txt -def test_cegb_split_buffer_clean(): - # modified from https://github.com/microsoft/LightGBM/issues/3679#issuecomment-938652811 - # and https://github.com/microsoft/LightGBM/pull/5087 - # test that the ``splits_per_leaf_`` of CEGB is cleaned before training a new tree - # which is done in the fix #5164 - # without the fix: - # Check failed: (best_split_info.left_count) > (0) - - R, C = 1000, 100 - seed = 29 - np.random.seed(seed) - data = np.random.randn(R, C) - for i in range(1, C): - data[i] += data[0] * np.random.randn() - - N = int(0.8 * len(data)) - train_data = data[:N] - test_data = data[N:] - train_y = np.sum(train_data, axis=1) - test_y = np.sum(test_data, axis=1) - - train = lgb.Dataset(train_data, train_y, free_raw_data=True) - test = lgb.Dataset(test_data, test_y, free_raw_data=True, reference=train) - - # The test is run twice, on cpu and gpu - params = { - 'device': "cpu", - 'boosting_type': 'gbdt', - 'objective': 'regression', - 'max_bin': 255, - 'num_leaves': 31, - 'seed': 0, - 'learning_rate': 0.1, - 'min_data_in_leaf': 0, - 'verbose': 2, - 'min_split_gain': 1000.0, - 'cegb_penalty_feature_coupled': 5 * np.arange(C), - 'cegb_penalty_split': 0.0002, - 'cegb_tradeoff': 10.0, - 'num_threads': 16, - 'force_col_wise': True, - } - - lgb.train(params, train, num_boost_round=20, valid_sets=test) - - def test_consistent_state_for_dataset_fields(): def check_asserts(data): diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index df840a768539..737b54d1e29e 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -3566,3 +3566,50 @@ def test_boost_from_average_with_single_leaf_trees(): preds = model.predict(X) mean_preds = np.mean(preds) assert y.min() <= mean_preds <= y.max() + + +def test_cegb_split_buffer_clean(): + # modified from https://github.com/microsoft/LightGBM/issues/3679#issuecomment-938652811 + # and https://github.com/microsoft/LightGBM/pull/5087 + # test that the ``splits_per_leaf_`` of CEGB is cleaned before training a new tree + # which is done in the fix #5164 + # without the fix: + # Check failed: (best_split_info.left_count) > (0) + + R, C = 1000, 100 + seed = 29 + np.random.seed(seed) + data = np.random.randn(R, C) + for i in range(1, C): + data[i] += data[0] * np.random.randn() + + N = int(0.8 * len(data)) + train_data = data[:N] + test_data = data[N:] + train_y = np.sum(train_data, axis=1) + test_y = np.sum(test_data, axis=1) + + train = lgb.Dataset(train_data, train_y, free_raw_data=True) + + params = { + 'device': "cpu", + 'boosting_type': 'gbdt', + 'objective': 'regression', + 'max_bin': 255, + 'num_leaves': 31, + 'seed': 0, + 'learning_rate': 0.1, + 'min_data_in_leaf': 0, + 'verbose': -1, + 'min_split_gain': 1000.0, + 'cegb_penalty_feature_coupled': 5 * np.arange(C), + 'cegb_penalty_split': 0.0002, + 'cegb_tradeoff': 10.0, + 'num_threads': 16, + 'force_col_wise': True, + } + + model = lgb.train(params, train, num_boost_round=10) + predicts = model.predict(test_data) + rmse = np.sqrt(np.mean((predicts - test_y) ** 2)) + assert rmse < 10.0 From f8e81700e6e4425d9cd9145b1173217855fd8d76 Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Thu, 12 May 2022 02:30:43 +0000 Subject: [PATCH 6/7] apply review comments --- tests/python_package_test/test_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 737b54d1e29e..722d237b8780 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -3605,11 +3605,10 @@ def test_cegb_split_buffer_clean(): 'cegb_penalty_feature_coupled': 5 * np.arange(C), 'cegb_penalty_split': 0.0002, 'cegb_tradeoff': 10.0, - 'num_threads': 16, 'force_col_wise': True, } model = lgb.train(params, train, num_boost_round=10) predicts = model.predict(test_data) - rmse = np.sqrt(np.mean((predicts - test_y) ** 2)) + rmse = np.sqrt(mean_squared_error(test_y, predicts)) assert rmse < 10.0 From 5235131e90539115206f6e168863ddbec7bdb007 Mon Sep 17 00:00:00 2001 From: Yu Shi Date: Mon, 30 May 2022 02:50:49 +0000 Subject: [PATCH 7/7] remove device cpu --- tests/python_package_test/test_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 722d237b8780..cb3357271542 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -3592,7 +3592,6 @@ def test_cegb_split_buffer_clean(): train = lgb.Dataset(train_data, train_y, free_raw_data=True) params = { - 'device': "cpu", 'boosting_type': 'gbdt', 'objective': 'regression', 'max_bin': 255,