diff --git a/docs/Parameters.rst b/docs/Parameters.rst
index 341cdd487c71..8bf02c1ba4da 100644
--- a/docs/Parameters.rst
+++ b/docs/Parameters.rst
@@ -618,6 +618,32 @@ Learning Control Parameters
- any two features can only appear in the same branch only if there exists a constraint containing both features
+- ``tree_interaction_constraints`` :raw-html:`🔗︎`, default = ``""``, type = string
+
+ - controls which features can appear in the same tree
+
+ - by default interaction constraints are disabled, to enable them you can specify
+
+ - for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]``
+
+ - for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]``
+
+ - for R-package, list of character or numeric vectors, e.g. ``list(c("var1", "var2", "var3"), c("var3", "var4"))`` or ``list(c(1L, 2L, 3L), c(3L, 4L))``. Numeric vectors should use 1-based indexing, where ``1L`` is the first feature, ``2L`` is the second feature, etc
+
+ - any two features can only appear in the same tree only if there exists a constraint containing both features
+
+- ``max_tree_interactions`` :raw-html:`🔗︎`, default = ``0``, type = int, constraints: ``max_tree_interactions >= 0.0``
+
+ - controls how many features can appear in the same tree
+
+ - by default (max_tree_interactions = 0) interaction constraints are disabled
+
+- ``max_interactions`` :raw-html:`🔗︎`, default = ``0``, type = int, constraints: ``max_interactions >= 0.0``
+
+ - controls how many features interactions can be added to the final model
+
+ - by default no limit is imposed on the interaction with max_interactions = 0
+
- ``verbosity`` :raw-html:`🔗︎`, default = ``1``, type = int, aliases: ``verbose``
- controls the level of LightGBM's verbosity
diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h
index 6500cb77272d..b8693814749f 100644
--- a/include/LightGBM/config.h
+++ b/include/LightGBM/config.h
@@ -571,6 +571,24 @@ struct Config {
// desc = any two features can only appear in the same branch only if there exists a constraint containing both features
std::string interaction_constraints = "";
+ // desc = controls which features can appear in the same tree
+ // desc = by default interaction constraints are disabled, to enable them you can specify
+ // descl2 = for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]``
+ // descl2 = for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]``
+ // descl2 = for R-package, list of character or numeric vectors, e.g. ``list(c("var1", "var2", "var3"), c("var3", "var4"))`` or ``list(c(1L, 2L, 3L), c(3L, 4L))``. Numeric vectors should use 1-based indexing, where ``1L`` is the first feature, ``2L`` is the second feature, etc
+ // desc = any two features can only appear in the same tree only if there exists a constraint containing both features
+ std::string tree_interaction_constraints = "";
+
+ // check = >= 0.0
+ // desc = controls how many features can appear in the same tree
+ // desc = by default (max_tree_interactions = 0) interaction constraints are disabled
+ int max_tree_interactions = 0;
+
+ // check = >= 0.0
+ // desc = controls how many features interactions can be added to the final model
+ // desc = by default no limit is imposed on the interaction with max_interactions = 0
+ int max_interactions = 0;
+
// alias = verbose
// desc = controls the level of LightGBM's verbosity
// desc = ``< 0``: Fatal, ``= 0``: Error (Warning), ``= 1``: Info, ``> 1``: Debug
@@ -1126,6 +1144,7 @@ struct Config {
static const std::unordered_set& parameter_set();
std::vector> auc_mu_weights_matrix;
std::vector> interaction_constraints_vector;
+ std::vector> tree_interaction_constraints_vector;
static const std::unordered_map& ParameterTypes();
static const std::string DumpAliases();
@@ -1135,6 +1154,7 @@ struct Config {
std::string SaveMembersToString() const;
void GetAucMuWeights();
void GetInteractionConstraints();
+ void GetTreeInteractionConstraints();
};
inline bool Config::GetString(
diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h
index 0c4a41f46a87..dc1303c00b6b 100644
--- a/include/LightGBM/tree.h
+++ b/include/LightGBM/tree.h
@@ -13,6 +13,7 @@
#include
#include
#include
+#include
namespace LightGBM {
@@ -158,6 +159,11 @@ class Tree {
/*! \brief Get features on leaf's branch*/
inline std::vector branch_features(int leaf) const { return branch_features_[leaf]; }
+ /*! \brief Get unique features used by the current tree*/
+ std::set tree_features() const {
+ return tree_features_;
+ }
+
inline double split_gain(int split_idx) const { return split_gain_[split_idx]; }
inline double internal_value(int node_idx) const {
@@ -319,6 +325,8 @@ class Tree {
inline bool is_linear() const { return is_linear_; }
+ inline bool is_tracking_branch_features() const { return track_branch_features_; }
+
#ifdef USE_CUDA
inline bool is_cuda_tree() const { return is_cuda_tree_; }
#endif // USE_CUDA
@@ -520,6 +528,10 @@ class Tree {
bool track_branch_features_;
/*! \brief Features on leaf's branch, original index */
std::vector> branch_features_;
+
+ /*! \brief Features used by the tree, original index */
+ std::set tree_features_;
+
double shrinkage_;
int max_depth_;
/*! \brief Tree has linear model at each leaf */
@@ -579,6 +591,7 @@ inline void Tree::Split(int leaf, int feature, int real_feature,
branch_features_[num_leaves_] = branch_features_[leaf];
branch_features_[num_leaves_].push_back(split_feature_[new_node_idx]);
branch_features_[leaf].push_back(split_feature_[new_node_idx]);
+ tree_features_.insert(split_feature_[new_node_idx]);
}
}
diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py
index e1779f0723be..8c9f0afe4a78 100644
--- a/python-package/lightgbm/engine.py
+++ b/python-package/lightgbm/engine.py
@@ -275,6 +275,7 @@ def train(
booster.best_iteration = 0
# start training
+ interactions_used = set()
for i in range(init_iteration, init_iteration + num_boost_round):
for cb in callbacks_before_iter:
cb(callback.CallbackEnv(model=booster,
@@ -287,6 +288,15 @@ def train(
booster.update(fobj=fobj)
evaluation_result_list: List[_LGBM_BoosterEvalMethodResultType] = []
+ if params.get("max_interactions", 0) > 0:
+ interaction_used = booster.dump_model(num_iteration=1, start_iteration=i)["tree_info"][0]["tree_features"]
+ interaction_used.sort()
+ interactions_used.add(tuple(interaction_used))
+
+ if len(interactions_used) == params["max_interactions"]:
+ params["tree_interaction_constraints"] = [list(feats) for feats in interactions_used]
+ params["max_interactions"] = 0
+ booster.reset_parameter(params)
# check evaluation result.
if valid_sets is not None:
if is_valid_contain_train:
diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp
index b75adab6d971..8ded1d9ea07c 100644
--- a/src/boosting/gbdt.cpp
+++ b/src/boosting/gbdt.cpp
@@ -244,6 +244,20 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
std::string snapshot_out = model_output_path + ".snapshot_iter_" + std::to_string(iter + 1);
SaveModelToFile(0, -1, config_->saved_feature_importance_type, snapshot_out.c_str());
}
+
+ if (config_->max_interactions != 0) {
+ interactions_used.insert(models_[models_.size() - 1]->tree_features());
+ }
+
+ if (config_->max_interactions != 0 && static_cast(interactions_used.size()) >= config_->max_interactions) {
+ auto new_config = std::unique_ptr(new Config(*config_));
+ new_config->tree_interaction_constraints_vector.clear();
+ for (auto &inter_set : interactions_used) {
+ new_config->tree_interaction_constraints_vector.emplace_back(inter_set.begin(), inter_set.end());
+ }
+ new_config->max_interactions = 0;
+ ResetConfig(new_config.release());
+ }
}
}
diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h
index 28ebee446fad..d3dc80a837ec 100644
--- a/src/boosting/gbdt.h
+++ b/src/boosting/gbdt.h
@@ -23,6 +23,7 @@
#include
#include
#include
+#include
#include "cuda/cuda_score_updater.hpp"
#include "score_updater.hpp"
@@ -542,6 +543,8 @@ class GBDT : public GBDTBase {
std::vector> best_msg_;
/*! \brief Trained models(trees) */
std::vector> models_;
+ /*! \brief Set of set of features used in all the models */
+ std::set> interactions_used;
/*! \brief Max feature index of training data*/
int max_feature_idx_;
/*! \brief Parser config file content */
diff --git a/src/io/config.cpp b/src/io/config.cpp
index e25bb6d4fd70..4388bbc91136 100644
--- a/src/io/config.cpp
+++ b/src/io/config.cpp
@@ -232,13 +232,21 @@ void Config::GetAucMuWeights() {
}
void Config::GetInteractionConstraints() {
- if (interaction_constraints == "") {
+ if (interaction_constraints.empty()) {
interaction_constraints_vector = std::vector>();
} else {
interaction_constraints_vector = Common::StringToArrayofArrays(interaction_constraints, '[', ']', ',');
}
}
+void Config::GetTreeInteractionConstraints() {
+ if (tree_interaction_constraints.empty()) {
+ tree_interaction_constraints_vector = std::vector>();
+ } else {
+ tree_interaction_constraints_vector = Common::StringToArrayofArrays(tree_interaction_constraints, '[', ']', ',');
+ }
+}
+
void Config::Set(const std::unordered_map& params) {
// generate seeds by seed.
if (GetInt(params, "seed", &seed)) {
@@ -269,6 +277,8 @@ void Config::Set(const std::unordered_map& params) {
GetInteractionConstraints();
+ GetTreeInteractionConstraints();
+
// sort eval_at
std::sort(eval_at.begin(), eval_at.end());
diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp
index 394614af3f33..784e23e57fe9 100644
--- a/src/io/config_auto.cpp
+++ b/src/io/config_auto.cpp
@@ -246,6 +246,9 @@ const std::unordered_set& Config::parameter_set() {
"cegb_penalty_feature_coupled",
"path_smooth",
"interaction_constraints",
+ "tree_interaction_constraints",
+ "max_tree_interactions",
+ "max_interactions",
"verbosity",
"input_model",
"output_model",
@@ -488,6 +491,14 @@ void Config::GetMembersFromString(const std::unordered_map>& Config::paramet
{"cegb_penalty_feature_coupled", {}},
{"path_smooth", {}},
{"interaction_constraints", {}},
+ {"tree_interaction_constraints", {}},
+ {"max_tree_interactions", {}},
+ {"max_interactions", {}},
{"verbosity", {"verbose"}},
{"input_model", {"model_input", "model_in"}},
{"output_model", {"model_output", "model_out"}},
@@ -989,6 +1006,9 @@ const std::unordered_map& Config::ParameterTypes() {
{"cegb_penalty_feature_coupled", "vector"},
{"path_smooth", "double"},
{"interaction_constraints", "vector>"},
+ {"tree_interaction_constraints", "string"},
+ {"max_tree_interactions", "int"},
+ {"max_interactions", "int"},
{"verbosity", "int"},
{"input_model", "string"},
{"output_model", "string"},
diff --git a/src/io/tree.cpp b/src/io/tree.cpp
index 4312b4f65002..7313db2d61a4 100644
--- a/src/io/tree.cpp
+++ b/src/io/tree.cpp
@@ -415,6 +415,18 @@ std::string Tree::ToJSON() const {
str_buf << "\"num_leaves\":" << num_leaves_ << "," << '\n';
str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
+
+ auto feats_used = tree_features();
+ size_t i = 0;
+ str_buf << "\"tree_features\":[";
+ for (int feat : feats_used) {
+ str_buf << feat;
+ if (i != feats_used.size() - 1) {
+ str_buf << ",";
+ }
+ ++i;
+ }
+ str_buf << "]," << '\n';
if (num_leaves_ == 1) {
if (is_linear_) {
str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << ", " << "\n";
diff --git a/src/treelearner/col_sampler.hpp b/src/treelearner/col_sampler.hpp
index c70b07e50efa..b35880da4be4 100644
--- a/src/treelearner/col_sampler.hpp
+++ b/src/treelearner/col_sampler.hpp
@@ -15,6 +15,7 @@
#include
#include
#include
+#include
namespace LightGBM {
class ColSampler {
@@ -28,6 +29,13 @@ class ColSampler {
std::unordered_set constraint_set(constraint.begin(), constraint.end());
interaction_constraints_.push_back(constraint_set);
}
+
+ for (auto constraint : config->tree_interaction_constraints_vector) {
+ std::unordered_set constraint_set(constraint.begin(), constraint.end());
+ tree_interaction_constraints_.push_back(constraint_set);
+ }
+
+ max_tree_interactions_ = config-> max_tree_interactions;
}
static int GetCnt(size_t total_cnt, double fraction) {
@@ -68,6 +76,11 @@ class ColSampler {
used_cnt_bytree_ =
GetCnt(valid_feature_indices_.size(), fraction_bytree_);
}
+ tree_interaction_constraints_.clear();
+ for (auto constraint : config->tree_interaction_constraints_vector) {
+ std::unordered_set constraint_set(constraint.begin(), constraint.end());
+ tree_interaction_constraints_.push_back(constraint_set);
+ }
ResetByTree();
}
@@ -88,31 +101,86 @@ class ColSampler {
}
}
- std::vector GetByNode(const Tree* tree, int leaf) {
- // get interaction constraints for current branch
- std::unordered_set allowed_features;
- if (!interaction_constraints_.empty()) {
- std::vector branch_features = tree->branch_features(leaf);
- allowed_features.insert(branch_features.begin(), branch_features.end());
- for (auto constraint : interaction_constraints_) {
- int num_feat_found = 0;
- if (branch_features.size() == 0) {
- allowed_features.insert(constraint.begin(), constraint.end());
+ void ComputeTreeAllowedFeatures(std::unordered_set *tree_allowed_features, std::set *tree_features) {
+ tree_allowed_features->insert(tree_features->begin(), tree_features->end());
+ if (tree_interaction_constraints_.empty()) {
+ for (int i = 0; i < train_data_->num_features(); ++i) {
+ tree_allowed_features->insert(tree_allowed_features->end(), i);
+ }
+ }
+ for (auto constraint : tree_interaction_constraints_) {
+ int num_feat_found = 0;
+ if (tree_features->empty()) {
+ tree_allowed_features->insert(constraint.begin(), constraint.end());
+ }
+ for (int feat : *tree_features) {
+ if (constraint.count(feat) == 0) { break; }
+ ++num_feat_found;
+ if (num_feat_found == static_cast(tree_features->size())) {
+ tree_allowed_features->insert(constraint.begin(), constraint.end());
+ break;
+ }
+ }
+ }
+ }
+
+ void ComputeBranchAllowedFeatures(const Tree *tree, int leaf, std::unordered_set *branch_allowed_features) {
+ if (!interaction_constraints_.empty()) {
+ std::vector branch_features = tree->branch_features(leaf);
+ for (auto constraint : interaction_constraints_) {
+ int num_feat_found = 0;
+ if (branch_features.empty()) {
+ (*branch_allowed_features).insert(constraint.begin(), constraint.end());
+ }
+ for (int feat : branch_features) {
+ if (constraint.count(feat) == 0) { break; }
+ ++num_feat_found;
+ if (num_feat_found == static_cast(branch_features.size())) {
+ (*branch_allowed_features).insert(constraint.begin(), constraint.end());
+ break;
+ }
+ }
+ }
}
- for (int feat : branch_features) {
- if (constraint.count(feat) == 0) { break; }
- ++num_feat_found;
- if (num_feat_found == static_cast(branch_features.size())) {
- allowed_features.insert(constraint.begin(), constraint.end());
- break;
+ }
+
+ std::vector GetByNode(const Tree* tree, int leaf) {
+ // get interaction constraints for current tree
+ std::unordered_set tree_allowed_features;
+ if (!tree_interaction_constraints_.empty() || max_tree_interactions_ > 0) {
+ std::set tree_features = tree->tree_features();
+ if (max_tree_interactions_ == 0 || tree_features.size() < (std::set::size_type) max_tree_interactions_) {
+ ComputeTreeAllowedFeatures(&tree_allowed_features, &tree_features);
+ } else {
+ for (int feat : tree_features) {
+ tree_allowed_features.insert(feat);
}
+ }
+ }
+ // get interaction constraints for current branch
+ std::unordered_set branch_allowed_features;
+
+ ComputeBranchAllowedFeatures(tree, leaf, &branch_allowed_features);
+
+
+ // intersect allowed features for branch and tree
+ std::unordered_set allowed_features;
+
+ if ((tree_interaction_constraints_.empty() && max_tree_interactions_ == 0) && !interaction_constraints_.empty()) {
+ allowed_features.insert(branch_allowed_features.begin(), branch_allowed_features.end());
+ } else if (!(tree_interaction_constraints_.empty() && max_tree_interactions_ == 0) && interaction_constraints_.empty()) {
+ allowed_features.insert(tree_allowed_features.begin(), tree_allowed_features.end());
+ } else {
+ for (int element : tree_allowed_features) {
+ if (branch_allowed_features.count(element) > 0) {
+ allowed_features.insert(element);
}
}
}
std::vector ret(train_data_->num_features(), 0);
if (fraction_bynode_ >= 1.0f) {
- if (interaction_constraints_.empty()) {
+ if (interaction_constraints_.empty() && tree_interaction_constraints_.empty() && max_tree_interactions_ == 0) {
return std::vector(train_data_->num_features(), 1);
} else {
for (int feat : allowed_features) {
@@ -128,7 +196,7 @@ class ColSampler {
auto used_feature_cnt = GetCnt(used_feature_indices_.size(), fraction_bynode_);
std::vector* allowed_used_feature_indices;
std::vector filtered_feature_indices;
- if (interaction_constraints_.empty()) {
+ if (interaction_constraints_.empty() && tree_interaction_constraints_.empty() && max_tree_interactions_ == 0) {
allowed_used_feature_indices = &used_feature_indices_;
} else {
for (int feat_ind : used_feature_indices_) {
@@ -154,7 +222,7 @@ class ColSampler {
GetCnt(valid_feature_indices_.size(), fraction_bynode_);
std::vector* allowed_valid_feature_indices;
std::vector filtered_feature_indices;
- if (interaction_constraints_.empty()) {
+ if (interaction_constraints_.empty() && tree_interaction_constraints_.empty() && max_tree_interactions_ == 0) {
allowed_valid_feature_indices = &valid_feature_indices_;
} else {
for (int feat : valid_feature_indices_) {
@@ -199,6 +267,9 @@ class ColSampler {
std::vector valid_feature_indices_;
/*! \brief interaction constraints index in original (raw data) features */
std::vector> interaction_constraints_;
+
+ std::vector> tree_interaction_constraints_;
+ int max_tree_interactions_;
};
} // namespace LightGBM
diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp
index d5c5cc59ef3a..43b46c72bd28 100644
--- a/src/treelearner/serial_tree_learner.cpp
+++ b/src/treelearner/serial_tree_learner.cpp
@@ -196,7 +196,10 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
// some initial works before training
BeforeTrain();
- bool track_branch_features = !(config_->interaction_constraints_vector.empty());
+ bool track_branch_features = !(config_->interaction_constraints_vector.empty()
+ && config_->tree_interaction_constraints_vector.empty()
+ && config_->max_tree_interactions == 0
+ && config_->max_interactions == 0);
auto tree = std::unique_ptr(new Tree(config_->num_leaves, track_branch_features, false));
auto tree_ptr = tree.get();
constraints_->ShareTreePointer(tree_ptr);
@@ -333,6 +336,21 @@ void SerialTreeLearner::BeforeTrain() {
bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
Common::FunctionTimer fun_timer("SerialTreeLearner::BeforeFindBestSplit", global_timer);
+
+ if (tree->is_tracking_branch_features()) {
+ #pragma omp parallel for schedule(static) num_threads(OMP_NUM_THREADS())
+ for (int i = 0; i < config_->num_leaves; ++i) {
+ int feat_index = best_split_per_leaf_[i].feature;
+ if (feat_index == -1) continue;
+
+ int inner_feat_index = train_data_->InnerFeatureIndex(feat_index);
+ auto allowed_feature = col_sampler_.GetByNode(tree, i);
+ if (!allowed_feature[inner_feat_index]) {
+ RecomputeBestSplitForLeaf(tree, i, &best_split_per_leaf_[i]);
+ }
+ }
+ }
+
// check depth of current leaf
if (config_->max_depth > 0) {
// only need to check left leaf, since right leaf is in same level of left leaf
@@ -1010,7 +1028,7 @@ double SerialTreeLearner::GetParentOutput(const Tree* tree, const LeafSplits* le
return parent_output;
}
-void SerialTreeLearner::RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split) {
+void SerialTreeLearner::RecomputeBestSplitForLeaf(const Tree* tree, int leaf, SplitInfo* split) {
FeatureHistogram* histogram_array_;
if (!histogram_pool_.Get(leaf, &histogram_array_)) {
Log::Warning(
diff --git a/src/treelearner/serial_tree_learner.h b/src/treelearner/serial_tree_learner.h
index 43ff6a4b1e13..9531a24c3e06 100644
--- a/src/treelearner/serial_tree_learner.h
+++ b/src/treelearner/serial_tree_learner.h
@@ -130,7 +130,7 @@ class SerialTreeLearner: public TreeLearner {
void GetShareStates(const Dataset* dataset, bool is_constant_hessian, bool is_first_time);
- void RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split);
+ void RecomputeBestSplitForLeaf(const Tree* tree, int leaf, SplitInfo* split);
/*!
* \brief Some initial works before training
diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py
index ccde38977d2d..681f77cb91b1 100644
--- a/tests/python_package_test/test_engine.py
+++ b/tests/python_package_test/test_engine.py
@@ -3459,6 +3459,149 @@ def test_interaction_constraints():
)
+@pytest.mark.skipif(
+ getenv("TASK", "") == "cuda", reason="Interaction constraints are not yet supported on the CUDA version"
+)
+def test_tree_interaction_constraints():
+ def check_consistency(est, tree_interaction_constraints):
+ feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())}
+ tree_df = est.trees_to_dataframe()
+ inter_found = set()
+ for tree_index in tree_df["tree_index"].unique():
+ tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index]
+ feat_used = [
+ feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None
+ ]
+ inter_found.add(tuple(sorted(feat_used)))
+ for feats_found in inter_found:
+ found = False
+ for real_contraints in tree_interaction_constraints:
+ if set(feats_found) <= set(real_contraints):
+ found = True
+ break
+ assert found is True
+
+ X, y = make_synthetic_regression(n_samples=400, n_features=30)
+ num_features = X.shape[1]
+ train_data = lgb.Dataset(X, label=y)
+ # check that tree constraint containing all features is equivalent to no constraint
+ params = {"verbose": -1, "seed": 0}
+ est = lgb.train(params, train_data, num_boost_round=10)
+ pred1 = est.predict(X)
+ est = lgb.train(
+ dict(params, tree_interaction_constraints=[list(range(num_features))]), train_data, num_boost_round=10
+ )
+ pred2 = est.predict(X)
+ np.testing.assert_allclose(pred1, pred2)
+
+ # check that each tree is composed exactly of 1 feature
+ tree_interaction_constraints = [[i] for i in range(num_features)]
+ new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
+ est = lgb.train(new_params, train_data, num_boost_round=100)
+ check_consistency(est, tree_interaction_constraints)
+
+ # check that each tree is composed exactly of 2 features contained in the constrained set
+ tree_interaction_constraints = [[i, i + 1] for i in range(0, num_features - 1, 2)]
+ new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
+ est = lgb.train(new_params, train_data, num_boost_round=100)
+ check_consistency(est, tree_interaction_constraints)
+
+ # check if tree features interaction constraints works with multiple set of features
+ tree_interaction_constraints = [list(range(i, i + 5)) for i in range(0, num_features - 5, 5)]
+ new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
+ est = lgb.train(new_params, train_data, num_boost_round=100)
+ check_consistency(est, tree_interaction_constraints)
+
+
+@pytest.mark.skipif(
+ getenv("TASK", "") == "cuda", reason="Interaction constraints are not yet supported on the CUDA version"
+)
+def test_max_tree_interactions():
+ def check_n_interactions(est):
+ feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())}
+ tree_df = est.trees_to_dataframe()
+ max_n_interactions = 0
+ for tree_index in tree_df["tree_index"].unique():
+ tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index]
+ feat_used = [
+ feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None
+ ]
+ max_n_interactions = max(max_n_interactions, len(feat_used))
+ assert max_n_interactions <= est.params["max_tree_interactions"]
+
+ X, y = make_synthetic_regression(n_samples=400, n_features=30)
+ train_data = lgb.Dataset(X, label=y)
+ # check that limiting the number of interaction to the number of features is equivalent to no constraint
+ params = {"verbose": -1, "seed": 0}
+ est = lgb.train(params, train_data, num_boost_round=100)
+ pred1 = est.predict(X)
+ est = lgb.train(dict(params, max_tree_interactions=400), train_data, num_boost_round=100)
+ pred2 = est.predict(X)
+
+ check_n_interactions(est)
+ np.testing.assert_allclose(pred1, pred2)
+
+ # check that the forest has only 1 interaction per tree
+ max_tree_interactions = 1
+ new_params = dict(params, max_tree_interactions=max_tree_interactions)
+ est = lgb.train(new_params, train_data, num_boost_round=100)
+ check_n_interactions(est)
+
+ # check that the forest has at most 10 features that interact in a tree
+ max_tree_interactions = 10
+ new_params = dict(params, max_tree_interactions=max_tree_interactions)
+ est = lgb.train(new_params, train_data, num_boost_round=100)
+ check_n_interactions(est)
+
+
+@pytest.mark.skipif(
+ getenv("TASK", "") == "cuda", reason="Interaction constraints are not yet supported on the CUDA version"
+)
+def test_max_interactions():
+ def check_interactions(est, max_interactions):
+ feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())}
+ tree_df = est.trees_to_dataframe()
+ inter_found = set()
+ for tree_index in tree_df["tree_index"].unique():
+ tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index]
+ feat_used = [
+ feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None
+ ]
+ add_inter = True
+ for inter in inter_found:
+ if set(feat_used) <= set(inter): # the interaction found is a subset of another interaction
+ add_inter = False
+ if add_inter:
+ inter_found.add(tuple(sorted(feat_used)))
+ assert len(inter_found) <= max_interactions
+
+ X, y = make_synthetic_regression(n_samples=400, n_features=30)
+ train_data = lgb.Dataset(X, label=y)
+ # check that limiting the number of distinct interactions to the number of trees is equivalent to no constraint
+ params = {"verbose": -1, "seed": 0}
+ est = lgb.train(params, train_data, num_boost_round=100)
+ pred1 = est.predict(X)
+
+ max_interactions = 100
+ est = lgb.train(dict(params, max_interactions=max_interactions), train_data, num_boost_round=100)
+ pred2 = est.predict(X)
+
+ check_interactions(est, max_interactions)
+ np.testing.assert_allclose(pred1, pred2)
+
+ # check that the forest has only 1 interaction
+ max_interactions = 1
+ new_params = dict(params, max_interactions=max_interactions)
+ est = lgb.train(new_params, train_data, num_boost_round=100)
+ check_interactions(est, max_interactions)
+
+ # check that the forest has at most 10 interactions
+ max_interactions = 10
+ new_params = dict(params, max_interactions=max_interactions)
+ est = lgb.train(new_params, train_data, num_boost_round=100)
+ check_interactions(est, max_interactions)
+
+
def test_linear_trees_num_threads():
# check that number of threads does not affect result
np.random.seed(0)