From add5dcdc8a671eeac59f0a93652ffcc93a395c59 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Wed, 28 Feb 2024 17:12:43 -0500 Subject: [PATCH] updates according to comments --- src/common/quantile.cc | 61 ++++++++++----------------------- src/tree/hist/evaluate_splits.h | 7 ++-- src/tree/hist/histogram.h | 5 +-- 3 files changed, 26 insertions(+), 47 deletions(-) diff --git a/src/common/quantile.cc b/src/common/quantile.cc index e10b256d2863..93ceed9ac64e 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -362,50 +362,32 @@ void SketchContainerImpl::AllReduce( template void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin, - HistogramCuts *cuts) { + HistogramCuts *cuts, bool secure) { size_t required_cuts = std::min(summary.size, static_cast(max_bin)); - auto &cut_values = cuts->cut_values_.HostVector(); - // we use the min_value as the first (0th) element, hence starting from 1. - for (size_t i = 1; i < required_cuts; ++i) { - bst_float cpt = summary.data[i].value; - if (i == 1 || cpt > cut_values.back()) { - cut_values.push_back(cpt); - } - } -} - -template -void AddCutPointSecure(typename SketchType::SummaryContainer const &summary, int max_bin, - HistogramCuts *cuts) { - // For secure vertical pipeline, we fill the cut values corresponding to empty columns - // with a vector of minimum value - size_t required_cuts = std::min(summary.size, static_cast(max_bin)); - // make a copy of required_cuts for mode selection - size_t required_cuts_original = required_cuts; + // make a copy of required_cuts for mode selection + size_t required_cuts_original = required_cuts; + if (secure) { // Sync the required_cuts across all workers collective::Allreduce(&required_cuts, 1); - - // add the cut points - auto &cut_values = cuts->cut_values_.HostVector(); - // if not empty column, fill the cut values with the actual values - if (required_cuts_original > 0) { - // we use the min_value as the first (0th) element, hence starting from 1. - for (size_t i = 1; i < required_cuts; ++i) { - bst_float cpt = summary.data[i].value; - if (i == 1 || cpt > cut_values.back()) { - cut_values.push_back(cpt); - } - } + } + auto &cut_values = cuts->cut_values_.HostVector(); + // if empty column, fill the cut values with 0 + if (secure && (required_cuts_original == 0)){ + for (size_t i = 1; i < required_cuts; ++i) { + cut_values.push_back(0.0); } - // if empty column, fill the cut values with 0 - else { - for (size_t i = 1; i < required_cuts; ++i) { - cut_values.push_back(0.0); + } + else { + // we use the min_value as the first (0th) element, hence starting from 1. + for (size_t i = 1; i < required_cuts; ++i) { + bst_float cpt = summary.data[i].value; + if (i == 1 || cpt > cut_values.back()) { + cut_values.push_back(cpt); } } + } } - auto AddCategories(std::set const &categories, HistogramCuts *cuts) { if (std::any_of(categories.cbegin(), categories.cend(), InvalidCat)) { InvalidCategory(); @@ -464,12 +446,7 @@ void SketchContainerImpl::MakeCuts(Context const *ctx, MetaInfo const max_cat = std::max(max_cat, AddCategories(categories_.at(fid), p_cuts)); } else { // use special AddCutPoint scheme for secure vertical federated learning - if (info.IsVerticalFederated() && info.IsSecure()) { - AddCutPointSecure(a, max_num_bins, p_cuts); - } - else { - AddCutPoint(a, max_num_bins, p_cuts); - } + AddCutPoint(a, max_num_bins, p_cuts, info.IsSecure()); // push a value that is greater than anything const bst_float cpt = (a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid]; diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index a41081070333..a96bd3c2b181 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -399,9 +399,10 @@ class HistEvaluator { } }); - for (unsigned nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) { - for (auto tidx = 0; tidx < n_threads; ++tidx) { - entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split); + for (unsigned nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) { + for (auto tidx = 0; tidx < n_threads; ++tidx) { + entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split); + } } } diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index 9d5d22eed782..ca589fc6c189 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -196,8 +196,9 @@ class HistogramBuilder { if (is_distributed_ && is_col_split_ && is_secure_) { // Under secure vertical mode, we perform allgather for all nodes CHECK(!nodes_to_build.empty()); - // in theory the operation is AllGather, but with current system functionality, - // we use AllReduce to simulate the AllGather operation + // in theory the operation is AllGather, under current histogram setting of + // same length with 0s for empty slots, + // AllReduce is the most efficient way of achieving the global histogram auto first_nidx = nodes_to_build.front(); std::size_t n = n_total_bins * nodes_to_build.size() * 2; collective::Allreduce(