Skip to content

Commit

Permalink
updates according to comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Feb 28, 2024
1 parent 7e407a8 commit add5dcd
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 47 deletions.
61 changes: 19 additions & 42 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,50 +362,32 @@ void SketchContainerImpl<WQSketch>::AllReduce(

template <typename SketchType>
void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin,
HistogramCuts *cuts) {
HistogramCuts *cuts, bool secure) {
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
auto &cut_values = cuts->cut_values_.HostVector();
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary.data[i].value;
if (i == 1 || cpt > cut_values.back()) {
cut_values.push_back(cpt);
}
}
}

template <typename SketchType>
void AddCutPointSecure(typename SketchType::SummaryContainer const &summary, int max_bin,
HistogramCuts *cuts) {
// For secure vertical pipeline, we fill the cut values corresponding to empty columns
// with a vector of minimum value
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
// make a copy of required_cuts for mode selection
size_t required_cuts_original = required_cuts;
// make a copy of required_cuts for mode selection
size_t required_cuts_original = required_cuts;
if (secure) {
// Sync the required_cuts across all workers
collective::Allreduce<collective::Operation::kMax>(&required_cuts, 1);

// add the cut points
auto &cut_values = cuts->cut_values_.HostVector();
// if not empty column, fill the cut values with the actual values
if (required_cuts_original > 0) {
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary.data[i].value;
if (i == 1 || cpt > cut_values.back()) {
cut_values.push_back(cpt);
}
}
}
auto &cut_values = cuts->cut_values_.HostVector();
// if empty column, fill the cut values with 0
if (secure && (required_cuts_original == 0)){
for (size_t i = 1; i < required_cuts; ++i) {
cut_values.push_back(0.0);
}
// if empty column, fill the cut values with 0
else {
for (size_t i = 1; i < required_cuts; ++i) {
cut_values.push_back(0.0);
}
else {
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary.data[i].value;
if (i == 1 || cpt > cut_values.back()) {
cut_values.push_back(cpt);
}
}
}
}


auto AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
if (std::any_of(categories.cbegin(), categories.cend(), InvalidCat)) {
InvalidCategory();
Expand Down Expand Up @@ -464,12 +446,7 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
max_cat = std::max(max_cat, AddCategories(categories_.at(fid), p_cuts));
} else {
// use special AddCutPoint scheme for secure vertical federated learning
if (info.IsVerticalFederated() && info.IsSecure()) {
AddCutPointSecure<WQSketch>(a, max_num_bins, p_cuts);
}
else {
AddCutPoint<WQSketch>(a, max_num_bins, p_cuts);
}
AddCutPoint<WQSketch>(a, max_num_bins, p_cuts, info.IsSecure());
// push a value that is greater than anything
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
Expand Down
7 changes: 4 additions & 3 deletions src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,10 @@ class HistEvaluator {
}
});

for (unsigned nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
for (unsigned nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
}
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/tree/hist/histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,9 @@ class HistogramBuilder {
if (is_distributed_ && is_col_split_ && is_secure_) {
// Under secure vertical mode, we perform allgather for all nodes
CHECK(!nodes_to_build.empty());
// in theory the operation is AllGather, but with current system functionality,
// we use AllReduce to simulate the AllGather operation
// in theory the operation is AllGather, under current histogram setting of
// same length with 0s for empty slots,
// AllReduce is the most efficient way of achieving the global histogram
auto first_nidx = nodes_to_build.front();
std::size_t n = n_total_bins * nodes_to_build.size() * 2;
collective::Allreduce<collective::Operation::kSum>(
Expand Down

0 comments on commit add5dcd

Please sign in to comment.