diff --git a/src/treelearner/data_parallel_tree_learner.cpp b/src/treelearner/data_parallel_tree_learner.cpp index 98dca40edb0d..f91dcdc9b250 100644 --- a/src/treelearner/data_parallel_tree_learner.cpp +++ b/src/treelearner/data_parallel_tree_learner.cpp @@ -26,8 +26,14 @@ void DataParallelTreeLearner::Init(const Dataset* train_data, boo // Get local rank and global machine size rank_ = Network::rank(); num_machines_ = Network::num_machines(); + + auto max_cat_threshold = this->config_->max_cat_threshold; + // need to be able to hold smaller and larger best splits in SyncUpGlobalBestSplit + size_t split_info_size = static_cast(SplitInfo::Size(max_cat_threshold) * 2); + size_t histogram_size = static_cast(this->train_data_->NumTotalBin() * kHistEntrySize); + // allocate buffer for communication - size_t buffer_size = this->train_data_->NumTotalBin() * kHistEntrySize; + size_t buffer_size = std::max(histogram_size, split_info_size); input_buffer_.resize(buffer_size); output_buffer_.resize(buffer_size); diff --git a/src/treelearner/feature_parallel_tree_learner.cpp b/src/treelearner/feature_parallel_tree_learner.cpp index 1523f004eb8d..74df187d46b2 100644 --- a/src/treelearner/feature_parallel_tree_learner.cpp +++ b/src/treelearner/feature_parallel_tree_learner.cpp @@ -24,8 +24,13 @@ void FeatureParallelTreeLearner::Init(const Dataset* train_data, TREELEARNER_T::Init(train_data, is_constant_hessian); rank_ = Network::rank(); num_machines_ = Network::num_machines(); - input_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->config_->max_cat_threshold) * 2); - output_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->config_->max_cat_threshold) * 2); + + auto max_cat_threshold = this->config_->max_cat_threshold; + // need to be able to hold smaller and larger best splits in SyncUpGlobalBestSplit + int split_info_size = SplitInfo::Size(max_cat_threshold) * 2; + + input_buffer_.resize(split_info_size); + output_buffer_.resize(split_info_size); } diff --git a/src/treelearner/split_info.hpp b/src/treelearner/split_info.hpp index 86653522dd04..492434d5160f 100644 --- a/src/treelearner/split_info.hpp +++ b/src/treelearner/split_info.hpp @@ -49,7 +49,7 @@ struct SplitInfo { bool default_left = true; int8_t monotone_type = 0; inline static int Size(int max_cat_threshold) { - return 2 * sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 9 + sizeof(data_size_t) * 2 + max_cat_threshold * sizeof(uint32_t) + sizeof(int8_t); + return 2 * sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 7 + sizeof(data_size_t) * 2 + max_cat_threshold * sizeof(uint32_t) + sizeof(int8_t); } inline void CopyTo(char* buffer) const { diff --git a/src/treelearner/voting_parallel_tree_learner.cpp b/src/treelearner/voting_parallel_tree_learner.cpp index 660a633193d8..d14e0d614ce0 100644 --- a/src/treelearner/voting_parallel_tree_learner.cpp +++ b/src/treelearner/voting_parallel_tree_learner.cpp @@ -37,6 +37,10 @@ void VotingParallelTreeLearner::Init(const Dataset* train_data, b } // calculate buffer size size_t buffer_size = 2 * top_k_ * std::max(max_bin * kHistEntrySize, sizeof(LightSplitInfo) * num_machines_); + auto max_cat_threshold = this->config_->max_cat_threshold; + // need to be able to hold smaller and larger best splits in SyncUpGlobalBestSplit + size_t split_info_size = static_cast(SplitInfo::Size(max_cat_threshold) * 2); + buffer_size = std::max(buffer_size, split_info_size); // left and right on same time, so need double size input_buffer_.resize(buffer_size); output_buffer_.resize(buffer_size);