Skip to content

Commit

Permalink
[secure boost] Vertical pipeline with hist sync (dmlc#10037)
Browse files Browse the repository at this point in the history
The first phase is to implement an alternative vertical pipeline that syncs the histograms from clients to the label owner.
  • Loading branch information
ZiyueXu77 authored Mar 1, 2024
1 parent 5ac2332 commit fe73294
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 76 deletions.
9 changes: 7 additions & 2 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ enum class DataType : uint8_t {

enum class FeatureType : uint8_t { kNumerical = 0, kCategorical = 1 };

enum class DataSplitMode : int { kRow = 0, kCol = 1 };
enum class DataSplitMode : int { kRow = 0, kCol = 1, kColSecure = 2 };

/*!
* \brief Meta information about dataset, always sit in memory.
Expand Down Expand Up @@ -186,7 +186,12 @@ class MetaInfo {
}

/** @brief Whether the data is split column-wise. */
bool IsColumnSplit() const { return data_split_mode == DataSplitMode::kCol; }
bool IsColumnSplit() const { return (data_split_mode == DataSplitMode::kCol)
|| (data_split_mode == DataSplitMode::kColSecure); }

/** @brief Whether the data is split column-wise with secure computation. */
bool IsSecure() const { return data_split_mode == DataSplitMode::kColSecure; }

/** @brief Whether this is a learning to rank data. */
bool IsRanking() const { return !group_ptr_.empty(); }

Expand Down
39 changes: 32 additions & 7 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,27 @@ void SketchContainerImpl<WQSketch>::AllReduce(

template <typename SketchType>
void AddCutPoint(typename SketchType::SummaryContainer const &summary, int max_bin,
HistogramCuts *cuts) {
HistogramCuts *cuts, bool secure) {
size_t required_cuts = std::min(summary.size, static_cast<size_t>(max_bin));
// make a copy of required_cuts for mode selection
size_t required_cuts_original = required_cuts;
if (secure) {
// Sync the required_cuts across all workers
collective::Allreduce<collective::Operation::kMax>(&required_cuts, 1);
}
auto &cut_values = cuts->cut_values_.HostVector();
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary.data[i].value;
if (i == 1 || cpt > cut_values.back()) {
cut_values.push_back(cpt);
// if empty column, fill the cut values with 0
if (secure && (required_cuts_original == 0)) {
for (size_t i = 1; i < required_cuts; ++i) {
cut_values.push_back(0.0);
}
} else {
// we use the min_value as the first (0th) element, hence starting from 1.
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary.data[i].value;
if (i == 1 || cpt > cut_values.back()) {
cut_values.push_back(cpt);
}
}
}
}
Expand Down Expand Up @@ -423,11 +436,16 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
float max_cat{-1.f};
for (size_t fid = 0; fid < reduced.size(); ++fid) {
size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
// If vertical and secure mode, we need to sync the max_num_bins aross workers
if (info.IsVerticalFederated() && info.IsSecure()) {
collective::Allreduce<collective::Operation::kMax>(&max_num_bins, 1);
}
typename WQSketch::SummaryContainer const &a = final_summaries[fid];
if (IsCat(feature_types_, fid)) {
max_cat = std::max(max_cat, AddCategories(categories_.at(fid), p_cuts));
} else {
AddCutPoint<WQSketch>(a, max_num_bins, p_cuts);
// use special AddCutPoint scheme for secure vertical federated learning
AddCutPoint<WQSketch>(a, max_num_bins, p_cuts, info.IsSecure());
// push a value that is greater than anything
const bst_float cpt =
(a.size > 0) ? a.data[a.size - 1].value : p_cuts->min_vals_.HostVector()[fid];
Expand All @@ -443,6 +461,13 @@ void SketchContainerImpl<WQSketch>::MakeCuts(Context const *ctx, MetaInfo const
p_cuts->cut_ptrs_.HostVector().push_back(cut_size);
}

if (info.IsVerticalFederated() && info.IsSecure()) {
// cut values need to be synced across all workers via Allreduce
auto cut_val = p_cuts->cut_values_.HostVector().data();
std::size_t n = p_cuts->cut_values_.HostVector().size();
collective::Allreduce<collective::Operation::kSum>(cut_val, n);
}

p_cuts->SetCategorical(this->has_categorical_, max_cat);
monitor_.Stop(__func__);
}
Expand Down
94 changes: 53 additions & 41 deletions src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class HistEvaluator {
std::shared_ptr<common::ColumnSampler> column_sampler_;
TreeEvaluator tree_evaluator_;
bool is_col_split_{false};
bool is_secure_{false};
FeatureInteractionConstraintHost interaction_constraints_;
std::vector<NodeEntry> snode_;

Expand Down Expand Up @@ -321,7 +322,6 @@ class HistEvaluator {
}
}
}

p_best->Update(best);
return left_sum;
}
Expand Down Expand Up @@ -353,54 +353,63 @@ class HistEvaluator {
auto evaluator = tree_evaluator_.GetEvaluator();
auto const &cut_ptrs = cut.Ptrs();

common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) {
auto tidx = omp_get_thread_num();
auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
auto best = &entry->split;
auto nidx = entry->nid;
auto histogram = hist[nidx];
auto features_set = features[nidx_in_set]->ConstHostSpan();
for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
auto fidx = features_set[fidx_in_set];
bool is_cat = common::IsCat(feature_types, fidx);
if (!interaction_constraints_.Query(nidx, fidx)) {
continue;
}
if (is_cat) {
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
if (common::UseOneHot(n_bins, param_->max_cat_to_onehot)) {
EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
} else {
std::vector<size_t> sorted_idx(n_bins);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins);
// Sort the histogram to get contiguous partitions.
std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) {
auto ret = evaluator.CalcWeightCat(*param_, feat_hist[l]) <
evaluator.CalcWeightCat(*param_, feat_hist[r]);
return ret;
});
EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
EnumeratePart<-1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
// Under secure vertical setting, only the active party is able to evaluate the split
// based on global histogram. Other parties will receive the final best split information
// Hence the below computation is not performed by the passive parties
if ((!is_secure_) || (collective::GetRank() == 0)) {
// Evaluate the splits for each feature
common::ParallelFor2d(space, n_threads, [&](size_t nidx_in_set, common::Range1d r) {
auto tidx = omp_get_thread_num();
auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
auto best = &entry->split;
auto nidx = entry->nid;
auto histogram = hist[nidx];
auto features_set = features[nidx_in_set]->ConstHostSpan();
for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
auto fidx = features_set[fidx_in_set];
bool is_cat = common::IsCat(feature_types, fidx);
if (!interaction_constraints_.Query(nidx, fidx)) {
continue;
}
} else {
auto grad_stats = EnumerateSplit<+1>(cut, histogram, fidx, nidx, evaluator, best);
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
EnumerateSplit<-1>(cut, histogram, fidx, nidx, evaluator, best);
if (is_cat) {
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
if (common::UseOneHot(n_bins, param_->max_cat_to_onehot)) {
EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
} else {
std::vector<size_t> sorted_idx(n_bins);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins);
// Sort the histogram to get contiguous partitions.
std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) {
auto ret = evaluator.CalcWeightCat(*param_, feat_hist[l]) <
evaluator.CalcWeightCat(*param_, feat_hist[r]);
return ret;
});
EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
EnumeratePart<-1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
}
} else {
auto grad_stats = EnumerateSplit<+1>(cut, histogram, fidx, nidx, evaluator, best);
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
EnumerateSplit<-1>(cut, histogram, fidx, nidx, evaluator, best);
}
}
}
}
});
});

for (unsigned nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
for (unsigned nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
}
}
}

if (is_col_split_) {
// With column-wise data split, we gather the best splits from all the workers and update the
// expand entries accordingly.
// Note that under secure vertical setting, only the label owner is able to evaluate the split
// based on the global histogram. The other parties will receive the final best splits
// allgather is capable of performing this (0-gain entries for non-label owners),
auto all_entries = AllgatherColumnSplit(entries);
for (auto worker = 0; worker < collective::GetWorldSize(); ++worker) {
for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
Expand Down Expand Up @@ -480,7 +489,8 @@ class HistEvaluator {
param_{param},
column_sampler_{std::move(sampler)},
tree_evaluator_{*param, static_cast<bst_feature_t>(info.num_col_), DeviceOrd::CPU()},
is_col_split_{info.IsColumnSplit()} {
is_col_split_{info.IsColumnSplit()},
is_secure_{info.IsSecure()}{
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,
Expand All @@ -496,6 +506,7 @@ class HistMultiEvaluator {
std::shared_ptr<common::ColumnSampler> column_sampler_;
Context const *ctx_;
bool is_col_split_{false};
bool is_secure_{false};

private:
static double MultiCalcSplitGain(TrainParam const &param,
Expand Down Expand Up @@ -709,7 +720,8 @@ class HistMultiEvaluator {
: param_{param},
column_sampler_{std::move(sampler)},
ctx_{ctx},
is_col_split_{info.IsColumnSplit()} {
is_col_split_{info.IsColumnSplit()},
is_secure_{info.IsSecure()} {
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,
Expand Down
22 changes: 19 additions & 3 deletions src/tree/hist/histogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class HistogramBuilder {
// Whether XGBoost is running in distributed environment.
bool is_distributed_{false};
bool is_col_split_{false};
bool is_secure_{false};

public:
/**
Expand All @@ -60,13 +61,14 @@ class HistogramBuilder {
* of using global rabit variable.
*/
void Reset(Context const *ctx, bst_bin_t total_bins, BatchParam const &p, bool is_distributed,
bool is_col_split, HistMakerTrainParam const *param) {
bool is_col_split, bool is_secure, HistMakerTrainParam const *param) {
n_threads_ = ctx->Threads();
param_ = p;
hist_.Reset(total_bins, param->max_cached_hist_node);
buffer_.Init(total_bins);
is_distributed_ = is_distributed;
is_col_split_ = is_col_split;
is_secure_ = is_secure;
}

template <bool any_missing>
Expand Down Expand Up @@ -175,6 +177,7 @@ class HistogramBuilder {
std::vector<bst_node_t> const &nodes_to_build,
std::vector<bst_node_t> const &nodes_to_trick) {
auto n_total_bins = buffer_.TotalBins();

common::BlockedSpace2d space(
nodes_to_build.size(), [&](std::size_t) { return n_total_bins; }, 1024);
common::ParallelFor2d(space, this->n_threads_, [&](size_t node, common::Range1d r) {
Expand All @@ -190,6 +193,18 @@ class HistogramBuilder {
reinterpret_cast<double *>(this->hist_[first_nidx].data()), n);
}

if (is_distributed_ && is_col_split_ && is_secure_) {
// Under secure vertical mode, we perform allgather for all nodes
CHECK(!nodes_to_build.empty());
// in theory the operation is AllGather, under current histogram setting of
// same length with 0s for empty slots,
// AllReduce is the most efficient way of achieving the global histogram
auto first_nidx = nodes_to_build.front();
std::size_t n = n_total_bins * nodes_to_build.size() * 2;
collective::Allreduce<collective::Operation::kSum>(
reinterpret_cast<double *>(this->hist_[first_nidx].data()), n);
}

common::BlockedSpace2d const &subspace =
nodes_to_trick.size() == nodes_to_build.size()
? space
Expand Down Expand Up @@ -329,12 +344,13 @@ class MultiHistogramBuilder {
[[nodiscard]] auto &Histogram(bst_target_t t) { return target_builders_[t].Histogram(); }

void Reset(Context const *ctx, bst_bin_t total_bins, bst_target_t n_targets, BatchParam const &p,
bool is_distributed, bool is_col_split, HistMakerTrainParam const *param) {
bool is_distributed, bool is_col_split, bool is_secure,
HistMakerTrainParam const *param) {
ctx_ = ctx;
target_builders_.resize(n_targets);
CHECK_GE(n_targets, 1);
for (auto &v : target_builders_) {
v.Reset(ctx, total_bins, p, is_distributed, is_col_split, param);
v.Reset(ctx, total_bins, p, is_distributed, is_col_split, is_secure, param);
}
}
};
Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_approx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class GloablApproxBuilder {

histogram_builder_.Reset(ctx_, n_total_bins, p_tree->NumTargets(), BatchSpec(*param_, hess),
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(),
hist_param_);
p_fmat->Info().IsSecure(), hist_param_);
monitor_->Stop(__func__);
}

Expand Down
4 changes: 2 additions & 2 deletions src/tree/updater_quantile_hist.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class MultiTargetHistBuilder {
histogram_builder_ = std::make_unique<MultiHistogramBuilder>();
histogram_builder_->Reset(ctx_, n_total_bins, n_targets, HistBatch(param_),
collective::IsDistributed(), p_fmat->Info().IsColumnSplit(),
hist_param_);
p_fmat->Info().IsSecure(), hist_param_);

evaluator_ = std::make_unique<HistMultiEvaluator>(ctx_, p_fmat->Info(), param_, col_sampler_);
p_last_tree_ = p_tree;
Expand Down Expand Up @@ -357,7 +357,7 @@ class HistUpdater {
fmat->Info().IsColumnSplit());
}
histogram_builder_->Reset(ctx_, n_total_bins, 1, HistBatch(param_), collective::IsDistributed(),
fmat->Info().IsColumnSplit(), hist_param_);
fmat->Info().IsColumnSplit(), fmat->Info().IsSecure(), hist_param_);
evaluator_ = std::make_unique<HistEvaluator>(ctx_, this->param_, fmat->Info(), col_sampler_);
p_last_tree_ = p_tree;
monitor_->Stop(__func__);
Expand Down
Loading

0 comments on commit fe73294

Please sign in to comment.