Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pr4 advanced method monotone constraints #3264

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
95be175
No need to pass the tree to all fuctions related to monotone constrai…
Jun 10, 2020
bb6668c
Fix OppositeChildShouldBeUpdated numerical split optimisation.
Jun 10, 2020
75ca708
No need to use constraints when computing the output of the root.
Jun 10, 2020
38b9ab1
Refactor existing constraints.
Jun 10, 2020
447eb3b
Add advanced constraints method.
Jun 10, 2020
d7e8a9e
Update tests.
Jun 10, 2020
8bee2cb
Add override.
Jul 29, 2020
0029358
linting.
Jul 31, 2020
5acfb14
Add override.
CharlesAuguste Aug 7, 2020
770f93f
Simplify condition in LeftRightContainsRelevantInformation.
CharlesAuguste Aug 9, 2020
e1ed799
Add virtual destructor to FeatureConstraint.
CharlesAuguste Aug 9, 2020
af52340
Remove redundant blank line.
CharlesAuguste Aug 9, 2020
04c53e7
linting of else.
CharlesAuguste Aug 9, 2020
2e13eaf
Indentation.
CharlesAuguste Aug 9, 2020
b9443b3
Lint else.
CharlesAuguste Aug 9, 2020
12f67d7
Replaced non-const reference by pointers.
CharlesAuguste Aug 9, 2020
6a5d2ed
Forgotten reference.
CharlesAuguste Aug 23, 2020
e78a5bc
Leverage USE_MC for efficiency.
CharlesAuguste Aug 23, 2020
6801322
Make constraints const again in feature_histogram.hpp.
CharlesAuguste Aug 23, 2020
7fc04cf
Update docs.
CharlesAuguste Aug 24, 2020
7f1c05a
Add "advanced" to the monotone constraints options.
CharlesAuguste Aug 30, 2020
24290e0
Update monotone constraints restrictions.
CharlesAuguste Sep 12, 2020
56bc0da
Fix loop iterator.
Sep 12, 2020
e47148f
Fix loop iterator.
Sep 12, 2020
bea1edd
Remove superfluous parenthesis.
CharlesAuguste Sep 12, 2020
8cf7aa8
Fix loop iterator.
Sep 12, 2020
81226b8
Fix loop iterator.
Sep 12, 2020
6b66558
Fix loop iterator.
Sep 12, 2020
250dfe7
Fix loop iterator.
Sep 12, 2020
73d9752
Fix loop iterator.
Sep 12, 2020
afa744f
Fix loop iterator.
Sep 12, 2020
7e9987b
Fix loop iterator.
Sep 12, 2020
9da9d09
Fix loop iterator.
Sep 12, 2020
184c4ef
Remove std namespace qualifier.
CharlesAuguste Sep 12, 2020
e9f6953
Fix unsigned_int size_t comparison.
CharlesAuguste Sep 12, 2020
1b38dc4
Set num_features as int for consistency with the rest of the codebase.
CharlesAuguste Sep 12, 2020
21f32d2
Make sure constraints exist before recomputing them.
CharlesAuguste Sep 12, 2020
609f78a
Initialize previous constraints in UpdateConstraints.
CharlesAuguste Sep 12, 2020
f554a24
Update monotone constraints restrictions.
CharlesAuguste Sep 14, 2020
6b3d73d
Refactor UpdateConstraints loop.
CharlesAuguste Sep 14, 2020
5774cf4
Update src/io/config.cpp
Sep 14, 2020
6ec24f4
Delete white spaces.
CharlesAuguste Sep 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
No need to pass the tree to all fuctions related to monotone constrai…
…nts because the pointer is shared.
  • Loading branch information
Charles Auguste authored and CharlesAuguste committed Sep 20, 2020
commit 95be175046a582b1c787fc3f9b95329baebf2a45
71 changes: 35 additions & 36 deletions src/treelearner/monotone_constraints.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ class LeafConstraintsBase {
virtual ~LeafConstraintsBase() {}
virtual const ConstraintEntry& Get(int leaf_idx) const = 0;
virtual void Reset() = 0;
virtual void BeforeSplit(const Tree* tree, int leaf, int new_leaf,
virtual void BeforeSplit(int leaf, int new_leaf,
int8_t monotone_type) = 0;
virtual std::vector<int> Update(
const Tree* tree, bool is_numerical_split,
bool is_numerical_split,
int leaf, int new_leaf, int8_t monotone_type, double right_output,
double left_output, int split_feature, const SplitInfo& split_info,
const std::vector<SplitInfo>& best_split_per_leaf) = 0;
Expand All @@ -78,7 +78,7 @@ class LeafConstraintsBase {
tree_ = tree;
}

private:
protected:
const Tree* tree_;
};

Expand All @@ -94,10 +94,9 @@ class BasicLeafConstraints : public LeafConstraintsBase {
}
}

void BeforeSplit(const Tree*, int, int, int8_t) override {}
void BeforeSplit(int, int, int8_t) override {}

std::vector<int> Update(const Tree*,
bool is_numerical_split, int leaf, int new_leaf,
std::vector<int> Update(bool is_numerical_split, int leaf, int new_leaf,
int8_t monotone_type, double right_output,
double left_output, int, const SplitInfo& ,
const std::vector<SplitInfo>&) override {
Expand Down Expand Up @@ -138,7 +137,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
leaves_to_update_.clear();
}

void BeforeSplit(const Tree* tree, int leaf, int new_leaf,
void BeforeSplit(int leaf, int new_leaf,
int8_t monotone_type) override {
if (monotone_type != 0 || leaf_is_in_monotone_subtree_[leaf]) {
leaf_is_in_monotone_subtree_[leaf] = true;
Expand All @@ -148,7 +147,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
CHECK_GE(new_leaf - 1, 0);
CHECK_LT(static_cast<size_t>(new_leaf - 1), node_parent_.size());
#endif
node_parent_[new_leaf - 1] = tree->leaf_parent(leaf);
node_parent_[new_leaf - 1] = tree_->leaf_parent(leaf);
}

void UpdateConstraintsWithOutputs(bool is_numerical_split, int leaf,
Expand All @@ -166,7 +165,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
}
}

std::vector<int> Update(const Tree* tree, bool is_numerical_split, int leaf,
std::vector<int> Update(bool is_numerical_split, int leaf,
int new_leaf, int8_t monotone_type,
double right_output, double left_output,
int split_feature, const SplitInfo& split_info,
Expand All @@ -177,7 +176,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
monotone_type, right_output, left_output);

// Initialize variables to store information while going up the tree
int depth = tree->leaf_depth(new_leaf) - 1;
int depth = tree_->leaf_depth(new_leaf) - 1;

std::vector<int> features_of_splits_going_up_from_original_leaf;
std::vector<uint32_t> thresholds_of_splits_going_up_from_original_leaf;
Expand All @@ -187,7 +186,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
thresholds_of_splits_going_up_from_original_leaf.reserve(depth);
was_original_leaf_right_child_of_split.reserve(depth);

GoUpToFindLeavesToUpdate(tree, tree->leaf_parent(new_leaf),
GoUpToFindLeavesToUpdate(tree_->leaf_parent(new_leaf),
&features_of_splits_going_up_from_original_leaf,
&thresholds_of_splits_going_up_from_original_leaf,
&was_original_leaf_right_child_of_split,
Expand Down Expand Up @@ -232,7 +231,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
// Recursive function that goes up the tree, and then down to find leaves that
// have constraints to be updated
void GoUpToFindLeavesToUpdate(
const Tree* tree, int node_idx,
int node_idx,
std::vector<int>* features_of_splits_going_up_from_original_leaf,
std::vector<uint32_t>* thresholds_of_splits_going_up_from_original_leaf,
std::vector<bool>* was_original_leaf_right_child_of_split,
Expand All @@ -245,11 +244,11 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
int parent_idx = node_parent_[node_idx];
// if not at the root
if (parent_idx != -1) {
int inner_feature = tree->split_feature_inner(parent_idx);
int feature = tree->split_feature(parent_idx);
int inner_feature = tree_->split_feature_inner(parent_idx);
int feature = tree_->split_feature(parent_idx);
int8_t monotone_type = config_->monotone_constraints[feature];
bool is_in_right_child = tree->right_child(parent_idx) == node_idx;
bool is_split_numerical = tree->IsNumericalSplit(node_idx);
bool is_in_right_child = tree_->right_child(parent_idx) == node_idx;
bool is_split_numerical = tree_->IsNumericalSplit(node_idx);

// this is just an optimisation not to waste time going down in subtrees
// where there won't be any leaf to update
Expand All @@ -264,8 +263,8 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
if (monotone_type != 0) {
// these variables correspond to the current split we encounter going
// up the tree
int left_child_idx = tree->left_child(parent_idx);
int right_child_idx = tree->right_child(parent_idx);
int left_child_idx = tree_->left_child(parent_idx);
int right_child_idx = tree_->right_child(parent_idx);
bool left_child_is_curr_idx = (left_child_idx == node_idx);
int opposite_child_idx =
(left_child_is_curr_idx) ? right_child_idx : left_child_idx;
Expand All @@ -277,7 +276,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
// so the code needs to go down in the the opposite child
// to see which leaves' constraints need to be updated
GoDownToFindLeavesToUpdate(
tree, opposite_child_idx,
opposite_child_idx,
*features_of_splits_going_up_from_original_leaf,
*thresholds_of_splits_going_up_from_original_leaf,
*was_original_leaf_right_child_of_split,
Expand All @@ -290,24 +289,24 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
// is actually contiguous to the original 2 leaves and should be updated
// so the variables associated with the split need to be recorded
was_original_leaf_right_child_of_split->push_back(
tree->right_child(parent_idx) == node_idx);
tree_->right_child(parent_idx) == node_idx);
thresholds_of_splits_going_up_from_original_leaf->push_back(
tree->threshold_in_bin(parent_idx));
tree_->threshold_in_bin(parent_idx));
features_of_splits_going_up_from_original_leaf->push_back(
tree->split_feature_inner(parent_idx));
tree_->split_feature_inner(parent_idx));
}

// since current node is not the root, keep going up
GoUpToFindLeavesToUpdate(
tree, parent_idx, features_of_splits_going_up_from_original_leaf,
parent_idx, features_of_splits_going_up_from_original_leaf,
thresholds_of_splits_going_up_from_original_leaf,
was_original_leaf_right_child_of_split, split_feature, split_info,
split_threshold, best_split_per_leaf);
}
}

void GoDownToFindLeavesToUpdate(
const Tree* tree, int node_idx,
int node_idx,
const std::vector<int>& features_of_splits_going_up_from_original_leaf,
const std::vector<uint32_t>&
thresholds_of_splits_going_up_from_original_leaf,
Expand Down Expand Up @@ -345,9 +344,9 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {

#ifdef DEBUG
if (update_max_constraints) {
CHECK_GE(min_max_constraints.first, tree->LeafOutput(leaf_idx));
CHECK_GE(min_max_constraints.first, tree_->LeafOutput(leaf_idx));
} else {
CHECK_LE(min_max_constraints.second, tree->LeafOutput(leaf_idx));
CHECK_LE(min_max_constraints.second, tree_->LeafOutput(leaf_idx));
}
#endif
// depending on which split made the current leaf and the original leaves contiguous,
Expand All @@ -368,12 +367,12 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
} else { // if node
// check if the children are contiguous with the original leaf
std::pair<bool, bool> keep_going_left_right = ShouldKeepGoingLeftRight(
tree, node_idx, features_of_splits_going_up_from_original_leaf,
node_idx, features_of_splits_going_up_from_original_leaf,
thresholds_of_splits_going_up_from_original_leaf,
was_original_leaf_right_child_of_split);
int inner_feature = tree->split_feature_inner(node_idx);
uint32_t threshold = tree->threshold_in_bin(node_idx);
bool is_split_numerical = tree->IsNumericalSplit(node_idx);
int inner_feature = tree_->split_feature_inner(node_idx);
uint32_t threshold = tree_->threshold_in_bin(node_idx);
bool is_split_numerical = tree_->IsNumericalSplit(node_idx);
bool use_left_leaf_for_update_right = true;
bool use_right_leaf_for_update_left = true;
// if the split is on the same feature (categorical variables not supported)
Expand All @@ -392,7 +391,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
// go down left
if (keep_going_left_right.first) {
GoDownToFindLeavesToUpdate(
tree, tree->left_child(node_idx),
tree_->left_child(node_idx),
features_of_splits_going_up_from_original_leaf,
thresholds_of_splits_going_up_from_original_leaf,
was_original_leaf_right_child_of_split, update_max_constraints,
Expand All @@ -403,7 +402,7 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
// go down right
if (keep_going_left_right.second) {
GoDownToFindLeavesToUpdate(
tree, tree->right_child(node_idx),
tree_->right_child(node_idx),
features_of_splits_going_up_from_original_leaf,
thresholds_of_splits_going_up_from_original_leaf,
was_original_leaf_right_child_of_split, update_max_constraints,
Expand All @@ -415,14 +414,14 @@ class IntermediateLeafConstraints : public BasicLeafConstraints {
}

std::pair<bool, bool> ShouldKeepGoingLeftRight(
const Tree* tree, int node_idx,
int node_idx,
const std::vector<int>& features_of_splits_going_up_from_original_leaf,
const std::vector<uint32_t>&
thresholds_of_splits_going_up_from_original_leaf,
const std::vector<bool>& was_original_leaf_right_child_of_split) {
int inner_feature = tree->split_feature_inner(node_idx);
uint32_t threshold = tree->threshold_in_bin(node_idx);
bool is_split_numerical = tree->IsNumericalSplit(node_idx);
int inner_feature = tree_->split_feature_inner(node_idx);
uint32_t threshold = tree_->threshold_in_bin(node_idx);
bool is_split_numerical = tree_->IsNumericalSplit(node_idx);

bool keep_going_right = true;
bool keep_going_left = true;
Expand Down
4 changes: 2 additions & 2 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
auto next_leaf_id = tree->NextLeafId();

// update before tree split
constraints_->BeforeSplit(tree, best_leaf, next_leaf_id,
constraints_->BeforeSplit(best_leaf, next_leaf_id,
best_split_info.monotone_type);

bool is_numerical_split =
Expand Down Expand Up @@ -657,7 +657,7 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
best_split_info.left_output);
}
auto leaves_need_update = constraints_->Update(
tree, is_numerical_split, *left_leaf, *right_leaf,
is_numerical_split, *left_leaf, *right_leaf,
best_split_info.monotone_type, best_split_info.right_output,
best_split_info.left_output, inner_feature_index, best_split_info,
best_split_per_leaf_);
Expand Down