Skip to content

Commit

Permalink
fix the bug in bin with small values (#2342)
Browse files Browse the repository at this point in the history
* fix the bug in bin with small values

* Update bin.cpp

* Update test_engine.py
  • Loading branch information
guolinke authored Aug 20, 2019
1 parent 86c6a2d commit 20f94c5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 21 deletions.
27 changes: 7 additions & 20 deletions src/io/bin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ namespace LightGBM {
int left_max_bin = static_cast<int>(static_cast<double>(left_cnt_data) / (total_sample_cnt - cnt_zero) * (max_bin - 1));
left_max_bin = std::max(1, left_max_bin);
bin_upper_bound = GreedyFindBin(distinct_values, counts, left_cnt, left_max_bin, left_cnt_data, min_data_in_bin);
if (bin_upper_bound.size() > 0) {
bin_upper_bound.back() = -kZeroThreshold;
}
}

int right_start = -1;
Expand All @@ -191,32 +194,16 @@ namespace LightGBM {
}
}

if (bin_upper_bound.size() == 0) {
if (max_bin > 2) {
// create zero bin
bin_upper_bound.push_back(-kZeroThreshold);
bin_upper_bound.push_back(kZeroThreshold);
}
else if (max_bin > 1) {
bin_upper_bound.push_back(kZeroThreshold);
}
} else {
bin_upper_bound.back() = -kZeroThreshold;
if (max_bin > 2) {
// create zero bin
bin_upper_bound.push_back(kZeroThreshold);
}
}

int right_max_bin = max_bin - static_cast<int>(bin_upper_bound.size());
if ((right_start >= 0) && (right_max_bin > 0)) {
int right_max_bin = max_bin - 1 - static_cast<int>(bin_upper_bound.size());
if (right_start >= 0 && right_max_bin > 0) {
auto right_bounds = GreedyFindBin(distinct_values + right_start, counts + right_start,
num_distinct_values - right_start, right_max_bin, right_cnt_data, min_data_in_bin);
bin_upper_bound.push_back(kZeroThreshold);
bin_upper_bound.insert(bin_upper_bound.end(), right_bounds.begin(), right_bounds.end());
} else {
bin_upper_bound.push_back(std::numeric_limits<double>::infinity());
}
CHECK(bin_upper_bound.size() <= max_bin);
CHECK(bin_upper_bound.size() <= static_cast<size_t>(max_bin));
return bin_upper_bound;
}

Expand Down
2 changes: 1 addition & 1 deletion tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ def test_max_bin_by_feature(self):
}
lgb_data = lgb.Dataset(X, label=y)
est = lgb.train(params, lgb_data, num_boost_round=1)
self.assertEqual(len(np.unique(est.predict(X))), 99)
self.assertEqual(len(np.unique(est.predict(X))), 100)
params['max_bin_by_feature'] = [2, 100]
lgb_data = lgb.Dataset(X, label=y)
est = lgb.train(params, lgb_data, num_boost_round=1)
Expand Down

0 comments on commit 20f94c5

Please sign in to comment.