Skip to content

Commit

Permalink
Merge branch 'vertical-federated-learning' into SecureBoostInf
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 authored Mar 1, 2024
2 parents 8a2e1b1 + fe73294 commit 3adef40
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 21 deletions.
3 changes: 2 additions & 1 deletion include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ class MetaInfo {
}

/** @brief Whether the data is split column-wise. */
bool IsColumnSplit() const { return (data_split_mode == DataSplitMode::kCol) || (data_split_mode == DataSplitMode::kColSecure); }
bool IsColumnSplit() const { return (data_split_mode == DataSplitMode::kCol)
|| (data_split_mode == DataSplitMode::kColSecure); }

/** @brief Whether the data is split column-wise with secure computation. */
bool IsSecure() const { return data_split_mode == DataSplitMode::kColSecure; }
Expand Down
32 changes: 26 additions & 6 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,27 @@ void SketchContainerImpl<WQSketch>::AllReduce(

template <typename SketchType>
void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin,
HistogramCuts *cuts) {
HistogramCuts *cuts, bool secure) {
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
// make a copy of required_cuts for mode selection
size_t required_cuts_original = required_cuts;
if (secure) {
// Sync the required_cuts across all workers
collective::Allreduce<collective::Operation::kMax>(&required_cuts, 1);
}
auto &cut_values = cuts->cut_values_.HostVector();
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary.data[i].value;
if (i == 1 || cpt > cut_values.back()) {
cut_values.push_back(cpt);
// if empty column, fill the cut values with 0
if (secure && (required_cuts_original == 0)) {
for (size_t i = 1; i < required_cuts; ++i) {
cut_values.push_back(0.0);
}
} else {
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary.data[i].value;
if (i == 1 || cpt > cut_values.back()) {
cut_values.push_back(cpt);
}
}
}
}
Expand Down Expand Up @@ -501,6 +514,13 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
p_cuts->cut_ptrs_.HostVector().push_back(cut_size);
}

if (info.IsVerticalFederated() && info.IsSecure()) {
// cut values need to be synced across all workers via Allreduce
auto cut_val = p_cuts->cut_values_.HostVector().data();
std::size_t n = p_cuts->cut_values_.HostVector().size();
collective::Allreduce<collective::Operation::kSum>(cut_val, n);
}

p_cuts->SetCategorical(this->has_categorical_, max_cat);
monitor_.Stop(__func__);
}
Expand Down
17 changes: 10 additions & 7 deletions src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ class HistEvaluator {
}
}
}

p_best->Update(best);
return left_sum;
}
Expand Down Expand Up @@ -370,9 +369,9 @@ class HistEvaluator {
auto evaluator = tree_evaluator_.GetEvaluator();
auto const &cut_ptrs = cut.Ptrs();

// Under secure vertical setting, only the label owner is able to evaluate the split
// based on the global histogram. The other parties will only receive the final best split information
// Hence the below computation is not performed by the non-label owners under secure vertical setting
// Under secure vertical setting, only the active party is able to evaluate the split
// based on global histogram. Other parties will receive the final best split information
// Hence the below computation is not performed by the passive parties
if ((!is_secure_) || (collective::GetRank() == 0)) {
// Evaluate the splits for each feature
common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) {
Expand Down Expand Up @@ -414,9 +413,10 @@ class HistEvaluator {
}
});

for (unsigned nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
for (unsigned nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
}
}
}

Expand All @@ -427,6 +427,9 @@ class HistEvaluator {
// based on the global histogram. The other parties will receive the final best splits
// allgather is capable of performing this (0-gain entries for non-label owners),
// expand entries accordingly.
// Note that under secure vertical setting, only the label owner is able to evaluate the split
// based on the global histogram. The other parties will receive the final best splits
// allgather is capable of performing this (0-gain entries for non-label owners),
auto all_entries = AllgatherColumnSplit(entries);
for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
Expand Down
8 changes: 5 additions & 3 deletions src/tree/hist/histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,9 @@ class HistogramBuilder {
if (is_distributed_ && is_col_split_ && is_secure_) {
// Under secure vertical mode, we perform allgather for all nodes
CHECK(!nodes_to_build.empty());
// in theory the operation is AllGather, but with current system functionality,
// we use AllReduce to simulate the AllGather operation
// in theory the operation is AllGather, under current histogram setting of
// same length with 0s for empty slots,
// AllReduce is the most efficient way of achieving the global histogram
auto first_nidx = nodes_to_build.front();
std::size_t n = n_total_bins * nodes_to_build.size() * 2;
collective::Allreduce<collective::Operation::kSum>(
Expand Down Expand Up @@ -343,7 +344,8 @@ class MultiHistogramBuilder {
[[nodiscard]] auto &Histogram(bst_target_t t) { return target_builders_[t].Histogram(); }

void Reset(Context const *ctx, bst_bin_t total_bins, bst_target_t n_targets, BatchParam const &p,
bool is_distributed, bool is_col_split, bool is_secure, HistMakerTrainParam const *param) {
bool is_distributed, bool is_col_split, bool is_secure,
HistMakerTrainParam const *param) {
ctx_ = ctx;
target_builders_.resize(n_targets);
CHECK_GE(n_targets, 1);
Expand Down
4 changes: 2 additions & 2 deletions src/tree/updater_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ class GloablApproxBuilder {
}

histogram_builder_.Reset(ctx_, n_total_bins, p_tree->NumTargets(), BatchSpec(*param_, hess),
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(), p_fmat->Info().IsSecure(),
hist_param_);
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(),
p_fmat->Info().IsSecure(), hist_param_);
monitor_->Stop(__func__);
}

Expand Down
4 changes: 2 additions & 2 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ class MultiTargetHistBuilder {
bst_target_t n_targets = p_tree->NumTargets();
histogram_builder_ = std::make_unique<MultiHistogramBuilder>();
histogram_builder_->Reset(ctx_, n_total_bins, n_targets, HistBatch(param_),
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(), p_fmat->Info().IsSecure(),
hist_param_);
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(),
p_fmat->Info().IsSecure(), hist_param_);

evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
p_last_tree_ = p_tree;
Expand Down

0 comments on commit 3adef40

Please sign in to comment.