diff --git a/R-package/R/lgb.train.R b/R-package/R/lgb.train.R index a78bf72eb203..1ba5bf086647 100644 --- a/R-package/R/lgb.train.R +++ b/R-package/R/lgb.train.R @@ -124,6 +124,10 @@ lgb.train <- function(params = list(), end_iteration <- begin_iteration + nrounds - 1L } + if (!is.null(params[["interaction_constraints"]])) { + stop("lgb.train: interaction_constraints is not implemented") + } + # Update parameters with parsed parameters data$update_params(params) diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 64251f8d4573..01362fb9af34 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -538,6 +538,20 @@ Learning Control Parameters - note that the parent output ``w_p`` itself has smoothing applied, unless it is the root node, so that the smoothing effect accumulates with the tree depth +- ``interaction_constraints`` :raw-html:`🔗︎`, default = ``""``, type = string + + - controls which features can appear in the same branch + + - by default interaction constraints are disabled, to enable them you can specify + + - for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]`` + + - for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]`` + + - for R-package, **not yet supported** + + - any two features can only appear in the same branch only if there exists a constraint containing both features + - ``verbosity`` :raw-html:`🔗︎`, default = ``1``, type = int, aliases: ``verbose`` - controls the level of LightGBM's verbosity diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 5cdc6139dc0e..2a3335c1c0ad 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -505,6 +505,14 @@ struct Config { // descl2 = note that the parent output ``w_p`` itself has smoothing applied, unless it is the root node, so that the smoothing effect accumulates with the tree depth double path_smooth = 0; + // desc = controls which features can appear in the same branch + // desc = by default interaction constraints are disabled, to enable them you can specify + // descl2 = for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]`` + // descl2 = for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]`` + // descl2 = for R-package, **not yet supported** + // desc = any two features can only appear in the same branch only if there exists a constraint containing both features + std::string interaction_constraints = ""; + // alias = verbose // desc = controls the level of LightGBM's verbosity // desc = ``< 0``: Fatal, ``= 0``: Error (Warning), ``= 1``: Info, ``> 1``: Debug @@ -958,12 +966,14 @@ struct Config { static const std::unordered_map& alias_table(); static const std::unordered_set& parameter_set(); std::vector> auc_mu_weights_matrix; + std::vector> interaction_constraints_vector; private: void CheckParamConflict(); void GetMembersFromString(const std::unordered_map& params); std::string SaveMembersToString() const; void GetAucMuWeights(); + void GetInteractionConstraints(); }; inline bool Config::GetString( diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h index 467ce0c652f8..5ce3ff9b3eb1 100644 --- a/include/LightGBM/tree.h +++ b/include/LightGBM/tree.h @@ -27,8 +27,9 @@ class Tree { /*! * \brief Constructor * \param max_leaves The number of max leaves + * \param track_branch_features Whether to keep track of ancestors of leaf nodes */ - explicit Tree(int max_leaves); + explicit Tree(int max_leaves, bool track_branch_features); /*! * \brief Constructor, from a string @@ -148,6 +149,9 @@ class Tree { /*! \brief Get feature of specific split*/ inline int split_feature(int split_idx) const { return split_feature_[split_idx]; } + /*! \brief Get features on leaf's branch*/ + inline std::vector branch_features(int leaf) const { return branch_features_[leaf]; } + inline double split_gain(int split_idx) const { return split_gain_[split_idx]; } inline double internal_value(int node_idx) const { @@ -436,6 +440,10 @@ class Tree { std::vector internal_count_; /*! \brief Depth for leaves */ std::vector leaf_depth_; + /*! \brief whether to keep track of ancestor nodes for each leaf (only needed when feature interactions are restricted) */ + bool track_branch_features_; + /*! \brief Features on leaf's branch, original index */ + std::vector> branch_features_; double shrinkage_; int max_depth_; }; @@ -477,6 +485,11 @@ inline void Tree::Split(int leaf, int feature, int real_feature, // update leaf depth leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1; leaf_depth_[leaf]++; + if (track_branch_features_) { + branch_features_[num_leaves_] = branch_features_[leaf]; + branch_features_[num_leaves_].push_back(split_feature_[new_node_idx]); + branch_features_[leaf].push_back(split_feature_[new_node_idx]); + } } inline double Tree::Predict(const double* feature_values) const { diff --git a/include/LightGBM/utils/common.h b/include/LightGBM/utils/common.h index 825d9692f762..bdc769e52226 100644 --- a/include/LightGBM/utils/common.h +++ b/include/LightGBM/utils/common.h @@ -103,6 +103,30 @@ inline static std::vector Split(const char* c_str, char delimiter) return ret; } +inline static std::vector SplitBrackets(const char* c_str, char left_delimiter, char right_delimiter) { + std::vector ret; + std::string str(c_str); + size_t i = 0; + size_t pos = 0; + bool open = false; + while (pos < str.length()) { + if (str[pos] == left_delimiter) { + open = true; + ++pos; + i = pos; + } else if (str[pos] == right_delimiter && open) { + if (i < pos) { + ret.push_back(str.substr(i, pos - i)); + } + open = false; + ++pos; + } else { + ++pos; + } + } + return ret; +} + inline static std::vector SplitLines(const char* c_str) { std::vector ret; std::string str(c_str); @@ -503,6 +527,17 @@ inline static std::vector StringToArray(const std::string& str, char delimite return ret; } +template +inline static std::vector> StringToArrayofArrays( + const std::string& str, char left_bracket, char right_bracket, char delimiter) { + std::vector strs = SplitBrackets(str.c_str(), left_bracket, right_bracket); + std::vector> ret; + for (const auto& s : strs) { + ret.push_back(StringToArray(s, delimiter)); + } + return ret; +} + template inline static std::vector StringToArray(const std::string& str, int n) { if (n == 0) { diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 07b7efd410e3..01a5f31e51b6 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -135,7 +135,12 @@ def param_dict_to_str(data): pairs = [] for key, val in data.items(): if isinstance(val, (list, tuple, set)) or is_numpy_1d_array(val): - pairs.append(str(key) + '=' + ','.join(map(str, val))) + def to_string(x): + if isinstance(x, list): + return "[{}]".format(','.join(map(str, x))) + else: + return str(x) + pairs.append(str(key) + '=' + ','.join(map(to_string, val))) elif isinstance(val, string_type) or isinstance(val, numeric_types) or is_numeric(val): pairs.append(str(key) + '=' + str(val)) elif val is not None: diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index 25a0946c90de..7871bbfb086c 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -352,7 +352,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) { bool should_continue = false; for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { const size_t offset = static_cast(cur_tree_id) * num_data_; - std::unique_ptr new_tree(new Tree(2)); + std::unique_ptr new_tree(new Tree(2, false)); if (class_need_train_[cur_tree_id] && train_data_->num_features() > 0) { auto grad = gradients + offset; auto hess = hessians + offset; diff --git a/src/boosting/rf.hpp b/src/boosting/rf.hpp index dd9be038aac9..5c90202a515e 100644 --- a/src/boosting/rf.hpp +++ b/src/boosting/rf.hpp @@ -109,7 +109,7 @@ class RF : public GBDT { gradients = gradients_.data(); hessians = hessians_.data(); for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) { - std::unique_ptr new_tree(new Tree(2)); + std::unique_ptr new_tree(new Tree(2, false)); size_t offset = static_cast(cur_tree_id)* num_data_; if (class_need_train_[cur_tree_id]) { auto grad = gradients + offset; diff --git a/src/io/config.cpp b/src/io/config.cpp index d31b7b839a3e..d569a7401e17 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -180,6 +180,14 @@ void Config::GetAucMuWeights() { } } +void Config::GetInteractionConstraints() { + if (interaction_constraints == "") { + interaction_constraints_vector = std::vector>(); + } else { + interaction_constraints_vector = Common::StringToArrayofArrays(interaction_constraints, '[', ']', ','); + } +} + void Config::Set(const std::unordered_map& params) { // generate seeds by seed. if (GetInt(params, "seed", &seed)) { @@ -204,6 +212,8 @@ void Config::Set(const std::unordered_map& params) { GetAucMuWeights(); + GetInteractionConstraints(); + // sort eval_at std::sort(eval_at.begin(), eval_at.end()); diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index 5881571d16f3..807cad785021 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -230,6 +230,7 @@ const std::unordered_set& Config::parameter_set() { "cegb_penalty_feature_lazy", "cegb_penalty_feature_coupled", "path_smooth", + "interaction_constraints", "verbosity", "input_model", "output_model", @@ -454,6 +455,8 @@ void Config::GetMembersFromString(const std::unordered_map>(max_leaves_); + } // root is in the depth 0 leaf_depth_[0] = 0; num_leaves_ = 1; diff --git a/src/treelearner/col_sampler.hpp b/src/treelearner/col_sampler.hpp index 68a98d159271..cd2884812552 100644 --- a/src/treelearner/col_sampler.hpp +++ b/src/treelearner/col_sampler.hpp @@ -13,6 +13,7 @@ #include #include +#include #include namespace LightGBM { @@ -23,6 +24,10 @@ class ColSampler { fraction_bynode_(config->feature_fraction_bynode), seed_(config->feature_fraction_seed), random_(config->feature_fraction_seed) { + for (auto constraint : config->interaction_constraints_vector) { + std::unordered_set constraint_set(constraint.begin(), constraint.end()); + interaction_constraints_.push_back(constraint_set); + } } static int GetCnt(size_t total_cnt, double fraction) { @@ -83,32 +88,87 @@ class ColSampler { } } - std::vector GetByNode() { - if (fraction_bynode_ >= 1.0f) { - return std::vector(train_data_->num_features(), 1); + std::vector GetByNode(const Tree* tree, int leaf) { + // get interaction constraints for current branch + std::unordered_set allowed_features; + if (!interaction_constraints_.empty()) { + std::vector branch_features = tree->branch_features(leaf); + allowed_features.insert(branch_features.begin(), branch_features.end()); + for (auto constraint : interaction_constraints_) { + int num_feat_found = 0; + if (branch_features.size() == 0) { + allowed_features.insert(constraint.begin(), constraint.end()); + } + for (int feat : branch_features) { + if (constraint.count(feat) == 0) { break; } + ++num_feat_found; + if (num_feat_found == static_cast(branch_features.size())) { + allowed_features.insert(constraint.begin(), constraint.end()); + break; + } + } + } } + std::vector ret(train_data_->num_features(), 0); + if (fraction_bynode_ >= 1.0f) { + if (interaction_constraints_.empty()) { + return std::vector(train_data_->num_features(), 1); + } else { + for (int feat : allowed_features) { + int inner_feat = train_data_->InnerFeatureIndex(feat); + ret[inner_feat] = 1; + } + return ret; + } + } if (need_reset_bytree_) { auto used_feature_cnt = GetCnt(used_feature_indices_.size(), fraction_bynode_); + std::vector* allowed_used_feature_indices; + std::vector filtered_feature_indices; + if (interaction_constraints_.empty()) { + allowed_used_feature_indices = &used_feature_indices_; + } else { + for (int feat_ind : used_feature_indices_) { + if (allowed_features.count(valid_feature_indices_[feat_ind]) == 1) { + filtered_feature_indices.push_back(feat_ind); + } + } + used_feature_cnt = std::min(used_feature_cnt, static_cast(filtered_feature_indices.size())); + allowed_used_feature_indices = &filtered_feature_indices; + } auto sampled_indices = random_.Sample( - static_cast(used_feature_indices_.size()), used_feature_cnt); + static_cast((*allowed_used_feature_indices).size()), used_feature_cnt); int omp_loop_size = static_cast(sampled_indices.size()); #pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024) for (int i = 0; i < omp_loop_size; ++i) { int used_feature = - valid_feature_indices_[used_feature_indices_[sampled_indices[i]]]; + valid_feature_indices_[(*allowed_used_feature_indices)[sampled_indices[i]]]; int inner_feature_index = train_data_->InnerFeatureIndex(used_feature); ret[inner_feature_index] = 1; } } else { auto used_feature_cnt = GetCnt(valid_feature_indices_.size(), fraction_bynode_); + std::vector* allowed_valid_feature_indices; + std::vector filtered_feature_indices; + if (interaction_constraints_.empty()) { + allowed_valid_feature_indices = &valid_feature_indices_; + } else { + for (int feat : valid_feature_indices_) { + if (allowed_features.count(feat) == 1) { + filtered_feature_indices.push_back(feat); + } + } + allowed_valid_feature_indices = &filtered_feature_indices; + used_feature_cnt = std::min(used_feature_cnt, static_cast(filtered_feature_indices.size())); + } auto sampled_indices = random_.Sample( - static_cast(valid_feature_indices_.size()), used_feature_cnt); + static_cast((*allowed_valid_feature_indices).size()), used_feature_cnt); int omp_loop_size = static_cast(sampled_indices.size()); #pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024) for (int i = 0; i < omp_loop_size; ++i) { - int used_feature = valid_feature_indices_[sampled_indices[i]]; + int used_feature = (*allowed_valid_feature_indices)[sampled_indices[i]]; int inner_feature_index = train_data_->InnerFeatureIndex(used_feature); ret[inner_feature_index] = 1; } @@ -135,6 +195,8 @@ class ColSampler { std::vector is_feature_used_; std::vector used_feature_indices_; std::vector valid_feature_indices_; + /*! \brief interaction constraints index in original (raw data) features */ + std::vector> interaction_constraints_; }; } // namespace LightGBM diff --git a/src/treelearner/data_parallel_tree_learner.cpp b/src/treelearner/data_parallel_tree_learner.cpp index f91dcdc9b250..0d6f9df251b6 100644 --- a/src/treelearner/data_parallel_tree_learner.cpp +++ b/src/treelearner/data_parallel_tree_learner.cpp @@ -152,7 +152,7 @@ void DataParallelTreeLearner::BeforeTrain() { } template -void DataParallelTreeLearner::FindBestSplits() { +void DataParallelTreeLearner::FindBestSplits(const Tree* tree) { TREELEARNER_T::ConstructHistograms( this->col_sampler_.is_feature_used_bytree(), true); // construct local histograms @@ -169,17 +169,17 @@ void DataParallelTreeLearner::FindBestSplits() { Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(hist_t), block_start_.data(), block_len_.data(), output_buffer_.data(), static_cast(output_buffer_.size()), &HistogramSumReducer); this->FindBestSplitsFromHistograms( - this->col_sampler_.is_feature_used_bytree(), true); + this->col_sampler_.is_feature_used_bytree(), true, tree); } template -void DataParallelTreeLearner::FindBestSplitsFromHistograms(const std::vector&, bool) { +void DataParallelTreeLearner::FindBestSplitsFromHistograms(const std::vector&, bool, const Tree* tree) { std::vector smaller_bests_per_thread(this->share_state_->num_threads); std::vector larger_bests_per_thread(this->share_state_->num_threads); std::vector smaller_node_used_features = - this->col_sampler_.GetByNode(); + this->col_sampler_.GetByNode(tree, this->smaller_leaf_splits_->leaf_index()); std::vector larger_node_used_features = - this->col_sampler_.GetByNode(); + this->col_sampler_.GetByNode(tree, this->larger_leaf_splits_->leaf_index()); OMP_INIT_EX(); #pragma omp parallel for schedule(static) for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) { diff --git a/src/treelearner/feature_parallel_tree_learner.cpp b/src/treelearner/feature_parallel_tree_learner.cpp index 74df187d46b2..c5202f3d706d 100644 --- a/src/treelearner/feature_parallel_tree_learner.cpp +++ b/src/treelearner/feature_parallel_tree_learner.cpp @@ -57,8 +57,9 @@ void FeatureParallelTreeLearner::BeforeTrain() { } template -void FeatureParallelTreeLearner::FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract) { - TREELEARNER_T::FindBestSplitsFromHistograms(is_feature_used, use_subtract); +void FeatureParallelTreeLearner::FindBestSplitsFromHistograms( + const std::vector& is_feature_used, bool use_subtract, const Tree* tree) { + TREELEARNER_T::FindBestSplitsFromHistograms(is_feature_used, use_subtract, tree); SplitInfo smaller_best_split, larger_best_split; // get best split at smaller leaf smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()]; diff --git a/src/treelearner/gpu_tree_learner.cpp b/src/treelearner/gpu_tree_learner.cpp index f8f57e4b4236..43ccadfd176f 100644 --- a/src/treelearner/gpu_tree_learner.cpp +++ b/src/treelearner/gpu_tree_learner.cpp @@ -1055,8 +1055,8 @@ void GPUTreeLearner::ConstructHistograms(const std::vector& is_feature_u } } -void GPUTreeLearner::FindBestSplits() { - SerialTreeLearner::FindBestSplits(); +void GPUTreeLearner::FindBestSplits(const Tree* tree) { + SerialTreeLearner::FindBestSplits(tree); #if GPU_DEBUG >= 3 for (int feature_index = 0; feature_index < num_features_; ++feature_index) { diff --git a/src/treelearner/gpu_tree_learner.h b/src/treelearner/gpu_tree_learner.h index 428b2b5a5a06..a909c57cbadc 100644 --- a/src/treelearner/gpu_tree_learner.h +++ b/src/treelearner/gpu_tree_learner.h @@ -66,7 +66,7 @@ class GPUTreeLearner: public SerialTreeLearner { protected: void BeforeTrain() override; bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override; - void FindBestSplits() override; + void FindBestSplits(const Tree* tree) override; void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override; void ConstructHistograms(const std::vector& is_feature_used, bool use_subtract) override; diff --git a/src/treelearner/parallel_tree_learner.h b/src/treelearner/parallel_tree_learner.h index dde47d4989da..137697408e8d 100644 --- a/src/treelearner/parallel_tree_learner.h +++ b/src/treelearner/parallel_tree_learner.h @@ -31,7 +31,7 @@ class FeatureParallelTreeLearner: public TREELEARNER_T { protected: void BeforeTrain() override; - void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract) override; + void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract, const Tree* tree) override; private: /*! \brief rank of local machine */ @@ -59,8 +59,8 @@ class DataParallelTreeLearner: public TREELEARNER_T { protected: void BeforeTrain() override; - void FindBestSplits() override; - void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract) override; + void FindBestSplits(const Tree* tree) override; + void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract, const Tree* tree) override; void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override; inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override { @@ -114,8 +114,8 @@ class VotingParallelTreeLearner: public TREELEARNER_T { protected: void BeforeTrain() override; bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override; - void FindBestSplits() override; - void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract) override; + void FindBestSplits(const Tree* tree) override; + void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract, const Tree* tree) override; void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override; inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override { diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp index a68a65ee91b4..db5cd0b4395d 100644 --- a/src/treelearner/serial_tree_learner.cpp +++ b/src/treelearner/serial_tree_learner.cpp @@ -163,7 +163,8 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians // some initial works before training BeforeTrain(); - auto tree = std::unique_ptr(new Tree(config_->num_leaves)); + bool track_branch_features = !(config_->interaction_constraints_vector.empty()); + auto tree = std::unique_ptr(new Tree(config_->num_leaves, track_branch_features)); auto tree_prt = tree.get(); constraints_->ShareTreePointer(tree_prt); @@ -179,7 +180,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians // some initial works before finding best split if (BeforeFindBestSplit(tree_prt, left_leaf, right_leaf)) { // find best threshold for every feature - FindBestSplits(); + FindBestSplits(tree_prt); } // Get a leaf with max split gain int best_leaf = static_cast(ArrayArgs::ArgMax(best_split_per_leaf_)); @@ -310,7 +311,7 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int return true; } -void SerialTreeLearner::FindBestSplits() { +void SerialTreeLearner::FindBestSplits(const Tree* tree) { std::vector is_feature_used(num_features_, 0); #pragma omp parallel for schedule(static, 256) if (num_features_ >= 512) for (int feature_index = 0; feature_index < num_features_; ++feature_index) { @@ -324,7 +325,7 @@ void SerialTreeLearner::FindBestSplits() { } bool use_subtract = parent_leaf_histogram_array_ != nullptr; ConstructHistograms(is_feature_used, use_subtract); - FindBestSplitsFromHistograms(is_feature_used, use_subtract); + FindBestSplitsFromHistograms(is_feature_used, use_subtract, tree); } void SerialTreeLearner::ConstructHistograms( @@ -353,13 +354,16 @@ void SerialTreeLearner::ConstructHistograms( } void SerialTreeLearner::FindBestSplitsFromHistograms( - const std::vector& is_feature_used, bool use_subtract) { + const std::vector& is_feature_used, bool use_subtract, const Tree* tree) { Common::FunctionTimer fun_timer( "SerialTreeLearner::FindBestSplitsFromHistograms", global_timer); std::vector smaller_best(share_state_->num_threads); std::vector larger_best(share_state_->num_threads); - std::vector smaller_node_used_features = col_sampler_.GetByNode(); - std::vector larger_node_used_features = col_sampler_.GetByNode(); + std::vector smaller_node_used_features = col_sampler_.GetByNode(tree, smaller_leaf_splits_->leaf_index()); + std::vector larger_node_used_features; + if (larger_leaf_splits_->leaf_index() >= 0) { + larger_node_used_features = col_sampler_.GetByNode(tree, larger_leaf_splits_->leaf_index()); + } OMP_INIT_EX(); // find splits #pragma omp parallel for schedule(static) num_threads(share_state_->num_threads) @@ -437,7 +441,7 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, int* left_leaf, // before processing next node from queue, store info for current left/right leaf // store "best split" for left and right, even if they might be overwritten by forced split if (BeforeFindBestSplit(tree, *left_leaf, *right_leaf)) { - FindBestSplits(); + FindBestSplits(tree); } // then, compute own splits SplitInfo left_split; diff --git a/src/treelearner/serial_tree_learner.h b/src/treelearner/serial_tree_learner.h index 6a0d7f0e9a6d..e6ac8e3ad09c 100644 --- a/src/treelearner/serial_tree_learner.h +++ b/src/treelearner/serial_tree_learner.h @@ -134,11 +134,11 @@ class SerialTreeLearner: public TreeLearner { */ virtual bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf); - virtual void FindBestSplits(); + virtual void FindBestSplits(const Tree* tree); virtual void ConstructHistograms(const std::vector& is_feature_used, bool use_subtract); - virtual void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract); + virtual void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract, const Tree*); /*! * \brief Partition tree and data according best split. @@ -196,7 +196,6 @@ class SerialTreeLearner: public TreeLearner { std::unique_ptr smaller_leaf_splits_; /*! \brief stores best thresholds for all feature for larger leaf */ std::unique_ptr larger_leaf_splits_; - #ifdef USE_GPU /*! \brief gradients of current iteration, ordered for cache optimized, aligned to 4K page */ std::vector> ordered_gradients_; diff --git a/src/treelearner/voting_parallel_tree_learner.cpp b/src/treelearner/voting_parallel_tree_learner.cpp index d14e0d614ce0..1c9c36ba8bbd 100644 --- a/src/treelearner/voting_parallel_tree_learner.cpp +++ b/src/treelearner/voting_parallel_tree_learner.cpp @@ -241,7 +241,7 @@ void VotingParallelTreeLearner::CopyLocalHistogram(const std::vec } template -void VotingParallelTreeLearner::FindBestSplits() { +void VotingParallelTreeLearner::FindBestSplits(const Tree* tree) { // use local data to find local best splits std::vector is_feature_used(this->num_features_, 0); #pragma omp parallel for schedule(static) @@ -343,17 +343,17 @@ void VotingParallelTreeLearner::FindBestSplits() { Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(hist_t), block_start_.data(), block_len_.data(), output_buffer_.data(), static_cast(output_buffer_.size()), &HistogramSumReducer); - this->FindBestSplitsFromHistograms(is_feature_used, false); + this->FindBestSplitsFromHistograms(is_feature_used, false, tree); } template -void VotingParallelTreeLearner::FindBestSplitsFromHistograms(const std::vector&, bool) { +void VotingParallelTreeLearner::FindBestSplitsFromHistograms(const std::vector&, bool, const Tree* tree) { std::vector smaller_bests_per_thread(this->share_state_->num_threads); std::vector larger_bests_per_thread(this->share_state_->num_threads); std::vector smaller_node_used_features = - this->col_sampler_.GetByNode(); + this->col_sampler_.GetByNode(tree, this->smaller_leaf_splits_->leaf_index()); std::vector larger_node_used_features = - this->col_sampler_.GetByNode(); + this->col_sampler_.GetByNode(tree, this->larger_leaf_splits_->leaf_index()); // find best split from local aggregated histograms OMP_INIT_EX(); diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 36f532dd3a62..dc48fc9d3a39 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -2185,3 +2185,28 @@ def _imptcs_to_numpy(X, impcts_dict): 'split_gain', 'threshold', 'decision_type', 'missing_direction', 'missing_type', 'weight', 'count'): self.assertIsNone(tree_df.loc[0, col]) + + def test_interaction_constraints(self): + X, y = load_boston(True) + num_features = X.shape[1] + train_data = lgb.Dataset(X, label=y) + # check that constraint containing all features is equivalent to no constraint + params = {'verbose': -1, + 'seed': 0} + est = lgb.train(params, train_data, num_boost_round=10) + pred1 = est.predict(X) + est = lgb.train(dict(params, interation_constraints=[list(range(num_features))]), train_data, + num_boost_round=10) + pred2 = est.predict(X) + np.testing.assert_allclose(pred1, pred2) + # check that constraint partitioning the features reduces train accuracy + est = lgb.train(dict(params, interaction_constraints=[list(range(num_features // 2)), + list(range(num_features // 2, num_features))]), + train_data, num_boost_round=10) + pred3 = est.predict(X) + self.assertLess(mean_squared_error(y, pred1), mean_squared_error(y, pred3)) + # check that constraints consisting of single features reduce accuracy further + est = lgb.train(dict(params, interaction_constraints=[[i] for i in range(num_features)]), train_data, + num_boost_round=10) + pred4 = est.predict(X) + self.assertLess(mean_squared_error(y, pred3), mean_squared_error(y, pred4))