Skip to content

Commit

Permalink
memory corruption fix for distributed data parallel version before Sy…
Browse files Browse the repository at this point in the history
…ncUpGlobalBestSplit (#3110)

* memory corruption fix for distributed data parallel version before SyncUpGlobalBestSplit

* updated based on comments

* updated voting and feature parallel based on comments

* fixing macos failure

* rename variable
  • Loading branch information
imatiach-msft authored May 26, 2020
1 parent 51b84df commit 8ead7cc
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 4 deletions.
8 changes: 7 additions & 1 deletion src/treelearner/data_parallel_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,14 @@ void DataParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, boo
// Get local rank and global machine size
rank_ = Network::rank();
num_machines_ = Network::num_machines();

auto max_cat_threshold = this->config_->max_cat_threshold;
// need to be able to hold smaller and larger best splits in SyncUpGlobalBestSplit
size_t split_info_size = static_cast<size_t>(SplitInfo::Size(max_cat_threshold) * 2);
size_t histogram_size = static_cast<size_t>(this->train_data_->NumTotalBin() * kHistEntrySize);

// allocate buffer for communication
size_t buffer_size = this->train_data_->NumTotalBin() * kHistEntrySize;
size_t buffer_size = std::max(histogram_size, split_info_size);

input_buffer_.resize(buffer_size);
output_buffer_.resize(buffer_size);
Expand Down
9 changes: 7 additions & 2 deletions src/treelearner/feature_parallel_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,13 @@ void FeatureParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data,
TREELEARNER_T::Init(train_data, is_constant_hessian);
rank_ = Network::rank();
num_machines_ = Network::num_machines();
input_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->config_->max_cat_threshold) * 2);
output_buffer_.resize((sizeof(SplitInfo) + sizeof(uint32_t) * this->config_->max_cat_threshold) * 2);

auto max_cat_threshold = this->config_->max_cat_threshold;
// need to be able to hold smaller and larger best splits in SyncUpGlobalBestSplit
int split_info_size = SplitInfo::Size(max_cat_threshold) * 2;

input_buffer_.resize(split_info_size);
output_buffer_.resize(split_info_size);
}


Expand Down
2 changes: 1 addition & 1 deletion src/treelearner/split_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct SplitInfo {
bool default_left = true;
int8_t monotone_type = 0;
inline static int Size(int max_cat_threshold) {
return 2 * sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 9 + sizeof(data_size_t) * 2 + max_cat_threshold * sizeof(uint32_t) + sizeof(int8_t);
return 2 * sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 7 + sizeof(data_size_t) * 2 + max_cat_threshold * sizeof(uint32_t) + sizeof(int8_t);
}

inline void CopyTo(char* buffer) const {
Expand Down
4 changes: 4 additions & 0 deletions src/treelearner/voting_parallel_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
}
// calculate buffer size
size_t buffer_size = 2 * top_k_ * std::max(max_bin * kHistEntrySize, sizeof(LightSplitInfo) * num_machines_);
auto max_cat_threshold = this->config_->max_cat_threshold;
// need to be able to hold smaller and larger best splits in SyncUpGlobalBestSplit
size_t split_info_size = static_cast<size_t>(SplitInfo::Size(max_cat_threshold) * 2);
buffer_size = std::max(buffer_size, split_info_size);
// left and right on same time, so need double size
input_buffer_.resize(buffer_size);
output_buffer_.resize(buffer_size);
Expand Down

0 comments on commit 8ead7cc

Please sign in to comment.