diff --git a/src/collective/aggregator.h b/src/collective/aggregator.h index f3d2aa090aa9..8ebd39b900fe 100644 --- a/src/collective/aggregator.h +++ b/src/collective/aggregator.h @@ -199,7 +199,7 @@ void BroadcastGradient(Context const* ctx, MetaInfo const& info, GradFn&& grad_f #if defined(XGBOOST_USE_FEDERATED) // Need to encrypt the gradient before broadcasting. common::Span encrypted; - auto const& comm = GlobalCommGroup()->Ctx(ctx, ctx->Device()); + auto const& comm = GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU()); auto const& fed = dynamic_cast(comm); if (GetRank() == 0) { // Obtain the gradient diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 83f84ec1f4a5..088fc199786b 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -44,6 +44,15 @@ #include "xgboost/task.h" // for ObjInfo #include "xgboost/tree_model.h" +#include "../collective/communicator-inl.h" +#include "../collective/allgather.h" // for AllgatherV + +#if defined(XGBOOST_USE_FEDERATED) +#include "../../plugin/federated/federated_comm.h" // for FederatedComm +#else +#include "../common/error_msg.h" // for NoFederated +#endif + namespace xgboost::tree { #if !defined(GTEST_TEST) DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); @@ -514,17 +523,67 @@ struct GPUHistMakerDevice { // num histograms is the number of contiguous histograms in memory to reduce over void AllReduceHist(int nidx, int num_histograms) { monitor.Start("AllReduce"); + std::size_t n = page->Cuts().TotalBins() * 2 * num_histograms; auto d_node_hist = hist.GetNodeHistogram(nidx).data(); using ReduceT = typename std::remove_pointer::type::ValueT; + auto hist_vec = linalg::MakeVec(reinterpret_cast(d_node_hist), n, ctx_->Device()); auto rc = collective::GlobalSum( - ctx_, info_, - linalg::MakeVec(reinterpret_cast(d_node_hist), - page->Cuts().TotalBins() * 2 * num_histograms, ctx_->Device())); + ctx_, info_, hist_vec); SafeColl(rc); - monitor.Stop("AllReduce"); } +#if defined(XGBOOST_USE_FEDERATED) + void AllReduceHistEncrypted(int nidx, int num_histograms) { + monitor.Start(__func__); + // Get encryption plugin + auto const &comm = collective::GlobalCommGroup()->Ctx(ctx_, DeviceOrd::CPU()); + auto const &fed = dynamic_cast(comm); + auto plugin = fed.EncryptionPlugin(); + + // Get the histogram data + std::size_t n = page->Cuts().TotalBins() * 2 * num_histograms; + auto d_node_hist = hist.GetNodeHistogram(nidx).data(); + using ReduceT = typename std::remove_pointer::type::ValueT; + auto hist_vec = linalg::MakeVec(reinterpret_cast(d_node_hist), n, ctx_->Device()); + + // copy the histogram out of GPU memory + common::Span erased = common::EraseType(hist_vec.Values()); + std::vector h_data(erased.size()); + dh::safe_cuda(cudaMemcpy(h_data.data(), erased.data(), erased.size(), cudaMemcpyDeviceToHost)); + + // call the encryption plugin + auto src_hist = common::Span{reinterpret_cast(h_data.data()), n}; + auto hist_buf = plugin->BuildEncryptedHistHori(src_hist); + + // allgather + HostDeviceVector hist_entries; + std::vector recv_segments; + auto rc = collective::AllgatherV(ctx_, linalg::MakeVec(hist_buf), + &recv_segments, &hist_entries); + collective::SafeColl(rc); + + // call the encryption plugin to decode the histograms + auto hist_aggr = plugin->SyncEncryptedHistHori( + common::RestoreType(hist_entries.HostSpan())); + + // reinterpret the aggregated histogram as a int64_t and aggregate + auto hist_aggr_64 = common::Span{ + reinterpret_cast(hist_aggr.data()), hist_aggr.size()}; + int num_ranks = collective::GlobalCommGroup()->World(); + for (size_t i = 0; i < n; i++) { + for (int j = 1; j < num_ranks; j++) { + hist_aggr_64[i] = hist_aggr_64[i] + hist_aggr_64[i + j * n]; + } + } + + // copy the aggregated histogram back to GPU memory + cudaMemcpy(erased.data(), hist_aggr_64.data(), erased.size(), cudaMemcpyHostToDevice); + + monitor.Stop(__func__); + } +#endif + /** * \brief Build GPU local histograms for the left and right child of some parent node */ @@ -559,7 +618,16 @@ struct GPUHistMakerDevice { // Reduce all in one go // This gives much better latency in a distributed setting // when processing a large batch - this->AllReduceHist(hist_nidx.at(0), hist_nidx.size()); + // If secure horizontal, perform AllReduce by calling the encryption plugin + if (collective::IsDistributed() && info_.IsRowSplit() && collective::IsEncrypted()) { + #if defined(XGBOOST_USE_FEDERATED) + this->AllReduceHistEncrypted(hist_nidx.at(0), hist_nidx.size()); + #else + LOG(FATAL) << error::NoFederated(); + #endif + } else { + this->AllReduceHist(hist_nidx.at(0), hist_nidx.size()); + } for (size_t i = 0; i < subtraction_nidx.size(); i++) { auto build_hist_nidx = hist_nidx.at(i); @@ -569,7 +637,15 @@ struct GPUHistMakerDevice { if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) { // Calculate other histogram manually this->BuildHist(subtraction_trick_nidx); - this->AllReduceHist(subtraction_trick_nidx, 1); + if (collective::IsDistributed() && info_.IsRowSplit() && collective::IsEncrypted()) { + #if defined(XGBOOST_USE_FEDERATED) + this->AllReduceHistEncrypted(subtraction_trick_nidx, 1); + #else + LOG(FATAL) << error::NoFederated(); + #endif + } else { + this->AllReduceHist(subtraction_trick_nidx, 1); + } } } } @@ -643,7 +719,15 @@ struct GPUHistMakerDevice { hist.AllocateHistograms(ctx_, {kRootNIdx}); this->BuildHist(kRootNIdx); - this->AllReduceHist(kRootNIdx, 1); + if (collective::IsDistributed() && info_.IsRowSplit() && collective::IsEncrypted()) { + #if defined(XGBOOST_USE_FEDERATED) + this->AllReduceHistEncrypted(kRootNIdx, 1); + #else + LOG(FATAL) << error::NoFederated(); + #endif + } else { + this->AllReduceHist(kRootNIdx, 1); + } // Remember root stats auto root_sum = quantiser.ToFloatingPoint(root_sum_quantised);