Skip to content

Commit

Permalink
Interaction constraints (#3126)
Browse files Browse the repository at this point in the history
* Add interaction constraints functionality.

* Minor fixes.

* Minor fixes.

* Change lambda to function.

* Fix gpu bug, remove extra blank lines.

* Fix gpu bug.

* Fix style issues.

* Try to fix segfault on MACOS.

* Fix bug.

* Fix bug.

* Fix bugs.

* Change parameter format for R.

* Fix R style issues.

* Change string formatting code.

* Change docs to say R package not supported.

* Remove R functionality, moving to separate PR.

* Keep track of branch features in tree object.

* Only track branch features when feature interactions are enabled.

* Fix lint error.

* Update docs and simplify tests.
  • Loading branch information
btrotta authored Jun 23, 2020
1 parent f5e5164 commit bca2da9
Show file tree
Hide file tree
Showing 21 changed files with 233 additions and 44 deletions.
4 changes: 4 additions & 0 deletions R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ lgb.train <- function(params = list(),
end_iteration <- begin_iteration + nrounds - 1L
}

if (!is.null(params[["interaction_constraints"]])) {
stop("lgb.train: interaction_constraints is not implemented")
}

# Update parameters with parsed parameters
data$update_params(params)

Expand Down
14 changes: 14 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,20 @@ Learning Control Parameters

- note that the parent output ``w_p`` itself has smoothing applied, unless it is the root node, so that the smoothing effect accumulates with the tree depth

- ``interaction_constraints`` :raw-html:`<a id="interaction_constraints" title="Permalink to this parameter" href="#interaction_constraints">&#x1F517;&#xFE0E;</a>`, default = ``""``, type = string

- controls which features can appear in the same branch

- by default interaction constraints are disabled, to enable them you can specify

- for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]``

- for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]``

- for R-package, **not yet supported**

- any two features can only appear in the same branch only if there exists a constraint containing both features

- ``verbosity`` :raw-html:`<a id="verbosity" title="Permalink to this parameter" href="#verbosity">&#x1F517;&#xFE0E;</a>`, default = ``1``, type = int, aliases: ``verbose``

- controls the level of LightGBM's verbosity
Expand Down
10 changes: 10 additions & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,14 @@ struct Config {
// descl2 = note that the parent output ``w_p`` itself has smoothing applied, unless it is the root node, so that the smoothing effect accumulates with the tree depth
double path_smooth = 0;

// desc = controls which features can appear in the same branch
// desc = by default interaction constraints are disabled, to enable them you can specify
// descl2 = for CLI, lists separated by commas, e.g. ``[0,1,2],[2,3]``
// descl2 = for Python-package, list of lists, e.g. ``[[0, 1, 2], [2, 3]]``
// descl2 = for R-package, **not yet supported**
// desc = any two features can only appear in the same branch only if there exists a constraint containing both features
std::string interaction_constraints = "";

// alias = verbose
// desc = controls the level of LightGBM's verbosity
// desc = ``< 0``: Fatal, ``= 0``: Error (Warning), ``= 1``: Info, ``> 1``: Debug
Expand Down Expand Up @@ -958,12 +966,14 @@ struct Config {
static const std::unordered_map<std::string, std::string>& alias_table();
static const std::unordered_set<std::string>& parameter_set();
std::vector<std::vector<double>> auc_mu_weights_matrix;
std::vector<std::vector<int>> interaction_constraints_vector;

private:
void CheckParamConflict();
void GetMembersFromString(const std::unordered_map<std::string, std::string>& params);
std::string SaveMembersToString() const;
void GetAucMuWeights();
void GetInteractionConstraints();
};

inline bool Config::GetString(
Expand Down
15 changes: 14 additions & 1 deletion include/LightGBM/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ class Tree {
/*!
* \brief Constructor
* \param max_leaves The number of max leaves
* \param track_branch_features Whether to keep track of ancestors of leaf nodes
*/
explicit Tree(int max_leaves);
explicit Tree(int max_leaves, bool track_branch_features);

/*!
* \brief Constructor, from a string
Expand Down Expand Up @@ -148,6 +149,9 @@ class Tree {
/*! \brief Get feature of specific split*/
inline int split_feature(int split_idx) const { return split_feature_[split_idx]; }

/*! \brief Get features on leaf's branch*/
inline std::vector<int> branch_features(int leaf) const { return branch_features_[leaf]; }

inline double split_gain(int split_idx) const { return split_gain_[split_idx]; }

inline double internal_value(int node_idx) const {
Expand Down Expand Up @@ -436,6 +440,10 @@ class Tree {
std::vector<int> internal_count_;
/*! \brief Depth for leaves */
std::vector<int> leaf_depth_;
/*! \brief whether to keep track of ancestor nodes for each leaf (only needed when feature interactions are restricted) */
bool track_branch_features_;
/*! \brief Features on leaf's branch, original index */
std::vector<std::vector<int>> branch_features_;
double shrinkage_;
int max_depth_;
};
Expand Down Expand Up @@ -477,6 +485,11 @@ inline void Tree::Split(int leaf, int feature, int real_feature,
// update leaf depth
leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1;
leaf_depth_[leaf]++;
if (track_branch_features_) {
branch_features_[num_leaves_] = branch_features_[leaf];
branch_features_[num_leaves_].push_back(split_feature_[new_node_idx]);
branch_features_[leaf].push_back(split_feature_[new_node_idx]);
}
}

inline double Tree::Predict(const double* feature_values) const {
Expand Down
35 changes: 35 additions & 0 deletions include/LightGBM/utils/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,30 @@ inline static std::vector<std::string> Split(const char* c_str, char delimiter)
return ret;
}

inline static std::vector<std::string> SplitBrackets(const char* c_str, char left_delimiter, char right_delimiter) {
std::vector<std::string> ret;
std::string str(c_str);
size_t i = 0;
size_t pos = 0;
bool open = false;
while (pos < str.length()) {
if (str[pos] == left_delimiter) {
open = true;
++pos;
i = pos;
} else if (str[pos] == right_delimiter && open) {
if (i < pos) {
ret.push_back(str.substr(i, pos - i));
}
open = false;
++pos;
} else {
++pos;
}
}
return ret;
}

inline static std::vector<std::string> SplitLines(const char* c_str) {
std::vector<std::string> ret;
std::string str(c_str);
Expand Down Expand Up @@ -503,6 +527,17 @@ inline static std::vector<T> StringToArray(const std::string& str, char delimite
return ret;
}

template<typename T>
inline static std::vector<std::vector<T>> StringToArrayofArrays(
const std::string& str, char left_bracket, char right_bracket, char delimiter) {
std::vector<std::string> strs = SplitBrackets(str.c_str(), left_bracket, right_bracket);
std::vector<std::vector<T>> ret;
for (const auto& s : strs) {
ret.push_back(StringToArray<T>(s, delimiter));
}
return ret;
}

template<typename T>
inline static std::vector<T> StringToArray(const std::string& str, int n) {
if (n == 0) {
Expand Down
7 changes: 6 additions & 1 deletion python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,12 @@ def param_dict_to_str(data):
pairs = []
for key, val in data.items():
if isinstance(val, (list, tuple, set)) or is_numpy_1d_array(val):
pairs.append(str(key) + '=' + ','.join(map(str, val)))
def to_string(x):
if isinstance(x, list):
return "[{}]".format(','.join(map(str, x)))
else:
return str(x)
pairs.append(str(key) + '=' + ','.join(map(to_string, val)))
elif isinstance(val, string_type) or isinstance(val, numeric_types) or is_numeric(val):
pairs.append(str(key) + '=' + str(val))
elif val is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
bool should_continue = false;
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
const size_t offset = static_cast<size_t>(cur_tree_id) * num_data_;
std::unique_ptr<Tree> new_tree(new Tree(2));
std::unique_ptr<Tree> new_tree(new Tree(2, false));
if (class_need_train_[cur_tree_id] && train_data_->num_features() > 0) {
auto grad = gradients + offset;
auto hess = hessians + offset;
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/rf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class RF : public GBDT {
gradients = gradients_.data();
hessians = hessians_.data();
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
std::unique_ptr<Tree> new_tree(new Tree(2));
std::unique_ptr<Tree> new_tree(new Tree(2, false));
size_t offset = static_cast<size_t>(cur_tree_id)* num_data_;
if (class_need_train_[cur_tree_id]) {
auto grad = gradients + offset;
Expand Down
10 changes: 10 additions & 0 deletions src/io/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,14 @@ void Config::GetAucMuWeights() {
}
}

void Config::GetInteractionConstraints() {
if (interaction_constraints == "") {
interaction_constraints_vector = std::vector<std::vector<int>>();
} else {
interaction_constraints_vector = Common::StringToArrayofArrays<int>(interaction_constraints, '[', ']', ',');
}
}

void Config::Set(const std::unordered_map<std::string, std::string>& params) {
// generate seeds by seed.
if (GetInt(params, "seed", &seed)) {
Expand All @@ -204,6 +212,8 @@ void Config::Set(const std::unordered_map<std::string, std::string>& params) {

GetAucMuWeights();

GetInteractionConstraints();

// sort eval_at
std::sort(eval_at.begin(), eval_at.end());

Expand Down
4 changes: 4 additions & 0 deletions src/io/config_auto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ const std::unordered_set<std::string>& Config::parameter_set() {
"cegb_penalty_feature_lazy",
"cegb_penalty_feature_coupled",
"path_smooth",
"interaction_constraints",
"verbosity",
"input_model",
"output_model",
Expand Down Expand Up @@ -454,6 +455,8 @@ void Config::GetMembersFromString(const std::unordered_map<std::string, std::str
GetDouble(params, "path_smooth", &path_smooth);
CHECK_GE(path_smooth, 0.0);

GetString(params, "interaction_constraints", &interaction_constraints);

GetInt(params, "verbosity", &verbosity);

GetString(params, "input_model", &input_model);
Expand Down Expand Up @@ -659,6 +662,7 @@ std::string Config::SaveMembersToString() const {
str_buf << "[cegb_penalty_feature_lazy: " << Common::Join(cegb_penalty_feature_lazy, ",") << "]\n";
str_buf << "[cegb_penalty_feature_coupled: " << Common::Join(cegb_penalty_feature_coupled, ",") << "]\n";
str_buf << "[path_smooth: " << path_smooth << "]\n";
str_buf << "[interaction_constraints: " << interaction_constraints << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n";
str_buf << "[max_bin: " << max_bin << "]\n";
str_buf << "[max_bin_by_feature: " << Common::Join(max_bin_by_feature, ",") << "]\n";
Expand Down
7 changes: 5 additions & 2 deletions src/io/tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

namespace LightGBM {

Tree::Tree(int max_leaves)
:max_leaves_(max_leaves) {
Tree::Tree(int max_leaves, bool track_branch_features)
:max_leaves_(max_leaves), track_branch_features_(track_branch_features) {
left_child_.resize(max_leaves_ - 1);
right_child_.resize(max_leaves_ - 1);
split_feature_inner_.resize(max_leaves_ - 1);
Expand All @@ -32,6 +32,9 @@ Tree::Tree(int max_leaves)
internal_weight_.resize(max_leaves_ - 1);
internal_count_.resize(max_leaves_ - 1);
leaf_depth_.resize(max_leaves_);
if (track_branch_features_) {
branch_features_ = std::vector<std::vector<int>>(max_leaves_);
}
// root is in the depth 0
leaf_depth_[0] = 0;
num_leaves_ = 1;
Expand Down
76 changes: 69 additions & 7 deletions src/treelearner/col_sampler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <LightGBM/utils/random.h>

#include <algorithm>
#include <unordered_set>
#include <vector>

namespace LightGBM {
Expand All @@ -23,6 +24,10 @@ class ColSampler {
fraction_bynode_(config->feature_fraction_bynode),
seed_(config->feature_fraction_seed),
random_(config->feature_fraction_seed) {
for (auto constraint : config->interaction_constraints_vector) {
std::unordered_set<int> constraint_set(constraint.begin(), constraint.end());
interaction_constraints_.push_back(constraint_set);
}
}

static int GetCnt(size_t total_cnt, double fraction) {
Expand Down Expand Up @@ -83,32 +88,87 @@ class ColSampler {
}
}

std::vector<int8_t> GetByNode() {
if (fraction_bynode_ >= 1.0f) {
return std::vector<int8_t>(train_data_->num_features(), 1);
std::vector<int8_t> GetByNode(const Tree* tree, int leaf) {
// get interaction constraints for current branch
std::unordered_set<int> allowed_features;
if (!interaction_constraints_.empty()) {
std::vector<int> branch_features = tree->branch_features(leaf);
allowed_features.insert(branch_features.begin(), branch_features.end());
for (auto constraint : interaction_constraints_) {
int num_feat_found = 0;
if (branch_features.size() == 0) {
allowed_features.insert(constraint.begin(), constraint.end());
}
for (int feat : branch_features) {
if (constraint.count(feat) == 0) { break; }
++num_feat_found;
if (num_feat_found == static_cast<int>(branch_features.size())) {
allowed_features.insert(constraint.begin(), constraint.end());
break;
}
}
}
}

std::vector<int8_t> ret(train_data_->num_features(), 0);
if (fraction_bynode_ >= 1.0f) {
if (interaction_constraints_.empty()) {
return std::vector<int8_t>(train_data_->num_features(), 1);
} else {
for (int feat : allowed_features) {
int inner_feat = train_data_->InnerFeatureIndex(feat);
ret[inner_feat] = 1;
}
return ret;
}
}
if (need_reset_bytree_) {
auto used_feature_cnt = GetCnt(used_feature_indices_.size(), fraction_bynode_);
std::vector<int>* allowed_used_feature_indices;
std::vector<int> filtered_feature_indices;
if (interaction_constraints_.empty()) {
allowed_used_feature_indices = &used_feature_indices_;
} else {
for (int feat_ind : used_feature_indices_) {
if (allowed_features.count(valid_feature_indices_[feat_ind]) == 1) {
filtered_feature_indices.push_back(feat_ind);
}
}
used_feature_cnt = std::min(used_feature_cnt, static_cast<int>(filtered_feature_indices.size()));
allowed_used_feature_indices = &filtered_feature_indices;
}
auto sampled_indices = random_.Sample(
static_cast<int>(used_feature_indices_.size()), used_feature_cnt);
static_cast<int>((*allowed_used_feature_indices).size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(sampled_indices.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
for (int i = 0; i < omp_loop_size; ++i) {
int used_feature =
valid_feature_indices_[used_feature_indices_[sampled_indices[i]]];
valid_feature_indices_[(*allowed_used_feature_indices)[sampled_indices[i]]];
int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
ret[inner_feature_index] = 1;
}
} else {
auto used_feature_cnt =
GetCnt(valid_feature_indices_.size(), fraction_bynode_);
std::vector<int>* allowed_valid_feature_indices;
std::vector<int> filtered_feature_indices;
if (interaction_constraints_.empty()) {
allowed_valid_feature_indices = &valid_feature_indices_;
} else {
for (int feat : valid_feature_indices_) {
if (allowed_features.count(feat) == 1) {
filtered_feature_indices.push_back(feat);
}
}
allowed_valid_feature_indices = &filtered_feature_indices;
used_feature_cnt = std::min(used_feature_cnt, static_cast<int>(filtered_feature_indices.size()));
}
auto sampled_indices = random_.Sample(
static_cast<int>(valid_feature_indices_.size()), used_feature_cnt);
static_cast<int>((*allowed_valid_feature_indices).size()), used_feature_cnt);
int omp_loop_size = static_cast<int>(sampled_indices.size());
#pragma omp parallel for schedule(static, 512) if (omp_loop_size >= 1024)
for (int i = 0; i < omp_loop_size; ++i) {
int used_feature = valid_feature_indices_[sampled_indices[i]];
int used_feature = (*allowed_valid_feature_indices)[sampled_indices[i]];
int inner_feature_index = train_data_->InnerFeatureIndex(used_feature);
ret[inner_feature_index] = 1;
}
Expand All @@ -135,6 +195,8 @@ class ColSampler {
std::vector<int8_t> is_feature_used_;
std::vector<int> used_feature_indices_;
std::vector<int> valid_feature_indices_;
/*! \brief interaction constraints index in original (raw data) features */
std::vector<std::unordered_set<int>> interaction_constraints_;
};

} // namespace LightGBM
Expand Down
Loading

0 comments on commit bca2da9

Please sign in to comment.