diff --git a/src/io/bin.cpp b/src/io/bin.cpp index 9b105e282923..2e79a80266b6 100644 --- a/src/io/bin.cpp +++ b/src/io/bin.cpp @@ -181,6 +181,9 @@ namespace LightGBM { int left_max_bin = static_cast(static_cast(left_cnt_data) / (total_sample_cnt - cnt_zero) * (max_bin - 1)); left_max_bin = std::max(1, left_max_bin); bin_upper_bound = GreedyFindBin(distinct_values, counts, left_cnt, left_max_bin, left_cnt_data, min_data_in_bin); + if (bin_upper_bound.size() > 0) { + bin_upper_bound.back() = -kZeroThreshold; + } } int right_start = -1; @@ -191,32 +194,16 @@ namespace LightGBM { } } - if (bin_upper_bound.size() == 0) { - if (max_bin > 2) { - // create zero bin - bin_upper_bound.push_back(-kZeroThreshold); - bin_upper_bound.push_back(kZeroThreshold); - } - else if (max_bin > 1) { - bin_upper_bound.push_back(kZeroThreshold); - } - } else { - bin_upper_bound.back() = -kZeroThreshold; - if (max_bin > 2) { - // create zero bin - bin_upper_bound.push_back(kZeroThreshold); - } - } - - int right_max_bin = max_bin - static_cast(bin_upper_bound.size()); - if ((right_start >= 0) && (right_max_bin > 0)) { + int right_max_bin = max_bin - 1 - static_cast(bin_upper_bound.size()); + if (right_start >= 0 && right_max_bin > 0) { auto right_bounds = GreedyFindBin(distinct_values + right_start, counts + right_start, num_distinct_values - right_start, right_max_bin, right_cnt_data, min_data_in_bin); + bin_upper_bound.push_back(kZeroThreshold); bin_upper_bound.insert(bin_upper_bound.end(), right_bounds.begin(), right_bounds.end()); } else { bin_upper_bound.push_back(std::numeric_limits::infinity()); } - CHECK(bin_upper_bound.size() <= max_bin); + CHECK(bin_upper_bound.size() <= static_cast(max_bin)); return bin_upper_bound; } diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 2039742dc9ff..9a34de869724 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -921,7 +921,7 @@ def test_max_bin_by_feature(self): } lgb_data = lgb.Dataset(X, label=y) est = lgb.train(params, lgb_data, num_boost_round=1) - self.assertEqual(len(np.unique(est.predict(X))), 99) + self.assertEqual(len(np.unique(est.predict(X))), 100) params['max_bin_by_feature'] = [2, 100] lgb_data = lgb.Dataset(X, label=y) est = lgb.train(params, lgb_data, num_boost_round=1)