diff --git a/R-package/R/lgb.train.R b/R-package/R/lgb.train.R
index a78bf72eb203..1ba5bf086647 100644
--- a/R-package/R/lgb.train.R
+++ b/R-package/R/lgb.train.R
@@ -124,6 +124,10 @@ lgb.train <- function(params = list(),
end_iteration <- begin_iteration + nrounds - 1L
}
+ if (!is.null(params[["interaction_constraints"]])) {
+ stop("lgb.train: interaction_constraints is not implemented")
+ }
+
# Update parameters with parsed parameters
data$update_params(params)
diff --git a/docs/Parameters.rst b/docs/Parameters.rst
index 64251f8d4573..01362fb9af34 100644
--- a/docs/Parameters.rst
+++ b/docs/Parameters.rst
@@ -538,6 +538,20 @@ Learning Control Parameters
- note that the parent output ``w_p`` itself has smoothing applied, unless it is the root node, so that the smoothing effect accumulates with the tree depth
+- ``interaction_constraints`` :raw-html:`🔗︎`, default = ``""``, type = string
+
+ - controls which features can appear in the same branch
+
+ - by default interaction constraints are disabled, to enable them you can specify
+
+ - for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]``
+
+ - for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]``
+
+ - for R-package, **not yet supported**
+
+ - any two features can only appear in the same branch only if there exists a constraint containing both features
+
- ``verbosity`` :raw-html:`🔗︎`, default = ``1``, type = int, aliases: ``verbose``
- controls the level of LightGBM's verbosity
diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h
index 5cdc6139dc0e..2a3335c1c0ad 100644
--- a/include/LightGBM/config.h
+++ b/include/LightGBM/config.h
@@ -505,6 +505,14 @@ struct Config {
// descl2 = note that the parent output ``w_p`` itself has smoothing applied, unless it is the root node, so that the smoothing effect accumulates with the tree depth
double path_smooth = 0;
+ // desc = controls which features can appear in the same branch
+ // desc = by default interaction constraints are disabled, to enable them you can specify
+ // descl2 = for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]``
+ // descl2 = for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]``
+ // descl2 = for R-package, **not yet supported**
+ // desc = any two features can only appear in the same branch only if there exists a constraint containing both features
+ std::string interaction_constraints = "";
+
// alias = verbose
// desc = controls the level of LightGBM's verbosity
// desc = ``< 0``: Fatal, ``= 0``: Error (Warning), ``= 1``: Info, ``> 1``: Debug
@@ -958,12 +966,14 @@ struct Config {
static const std::unordered_map& alias_table();
static const std::unordered_set& parameter_set();
std::vector> auc_mu_weights_matrix;
+ std::vector> interaction_constraints_vector;
private:
void CheckParamConflict();
void GetMembersFromString(const std::unordered_map& params);
std::string SaveMembersToString() const;
void GetAucMuWeights();
+ void GetInteractionConstraints();
};
inline bool Config::GetString(
diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h
index 467ce0c652f8..5ce3ff9b3eb1 100644
--- a/include/LightGBM/tree.h
+++ b/include/LightGBM/tree.h
@@ -27,8 +27,9 @@ class Tree {
/*!
* \brief Constructor
* \param max_leaves The number of max leaves
+ * \param track_branch_features Whether to keep track of ancestors of leaf nodes
*/
- explicit Tree(int max_leaves);
+ explicit Tree(int max_leaves, bool track_branch_features);
/*!
* \brief Constructor, from a string
@@ -148,6 +149,9 @@ class Tree {
/*! \brief Get feature of specific split*/
inline int split_feature(int split_idx) const { return split_feature_[split_idx]; }
+ /*! \brief Get features on leaf's branch*/
+ inline std::vector branch_features(int leaf) const { return branch_features_[leaf]; }
+
inline double split_gain(int split_idx) const { return split_gain_[split_idx]; }
inline double internal_value(int node_idx) const {
@@ -436,6 +440,10 @@ class Tree {
std::vector internal_count_;
/*! \brief Depth for leaves */
std::vector leaf_depth_;
+ /*! \brief whether to keep track of ancestor nodes for each leaf (only needed when feature interactions are restricted) */
+ bool track_branch_features_;
+ /*! \brief Features on leaf's branch, original index */
+ std::vector> branch_features_;
double shrinkage_;
int max_depth_;
};
@@ -477,6 +485,11 @@ inline void Tree::Split(int leaf, int feature, int real_feature,
// update leaf depth
leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1;
leaf_depth_[leaf]++;
+ if (track_branch_features_) {
+ branch_features_[num_leaves_] = branch_features_[leaf];
+ branch_features_[num_leaves_].push_back(split_feature_[new_node_idx]);
+ branch_features_[leaf].push_back(split_feature_[new_node_idx]);
+ }
}
inline double Tree::Predict(const double* feature_values) const {
diff --git a/include/LightGBM/utils/common.h b/include/LightGBM/utils/common.h
index 825d9692f762..bdc769e52226 100644
--- a/include/LightGBM/utils/common.h
+++ b/include/LightGBM/utils/common.h
@@ -103,6 +103,30 @@ inline static std::vector Split(const char* c_str, char delimiter)
return ret;
}
+inline static std::vector SplitBrackets(const char* c_str, char left_delimiter, char right_delimiter) {
+ std::vector ret;
+ std::string str(c_str);
+ size_t i = 0;
+ size_t pos = 0;
+ bool open = false;
+ while (pos < str.length()) {
+ if (str[pos] == left_delimiter) {
+ open = true;
+ ++pos;
+ i = pos;
+ } else if (str[pos] == right_delimiter && open) {
+ if (i < pos) {
+ ret.push_back(str.substr(i, pos - i));
+ }
+ open = false;
+ ++pos;
+ } else {
+ ++pos;
+ }
+ }
+ return ret;
+}
+
inline static std::vector SplitLines(const char* c_str) {
std::vector ret;
std::string str(c_str);
@@ -503,6 +527,17 @@ inline static std::vector StringToArray(const std::string& str, char delimite
return ret;
}
+template
+inline static std::vector> StringToArrayofArrays(
+ const std::string& str, char left_bracket, char right_bracket, char delimiter) {
+ std::vector strs = SplitBrackets(str.c_str(), left_bracket, right_bracket);
+ std::vector> ret;
+ for (const auto& s : strs) {
+ ret.push_back(StringToArray(s, delimiter));
+ }
+ return ret;
+}
+
template
inline static std::vector StringToArray(const std::string& str, int n) {
if (n == 0) {
diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py
index 07b7efd410e3..01a5f31e51b6 100644
--- a/python-package/lightgbm/basic.py
+++ b/python-package/lightgbm/basic.py
@@ -135,7 +135,12 @@ def param_dict_to_str(data):
pairs = []
for key, val in data.items():
if isinstance(val, (list, tuple, set)) or is_numpy_1d_array(val):
- pairs.append(str(key) + '=' + ','.join(map(str, val)))
+ def to_string(x):
+ if isinstance(x, list):
+ return "[{}]".format(','.join(map(str, x)))
+ else:
+ return str(x)
+ pairs.append(str(key) + '=' + ','.join(map(to_string, val)))
elif isinstance(val, string_type) or isinstance(val, numeric_types) or is_numeric(val):
pairs.append(str(key) + '=' + str(val))
elif val is not None:
diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp
index 25a0946c90de..7871bbfb086c 100644
--- a/src/boosting/gbdt.cpp
+++ b/src/boosting/gbdt.cpp
@@ -352,7 +352,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
bool should_continue = false;
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
const size_t offset = static_cast(cur_tree_id) * num_data_;
- std::unique_ptr new_tree(new Tree(2));
+ std::unique_ptr new_tree(new Tree(2, false));
if (class_need_train_[cur_tree_id] && train_data_->num_features() > 0) {
auto grad = gradients + offset;
auto hess = hessians + offset;
diff --git a/src/boosting/rf.hpp b/src/boosting/rf.hpp
index dd9be038aac9..5c90202a515e 100644
--- a/src/boosting/rf.hpp
+++ b/src/boosting/rf.hpp
@@ -109,7 +109,7 @@ class RF : public GBDT {
gradients = gradients_.data();
hessians = hessians_.data();
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
- std::unique_ptr new_tree(new Tree(2));
+ std::unique_ptr new_tree(new Tree(2, false));
size_t offset = static_cast(cur_tree_id)* num_data_;
if (class_need_train_[cur_tree_id]) {
auto grad = gradients + offset;
diff --git a/src/io/config.cpp b/src/io/config.cpp
index d31b7b839a3e..d569a7401e17 100644
--- a/src/io/config.cpp
+++ b/src/io/config.cpp
@@ -180,6 +180,14 @@ void Config::GetAucMuWeights() {
}
}
+void Config::GetInteractionConstraints() {
+ if (interaction_constraints == "") {
+ interaction_constraints_vector = std::vector>();
+ } else {
+ interaction_constraints_vector = Common::StringToArrayofArrays(interaction_constraints, '[', ']', ',');
+ }
+}
+
void Config::Set(const std::unordered_map& params) {
// generate seeds by seed.
if (GetInt(params, "seed", &seed)) {
@@ -204,6 +212,8 @@ void Config::Set(const std::unordered_map& params) {
GetAucMuWeights();
+ GetInteractionConstraints();
+
// sort eval_at
std::sort(eval_at.begin(), eval_at.end());
diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp
index 5881571d16f3..807cad785021 100644
--- a/src/io/config_auto.cpp
+++ b/src/io/config_auto.cpp
@@ -230,6 +230,7 @@ const std::unordered_set& Config::parameter_set() {
"cegb_penalty_feature_lazy",
"cegb_penalty_feature_coupled",
"path_smooth",
+ "interaction_constraints",
"verbosity",
"input_model",
"output_model",
@@ -454,6 +455,8 @@ void Config::GetMembersFromString(const std::unordered_map>(max_leaves_);
+ }
// root is in the depth 0
leaf_depth_[0] = 0;
num_leaves_ = 1;
diff --git a/src/treelearner/col_sampler.hpp b/src/treelearner/col_sampler.hpp
index 68a98d159271..cd2884812552 100644
--- a/src/treelearner/col_sampler.hpp
+++ b/src/treelearner/col_sampler.hpp
@@ -13,6 +13,7 @@
#include
#include
+#include
#include
namespace LightGBM {
@@ -23,6 +24,10 @@ class ColSampler {
fraction_bynode_(config->feature_fraction_bynode),
seed_(config->feature_fraction_seed),
random_(config->feature_fraction_seed) {
+ for (auto constraint : config->interaction_constraints_vector) {
+ std::unordered_set constraint_set(constraint.begin(), constraint.end());
+ interaction_constraints_.push_back(constraint_set);
+ }
}
static int GetCnt(size_t total_cnt, double fraction) {
@@ -83,32 +88,87 @@ class ColSampler {
}
}
- std::vector GetByNode() {
- if (fraction_bynode_ >= 1.0f) {
- return std::vector(train_data_->num_features(), 1);
+ std::vector GetByNode(const Tree* tree, int leaf) {
+ // get interaction constraints for current branch
+ std::unordered_set allowed_features;
+ if (!interaction_constraints_.empty()) {
+ std::vector branch_features = tree->branch_features(leaf);
+ allowed_features.insert(branch_features.begin(), branch_features.end());
+ for (auto constraint : interaction_constraints_) {
+ int num_feat_found = 0;
+ if (branch_features.size() == 0) {
+ allowed_features.insert(constraint.begin(), constraint.end());
+ }
+ for (int feat : branch_features) {
+ if (constraint.count(feat) == 0) { break; }
+ ++num_feat_found;
+ if (num_feat_found == static_cast(branch_features.size())) {
+ allowed_features.insert(constraint.begin(), constraint.end());
+ break;
+ }
+ }
+ }
}
+
std::vector ret(train_data_->num_features(), 0);
+ if (fraction_bynode_ >= 1.0f) {
+ if (interaction_constraints_.empty()) {
+ return std::vector(train_data_->num_features(), 1);
+ } else {
+ for (int feat : allowed_features) {
+ int inner_feat = train_data_->InnerFeatureIndex(feat);
+ ret[inner_feat] = 1;
+ }
+ return ret;
+ }
+ }
if (need_reset_bytree_) {
auto used_feature_cnt = GetCnt(used_feature_indices_.size(), fraction_bynode_);
+ std::vector* allowed_used_feature_indices;
+ std::vector filtered_feature_indices;
+ if (interaction_constraints_.empty()) {
+ allowed_used_feature_indices = &used_feature_indices_;
+ } else {
+ for (int feat_ind : used_feature_indices_) {
+ if (allowed_features.count(valid_feature_indices_[feat_ind]) == 1) {
+ filtered_feature_indices.push_back(feat_ind);
+ }
+ }
+ used_feature_cnt = std::min(used_feature_cnt, static_cast(filtered_feature_indices.size()));
+ allowed_used_feature_indices = &filtered_feature_indices;
+ }
auto sampled_indices = random_.Sample(
- static_cast(used_feature_indices_.size()), used_feature_cnt);
+ static_cast((*allowed_used_feature_indices).size()), used_feature_cnt);
int omp_loop_size = static_cast(sampled_indices.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
for (int i = 0; i < omp_loop_size; ++i) {
int used_feature =
- valid_feature_indices_[used_feature_indices_[sampled_indices[i]]];
+ valid_feature_indices_[(*allowed_used_feature_indices)[sampled_indices[i]]];
int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
ret[inner_feature_index] = 1;
}
} else {
auto used_feature_cnt =
GetCnt(valid_feature_indices_.size(), fraction_bynode_);
+ std::vector* allowed_valid_feature_indices;
+ std::vector filtered_feature_indices;
+ if (interaction_constraints_.empty()) {
+ allowed_valid_feature_indices = &valid_feature_indices_;
+ } else {
+ for (int feat : valid_feature_indices_) {
+ if (allowed_features.count(feat) == 1) {
+ filtered_feature_indices.push_back(feat);
+ }
+ }
+ allowed_valid_feature_indices = &filtered_feature_indices;
+ used_feature_cnt = std::min(used_feature_cnt, static_cast(filtered_feature_indices.size()));
+ }
auto sampled_indices = random_.Sample(
- static_cast(valid_feature_indices_.size()), used_feature_cnt);
+ static_cast((*allowed_valid_feature_indices).size()), used_feature_cnt);
int omp_loop_size = static_cast(sampled_indices.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
for (int i = 0; i < omp_loop_size; ++i) {
- int used_feature = valid_feature_indices_[sampled_indices[i]];
+ int used_feature = (*allowed_valid_feature_indices)[sampled_indices[i]];
int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
ret[inner_feature_index] = 1;
}
@@ -135,6 +195,8 @@ class ColSampler {
std::vector is_feature_used_;
std::vector used_feature_indices_;
std::vector valid_feature_indices_;
+ /*! \brief interaction constraints index in original (raw data) features */
+ std::vector> interaction_constraints_;
};
} // namespace LightGBM
diff --git a/src/treelearner/data_parallel_tree_learner.cpp b/src/treelearner/data_parallel_tree_learner.cpp
index f91dcdc9b250..0d6f9df251b6 100644
--- a/src/treelearner/data_parallel_tree_learner.cpp
+++ b/src/treelearner/data_parallel_tree_learner.cpp
@@ -152,7 +152,7 @@ void DataParallelTreeLearner::BeforeTrain() {
}
template
-void DataParallelTreeLearner::FindBestSplits() {
+void DataParallelTreeLearner::FindBestSplits(const Tree* tree) {
TREELEARNER_T::ConstructHistograms(
this->col_sampler_.is_feature_used_bytree(), true);
// construct local histograms
@@ -169,17 +169,17 @@ void DataParallelTreeLearner::FindBestSplits() {
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(hist_t), block_start_.data(),
block_len_.data(), output_buffer_.data(), static_cast(output_buffer_.size()), &HistogramSumReducer);
this->FindBestSplitsFromHistograms(
- this->col_sampler_.is_feature_used_bytree(), true);
+ this->col_sampler_.is_feature_used_bytree(), true, tree);
}
template
-void DataParallelTreeLearner::FindBestSplitsFromHistograms(const std::vector&, bool) {
+void DataParallelTreeLearner::FindBestSplitsFromHistograms(const std::vector&, bool, const Tree* tree) {
std::vector smaller_bests_per_thread(this->share_state_->num_threads);
std::vector larger_bests_per_thread(this->share_state_->num_threads);
std::vector smaller_node_used_features =
- this->col_sampler_.GetByNode();
+ this->col_sampler_.GetByNode(tree, this->smaller_leaf_splits_->leaf_index());
std::vector larger_node_used_features =
- this->col_sampler_.GetByNode();
+ this->col_sampler_.GetByNode(tree, this->larger_leaf_splits_->leaf_index());
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
diff --git a/src/treelearner/feature_parallel_tree_learner.cpp b/src/treelearner/feature_parallel_tree_learner.cpp
index 74df187d46b2..c5202f3d706d 100644
--- a/src/treelearner/feature_parallel_tree_learner.cpp
+++ b/src/treelearner/feature_parallel_tree_learner.cpp
@@ -57,8 +57,9 @@ void FeatureParallelTreeLearner::BeforeTrain() {
}
template
-void FeatureParallelTreeLearner::FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract) {
- TREELEARNER_T::FindBestSplitsFromHistograms(is_feature_used, use_subtract);
+void FeatureParallelTreeLearner::FindBestSplitsFromHistograms(
+ const std::vector& is_feature_used, bool use_subtract, const Tree* tree) {
+ TREELEARNER_T::FindBestSplitsFromHistograms(is_feature_used, use_subtract, tree);
SplitInfo smaller_best_split, larger_best_split;
// get best split at smaller leaf
smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()];
diff --git a/src/treelearner/gpu_tree_learner.cpp b/src/treelearner/gpu_tree_learner.cpp
index f8f57e4b4236..43ccadfd176f 100644
--- a/src/treelearner/gpu_tree_learner.cpp
+++ b/src/treelearner/gpu_tree_learner.cpp
@@ -1055,8 +1055,8 @@ void GPUTreeLearner::ConstructHistograms(const std::vector& is_feature_u
}
}
-void GPUTreeLearner::FindBestSplits() {
- SerialTreeLearner::FindBestSplits();
+void GPUTreeLearner::FindBestSplits(const Tree* tree) {
+ SerialTreeLearner::FindBestSplits(tree);
#if GPU_DEBUG >= 3
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
diff --git a/src/treelearner/gpu_tree_learner.h b/src/treelearner/gpu_tree_learner.h
index 428b2b5a5a06..a909c57cbadc 100644
--- a/src/treelearner/gpu_tree_learner.h
+++ b/src/treelearner/gpu_tree_learner.h
@@ -66,7 +66,7 @@ class GPUTreeLearner: public SerialTreeLearner {
protected:
void BeforeTrain() override;
bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
- void FindBestSplits() override;
+ void FindBestSplits(const Tree* tree) override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
void ConstructHistograms(const std::vector& is_feature_used, bool use_subtract) override;
diff --git a/src/treelearner/parallel_tree_learner.h b/src/treelearner/parallel_tree_learner.h
index dde47d4989da..137697408e8d 100644
--- a/src/treelearner/parallel_tree_learner.h
+++ b/src/treelearner/parallel_tree_learner.h
@@ -31,7 +31,7 @@ class FeatureParallelTreeLearner: public TREELEARNER_T {
protected:
void BeforeTrain() override;
- void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract) override;
+ void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract, const Tree* tree) override;
private:
/*! \brief rank of local machine */
@@ -59,8 +59,8 @@ class DataParallelTreeLearner: public TREELEARNER_T {
protected:
void BeforeTrain() override;
- void FindBestSplits() override;
- void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract) override;
+ void FindBestSplits(const Tree* tree) override;
+ void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract, const Tree* tree) override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
@@ -114,8 +114,8 @@ class VotingParallelTreeLearner: public TREELEARNER_T {
protected:
void BeforeTrain() override;
bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) override;
- void FindBestSplits() override;
- void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract) override;
+ void FindBestSplits(const Tree* tree) override;
+ void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract, const Tree* tree) override;
void Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) override;
inline data_size_t GetGlobalDataCountInLeaf(int leaf_idx) const override {
diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp
index a68a65ee91b4..db5cd0b4395d 100644
--- a/src/treelearner/serial_tree_learner.cpp
+++ b/src/treelearner/serial_tree_learner.cpp
@@ -163,7 +163,8 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
// some initial works before training
BeforeTrain();
- auto tree = std::unique_ptr(new Tree(config_->num_leaves));
+ bool track_branch_features = !(config_->interaction_constraints_vector.empty());
+ auto tree = std::unique_ptr(new Tree(config_->num_leaves, track_branch_features));
auto tree_prt = tree.get();
constraints_->ShareTreePointer(tree_prt);
@@ -179,7 +180,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
// some initial works before finding best split
if (BeforeFindBestSplit(tree_prt, left_leaf, right_leaf)) {
// find best threshold for every feature
- FindBestSplits();
+ FindBestSplits(tree_prt);
}
// Get a leaf with max split gain
int best_leaf = static_cast(ArrayArgs::ArgMax(best_split_per_leaf_));
@@ -310,7 +311,7 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int
return true;
}
-void SerialTreeLearner::FindBestSplits() {
+void SerialTreeLearner::FindBestSplits(const Tree* tree) {
std::vector is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(static, 256) if (num_features_ >= 512)
for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
@@ -324,7 +325,7 @@ void SerialTreeLearner::FindBestSplits() {
}
bool use_subtract = parent_leaf_histogram_array_ != nullptr;
ConstructHistograms(is_feature_used, use_subtract);
- FindBestSplitsFromHistograms(is_feature_used, use_subtract);
+ FindBestSplitsFromHistograms(is_feature_used, use_subtract, tree);
}
void SerialTreeLearner::ConstructHistograms(
@@ -353,13 +354,16 @@ void SerialTreeLearner::ConstructHistograms(
}
void SerialTreeLearner::FindBestSplitsFromHistograms(
- const std::vector& is_feature_used, bool use_subtract) {
+ const std::vector& is_feature_used, bool use_subtract, const Tree* tree) {
Common::FunctionTimer fun_timer(
"SerialTreeLearner::FindBestSplitsFromHistograms", global_timer);
std::vector smaller_best(share_state_->num_threads);
std::vector larger_best(share_state_->num_threads);
- std::vector smaller_node_used_features = col_sampler_.GetByNode();
- std::vector larger_node_used_features = col_sampler_.GetByNode();
+ std::vector smaller_node_used_features = col_sampler_.GetByNode(tree, smaller_leaf_splits_->leaf_index());
+ std::vector larger_node_used_features;
+ if (larger_leaf_splits_->leaf_index() >= 0) {
+ larger_node_used_features = col_sampler_.GetByNode(tree, larger_leaf_splits_->leaf_index());
+ }
OMP_INIT_EX();
// find splits
#pragma omp parallel for schedule(static) num_threads(share_state_->num_threads)
@@ -437,7 +441,7 @@ int32_t SerialTreeLearner::ForceSplits(Tree* tree, int* left_leaf,
// before processing next node from queue, store info for current left/right leaf
// store "best split" for left and right, even if they might be overwritten by forced split
if (BeforeFindBestSplit(tree, *left_leaf, *right_leaf)) {
- FindBestSplits();
+ FindBestSplits(tree);
}
// then, compute own splits
SplitInfo left_split;
diff --git a/src/treelearner/serial_tree_learner.h b/src/treelearner/serial_tree_learner.h
index 6a0d7f0e9a6d..e6ac8e3ad09c 100644
--- a/src/treelearner/serial_tree_learner.h
+++ b/src/treelearner/serial_tree_learner.h
@@ -134,11 +134,11 @@ class SerialTreeLearner: public TreeLearner {
*/
virtual bool BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf);
- virtual void FindBestSplits();
+ virtual void FindBestSplits(const Tree* tree);
virtual void ConstructHistograms(const std::vector& is_feature_used, bool use_subtract);
- virtual void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract);
+ virtual void FindBestSplitsFromHistograms(const std::vector& is_feature_used, bool use_subtract, const Tree*);
/*!
* \brief Partition tree and data according best split.
@@ -196,7 +196,6 @@ class SerialTreeLearner: public TreeLearner {
std::unique_ptr smaller_leaf_splits_;
/*! \brief stores best thresholds for all feature for larger leaf */
std::unique_ptr larger_leaf_splits_;
-
#ifdef USE_GPU
/*! \brief gradients of current iteration, ordered for cache optimized, aligned to 4K page */
std::vector> ordered_gradients_;
diff --git a/src/treelearner/voting_parallel_tree_learner.cpp b/src/treelearner/voting_parallel_tree_learner.cpp
index d14e0d614ce0..1c9c36ba8bbd 100644
--- a/src/treelearner/voting_parallel_tree_learner.cpp
+++ b/src/treelearner/voting_parallel_tree_learner.cpp
@@ -241,7 +241,7 @@ void VotingParallelTreeLearner::CopyLocalHistogram(const std::vec
}
template
-void VotingParallelTreeLearner::FindBestSplits() {
+void VotingParallelTreeLearner::FindBestSplits(const Tree* tree) {
// use local data to find local best splits
std::vector is_feature_used(this->num_features_, 0);
#pragma omp parallel for schedule(static)
@@ -343,17 +343,17 @@ void VotingParallelTreeLearner::FindBestSplits() {
Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(hist_t), block_start_.data(), block_len_.data(),
output_buffer_.data(), static_cast(output_buffer_.size()), &HistogramSumReducer);
- this->FindBestSplitsFromHistograms(is_feature_used, false);
+ this->FindBestSplitsFromHistograms(is_feature_used, false, tree);
}
template
-void VotingParallelTreeLearner::FindBestSplitsFromHistograms(const std::vector&, bool) {
+void VotingParallelTreeLearner::FindBestSplitsFromHistograms(const std::vector&, bool, const Tree* tree) {
std::vector smaller_bests_per_thread(this->share_state_->num_threads);
std::vector larger_bests_per_thread(this->share_state_->num_threads);
std::vector smaller_node_used_features =
- this->col_sampler_.GetByNode();
+ this->col_sampler_.GetByNode(tree, this->smaller_leaf_splits_->leaf_index());
std::vector larger_node_used_features =
- this->col_sampler_.GetByNode();
+ this->col_sampler_.GetByNode(tree, this->larger_leaf_splits_->leaf_index());
// find best split from local aggregated histograms
OMP_INIT_EX();
diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py
index 36f532dd3a62..dc48fc9d3a39 100644
--- a/tests/python_package_test/test_engine.py
+++ b/tests/python_package_test/test_engine.py
@@ -2185,3 +2185,28 @@ def _imptcs_to_numpy(X, impcts_dict):
'split_gain', 'threshold', 'decision_type', 'missing_direction',
'missing_type', 'weight', 'count'):
self.assertIsNone(tree_df.loc[0, col])
+
+ def test_interaction_constraints(self):
+ X, y = load_boston(True)
+ num_features = X.shape[1]
+ train_data = lgb.Dataset(X, label=y)
+ # check that constraint containing all features is equivalent to no constraint
+ params = {'verbose': -1,
+ 'seed': 0}
+ est = lgb.train(params, train_data, num_boost_round=10)
+ pred1 = est.predict(X)
+ est = lgb.train(dict(params, interation_constraints=[list(range(num_features))]), train_data,
+ num_boost_round=10)
+ pred2 = est.predict(X)
+ np.testing.assert_allclose(pred1, pred2)
+ # check that constraint partitioning the features reduces train accuracy
+ est = lgb.train(dict(params, interaction_constraints=[list(range(num_features // 2)),
+ list(range(num_features // 2, num_features))]),
+ train_data, num_boost_round=10)
+ pred3 = est.predict(X)
+ self.assertLess(mean_squared_error(y, pred1), mean_squared_error(y, pred3))
+ # check that constraints consisting of single features reduce accuracy further
+ est = lgb.train(dict(params, interaction_constraints=[[i] for i in range(num_features)]), train_data,
+ num_boost_round=10)
+ pred4 = est.predict(X)
+ self.assertLess(mean_squared_error(y, pred3), mean_squared_error(y, pred4))