From 5e85438171dc560904663f30ef3d86517d9e5e47 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Tue, 27 Feb 2024 11:27:21 -0500 Subject: [PATCH] modify inference behavior of secure vertical from split value to index for training phase --- src/common/quantile.cc | 88 +++++++++++++++++++++++++------ src/learner.cc | 57 +++++++++++++++++++- src/tree/common_row_partitioner.h | 42 ++++++++++----- src/tree/hist/evaluate_splits.h | 34 ++++++++---- src/tree/updater_approx.cc | 2 +- src/tree/updater_quantile_hist.cc | 4 +- 6 files changed, 184 insertions(+), 43 deletions(-) diff --git a/src/common/quantile.cc b/src/common/quantile.cc index c3b0d431c35c..602073428eb1 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -5,6 +5,7 @@ #include #include +#include #include "../collective/aggregator.h" #include "../data/adapter.h" @@ -367,7 +368,7 @@ void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_b } template -void AddCutPointSecure(typename SketchType::SummaryContainer const &summary, int max_bin, +double AddCutPointSecure(typename SketchType::SummaryContainer const &summary, int max_bin, HistogramCuts *cuts) { // For secure vertical pipeline, we fill the cut values corresponding to empty columns // with a vector of minimum value @@ -388,12 +389,15 @@ void AddCutPointSecure(typename SketchType::SummaryContainer const &summary, int cut_values.push_back(cpt); } } + return cut_values.back(); } - // if empty column, fill the cut values with 0 + // if empty column, fill the cut values with NaN else { for (size_t i = 1; i < required_cuts; ++i) { - cut_values.push_back(0.0); + //cut_values.push_back(0.0); + cut_values.push_back(std::numeric_limits::quiet_NaN()); } + return std::numeric_limits::quiet_NaN(); } } @@ -448,6 +452,7 @@ void SketchContainerImpl::MakeCuts(Context const *ctx, MetaInfo const for (size_t fid = 0; fid < reduced.size(); ++fid) { size_t max_num_bins = std::min(num_cuts[fid], max_bins_); // If vertical and secure mode, we need to sync the max_num_bins aross workers + // to create the same global number of cut point bins for easier future processing if (info.IsVerticalFederated() && info.IsSecure()) { collective::Allreduce(&max_num_bins, 1); } @@ -457,17 +462,31 @@ void SketchContainerImpl::MakeCuts(Context const *ctx, MetaInfo const } else { // use special AddCutPoint scheme for secure vertical federated learning if (info.IsVerticalFederated() && info.IsSecure()) { - AddCutPointSecure(a, max_num_bins, p_cuts); + double last_value = AddCutPointSecure(a, max_num_bins, p_cuts); + // push a value that is greater than anything if the feature is not empty + // i.e. if the last value is not NaN + if (!std::isnan(last_value)) { + const bst_float cpt = + (a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid]; + // this must be bigger than last value in a scale + const bst_float last = cpt + (fabs(cpt) + 1e-5f); + p_cuts->cut_values_.HostVector().push_back(last); + } + else { + // if the feature is empty, push a NaN value + p_cuts->cut_values_.HostVector().push_back(std::numeric_limits::quiet_NaN()); + } } else { AddCutPoint(a, max_num_bins, p_cuts); + // push a value that is greater than anything + const bst_float cpt = + (a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid]; + // this must be bigger than last value in a scale + const bst_float last = cpt + (fabs(cpt) + 1e-5f); + p_cuts->cut_values_.HostVector().push_back(last); } - // push a value that is greater than anything - const bst_float cpt = - (a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid]; - // this must be bigger than last value in a scale - const bst_float last = cpt + (fabs(cpt) + 1e-5f); - p_cuts->cut_values_.HostVector().push_back(last); + } // Ensure that every feature gets at least one quantile point @@ -477,12 +496,49 @@ void SketchContainerImpl::MakeCuts(Context const *ctx, MetaInfo const p_cuts->cut_ptrs_.HostVector().push_back(cut_size); } - if (info.IsVerticalFederated() && info.IsSecure()) { - // cut values need to be synced across all workers via Allreduce - auto cut_val = p_cuts->cut_values_.HostVector().data(); - std::size_t n = p_cuts->cut_values_.HostVector().size(); - collective::Allreduce(cut_val, n); - } + +/* + // save the cut values and cut pointers to files for examination + if (collective::GetRank() == 0) { + //print the entries to file for debug + std::ofstream file; + file.open("cut_info_0.txt", std::ios_base::app); + file << " Total cut ptr count: " << p_cuts->cut_ptrs_.HostVector().size() << std::endl; + file << " Total cut count: " << p_cuts->cut_values_.HostVector().size() << std::endl; + //iterate through the cut pointers + for (auto i = 0; i < p_cuts->cut_ptrs_.HostVector().size(); i++) { + file << "cut_ptr " << i << ": " << p_cuts->cut_ptrs_.HostVector()[i] << std::endl; + } + //iterate through the cut values + for (auto i = 0; i < p_cuts->cut_values_.HostVector().size(); i++) { + file << "cut_value " << i << ": " << p_cuts->cut_values_.HostVector()[i] << std::endl; + } + file.close(); + } + if (collective::GetRank() == 1) { + //print the entries to file for debug + std::ofstream file; + file.open("cut_info_1.txt", std::ios_base::app); + file << " Total cut ptr count: " << p_cuts->cut_ptrs_.HostVector().size() << std::endl; + file << " Total cut count: " << p_cuts->cut_values_.HostVector().size() << std::endl; + //iterate through the cut pointers + for (auto i = 0; i < p_cuts->cut_ptrs_.HostVector().size(); i++) { + file << "cut_ptr " << i << ": " << p_cuts->cut_ptrs_.HostVector()[i] << std::endl; + } + //iterate through the cut values + for (auto i = 0; i < p_cuts->cut_values_.HostVector().size(); i++) { + file << "cut_value " << i << ": " << p_cuts->cut_values_.HostVector()[i] << std::endl; + } + file.close(); + } + + if (info.IsVerticalFederated() && info.IsSecure()) { + // cut values need to be synced across all workers via Allreduce + auto cut_val = p_cuts->cut_values_.HostVector().data(); + std::size_t n = p_cuts->cut_values_.HostVector().size(); + collective::Allreduce(cut_val, n); + } + */ p_cuts->SetCategorical(this->has_categorical_, max_cat); monitor_.Stop(__func__); diff --git a/src/learner.cc b/src/learner.cc index db72f71644cb..78c6c15ea35d 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1285,6 +1285,45 @@ class LearnerImpl : public LearnerIO { monitor_.Start("GetGradient"); GetGradient(predt.predictions, train->Info(), iter, &gpair_); monitor_.Stop("GetGradient"); + + + + if(collective::GetRank()==0){ + //print the total number of samples + std::cout << "Total number of samples: " << train->Info().labels.Size() << std::endl; + auto i = 0; + // print the first five predictions + for (auto p : predt.predictions.HostVector()) { + std::cout << "Prediction " << i << ": " << p << std::endl; + i++; + if (i == 5) { + break; + } + } + + // print the first five labels + std::cout << "Labels: " << std::endl; + i = 0; + while ( i<5 ) { + std::cout << "Label " << i << ": " << train->Info().labels.HostView()(i) << std::endl; + i++; + } + + // print the first five gradients + std::cout << "Gradients: " << std::endl; + i = 0; + for (auto p : gpair_.Data()->HostVector()) { + std::cout << "Gradient " << i << ": " << p.GetGrad() << std::endl; + i++; + if (i == 5) { + break; + } + } + } + + + + TrainingObserver::Instance().Observe(*gpair_.Data(), "Gradients"); gbm_->DoBoost(train.get(), &gpair_, &predt, obj_.get()); @@ -1333,6 +1372,9 @@ class LearnerImpl : public LearnerIO { std::shared_ptr m = data_sets[i]; auto &predt = prediction_container_.Cache(m, ctx_.Device()); this->ValidateDMatrix(m.get(), false); + if(collective::GetRank()==0){ + std::cout << "data size = " << data_sets[i]->Info().num_row_ << std::endl; + } this->PredictRaw(m.get(), &predt, false, 0, 0); auto &out = output_predictions_.Cache(m, ctx_.Device()).predictions; @@ -1341,7 +1383,15 @@ class LearnerImpl : public LearnerIO { obj_->EvalTransform(&out); for (auto& ev : metrics_) { - os << '\t' << data_names[i] << '-' << ev->Name() << ':' << ev->Evaluate(out, m); + + auto metric = ev->Evaluate(out, m); + + if(collective::GetRank()==0){ + std::cout << "eval result = " << metric << std::endl; + } + + + os << '\t' << data_names[i] << '-' << ev->Name() << ':' << metric; //ev->Evaluate(out, m); } } @@ -1446,6 +1496,11 @@ class LearnerImpl : public LearnerIO { CHECK(gbm_ != nullptr) << "Predict must happen after Load or configuration"; this->CheckModelInitialized(); this->ValidateDMatrix(data, false); + + if(collective::GetRank()==0){ + std::cout << "PredictRaw training ? " << training << std::endl; + } + gbm_->PredictBatch(data, out_preds, training, layer_begin, layer_end); } diff --git a/src/tree/common_row_partitioner.h b/src/tree/common_row_partitioner.h index 4360c0b1314e..2152155fd9ef 100644 --- a/src/tree/common_row_partitioner.h +++ b/src/tree/common_row_partitioner.h @@ -89,8 +89,8 @@ class CommonRowPartitioner { CommonRowPartitioner() = default; CommonRowPartitioner(Context const* ctx, bst_row_t num_row, bst_row_t _base_rowid, - bool is_col_split) - : base_rowid{_base_rowid}, is_col_split_{is_col_split} { + bool is_col_split, bool is_secure) + : base_rowid{_base_rowid}, is_col_split_{is_col_split}, is_secure_{is_secure} { row_set_collection_.Clear(); std::vector& row_indices = *row_set_collection_.Data(); row_indices.resize(num_row); @@ -106,7 +106,7 @@ class CommonRowPartitioner { template void FindSplitConditions(const std::vector& nodes, const RegTree& tree, - const GHistIndexMatrix& gmat, std::vector* split_conditions) { + const GHistIndexMatrix& gmat, std::vector* split_conditions, bool is_index) { auto const& ptrs = gmat.cut.Ptrs(); auto const& vals = gmat.cut.Values(); @@ -114,18 +114,25 @@ class CommonRowPartitioner { bst_node_t const nidx = nodes[i].nid; bst_feature_t const fidx = tree.SplitIndex(nidx); float const split_pt = tree.SplitCond(nidx); - std::uint32_t const lower_bound = ptrs[fidx]; - std::uint32_t const upper_bound = ptrs[fidx + 1]; - bst_bin_t split_cond = -1; - // convert floating-point split_pt into corresponding bin_id - // split_cond = -1 indicates that split_pt is less than all known cut points - CHECK_LT(upper_bound, static_cast(std::numeric_limits::max())); - for (auto bound = lower_bound; bound < upper_bound; ++bound) { - if (split_pt == vals[bound]) { - split_cond = static_cast(bound); + if (is_index) { + // if the split_pt is already recorded as a bin_id, use it directly + (*split_conditions)[i] = static_cast(split_pt); + } + else { + // otherwise find the bin_id that corresponds to split_pt + std::uint32_t const lower_bound = ptrs[fidx]; + std::uint32_t const upper_bound = ptrs[fidx + 1]; + bst_bin_t split_cond = -1; + // convert floating-point split_pt into corresponding bin_id + // split_cond = -1 indicates that split_pt is less than all known cut points + CHECK_LT(upper_bound, static_cast(std::numeric_limits::max())); + for (auto bound = lower_bound; bound < upper_bound; ++bound) { + if (split_pt == vals[bound]) { + split_cond = static_cast(bound); + } } + (*split_conditions)[i] = split_cond; } - (*split_conditions)[i] = split_cond; } } @@ -194,7 +201,13 @@ class CommonRowPartitioner { std::vector split_conditions; if (column_matrix.IsInitialized()) { split_conditions.resize(n_nodes); - FindSplitConditions(nodes, *p_tree, gmat, &split_conditions); + if (is_secure_) { + // in secure mode, the split index is kept instead of the split value + FindSplitConditions(nodes, *p_tree, gmat, &split_conditions, true); + } + else { + FindSplitConditions(nodes, *p_tree, gmat, &split_conditions, false); + } } // 2.1 Create a blocked space of size SUM(samples in each node) @@ -294,6 +307,7 @@ class CommonRowPartitioner { common::PartitionBuilder partition_builder_; common::RowSetCollection row_set_collection_; bool is_col_split_; + bool is_secure_; ColumnSplitHelper column_split_helper_; }; diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 1ee896102086..3a7179704e33 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -264,20 +264,36 @@ class HistEvaluator { static_cast(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) - parent.root_gain); - split_pt = cut_val[i]; // not used for partition based - best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum); + if (!is_secure_) { + split_pt = cut_val[i]; // not used for partition based + best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum); + } + else { + // secure mode: record the best split point, rather than the actual value since it is not accessible + best.Update(loss_chg, fidx, i, d_step == -1, false, left_sum, right_sum); + } + } else { // backward enumeration: split at left bound of each bin loss_chg = static_cast(evaluator.CalcSplitGain(*param_, nidx, fidx, GradStats{right_sum}, GradStats{left_sum}) - parent.root_gain); - if (i == imin) { - split_pt = cut.MinValues()[fidx]; - } else { - split_pt = cut_val[i - 1]; + if (!is_secure_) { + if (i == imin) { + split_pt = cut.MinValues()[fidx]; + } else { + split_pt = cut_val[i - 1]; + } + best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum); + } + else { + // secure mode: record the best split point, rather than the actual value since it is not accessible + if (i != imin) { + i = i - 1; + } + best.Update(loss_chg, fidx, i, d_step == -1, false, right_sum, left_sum); } - best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum); } } } @@ -387,7 +403,7 @@ class HistEvaluator { auto grad_stats = EnumerateSplit<+1>(cut, histogram, fidx, nidx, evaluator, best); // print the best split for each feature - // std::cout << "Best split for feature " << fidx << " is " << best->split_value << " with gain " << best->loss_chg << std::endl; + //std::cout << "Current best split at feature " << fidx << " is: " << std::endl << *best << std::endl; if (SplitContainsMissingValues(grad_stats, snode_[nidx])) { @@ -408,7 +424,7 @@ class HistEvaluator { if (is_col_split_) { // With column-wise data split, we gather the best splits from all the workers and update the - // expand entries accordingly. + // expand entries accordingly. Update() takes care of selecting the best one. // Note that under secure vertical setting, only the label owner is able to evaluate the split // based on the global histogram. The other parties will receive the final best splits // allgather is capable of performing this (0-gain entries for non-label owners), diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 42546188ff52..bafd274d38cd 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -87,7 +87,7 @@ class GloablApproxBuilder { CHECK_EQ(n_total_bins, page.cut.TotalBins()); } partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, - p_fmat->Info().IsColumnSplit()); + p_fmat->Info().IsColumnSplit(), p_fmat->Info().IsSecure()); n_batches_++; } diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 2403aa8a6bdd..a63067c229a2 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -163,7 +163,7 @@ class MultiTargetHistBuilder { } else { CHECK_EQ(n_total_bins, page.cut.TotalBins()); } - partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->Info().IsColumnSplit()); + partitioner_.emplace_back(ctx_, page.Size(), page.base_rowid, p_fmat->Info().IsColumnSplit(), p_fmat->Info().IsSecure()); } bst_target_t n_targets = p_tree->NumTargets(); @@ -355,7 +355,7 @@ class HistUpdater { CHECK_EQ(n_total_bins, page.cut.TotalBins()); } partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, - fmat->Info().IsColumnSplit()); + fmat->Info().IsColumnSplit(), fmat->Info().IsSecure()); } histogram_builder_->Reset(ctx_, n_total_bins, 1, HistBatch(param_), collective::IsDistributed(), fmat->Info().IsColumnSplit(), fmat->Info().IsSecure(), hist_param_);