From 661d4fc4dfdb47ddf233a2a937a95397d9288508 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Tue, 20 Aug 2019 11:34:25 +0800 Subject: [PATCH 1/3] fix the bug in bin with small values --- src/io/bin.cpp | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/src/io/bin.cpp b/src/io/bin.cpp index 9b105e282923..12cdada6387a 100644 --- a/src/io/bin.cpp +++ b/src/io/bin.cpp @@ -181,6 +181,9 @@ namespace LightGBM { int left_max_bin = static_cast(static_cast(left_cnt_data) / (total_sample_cnt - cnt_zero) * (max_bin - 1)); left_max_bin = std::max(1, left_max_bin); bin_upper_bound = GreedyFindBin(distinct_values, counts, left_cnt, left_max_bin, left_cnt_data, min_data_in_bin); + if (bin_upper_bound.size() > 0) { + bin_upper_bound.back() = -kZeroThreshold; + } } int right_start = -1; @@ -191,27 +194,11 @@ namespace LightGBM { } } - if (bin_upper_bound.size() == 0) { - if (max_bin > 2) { - // create zero bin - bin_upper_bound.push_back(-kZeroThreshold); - bin_upper_bound.push_back(kZeroThreshold); - } - else if (max_bin > 1) { - bin_upper_bound.push_back(kZeroThreshold); - } - } else { - bin_upper_bound.back() = -kZeroThreshold; - if (max_bin > 2) { - // create zero bin - bin_upper_bound.push_back(kZeroThreshold); - } - } - - int right_max_bin = max_bin - static_cast(bin_upper_bound.size()); - if ((right_start >= 0) && (right_max_bin > 0)) { + int right_max_bin = max_bin - 1 - static_cast(bin_upper_bound.size()); + if (right_start >= 0 && right_max_bin > 0) { auto right_bounds = GreedyFindBin(distinct_values + right_start, counts + right_start, num_distinct_values - right_start, right_max_bin, right_cnt_data, min_data_in_bin); + bin_upper_bound.push_back(kZeroThreshold); bin_upper_bound.insert(bin_upper_bound.end(), right_bounds.begin(), right_bounds.end()); } else { bin_upper_bound.push_back(std::numeric_limits::infinity()); From d7725b5f3164f992585a890defd4f37149e4d2f0 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Tue, 20 Aug 2019 11:39:58 +0800 Subject: [PATCH 2/3] Update bin.cpp --- src/io/bin.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/io/bin.cpp b/src/io/bin.cpp index 12cdada6387a..2e79a80266b6 100644 --- a/src/io/bin.cpp +++ b/src/io/bin.cpp @@ -203,7 +203,7 @@ namespace LightGBM { } else { bin_upper_bound.push_back(std::numeric_limits::infinity()); } - CHECK(bin_upper_bound.size() <= max_bin); + CHECK(bin_upper_bound.size() <= static_cast(max_bin)); return bin_upper_bound; } From e555345f6ed2a1a3f8f660bf49604fa7023e8e59 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Tue, 20 Aug 2019 12:27:40 +0800 Subject: [PATCH 3/3] Update test_engine.py --- tests/python_package_test/test_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 2039742dc9ff..9a34de869724 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -921,7 +921,7 @@ def test_max_bin_by_feature(self): } lgb_data = lgb.Dataset(X, label=y) est = lgb.train(params, lgb_data, num_boost_round=1) - self.assertEqual(len(np.unique(est.predict(X))), 99) + self.assertEqual(len(np.unique(est.predict(X))), 100) params['max_bin_by_feature'] = [2, 100] lgb_data = lgb.Dataset(X, label=y) est = lgb.train(params, lgb_data, num_boost_round=1)