From f47f82301601e234536dcb6592253dce7ad84be6 Mon Sep 17 00:00:00 2001
From: Alberto Veneri <veneri.alberto@gmail.com>
Date: Thu, 7 Apr 2022 16:00:26 +0200
Subject: [PATCH 01/21] First version of the new parameter
 "tree_interaction_constraints""

---
 include/LightGBM/config.h                | 10 ++++
 include/LightGBM/tree.h                  | 10 ++++
 src/io/config.cpp                        | 12 ++++-
 src/io/config_auto.cpp                   |  5 ++
 src/treelearner/col_sampler.hpp          | 59 ++++++++++++++++++++----
 src/treelearner/serial_tree_learner.cpp  | 18 +++++++-
 src/treelearner/serial_tree_learner.h    |  2 +-
 tests/python_package_test/test_engine.py | 52 ++++++++++++++++++++-
 8 files changed, 155 insertions(+), 13 deletions(-)

diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h
index 83c228fe5dc6..3898a736c8dc 100644
--- a/include/LightGBM/config.h
+++ b/include/LightGBM/config.h
@@ -543,6 +543,14 @@ struct Config {
   // desc = any two features can only appear in the same branch only if there exists a constraint containing both features
   std::string interaction_constraints = "";
 
+  // desc = controls which features can appear in the same tree
+  // desc = by default interaction constraints are disabled, to enable them you can specify
+  // descl2 = for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]``
+  // descl2 = for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]``
+  // descl2 = for R-package, list of character or numeric vectors, e.g. ``list(c("var1", "var2", "var3"), c("var3", "var4"))`` or ``list(c(1L, 2L, 3L), c(3L, 4L))``. Numeric vectors should use 1-based indexing, where ``1L`` is the first feature, ``2L`` is the second feature, etc
+  // desc = any two features can only appear in the same tree only if there exists a constraint containing both features
+  std::string tree_interaction_constraints = "";
+
   // alias = verbose
   // desc = controls the level of LightGBM's verbosity
   // desc = ``< 0``: Fatal, ``= 0``: Error (Warning), ``= 1``: Info, ``> 1``: Debug
@@ -1065,6 +1073,7 @@ struct Config {
   static const std::unordered_set<std::string>& parameter_set();
   std::vector<std::vector<double>> auc_mu_weights_matrix;
   std::vector<std::vector<int>> interaction_constraints_vector;
+  std::vector<std::vector<int>> tree_interaction_constraints_vector;
   static const std::string DumpAliases();
 
  private:
@@ -1073,6 +1082,7 @@ struct Config {
   std::string SaveMembersToString() const;
   void GetAucMuWeights();
   void GetInteractionConstraints();
+  void GetTreeInteractionConstraints();
 };
 
 inline bool Config::GetString(
diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h
index 6ff0370e2ea6..906dcebdde1a 100644
--- a/include/LightGBM/tree.h
+++ b/include/LightGBM/tree.h
@@ -158,6 +158,10 @@ class Tree {
   /*! \brief Get features on leaf's branch*/
   inline std::vector<int> branch_features(int leaf) const { return branch_features_[leaf]; }
 
+  std::set<int> tree_features() const {
+    return tree_features_;
+  }
+
   inline double split_gain(int split_idx) const { return split_gain_[split_idx]; }
 
   inline double internal_value(int node_idx) const {
@@ -520,6 +524,10 @@ class Tree {
   bool track_branch_features_;
   /*! \brief Features on leaf's branch, original index */
   std::vector<std::vector<int>> branch_features_;
+
+  /*! \brief Features used by the tree, original index */
+  std::set<int> tree_features_;
+
   double shrinkage_;
   int max_depth_;
   /*! \brief Tree has linear model at each leaf */
@@ -579,7 +587,9 @@ inline void Tree::Split(int leaf, int feature, int real_feature,
     branch_features_[num_leaves_] = branch_features_[leaf];
     branch_features_[num_leaves_].push_back(split_feature_[new_node_idx]);
     branch_features_[leaf].push_back(split_feature_[new_node_idx]);
+    tree_features_.insert(split_feature_[new_node_idx]);
   }
+
 }
 
 inline double Tree::Predict(const double* feature_values) const {
diff --git a/src/io/config.cpp b/src/io/config.cpp
index 090ce79b830f..a9d0bde77d9b 100644
--- a/src/io/config.cpp
+++ b/src/io/config.cpp
@@ -185,13 +185,21 @@ void Config::GetAucMuWeights() {
 }
 
 void Config::GetInteractionConstraints() {
-  if (interaction_constraints == "") {
+  if (interaction_constraints.empty()) {
     interaction_constraints_vector = std::vector<std::vector<int>>();
   } else {
     interaction_constraints_vector = Common::StringToArrayofArrays<int>(interaction_constraints, '[', ']', ',');
   }
 }
 
+void Config::GetTreeInteractionConstraints() {
+  if (tree_interaction_constraints.empty()) {
+    tree_interaction_constraints_vector = std::vector<std::vector<int>>();
+  } else {
+    tree_interaction_constraints_vector = Common::StringToArrayofArrays<int>(tree_interaction_constraints, '[', ']', ',');
+  }
+}
+
 void Config::Set(const std::unordered_map<std::string, std::string>& params) {
   // generate seeds by seed.
   if (GetInt(params, "seed", &seed)) {
@@ -221,6 +229,8 @@ void Config::Set(const std::unordered_map<std::string, std::string>& params) {
 
   GetInteractionConstraints();
 
+  GetTreeInteractionConstraints();
+
   // sort eval_at
   std::sort(eval_at.begin(), eval_at.end());
 
diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp
index 9f3dd7a188f1..fed93121c942 100644
--- a/src/io/config_auto.cpp
+++ b/src/io/config_auto.cpp
@@ -245,6 +245,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
   "cegb_penalty_feature_coupled",
   "path_smooth",
   "interaction_constraints",
+  "tree_interaction_constraints",
   "verbosity",
   "input_model",
   "output_model",
@@ -482,6 +483,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
 
   GetString(params, "interaction_constraints", &interaction_constraints);
 
+  GetString(params, "tree_interaction_constraints", &tree_interaction_constraints);
+
   GetInt(params, "verbosity", &verbosity);
 
   GetString(params, "input_model", &input_model);
@@ -703,6 +706,7 @@ std::string Config::SaveMembersToString() const {
   str_buf << "[cegb_penalty_feature_coupled: " << Common::Join(cegb_penalty_feature_coupled, ",") << "]\n";
   str_buf << "[path_smooth: " << path_smooth << "]\n";
   str_buf << "[interaction_constraints: " << interaction_constraints << "]\n";
+  str_buf << "[tree_interaction_constraints: " << tree_interaction_constraints << "]\n";
   str_buf << "[verbosity: " << verbosity << "]\n";
   str_buf << "[saved_feature_importance_type: " << saved_feature_importance_type << "]\n";
   str_buf << "[linear_tree: " << linear_tree << "]\n";
@@ -822,6 +826,7 @@ const std::string Config::DumpAliases() {
   str_buf << "\"cegb_penalty_feature_coupled\": [], ";
   str_buf << "\"path_smooth\": [], ";
   str_buf << "\"interaction_constraints\": [], ";
+  str_buf << "\"tree_interaction_constraints\": [], ";
   str_buf << "\"verbosity\": [\"verbose\"], ";
   str_buf << "\"input_model\": [\"model_input\", \"model_in\"], ";
   str_buf << "\"output_model\": [\"model_output\", \"model_out\"], ";
diff --git a/src/treelearner/col_sampler.hpp b/src/treelearner/col_sampler.hpp
index 6debe9db60ca..82b08f90fbdb 100644
--- a/src/treelearner/col_sampler.hpp
+++ b/src/treelearner/col_sampler.hpp
@@ -28,6 +28,10 @@ class ColSampler {
       std::unordered_set<int> constraint_set(constraint.begin(), constraint.end());
       interaction_constraints_.push_back(constraint_set);
     }
+    for (auto constraint : config->tree_interaction_constraints_vector) {
+      std::unordered_set<int> constraint_set(constraint.begin(), constraint.end());
+      tree_interaction_constraints_.push_back(constraint_set);
+    }
   }
 
   static int GetCnt(size_t total_cnt, double fraction) {
@@ -89,30 +93,67 @@ class ColSampler {
   }
 
   std::vector<int8_t> GetByNode(const Tree* tree, int leaf) {
+    std::unordered_set<int> tree_allowed_features;
+    if (!tree_interaction_constraints_.empty()) {
+      std::set<int> tree_features = tree->tree_features();
+      tree_allowed_features.insert(tree_features.begin(), tree_features.end());
+      for (auto constraint : tree_interaction_constraints_) {
+        int num_feat_found = 0;
+
+        if (tree_features.empty()) {
+          tree_allowed_features.insert(constraint.begin(), constraint.end());
+        }
+
+        for (int feat : tree_features) {
+          if (constraint.count(feat) == 0) { break; }
+          ++num_feat_found;
+          if (num_feat_found == static_cast<int>(tree_features.size())) {
+            tree_allowed_features.insert(constraint.begin(), constraint.end());
+            break;
+          }
+        }
+      }
+    }
+
     // get interaction constraints for current branch
-    std::unordered_set<int> allowed_features;
+    std::unordered_set<int> branch_allowed_features;
     if (!interaction_constraints_.empty()) {
       std::vector<int> branch_features = tree->branch_features(leaf);
-      allowed_features.insert(branch_features.begin(), branch_features.end());
       for (auto constraint : interaction_constraints_) {
         int num_feat_found = 0;
-        if (branch_features.size() == 0) {
-          allowed_features.insert(constraint.begin(), constraint.end());
+        if (branch_features.empty()) {
+          branch_allowed_features.insert(constraint.begin(), constraint.end());
         }
         for (int feat : branch_features) {
           if (constraint.count(feat) == 0) { break; }
           ++num_feat_found;
           if (num_feat_found == static_cast<int>(branch_features.size())) {
-            allowed_features.insert(constraint.begin(), constraint.end());
+            branch_allowed_features.insert(constraint.begin(), constraint.end());
             break;
           }
         }
       }
     }
 
+    // intersect allowed features for branch and tree
+    std::unordered_set<int> allowed_features;
+
+    if(tree_interaction_constraints_.empty() && !interaction_constraints_.empty()) {
+      allowed_features.insert(branch_allowed_features.begin(), branch_allowed_features.end());
+    } else if(!tree_interaction_constraints_.empty() && interaction_constraints_.empty()){
+      allowed_features.insert(tree_allowed_features.begin(), tree_allowed_features.end());
+    } else {
+      for (int element : tree_allowed_features) {
+        if (branch_allowed_features.count(element) > 0) {
+          allowed_features.insert(element);
+        }
+      }
+    }
+
+
     std::vector<int8_t> ret(train_data_->num_features(), 0);
     if (fraction_bynode_ >= 1.0f) {
-      if (interaction_constraints_.empty()) {
+      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty()) {
         return std::vector<int8_t>(train_data_->num_features(), 1);
       } else {
         for (int feat : allowed_features) {
@@ -128,7 +169,7 @@ class ColSampler {
       auto used_feature_cnt = GetCnt(used_feature_indices_.size(), fraction_bynode_);
       std::vector<int>* allowed_used_feature_indices;
       std::vector<int> filtered_feature_indices;
-      if (interaction_constraints_.empty()) {
+      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty()) {
         allowed_used_feature_indices = &used_feature_indices_;
       } else {
         for (int feat_ind : used_feature_indices_) {
@@ -154,7 +195,7 @@ class ColSampler {
           GetCnt(valid_feature_indices_.size(), fraction_bynode_);
       std::vector<int>* allowed_valid_feature_indices;
       std::vector<int> filtered_feature_indices;
-      if (interaction_constraints_.empty()) {
+      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty()) {
         allowed_valid_feature_indices = &valid_feature_indices_;
       } else {
         for (int feat : valid_feature_indices_) {
@@ -199,6 +240,8 @@ class ColSampler {
   std::vector<int> valid_feature_indices_;
   /*! \brief interaction constraints index in original (raw data) features */
   std::vector<std::unordered_set<int>> interaction_constraints_;
+  /*! \brief tree nteraction constraints index in original (raw data) features */
+  std::vector<std::unordered_set<int>> tree_interaction_constraints_;
 };
 
 }  // namespace LightGBM
diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp
index 304c712f0723..25fe7738fcdb 100644
--- a/src/treelearner/serial_tree_learner.cpp
+++ b/src/treelearner/serial_tree_learner.cpp
@@ -172,7 +172,8 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
   // some initial works before training
   BeforeTrain();
 
-  bool track_branch_features = !(config_->interaction_constraints_vector.empty());
+  bool track_branch_features = !(config_->interaction_constraints_vector.empty()
+                                 && config_->tree_interaction_constraints_vector.empty());
   auto tree = std::unique_ptr<Tree>(new Tree(config_->num_leaves, track_branch_features, false));
   auto tree_ptr = tree.get();
   constraints_->ShareTreePointer(tree_ptr);
@@ -282,6 +283,19 @@ void SerialTreeLearner::BeforeTrain() {
 
 bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
   Common::FunctionTimer fun_timer("SerialTreeLearner::BeforeFindBestSplit", global_timer);
+
+  #pragma omp parallel for schedule(static)
+  for (int i = 0; i < config_->num_leaves; ++i) {
+    int feat_index = best_split_per_leaf_[i].feature;
+    if(feat_index == -1) continue;
+
+    int inner_feat_index = train_data_->InnerFeatureIndex(feat_index);
+    auto allowed_feature = col_sampler_.GetByNode(tree, i);
+    if(!allowed_feature[inner_feat_index]){
+      RecomputeBestSplitForLeaf(tree, i, &best_split_per_leaf_[i]);
+    }
+  }
+
   // check depth of current leaf
   if (config_->max_depth > 0) {
     // only need to check left leaf, since right leaf is in same level of left leaf
@@ -801,7 +815,7 @@ double SerialTreeLearner::GetParentOutput(const Tree* tree, const LeafSplits* le
   return parent_output;
 }
 
-void SerialTreeLearner::RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split) {
+void SerialTreeLearner::RecomputeBestSplitForLeaf(const Tree* tree, int leaf, SplitInfo* split) {
   FeatureHistogram* histogram_array_;
   if (!histogram_pool_.Get(leaf, &histogram_array_)) {
     Log::Warning(
diff --git a/src/treelearner/serial_tree_learner.h b/src/treelearner/serial_tree_learner.h
index 7d05debbc12b..c507cd2ffde1 100644
--- a/src/treelearner/serial_tree_learner.h
+++ b/src/treelearner/serial_tree_learner.h
@@ -129,7 +129,7 @@ class SerialTreeLearner: public TreeLearner {
 
   void GetShareStates(const Dataset* dataset, bool is_constant_hessian, bool is_first_time);
 
-  void RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split);
+  void RecomputeBestSplitForLeaf(const Tree* tree, int leaf, SplitInfo* split);
 
   /*!
   * \brief Some initial works before training
diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py
index 1b202b413a2b..a47c1fd4081b 100644
--- a/tests/python_package_test/test_engine.py
+++ b/tests/python_package_test/test_engine.py
@@ -2600,7 +2600,7 @@ def metrics_combination_cv_regression(metric_list, assumed_iteration,
                                       feval=lambda preds, train_data: [constant_metric(preds, train_data),
                                                                        decreasing_metric(preds, train_data)])
 
-
+#TODO investigate why this test fails
 def test_node_level_subcol():
     X, y = load_breast_cancer(return_X_y=True)
     X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
@@ -2918,6 +2918,56 @@ def test_interaction_constraints():
                                                           [1] + list(range(2, num_features))]),
                     train_data, num_boost_round=10)
 
+@pytest.mark.skipif(getenv('TASK', '') == 'cuda_exp', reason='Interaction constraints are not yet supported by CUDA Experimental version')
+def test_tree_interaction_constraints():
+    def check_consistency(est, tree_interaction_constraints):
+        feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())}
+        tree_df = est.trees_to_dataframe()
+        inter_found = set()
+        for tree_index in tree_df["tree_index"].unique():
+            tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index]
+            feat_used = [feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None]
+            inter_found.add(tuple(sorted(feat_used)))
+        print(inter_found)
+        for feats_found in inter_found:
+            found = False
+            for real_contraints in tree_interaction_constraints:
+                if set(feats_found) <= set(real_contraints):
+                    found = True
+                    break
+            assert found is True
+    X, y = load_boston(return_X_y=True)
+    num_features = X.shape[1]
+    train_data = lgb.Dataset(X, label=y)
+    # check that tree constraint containing all features is equivalent to no constraint
+    params = {'verbose': -1,
+              'seed': 0}
+    est = lgb.train(params, train_data, num_boost_round=10)
+    pred1 = est.predict(X)
+    est = lgb.train(dict(params, tree_interaction_constraints=[list(range(num_features))]), train_data,
+                    num_boost_round=10)
+    pred2 = est.predict(X)
+    np.testing.assert_allclose(pred1, pred2)
+    # check that each tree is composed exactly of 2 features contained in the contrained set
+    tree_interaction_constraints = [[i, i + 1] for i in range(0, num_features - 1, 2)]
+    print(tree_interaction_constraints)
+    new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
+    est = lgb.train(new_params, train_data, num_boost_round=100)
+    check_consistency(est, tree_interaction_constraints)
+    # check if tree features interaction constraints works with multiple set of features
+    tree_interaction_constraints = [[i for i in range(i, i + 5)] for i in range(0, num_features - 5, 5)]
+    print(tree_interaction_constraints)
+    new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
+    est = lgb.train(new_params, train_data, num_boost_round=100)
+    check_consistency(est, tree_interaction_constraints)
+    # check if tree features interaction constraints works with multiple set of features
+    tree_interaction_constraints = [[i] for i in range(num_features)]
+    print(tree_interaction_constraints)
+    new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
+    est = lgb.train(new_params, train_data, num_boost_round=100)
+    check_consistency(est, tree_interaction_constraints)
+
+
 
 def test_linear_trees(tmp_path):
     # check that setting linear_tree=True fits better than ordinary trees when data has linear relationship

From 57301984863932a1c2fcd57b7bb90b9181c8730b Mon Sep 17 00:00:00 2001
From: Alberto Veneri <veneri.alberto@gmail.com>
Date: Thu, 7 Apr 2022 16:05:20 +0200
Subject: [PATCH 02/21] readme update

---
 README.md | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/README.md b/README.md
index 1cb7b9019ff5..83487085d341 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,12 @@
+
+Forked version of LightGBM by veneres
+===============================
+
+This forked version of LightGBM is made to be used with the proof of concept of ILMART, a new additive model based on
+LambdaMART.
+Further information will be shared soon.
+
+
 <img src=https://github.com/microsoft/LightGBM/blob/master/docs/logo/LightGBM_logo_black_text.svg width=300 />
 
 Light Gradient Boosting Machine

From 5d69338f8c8331da10906f10fd65bfca9f36c680 Mon Sep 17 00:00:00 2001
From: Alberto Veneri <veneri.alberto@gmail.com>
Date: Thu, 7 Apr 2022 16:00:26 +0200
Subject: [PATCH 03/21] First version of the new parameter
 "tree_interaction_constraints""

---
 include/LightGBM/config.h                |  10 ++
 include/LightGBM/tree.h                  |  10 ++
 src/io/config.cpp                        |  12 +-
 src/io/config_auto.cpp                   | 142 +++++++++++++++++++++++
 src/treelearner/col_sampler.hpp          |  59 ++++++++--
 src/treelearner/serial_tree_learner.cpp  |  18 ++-
 src/treelearner/serial_tree_learner.h    |   2 +-
 tests/python_package_test/test_engine.py |  52 ++++++++-
 8 files changed, 292 insertions(+), 13 deletions(-)

diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h
index 6500cb77272d..e46f4c51e871 100644
--- a/include/LightGBM/config.h
+++ b/include/LightGBM/config.h
@@ -571,6 +571,14 @@ struct Config {
   // desc = any two features can only appear in the same branch only if there exists a constraint containing both features
   std::string interaction_constraints = "";
 
+  // desc = controls which features can appear in the same tree
+  // desc = by default interaction constraints are disabled, to enable them you can specify
+  // descl2 = for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]``
+  // descl2 = for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]``
+  // descl2 = for R-package, list of character or numeric vectors, e.g. ``list(c("var1", "var2", "var3"), c("var3", "var4"))`` or ``list(c(1L, 2L, 3L), c(3L, 4L))``. Numeric vectors should use 1-based indexing, where ``1L`` is the first feature, ``2L`` is the second feature, etc
+  // desc = any two features can only appear in the same tree only if there exists a constraint containing both features
+  std::string tree_interaction_constraints = "";
+
   // alias = verbose
   // desc = controls the level of LightGBM's verbosity
   // desc = ``< 0``: Fatal, ``= 0``: Error (Warning), ``= 1``: Info, ``> 1``: Debug
@@ -1126,6 +1134,7 @@ struct Config {
   static const std::unordered_set<std::string>& parameter_set();
   std::vector<std::vector<double>> auc_mu_weights_matrix;
   std::vector<std::vector<int>> interaction_constraints_vector;
+  std::vector<std::vector<int>> tree_interaction_constraints_vector;
   static const std::unordered_map<std::string, std::string>& ParameterTypes();
   static const std::string DumpAliases();
 
@@ -1135,6 +1144,7 @@ struct Config {
   std::string SaveMembersToString() const;
   void GetAucMuWeights();
   void GetInteractionConstraints();
+  void GetTreeInteractionConstraints();
 };
 
 inline bool Config::GetString(
diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h
index 0c4a41f46a87..2bcfde43cf34 100644
--- a/include/LightGBM/tree.h
+++ b/include/LightGBM/tree.h
@@ -158,6 +158,10 @@ class Tree {
   /*! \brief Get features on leaf's branch*/
   inline std::vector<int> branch_features(int leaf) const { return branch_features_[leaf]; }
 
+  std::set<int> tree_features() const {
+    return tree_features_;
+  }
+
   inline double split_gain(int split_idx) const { return split_gain_[split_idx]; }
 
   inline double internal_value(int node_idx) const {
@@ -520,6 +524,10 @@ class Tree {
   bool track_branch_features_;
   /*! \brief Features on leaf's branch, original index */
   std::vector<std::vector<int>> branch_features_;
+
+  /*! \brief Features used by the tree, original index */
+  std::set<int> tree_features_;
+
   double shrinkage_;
   int max_depth_;
   /*! \brief Tree has linear model at each leaf */
@@ -579,7 +587,9 @@ inline void Tree::Split(int leaf, int feature, int real_feature,
     branch_features_[num_leaves_] = branch_features_[leaf];
     branch_features_[num_leaves_].push_back(split_feature_[new_node_idx]);
     branch_features_[leaf].push_back(split_feature_[new_node_idx]);
+    tree_features_.insert(split_feature_[new_node_idx]);
   }
+
 }
 
 inline double Tree::Predict(const double* feature_values) const {
diff --git a/src/io/config.cpp b/src/io/config.cpp
index e25bb6d4fd70..4388bbc91136 100644
--- a/src/io/config.cpp
+++ b/src/io/config.cpp
@@ -232,13 +232,21 @@ void Config::GetAucMuWeights() {
 }
 
 void Config::GetInteractionConstraints() {
-  if (interaction_constraints == "") {
+  if (interaction_constraints.empty()) {
     interaction_constraints_vector = std::vector<std::vector<int>>();
   } else {
     interaction_constraints_vector = Common::StringToArrayofArrays<int>(interaction_constraints, '[', ']', ',');
   }
 }
 
+void Config::GetTreeInteractionConstraints() {
+  if (tree_interaction_constraints.empty()) {
+    tree_interaction_constraints_vector = std::vector<std::vector<int>>();
+  } else {
+    tree_interaction_constraints_vector = Common::StringToArrayofArrays<int>(tree_interaction_constraints, '[', ']', ',');
+  }
+}
+
 void Config::Set(const std::unordered_map<std::string, std::string>& params) {
   // generate seeds by seed.
   if (GetInt(params, "seed", &seed)) {
@@ -269,6 +277,8 @@ void Config::Set(const std::unordered_map<std::string, std::string>& params) {
 
   GetInteractionConstraints();
 
+  GetTreeInteractionConstraints();
+
   // sort eval_at
   std::sort(eval_at.begin(), eval_at.end());
 
diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp
index 394614af3f33..e7090c752fd2 100644
--- a/src/io/config_auto.cpp
+++ b/src/io/config_auto.cpp
@@ -246,6 +246,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
   "cegb_penalty_feature_coupled",
   "path_smooth",
   "interaction_constraints",
+  "tree_interaction_constraints",
   "verbosity",
   "input_model",
   "output_model",
@@ -488,6 +489,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
 
   GetString(params, "interaction_constraints", &interaction_constraints);
 
+  GetString(params, "tree_interaction_constraints", &tree_interaction_constraints);
+
   GetInt(params, "verbosity", &verbosity);
 
   GetString(params, "input_model", &input_model);
@@ -722,6 +725,7 @@ std::string Config::SaveMembersToString() const {
   str_buf << "[cegb_penalty_feature_coupled: " << Common::Join(cegb_penalty_feature_coupled, ",") << "]\n";
   str_buf << "[path_smooth: " << path_smooth << "]\n";
   str_buf << "[interaction_constraints: " << interaction_constraints << "]\n";
+  str_buf << "[tree_interaction_constraints: " << tree_interaction_constraints << "]\n";
   str_buf << "[verbosity: " << verbosity << "]\n";
   str_buf << "[saved_feature_importance_type: " << saved_feature_importance_type << "]\n";
   str_buf << "[use_quantized_grad: " << use_quantized_grad << "]\n";
@@ -922,6 +926,144 @@ const std::unordered_map<std::string, std::vector<std::string>>& Config::paramet
     {"num_gpu", {}},
   });
   return map;
+const std::string Config::DumpAliases() {
+  std::stringstream str_buf;
+  str_buf << "{";
+  str_buf << "\"config\": [\"config_file\"], ";
+  str_buf << "\"task\": [\"task_type\"], ";
+  str_buf << "\"objective\": [\"objective_type\", \"app\", \"application\", \"loss\"], ";
+  str_buf << "\"boosting\": [\"boosting_type\", \"boost\"], ";
+  str_buf << "\"data\": [\"train\", \"train_data\", \"train_data_file\", \"data_filename\"], ";
+  str_buf << "\"valid\": [\"test\", \"valid_data\", \"valid_data_file\", \"test_data\", \"test_data_file\", \"valid_filenames\"], ";
+  str_buf << "\"num_iterations\": [\"num_iteration\", \"n_iter\", \"num_tree\", \"num_trees\", \"num_round\", \"num_rounds\", \"nrounds\", \"num_boost_round\", \"n_estimators\", \"max_iter\"], ";
+  str_buf << "\"learning_rate\": [\"shrinkage_rate\", \"eta\"], ";
+  str_buf << "\"num_leaves\": [\"num_leaf\", \"max_leaves\", \"max_leaf\", \"max_leaf_nodes\"], ";
+  str_buf << "\"tree_learner\": [\"tree\", \"tree_type\", \"tree_learner_type\"], ";
+  str_buf << "\"num_threads\": [\"num_thread\", \"nthread\", \"nthreads\", \"n_jobs\"], ";
+  str_buf << "\"device_type\": [\"device\"], ";
+  str_buf << "\"seed\": [\"random_seed\", \"random_state\"], ";
+  str_buf << "\"deterministic\": [], ";
+  str_buf << "\"force_col_wise\": [], ";
+  str_buf << "\"force_row_wise\": [], ";
+  str_buf << "\"histogram_pool_size\": [\"hist_pool_size\"], ";
+  str_buf << "\"max_depth\": [], ";
+  str_buf << "\"min_data_in_leaf\": [\"min_data_per_leaf\", \"min_data\", \"min_child_samples\", \"min_samples_leaf\"], ";
+  str_buf << "\"min_sum_hessian_in_leaf\": [\"min_sum_hessian_per_leaf\", \"min_sum_hessian\", \"min_hessian\", \"min_child_weight\"], ";
+  str_buf << "\"bagging_fraction\": [\"sub_row\", \"subsample\", \"bagging\"], ";
+  str_buf << "\"pos_bagging_fraction\": [\"pos_sub_row\", \"pos_subsample\", \"pos_bagging\"], ";
+  str_buf << "\"neg_bagging_fraction\": [\"neg_sub_row\", \"neg_subsample\", \"neg_bagging\"], ";
+  str_buf << "\"bagging_freq\": [\"subsample_freq\"], ";
+  str_buf << "\"bagging_seed\": [\"bagging_fraction_seed\"], ";
+  str_buf << "\"feature_fraction\": [\"sub_feature\", \"colsample_bytree\"], ";
+  str_buf << "\"feature_fraction_bynode\": [\"sub_feature_bynode\", \"colsample_bynode\"], ";
+  str_buf << "\"feature_fraction_seed\": [], ";
+  str_buf << "\"extra_trees\": [\"extra_tree\"], ";
+  str_buf << "\"extra_seed\": [], ";
+  str_buf << "\"early_stopping_round\": [\"early_stopping_rounds\", \"early_stopping\", \"n_iter_no_change\"], ";
+  str_buf << "\"first_metric_only\": [], ";
+  str_buf << "\"max_delta_step\": [\"max_tree_output\", \"max_leaf_output\"], ";
+  str_buf << "\"lambda_l1\": [\"reg_alpha\", \"l1_regularization\"], ";
+  str_buf << "\"lambda_l2\": [\"reg_lambda\", \"lambda\", \"l2_regularization\"], ";
+  str_buf << "\"linear_lambda\": [], ";
+  str_buf << "\"min_gain_to_split\": [\"min_split_gain\"], ";
+  str_buf << "\"drop_rate\": [\"rate_drop\"], ";
+  str_buf << "\"max_drop\": [], ";
+  str_buf << "\"skip_drop\": [], ";
+  str_buf << "\"xgboost_dart_mode\": [], ";
+  str_buf << "\"uniform_drop\": [], ";
+  str_buf << "\"drop_seed\": [], ";
+  str_buf << "\"top_rate\": [], ";
+  str_buf << "\"other_rate\": [], ";
+  str_buf << "\"min_data_per_group\": [], ";
+  str_buf << "\"max_cat_threshold\": [], ";
+  str_buf << "\"cat_l2\": [], ";
+  str_buf << "\"cat_smooth\": [], ";
+  str_buf << "\"max_cat_to_onehot\": [], ";
+  str_buf << "\"top_k\": [\"topk\"], ";
+  str_buf << "\"monotone_constraints\": [\"mc\", \"monotone_constraint\", \"monotonic_cst\"], ";
+  str_buf << "\"monotone_constraints_method\": [\"monotone_constraining_method\", \"mc_method\"], ";
+  str_buf << "\"monotone_penalty\": [\"monotone_splits_penalty\", \"ms_penalty\", \"mc_penalty\"], ";
+  str_buf << "\"feature_contri\": [\"feature_contrib\", \"fc\", \"fp\", \"feature_penalty\"], ";
+  str_buf << "\"forcedsplits_filename\": [\"fs\", \"forced_splits_filename\", \"forced_splits_file\", \"forced_splits\"], ";
+  str_buf << "\"refit_decay_rate\": [], ";
+  str_buf << "\"cegb_tradeoff\": [], ";
+  str_buf << "\"cegb_penalty_split\": [], ";
+  str_buf << "\"cegb_penalty_feature_lazy\": [], ";
+  str_buf << "\"cegb_penalty_feature_coupled\": [], ";
+  str_buf << "\"path_smooth\": [], ";
+  str_buf << "\"interaction_constraints\": [], ";
+  str_buf << "\"tree_interaction_constraints\": [], ";
+  str_buf << "\"verbosity\": [\"verbose\"], ";
+  str_buf << "\"input_model\": [\"model_input\", \"model_in\"], ";
+  str_buf << "\"output_model\": [\"model_output\", \"model_out\"], ";
+  str_buf << "\"saved_feature_importance_type\": [], ";
+  str_buf << "\"snapshot_freq\": [\"save_period\"], ";
+  str_buf << "\"linear_tree\": [\"linear_trees\"], ";
+  str_buf << "\"max_bin\": [\"max_bins\"], ";
+  str_buf << "\"max_bin_by_feature\": [], ";
+  str_buf << "\"min_data_in_bin\": [], ";
+  str_buf << "\"bin_construct_sample_cnt\": [\"subsample_for_bin\"], ";
+  str_buf << "\"data_random_seed\": [\"data_seed\"], ";
+  str_buf << "\"is_enable_sparse\": [\"is_sparse\", \"enable_sparse\", \"sparse\"], ";
+  str_buf << "\"enable_bundle\": [\"is_enable_bundle\", \"bundle\"], ";
+  str_buf << "\"use_missing\": [], ";
+  str_buf << "\"zero_as_missing\": [], ";
+  str_buf << "\"feature_pre_filter\": [], ";
+  str_buf << "\"pre_partition\": [\"is_pre_partition\"], ";
+  str_buf << "\"two_round\": [\"two_round_loading\", \"use_two_round_loading\"], ";
+  str_buf << "\"header\": [\"has_header\"], ";
+  str_buf << "\"label_column\": [\"label\"], ";
+  str_buf << "\"weight_column\": [\"weight\"], ";
+  str_buf << "\"group_column\": [\"group\", \"group_id\", \"query_column\", \"query\", \"query_id\"], ";
+  str_buf << "\"ignore_column\": [\"ignore_feature\", \"blacklist\"], ";
+  str_buf << "\"categorical_feature\": [\"cat_feature\", \"categorical_column\", \"cat_column\", \"categorical_features\"], ";
+  str_buf << "\"forcedbins_filename\": [], ";
+  str_buf << "\"save_binary\": [\"is_save_binary\", \"is_save_binary_file\"], ";
+  str_buf << "\"precise_float_parser\": [], ";
+  str_buf << "\"parser_config_file\": [], ";
+  str_buf << "\"start_iteration_predict\": [], ";
+  str_buf << "\"num_iteration_predict\": [], ";
+  str_buf << "\"predict_raw_score\": [\"is_predict_raw_score\", \"predict_rawscore\", \"raw_score\"], ";
+  str_buf << "\"predict_leaf_index\": [\"is_predict_leaf_index\", \"leaf_index\"], ";
+  str_buf << "\"predict_contrib\": [\"is_predict_contrib\", \"contrib\"], ";
+  str_buf << "\"predict_disable_shape_check\": [], ";
+  str_buf << "\"pred_early_stop\": [], ";
+  str_buf << "\"pred_early_stop_freq\": [], ";
+  str_buf << "\"pred_early_stop_margin\": [], ";
+  str_buf << "\"output_result\": [\"predict_result\", \"prediction_result\", \"predict_name\", \"prediction_name\", \"pred_name\", \"name_pred\"], ";
+  str_buf << "\"convert_model_language\": [], ";
+  str_buf << "\"convert_model\": [\"convert_model_file\"], ";
+  str_buf << "\"objective_seed\": [], ";
+  str_buf << "\"num_class\": [\"num_classes\"], ";
+  str_buf << "\"is_unbalance\": [\"unbalance\", \"unbalanced_sets\"], ";
+  str_buf << "\"scale_pos_weight\": [], ";
+  str_buf << "\"sigmoid\": [], ";
+  str_buf << "\"boost_from_average\": [], ";
+  str_buf << "\"reg_sqrt\": [], ";
+  str_buf << "\"alpha\": [], ";
+  str_buf << "\"fair_c\": [], ";
+  str_buf << "\"poisson_max_delta_step\": [], ";
+  str_buf << "\"tweedie_variance_power\": [], ";
+  str_buf << "\"lambdarank_truncation_level\": [], ";
+  str_buf << "\"lambdarank_norm\": [], ";
+  str_buf << "\"label_gain\": [], ";
+  str_buf << "\"metric\": [\"metrics\", \"metric_types\"], ";
+  str_buf << "\"metric_freq\": [\"output_freq\"], ";
+  str_buf << "\"is_provide_training_metric\": [\"training_metric\", \"is_training_metric\", \"train_metric\"], ";
+  str_buf << "\"eval_at\": [\"ndcg_eval_at\", \"ndcg_at\", \"map_eval_at\", \"map_at\"], ";
+  str_buf << "\"multi_error_top_k\": [], ";
+  str_buf << "\"auc_mu_weights\": [], ";
+  str_buf << "\"num_machines\": [\"num_machine\"], ";
+  str_buf << "\"local_listen_port\": [\"local_port\", \"port\"], ";
+  str_buf << "\"time_out\": [], ";
+  str_buf << "\"machine_list_filename\": [\"machine_list_file\", \"machine_list\", \"mlist\"], ";
+  str_buf << "\"machines\": [\"workers\", \"nodes\"], ";
+  str_buf << "\"gpu_platform_id\": [], ";
+  str_buf << "\"gpu_device_id\": [], ";
+  str_buf << "\"gpu_use_dp\": [], ";
+  str_buf << "\"num_gpu\": []";
+  str_buf << "}";
+  return str_buf.str();
 }
 
 const std::unordered_map<std::string, std::string>& Config::ParameterTypes() {
diff --git a/src/treelearner/col_sampler.hpp b/src/treelearner/col_sampler.hpp
index c70b07e50efa..87294369906d 100644
--- a/src/treelearner/col_sampler.hpp
+++ b/src/treelearner/col_sampler.hpp
@@ -28,6 +28,10 @@ class ColSampler {
       std::unordered_set<int> constraint_set(constraint.begin(), constraint.end());
       interaction_constraints_.push_back(constraint_set);
     }
+    for (auto constraint : config->tree_interaction_constraints_vector) {
+      std::unordered_set<int> constraint_set(constraint.begin(), constraint.end());
+      tree_interaction_constraints_.push_back(constraint_set);
+    }
   }
 
   static int GetCnt(size_t total_cnt, double fraction) {
@@ -89,30 +93,67 @@ class ColSampler {
   }
 
   std::vector<int8_t> GetByNode(const Tree* tree, int leaf) {
+    std::unordered_set<int> tree_allowed_features;
+    if (!tree_interaction_constraints_.empty()) {
+      std::set<int> tree_features = tree->tree_features();
+      tree_allowed_features.insert(tree_features.begin(), tree_features.end());
+      for (auto constraint : tree_interaction_constraints_) {
+        int num_feat_found = 0;
+
+        if (tree_features.empty()) {
+          tree_allowed_features.insert(constraint.begin(), constraint.end());
+        }
+
+        for (int feat : tree_features) {
+          if (constraint.count(feat) == 0) { break; }
+          ++num_feat_found;
+          if (num_feat_found == static_cast<int>(tree_features.size())) {
+            tree_allowed_features.insert(constraint.begin(), constraint.end());
+            break;
+          }
+        }
+      }
+    }
+
     // get interaction constraints for current branch
-    std::unordered_set<int> allowed_features;
+    std::unordered_set<int> branch_allowed_features;
     if (!interaction_constraints_.empty()) {
       std::vector<int> branch_features = tree->branch_features(leaf);
-      allowed_features.insert(branch_features.begin(), branch_features.end());
       for (auto constraint : interaction_constraints_) {
         int num_feat_found = 0;
-        if (branch_features.size() == 0) {
-          allowed_features.insert(constraint.begin(), constraint.end());
+        if (branch_features.empty()) {
+          branch_allowed_features.insert(constraint.begin(), constraint.end());
         }
         for (int feat : branch_features) {
           if (constraint.count(feat) == 0) { break; }
           ++num_feat_found;
           if (num_feat_found == static_cast<int>(branch_features.size())) {
-            allowed_features.insert(constraint.begin(), constraint.end());
+            branch_allowed_features.insert(constraint.begin(), constraint.end());
             break;
           }
         }
       }
     }
 
+    // intersect allowed features for branch and tree
+    std::unordered_set<int> allowed_features;
+
+    if(tree_interaction_constraints_.empty() && !interaction_constraints_.empty()) {
+      allowed_features.insert(branch_allowed_features.begin(), branch_allowed_features.end());
+    } else if(!tree_interaction_constraints_.empty() && interaction_constraints_.empty()){
+      allowed_features.insert(tree_allowed_features.begin(), tree_allowed_features.end());
+    } else {
+      for (int element : tree_allowed_features) {
+        if (branch_allowed_features.count(element) > 0) {
+          allowed_features.insert(element);
+        }
+      }
+    }
+
+
     std::vector<int8_t> ret(train_data_->num_features(), 0);
     if (fraction_bynode_ >= 1.0f) {
-      if (interaction_constraints_.empty()) {
+      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty()) {
         return std::vector<int8_t>(train_data_->num_features(), 1);
       } else {
         for (int feat : allowed_features) {
@@ -128,7 +169,7 @@ class ColSampler {
       auto used_feature_cnt = GetCnt(used_feature_indices_.size(), fraction_bynode_);
       std::vector<int>* allowed_used_feature_indices;
       std::vector<int> filtered_feature_indices;
-      if (interaction_constraints_.empty()) {
+      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty()) {
         allowed_used_feature_indices = &used_feature_indices_;
       } else {
         for (int feat_ind : used_feature_indices_) {
@@ -154,7 +195,7 @@ class ColSampler {
           GetCnt(valid_feature_indices_.size(), fraction_bynode_);
       std::vector<int>* allowed_valid_feature_indices;
       std::vector<int> filtered_feature_indices;
-      if (interaction_constraints_.empty()) {
+      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty()) {
         allowed_valid_feature_indices = &valid_feature_indices_;
       } else {
         for (int feat : valid_feature_indices_) {
@@ -199,6 +240,8 @@ class ColSampler {
   std::vector<int> valid_feature_indices_;
   /*! \brief interaction constraints index in original (raw data) features */
   std::vector<std::unordered_set<int>> interaction_constraints_;
+  /*! \brief tree nteraction constraints index in original (raw data) features */
+  std::vector<std::unordered_set<int>> tree_interaction_constraints_;
 };
 
 }  // namespace LightGBM
diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp
index d5c5cc59ef3a..4301ce4f0e24 100644
--- a/src/treelearner/serial_tree_learner.cpp
+++ b/src/treelearner/serial_tree_learner.cpp
@@ -196,7 +196,8 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
   // some initial works before training
   BeforeTrain();
 
-  bool track_branch_features = !(config_->interaction_constraints_vector.empty());
+  bool track_branch_features = !(config_->interaction_constraints_vector.empty()
+                                 && config_->tree_interaction_constraints_vector.empty());
   auto tree = std::unique_ptr<Tree>(new Tree(config_->num_leaves, track_branch_features, false));
   auto tree_ptr = tree.get();
   constraints_->ShareTreePointer(tree_ptr);
@@ -333,6 +334,19 @@ void SerialTreeLearner::BeforeTrain() {
 
 bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
   Common::FunctionTimer fun_timer("SerialTreeLearner::BeforeFindBestSplit", global_timer);
+
+  #pragma omp parallel for schedule(static)
+  for (int i = 0; i < config_->num_leaves; ++i) {
+    int feat_index = best_split_per_leaf_[i].feature;
+    if(feat_index == -1) continue;
+
+    int inner_feat_index = train_data_->InnerFeatureIndex(feat_index);
+    auto allowed_feature = col_sampler_.GetByNode(tree, i);
+    if(!allowed_feature[inner_feat_index]){
+      RecomputeBestSplitForLeaf(tree, i, &best_split_per_leaf_[i]);
+    }
+  }
+
   // check depth of current leaf
   if (config_->max_depth > 0) {
     // only need to check left leaf, since right leaf is in same level of left leaf
@@ -1010,7 +1024,7 @@ double SerialTreeLearner::GetParentOutput(const Tree* tree, const LeafSplits* le
   return parent_output;
 }
 
-void SerialTreeLearner::RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split) {
+void SerialTreeLearner::RecomputeBestSplitForLeaf(const Tree* tree, int leaf, SplitInfo* split) {
   FeatureHistogram* histogram_array_;
   if (!histogram_pool_.Get(leaf, &histogram_array_)) {
     Log::Warning(
diff --git a/src/treelearner/serial_tree_learner.h b/src/treelearner/serial_tree_learner.h
index 43ff6a4b1e13..9531a24c3e06 100644
--- a/src/treelearner/serial_tree_learner.h
+++ b/src/treelearner/serial_tree_learner.h
@@ -130,7 +130,7 @@ class SerialTreeLearner: public TreeLearner {
 
   void GetShareStates(const Dataset* dataset, bool is_constant_hessian, bool is_first_time);
 
-  void RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split);
+  void RecomputeBestSplitForLeaf(const Tree* tree, int leaf, SplitInfo* split);
 
   /*!
   * \brief Some initial works before training
diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py
index e355e5ab074a..82a608a1fa4e 100644
--- a/tests/python_package_test/test_engine.py
+++ b/tests/python_package_test/test_engine.py
@@ -3330,7 +3330,7 @@ def metrics_combination_cv_regression(metric_list, assumed_iteration,
                                       feval=lambda preds, train_data: [constant_metric(preds, train_data),
                                                                        decreasing_metric(preds, train_data)])
 
-
+#TODO investigate why this test fails
 def test_node_level_subcol():
     X, y = load_breast_cancer(return_X_y=True)
     X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
@@ -3664,6 +3664,56 @@ def test_interaction_constraints():
                                                           [1] + list(range(2, num_features))]),
                     train_data, num_boost_round=10)
 
+@pytest.mark.skipif(getenv('TASK', '') == 'cuda_exp', reason='Interaction constraints are not yet supported by CUDA Experimental version')
+def test_tree_interaction_constraints():
+    def check_consistency(est, tree_interaction_constraints):
+        feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())}
+        tree_df = est.trees_to_dataframe()
+        inter_found = set()
+        for tree_index in tree_df["tree_index"].unique():
+            tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index]
+            feat_used = [feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None]
+            inter_found.add(tuple(sorted(feat_used)))
+        print(inter_found)
+        for feats_found in inter_found:
+            found = False
+            for real_contraints in tree_interaction_constraints:
+                if set(feats_found) <= set(real_contraints):
+                    found = True
+                    break
+            assert found is True
+    X, y = load_boston(return_X_y=True)
+    num_features = X.shape[1]
+    train_data = lgb.Dataset(X, label=y)
+    # check that tree constraint containing all features is equivalent to no constraint
+    params = {'verbose': -1,
+              'seed': 0}
+    est = lgb.train(params, train_data, num_boost_round=10)
+    pred1 = est.predict(X)
+    est = lgb.train(dict(params, tree_interaction_constraints=[list(range(num_features))]), train_data,
+                    num_boost_round=10)
+    pred2 = est.predict(X)
+    np.testing.assert_allclose(pred1, pred2)
+    # check that each tree is composed exactly of 2 features contained in the contrained set
+    tree_interaction_constraints = [[i, i + 1] for i in range(0, num_features - 1, 2)]
+    print(tree_interaction_constraints)
+    new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
+    est = lgb.train(new_params, train_data, num_boost_round=100)
+    check_consistency(est, tree_interaction_constraints)
+    # check if tree features interaction constraints works with multiple set of features
+    tree_interaction_constraints = [[i for i in range(i, i + 5)] for i in range(0, num_features - 5, 5)]
+    print(tree_interaction_constraints)
+    new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
+    est = lgb.train(new_params, train_data, num_boost_round=100)
+    check_consistency(est, tree_interaction_constraints)
+    # check if tree features interaction constraints works with multiple set of features
+    tree_interaction_constraints = [[i] for i in range(num_features)]
+    print(tree_interaction_constraints)
+    new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
+    est = lgb.train(new_params, train_data, num_boost_round=100)
+    check_consistency(est, tree_interaction_constraints)
+
+
 
 def test_linear_trees_num_threads():
     # check that number of threads does not affect result

From ec9ed61ca59fe36c7c6ac33371a10cae50a75b3b Mon Sep 17 00:00:00 2001
From: Alberto Veneri <veneri.alberto@gmail.com>
Date: Thu, 7 Apr 2022 16:05:20 +0200
Subject: [PATCH 04/21] readme update

---
 README.md | 9 +++++++++
 1 file changed, 9 insertions(+)

diff --git a/README.md b/README.md
index 3b3fe40790db..67d4f858d7ac 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,12 @@
+
+Forked version of LightGBM by veneres
+===============================
+
+This forked version of LightGBM is made to be used with the proof of concept of ILMART, a new additive model based on
+LambdaMART.
+Further information will be shared soon.
+
+
 <img src=https://github.com/microsoft/LightGBM/blob/master/docs/logo/LightGBM_logo_black_text.svg width=300 />
 
 Light Gradient Boosting Machine

From d1966c209916fff31967b22eb0626220dc146917 Mon Sep 17 00:00:00 2001
From: Alberto Veneri <alberto.veneri@unive.it>
Date: Wed, 14 Feb 2024 17:20:09 +0100
Subject: [PATCH 05/21] Updated readme

---
 README.md | 9 ---------
 1 file changed, 9 deletions(-)

diff --git a/README.md b/README.md
index 67d4f858d7ac..3b3fe40790db 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,3 @@
-
-Forked version of LightGBM by veneres
-===============================
-
-This forked version of LightGBM is made to be used with the proof of concept of ILMART, a new additive model based on
-LambdaMART.
-Further information will be shared soon.
-
-
 <img src=https://github.com/microsoft/LightGBM/blob/master/docs/logo/LightGBM_logo_black_text.svg width=300 />
 
 Light Gradient Boosting Machine

From 848fd58b01829b8159f75af7fd88e9551e7d24aa Mon Sep 17 00:00:00 2001
From: Alberto Veneri <alberto.veneri@unive.it>
Date: Wed, 14 Feb 2024 17:57:42 +0100
Subject: [PATCH 06/21] Fix missing parenthesis

---
 src/io/config_auto.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp
index e7090c752fd2..ae7f9567f6fb 100644
--- a/src/io/config_auto.cpp
+++ b/src/io/config_auto.cpp
@@ -926,6 +926,7 @@ const std::unordered_map<std::string, std::vector<std::string>>& Config::paramet
     {"num_gpu", {}},
   });
   return map;
+}
 const std::string Config::DumpAliases() {
   std::stringstream str_buf;
   str_buf << "{";

From d32b7f6d131b957552bcb8179d6d8feb3f56453d Mon Sep 17 00:00:00 2001
From: Alberto Veneri <alberto.veneri@unive.it>
Date: Wed, 14 Feb 2024 18:09:30 +0100
Subject: [PATCH 07/21] Temporarly remove a new test

---
 tests/python_package_test/test_engine.py | 50 ------------------------
 1 file changed, 50 deletions(-)

diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py
index b182b001dbc0..3b7433570761 100644
--- a/tests/python_package_test/test_engine.py
+++ b/tests/python_package_test/test_engine.py
@@ -3664,56 +3664,6 @@ def test_interaction_constraints():
                                                           [1] + list(range(2, num_features))]),
                     train_data, num_boost_round=10)
 
-@pytest.mark.skipif(getenv('TASK', '') == 'cuda_exp', reason='Interaction constraints are not yet supported by CUDA Experimental version')
-def test_tree_interaction_constraints():
-    def check_consistency(est, tree_interaction_constraints):
-        feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())}
-        tree_df = est.trees_to_dataframe()
-        inter_found = set()
-        for tree_index in tree_df["tree_index"].unique():
-            tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index]
-            feat_used = [feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None]
-            inter_found.add(tuple(sorted(feat_used)))
-        print(inter_found)
-        for feats_found in inter_found:
-            found = False
-            for real_contraints in tree_interaction_constraints:
-                if set(feats_found) <= set(real_contraints):
-                    found = True
-                    break
-            assert found is True
-    X, y = load_boston(return_X_y=True)
-    num_features = X.shape[1]
-    train_data = lgb.Dataset(X, label=y)
-    # check that tree constraint containing all features is equivalent to no constraint
-    params = {'verbose': -1,
-              'seed': 0}
-    est = lgb.train(params, train_data, num_boost_round=10)
-    pred1 = est.predict(X)
-    est = lgb.train(dict(params, tree_interaction_constraints=[list(range(num_features))]), train_data,
-                    num_boost_round=10)
-    pred2 = est.predict(X)
-    np.testing.assert_allclose(pred1, pred2)
-    # check that each tree is composed exactly of 2 features contained in the contrained set
-    tree_interaction_constraints = [[i, i + 1] for i in range(0, num_features - 1, 2)]
-    print(tree_interaction_constraints)
-    new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
-    est = lgb.train(new_params, train_data, num_boost_round=100)
-    check_consistency(est, tree_interaction_constraints)
-    # check if tree features interaction constraints works with multiple set of features
-    tree_interaction_constraints = [[i for i in range(i, i + 5)] for i in range(0, num_features - 5, 5)]
-    print(tree_interaction_constraints)
-    new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
-    est = lgb.train(new_params, train_data, num_boost_round=100)
-    check_consistency(est, tree_interaction_constraints)
-    # check if tree features interaction constraints works with multiple set of features
-    tree_interaction_constraints = [[i] for i in range(num_features)]
-    print(tree_interaction_constraints)
-    new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
-    est = lgb.train(new_params, train_data, num_boost_round=100)
-    check_consistency(est, tree_interaction_constraints)
-
-
 def test_linear_trees(tmp_path):
     # check that setting linear_tree=True fits better than ordinary trees when data has linear relationship
     np.random.seed(0)

From d216823ddcef5965f2da9887ee356a908f8040f4 Mon Sep 17 00:00:00 2001
From: veneres <alberto.veneri@unive.it>
Date: Thu, 15 Feb 2024 11:31:40 +0100
Subject: [PATCH 08/21] Merge with private repository edits

---
 include/LightGBM/config.h                |  10 ++
 include/LightGBM/tree.h                  |   5 +-
 python-package/lightgbm/engine.py        |  10 ++
 python-package/pyproject.toml            |  36 ++++--
 src/boosting/gbdt.cpp                    |  13 ++
 src/boosting/gbdt.h                      |   2 +
 src/io/config_auto.cpp                   | 153 ++---------------------
 src/io/tree.cpp                          |  12 ++
 src/treelearner/col_sampler.hpp          | 109 ++++++++++------
 src/treelearner/serial_tree_learner.cpp  |   4 +-
 tests/python_package_test/test_engine.py |  23 +++-
 11 files changed, 181 insertions(+), 196 deletions(-)

diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h
index e46f4c51e871..080f3408590f 100644
--- a/include/LightGBM/config.h
+++ b/include/LightGBM/config.h
@@ -579,6 +579,16 @@ struct Config {
   // desc = any two features can only appear in the same tree only if there exists a constraint containing both features
   std::string tree_interaction_constraints = "";
 
+
+  // desc = controls how many features can appear in the same tree
+  // desc = by default (n_tree_interaction_constraints = 0) interaction constraints are disabled
+  int n_tree_interaction_constraints = 0;
+
+  // desc = controls how many features interactions can be added to the final model
+  // desc = by default no limit is imposed on the interaction with max_interactions = 0
+  // desc = any two features can only appear in the same tree only if there exists a constraint containing both features
+  int max_interactions = 0;
+
   // alias = verbose
   // desc = controls the level of LightGBM's verbosity
   // desc = ``< 0``: Fatal, ``= 0``: Error (Warning), ``= 1``: Info, ``> 1``: Debug
diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h
index 2bcfde43cf34..5472bebf0517 100644
--- a/include/LightGBM/tree.h
+++ b/include/LightGBM/tree.h
@@ -156,10 +156,10 @@ class Tree {
   inline int split_feature_inner(int split_idx) const { return split_feature_inner_[split_idx]; }
 
   /*! \brief Get features on leaf's branch*/
-  inline std::vector<int> branch_features(int leaf) const { return branch_features_[leaf]; }
+  std::vector<int> branch_features(int leaf) const { return branch_features_[leaf]; }
 
   std::set<int> tree_features() const {
-    return tree_features_;
+     return tree_features_;
   }
 
   inline double split_gain(int split_idx) const { return split_gain_[split_idx]; }
@@ -589,7 +589,6 @@ inline void Tree::Split(int leaf, int feature, int real_feature,
     branch_features_[leaf].push_back(split_feature_[new_node_idx]);
     tree_features_.insert(split_feature_[new_node_idx]);
   }
-
 }
 
 inline double Tree::Predict(const double* feature_values) const {
diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py
index 822aa3b35017..daa7e823c614 100644
--- a/python-package/lightgbm/engine.py
+++ b/python-package/lightgbm/engine.py
@@ -264,6 +264,7 @@ def train(
     booster.best_iteration = 0
 
     # start training
+    interactions_used = set()
     for i in range(init_iteration, init_iteration + num_boost_round):
         for cb in callbacks_before_iter:
             cb(callback.CallbackEnv(model=booster,
@@ -276,6 +277,15 @@ def train(
         booster.update(fobj=fobj)
 
         evaluation_result_list: List[_LGBM_BoosterEvalMethodResultType] = []
+        if params["max_interactions"] > 0:
+            interaction_used = booster.dump_model(num_iteration=1, start_iteration=i)["tree_info"][0]["tree_features"]
+            interaction_used.sort()
+            interactions_used.add(tuple(interaction_used))
+
+            if len(interactions_used) == params["max_interactions"]:
+                params["tree_interaction_constraints"] = [list(feats) for feats in interactions_used]
+                params["max_interactions"] = 0
+                booster.reset_parameter(params)
         # check evaluation result.
         if valid_sets is not None:
             if is_valid_contain_train:
diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml
index 0aef91b1cd93..cb0c276faaeb 100644
--- a/python-package/pyproject.toml
+++ b/python-package/pyproject.toml
@@ -95,11 +95,29 @@ ignore_missing_imports = true
 exclude = [
     "build",
     "compile",
-    "docs",
     "external_libs",
     "lightgbm-python",
-    "setup.py"
 ]
+line-length = 120
+
+# this should be set to the oldest version of python LightGBM supports
+target-version = "py37"
+
+[tool.ruff.format]
+docstring-code-format = false
+exclude = [
+    "build/*.py",
+    "compile/*.py",
+    "examples/*.py",
+    "external_libs/*.py",
+    "lightgbm-python/*.py",
+    "python-package/*.py",
+    "tests/*.py"
+]
+indent-style = "space"
+quote-style = "double"
+
+[tool.ruff.lint]
 ignore = [
     # (pydocstyle) Missing docstring in magic method
     "D105",
@@ -125,10 +143,13 @@ select = [
     "T",
 ]
 
-# this should be set to the oldest version of python LightGBM supports
-target-version = "py37"
-
-[tool.ruff.per-file-ignores]
+[tool.ruff.lint.per-file-ignores]
+"docs/conf.py" = [
+    # (flake8-bugbear) raise exceptions with "raise ... from errr"
+    "B904",
+    # (flake8-print) flake8-print
+    "T"
+]
 "examples/*" = [
     # pydocstyle
     "D",
@@ -144,6 +165,5 @@ target-version = "py37"
     "T"
 ]
 
-[tool.ruff.pydocstyle]
-
+[tool.ruff.lint.pydocstyle]
 convention = "numpy"
diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp
index b75adab6d971..e60f43c21f35 100644
--- a/src/boosting/gbdt.cpp
+++ b/src/boosting/gbdt.cpp
@@ -244,6 +244,19 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
       std::string snapshot_out = model_output_path + ".snapshot_iter_" + std::to_string(iter + 1);
       SaveModelToFile(0, -1, config_->saved_feature_importance_type, snapshot_out.c_str());
     }
+    interactions_used.insert(models_[models_.size() - 1]->tree_features());
+
+    if (config_->max_interactions != 0 && (int)interactions_used.size() >= config_->max_interactions) {
+      auto new_config = std::unique_ptr<Config>(new Config(*config_));
+      new_config->tree_interaction_constraints_vector.clear();
+      for (auto &inter_set : interactions_used) {
+        new_config->tree_interaction_constraints_vector.emplace_back(
+            inter_set.begin(), inter_set.end());
+      }
+      new_config->max_interactions = 0;
+
+      ResetConfig(new_config.release());
+    }
   }
 }
 
diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h
index 28ebee446fad..a79c952c492a 100644
--- a/src/boosting/gbdt.h
+++ b/src/boosting/gbdt.h
@@ -542,6 +542,8 @@ class GBDT : public GBDTBase {
   std::vector<std::vector<std::string>> best_msg_;
   /*! \brief Trained models(trees) */
   std::vector<std::unique_ptr<Tree>> models_;
+  /*! \brief Trained models(trees) */
+  std::set<std::set<int>> interactions_used;
   /*! \brief Max feature index of training data*/
   int max_feature_idx_;
   /*! \brief Parser config file content */
diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp
index ae7f9567f6fb..c09a6bd18736 100644
--- a/src/io/config_auto.cpp
+++ b/src/io/config_auto.cpp
@@ -247,6 +247,8 @@ const std::unordered_set<std::string>& Config::parameter_set() {
   "path_smooth",
   "interaction_constraints",
   "tree_interaction_constraints",
+  "n_tree_interaction_constraints",
+  "max_interactions",
   "verbosity",
   "input_model",
   "output_model",
@@ -320,7 +322,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
   "gpu_platform_id",
   "gpu_device_id",
   "gpu_use_dp",
-  "num_gpu",
+  "num_gpu"
   });
   return params;
 }
@@ -489,7 +491,13 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
 
   GetString(params, "interaction_constraints", &interaction_constraints);
 
-  GetString(params, "tree_interaction_constraints", &tree_interaction_constraints);
+  GetString(params, "tree_interaction_constraints",&tree_interaction_constraints);
+
+  GetInt(params, "n_tree_interaction_constraints",&n_tree_interaction_constraints);
+  CHECK_GT(n_tree_interaction_constraints, -1);
+
+  GetInt(params, "max_interactions", &max_interactions);
+  CHECK_GT(max_interactions, -1);
 
   GetInt(params, "verbosity", &verbosity);
 
@@ -663,6 +671,7 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
 
   GetInt(params, "num_gpu", &num_gpu);
   CHECK_GT(num_gpu, 0);
+
 }
 
 std::string Config::SaveMembersToString() const {
@@ -726,6 +735,7 @@ std::string Config::SaveMembersToString() const {
   str_buf << "[path_smooth: " << path_smooth << "]\n";
   str_buf << "[interaction_constraints: " << interaction_constraints << "]\n";
   str_buf << "[tree_interaction_constraints: " << tree_interaction_constraints << "]\n";
+  str_buf << "[max_interactions: " << max_interactions << "]\n";
   str_buf << "[verbosity: " << verbosity << "]\n";
   str_buf << "[saved_feature_importance_type: " << saved_feature_importance_type << "]\n";
   str_buf << "[use_quantized_grad: " << use_quantized_grad << "]\n";
@@ -927,145 +937,6 @@ const std::unordered_map<std::string, std::vector<std::string>>& Config::paramet
   });
   return map;
 }
-const std::string Config::DumpAliases() {
-  std::stringstream str_buf;
-  str_buf << "{";
-  str_buf << "\"config\": [\"config_file\"], ";
-  str_buf << "\"task\": [\"task_type\"], ";
-  str_buf << "\"objective\": [\"objective_type\", \"app\", \"application\", \"loss\"], ";
-  str_buf << "\"boosting\": [\"boosting_type\", \"boost\"], ";
-  str_buf << "\"data\": [\"train\", \"train_data\", \"train_data_file\", \"data_filename\"], ";
-  str_buf << "\"valid\": [\"test\", \"valid_data\", \"valid_data_file\", \"test_data\", \"test_data_file\", \"valid_filenames\"], ";
-  str_buf << "\"num_iterations\": [\"num_iteration\", \"n_iter\", \"num_tree\", \"num_trees\", \"num_round\", \"num_rounds\", \"nrounds\", \"num_boost_round\", \"n_estimators\", \"max_iter\"], ";
-  str_buf << "\"learning_rate\": [\"shrinkage_rate\", \"eta\"], ";
-  str_buf << "\"num_leaves\": [\"num_leaf\", \"max_leaves\", \"max_leaf\", \"max_leaf_nodes\"], ";
-  str_buf << "\"tree_learner\": [\"tree\", \"tree_type\", \"tree_learner_type\"], ";
-  str_buf << "\"num_threads\": [\"num_thread\", \"nthread\", \"nthreads\", \"n_jobs\"], ";
-  str_buf << "\"device_type\": [\"device\"], ";
-  str_buf << "\"seed\": [\"random_seed\", \"random_state\"], ";
-  str_buf << "\"deterministic\": [], ";
-  str_buf << "\"force_col_wise\": [], ";
-  str_buf << "\"force_row_wise\": [], ";
-  str_buf << "\"histogram_pool_size\": [\"hist_pool_size\"], ";
-  str_buf << "\"max_depth\": [], ";
-  str_buf << "\"min_data_in_leaf\": [\"min_data_per_leaf\", \"min_data\", \"min_child_samples\", \"min_samples_leaf\"], ";
-  str_buf << "\"min_sum_hessian_in_leaf\": [\"min_sum_hessian_per_leaf\", \"min_sum_hessian\", \"min_hessian\", \"min_child_weight\"], ";
-  str_buf << "\"bagging_fraction\": [\"sub_row\", \"subsample\", \"bagging\"], ";
-  str_buf << "\"pos_bagging_fraction\": [\"pos_sub_row\", \"pos_subsample\", \"pos_bagging\"], ";
-  str_buf << "\"neg_bagging_fraction\": [\"neg_sub_row\", \"neg_subsample\", \"neg_bagging\"], ";
-  str_buf << "\"bagging_freq\": [\"subsample_freq\"], ";
-  str_buf << "\"bagging_seed\": [\"bagging_fraction_seed\"], ";
-  str_buf << "\"feature_fraction\": [\"sub_feature\", \"colsample_bytree\"], ";
-  str_buf << "\"feature_fraction_bynode\": [\"sub_feature_bynode\", \"colsample_bynode\"], ";
-  str_buf << "\"feature_fraction_seed\": [], ";
-  str_buf << "\"extra_trees\": [\"extra_tree\"], ";
-  str_buf << "\"extra_seed\": [], ";
-  str_buf << "\"early_stopping_round\": [\"early_stopping_rounds\", \"early_stopping\", \"n_iter_no_change\"], ";
-  str_buf << "\"first_metric_only\": [], ";
-  str_buf << "\"max_delta_step\": [\"max_tree_output\", \"max_leaf_output\"], ";
-  str_buf << "\"lambda_l1\": [\"reg_alpha\", \"l1_regularization\"], ";
-  str_buf << "\"lambda_l2\": [\"reg_lambda\", \"lambda\", \"l2_regularization\"], ";
-  str_buf << "\"linear_lambda\": [], ";
-  str_buf << "\"min_gain_to_split\": [\"min_split_gain\"], ";
-  str_buf << "\"drop_rate\": [\"rate_drop\"], ";
-  str_buf << "\"max_drop\": [], ";
-  str_buf << "\"skip_drop\": [], ";
-  str_buf << "\"xgboost_dart_mode\": [], ";
-  str_buf << "\"uniform_drop\": [], ";
-  str_buf << "\"drop_seed\": [], ";
-  str_buf << "\"top_rate\": [], ";
-  str_buf << "\"other_rate\": [], ";
-  str_buf << "\"min_data_per_group\": [], ";
-  str_buf << "\"max_cat_threshold\": [], ";
-  str_buf << "\"cat_l2\": [], ";
-  str_buf << "\"cat_smooth\": [], ";
-  str_buf << "\"max_cat_to_onehot\": [], ";
-  str_buf << "\"top_k\": [\"topk\"], ";
-  str_buf << "\"monotone_constraints\": [\"mc\", \"monotone_constraint\", \"monotonic_cst\"], ";
-  str_buf << "\"monotone_constraints_method\": [\"monotone_constraining_method\", \"mc_method\"], ";
-  str_buf << "\"monotone_penalty\": [\"monotone_splits_penalty\", \"ms_penalty\", \"mc_penalty\"], ";
-  str_buf << "\"feature_contri\": [\"feature_contrib\", \"fc\", \"fp\", \"feature_penalty\"], ";
-  str_buf << "\"forcedsplits_filename\": [\"fs\", \"forced_splits_filename\", \"forced_splits_file\", \"forced_splits\"], ";
-  str_buf << "\"refit_decay_rate\": [], ";
-  str_buf << "\"cegb_tradeoff\": [], ";
-  str_buf << "\"cegb_penalty_split\": [], ";
-  str_buf << "\"cegb_penalty_feature_lazy\": [], ";
-  str_buf << "\"cegb_penalty_feature_coupled\": [], ";
-  str_buf << "\"path_smooth\": [], ";
-  str_buf << "\"interaction_constraints\": [], ";
-  str_buf << "\"tree_interaction_constraints\": [], ";
-  str_buf << "\"verbosity\": [\"verbose\"], ";
-  str_buf << "\"input_model\": [\"model_input\", \"model_in\"], ";
-  str_buf << "\"output_model\": [\"model_output\", \"model_out\"], ";
-  str_buf << "\"saved_feature_importance_type\": [], ";
-  str_buf << "\"snapshot_freq\": [\"save_period\"], ";
-  str_buf << "\"linear_tree\": [\"linear_trees\"], ";
-  str_buf << "\"max_bin\": [\"max_bins\"], ";
-  str_buf << "\"max_bin_by_feature\": [], ";
-  str_buf << "\"min_data_in_bin\": [], ";
-  str_buf << "\"bin_construct_sample_cnt\": [\"subsample_for_bin\"], ";
-  str_buf << "\"data_random_seed\": [\"data_seed\"], ";
-  str_buf << "\"is_enable_sparse\": [\"is_sparse\", \"enable_sparse\", \"sparse\"], ";
-  str_buf << "\"enable_bundle\": [\"is_enable_bundle\", \"bundle\"], ";
-  str_buf << "\"use_missing\": [], ";
-  str_buf << "\"zero_as_missing\": [], ";
-  str_buf << "\"feature_pre_filter\": [], ";
-  str_buf << "\"pre_partition\": [\"is_pre_partition\"], ";
-  str_buf << "\"two_round\": [\"two_round_loading\", \"use_two_round_loading\"], ";
-  str_buf << "\"header\": [\"has_header\"], ";
-  str_buf << "\"label_column\": [\"label\"], ";
-  str_buf << "\"weight_column\": [\"weight\"], ";
-  str_buf << "\"group_column\": [\"group\", \"group_id\", \"query_column\", \"query\", \"query_id\"], ";
-  str_buf << "\"ignore_column\": [\"ignore_feature\", \"blacklist\"], ";
-  str_buf << "\"categorical_feature\": [\"cat_feature\", \"categorical_column\", \"cat_column\", \"categorical_features\"], ";
-  str_buf << "\"forcedbins_filename\": [], ";
-  str_buf << "\"save_binary\": [\"is_save_binary\", \"is_save_binary_file\"], ";
-  str_buf << "\"precise_float_parser\": [], ";
-  str_buf << "\"parser_config_file\": [], ";
-  str_buf << "\"start_iteration_predict\": [], ";
-  str_buf << "\"num_iteration_predict\": [], ";
-  str_buf << "\"predict_raw_score\": [\"is_predict_raw_score\", \"predict_rawscore\", \"raw_score\"], ";
-  str_buf << "\"predict_leaf_index\": [\"is_predict_leaf_index\", \"leaf_index\"], ";
-  str_buf << "\"predict_contrib\": [\"is_predict_contrib\", \"contrib\"], ";
-  str_buf << "\"predict_disable_shape_check\": [], ";
-  str_buf << "\"pred_early_stop\": [], ";
-  str_buf << "\"pred_early_stop_freq\": [], ";
-  str_buf << "\"pred_early_stop_margin\": [], ";
-  str_buf << "\"output_result\": [\"predict_result\", \"prediction_result\", \"predict_name\", \"prediction_name\", \"pred_name\", \"name_pred\"], ";
-  str_buf << "\"convert_model_language\": [], ";
-  str_buf << "\"convert_model\": [\"convert_model_file\"], ";
-  str_buf << "\"objective_seed\": [], ";
-  str_buf << "\"num_class\": [\"num_classes\"], ";
-  str_buf << "\"is_unbalance\": [\"unbalance\", \"unbalanced_sets\"], ";
-  str_buf << "\"scale_pos_weight\": [], ";
-  str_buf << "\"sigmoid\": [], ";
-  str_buf << "\"boost_from_average\": [], ";
-  str_buf << "\"reg_sqrt\": [], ";
-  str_buf << "\"alpha\": [], ";
-  str_buf << "\"fair_c\": [], ";
-  str_buf << "\"poisson_max_delta_step\": [], ";
-  str_buf << "\"tweedie_variance_power\": [], ";
-  str_buf << "\"lambdarank_truncation_level\": [], ";
-  str_buf << "\"lambdarank_norm\": [], ";
-  str_buf << "\"label_gain\": [], ";
-  str_buf << "\"metric\": [\"metrics\", \"metric_types\"], ";
-  str_buf << "\"metric_freq\": [\"output_freq\"], ";
-  str_buf << "\"is_provide_training_metric\": [\"training_metric\", \"is_training_metric\", \"train_metric\"], ";
-  str_buf << "\"eval_at\": [\"ndcg_eval_at\", \"ndcg_at\", \"map_eval_at\", \"map_at\"], ";
-  str_buf << "\"multi_error_top_k\": [], ";
-  str_buf << "\"auc_mu_weights\": [], ";
-  str_buf << "\"num_machines\": [\"num_machine\"], ";
-  str_buf << "\"local_listen_port\": [\"local_port\", \"port\"], ";
-  str_buf << "\"time_out\": [], ";
-  str_buf << "\"machine_list_filename\": [\"machine_list_file\", \"machine_list\", \"mlist\"], ";
-  str_buf << "\"machines\": [\"workers\", \"nodes\"], ";
-  str_buf << "\"gpu_platform_id\": [], ";
-  str_buf << "\"gpu_device_id\": [], ";
-  str_buf << "\"gpu_use_dp\": [], ";
-  str_buf << "\"num_gpu\": []";
-  str_buf << "}";
-  return str_buf.str();
-}
 
 const std::unordered_map<std::string, std::string>& Config::ParameterTypes() {
   static std::unordered_map<std::string, std::string> map({
diff --git a/src/io/tree.cpp b/src/io/tree.cpp
index 4312b4f65002..97e9c969e1d6 100644
--- a/src/io/tree.cpp
+++ b/src/io/tree.cpp
@@ -415,6 +415,18 @@ std::string Tree::ToJSON() const {
   str_buf << "\"num_leaves\":" << num_leaves_ << "," << '\n';
   str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
   str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
+
+  auto feats_used = tree_features();
+  size_t i = 0;
+  str_buf << "\"tree_features\":[";
+  for (int feat: feats_used) {
+    str_buf << feat;
+    if(i != feats_used.size() - 1) {
+      str_buf << ",";
+    }
+    ++i;
+  }
+  str_buf <<    "]," << '\n';
   if (num_leaves_ == 1) {
     if (is_linear_) {
       str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << ", " << "\n";
diff --git a/src/treelearner/col_sampler.hpp b/src/treelearner/col_sampler.hpp
index 87294369906d..642cecd914b2 100644
--- a/src/treelearner/col_sampler.hpp
+++ b/src/treelearner/col_sampler.hpp
@@ -28,10 +28,13 @@ class ColSampler {
       std::unordered_set<int> constraint_set(constraint.begin(), constraint.end());
       interaction_constraints_.push_back(constraint_set);
     }
+
     for (auto constraint : config->tree_interaction_constraints_vector) {
-      std::unordered_set<int> constraint_set(constraint.begin(), constraint.end());
+      std::unordered_set<int> constraint_set(constraint.begin(),constraint.end());
       tree_interaction_constraints_.push_back(constraint_set);
     }
+
+    n_tree_interaction_constraints_ = config-> n_tree_interaction_constraints;
   }
 
   static int GetCnt(size_t total_cnt, double fraction) {
@@ -72,6 +75,11 @@ class ColSampler {
       used_cnt_bytree_ =
           GetCnt(valid_feature_indices_.size(), fraction_bytree_);
     }
+    tree_interaction_constraints_.clear();
+    for (auto constraint : config->tree_interaction_constraints_vector) {
+      std::unordered_set<int> constraint_set(constraint.begin(), constraint.end());
+      tree_interaction_constraints_.push_back(constraint_set);
+    }
     ResetByTree();
   }
 
@@ -92,55 +100,74 @@ class ColSampler {
     }
   }
 
-  std::vector<int8_t> GetByNode(const Tree* tree, int leaf) {
-    std::unordered_set<int> tree_allowed_features;
-    if (!tree_interaction_constraints_.empty()) {
-      std::set<int> tree_features = tree->tree_features();
-      tree_allowed_features.insert(tree_features.begin(), tree_features.end());
-      for (auto constraint : tree_interaction_constraints_) {
-        int num_feat_found = 0;
-
-        if (tree_features.empty()) {
+  void ComputeTreeAllowedFeatures(std::unordered_set<int> &tree_allowed_features, std::set<int> &tree_features) {
+    tree_allowed_features.insert(tree_features.begin(), tree_features.end());
+    if(tree_interaction_constraints_.empty()){
+        for (int i = 0; i < train_data_->num_features(); ++i) {
+            tree_allowed_features.insert(tree_allowed_features.end(), i);
+        }
+    }
+    for (auto constraint : tree_interaction_constraints_) {
+      int num_feat_found = 0;
+      if (tree_features.empty()) {
+        tree_allowed_features.insert(constraint.begin(), constraint.end());
+      }
+      for (int feat : tree_features) {
+        if (constraint.count(feat) == 0) { break; }
+        ++num_feat_found;
+        if (num_feat_found == static_cast<int>(tree_features.size())) {
           tree_allowed_features.insert(constraint.begin(), constraint.end());
+          break;
         }
+      }
+    }
+  }
 
-        for (int feat : tree_features) {
-          if (constraint.count(feat) == 0) { break; }
-          ++num_feat_found;
-          if (num_feat_found == static_cast<int>(tree_features.size())) {
-            tree_allowed_features.insert(constraint.begin(), constraint.end());
-            break;
+    void ComputeBranchAllowedFeatures(const Tree *tree, int leaf, std::unordered_set<int> &branch_allowed_features) {
+        if (!interaction_constraints_.empty()) {
+          std::vector<int> branch_features = tree->branch_features(leaf);
+          for (auto constraint : interaction_constraints_) {
+            int num_feat_found = 0;
+            if (branch_features.empty()) {
+              branch_allowed_features.insert(constraint.begin(), constraint.end());
+            }
+            for (int feat : branch_features) {
+              if (constraint.count(feat) == 0) { break; }
+              ++num_feat_found;
+              if (num_feat_found == static_cast<int>(branch_features.size())) {
+                branch_allowed_features.insert(constraint.begin(), constraint.end());
+                break;
+              }
+            }
           }
         }
-      }
     }
 
-    // get interaction constraints for current branch
-    std::unordered_set<int> branch_allowed_features;
-    if (!interaction_constraints_.empty()) {
-      std::vector<int> branch_features = tree->branch_features(leaf);
-      for (auto constraint : interaction_constraints_) {
-        int num_feat_found = 0;
-        if (branch_features.empty()) {
-          branch_allowed_features.insert(constraint.begin(), constraint.end());
-        }
-        for (int feat : branch_features) {
-          if (constraint.count(feat) == 0) { break; }
-          ++num_feat_found;
-          if (num_feat_found == static_cast<int>(branch_features.size())) {
-            branch_allowed_features.insert(constraint.begin(), constraint.end());
-            break;
+    std::vector<int8_t> GetByNode(const Tree* tree, int leaf) {
+    // get interaction constraints for current tree
+    std::unordered_set<int> tree_allowed_features;
+    if (!tree_interaction_constraints_.empty() || n_tree_interaction_constraints_ > 0) {
+      std::set<int> tree_features = tree->tree_features();
+      if(n_tree_interaction_constraints_ == 0 || tree_features.size() < (unsigned long) n_tree_interaction_constraints_){
+          ComputeTreeAllowedFeatures(tree_allowed_features, tree_features);
+      } else{
+          for(int feat: tree_features){
+              tree_allowed_features.insert(feat);
           }
-        }
       }
     }
+    // get interaction constraints for current branch
+    std::unordered_set<int> branch_allowed_features;
+
+    ComputeBranchAllowedFeatures(tree, leaf, branch_allowed_features);
 
-    // intersect allowed features for branch and tree
+
+        // intersect allowed features for branch and tree
     std::unordered_set<int> allowed_features;
 
-    if(tree_interaction_constraints_.empty() && !interaction_constraints_.empty()) {
+    if((tree_interaction_constraints_.empty() && n_tree_interaction_constraints_ == 0) && !interaction_constraints_.empty()) {
       allowed_features.insert(branch_allowed_features.begin(), branch_allowed_features.end());
-    } else if(!tree_interaction_constraints_.empty() && interaction_constraints_.empty()){
+    } else if(!(tree_interaction_constraints_.empty() && n_tree_interaction_constraints_ == 0) && interaction_constraints_.empty()){
       allowed_features.insert(tree_allowed_features.begin(), tree_allowed_features.end());
     } else {
       for (int element : tree_allowed_features) {
@@ -150,10 +177,9 @@ class ColSampler {
       }
     }
 
-
     std::vector<int8_t> ret(train_data_->num_features(), 0);
     if (fraction_bynode_ >= 1.0f) {
-      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty()) {
+      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty() && n_tree_interaction_constraints_ == 0) {
         return std::vector<int8_t>(train_data_->num_features(), 1);
       } else {
         for (int feat : allowed_features) {
@@ -169,7 +195,7 @@ class ColSampler {
       auto used_feature_cnt = GetCnt(used_feature_indices_.size(), fraction_bynode_);
       std::vector<int>* allowed_used_feature_indices;
       std::vector<int> filtered_feature_indices;
-      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty()) {
+      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty() && n_tree_interaction_constraints_ == 0) {
         allowed_used_feature_indices = &used_feature_indices_;
       } else {
         for (int feat_ind : used_feature_indices_) {
@@ -195,7 +221,7 @@ class ColSampler {
           GetCnt(valid_feature_indices_.size(), fraction_bynode_);
       std::vector<int>* allowed_valid_feature_indices;
       std::vector<int> filtered_feature_indices;
-      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty()) {
+      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty() && n_tree_interaction_constraints_ == 0) {
         allowed_valid_feature_indices = &valid_feature_indices_;
       } else {
         for (int feat : valid_feature_indices_) {
@@ -240,8 +266,9 @@ class ColSampler {
   std::vector<int> valid_feature_indices_;
   /*! \brief interaction constraints index in original (raw data) features */
   std::vector<std::unordered_set<int>> interaction_constraints_;
-  /*! \brief tree nteraction constraints index in original (raw data) features */
+
   std::vector<std::unordered_set<int>> tree_interaction_constraints_;
+  int n_tree_interaction_constraints_;
 };
 
 }  // namespace LightGBM
diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp
index 4301ce4f0e24..aed1a949be6a 100644
--- a/src/treelearner/serial_tree_learner.cpp
+++ b/src/treelearner/serial_tree_learner.cpp
@@ -197,7 +197,8 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
   BeforeTrain();
 
   bool track_branch_features = !(config_->interaction_constraints_vector.empty()
-                                 && config_->tree_interaction_constraints_vector.empty());
+                                 && config_->tree_interaction_constraints_vector.empty()
+                                 && config_->n_tree_interaction_constraints == 0);
   auto tree = std::unique_ptr<Tree>(new Tree(config_->num_leaves, track_branch_features, false));
   auto tree_ptr = tree.get();
   constraints_->ShareTreePointer(tree_ptr);
@@ -334,7 +335,6 @@ void SerialTreeLearner::BeforeTrain() {
 
 bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
   Common::FunctionTimer fun_timer("SerialTreeLearner::BeforeFindBestSplit", global_timer);
-
   #pragma omp parallel for schedule(static)
   for (int i = 0; i < config_->num_leaves; ++i) {
     int feat_index = best_split_per_leaf_[i].feature;
diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py
index 3b7433570761..e355e5ab074a 100644
--- a/tests/python_package_test/test_engine.py
+++ b/tests/python_package_test/test_engine.py
@@ -3330,7 +3330,7 @@ def metrics_combination_cv_regression(metric_list, assumed_iteration,
                                       feval=lambda preds, train_data: [constant_metric(preds, train_data),
                                                                        decreasing_metric(preds, train_data)])
 
-#TODO investigate why this test fails
+
 def test_node_level_subcol():
     X, y = load_breast_cancer(return_X_y=True)
     X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
@@ -3664,6 +3664,27 @@ def test_interaction_constraints():
                                                           [1] + list(range(2, num_features))]),
                     train_data, num_boost_round=10)
 
+
+def test_linear_trees_num_threads():
+    # check that number of threads does not affect result
+    np.random.seed(0)
+    x = np.arange(0, 1000, 0.1)
+    y = 2 * x + np.random.normal(0, 0.1, len(x))
+    x = x[:, np.newaxis]
+    lgb_train = lgb.Dataset(x, label=y)
+    params = {'verbose': -1,
+              'objective': 'regression',
+              'seed': 0,
+              'linear_tree': True,
+              'num_threads': 2}
+    est = lgb.train(params, lgb_train, num_boost_round=100)
+    pred1 = est.predict(x)
+    params["num_threads"] = 4
+    est = lgb.train(params, lgb_train, num_boost_round=100)
+    pred2 = est.predict(x)
+    np.testing.assert_allclose(pred1, pred2)
+
+
 def test_linear_trees(tmp_path):
     # check that setting linear_tree=True fits better than ordinary trees when data has linear relationship
     np.random.seed(0)

From 137bc6d4ca04cf8c7bb317cf62ee0a23fb789945 Mon Sep 17 00:00:00 2001
From: veneres <alberto.veneri@unive.it>
Date: Thu, 15 Feb 2024 13:59:37 +0100
Subject: [PATCH 09/21] Resolved lint errors identified by github actions

---
 docs/Parameters.rst                     | 22 ++++++++++++++++
 include/LightGBM/config.h               |  1 -
 include/LightGBM/tree.h                 |  1 +
 python-package/lightgbm/engine.py       |  2 +-
 src/boosting/gbdt.cpp                   |  2 +-
 src/boosting/gbdt.h                     |  1 +
 src/io/config_auto.cpp                  |  5 ++--
 src/io/tree.cpp                         |  4 +--
 src/treelearner/col_sampler.hpp         | 35 +++++++++++++------------
 src/treelearner/serial_tree_learner.cpp |  4 +--
 10 files changed, 50 insertions(+), 27 deletions(-)

diff --git a/docs/Parameters.rst b/docs/Parameters.rst
index 341cdd487c71..952b59c2701d 100644
--- a/docs/Parameters.rst
+++ b/docs/Parameters.rst
@@ -618,6 +618,28 @@ Learning Control Parameters
 
    -  any two features can only appear in the same branch only if there exists a constraint containing both features
 
+-  ``tree_interaction_constraints`` :raw-html:`<a id="interaction_constraints" title="Permalink to this parameter" href="#interaction_constraints">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = string
+
+   -  controls which features can appear in the same tree
+
+   -  by default interaction constraints are disabled, to enable them you can specify
+
+      -  for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]``
+
+      -  for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]``
+
+      -  for R-package, list of character or numeric vectors, e.g. ``list(c("var1", "var2", "var3"), c("var3", "var4"))`` or ``list(c(1L, 2L, 3L), c(3L, 4L))``. Numeric vectors should use 1-based indexing, where ``1L`` is the first feature, ``2L`` is the second feature, etc
+
+   -  any two features can only appear in the same tree only if there exists a constraint containing both features
+
+-  ``n_tree_interaction_constraints`` :raw-html:`<a id="n_tree_interaction_constraints" title="Permalink to this parameter" href="#n_tree_interaction_constraints">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int
+    - controls how many features can appear in the same tree
+    - by default (n_tree_interaction_constraints = 0) interaction constraints are disabled
+
+-  ``max_interactions`` :raw-html:`<a id="max_interactions" title="Permalink to this parameter" href="#max_interactions">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int
+    - controls how many features interactions can be added to the final model
+    - by default no limit is imposed on the interaction with max_interactions = 0
+
 -  ``verbosity`` :raw-html:`<a id="verbosity" title="Permalink to this parameter" href="#verbosity">&#x1F517;&#xFE0E;</a>`, default = ``1``, type = int, aliases: ``verbose``
 
    -  controls the level of LightGBM's verbosity
diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h
index 080f3408590f..97932dfe0501 100644
--- a/include/LightGBM/config.h
+++ b/include/LightGBM/config.h
@@ -586,7 +586,6 @@ struct Config {
 
   // desc = controls how many features interactions can be added to the final model
   // desc = by default no limit is imposed on the interaction with max_interactions = 0
-  // desc = any two features can only appear in the same tree only if there exists a constraint containing both features
   int max_interactions = 0;
 
   // alias = verbose
diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h
index 5472bebf0517..231de949b844 100644
--- a/include/LightGBM/tree.h
+++ b/include/LightGBM/tree.h
@@ -13,6 +13,7 @@
 #include <memory>
 #include <unordered_map>
 #include <vector>
+#include <set>
 
 namespace LightGBM {
 
diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py
index daa7e823c614..ba7ff864737e 100644
--- a/python-package/lightgbm/engine.py
+++ b/python-package/lightgbm/engine.py
@@ -277,7 +277,7 @@ def train(
         booster.update(fobj=fobj)
 
         evaluation_result_list: List[_LGBM_BoosterEvalMethodResultType] = []
-        if params["max_interactions"] > 0:
+        if params.get("max_interactions", 0) > 0:
             interaction_used = booster.dump_model(num_iteration=1, start_iteration=i)["tree_info"][0]["tree_features"]
             interaction_used.sort()
             interactions_used.add(tuple(interaction_used))
diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp
index e60f43c21f35..11c22841bd15 100644
--- a/src/boosting/gbdt.cpp
+++ b/src/boosting/gbdt.cpp
@@ -246,7 +246,7 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
     }
     interactions_used.insert(models_[models_.size() - 1]->tree_features());
 
-    if (config_->max_interactions != 0 && (int)interactions_used.size() >= config_->max_interactions) {
+    if (config_->max_interactions != 0 && static_cast<int>(interactions_used.size()) >= config_->max_interactions) {
       auto new_config = std::unique_ptr<Config>(new Config(*config_));
       new_config->tree_interaction_constraints_vector.clear();
       for (auto &inter_set : interactions_used) {
diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h
index a79c952c492a..b96f15216246 100644
--- a/src/boosting/gbdt.h
+++ b/src/boosting/gbdt.h
@@ -23,6 +23,7 @@
 #include <unordered_map>
 #include <utility>
 #include <vector>
+#include <set>
 
 #include "cuda/cuda_score_updater.hpp"
 #include "score_updater.hpp"
diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp
index c09a6bd18736..42a518cb3029 100644
--- a/src/io/config_auto.cpp
+++ b/src/io/config_auto.cpp
@@ -491,9 +491,9 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
 
   GetString(params, "interaction_constraints", &interaction_constraints);
 
-  GetString(params, "tree_interaction_constraints",&tree_interaction_constraints);
+  GetString(params, "tree_interaction_constraints", &tree_interaction_constraints);
 
-  GetInt(params, "n_tree_interaction_constraints",&n_tree_interaction_constraints);
+  GetInt(params, "n_tree_interaction_constraints", &n_tree_interaction_constraints);
   CHECK_GT(n_tree_interaction_constraints, -1);
 
   GetInt(params, "max_interactions", &max_interactions);
@@ -671,7 +671,6 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
 
   GetInt(params, "num_gpu", &num_gpu);
   CHECK_GT(num_gpu, 0);
-
 }
 
 std::string Config::SaveMembersToString() const {
diff --git a/src/io/tree.cpp b/src/io/tree.cpp
index 97e9c969e1d6..7313db2d61a4 100644
--- a/src/io/tree.cpp
+++ b/src/io/tree.cpp
@@ -419,9 +419,9 @@ std::string Tree::ToJSON() const {
   auto feats_used = tree_features();
   size_t i = 0;
   str_buf << "\"tree_features\":[";
-  for (int feat: feats_used) {
+  for (int feat : feats_used) {
     str_buf << feat;
-    if(i != feats_used.size() - 1) {
+    if (i != feats_used.size() - 1) {
       str_buf << ",";
     }
     ++i;
diff --git a/src/treelearner/col_sampler.hpp b/src/treelearner/col_sampler.hpp
index 642cecd914b2..08d709af6e34 100644
--- a/src/treelearner/col_sampler.hpp
+++ b/src/treelearner/col_sampler.hpp
@@ -15,6 +15,7 @@
 #include <algorithm>
 #include <unordered_set>
 #include <vector>
+#include <set>
 
 namespace LightGBM {
 class ColSampler {
@@ -30,7 +31,7 @@ class ColSampler {
     }
 
     for (auto constraint : config->tree_interaction_constraints_vector) {
-      std::unordered_set<int> constraint_set(constraint.begin(),constraint.end());
+      std::unordered_set<int> constraint_set(constraint.begin(), constraint.end());
       tree_interaction_constraints_.push_back(constraint_set);
     }
 
@@ -100,22 +101,22 @@ class ColSampler {
     }
   }
 
-  void ComputeTreeAllowedFeatures(std::unordered_set<int> &tree_allowed_features, std::set<int> &tree_features) {
-    tree_allowed_features.insert(tree_features.begin(), tree_features.end());
-    if(tree_interaction_constraints_.empty()){
+  void ComputeTreeAllowedFeatures(std::unordered_set<int> &tree_allowed_features, std::set<int> *tree_features) {
+    tree_allowed_features.insert((*tree_features).begin(), (*tree_features).end());
+    if (tree_interaction_constraints_.empty()) {
         for (int i = 0; i < train_data_->num_features(); ++i) {
             tree_allowed_features.insert(tree_allowed_features.end(), i);
         }
     }
     for (auto constraint : tree_interaction_constraints_) {
       int num_feat_found = 0;
-      if (tree_features.empty()) {
+      if ((*tree_features).empty()) {
         tree_allowed_features.insert(constraint.begin(), constraint.end());
       }
-      for (int feat : tree_features) {
+      for (int feat : *tree_features) {
         if (constraint.count(feat) == 0) { break; }
         ++num_feat_found;
-        if (num_feat_found == static_cast<int>(tree_features.size())) {
+        if (num_feat_found == static_cast<int>((*tree_features).size())) {
           tree_allowed_features.insert(constraint.begin(), constraint.end());
           break;
         }
@@ -123,19 +124,19 @@ class ColSampler {
     }
   }
 
-    void ComputeBranchAllowedFeatures(const Tree *tree, int leaf, std::unordered_set<int> &branch_allowed_features) {
+    void ComputeBranchAllowedFeatures(const Tree *tree, int leaf, std::unordered_set<int> *branch_allowed_features) {
         if (!interaction_constraints_.empty()) {
           std::vector<int> branch_features = tree->branch_features(leaf);
           for (auto constraint : interaction_constraints_) {
             int num_feat_found = 0;
             if (branch_features.empty()) {
-              branch_allowed_features.insert(constraint.begin(), constraint.end());
+                (*branch_allowed_features).insert(constraint.begin(), constraint.end());
             }
             for (int feat : branch_features) {
               if (constraint.count(feat) == 0) { break; }
               ++num_feat_found;
               if (num_feat_found == static_cast<int>(branch_features.size())) {
-                branch_allowed_features.insert(constraint.begin(), constraint.end());
+                  (*branch_allowed_features).insert(constraint.begin(), constraint.end());
                 break;
               }
             }
@@ -148,10 +149,10 @@ class ColSampler {
     std::unordered_set<int> tree_allowed_features;
     if (!tree_interaction_constraints_.empty() || n_tree_interaction_constraints_ > 0) {
       std::set<int> tree_features = tree->tree_features();
-      if(n_tree_interaction_constraints_ == 0 || tree_features.size() < (unsigned long) n_tree_interaction_constraints_){
-          ComputeTreeAllowedFeatures(tree_allowed_features, tree_features);
-      } else{
-          for(int feat: tree_features){
+      if (n_tree_interaction_constraints_ == 0 || tree_features.size() < (std::set<int>::size_type) n_tree_interaction_constraints_) {
+          ComputeTreeAllowedFeatures(tree_allowed_features, &tree_features);
+      } else {
+          for(int feat : tree_features) {
               tree_allowed_features.insert(feat);
           }
       }
@@ -159,15 +160,15 @@ class ColSampler {
     // get interaction constraints for current branch
     std::unordered_set<int> branch_allowed_features;
 
-    ComputeBranchAllowedFeatures(tree, leaf, branch_allowed_features);
+    ComputeBranchAllowedFeatures(tree, leaf, &branch_allowed_features);
 
 
         // intersect allowed features for branch and tree
     std::unordered_set<int> allowed_features;
 
-    if((tree_interaction_constraints_.empty() && n_tree_interaction_constraints_ == 0) && !interaction_constraints_.empty()) {
+    if ((tree_interaction_constraints_.empty() && n_tree_interaction_constraints_ == 0) && !interaction_constraints_.empty()) {
       allowed_features.insert(branch_allowed_features.begin(), branch_allowed_features.end());
-    } else if(!(tree_interaction_constraints_.empty() && n_tree_interaction_constraints_ == 0) && interaction_constraints_.empty()){
+    } else if (!(tree_interaction_constraints_.empty() && n_tree_interaction_constraints_ == 0) && interaction_constraints_.empty()) {
       allowed_features.insert(tree_allowed_features.begin(), tree_allowed_features.end());
     } else {
       for (int element : tree_allowed_features) {
diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp
index aed1a949be6a..43cdc09accce 100644
--- a/src/treelearner/serial_tree_learner.cpp
+++ b/src/treelearner/serial_tree_learner.cpp
@@ -338,11 +338,11 @@ bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int
   #pragma omp parallel for schedule(static)
   for (int i = 0; i < config_->num_leaves; ++i) {
     int feat_index = best_split_per_leaf_[i].feature;
-    if(feat_index == -1) continue;
+    if (feat_index == -1) continue;
 
     int inner_feat_index = train_data_->InnerFeatureIndex(feat_index);
     auto allowed_feature = col_sampler_.GetByNode(tree, i);
-    if(!allowed_feature[inner_feat_index]){
+    if (!allowed_feature[inner_feat_index]) {
       RecomputeBestSplitForLeaf(tree, i, &best_split_per_leaf_[i]);
     }
   }

From 9b3fb5ec813792c171c704f3cbd128e205f26bb6 Mon Sep 17 00:00:00 2001
From: veneres <alberto.veneri@unive.it>
Date: Thu, 15 Feb 2024 14:32:59 +0100
Subject: [PATCH 10/21] Fix docs

---
 docs/Parameters.rst             |  2 ++
 include/LightGBM/config.h       |  1 -
 src/treelearner/col_sampler.hpp | 16 ++++++++--------
 3 files changed, 10 insertions(+), 9 deletions(-)

diff --git a/docs/Parameters.rst b/docs/Parameters.rst
index 952b59c2701d..05fbc26e1d9d 100644
--- a/docs/Parameters.rst
+++ b/docs/Parameters.rst
@@ -634,10 +634,12 @@ Learning Control Parameters
 
 -  ``n_tree_interaction_constraints`` :raw-html:`<a id="n_tree_interaction_constraints" title="Permalink to this parameter" href="#n_tree_interaction_constraints">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int
     - controls how many features can appear in the same tree
+
     - by default (n_tree_interaction_constraints = 0) interaction constraints are disabled
 
 -  ``max_interactions`` :raw-html:`<a id="max_interactions" title="Permalink to this parameter" href="#max_interactions">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int
     - controls how many features interactions can be added to the final model
+    
     - by default no limit is imposed on the interaction with max_interactions = 0
 
 -  ``verbosity`` :raw-html:`<a id="verbosity" title="Permalink to this parameter" href="#verbosity">&#x1F517;&#xFE0E;</a>`, default = ``1``, type = int, aliases: ``verbose``
diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h
index 97932dfe0501..248e41f95ef6 100644
--- a/include/LightGBM/config.h
+++ b/include/LightGBM/config.h
@@ -579,7 +579,6 @@ struct Config {
   // desc = any two features can only appear in the same tree only if there exists a constraint containing both features
   std::string tree_interaction_constraints = "";
 
-
   // desc = controls how many features can appear in the same tree
   // desc = by default (n_tree_interaction_constraints = 0) interaction constraints are disabled
   int n_tree_interaction_constraints = 0;
diff --git a/src/treelearner/col_sampler.hpp b/src/treelearner/col_sampler.hpp
index 08d709af6e34..cbfa7331f380 100644
--- a/src/treelearner/col_sampler.hpp
+++ b/src/treelearner/col_sampler.hpp
@@ -101,23 +101,23 @@ class ColSampler {
     }
   }
 
-  void ComputeTreeAllowedFeatures(std::unordered_set<int> &tree_allowed_features, std::set<int> *tree_features) {
-    tree_allowed_features.insert((*tree_features).begin(), (*tree_features).end());
+  void ComputeTreeAllowedFeatures(std::unordered_set<int> *tree_allowed_features, std::set<int> *tree_features) {
+    tree_allowed_features->insert(tree_features->begin(), tree_features->end());
     if (tree_interaction_constraints_.empty()) {
         for (int i = 0; i < train_data_->num_features(); ++i) {
-            tree_allowed_features.insert(tree_allowed_features.end(), i);
+            tree_allowed_features->insert(tree_allowed_features->end(), i);
         }
     }
     for (auto constraint : tree_interaction_constraints_) {
       int num_feat_found = 0;
-      if ((*tree_features).empty()) {
-        tree_allowed_features.insert(constraint.begin(), constraint.end());
+      if (tree_features->empty()) {
+        tree_allowed_features->insert(constraint.begin(), constraint.end());
       }
       for (int feat : *tree_features) {
         if (constraint.count(feat) == 0) { break; }
         ++num_feat_found;
-        if (num_feat_found == static_cast<int>((*tree_features).size())) {
-          tree_allowed_features.insert(constraint.begin(), constraint.end());
+        if (num_feat_found == static_cast<int>(tree_features->size())) {
+          tree_allowed_features->insert(constraint.begin(), constraint.end());
           break;
         }
       }
@@ -150,7 +150,7 @@ class ColSampler {
     if (!tree_interaction_constraints_.empty() || n_tree_interaction_constraints_ > 0) {
       std::set<int> tree_features = tree->tree_features();
       if (n_tree_interaction_constraints_ == 0 || tree_features.size() < (std::set<int>::size_type) n_tree_interaction_constraints_) {
-          ComputeTreeAllowedFeatures(tree_allowed_features, &tree_features);
+          ComputeTreeAllowedFeatures(&tree_allowed_features, &tree_features);
       } else {
           for(int feat : tree_features) {
               tree_allowed_features.insert(feat);

From 997e06b426a786f5b2d2ba6485315802e5ebbc2b Mon Sep 17 00:00:00 2001
From: veneres <alberto.veneri@unive.it>
Date: Thu, 15 Feb 2024 14:38:18 +0100
Subject: [PATCH 11/21] Fix docs

---
 docs/Parameters.rst             | 6 ++++--
 src/treelearner/col_sampler.hpp | 2 +-
 2 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/docs/Parameters.rst b/docs/Parameters.rst
index 05fbc26e1d9d..427647dd6fec 100644
--- a/docs/Parameters.rst
+++ b/docs/Parameters.rst
@@ -618,7 +618,7 @@ Learning Control Parameters
 
    -  any two features can only appear in the same branch only if there exists a constraint containing both features
 
--  ``tree_interaction_constraints`` :raw-html:`<a id="interaction_constraints" title="Permalink to this parameter" href="#interaction_constraints">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = string
+-  ``tree_interaction_constraints`` :raw-html:`<a id="tree_interaction_constraints" title="Permalink to this parameter" href="#tree_interaction_constraints">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = string
 
    -  controls which features can appear in the same tree
 
@@ -633,13 +633,15 @@ Learning Control Parameters
    -  any two features can only appear in the same tree only if there exists a constraint containing both features
 
 -  ``n_tree_interaction_constraints`` :raw-html:`<a id="n_tree_interaction_constraints" title="Permalink to this parameter" href="#n_tree_interaction_constraints">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int
+
     - controls how many features can appear in the same tree
 
     - by default (n_tree_interaction_constraints = 0) interaction constraints are disabled
 
 -  ``max_interactions`` :raw-html:`<a id="max_interactions" title="Permalink to this parameter" href="#max_interactions">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int
+
     - controls how many features interactions can be added to the final model
-    
+
     - by default no limit is imposed on the interaction with max_interactions = 0
 
 -  ``verbosity`` :raw-html:`<a id="verbosity" title="Permalink to this parameter" href="#verbosity">&#x1F517;&#xFE0E;</a>`, default = ``1``, type = int, aliases: ``verbose``
diff --git a/src/treelearner/col_sampler.hpp b/src/treelearner/col_sampler.hpp
index cbfa7331f380..00e105581708 100644
--- a/src/treelearner/col_sampler.hpp
+++ b/src/treelearner/col_sampler.hpp
@@ -152,7 +152,7 @@ class ColSampler {
       if (n_tree_interaction_constraints_ == 0 || tree_features.size() < (std::set<int>::size_type) n_tree_interaction_constraints_) {
           ComputeTreeAllowedFeatures(&tree_allowed_features, &tree_features);
       } else {
-          for(int feat : tree_features) {
+          for (int feat : tree_features) {
               tree_allowed_features.insert(feat);
           }
       }

From 64ff80cfac9fc2e06dd390b781957ad8adbcd720 Mon Sep 17 00:00:00 2001
From: veneres <alberto.veneri@unive.it>
Date: Thu, 15 Feb 2024 14:45:07 +0100
Subject: [PATCH 12/21] Fix docs and linting

---
 docs/Parameters.rst                     | 8 ++++----
 src/treelearner/serial_tree_learner.cpp | 2 +-
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/docs/Parameters.rst b/docs/Parameters.rst
index 427647dd6fec..2dfc4e7dcee3 100644
--- a/docs/Parameters.rst
+++ b/docs/Parameters.rst
@@ -634,15 +634,15 @@ Learning Control Parameters
 
 -  ``n_tree_interaction_constraints`` :raw-html:`<a id="n_tree_interaction_constraints" title="Permalink to this parameter" href="#n_tree_interaction_constraints">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int
 
-    - controls how many features can appear in the same tree
+   - controls how many features can appear in the same tree
 
-    - by default (n_tree_interaction_constraints = 0) interaction constraints are disabled
+   - by default (n_tree_interaction_constraints = 0) interaction constraints are disabled
 
 -  ``max_interactions`` :raw-html:`<a id="max_interactions" title="Permalink to this parameter" href="#max_interactions">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int
 
-    - controls how many features interactions can be added to the final model
+   - controls how many features interactions can be added to the final model
 
-    - by default no limit is imposed on the interaction with max_interactions = 0
+   - by default no limit is imposed on the interaction with max_interactions = 0
 
 -  ``verbosity`` :raw-html:`<a id="verbosity" title="Permalink to this parameter" href="#verbosity">&#x1F517;&#xFE0E;</a>`, default = ``1``, type = int, aliases: ``verbose``
 
diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp
index 43cdc09accce..b60ff45d84c1 100644
--- a/src/treelearner/serial_tree_learner.cpp
+++ b/src/treelearner/serial_tree_learner.cpp
@@ -335,7 +335,7 @@ void SerialTreeLearner::BeforeTrain() {
 
 bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
   Common::FunctionTimer fun_timer("SerialTreeLearner::BeforeFindBestSplit", global_timer);
-  #pragma omp parallel for schedule(static)
+  #pragma omp parallel for schedule(static) num_threads(OMP_NUM_THREADS())
   for (int i = 0; i < config_->num_leaves; ++i) {
     int feat_index = best_split_per_leaf_[i].feature;
     if (feat_index == -1) continue;

From ee8d6e6793c6cc57c138094149a432dc5441e91d Mon Sep 17 00:00:00 2001
From: veneres <alberto.veneri@unive.it>
Date: Thu, 15 Feb 2024 14:53:55 +0100
Subject: [PATCH 13/21] Fix docs

---
 docs/Parameters.rst | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/docs/Parameters.rst b/docs/Parameters.rst
index 2dfc4e7dcee3..cc8b599a73bd 100644
--- a/docs/Parameters.rst
+++ b/docs/Parameters.rst
@@ -634,15 +634,15 @@ Learning Control Parameters
 
 -  ``n_tree_interaction_constraints`` :raw-html:`<a id="n_tree_interaction_constraints" title="Permalink to this parameter" href="#n_tree_interaction_constraints">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int
 
-   - controls how many features can appear in the same tree
+   -  controls how many features can appear in the same tree
 
-   - by default (n_tree_interaction_constraints = 0) interaction constraints are disabled
+   -  by default (n_tree_interaction_constraints = 0) interaction constraints are disabled
 
 -  ``max_interactions`` :raw-html:`<a id="max_interactions" title="Permalink to this parameter" href="#max_interactions">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int
 
-   - controls how many features interactions can be added to the final model
+   -  controls how many features interactions can be added to the final model
 
-   - by default no limit is imposed on the interaction with max_interactions = 0
+   -  by default no limit is imposed on the interaction with max_interactions = 0
 
 -  ``verbosity`` :raw-html:`<a id="verbosity" title="Permalink to this parameter" href="#verbosity">&#x1F517;&#xFE0E;</a>`, default = ``1``, type = int, aliases: ``verbose``
 

From 09acfcf1da971bb900f8a0b236f43acd34e37dc6 Mon Sep 17 00:00:00 2001
From: veneres <alberto.veneri@unive.it>
Date: Thu, 15 Feb 2024 15:12:55 +0100
Subject: [PATCH 14/21] Fix docs

---
 docs/Parameters.rst       |  4 ++--
 include/LightGBM/config.h |  2 ++
 src/io/config_auto.cpp    | 13 ++++++++++---
 3 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/docs/Parameters.rst b/docs/Parameters.rst
index cc8b599a73bd..5be86c8a50d4 100644
--- a/docs/Parameters.rst
+++ b/docs/Parameters.rst
@@ -632,13 +632,13 @@ Learning Control Parameters
 
    -  any two features can only appear in the same tree only if there exists a constraint containing both features
 
--  ``n_tree_interaction_constraints`` :raw-html:`<a id="n_tree_interaction_constraints" title="Permalink to this parameter" href="#n_tree_interaction_constraints">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int
+-  ``n_tree_interaction_constraints`` :raw-html:`<a id="n_tree_interaction_constraints" title="Permalink to this parameter" href="#n_tree_interaction_constraints">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int, constraints: ``n_tree_interaction_constraints >=  0.0``
 
    -  controls how many features can appear in the same tree
 
    -  by default (n_tree_interaction_constraints = 0) interaction constraints are disabled
 
--  ``max_interactions`` :raw-html:`<a id="max_interactions" title="Permalink to this parameter" href="#max_interactions">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int
+-  ``max_interactions`` :raw-html:`<a id="max_interactions" title="Permalink to this parameter" href="#max_interactions">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int, constraints: ``max_interactions >=  0.0``
 
    -  controls how many features interactions can be added to the final model
 
diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h
index 248e41f95ef6..648468bee2d4 100644
--- a/include/LightGBM/config.h
+++ b/include/LightGBM/config.h
@@ -579,10 +579,12 @@ struct Config {
   // desc = any two features can only appear in the same tree only if there exists a constraint containing both features
   std::string tree_interaction_constraints = "";
 
+  // check = >= 0.0
   // desc = controls how many features can appear in the same tree
   // desc = by default (n_tree_interaction_constraints = 0) interaction constraints are disabled
   int n_tree_interaction_constraints = 0;
 
+  // check = >= 0.0
   // desc = controls how many features interactions can be added to the final model
   // desc = by default no limit is imposed on the interaction with max_interactions = 0
   int max_interactions = 0;
diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp
index 42a518cb3029..e8debbf8c05f 100644
--- a/src/io/config_auto.cpp
+++ b/src/io/config_auto.cpp
@@ -322,7 +322,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
   "gpu_platform_id",
   "gpu_device_id",
   "gpu_use_dp",
-  "num_gpu"
+  "num_gpu",
   });
   return params;
 }
@@ -494,10 +494,10 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
   GetString(params, "tree_interaction_constraints", &tree_interaction_constraints);
 
   GetInt(params, "n_tree_interaction_constraints", &n_tree_interaction_constraints);
-  CHECK_GT(n_tree_interaction_constraints, -1);
+  CHECK_GE(n_tree_interaction_constraints,  0.0);
 
   GetInt(params, "max_interactions", &max_interactions);
-  CHECK_GT(max_interactions, -1);
+  CHECK_GE(max_interactions,  0.0);
 
   GetInt(params, "verbosity", &verbosity);
 
@@ -734,6 +734,7 @@ std::string Config::SaveMembersToString() const {
   str_buf << "[path_smooth: " << path_smooth << "]\n";
   str_buf << "[interaction_constraints: " << interaction_constraints << "]\n";
   str_buf << "[tree_interaction_constraints: " << tree_interaction_constraints << "]\n";
+  str_buf << "[n_tree_interaction_constraints: " << n_tree_interaction_constraints << "]\n";
   str_buf << "[max_interactions: " << max_interactions << "]\n";
   str_buf << "[verbosity: " << verbosity << "]\n";
   str_buf << "[saved_feature_importance_type: " << saved_feature_importance_type << "]\n";
@@ -859,6 +860,9 @@ const std::unordered_map<std::string, std::vector<std::string>>& Config::paramet
     {"cegb_penalty_feature_coupled", {}},
     {"path_smooth", {}},
     {"interaction_constraints", {}},
+    {"tree_interaction_constraints", {}},
+    {"n_tree_interaction_constraints", {}},
+    {"max_interactions", {}},
     {"verbosity", {"verbose"}},
     {"input_model", {"model_input", "model_in"}},
     {"output_model", {"model_output", "model_out"}},
@@ -1002,6 +1006,9 @@ const std::unordered_map<std::string, std::string>& Config::ParameterTypes() {
     {"cegb_penalty_feature_coupled", "vector<double>"},
     {"path_smooth", "double"},
     {"interaction_constraints", "vector<vector<int>>"},
+    {"tree_interaction_constraints", "string"},
+    {"n_tree_interaction_constraints", "int"},
+    {"max_interactions", "int"},
     {"verbosity", "int"},
     {"input_model", "string"},
     {"output_model", "string"},

From 0d66bea3e112f2a496315c28a3c8a03b8132171f Mon Sep 17 00:00:00 2001
From: veneres <alberto.veneri@unive.it>
Date: Fri, 16 Feb 2024 15:49:15 +0100
Subject: [PATCH 15/21] Boolean guards added for constrained learning

---
 include/LightGBM/tree.h                  |  5 +-
 src/boosting/gbdt.cpp                    |  9 ++--
 src/boosting/gbdt.h                      |  2 +-
 src/treelearner/col_sampler.hpp          |  2 +-
 src/treelearner/serial_tree_learner.cpp  | 21 +++++----
 tests/python_package_test/test_engine.py | 60 ++++++++++++++++++++++++
 6 files changed, 83 insertions(+), 16 deletions(-)

diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h
index 231de949b844..dc1303c00b6b 100644
--- a/include/LightGBM/tree.h
+++ b/include/LightGBM/tree.h
@@ -157,8 +157,9 @@ class Tree {
   inline int split_feature_inner(int split_idx) const { return split_feature_inner_[split_idx]; }
 
   /*! \brief Get features on leaf's branch*/
-  std::vector<int> branch_features(int leaf) const { return branch_features_[leaf]; }
+  inline std::vector<int> branch_features(int leaf) const { return branch_features_[leaf]; }
 
+  /*! \brief Get unique features used by the current tree*/
   std::set<int> tree_features() const {
      return tree_features_;
   }
@@ -324,6 +325,8 @@ class Tree {
 
   inline bool is_linear() const { return is_linear_; }
 
+  inline bool is_tracking_branch_features() const { return track_branch_features_; }
+
   #ifdef USE_CUDA
   inline bool is_cuda_tree() const { return is_cuda_tree_; }
   #endif  // USE_CUDA
diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp
index 11c22841bd15..8ded1d9ea07c 100644
--- a/src/boosting/gbdt.cpp
+++ b/src/boosting/gbdt.cpp
@@ -244,17 +244,18 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
       std::string snapshot_out = model_output_path + ".snapshot_iter_" + std::to_string(iter + 1);
       SaveModelToFile(0, -1, config_->saved_feature_importance_type, snapshot_out.c_str());
     }
-    interactions_used.insert(models_[models_.size() - 1]->tree_features());
+
+    if (config_->max_interactions != 0) {
+      interactions_used.insert(models_[models_.size() - 1]->tree_features());
+    }
 
     if (config_->max_interactions != 0 && static_cast<int>(interactions_used.size()) >= config_->max_interactions) {
       auto new_config = std::unique_ptr<Config>(new Config(*config_));
       new_config->tree_interaction_constraints_vector.clear();
       for (auto &inter_set : interactions_used) {
-        new_config->tree_interaction_constraints_vector.emplace_back(
-            inter_set.begin(), inter_set.end());
+        new_config->tree_interaction_constraints_vector.emplace_back(inter_set.begin(), inter_set.end());
       }
       new_config->max_interactions = 0;
-
       ResetConfig(new_config.release());
     }
   }
diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h
index b96f15216246..d3dc80a837ec 100644
--- a/src/boosting/gbdt.h
+++ b/src/boosting/gbdt.h
@@ -543,7 +543,7 @@ class GBDT : public GBDTBase {
   std::vector<std::vector<std::string>> best_msg_;
   /*! \brief Trained models(trees) */
   std::vector<std::unique_ptr<Tree>> models_;
-  /*! \brief Trained models(trees) */
+  /*! \brief Set of set of features used in all the models */
   std::set<std::set<int>> interactions_used;
   /*! \brief Max feature index of training data*/
   int max_feature_idx_;
diff --git a/src/treelearner/col_sampler.hpp b/src/treelearner/col_sampler.hpp
index 00e105581708..7018c7b5051a 100644
--- a/src/treelearner/col_sampler.hpp
+++ b/src/treelearner/col_sampler.hpp
@@ -163,7 +163,7 @@ class ColSampler {
     ComputeBranchAllowedFeatures(tree, leaf, &branch_allowed_features);
 
 
-        // intersect allowed features for branch and tree
+    // intersect allowed features for branch and tree
     std::unordered_set<int> allowed_features;
 
     if ((tree_interaction_constraints_.empty() && n_tree_interaction_constraints_ == 0) && !interaction_constraints_.empty()) {
diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp
index b60ff45d84c1..a790bcf93388 100644
--- a/src/treelearner/serial_tree_learner.cpp
+++ b/src/treelearner/serial_tree_learner.cpp
@@ -335,16 +335,19 @@ void SerialTreeLearner::BeforeTrain() {
 
 bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
   Common::FunctionTimer fun_timer("SerialTreeLearner::BeforeFindBestSplit", global_timer);
-  #pragma omp parallel for schedule(static) num_threads(OMP_NUM_THREADS())
-  for (int i = 0; i < config_->num_leaves; ++i) {
-    int feat_index = best_split_per_leaf_[i].feature;
-    if (feat_index == -1) continue;
 
-    int inner_feat_index = train_data_->InnerFeatureIndex(feat_index);
-    auto allowed_feature = col_sampler_.GetByNode(tree, i);
-    if (!allowed_feature[inner_feat_index]) {
-      RecomputeBestSplitForLeaf(tree, i, &best_split_per_leaf_[i]);
-    }
+  if (tree->is_tracking_branch_features()) {
+      #pragma omp parallel for schedule(static) num_threads(OMP_NUM_THREADS())
+      for (int i = 0; i < config_->num_leaves; ++i) {
+          int feat_index = best_split_per_leaf_[i].feature;
+          if (feat_index == -1) continue;
+
+          int inner_feat_index = train_data_->InnerFeatureIndex(feat_index);
+          auto allowed_feature = col_sampler_.GetByNode(tree, i);
+          if (!allowed_feature[inner_feat_index]) {
+              RecomputeBestSplitForLeaf(tree, i, &best_split_per_leaf_[i]);
+          }
+      }
   }
 
   // check depth of current leaf
diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py
index e355e5ab074a..517b9c2700f7 100644
--- a/tests/python_package_test/test_engine.py
+++ b/tests/python_package_test/test_engine.py
@@ -3665,6 +3665,66 @@ def test_interaction_constraints():
                     train_data, num_boost_round=10)
 
 
+'''
+@pytest.mark.skipif(getenv('TASK', '') == 'cuda_exp', reason='Interaction constraints are not yet supported by CUDA Experimental version')
+def test_tree_interaction_constraints():
+    def check_consistency(est, tree_interaction_constraints):
+        feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())}
+        tree_df = est.trees_to_dataframe()
+        inter_found = set()
+        for tree_index in tree_df["tree_index"].unique():
+            tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index]
+            feat_used = [feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None]
+            inter_found.add(tuple(sorted(feat_used)))
+        print(inter_found)
+        for feats_found in inter_found:
+            found = False
+            for real_contraints in tree_interaction_constraints:
+                if set(feats_found) <= set(real_contraints):
+                    found = True
+                    break
+            assert found is True
+    X, y = make_synthetic_regression(n_samples=200)
+    num_features = X.shape[1]
+    train_data = lgb.Dataset(X, label=y)
+    # check that constraint containing all features is equivalent to no constraint
+    params = {'verbose': -1,
+              'seed': 0}
+    est = lgb.train(params, train_data, num_boost_round=10)
+    pred1 = est.predict(X)
+    est = lgb.train(dict(params, interaction_constraints=[list(range(num_features))]), train_data,
+                    num_boost_round=10)
+    num_features = X.shape[1]
+    train_data = lgb.Dataset(X, label=y)
+    # check that tree constraint containing all features is equivalent to no constraint
+    params = {'verbose': -1,
+              'seed': 0}
+    est = lgb.train(params, train_data, num_boost_round=10)
+    pred1 = est.predict(X)
+    est = lgb.train(dict(params, tree_interaction_constraints=[list(range(num_features))]), train_data,
+                    num_boost_round=10)
+    pred2 = est.predict(X)
+    np.testing.assert_allclose(pred1, pred2)
+    # check that each tree is composed exactly of 2 features contained in the contrained set
+    tree_interaction_constraints = [[i, i + 1] for i in range(0, num_features - 1, 2)]
+    print(tree_interaction_constraints)
+    new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
+    est = lgb.train(new_params, train_data, num_boost_round=100)
+    check_consistency(est, tree_interaction_constraints)
+    # check if tree features interaction constraints works with multiple set of features
+    tree_interaction_constraints = [[i for i in range(i, i + 5)] for i in range(0, num_features - 5, 5)]
+    print(tree_interaction_constraints)
+    new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
+    est = lgb.train(new_params, train_data, num_boost_round=100)
+    check_consistency(est, tree_interaction_constraints)
+    # check if tree features interaction constraints works with multiple set of features
+    tree_interaction_constraints = [[i] for i in range(num_features)]
+    print(tree_interaction_constraints)
+    new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
+    est = lgb.train(new_params, train_data, num_boost_round=100)
+    check_consistency(est, tree_interaction_constraints)
+'''
+
 def test_linear_trees_num_threads():
     # check that number of threads does not affect result
     np.random.seed(0)

From 84287f12296d883c3a2becee412bc93c67f6e2aa Mon Sep 17 00:00:00 2001
From: veneres <alberto.veneri@unive.it>
Date: Fri, 16 Feb 2024 17:54:59 +0100
Subject: [PATCH 16/21] test and small fix added

---
 src/treelearner/serial_tree_learner.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp
index a790bcf93388..164ae1f5f25d 100644
--- a/src/treelearner/serial_tree_learner.cpp
+++ b/src/treelearner/serial_tree_learner.cpp
@@ -198,7 +198,8 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
 
   bool track_branch_features = !(config_->interaction_constraints_vector.empty()
                                  && config_->tree_interaction_constraints_vector.empty()
-                                 && config_->n_tree_interaction_constraints == 0);
+                                 && config_->n_tree_interaction_constraints == 0
+                                 && config_->max_interactions == 0);
   auto tree = std::unique_ptr<Tree>(new Tree(config_->num_leaves, track_branch_features, false));
   auto tree_ptr = tree.get();
   constraints_->ShareTreePointer(tree_ptr);

From 227ec1b58b79d494c5a033b6729d1310f6d57036 Mon Sep 17 00:00:00 2001
From: Alberto Veneri <alberto.veneri@unive.it>
Date: Wed, 21 Feb 2024 10:49:16 +0100
Subject: [PATCH 17/21] Param name refactor

---
 docs/Parameters.rst                     |  4 ++--
 include/LightGBM/config.h               |  4 ++--
 src/io/config_auto.cpp                  | 12 ++++++------
 src/treelearner/col_sampler.hpp         | 18 +++++++++---------
 src/treelearner/serial_tree_learner.cpp |  2 +-
 5 files changed, 20 insertions(+), 20 deletions(-)

diff --git a/docs/Parameters.rst b/docs/Parameters.rst
index 5be86c8a50d4..8bf02c1ba4da 100644
--- a/docs/Parameters.rst
+++ b/docs/Parameters.rst
@@ -632,11 +632,11 @@ Learning Control Parameters
 
    -  any two features can only appear in the same tree only if there exists a constraint containing both features
 
--  ``n_tree_interaction_constraints`` :raw-html:`<a id="n_tree_interaction_constraints" title="Permalink to this parameter" href="#n_tree_interaction_constraints">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int, constraints: ``n_tree_interaction_constraints >=  0.0``
+-  ``max_tree_interactions`` :raw-html:`<a id="max_tree_interactions" title="Permalink to this parameter" href="#max_tree_interactions">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int, constraints: ``max_tree_interactions >=  0.0``
 
    -  controls how many features can appear in the same tree
 
-   -  by default (n_tree_interaction_constraints = 0) interaction constraints are disabled
+   -  by default (max_tree_interactions = 0) interaction constraints are disabled
 
 -  ``max_interactions`` :raw-html:`<a id="max_interactions" title="Permalink to this parameter" href="#max_interactions">&#x1F517;&#xFE0E;</a>`, default = ``0``, type = int, constraints: ``max_interactions >=  0.0``
 
diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h
index 648468bee2d4..b8693814749f 100644
--- a/include/LightGBM/config.h
+++ b/include/LightGBM/config.h
@@ -581,8 +581,8 @@ struct Config {
 
   // check = >= 0.0
   // desc = controls how many features can appear in the same tree
-  // desc = by default (n_tree_interaction_constraints = 0) interaction constraints are disabled
-  int n_tree_interaction_constraints = 0;
+  // desc = by default (max_tree_interactions = 0) interaction constraints are disabled
+  int max_tree_interactions = 0;
 
   // check = >= 0.0
   // desc = controls how many features interactions can be added to the final model
diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp
index e8debbf8c05f..784e23e57fe9 100644
--- a/src/io/config_auto.cpp
+++ b/src/io/config_auto.cpp
@@ -247,7 +247,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
   "path_smooth",
   "interaction_constraints",
   "tree_interaction_constraints",
-  "n_tree_interaction_constraints",
+  "max_tree_interactions",
   "max_interactions",
   "verbosity",
   "input_model",
@@ -493,8 +493,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
 
   GetString(params, "tree_interaction_constraints", &tree_interaction_constraints);
 
-  GetInt(params, "n_tree_interaction_constraints", &n_tree_interaction_constraints);
-  CHECK_GE(n_tree_interaction_constraints,  0.0);
+  GetInt(params, "max_tree_interactions", &max_tree_interactions);
+  CHECK_GE(max_tree_interactions,  0.0);
 
   GetInt(params, "max_interactions", &max_interactions);
   CHECK_GE(max_interactions,  0.0);
@@ -734,7 +734,7 @@ std::string Config::SaveMembersToString() const {
   str_buf << "[path_smooth: " << path_smooth << "]\n";
   str_buf << "[interaction_constraints: " << interaction_constraints << "]\n";
   str_buf << "[tree_interaction_constraints: " << tree_interaction_constraints << "]\n";
-  str_buf << "[n_tree_interaction_constraints: " << n_tree_interaction_constraints << "]\n";
+  str_buf << "[max_tree_interactions: " << max_tree_interactions << "]\n";
   str_buf << "[max_interactions: " << max_interactions << "]\n";
   str_buf << "[verbosity: " << verbosity << "]\n";
   str_buf << "[saved_feature_importance_type: " << saved_feature_importance_type << "]\n";
@@ -861,7 +861,7 @@ const std::unordered_map<std::string, std::vector<std::string>>& Config::paramet
     {"path_smooth", {}},
     {"interaction_constraints", {}},
     {"tree_interaction_constraints", {}},
-    {"n_tree_interaction_constraints", {}},
+    {"max_tree_interactions", {}},
     {"max_interactions", {}},
     {"verbosity", {"verbose"}},
     {"input_model", {"model_input", "model_in"}},
@@ -1007,7 +1007,7 @@ const std::unordered_map<std::string, std::string>& Config::ParameterTypes() {
     {"path_smooth", "double"},
     {"interaction_constraints", "vector<vector<int>>"},
     {"tree_interaction_constraints", "string"},
-    {"n_tree_interaction_constraints", "int"},
+    {"max_tree_interactions", "int"},
     {"max_interactions", "int"},
     {"verbosity", "int"},
     {"input_model", "string"},
diff --git a/src/treelearner/col_sampler.hpp b/src/treelearner/col_sampler.hpp
index 7018c7b5051a..b35880da4be4 100644
--- a/src/treelearner/col_sampler.hpp
+++ b/src/treelearner/col_sampler.hpp
@@ -35,7 +35,7 @@ class ColSampler {
       tree_interaction_constraints_.push_back(constraint_set);
     }
 
-    n_tree_interaction_constraints_ = config-> n_tree_interaction_constraints;
+    max_tree_interactions_ = config-> max_tree_interactions;
   }
 
   static int GetCnt(size_t total_cnt, double fraction) {
@@ -147,9 +147,9 @@ class ColSampler {
     std::vector<int8_t> GetByNode(const Tree* tree, int leaf) {
     // get interaction constraints for current tree
     std::unordered_set<int> tree_allowed_features;
-    if (!tree_interaction_constraints_.empty() || n_tree_interaction_constraints_ > 0) {
+    if (!tree_interaction_constraints_.empty() || max_tree_interactions_ > 0) {
       std::set<int> tree_features = tree->tree_features();
-      if (n_tree_interaction_constraints_ == 0 || tree_features.size() < (std::set<int>::size_type) n_tree_interaction_constraints_) {
+      if (max_tree_interactions_ == 0 || tree_features.size() < (std::set<int>::size_type) max_tree_interactions_) {
           ComputeTreeAllowedFeatures(&tree_allowed_features, &tree_features);
       } else {
           for (int feat : tree_features) {
@@ -166,9 +166,9 @@ class ColSampler {
     // intersect allowed features for branch and tree
     std::unordered_set<int> allowed_features;
 
-    if ((tree_interaction_constraints_.empty() && n_tree_interaction_constraints_ == 0) && !interaction_constraints_.empty()) {
+    if ((tree_interaction_constraints_.empty() && max_tree_interactions_ == 0) && !interaction_constraints_.empty()) {
       allowed_features.insert(branch_allowed_features.begin(), branch_allowed_features.end());
-    } else if (!(tree_interaction_constraints_.empty() && n_tree_interaction_constraints_ == 0) && interaction_constraints_.empty()) {
+    } else if (!(tree_interaction_constraints_.empty() && max_tree_interactions_ == 0) && interaction_constraints_.empty()) {
       allowed_features.insert(tree_allowed_features.begin(), tree_allowed_features.end());
     } else {
       for (int element : tree_allowed_features) {
@@ -180,7 +180,7 @@ class ColSampler {
 
     std::vector<int8_t> ret(train_data_->num_features(), 0);
     if (fraction_bynode_ >= 1.0f) {
-      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty() && n_tree_interaction_constraints_ == 0) {
+      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty() && max_tree_interactions_ == 0) {
         return std::vector<int8_t>(train_data_->num_features(), 1);
       } else {
         for (int feat : allowed_features) {
@@ -196,7 +196,7 @@ class ColSampler {
       auto used_feature_cnt = GetCnt(used_feature_indices_.size(), fraction_bynode_);
       std::vector<int>* allowed_used_feature_indices;
       std::vector<int> filtered_feature_indices;
-      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty() && n_tree_interaction_constraints_ == 0) {
+      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty() && max_tree_interactions_ == 0) {
         allowed_used_feature_indices = &used_feature_indices_;
       } else {
         for (int feat_ind : used_feature_indices_) {
@@ -222,7 +222,7 @@ class ColSampler {
           GetCnt(valid_feature_indices_.size(), fraction_bynode_);
       std::vector<int>* allowed_valid_feature_indices;
       std::vector<int> filtered_feature_indices;
-      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty() && n_tree_interaction_constraints_ == 0) {
+      if (interaction_constraints_.empty() && tree_interaction_constraints_.empty() && max_tree_interactions_ == 0) {
         allowed_valid_feature_indices = &valid_feature_indices_;
       } else {
         for (int feat : valid_feature_indices_) {
@@ -269,7 +269,7 @@ class ColSampler {
   std::vector<std::unordered_set<int>> interaction_constraints_;
 
   std::vector<std::unordered_set<int>> tree_interaction_constraints_;
-  int n_tree_interaction_constraints_;
+  int max_tree_interactions_;
 };
 
 }  // namespace LightGBM
diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp
index 164ae1f5f25d..43b46c72bd28 100644
--- a/src/treelearner/serial_tree_learner.cpp
+++ b/src/treelearner/serial_tree_learner.cpp
@@ -198,7 +198,7 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
 
   bool track_branch_features = !(config_->interaction_constraints_vector.empty()
                                  && config_->tree_interaction_constraints_vector.empty()
-                                 && config_->n_tree_interaction_constraints == 0
+                                 && config_->max_tree_interactions == 0
                                  && config_->max_interactions == 0);
   auto tree = std::unique_ptr<Tree>(new Tree(config_->num_leaves, track_branch_features, false));
   auto tree_ptr = tree.get();

From ca3dac56ebcea3876a4fa2d8435ff3355a140977 Mon Sep 17 00:00:00 2001
From: Alberto Veneri <alberto.veneri@unive.it>
Date: Wed, 21 Feb 2024 10:49:48 +0100
Subject: [PATCH 18/21] Interaction constraints test added

---
 tests/python_package_test/test_engine.py | 116 ++++++++++++++++++-----
 1 file changed, 94 insertions(+), 22 deletions(-)

diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py
index 517b9c2700f7..09eb2a503639 100644
--- a/tests/python_package_test/test_engine.py
+++ b/tests/python_package_test/test_engine.py
@@ -3664,8 +3664,6 @@ def test_interaction_constraints():
                                                           [1] + list(range(2, num_features))]),
                     train_data, num_boost_round=10)
 
-
-'''
 @pytest.mark.skipif(getenv('TASK', '') == 'cuda_exp', reason='Interaction constraints are not yet supported by CUDA Experimental version')
 def test_tree_interaction_constraints():
     def check_consistency(est, tree_interaction_constraints):
@@ -3676,7 +3674,6 @@ def check_consistency(est, tree_interaction_constraints):
             tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index]
             feat_used = [feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None]
             inter_found.add(tuple(sorted(feat_used)))
-        print(inter_found)
         for feats_found in inter_found:
             found = False
             for real_contraints in tree_interaction_constraints:
@@ -3684,16 +3681,7 @@ def check_consistency(est, tree_interaction_constraints):
                     found = True
                     break
             assert found is True
-    X, y = make_synthetic_regression(n_samples=200)
-    num_features = X.shape[1]
-    train_data = lgb.Dataset(X, label=y)
-    # check that constraint containing all features is equivalent to no constraint
-    params = {'verbose': -1,
-              'seed': 0}
-    est = lgb.train(params, train_data, num_boost_round=10)
-    pred1 = est.predict(X)
-    est = lgb.train(dict(params, interaction_constraints=[list(range(num_features))]), train_data,
-                    num_boost_round=10)
+    X, y = make_synthetic_regression(n_samples=400, n_features=30)
     num_features = X.shape[1]
     train_data = lgb.Dataset(X, label=y)
     # check that tree constraint containing all features is equivalent to no constraint
@@ -3705,25 +3693,109 @@ def check_consistency(est, tree_interaction_constraints):
                     num_boost_round=10)
     pred2 = est.predict(X)
     np.testing.assert_allclose(pred1, pred2)
-    # check that each tree is composed exactly of 2 features contained in the contrained set
-    tree_interaction_constraints = [[i, i + 1] for i in range(0, num_features - 1, 2)]
-    print(tree_interaction_constraints)
+
+    # check that each tree is composed exactly of 1 feature
+    tree_interaction_constraints = [[i] for i in range(num_features)]
     new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
     est = lgb.train(new_params, train_data, num_boost_round=100)
     check_consistency(est, tree_interaction_constraints)
-    # check if tree features interaction constraints works with multiple set of features
-    tree_interaction_constraints = [[i for i in range(i, i + 5)] for i in range(0, num_features - 5, 5)]
-    print(tree_interaction_constraints)
+
+    # check that each tree is composed exactly of 2 features contained in the constrained set
+    tree_interaction_constraints = [[i, i + 1] for i in range(0, num_features - 1, 2)]
     new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
     est = lgb.train(new_params, train_data, num_boost_round=100)
     check_consistency(est, tree_interaction_constraints)
+
     # check if tree features interaction constraints works with multiple set of features
-    tree_interaction_constraints = [[i] for i in range(num_features)]
-    print(tree_interaction_constraints)
+    tree_interaction_constraints = [[i for i in range(i, i + 5)] for i in range(0, num_features - 5, 5)]
     new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
     est = lgb.train(new_params, train_data, num_boost_round=100)
     check_consistency(est, tree_interaction_constraints)
-'''
+
+
+
+@pytest.mark.skipif(getenv('TASK', '') == 'cuda_exp', reason='Interaction constraints are not yet supported by CUDA Experimental version')
+def test_max_tree_interactions():
+    def check_n_interactions(est):
+        feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())}
+        tree_df = est.trees_to_dataframe()
+        max_n_interactions = 0
+        for tree_index in tree_df["tree_index"].unique():
+            tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index]
+            feat_used = [feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None]
+            max_n_interactions = max(max_n_interactions, len(feat_used))
+        assert max_n_interactions <= est.params["max_tree_interactions"]
+
+    X, y = make_synthetic_regression(n_samples=400, n_features=30)
+    train_data = lgb.Dataset(X, label=y)
+    # check that limiting the number of interaction to the number of features is equivalent to no constraint
+    params = {'verbose': -1,
+              'seed': 0}
+    est = lgb.train(params, train_data, num_boost_round=100)
+    pred1 = est.predict(X)
+    est = lgb.train(dict(params, max_tree_interactions=400), train_data, num_boost_round=100)
+    pred2 = est.predict(X)
+
+    check_n_interactions(est)
+    np.testing.assert_allclose(pred1, pred2)
+
+    # check that the forest has only 1 interaction per tree
+    max_tree_interactions = 1
+    new_params = dict(params, max_tree_interactions=max_tree_interactions)
+    est = lgb.train(new_params, train_data, num_boost_round=100)
+    check_n_interactions(est)
+
+    #  check that the forest has at most 10 features that interact in a tree
+    max_tree_interactions = 10
+    new_params = dict(params, max_tree_interactions=max_tree_interactions)
+    est = lgb.train(new_params, train_data, num_boost_round=100)
+    check_n_interactions(est)
+
+@pytest.mark.skipif(getenv('TASK', '') == 'cuda_exp', reason='Interaction constraints are not yet supported by CUDA Experimental version')
+def test_max_interactions():
+    def check_interactions(est, max_interactions):
+        feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())}
+        tree_df = est.trees_to_dataframe()
+        inter_found = set()
+        for tree_index in tree_df["tree_index"].unique():
+            tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index]
+            feat_used = [feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None]
+            add_inter = True
+            for inter in inter_found:
+                if set(feat_used) <= set(inter): # the interaction found is a subset of another interaction
+                    add_inter = False
+            if add_inter:
+                inter_found.add(tuple(sorted(feat_used)))
+        assert len(inter_found) <= max_interactions
+
+    X, y = make_synthetic_regression(n_samples=400, n_features=30)
+    train_data = lgb.Dataset(X, label=y)
+    # check that limiting the number of distinct interactions to the number of trees is equivalent to no constraint
+    params = {'verbose': -1,
+              'seed': 0}
+    est = lgb.train(params, train_data, num_boost_round=100)
+    pred1 = est.predict(X)
+
+    max_interactions = 100
+    est = lgb.train(dict(params, max_interactions=max_interactions), train_data, num_boost_round=100)
+    pred2 = est.predict(X)
+
+    check_interactions(est, max_interactions)
+    np.testing.assert_allclose(pred1, pred2)
+
+    # check that the forest has only 1 interaction
+    max_interactions = 1
+    new_params = dict(params, max_interactions=max_interactions)
+    est = lgb.train(new_params, train_data, num_boost_round=100)
+    check_interactions(est, max_interactions)
+
+    #  check that the forest has at most 10 interactions
+    max_interactions = 10
+    new_params = dict(params, max_interactions=max_interactions)
+    est = lgb.train(new_params, train_data, num_boost_round=100)
+    check_interactions(est, max_interactions)
+
+
 
 def test_linear_trees_num_threads():
     # check that number of threads does not affect result

From ab0435242b01412a066d7be267961f2468fb9bc7 Mon Sep 17 00:00:00 2001
From: Alberto Veneri <alberto.veneri@unive.it>
Date: Wed, 21 Feb 2024 10:55:53 +0100
Subject: [PATCH 19/21] Addressed: Unnecessary `list` comprehension (rewrite
 using `list()`)

---
 tests/python_package_test/test_engine.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py
index 09eb2a503639..20f58304f633 100644
--- a/tests/python_package_test/test_engine.py
+++ b/tests/python_package_test/test_engine.py
@@ -3707,7 +3707,7 @@ def check_consistency(est, tree_interaction_constraints):
     check_consistency(est, tree_interaction_constraints)
 
     # check if tree features interaction constraints works with multiple set of features
-    tree_interaction_constraints = [[i for i in range(i, i + 5)] for i in range(0, num_features - 5, 5)]
+    tree_interaction_constraints = [list(range(i, i + 5)) for i in range(0, num_features - 5, 5)]
     new_params = dict(params, tree_interaction_constraints=tree_interaction_constraints)
     est = lgb.train(new_params, train_data, num_boost_round=100)
     check_consistency(est, tree_interaction_constraints)

From c0a45917e0e960b032f5a179f6314a41215148e5 Mon Sep 17 00:00:00 2001
From: Alberto Veneri <alberto.veneri@unive.it>
Date: Thu, 22 Feb 2024 13:30:31 +0100
Subject: [PATCH 20/21] Skip constraint test on CUDA for the moment

---
 tests/python_package_test/test_engine.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py
index b9fe2705bddb..cca8d389452a 100644
--- a/tests/python_package_test/test_engine.py
+++ b/tests/python_package_test/test_engine.py
@@ -3458,7 +3458,7 @@ def test_interaction_constraints():
         num_boost_round=10,
     )
 
-@pytest.mark.skipif(getenv('TASK', '') == 'cuda_exp', reason='Interaction constraints are not yet supported by CUDA Experimental version')
+@pytest.mark.skipif(getenv('TASK', '') == 'cuda', reason='Interaction constraints are not yet supported on the CUDA version')
 def test_tree_interaction_constraints():
     def check_consistency(est, tree_interaction_constraints):
         feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())}
@@ -3508,7 +3508,7 @@ def check_consistency(est, tree_interaction_constraints):
 
 
 
-@pytest.mark.skipif(getenv('TASK', '') == 'cuda_exp', reason='Interaction constraints are not yet supported by CUDA Experimental version')
+@pytest.mark.skipif(getenv('TASK', '') == 'cuda', reason='Interaction constraints are not yet supported on the CUDA version')
 def test_max_tree_interactions():
     def check_n_interactions(est):
         feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())}
@@ -3545,7 +3545,7 @@ def check_n_interactions(est):
     est = lgb.train(new_params, train_data, num_boost_round=100)
     check_n_interactions(est)
 
-@pytest.mark.skipif(getenv('TASK', '') == 'cuda_exp', reason='Interaction constraints are not yet supported by CUDA Experimental version')
+@pytest.mark.skipif(getenv('TASK', '') == 'cuda', reason='Interaction constraints are not yet supported on the CUDA version')
 def test_max_interactions():
     def check_interactions(est, max_interactions):
         feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())}

From 8165317e322f94391300db9f2b170d9ef4005ae6 Mon Sep 17 00:00:00 2001
From: Alberto Veneri <alberto.veneri@unive.it>
Date: Thu, 22 Feb 2024 13:47:48 +0100
Subject: [PATCH 21/21] Reformat file for ruff check

---
 tests/python_package_test/test_engine.py | 45 +++++++++++++++---------
 1 file changed, 28 insertions(+), 17 deletions(-)

diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py
index cca8d389452a..681f77cb91b1 100644
--- a/tests/python_package_test/test_engine.py
+++ b/tests/python_package_test/test_engine.py
@@ -3458,7 +3458,10 @@ def test_interaction_constraints():
         num_boost_round=10,
     )
 
-@pytest.mark.skipif(getenv('TASK', '') == 'cuda', reason='Interaction constraints are not yet supported on the CUDA version')
+
+@pytest.mark.skipif(
+    getenv("TASK", "") == "cuda", reason="Interaction constraints are not yet supported on the CUDA version"
+)
 def test_tree_interaction_constraints():
     def check_consistency(est, tree_interaction_constraints):
         feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())}
@@ -3466,7 +3469,9 @@ def check_consistency(est, tree_interaction_constraints):
         inter_found = set()
         for tree_index in tree_df["tree_index"].unique():
             tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index]
-            feat_used = [feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None]
+            feat_used = [
+                feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None
+            ]
             inter_found.add(tuple(sorted(feat_used)))
         for feats_found in inter_found:
             found = False
@@ -3475,16 +3480,17 @@ def check_consistency(est, tree_interaction_constraints):
                     found = True
                     break
             assert found is True
+
     X, y = make_synthetic_regression(n_samples=400, n_features=30)
     num_features = X.shape[1]
     train_data = lgb.Dataset(X, label=y)
     # check that tree constraint containing all features is equivalent to no constraint
-    params = {'verbose': -1,
-              'seed': 0}
+    params = {"verbose": -1, "seed": 0}
     est = lgb.train(params, train_data, num_boost_round=10)
     pred1 = est.predict(X)
-    est = lgb.train(dict(params, tree_interaction_constraints=[list(range(num_features))]), train_data,
-                    num_boost_round=10)
+    est = lgb.train(
+        dict(params, tree_interaction_constraints=[list(range(num_features))]), train_data, num_boost_round=10
+    )
     pred2 = est.predict(X)
     np.testing.assert_allclose(pred1, pred2)
 
@@ -3507,8 +3513,9 @@ def check_consistency(est, tree_interaction_constraints):
     check_consistency(est, tree_interaction_constraints)
 
 
-
-@pytest.mark.skipif(getenv('TASK', '') == 'cuda', reason='Interaction constraints are not yet supported on the CUDA version')
+@pytest.mark.skipif(
+    getenv("TASK", "") == "cuda", reason="Interaction constraints are not yet supported on the CUDA version"
+)
 def test_max_tree_interactions():
     def check_n_interactions(est):
         feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())}
@@ -3516,15 +3523,16 @@ def check_n_interactions(est):
         max_n_interactions = 0
         for tree_index in tree_df["tree_index"].unique():
             tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index]
-            feat_used = [feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None]
+            feat_used = [
+                feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None
+            ]
             max_n_interactions = max(max_n_interactions, len(feat_used))
         assert max_n_interactions <= est.params["max_tree_interactions"]
 
     X, y = make_synthetic_regression(n_samples=400, n_features=30)
     train_data = lgb.Dataset(X, label=y)
     # check that limiting the number of interaction to the number of features is equivalent to no constraint
-    params = {'verbose': -1,
-              'seed': 0}
+    params = {"verbose": -1, "seed": 0}
     est = lgb.train(params, train_data, num_boost_round=100)
     pred1 = est.predict(X)
     est = lgb.train(dict(params, max_tree_interactions=400), train_data, num_boost_round=100)
@@ -3545,7 +3553,10 @@ def check_n_interactions(est):
     est = lgb.train(new_params, train_data, num_boost_round=100)
     check_n_interactions(est)
 
-@pytest.mark.skipif(getenv('TASK', '') == 'cuda', reason='Interaction constraints are not yet supported on the CUDA version')
+
+@pytest.mark.skipif(
+    getenv("TASK", "") == "cuda", reason="Interaction constraints are not yet supported on the CUDA version"
+)
 def test_max_interactions():
     def check_interactions(est, max_interactions):
         feat_to_index = {feat: i for i, feat in enumerate(est.feature_name())}
@@ -3553,10 +3564,12 @@ def check_interactions(est, max_interactions):
         inter_found = set()
         for tree_index in tree_df["tree_index"].unique():
             tree_df_per_index = tree_df[tree_df["tree_index"] == tree_index]
-            feat_used = [feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None]
+            feat_used = [
+                feat_to_index[feat] for feat in tree_df_per_index["split_feature"].unique() if feat is not None
+            ]
             add_inter = True
             for inter in inter_found:
-                if set(feat_used) <= set(inter): # the interaction found is a subset of another interaction
+                if set(feat_used) <= set(inter):  # the interaction found is a subset of another interaction
                     add_inter = False
             if add_inter:
                 inter_found.add(tuple(sorted(feat_used)))
@@ -3565,8 +3578,7 @@ def check_interactions(est, max_interactions):
     X, y = make_synthetic_regression(n_samples=400, n_features=30)
     train_data = lgb.Dataset(X, label=y)
     # check that limiting the number of distinct interactions to the number of trees is equivalent to no constraint
-    params = {'verbose': -1,
-              'seed': 0}
+    params = {"verbose": -1, "seed": 0}
     est = lgb.train(params, train_data, num_boost_round=100)
     pred1 = est.predict(X)
 
@@ -3590,7 +3602,6 @@ def check_interactions(est, max_interactions):
     check_interactions(est, max_interactions)
 
 
-
 def test_linear_trees_num_threads():
     # check that number of threads does not affect result
     np.random.seed(0)