Skip to content

Commit

Permalink
[fix] fix quantized training (fixes #5982) (fixes #5994) (#6092)
Browse files Browse the repository at this point in the history
* fix leaf splits update after split in quantized training

* fix preparation ordered gradients for quantized training

* remove force_row_wise in distributed test for quantized training

* Update src/treelearner/leaf_splits.hpp

---------

Co-authored-by: James Lamb <[email protected]>
  • Loading branch information
shiyu1994 and jameslamb authored Sep 12, 2023
1 parent cd39520 commit a92bf37
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 32 deletions.
37 changes: 25 additions & 12 deletions src/io/dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1278,21 +1278,34 @@ void Dataset::ConstructHistogramsInner(
auto ptr_ordered_grad = gradients;
auto ptr_ordered_hess = hessians;
if (num_used_dense_group > 0) {
if (USE_INDICES) {
if (USE_HESSIAN) {
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
if (USE_QUANT_GRAD) {
int16_t* ordered_gradients_and_hessians = reinterpret_cast<int16_t*>(ordered_gradients);
const int16_t* gradients_and_hessians = reinterpret_cast<const int16_t*>(gradients);
if (USE_INDICES) {
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
for (data_size_t i = 0; i < num_data; ++i) {
ordered_gradients[i] = gradients[data_indices[i]];
ordered_hessians[i] = hessians[data_indices[i]];
ordered_gradients_and_hessians[i] = gradients_and_hessians[data_indices[i]];
}
ptr_ordered_grad = ordered_gradients;
ptr_ordered_hess = ordered_hessians;
} else {
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
for (data_size_t i = 0; i < num_data; ++i) {
ordered_gradients[i] = gradients[data_indices[i]];
ptr_ordered_grad = reinterpret_cast<const score_t*>(ordered_gradients);
ptr_ordered_hess = nullptr;
}
} else {
if (USE_INDICES) {
if (USE_HESSIAN) {
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
for (data_size_t i = 0; i < num_data; ++i) {
ordered_gradients[i] = gradients[data_indices[i]];
ordered_hessians[i] = hessians[data_indices[i]];
}
ptr_ordered_grad = ordered_gradients;
ptr_ordered_hess = ordered_hessians;
} else {
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
for (data_size_t i = 0; i < num_data; ++i) {
ordered_gradients[i] = gradients[data_indices[i]];
}
ptr_ordered_grad = ordered_gradients;
}
ptr_ordered_grad = ordered_gradients;
}
}
OMP_INIT_EX();
Expand Down
19 changes: 19 additions & 0 deletions src/treelearner/leaf_splits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ class LeafSplits {
weight_ = weight;
}

/*!
* \brief Init split on current leaf on partial data.
* \param leaf Index of current leaf
* \param data_partition current data partition
* \param sum_gradients
* \param sum_hessians
* \param sum_gradients_and_hessians
* \param weight
*/
void Init(int leaf, const DataPartition* data_partition, double sum_gradients,
double sum_hessians, int64_t sum_gradients_and_hessians, double weight) {
leaf_index_ = leaf;
data_indices_ = data_partition->GetIndexOnLeaf(leaf, &num_data_in_leaf_);
sum_gradients_ = sum_gradients;
sum_hessians_ = sum_hessians;
int_sum_gradients_and_hessians_ = sum_gradients_and_hessians;
weight_ = weight;
}

/*!
* \brief Init split on current leaf on partial data.
* \param leaf Index of current leaf
Expand Down
115 changes: 96 additions & 19 deletions src/treelearner/serial_tree_learner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -841,32 +841,65 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
#endif

// init the leaves that used on next iteration
if (best_split_info.left_count < best_split_info.right_count) {
CHECK_GT(best_split_info.left_count, 0);
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian,
best_split_info.left_output);
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
best_split_info.right_output);
if (!config_->use_quantized_grad) {
if (best_split_info.left_count < best_split_info.right_count) {
CHECK_GT(best_split_info.left_count, 0);
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian,
best_split_info.left_output);
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
best_split_info.right_output);
} else {
CHECK_GT(best_split_info.right_count, 0);
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
best_split_info.right_output);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian,
best_split_info.left_output);
}
} else {
CHECK_GT(best_split_info.right_count, 0);
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
best_split_info.right_output);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian,
best_split_info.left_output);
if (best_split_info.left_count < best_split_info.right_count) {
CHECK_GT(best_split_info.left_count, 0);
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian,
best_split_info.left_sum_gradient_and_hessian,
best_split_info.left_output);
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
best_split_info.right_sum_gradient_and_hessian,
best_split_info.right_output);
} else {
CHECK_GT(best_split_info.right_count, 0);
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian,
best_split_info.right_sum_gradient_and_hessian,
best_split_info.right_output);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(),
best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian,
best_split_info.left_sum_gradient_and_hessian,
best_split_info.left_output);
}
}
if (config_->use_quantized_grad && config_->tree_learner != std::string("data")) {
gradient_discretizer_->SetNumBitsInHistogramBin<false>(*left_leaf, *right_leaf,
data_partition_->leaf_count(*left_leaf),
data_partition_->leaf_count(*right_leaf));
}

#ifdef DEBUG
CheckSplit(best_split_info, *left_leaf, *right_leaf);
#endif

auto leaves_need_update = constraints_->Update(
is_numerical_split, *left_leaf, *right_leaf,
best_split_info.monotone_type, best_split_info.right_output,
Expand Down Expand Up @@ -1024,4 +1057,48 @@ std::vector<int8_t> node_used_features = col_sampler_.GetByNode(tree, leaf);
*split = bests[best_idx];
}

#ifdef DEBUG
void SerialTreeLearner::CheckSplit(const SplitInfo& best_split_info, const int left_leaf_index, const int right_leaf_index) {
data_size_t num_data_in_left = 0;
data_size_t num_data_in_right = 0;
const data_size_t* data_indices_in_left = data_partition_->GetIndexOnLeaf(left_leaf_index, &num_data_in_left);
const data_size_t* data_indices_in_right = data_partition_->GetIndexOnLeaf(right_leaf_index, &num_data_in_right);
if (config_->use_quantized_grad) {
int32_t sum_left_gradient = 0;
int32_t sum_left_hessian = 0;
int32_t sum_right_gradient = 0;
int32_t sum_right_hessian = 0;
const int8_t* discretized_grad_and_hess = gradient_discretizer_->discretized_gradients_and_hessians();
for (data_size_t i = 0; i < num_data_in_left; ++i) {
const data_size_t index = data_indices_in_left[i];
sum_left_gradient += discretized_grad_and_hess[2 * index + 1];
sum_left_hessian += discretized_grad_and_hess[2 * index];
}
for (data_size_t i = 0; i < num_data_in_right; ++i) {
const data_size_t index = data_indices_in_right[i];
sum_right_gradient += discretized_grad_and_hess[2 * index + 1];
sum_right_hessian += discretized_grad_and_hess[2 * index];
}
Log::Warning("============================ start leaf split info ============================");
Log::Warning("left_leaf_index = %d, right_leaf_index = %d", left_leaf_index, right_leaf_index);
Log::Warning("num_data_in_left = %d, num_data_in_right = %d", num_data_in_left, num_data_in_right);
Log::Warning("sum_left_gradient = %d, best_split_info->left_sum_gradient_and_hessian.gradient = %d", sum_left_gradient,
static_cast<int32_t>(best_split_info.left_sum_gradient_and_hessian >> 32));
Log::Warning("sum_left_hessian = %d, best_split_info->left_sum_gradient_and_hessian.hessian = %d", sum_left_hessian,
static_cast<int32_t>(best_split_info.left_sum_gradient_and_hessian & 0x00000000ffffffff));
Log::Warning("sum_right_gradient = %d, best_split_info->right_sum_gradient_and_hessian.gradient = %d", sum_right_gradient,
static_cast<int32_t>(best_split_info.right_sum_gradient_and_hessian >> 32));
Log::Warning("sum_right_hessian = %d, best_split_info->right_sum_gradient_and_hessian.hessian = %d", sum_right_hessian,
static_cast<int32_t>(best_split_info.right_sum_gradient_and_hessian & 0x00000000ffffffff));
CHECK_EQ(num_data_in_left, best_split_info.left_count);
CHECK_EQ(num_data_in_right, best_split_info.right_count);
CHECK_EQ(sum_left_gradient, static_cast<int32_t>(best_split_info.left_sum_gradient_and_hessian >> 32))
CHECK_EQ(sum_left_hessian, static_cast<int32_t>(best_split_info.left_sum_gradient_and_hessian & 0x00000000ffffffff));
CHECK_EQ(sum_right_gradient, static_cast<int32_t>(best_split_info.right_sum_gradient_and_hessian >> 32));
CHECK_EQ(sum_right_hessian, static_cast<int32_t>(best_split_info.right_sum_gradient_and_hessian & 0x00000000ffffffff));
Log::Warning("============================ end leaf split info ============================");
}
}
#endif

} // namespace LightGBM
2 changes: 2 additions & 0 deletions src/treelearner/serial_tree_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ class SerialTreeLearner: public TreeLearner {

std::set<int> FindAllForceFeatures(Json force_split_leaf_setting);

#ifdef DEBUG
void CheckSplit(const SplitInfo& best_split_info, const int left_leaf_index, const int right_leaf_index);
#endif

/*!
* \brief Get the number of data in a leaf
Expand Down
1 change: 0 additions & 1 deletion tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1838,7 +1838,6 @@ def test_distributed_quantized_training(cluster):
'num_grad_quant_bins': 30,
'quant_train_renew_leaf': True,
'verbose': -1,
'force_row_wise': True,
}

quant_dask_classifier = lgb.DaskLGBMRegressor(
Expand Down

0 comments on commit a92bf37

Please sign in to comment.