Skip to content

Commit

Permalink
Secure horizontal federated scheme for GPU computation (dmlc#10601)
Browse files Browse the repository at this point in the history


---------

Co-authored-by: Jiaming Yuan <[email protected]>
  • Loading branch information
ZiyueXu77 and trivialfis authored Jul 30, 2024
1 parent be86d0e commit 5ce742a
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ void BroadcastGradient(Context const* ctx, MetaInfo const& info, GradFn&& grad_f
#if defined(XGBOOST_USE_FEDERATED)
// Need to encrypt the gradient before broadcasting.
common::Span<std::uint8_t> encrypted;
auto const& comm = GlobalCommGroup()->Ctx(ctx, ctx->Device());
auto const& comm = GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU());
auto const& fed = dynamic_cast<FederatedComm const&>(comm);
if (GetRank() == 0) {
// Obtain the gradient
Expand Down
98 changes: 91 additions & 7 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@
#include "xgboost/task.h" // for ObjInfo
#include "xgboost/tree_model.h"

#include "../collective/communicator-inl.h"
#include "../collective/allgather.h" // for AllgatherV

#if defined(XGBOOST_USE_FEDERATED)
#include "../../plugin/federated/federated_comm.h" // for FederatedComm
#else
#include "../common/error_msg.h" // for NoFederated
#endif

namespace xgboost::tree {
#if !defined(GTEST_TEST)
DMLC_REGISTRY_FILE_TAG(updater_gpu_hist);
Expand Down Expand Up @@ -514,17 +523,67 @@ struct GPUHistMakerDevice {
// num histograms is the number of contiguous histograms in memory to reduce over
void AllReduceHist(int nidx, int num_histograms) {
monitor.Start("AllReduce");
std::size_t n = page->Cuts().TotalBins() * 2 * num_histograms;
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
auto hist_vec = linalg::MakeVec(reinterpret_cast<ReduceT*>(d_node_hist), n, ctx_->Device());
auto rc = collective::GlobalSum(
ctx_, info_,
linalg::MakeVec(reinterpret_cast<ReduceT*>(d_node_hist),
page->Cuts().TotalBins() * 2 * num_histograms, ctx_->Device()));
ctx_, info_, hist_vec);
SafeColl(rc);

monitor.Stop("AllReduce");
}

#if defined(XGBOOST_USE_FEDERATED)
void AllReduceHistEncrypted(int nidx, int num_histograms) {
monitor.Start(__func__);
// Get encryption plugin
auto const &comm = collective::GlobalCommGroup()->Ctx(ctx_, DeviceOrd::CPU());
auto const &fed = dynamic_cast<collective::FederatedComm const &>(comm);
auto plugin = fed.EncryptionPlugin();

// Get the histogram data
std::size_t n = page->Cuts().TotalBins() * 2 * num_histograms;
auto d_node_hist = hist.GetNodeHistogram(nidx).data();
using ReduceT = typename std::remove_pointer<decltype(d_node_hist)>::type::ValueT;
auto hist_vec = linalg::MakeVec(reinterpret_cast<ReduceT*>(d_node_hist), n, ctx_->Device());

// copy the histogram out of GPU memory
common::Span<std::int8_t> erased = common::EraseType(hist_vec.Values());
std::vector<std::int8_t> h_data(erased.size());
dh::safe_cuda(cudaMemcpy(h_data.data(), erased.data(), erased.size(), cudaMemcpyDeviceToHost));

// call the encryption plugin
auto src_hist = common::Span{reinterpret_cast<double const *>(h_data.data()), n};
auto hist_buf = plugin->BuildEncryptedHistHori(src_hist);

// allgather
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
auto rc = collective::AllgatherV(ctx_, linalg::MakeVec(hist_buf),
&recv_segments, &hist_entries);
collective::SafeColl(rc);

// call the encryption plugin to decode the histograms
auto hist_aggr = plugin->SyncEncryptedHistHori(
common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));

// reinterpret the aggregated histogram as a int64_t and aggregate
auto hist_aggr_64 = common::Span{
reinterpret_cast<std::int64_t *>(hist_aggr.data()), hist_aggr.size()};
int num_ranks = collective::GlobalCommGroup()->World();
for (size_t i = 0; i < n; i++) {
for (int j = 1; j < num_ranks; j++) {
hist_aggr_64[i] = hist_aggr_64[i] + hist_aggr_64[i + j * n];
}
}

// copy the aggregated histogram back to GPU memory
cudaMemcpy(erased.data(), hist_aggr_64.data(), erased.size(), cudaMemcpyHostToDevice);

monitor.Stop(__func__);
}
#endif

/**
* \brief Build GPU local histograms for the left and right child of some parent node
*/
Expand Down Expand Up @@ -559,7 +618,16 @@ struct GPUHistMakerDevice {
// Reduce all in one go
// This gives much better latency in a distributed setting
// when processing a large batch
this->AllReduceHist(hist_nidx.at(0), hist_nidx.size());
// If secure horizontal, perform AllReduce by calling the encryption plugin
if (collective::IsDistributed() && info_.IsRowSplit() && collective::IsEncrypted()) {
#if defined(XGBOOST_USE_FEDERATED)
this->AllReduceHistEncrypted(hist_nidx.at(0), hist_nidx.size());
#else
LOG(FATAL) << error::NoFederated();
#endif
} else {
this->AllReduceHist(hist_nidx.at(0), hist_nidx.size());
}

for (size_t i = 0; i < subtraction_nidx.size(); i++) {
auto build_hist_nidx = hist_nidx.at(i);
Expand All @@ -569,7 +637,15 @@ struct GPUHistMakerDevice {
if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) {
// Calculate other histogram manually
this->BuildHist(subtraction_trick_nidx);
this->AllReduceHist(subtraction_trick_nidx, 1);
if (collective::IsDistributed() && info_.IsRowSplit() && collective::IsEncrypted()) {
#if defined(XGBOOST_USE_FEDERATED)
this->AllReduceHistEncrypted(subtraction_trick_nidx, 1);
#else
LOG(FATAL) << error::NoFederated();
#endif
} else {
this->AllReduceHist(subtraction_trick_nidx, 1);
}
}
}
}
Expand Down Expand Up @@ -643,7 +719,15 @@ struct GPUHistMakerDevice {

hist.AllocateHistograms(ctx_, {kRootNIdx});
this->BuildHist(kRootNIdx);
this->AllReduceHist(kRootNIdx, 1);
if (collective::IsDistributed() && info_.IsRowSplit() && collective::IsEncrypted()) {
#if defined(XGBOOST_USE_FEDERATED)
this->AllReduceHistEncrypted(kRootNIdx, 1);
#else
LOG(FATAL) << error::NoFederated();
#endif
} else {
this->AllReduceHist(kRootNIdx, 1);
}

// Remember root stats
auto root_sum = quantiser.ToFloatingPoint(root_sum_quantised);
Expand Down

0 comments on commit 5ce742a

Please sign in to comment.