diff --git a/src/collective/aggregator.h b/src/collective/aggregator.h index 8ebd39b900fe..ed548a0f173b 100644 --- a/src/collective/aggregator.h +++ b/src/collective/aggregator.h @@ -230,6 +230,10 @@ void BroadcastGradient(Context const* ctx, MetaInfo const& info, GradFn&& grad_f #else LOG(FATAL) << error::NoFederated(); #endif + + // !!!Temporarily turn on regular gradient broadcasting for testing + // encrypted vertical + ApplyWithLabels(ctx, info, out_gpair->Data(), [&] { grad_fn(out_gpair); }); } else { ApplyWithLabels(ctx, info, out_gpair->Data(), [&] { grad_fn(out_gpair); }); } diff --git a/src/common/quantile.cu b/src/common/quantile.cu index d0356ae421c7..d5e927b30414 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -673,6 +673,7 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i } } + auto secure_vertical = is_column_split && collective::IsEncrypted(); // Set up output cuts for (bst_feature_t i = 0; i < num_columns_; ++i) { size_t column_size = std::max(static_cast(1ul), this->Column(i).size()); @@ -681,6 +682,11 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i CheckMaxCat(max_values[i].value, column_size); h_out_columns_ptr.push_back(max_values[i].value + 1); // includes both max_cat and 0. } else { + // If vertical and secure mode, we need to sync the max_num_bins aross workers + // to create the same global number of cut point bins for easier future processing + if (secure_vertical) { + collective::SafeColl(collective::Allreduce(ctx, &column_size, collective::Op::kMax)); + } h_out_columns_ptr.push_back( std::min(static_cast(column_size), static_cast(num_bins_))); } @@ -711,6 +717,10 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i out_column[0] = kRtEps; assert(out_column.size() == 1); } + // For secure vertical split, fill all cut values with dummy value + if (secure_vertical) { + out_column[idx] = kRtEps; + } return; } @@ -736,6 +746,19 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i out_column[idx] = in_column[idx+1].value; }); + if (secure_vertical) { + // cut values need to be synced across all workers via Allreduce + // To do: apply same inference indexing as CPU, skip for now + auto cut_values_device = p_cuts->cut_values_.DeviceSpan(); + std::vector cut_values_host(cut_values_device.size()); + dh::CopyDeviceSpanToVector(&cut_values_host, cut_values_device); + auto rc = collective::Allreduce(ctx, &cut_values_host, collective::Op::kSum); + SafeColl(rc); + dh::safe_cuda(cudaMemcpyAsync(cut_values_device.data(), cut_values_host.data(), + cut_values_device.size() * sizeof(float), + cudaMemcpyHostToDevice)); + } + p_cuts->SetCategorical(this->has_categorical_, max_cat); timer_.Stop(__func__); } diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index 5e225a13f142..3b5c1b76fa1c 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -6,6 +6,7 @@ #include #include "../../collective/allgather.h" +#include "../../collective/broadcast.h" #include "../../common/categorical.h" #include "../../data/ellpack_page.cuh" #include "evaluate_splits.cuh" @@ -404,34 +405,48 @@ void GPUHistEvaluator::EvaluateSplits(Context const *ctx, const std::vector splits_out_storage(d_inputs.size()); auto out_splits = dh::ToSpan(splits_out_storage); - this->LaunchEvaluateSplits(max_active_features, d_inputs, shared_inputs, - evaluator, out_splits); - if (is_column_split_) { - // With column-wise data split, we gather the split candidates from all the workers and find the - // global best candidates. - auto const world_size = collective::GetWorldSize(); - dh::TemporaryArray all_candidate_storage(out_splits.size() * world_size); - auto all_candidates = dh::ToSpan(all_candidate_storage); - auto current_rank = - all_candidates.subspan(collective::GetRank() * out_splits.size(), out_splits.size()); - dh::safe_cuda(cudaMemcpyAsync(current_rank.data(), out_splits.data(), - out_splits.size() * sizeof(DeviceSplitCandidate), - cudaMemcpyDeviceToDevice)); - auto rc = collective::Allgather( - ctx, linalg::MakeVec(all_candidates.data(), all_candidates.size(), ctx->Device())); - collective::SafeColl(rc); - - // Reduce to get the best candidate from all workers. - dh::LaunchN(out_splits.size(), ctx->CUDACtx()->Stream(), - [world_size, all_candidates, out_splits] __device__(size_t i) { - out_splits[i] = all_candidates[i]; - for (auto rank = 1; rank < world_size; rank++) { - out_splits[i] = out_splits[i] + all_candidates[rank * out_splits.size() + i]; - } - }); + bool is_passive_party = is_column_split_ && collective::IsEncrypted() && collective::GetRank() != 0; + bool is_active_party = !is_passive_party; + // Under secure vertical setting, only the active party is able to evaluate the split + // based on global histogram. Other parties will receive the final best split information + // Hence the below computation is not performed by the passive parties + if (is_active_party) { + this->LaunchEvaluateSplits(max_active_features, d_inputs, shared_inputs, + evaluator, out_splits); } + if (is_column_split_) { + if (!collective::IsEncrypted()) { + // With regular column-wise data split, we gather the split candidates from + // all the workers and find the global best candidates. + auto const world_size = collective::GetWorldSize(); + dh::TemporaryArray all_candidate_storage(out_splits.size() * world_size); + auto all_candidates = dh::ToSpan(all_candidate_storage); + auto current_rank = + all_candidates.subspan(collective::GetRank() * out_splits.size(), out_splits.size()); + dh::safe_cuda(cudaMemcpyAsync(current_rank.data(), out_splits.data(), + out_splits.size() * sizeof(DeviceSplitCandidate), + cudaMemcpyDeviceToDevice)); + auto rc = collective::Allgather( + ctx, linalg::MakeVec(all_candidates.data(), all_candidates.size(), ctx->Device())); + collective::SafeColl(rc); + // Reduce to get the best candidate from all workers. + dh::LaunchN(out_splits.size(), ctx->CUDACtx()->Stream(), + [world_size, all_candidates, out_splits] __device__(size_t i) { + out_splits[i] = all_candidates[i]; + for (auto rank = 1; rank < world_size; rank++) { + out_splits[i] = out_splits[i] + all_candidates[rank * out_splits.size() + i]; + } + }); + } else { + // With encrypted column-wise data split, we distribute the best split candidates + // from Rank 0 to all other workers + auto rc = collective::Broadcast( + ctx, linalg::MakeVec(out_splits.data(), out_splits.size(), ctx->Device()), 0); + collective::SafeColl(rc); + } + } auto d_sorted_idx = this->SortedIdx(d_inputs.size(), shared_inputs.feature_values.size()); auto d_entries = out_entries; auto device_cats_accessor = this->DeviceCatStorage(nidx); diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 372a5c09ba0c..d0d2634a1363 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -14,6 +14,13 @@ #include "row_partitioner.cuh" #include "xgboost/base.h" +#include "../../common/device_helpers.cuh" +#if defined(XGBOOST_USE_FEDERATED) +#include "../../../plugin/federated/federated_hist.h" // for FederataedHistPolicy +#else +#include "../../common/error_msg.h" // for NoFederated +#endif + namespace xgboost::tree { namespace { struct Pair { @@ -309,6 +316,21 @@ class DeviceHistogramBuilderImpl { bool force_global_memory) { this->kernel_ = std::make_unique>(ctx, feature_groups, force_global_memory); this->force_global_memory_ = force_global_memory; + + + std::cout << "Reset DeviceHistogramBuilderImpl" << std::endl; + + // Reset federated plugin + // start of every round, transmit the matrix to plugin + #if defined(XGBOOST_USE_FEDERATED) + // Get encryption plugin + auto const &comm = collective::GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU()); + auto const &fed = dynamic_cast(comm); + auto plugin = fed.EncryptionPlugin(); + // Reset plugin + //plugin->Reset(); + #endif + } void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix, @@ -354,13 +376,64 @@ void DeviceHistogramBuilder::Reset(Context const* ctx, FeatureGroupsAccessor con this->p_impl_->Reset(ctx, feature_groups, force_global_memory); } -void DeviceHistogramBuilder::BuildHistogram(CUDAContext const* ctx, +void DeviceHistogramBuilder::BuildHistogram(Context const* ctx, EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span ridx, common::Span histogram, - GradientQuantiser rounding) { - this->p_impl_->BuildHistogram(ctx, matrix, feature_groups, gpair, ridx, histogram, rounding); + GradientQuantiser rounding, MetaInfo const& info) { + + auto IsSecureVertical = !info.IsRowSplit() && collective::IsDistributed() && collective::IsEncrypted(); + if (!IsSecureVertical) { + // Regular training, build histogram locally + this->p_impl_->BuildHistogram(ctx->CUDACtx(), matrix, feature_groups, gpair, ridx, histogram, rounding); + } else { + // Encrypted vertical, build histogram using federated plugin + auto const &comm = collective::GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU()); + auto const &fed = dynamic_cast(comm); + auto plugin = fed.EncryptionPlugin(); + // Transmit matrix to plugin + //plugin->TransmitMatrix(matrix); + // Transmit row indices to plugin + //plugin->TransmitRowIndices(ridx); + + // !!!Temporarily turn on regular histogram building for testing + // encrypted vertical + this->p_impl_->BuildHistogram(ctx->CUDACtx(), matrix, feature_groups, gpair, ridx, histogram, rounding); + + // Further histogram sync process - simulated + // only the last stage is needed under plugin system + + // copy histogram data to host + std::vector host_histogram(histogram.size()); + dh::CopyDeviceSpanToVector(&host_histogram, histogram); + // convert to regular vector + std::vector host_histogram_64(histogram.size() * 2); + for (auto i = 0; i < host_histogram.size(); i++) { + host_histogram_64[i * 2] = host_histogram[i].GetQuantisedGrad(); + host_histogram_64[i * 2 + 1] = host_histogram[i].GetQuantisedHess(); + } + // aggregate histograms in float + auto rc = collective::Allreduce(ctx, &host_histogram_64, collective::Op::kSum); + SafeColl(rc); + // convert back to GradientPairInt64 + // only copy to Rank 0, clear other ranks to simulate the plugin scenario + for (auto i = 0; i < host_histogram.size(); i++) { + GradientPairInt64 hist_item(host_histogram_64[i * 2], host_histogram_64[i * 2 + 1]); + GradientPairInt64 hist_item_empty(0, 0); + if (collective::GetRank() != 0) { + hist_item = hist_item_empty; + } else { + host_histogram[i] = hist_item; + } + } + // copy the aggregated histogram back to GPU memory + // at this point, the histogram contains full information from all parties + dh::safe_cuda(cudaMemcpyAsync(histogram.data(), host_histogram.data(), + histogram.size() * sizeof(GradientPairPrecise), + cudaMemcpyHostToDevice)); + + } } } // namespace xgboost::tree diff --git a/src/tree/gpu_hist/histogram.cuh b/src/tree/gpu_hist/histogram.cuh index 87c60a8bfdbc..5c0d58f2df46 100644 --- a/src/tree/gpu_hist/histogram.cuh +++ b/src/tree/gpu_hist/histogram.cuh @@ -178,11 +178,12 @@ class DeviceHistogramBuilder { void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups, bool force_global_memory); - void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix, + void BuildHistogram(Context const* ctx, EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, common::Span gpair, common::Span ridx, - common::Span histogram, GradientQuantiser rounding); + common::Span histogram, GradientQuantiser rounding, + MetaInfo const& info); }; } // namespace xgboost::tree #endif // HISTOGRAM_CUH_ diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 088fc199786b..d588e4db9d2b 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -249,9 +249,9 @@ struct GPUHistMakerDevice { void BuildHist(int nidx) { auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_ridx = row_partitioner->GetRows(nidx); - this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()), + this->histogram_.BuildHistogram(ctx_, page->GetDeviceAccessor(ctx_->Device()), feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx, - d_node_hist, *quantiser); + d_node_hist, *quantiser, info_); } // Attempt to do subtraction trick