Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clear split info buffer in cost efficient gradient boosting before every iteration (fix partially #3679) #5164

Merged
merged 9 commits into from
Jun 8, 2022
13 changes: 13 additions & 0 deletions src/treelearner/cost_effective_gradient_boosting.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class CostEfficientGradientBoosting {
return true;
}
}

void Init() {
auto train_data = tree_learner_->train_data_;
if (!init_) {
Expand Down Expand Up @@ -63,6 +64,17 @@ class CostEfficientGradientBoosting {
}
init_ = true;
}

void BeforeTrain() {
// clear the splits in splits_per_leaf_
const int num_total_splits = static_cast<int>(splits_per_leaf_.size());
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
const int num_threads = OMP_NUM_THREADS();
#pragma omp parallel for schedule(static) num_threads(num_threads)
for (int i = 0; i < num_total_splits; ++i) {
splits_per_leaf_[i].Reset();
}
}

double DetlaGain(int feature_index, int real_fidx, int leaf_index,
int num_data_in_leaf, SplitInfo split_info) {
auto config = tree_learner_->config_;
Expand All @@ -82,6 +94,7 @@ class CostEfficientGradientBoosting {
feature_index] = split_info;
return delta;
}

void UpdateLeafBestSplits(Tree* tree, int best_leaf,
const SplitInfo* best_split_info,
std::vector<SplitInfo>* best_split_per_leaf) {
Expand Down
2 changes: 2 additions & 0 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ void SerialTreeLearner::BeforeTrain() {
}

larger_leaf_splits_->Init();

cegb_->BeforeTrain();
}

bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
Expand Down