diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 93c241bce215..584237464fd1 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -404,6 +404,14 @@ Learning Control Parameters - see `this file `__ as an example +- ``forcedbins_filename`` :raw-html:`🔗︎`, default = ``""``, type = string + + - path to a ``.json`` file that specifies bin upper bounds for some or all features + + - ``.json`` file should contain an array of objects, each containing the name ``feature`` (integer feature number) and ``bin_upper_bounds`` (array of thresolds for binning) + + - see `this file `__ as an example + - ``refit_decay_rate`` :raw-html:`🔗︎`, default = ``0.9``, type = double, constraints: ``0.0 <= refit_decay_rate <= 1.0`` - decay rate of ``refit`` task, will use ``leaf_output = refit_decay_rate * old_leaf_output + (1.0 - refit_decay_rate) * new_leaf_output`` to refit trees diff --git a/include/LightGBM/bin.h b/include/LightGBM/bin.h index 46baee58fc46..1c5f62cd1907 100644 --- a/include/LightGBM/bin.h +++ b/include/LightGBM/bin.h @@ -146,9 +146,10 @@ class BinMapper { * \param bin_type Type of this bin * \param use_missing True to enable missing value handle * \param zero_as_missing True to use zero as missing value + * \param forced_upper_bounds Vector of split points that must be used (if this has size less than max_bin, remaining splits are found by the algorithm) */ void FindBin(double* values, int num_values, size_t total_sample_cnt, int max_bin, int min_data_in_bin, int min_split_data, BinType bin_type, - bool use_missing, bool zero_as_missing); + bool use_missing, bool zero_as_missing, std::vector forced_upper_bounds); /*! * \brief Use specific number of bin to calculate the size of this class diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 08b2a7352c0a..1c0c14f69508 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -402,6 +402,11 @@ struct Config { // desc = see `this file `__ as an example std::string forcedsplits_filename = ""; + // desc = path to a ``.json`` file that specifies bin upper bounds for some or all features + // desc = ``.json`` file should contain an array of objects, each containing the name ``feature`` (integer feature number) and ``bin_upper_bounds`` (array of thresolds for binning) + // desc = see `this file `__ as an example + std::string forcedbins_filename = ""; + // check = >=0.0 // check = <=1.0 // desc = decay rate of ``refit`` task, will use ``leaf_output = refit_decay_rate * old_leaf_output + (1.0 - refit_decay_rate) * new_leaf_output`` to refit trees diff --git a/include/LightGBM/dataset.h b/include/LightGBM/dataset.h index e688522fbb1a..900487eafbf4 100644 --- a/include/LightGBM/dataset.h +++ b/include/LightGBM/dataset.h @@ -596,6 +596,8 @@ class Dataset { void addFeaturesFrom(Dataset* other); + static std::vector> GetForcedBins(std::string forced_bins_path, int num_total_features); + private: std::string data_filename_; /*! \brief Store used features */ @@ -630,6 +632,7 @@ class Dataset { bool is_finish_load_; int max_bin_; std::vector max_bin_by_feature_; + std::vector> forced_bin_bounds_; int bin_construct_sample_cnt_; int min_data_in_bin_; bool use_missing_; diff --git a/src/io/bin.cpp b/src/io/bin.cpp index 617bdf5bac73..62713d1bddd3 100644 --- a/src/io/bin.cpp +++ b/src/io/bin.cpp @@ -150,8 +150,10 @@ namespace LightGBM { } std::vector FindBinWithZeroAsOneBin(const double* distinct_values, const int* counts, - int num_distinct_values, int max_bin, size_t total_sample_cnt, int min_data_in_bin) { + int num_distinct_values, int max_bin, size_t total_sample_cnt, int min_data_in_bin, std::vector forced_upper_bounds) { std::vector bin_upper_bound; + + // get list of distinct values int left_cnt_data = 0; int cnt_zero = 0; int right_cnt_data = 0; @@ -165,6 +167,7 @@ namespace LightGBM { } } + // get number of positive and negative distinct values int left_cnt = -1; for (int i = 0; i < num_distinct_values; ++i) { if (distinct_values[i] > -kZeroThreshold) { @@ -172,18 +175,9 @@ namespace LightGBM { break; } } - if (left_cnt < 0) { left_cnt = num_distinct_values; } - - if (left_cnt > 0) { - int left_max_bin = static_cast(static_cast(left_cnt_data) / (total_sample_cnt - cnt_zero) * (max_bin - 1)); - left_max_bin = std::max(1, left_max_bin); - bin_upper_bound = GreedyFindBin(distinct_values, counts, left_cnt, left_max_bin, left_cnt_data, min_data_in_bin); - bin_upper_bound.back() = -kZeroThreshold; - } - int right_start = -1; for (int i = left_cnt; i < num_distinct_values; ++i) { if (distinct_values[i] > kZeroThreshold) { @@ -192,21 +186,66 @@ namespace LightGBM { } } - if (right_start >= 0) { - int right_max_bin = max_bin - 1 - static_cast(bin_upper_bound.size()); - CHECK(right_max_bin > 0); - auto right_bounds = GreedyFindBin(distinct_values + right_start, counts + right_start, - num_distinct_values - right_start, right_max_bin, right_cnt_data, min_data_in_bin); + // include zero bounds if possible + if (max_bin == 2) { + if (left_cnt == 0) { + bin_upper_bound.push_back(kZeroThreshold); + } else { + bin_upper_bound.push_back(-kZeroThreshold); + } + } else if (max_bin >= 3) { + bin_upper_bound.push_back(-kZeroThreshold); bin_upper_bound.push_back(kZeroThreshold); - bin_upper_bound.insert(bin_upper_bound.end(), right_bounds.begin(), right_bounds.end()); - } else { - bin_upper_bound.push_back(std::numeric_limits::infinity()); } + + // add forced bounds, excluding zeros since we have already added zero bounds + int i = 0; + while (i < forced_upper_bounds.size()) { + if (std::fabs(forced_upper_bounds[i]) <= kZeroThreshold) { + forced_upper_bounds.erase(forced_upper_bounds.begin() + i); + } else { + ++i; + } + } + bin_upper_bound.push_back(std::numeric_limits::infinity()); + int max_to_insert = max_bin - static_cast(bin_upper_bound.size()); + int num_to_insert = std::min(max_to_insert, static_cast(forced_upper_bounds.size())); + if (num_to_insert > 0) { + bin_upper_bound.insert(bin_upper_bound.end(), forced_upper_bounds.begin(), forced_upper_bounds.begin() + num_to_insert); + } + std::sort(bin_upper_bound.begin(), bin_upper_bound.end()); + + // find remaining bounds + std::vector bounds_to_add; + int value_ind = 0; + for (int i = 0; i < bin_upper_bound.size(); ++i) { + int cnt_in_bin = 0; + int distinct_cnt_in_bin = 0; + int bin_start = value_ind; + while ((value_ind < num_distinct_values) && (distinct_values[value_ind] < bin_upper_bound[i])) { + cnt_in_bin += counts[value_ind]; + ++distinct_cnt_in_bin; + ++value_ind; + } + int bins_remaining = max_bin - static_cast(bin_upper_bound.size()) - static_cast(bounds_to_add.size()); + int num_sub_bins = static_cast(std::lround((static_cast(cnt_in_bin) * bins_remaining / total_sample_cnt))); + num_sub_bins = std::min(num_sub_bins, bins_remaining) + 1; + if (i == bin_upper_bound.size() - 1) { + num_sub_bins = bins_remaining + 1; + } + std::vector new_upper_bounds = GreedyFindBin(distinct_values + bin_start, counts + bin_start, distinct_cnt_in_bin, + num_sub_bins, cnt_in_bin, min_data_in_bin); + bounds_to_add.insert(bounds_to_add.end(), new_upper_bounds.begin(), new_upper_bounds.end() - 1); // last bound is infinity + } + bin_upper_bound.insert(bin_upper_bound.end(), bounds_to_add.begin(), bounds_to_add.end()); + std::sort(bin_upper_bound.begin(), bin_upper_bound.end()); + CHECK(bin_upper_bound.size() <= max_bin); return bin_upper_bound; } void BinMapper::FindBin(double* values, int num_sample_values, size_t total_sample_cnt, - int max_bin, int min_data_in_bin, int min_split_data, BinType bin_type, bool use_missing, bool zero_as_missing) { + int max_bin, int min_data_in_bin, int min_split_data, BinType bin_type, bool use_missing, bool zero_as_missing, + std::vector forced_upper_bounds) { int na_cnt = 0; int tmp_num_sample_values = 0; for (int i = 0; i < num_sample_values; ++i) { @@ -274,14 +313,17 @@ namespace LightGBM { int num_distinct_values = static_cast(distinct_values.size()); if (bin_type_ == BinType::NumericalBin) { if (missing_type_ == MissingType::Zero) { - bin_upper_bound_ = FindBinWithZeroAsOneBin(distinct_values.data(), counts.data(), num_distinct_values, max_bin, total_sample_cnt, min_data_in_bin); + bin_upper_bound_ = FindBinWithZeroAsOneBin(distinct_values.data(), counts.data(), num_distinct_values, max_bin, total_sample_cnt, + min_data_in_bin, forced_upper_bounds); if (bin_upper_bound_.size() == 2) { missing_type_ = MissingType::None; } } else if (missing_type_ == MissingType::None) { - bin_upper_bound_ = FindBinWithZeroAsOneBin(distinct_values.data(), counts.data(), num_distinct_values, max_bin, total_sample_cnt, min_data_in_bin); + bin_upper_bound_ = FindBinWithZeroAsOneBin(distinct_values.data(), counts.data(), num_distinct_values, max_bin, total_sample_cnt, + min_data_in_bin, forced_upper_bounds); } else { - bin_upper_bound_ = FindBinWithZeroAsOneBin(distinct_values.data(), counts.data(), num_distinct_values, max_bin - 1, total_sample_cnt - na_cnt, min_data_in_bin); + bin_upper_bound_ = FindBinWithZeroAsOneBin(distinct_values.data(), counts.data(), num_distinct_values, max_bin - 1, total_sample_cnt - na_cnt, + min_data_in_bin, forced_upper_bounds); bin_upper_bound_.push_back(NaN); } num_bin_ = static_cast(bin_upper_bound_.size()); diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index 8d75b1cde3df..ad5b43811ebe 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -211,6 +211,7 @@ std::unordered_set Config::parameter_set({ "monotone_constraints", "feature_contri", "forcedsplits_filename", + "forcedbins_filename", "refit_decay_rate", "cegb_tradeoff", "cegb_penalty_split", @@ -396,6 +397,8 @@ void Config::GetMembersFromString(const std::unordered_map=0.0); CHECK(refit_decay_rate <=1.0); @@ -608,6 +611,7 @@ std::string Config::SaveMembersToString() const { str_buf << "[monotone_constraints: " << Common::Join(Common::ArrayCast(monotone_constraints), ",") << "]\n"; str_buf << "[feature_contri: " << Common::Join(feature_contri, ",") << "]\n"; str_buf << "[forcedsplits_filename: " << forcedsplits_filename << "]\n"; + str_buf << "[forcedbins_filename: " << forcedbins_filename << "]\n"; str_buf << "[refit_decay_rate: " << refit_decay_rate << "]\n"; str_buf << "[cegb_tradeoff: " << cegb_tradeoff << "]\n"; str_buf << "[cegb_penalty_split: " << cegb_penalty_split << "]\n"; diff --git a/src/io/dataset.cpp b/src/io/dataset.cpp index f201a40a1a7a..c931e945cd24 100644 --- a/src/io/dataset.cpp +++ b/src/io/dataset.cpp @@ -8,12 +8,17 @@ #include #include #include +#include #include #include #include #include #include +#include + +using namespace json11; + namespace LightGBM { @@ -324,6 +329,7 @@ void Dataset::Construct( max_bin_by_feature_.resize(num_total_features_); max_bin_by_feature_.assign(io_config.max_bin_by_feature.begin(), io_config.max_bin_by_feature.end()); } + forced_bin_bounds_ = Dataset::GetForcedBins(io_config.forcedbins_filename, num_total_features_); max_bin_ = io_config.max_bin; min_data_in_bin_ = io_config.min_data_in_bin; bin_construct_sample_cnt_ = io_config.bin_construct_sample_cnt; @@ -356,6 +362,12 @@ void Dataset::ResetConfig(const char* parameters) { if (param.count("sparse_threshold") && io_config.sparse_threshold != sparse_threshold_) { Log::Warning("Cannot change sparse_threshold after constructed Dataset handle."); } + if (param.count("forcedbins_filename")) { + std::vector> config_bounds = Dataset::GetForcedBins(io_config.forcedbins_filename, num_total_features_); + if (config_bounds != forced_bin_bounds_) { + Log::Warning("Cannot change forced bins after constructed Dataset handle."); + } + } if (!io_config.monotone_constraints.empty()) { CHECK(static_cast(num_total_features_) == io_config.monotone_constraints.size()); @@ -657,6 +669,10 @@ void Dataset::SaveBinaryFile(const char* bin_filename) { for (int i = 0; i < num_total_features_; ++i) { size_of_header += feature_names_[i].size() + sizeof(int); } + // size of forced bins + for (int i = 0; i < num_total_features_; ++i) { + size_of_header += forced_bin_bounds_[i].size() * sizeof(double) + sizeof(int); + } writer->Write(&size_of_header, sizeof(size_of_header)); // write header writer->Write(&num_data_, sizeof(num_data_)); @@ -705,6 +721,15 @@ void Dataset::SaveBinaryFile(const char* bin_filename) { const char* c_str = feature_names_[i].c_str(); writer->Write(c_str, sizeof(char) * str_len); } + // write forced bins + for (int i = 0; i < num_total_features_; ++i) { + int num_bounds = static_cast(forced_bin_bounds_[i].size()); + writer->Write(&num_bounds, sizeof(int)); + + for (int j = 0; j < forced_bin_bounds_[i].size(); ++j) { + writer->Write(&forced_bin_bounds_[i][j], sizeof(double)); + } + } // get size of meta data size_t size_of_metadata = metadata_.SizesInByte(); @@ -754,6 +779,13 @@ void Dataset::DumpTextFile(const char* text_filename) { for (auto n : feature_names_) { fprintf(file, "%s, ", n.c_str()); } + fprintf(file, "\nforced_bins: "); + for (int i = 0; i < num_total_features_; ++i) { + fprintf(file, "\nfeature %d: ", i); + for (int j = 0; j < forced_bin_bounds_[i].size(); ++j) { + fprintf(file, "%lf, ", forced_bin_bounds_[i][j]); + } + } std::vector> iterators; iterators.reserve(num_features_); for (int j = 0; j < num_features_; ++j) { @@ -1005,6 +1037,7 @@ void Dataset::addFeaturesFrom(Dataset* other) { PushVector(feature_names_, other->feature_names_); PushVector(feature2subfeature_, other->feature2subfeature_); PushVector(group_feature_cnt_, other->group_feature_cnt_); + PushVector(forced_bin_bounds_, other->forced_bin_bounds_); feature_groups_.reserve(other->feature_groups_.size()); for (auto& fg : other->feature_groups_) { feature_groups_.emplace_back(new FeatureGroup(*fg)); @@ -1027,10 +1060,39 @@ void Dataset::addFeaturesFrom(Dataset* other) { PushClearIfEmpty(monotone_types_, num_total_features_, other->monotone_types_, other->num_total_features_, (int8_t)0); PushClearIfEmpty(feature_penalty_, num_total_features_, other->feature_penalty_, other->num_total_features_, 1.0); - + PushClearIfEmpty(max_bin_by_feature_, num_total_features_, other->max_bin_by_feature_, other->num_total_features_, -1); num_features_ += other->num_features_; num_total_features_ += other->num_total_features_; num_groups_ += other->num_groups_; } + +std::vector> Dataset::GetForcedBins(std::string forced_bins_path, int num_total_features) { + std::vector> forced_bins(num_total_features, std::vector()); + if (forced_bins_path != "") { + std::ifstream forced_bins_stream(forced_bins_path.c_str()); + std::stringstream buffer; + buffer << forced_bins_stream.rdbuf(); + std::string err; + Json forced_bins_json = Json::parse(buffer.str(), err); + CHECK(forced_bins_json.is_array()); + std::vector forced_bins_arr = forced_bins_json.array_items(); + for (int i = 0; i < forced_bins_arr.size(); ++i) { + int feature_num = forced_bins_arr[i]["feature"].int_value(); + CHECK(feature_num < num_total_features); + std::vector bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items(); + for (int j = 0; j < bounds_arr.size(); ++j) { + forced_bins[feature_num].push_back(bounds_arr[j].number_value()); + } + } + // remove duplicates + for (int i = 0; i < num_total_features; ++i) { + auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end()); + forced_bins[i].erase(new_end, forced_bins[i].end()); + } + } + return forced_bins; +} + + } // namespace LightGBM diff --git a/src/io/dataset_loader.cpp b/src/io/dataset_loader.cpp index 1130d803ea36..f36d5b1df27d 100644 --- a/src/io/dataset_loader.cpp +++ b/src/io/dataset_loader.cpp @@ -3,7 +3,6 @@ * Licensed under the MIT License. See LICENSE file in the project root for license information. */ #include - #include #include #include @@ -458,6 +457,21 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b } dataset->feature_names_.emplace_back(str_buf.str()); } + // get forced_bin_bounds_ + dataset->forced_bin_bounds_ = std::vector>(dataset->num_total_features_, std::vector()); + for (int i = 0; i < dataset->num_total_features_; ++i) { + int num_bounds = *(reinterpret_cast(mem_ptr)); + mem_ptr += sizeof(int); + dataset->forced_bin_bounds_[i] = std::vector(); + const double* tmp_ptr_forced_bounds = reinterpret_cast(mem_ptr); + + for (int j = 0; j < num_bounds; ++j) { + double bound = tmp_ptr_forced_bounds[j]; + dataset->forced_bin_bounds_[i].push_back(bound); + } + mem_ptr += num_bounds * sizeof(double); + + } // read size of meta data read_cnt = reader->Read(buffer.data(), sizeof(size_t)); @@ -549,6 +563,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b return dataset.release(); } + Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, int** sample_indices, int num_col, const int* num_per_col, size_t total_sample_size, data_size_t num_data) { @@ -565,6 +580,11 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, CHECK(static_cast(num_col) == config_.max_bin_by_feature.size()); CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1); } + + // get forced split + std::string forced_bins_path = config_.forcedbins_filename; + std::vector> forced_bin_bounds = Dataset::GetForcedBins(forced_bins_path, num_col); + const data_size_t filter_cnt = static_cast( static_cast(config_.min_data_in_leaf * total_sample_size) / num_data); if (Network::num_machines() == 1) { @@ -585,12 +605,13 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, if (config_.max_bin_by_feature.empty()) { bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size, config_.max_bin, config_.min_data_in_bin, filter_cnt, - bin_type, config_.use_missing, config_.zero_as_missing); + bin_type, config_.use_missing, config_.zero_as_missing, + forced_bin_bounds[i]); } else { bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size, config_.max_bin_by_feature[i], config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, - config_.zero_as_missing); + config_.zero_as_missing, forced_bin_bounds[i]); } OMP_LOOP_EX_END(); } @@ -630,12 +651,13 @@ Dataset* DatasetLoader::CostructFromSampleData(double** sample_values, if (config_.max_bin_by_feature.empty()) { bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i], total_sample_size, config_.max_bin, config_.min_data_in_bin, - filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing); + filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing, + forced_bin_bounds[i]); } else { bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i], total_sample_size, config_.max_bin_by_feature[start[rank] + i], config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, - config_.zero_as_missing); + config_.zero_as_missing, forced_bin_bounds[i]); } OMP_LOOP_EX_END(); } @@ -872,6 +894,10 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1); } + // get forced split + std::string forced_bins_path = config_.forcedbins_filename; + std::vector> forced_bin_bounds = Dataset::GetForcedBins(forced_bins_path, dataset->num_total_features_); + // check the range of label_idx, weight_idx and group_idx CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_); CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_); @@ -909,12 +935,13 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, if (config_.max_bin_by_feature.empty()) { bin_mappers[i]->FindBin(sample_values[i].data(), static_cast(sample_values[i].size()), sample_data.size(), config_.max_bin, config_.min_data_in_bin, - filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing); + filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing, + forced_bin_bounds[i]); } else { bin_mappers[i]->FindBin(sample_values[i].data(), static_cast(sample_values[i].size()), sample_data.size(), config_.max_bin_by_feature[i], config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing, - config_.zero_as_missing); + config_.zero_as_missing, forced_bin_bounds[i]); } OMP_LOOP_EX_END(); } @@ -955,13 +982,14 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines, bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(), static_cast(sample_values[start[rank] + i].size()), sample_data.size(), config_.max_bin, config_.min_data_in_bin, - filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing); + filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing, + forced_bin_bounds[i]); } else { bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(), static_cast(sample_values[start[rank] + i].size()), sample_data.size(), config_.max_bin_by_feature[i], config_.min_data_in_bin, filter_cnt, bin_type, - config_.use_missing, config_.zero_as_missing); + config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]); } OMP_LOOP_EX_END(); } diff --git a/tests/data/forced_bins.json b/tests/data/forced_bins.json new file mode 100644 index 000000000000..aa74c36ffb78 --- /dev/null +++ b/tests/data/forced_bins.json @@ -0,0 +1,10 @@ +[ + { + "feature": 0, + "bin_upper_bound": [ 0.3, 0.35, 0.4 ] + }, + { + "feature": 1, + "bin_upper_bound": [ -0.1, -0.15, -0.2 ] + } +] \ No newline at end of file diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 4c9a9eddc6c6..59ea0113f50a 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -895,7 +895,7 @@ def test_max_bin_by_feature(self): } lgb_data = lgb.Dataset(X, label=y) est = lgb.train(params, lgb_data, num_boost_round=1) - self.assertEqual(len(np.unique(est.predict(X))), 100) + self.assertEqual(len(np.unique(est.predict(X))), 99) params['max_bin_by_feature'] = [2, 100] lgb_data = lgb.Dataset(X, label=y) est = lgb.train(params, lgb_data, num_boost_round=1) @@ -1544,3 +1544,33 @@ def constant_metric(preds, train_data): decreasing_metric(preds, train_data)], early_stopping_rounds=5, verbose_eval=False) self.assertEqual(gbm.best_iteration, 1) + + def test_forced_bins(self): + x = np.zeros((100, 2)) + x[:, 0] = np.arange(0, 1, 0.01) + x[:, 1] = -np.arange(0, 1, 0.01) + y = np.arange(0, 1, 0.01) + forcedbins_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data/forced_bins.json') + params = {'objective': 'regression_l1', + 'max_bin': 6, + 'forcedbins_filename': forcedbins_filename, + 'num_leaves': 2, + 'min_data_in_leaf': 1, + 'verbose': -1, + 'seed': 0} + lgb_x = lgb.Dataset(x, label=y) + est = lgb.train(params, lgb_x, num_boost_round=100) + new_x = np.zeros((3, x.shape[1])) + new_x[:, 0] = [0.31, 0.37, 0.41] + new_x[:, 1] = [0, 0, 0] + predicted = est.predict(new_x) + self.assertEqual(len(np.unique(predicted)), 3) + new_x[:, 0] = [0, 0, 0] + new_x[:, 1] = [-0.25, -0.5, -0.9] + predicted = est.predict(new_x) + self.assertEqual(len(np.unique(predicted)), 1) + params['forcedbins_filename'] = '' + lgb_x = lgb.Dataset(x, label=y) + est = lgb.train(params, lgb_x, num_boost_round=100) + predicted = est.predict(new_x) + self.assertEqual(len(np.unique(predicted)), 3)