diff --git a/docs/Parameters.rst b/docs/Parameters.rst index 341cdd487c71..8bf02c1ba4da 100644 --- a/docs/Parameters.rst +++ b/docs/Parameters.rst @@ -618,6 +618,32 @@ Learning Control Parameters - any two features can only appear in the same branch only if there exists a constraint containing both features +- ``tree_interaction_constraints`` :raw-html:`🔗︎`, default = ``""``, type = string + + - controls which features can appear in the same tree + + - by default interaction constraints are disabled, to enable them you can specify + + - for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]`` + + - for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]`` + + - for R-package, list of character or numeric vectors, e.g. ``list(c("var1", "var2", "var3"), c("var3", "var4"))`` or ``list(c(1L, 2L, 3L), c(3L, 4L))``. Numeric vectors should use 1-based indexing, where ``1L`` is the first feature, ``2L`` is the second feature, etc + + - any two features can only appear in the same tree only if there exists a constraint containing both features + +- ``max_tree_interactions`` :raw-html:`🔗︎`, default = ``0``, type = int, constraints: ``max_tree_interactions >= 0.0`` + + - controls how many features can appear in the same tree + + - by default (max_tree_interactions = 0) interaction constraints are disabled + +- ``max_interactions`` :raw-html:`🔗︎`, default = ``0``, type = int, constraints: ``max_interactions >= 0.0`` + + - controls how many features interactions can be added to the final model + + - by default no limit is imposed on the interaction with max_interactions = 0 + - ``verbosity`` :raw-html:`🔗︎`, default = ``1``, type = int, aliases: ``verbose`` - controls the level of LightGBM's verbosity diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index 6500cb77272d..b8693814749f 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -571,6 +571,24 @@ struct Config { // desc = any two features can only appear in the same branch only if there exists a constraint containing both features std::string interaction_constraints = ""; + // desc = controls which features can appear in the same tree + // desc = by default interaction constraints are disabled, to enable them you can specify + // descl2 = for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]`` + // descl2 = for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]`` + // descl2 = for R-package, list of character or numeric vectors, e.g. ``list(c("var1", "var2", "var3"), c("var3", "var4"))`` or ``list(c(1L, 2L, 3L), c(3L, 4L))``. Numeric vectors should use 1-based indexing, where ``1L`` is the first feature, ``2L`` is the second feature, etc + // desc = any two features can only appear in the same tree only if there exists a constraint containing both features + std::string tree_interaction_constraints = ""; + + // check = >= 0.0 + // desc = controls how many features can appear in the same tree + // desc = by default (max_tree_interactions = 0) interaction constraints are disabled + int max_tree_interactions = 0; + + // check = >= 0.0 + // desc = controls how many features interactions can be added to the final model + // desc = by default no limit is imposed on the interaction with max_interactions = 0 + int max_interactions = 0; + // alias = verbose // desc = controls the level of LightGBM's verbosity // desc = ``< 0``: Fatal, ``= 0``: Error (Warning), ``= 1``: Info, ``> 1``: Debug @@ -1126,6 +1144,7 @@ struct Config { static const std::unordered_set& parameter_set(); std::vector> auc_mu_weights_matrix; std::vector> interaction_constraints_vector; + std::vector> tree_interaction_constraints_vector; static const std::unordered_map& ParameterTypes(); static const std::string DumpAliases(); @@ -1135,6 +1154,7 @@ struct Config { std::string SaveMembersToString() const; void GetAucMuWeights(); void GetInteractionConstraints(); + void GetTreeInteractionConstraints(); }; inline bool Config::GetString( diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h index 0c4a41f46a87..dc1303c00b6b 100644 --- a/include/LightGBM/tree.h +++ b/include/LightGBM/tree.h @@ -13,6 +13,7 @@ #include #include #include +#include namespace LightGBM { @@ -158,6 +159,11 @@ class Tree { /*! \brief Get features on leaf's branch*/ inline std::vector branch_features(int leaf) const { return branch_features_[leaf]; } + /*! \brief Get unique features used by the current tree*/ + std::set tree_features() const { + return tree_features_; + } + inline double split_gain(int split_idx) const { return split_gain_[split_idx]; } inline double internal_value(int node_idx) const { @@ -319,6 +325,8 @@ class Tree { inline bool is_linear() const { return is_linear_; } + inline bool is_tracking_branch_features() const { return track_branch_features_; } + #ifdef USE_CUDA inline bool is_cuda_tree() const { return is_cuda_tree_; } #endif // USE_CUDA @@ -520,6 +528,10 @@ class Tree { bool track_branch_features_; /*! \brief Features on leaf's branch, original index */ std::vector> branch_features_; + + /*! \brief Features used by the tree, original index */ + std::set tree_features_; + double shrinkage_; int max_depth_; /*! \brief Tree has linear model at each leaf */ @@ -579,6 +591,7 @@ inline void Tree::Split(int leaf, int feature, int real_feature, branch_features_[num_leaves_] = branch_features_[leaf]; branch_features_[num_leaves_].push_back(split_feature_[new_node_idx]); branch_features_[leaf].push_back(split_feature_[new_node_idx]); + tree_features_.insert(split_feature_[new_node_idx]); } } diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index e1779f0723be..8c9f0afe4a78 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -275,6 +275,7 @@ def train( booster.best_iteration = 0 # start training + interactions_used = set() for i in range(init_iteration, init_iteration + num_boost_round): for cb in callbacks_before_iter: cb(callback.CallbackEnv(model=booster, @@ -287,6 +288,15 @@ def train( booster.update(fobj=fobj) evaluation_result_list: List[_LGBM_BoosterEvalMethodResultType] = [] + if params.get("max_interactions", 0) > 0: + interaction_used = booster.dump_model(num_iteration=1, start_iteration=i)["tree_info"][0]["tree_features"] + interaction_used.sort() + interactions_used.add(tuple(interaction_used)) + + if len(interactions_used) == params["max_interactions"]: + params["tree_interaction_constraints"] = [list(feats) for feats in interactions_used] + params["max_interactions"] = 0 + booster.reset_parameter(params) # check evaluation result. if valid_sets is not None: if is_valid_contain_train: diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index b75adab6d971..8ded1d9ea07c 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -244,6 +244,20 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) { std::string snapshot_out = model_output_path + ".snapshot_iter_" + std::to_string(iter + 1); SaveModelToFile(0, -1, config_->saved_feature_importance_type, snapshot_out.c_str()); } + + if (config_->max_interactions != 0) { + interactions_used.insert(models_[models_.size() - 1]->tree_features()); + } + + if (config_->max_interactions != 0 && static_cast(interactions_used.size()) >= config_->max_interactions) { + auto new_config = std::unique_ptr(new Config(*config_)); + new_config->tree_interaction_constraints_vector.clear(); + for (auto &inter_set : interactions_used) { + new_config->tree_interaction_constraints_vector.emplace_back(inter_set.begin(), inter_set.end()); + } + new_config->max_interactions = 0; + ResetConfig(new_config.release()); + } } } diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 28ebee446fad..d3dc80a837ec 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -23,6 +23,7 @@ #include #include #include +#include #include "cuda/cuda_score_updater.hpp" #include "score_updater.hpp" @@ -542,6 +543,8 @@ class GBDT : public GBDTBase { std::vector> best_msg_; /*! \brief Trained models(trees) */ std::vector> models_; + /*! \brief Set of set of features used in all the models */ + std::set> interactions_used; /*! \brief Max feature index of training data*/ int max_feature_idx_; /*! \brief Parser config file content */ diff --git a/src/io/config.cpp b/src/io/config.cpp index e25bb6d4fd70..4388bbc91136 100644 --- a/src/io/config.cpp +++ b/src/io/config.cpp @@ -232,13 +232,21 @@ void Config::GetAucMuWeights() { } void Config::GetInteractionConstraints() { - if (interaction_constraints == "") { + if (interaction_constraints.empty()) { interaction_constraints_vector = std::vector>(); } else { interaction_constraints_vector = Common::StringToArrayofArrays(interaction_constraints, '[', ']', ','); } } +void Config::GetTreeInteractionConstraints() { + if (tree_interaction_constraints.empty()) { + tree_interaction_constraints_vector = std::vector>(); + } else { + tree_interaction_constraints_vector = Common::StringToArrayofArrays(tree_interaction_constraints, '[', ']', ','); + } +} + void Config::Set(const std::unordered_map& params) { // generate seeds by seed. if (GetInt(params, "seed", &seed)) { @@ -269,6 +277,8 @@ void Config::Set(const std::unordered_map& params) { GetInteractionConstraints(); + GetTreeInteractionConstraints(); + // sort eval_at std::sort(eval_at.begin(), eval_at.end()); diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index 394614af3f33..784e23e57fe9 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -246,6 +246,9 @@ const std::unordered_set& Config::parameter_set() { "cegb_penalty_feature_coupled", "path_smooth", "interaction_constraints", + "tree_interaction_constraints", + "max_tree_interactions", + "max_interactions", "verbosity", "input_model", "output_model", @@ -488,6 +491,14 @@ void Config::GetMembersFromString(const std::unordered_map>& Config::paramet {"cegb_penalty_feature_coupled", {}}, {"path_smooth", {}}, {"interaction_constraints", {}}, + {"tree_interaction_constraints", {}}, + {"max_tree_interactions", {}}, + {"max_interactions", {}}, {"verbosity", {"verbose"}}, {"input_model", {"model_input", "model_in"}}, {"output_model", {"model_output", "model_out"}}, @@ -989,6 +1006,9 @@ const std::unordered_map& Config::ParameterTypes() { {"cegb_penalty_feature_coupled", "vector"}, {"path_smooth", "double"}, {"interaction_constraints", "vector>"}, + {"tree_interaction_constraints", "string"}, + {"max_tree_interactions", "int"}, + {"max_interactions", "int"}, {"verbosity", "int"}, {"input_model", "string"}, {"output_model", "string"}, diff --git a/src/io/tree.cpp b/src/io/tree.cpp index 4312b4f65002..7313db2d61a4 100644 --- a/src/io/tree.cpp +++ b/src/io/tree.cpp @@ -415,6 +415,18 @@ std::string Tree::ToJSON() const { str_buf << "\"num_leaves\":" << num_leaves_ << "," << '\n'; str_buf << "\"num_cat\":" << num_cat_ << "," << '\n'; str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n'; + + auto feats_used = tree_features(); + size_t i = 0; + str_buf << "\"tree_features\":["; + for (int feat : feats_used) { + str_buf << feat; + if (i != feats_used.size() - 1) { + str_buf << ","; + } + ++i; + } + str_buf << "]," << '\n'; if (num_leaves_ == 1) { if (is_linear_) { str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << ", " << "\n"; diff --git a/src/treelearner/col_sampler.hpp b/src/treelearner/col_sampler.hpp index c70b07e50efa..b35880da4be4 100644 --- a/src/treelearner/col_sampler.hpp +++ b/src/treelearner/col_sampler.hpp @@ -15,6 +15,7 @@ #include #include #include +#include namespace LightGBM { class ColSampler { @@ -28,6 +29,13 @@ class ColSampler { std::unordered_set constraint_set(constraint.begin(), constraint.end()); interaction_constraints_.push_back(constraint_set); } + + for (auto constraint : config->tree_interaction_constraints_vector) { + std::unordered_set constraint_set(constraint.begin(), constraint.end()); + tree_interaction_constraints_.push_back(constraint_set); + } + + max_tree_interactions_ = config-> max_tree_interactions; } static int GetCnt(size_t total_cnt, double fraction) { @@ -68,6 +76,11 @@ class ColSampler { used_cnt_bytree_ = GetCnt(valid_feature_indices_.size(), fraction_bytree_); } + tree_interaction_constraints_.clear(); + for (auto constraint : config->tree_interaction_constraints_vector) { + std::unordered_set constraint_set(constraint.begin(), constraint.end()); + tree_interaction_constraints_.push_back(constraint_set); + } ResetByTree(); } @@ -88,31 +101,86 @@ class ColSampler { } } - std::vector GetByNode(const Tree* tree, int leaf) { - // get interaction constraints for current branch - std::unordered_set allowed_features; - if (!interaction_constraints_.empty()) { - std::vector branch_features = tree->branch_features(leaf); - allowed_features.insert(branch_features.begin(), branch_features.end()); - for (auto constraint : interaction_constraints_) { - int num_feat_found = 0; - if (branch_features.size() == 0) { - allowed_features.insert(constraint.begin(), constraint.end()); + void ComputeTreeAllowedFeatures(std::unordered_set *tree_allowed_features, std::set *tree_features) { + tree_allowed_features->insert(tree_features->begin(), tree_features->end()); + if (tree_interaction_constraints_.empty()) { + for (int i = 0; i < train_data_->num_features(); ++i) { + tree_allowed_features->insert(tree_allowed_features->end(), i); + } + } + for (auto constraint : tree_interaction_constraints_) { + int num_feat_found = 0; + if (tree_features->empty()) { + tree_allowed_features->insert(constraint.begin(), constraint.end()); + } + for (int feat : *tree_features) { + if (constraint.count(feat) == 0) { break; } + ++num_feat_found; + if (num_feat_found == static_cast(tree_features->size())) { + tree_allowed_features->insert(constraint.begin(), constraint.end()); + break; + } + } + } + } + + void ComputeBranchAllowedFeatures(const Tree *tree, int leaf, std::unordered_set *branch_allowed_features) { + if (!interaction_constraints_.empty()) { + std::vector branch_features = tree->branch_features(leaf); + for (auto constraint : interaction_constraints_) { + int num_feat_found = 0; + if (branch_features.empty()) { + (*branch_allowed_features).insert(constraint.begin(), constraint.end()); + } + for (int feat : branch_features) { + if (constraint.count(feat) == 0) { break; } + ++num_feat_found; + if (num_feat_found == static_cast(branch_features.size())) { + (*branch_allowed_features).insert(constraint.begin(), constraint.end()); + break; + } + } + } } - for (int feat : branch_features) { - if (constraint.count(feat) == 0) { break; } - ++num_feat_found; - if (num_feat_found == static_cast(branch_features.size())) { - allowed_features.insert(constraint.begin(), constraint.end()); - break; + } + + std::vector GetByNode(const Tree* tree, int leaf) { + // get interaction constraints for current tree + std::unordered_set tree_allowed_features; + if (!tree_interaction_constraints_.empty() || max_tree_interactions_ > 0) { + std::set tree_features = tree->tree_features(); + if (max_tree_interactions_ == 0 || tree_features.size() < (std::set::size_type) max_tree_interactions_) { + ComputeTreeAllowedFeatures(&tree_allowed_features, &tree_features); + } else { + for (int feat : tree_features) { + tree_allowed_features.insert(feat); } + } + } + // get interaction constraints for current branch + std::unordered_set branch_allowed_features; + + ComputeBranchAllowedFeatures(tree, leaf, &branch_allowed_features); + + + // intersect allowed features for branch and tree + std::unordered_set allowed_features; + + if ((tree_interaction_constraints_.empty() && max_tree_interactions_ == 0) && !interaction_constraints_.empty()) { + allowed_features.insert(branch_allowed_features.begin(), branch_allowed_features.end()); + } else if (!(tree_interaction_constraints_.empty() && max_tree_interactions_ == 0) && interaction_constraints_.empty()) { + allowed_features.insert(tree_allowed_features.begin(), tree_allowed_features.end()); + } else { + for (int element : tree_allowed_features) { + if (branch_allowed_features.count(element) > 0) { + allowed_features.insert(element); } } } std::vector ret(train_data_->num_features(), 0); if (fraction_bynode_ >= 1.0f) { - if (interaction_constraints_.empty()) { + if (interaction_constraints_.empty() && tree_interaction_constraints_.empty() && max_tree_interactions_ == 0) { return std::vector(train_data_->num_features(), 1); } else { for (int feat : allowed_features) { @@ -128,7 +196,7 @@ class ColSampler { auto used_feature_cnt = GetCnt(used_feature_indices_.size(), fraction_bynode_); std::vector* allowed_used_feature_indices; std::vector filtered_feature_indices; - if (interaction_constraints_.empty()) { + if (interaction_constraints_.empty() && tree_interaction_constraints_.empty() && max_tree_interactions_ == 0) { allowed_used_feature_indices = &used_feature_indices_; } else { for (int feat_ind : used_feature_indices_) { @@ -154,7 +222,7 @@ class ColSampler { GetCnt(valid_feature_indices_.size(), fraction_bynode_); std::vector* allowed_valid_feature_indices; std::vector filtered_feature_indices; - if (interaction_constraints_.empty()) { + if (interaction_constraints_.empty() && tree_interaction_constraints_.empty() && max_tree_interactions_ == 0) { allowed_valid_feature_indices = &valid_feature_indices_; } else { for (int feat : valid_feature_indices_) { @@ -199,6 +267,9 @@ class ColSampler { std::vector valid_feature_indices_; /*! \brief interaction constraints index in original (raw data) features */ std::vector> interaction_constraints_; + + std::vector> tree_interaction_constraints_; + int max_tree_interactions_; }; } // namespace LightGBM diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp index d5c5cc59ef3a..43b46c72bd28 100644 --- a/src/treelearner/serial_tree_learner.cpp +++ b/src/treelearner/serial_tree_learner.cpp @@ -196,7 +196,10 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians // some initial works before training BeforeTrain(); - bool track_branch_features = !(config_->interaction_constraints_vector.empty()); + bool track_branch_features = !(config_->interaction_constraints_vector.empty() + && config_->tree_interaction_constraints_vector.empty() + && config_->max_tree_interactions == 0 + && config_->max_interactions == 0); auto tree = std::unique_ptr(new Tree(config_->num_leaves, track_branch_features, false)); auto tree_ptr = tree.get(); constraints_->ShareTreePointer(tree_ptr); @@ -333,6 +336,21 @@ void SerialTreeLearner::BeforeTrain() { bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) { Common::FunctionTimer fun_timer("SerialTreeLearner::BeforeFindBestSplit", global_timer); + + if (tree->is_tracking_branch_features()) { + #pragma omp parallel for schedule(static) num_threads(OMP_NUM_THREADS()) + for (int i = 0; i < config_->num_leaves; ++i) { + int feat_index = best_split_per_leaf_[i].feature; + if (feat_index == -1) continue; + + int inner_feat_index = train_data_->InnerFeatureIndex(feat_index); + auto allowed_feature = col_sampler_.GetByNode(tree, i); + if (!allowed_feature[inner_feat_index]) { + RecomputeBestSplitForLeaf(tree, i, &best_split_per_leaf_[i]); + } + } + } + // check depth of current leaf if (config_->max_depth > 0) { // only need to check left leaf, since right leaf is in same level of left leaf @@ -1010,7 +1028,7 @@ double SerialTreeLearner::GetParentOutput(const Tree* tree, const LeafSplits* le return parent_output; } -void SerialTreeLearner::RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split) { +void SerialTreeLearner::RecomputeBestSplitForLeaf(const Tree* tree, int leaf, SplitInfo* split) { FeatureHistogram* histogram_array_; if (!histogram_pool_.Get(leaf, &histogram_array_)) { Log::Warning( diff --git a/src/treelearner/serial_tree_learner.h b/src/treelearner/serial_tree_learner.h index 43ff6a4b1e13..9531a24c3e06 100644 --- a/src/treelearner/serial_tree_learner.h +++ b/src/treelearner/serial_tree_learner.h @@ -130,7 +130,7 @@ class SerialTreeLearner: public TreeLearner { void GetShareStates(const Dataset* dataset, bool is_constant_hessian, bool is_first_time); - void RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split); + void RecomputeBestSplitForLeaf(const Tree* tree, int leaf, SplitInfo* split); /*! * \brief Some initial works before training diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index ccde38977d2d..681f77cb91b1 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -3459,6 +3459,149 @@ def test_interaction_constraints(): ) +@pytest.mark.skipif( + getenv("TASK", "") == "cuda", reason="Interaction constraints are not yet supported on the CUDA version" +) +def test_tree_interaction_constraints(): + def check_consistency(est, tree_interaction_constraints): + feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())} + tree_df = est.trees_to_dataframe() + inter_found = set() + for tree_index in tree_df["tree_index"].unique(): + tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index] + feat_used = [ + feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None + ] + inter_found.add(tuple(sorted(feat_used))) + for feats_found in inter_found: + found = False + for real_contraints in tree_interaction_constraints: + if set(feats_found) <= set(real_contraints): + found = True + break + assert found is True + + X, y = make_synthetic_regression(n_samples=400, n_features=30) + num_features = X.shape[1] + train_data = lgb.Dataset(X, label=y) + # check that tree constraint containing all features is equivalent to no constraint + params = {"verbose": -1, "seed": 0} + est = lgb.train(params, train_data, num_boost_round=10) + pred1 = est.predict(X) + est = lgb.train( + dict(params, tree_interaction_constraints=[list(range(num_features))]), train_data, num_boost_round=10 + ) + pred2 = est.predict(X) + np.testing.assert_allclose(pred1, pred2) + + # check that each tree is composed exactly of 1 feature + tree_interaction_constraints = [[i] for i in range(num_features)] + new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints) + est = lgb.train(new_params, train_data, num_boost_round=100) + check_consistency(est, tree_interaction_constraints) + + # check that each tree is composed exactly of 2 features contained in the constrained set + tree_interaction_constraints = [[i, i + 1] for i in range(0, num_features - 1, 2)] + new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints) + est = lgb.train(new_params, train_data, num_boost_round=100) + check_consistency(est, tree_interaction_constraints) + + # check if tree features interaction constraints works with multiple set of features + tree_interaction_constraints = [list(range(i, i + 5)) for i in range(0, num_features - 5, 5)] + new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints) + est = lgb.train(new_params, train_data, num_boost_round=100) + check_consistency(est, tree_interaction_constraints) + + +@pytest.mark.skipif( + getenv("TASK", "") == "cuda", reason="Interaction constraints are not yet supported on the CUDA version" +) +def test_max_tree_interactions(): + def check_n_interactions(est): + feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())} + tree_df = est.trees_to_dataframe() + max_n_interactions = 0 + for tree_index in tree_df["tree_index"].unique(): + tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index] + feat_used = [ + feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None + ] + max_n_interactions = max(max_n_interactions, len(feat_used)) + assert max_n_interactions <= est.params["max_tree_interactions"] + + X, y = make_synthetic_regression(n_samples=400, n_features=30) + train_data = lgb.Dataset(X, label=y) + # check that limiting the number of interaction to the number of features is equivalent to no constraint + params = {"verbose": -1, "seed": 0} + est = lgb.train(params, train_data, num_boost_round=100) + pred1 = est.predict(X) + est = lgb.train(dict(params, max_tree_interactions=400), train_data, num_boost_round=100) + pred2 = est.predict(X) + + check_n_interactions(est) + np.testing.assert_allclose(pred1, pred2) + + # check that the forest has only 1 interaction per tree + max_tree_interactions = 1 + new_params = dict(params, max_tree_interactions=max_tree_interactions) + est = lgb.train(new_params, train_data, num_boost_round=100) + check_n_interactions(est) + + # check that the forest has at most 10 features that interact in a tree + max_tree_interactions = 10 + new_params = dict(params, max_tree_interactions=max_tree_interactions) + est = lgb.train(new_params, train_data, num_boost_round=100) + check_n_interactions(est) + + +@pytest.mark.skipif( + getenv("TASK", "") == "cuda", reason="Interaction constraints are not yet supported on the CUDA version" +) +def test_max_interactions(): + def check_interactions(est, max_interactions): + feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())} + tree_df = est.trees_to_dataframe() + inter_found = set() + for tree_index in tree_df["tree_index"].unique(): + tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index] + feat_used = [ + feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None + ] + add_inter = True + for inter in inter_found: + if set(feat_used) <= set(inter): # the interaction found is a subset of another interaction + add_inter = False + if add_inter: + inter_found.add(tuple(sorted(feat_used))) + assert len(inter_found) <= max_interactions + + X, y = make_synthetic_regression(n_samples=400, n_features=30) + train_data = lgb.Dataset(X, label=y) + # check that limiting the number of distinct interactions to the number of trees is equivalent to no constraint + params = {"verbose": -1, "seed": 0} + est = lgb.train(params, train_data, num_boost_round=100) + pred1 = est.predict(X) + + max_interactions = 100 + est = lgb.train(dict(params, max_interactions=max_interactions), train_data, num_boost_round=100) + pred2 = est.predict(X) + + check_interactions(est, max_interactions) + np.testing.assert_allclose(pred1, pred2) + + # check that the forest has only 1 interaction + max_interactions = 1 + new_params = dict(params, max_interactions=max_interactions) + est = lgb.train(new_params, train_data, num_boost_round=100) + check_interactions(est, max_interactions) + + # check that the forest has at most 10 interactions + max_interactions = 10 + new_params = dict(params, max_interactions=max_interactions) + est = lgb.train(new_params, train_data, num_boost_round=100) + check_interactions(est, max_interactions) + + def test_linear_trees_num_threads(): # check that number of threads does not affect result np.random.seed(0)