Skip to content

Commit

Permalink
implement alternative vertical pipeline in GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Jul 30, 2024
1 parent 3cc863a commit f2b876d
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 32 deletions.
4 changes: 4 additions & 0 deletions src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,10 @@ void BroadcastGradient(Context const* ctx, MetaInfo const& info, GradFn&& grad_f
#else
LOG(FATAL) << error::NoFederated();
#endif

// !!!Temporarily turn on regular gradient broadcasting for testing
// encrypted vertical
ApplyWithLabels(ctx, info, out_gpair->Data(), [&] { grad_fn(out_gpair); });
} else {
ApplyWithLabels(ctx, info, out_gpair->Data(), [&] { grad_fn(out_gpair); });
}
Expand Down
23 changes: 23 additions & 0 deletions src/common/quantile.cu
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,7 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
}
}

auto secure_vertical = is_column_split && collective::IsEncrypted();
// Set up output cuts
for (bst_feature_t i = 0; i < num_columns_; ++i) {
size_t column_size = std::max(static_cast<size_t>(1ul), this->Column(i).size());
Expand All @@ -681,6 +682,11 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
CheckMaxCat(max_values[i].value, column_size);
h_out_columns_ptr.push_back(max_values[i].value + 1); // includes both max_cat and 0.
} else {
// If vertical and secure mode, we need to sync the max_num_bins aross workers
// to create the same global number of cut point bins for easier future processing
if (secure_vertical) {
collective::SafeColl(collective::Allreduce(ctx, &column_size, collective::Op::kMax));
}
h_out_columns_ptr.push_back(
std::min(static_cast<size_t>(column_size), static_cast<size_t>(num_bins_)));
}
Expand Down Expand Up @@ -711,6 +717,10 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
out_column[0] = kRtEps;
assert(out_column.size() == 1);
}
// For secure vertical split, fill all cut values with dummy value
if (secure_vertical) {
out_column[idx] = kRtEps;
}
return;
}

Expand All @@ -736,6 +746,19 @@ void SketchContainer::MakeCuts(Context const* ctx, HistogramCuts* p_cuts, bool i
out_column[idx] = in_column[idx+1].value;
});

if (secure_vertical) {
// cut values need to be synced across all workers via Allreduce
// To do: apply same inference indexing as CPU, skip for now
auto cut_values_device = p_cuts->cut_values_.DeviceSpan();
std::vector<float> cut_values_host(cut_values_device.size());
dh::CopyDeviceSpanToVector(&cut_values_host, cut_values_device);
auto rc = collective::Allreduce(ctx, &cut_values_host, collective::Op::kSum);
SafeColl(rc);
dh::safe_cuda(cudaMemcpyAsync(cut_values_device.data(), cut_values_host.data(),
cut_values_device.size() * sizeof(float),
cudaMemcpyHostToDevice));
}

p_cuts->SetCategorical(this->has_categorical_, max_cat);
timer_.Stop(__func__);
}
Expand Down
65 changes: 40 additions & 25 deletions src/tree/gpu_hist/evaluate_splits.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <limits>

#include "../../collective/allgather.h"
#include "../../collective/broadcast.h"
#include "../../common/categorical.h"
#include "../../data/ellpack_page.cuh"
#include "evaluate_splits.cuh"
Expand Down Expand Up @@ -404,34 +405,48 @@ void GPUHistEvaluator::EvaluateSplits(Context const *ctx, const std::vector<bst_

dh::TemporaryArray<DeviceSplitCandidate> splits_out_storage(d_inputs.size());
auto out_splits = dh::ToSpan(splits_out_storage);
this->LaunchEvaluateSplits(max_active_features, d_inputs, shared_inputs,
evaluator, out_splits);

if (is_column_split_) {
// With column-wise data split, we gather the split candidates from all the workers and find the
// global best candidates.
auto const world_size = collective::GetWorldSize();
dh::TemporaryArray<DeviceSplitCandidate> all_candidate_storage(out_splits.size() * world_size);
auto all_candidates = dh::ToSpan(all_candidate_storage);
auto current_rank =
all_candidates.subspan(collective::GetRank() * out_splits.size(), out_splits.size());
dh::safe_cuda(cudaMemcpyAsync(current_rank.data(), out_splits.data(),
out_splits.size() * sizeof(DeviceSplitCandidate),
cudaMemcpyDeviceToDevice));
auto rc = collective::Allgather(
ctx, linalg::MakeVec(all_candidates.data(), all_candidates.size(), ctx->Device()));
collective::SafeColl(rc);

// Reduce to get the best candidate from all workers.
dh::LaunchN(out_splits.size(), ctx->CUDACtx()->Stream(),
[world_size, all_candidates, out_splits] __device__(size_t i) {
out_splits[i] = all_candidates[i];
for (auto rank = 1; rank < world_size; rank++) {
out_splits[i] = out_splits[i] + all_candidates[rank * out_splits.size() + i];
}
});
bool is_passive_party = is_column_split_ && collective::IsEncrypted() && collective::GetRank() != 0;
bool is_active_party = !is_passive_party;
// Under secure vertical setting, only the active party is able to evaluate the split
// based on global histogram. Other parties will receive the final best split information
// Hence the below computation is not performed by the passive parties
if (is_active_party) {
this->LaunchEvaluateSplits(max_active_features, d_inputs, shared_inputs,
evaluator, out_splits);
}

if (is_column_split_) {
if (!collective::IsEncrypted()) {
// With regular column-wise data split, we gather the split candidates from
// all the workers and find the global best candidates.
auto const world_size = collective::GetWorldSize();
dh::TemporaryArray<DeviceSplitCandidate> all_candidate_storage(out_splits.size() * world_size);
auto all_candidates = dh::ToSpan(all_candidate_storage);
auto current_rank =
all_candidates.subspan(collective::GetRank() * out_splits.size(), out_splits.size());
dh::safe_cuda(cudaMemcpyAsync(current_rank.data(), out_splits.data(),
out_splits.size() * sizeof(DeviceSplitCandidate),
cudaMemcpyDeviceToDevice));
auto rc = collective::Allgather(
ctx, linalg::MakeVec(all_candidates.data(), all_candidates.size(), ctx->Device()));
collective::SafeColl(rc);
// Reduce to get the best candidate from all workers.
dh::LaunchN(out_splits.size(), ctx->CUDACtx()->Stream(),
[world_size, all_candidates, out_splits] __device__(size_t i) {
out_splits[i] = all_candidates[i];
for (auto rank = 1; rank < world_size; rank++) {
out_splits[i] = out_splits[i] + all_candidates[rank * out_splits.size() + i];
}
});
} else {
// With encrypted column-wise data split, we distribute the best split candidates
// from Rank 0 to all other workers
auto rc = collective::Broadcast(
ctx, linalg::MakeVec(out_splits.data(), out_splits.size(), ctx->Device()), 0);
collective::SafeColl(rc);
}
}
auto d_sorted_idx = this->SortedIdx(d_inputs.size(), shared_inputs.feature_values.size());
auto d_entries = out_entries;
auto device_cats_accessor = this->DeviceCatStorage(nidx);
Expand Down
79 changes: 76 additions & 3 deletions src/tree/gpu_hist/histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
#include "row_partitioner.cuh"
#include "xgboost/base.h"

#include "../../common/device_helpers.cuh"
#if defined(XGBOOST_USE_FEDERATED)
#include "../../../plugin/federated/federated_hist.h" // for FederataedHistPolicy
#else
#include "../../common/error_msg.h" // for NoFederated
#endif

namespace xgboost::tree {
namespace {
struct Pair {
Expand Down Expand Up @@ -309,6 +316,21 @@ class DeviceHistogramBuilderImpl {
bool force_global_memory) {
this->kernel_ = std::make_unique<HistogramKernel<>>(ctx, feature_groups, force_global_memory);
this->force_global_memory_ = force_global_memory;


std::cout << "Reset DeviceHistogramBuilderImpl" << std::endl;

// Reset federated plugin
// start of every round, transmit the matrix to plugin
#if defined(XGBOOST_USE_FEDERATED)
// Get encryption plugin
auto const &comm = collective::GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU());
auto const &fed = dynamic_cast<collective::FederatedComm const &>(comm);
auto plugin = fed.EncryptionPlugin();
// Reset plugin
//plugin->Reset();
#endif

}

void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
Expand Down Expand Up @@ -354,13 +376,64 @@ void DeviceHistogramBuilder::Reset(Context const* ctx, FeatureGroupsAccessor con
this->p_impl_->Reset(ctx, feature_groups, force_global_memory);
}

void DeviceHistogramBuilder::BuildHistogram(CUDAContext const* ctx,
void DeviceHistogramBuilder::BuildHistogram(Context const* ctx,
EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const std::uint32_t> ridx,
common::Span<GradientPairInt64> histogram,
GradientQuantiser rounding) {
this->p_impl_->BuildHistogram(ctx, matrix, feature_groups, gpair, ridx, histogram, rounding);
GradientQuantiser rounding, MetaInfo const& info) {

auto IsSecureVertical = !info.IsRowSplit() && collective::IsDistributed() && collective::IsEncrypted();
if (!IsSecureVertical) {
// Regular training, build histogram locally
this->p_impl_->BuildHistogram(ctx->CUDACtx(), matrix, feature_groups, gpair, ridx, histogram, rounding);
} else {
// Encrypted vertical, build histogram using federated plugin
auto const &comm = collective::GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU());
auto const &fed = dynamic_cast<collective::FederatedComm const &>(comm);
auto plugin = fed.EncryptionPlugin();
// Transmit matrix to plugin
//plugin->TransmitMatrix(matrix);
// Transmit row indices to plugin
//plugin->TransmitRowIndices(ridx);

// !!!Temporarily turn on regular histogram building for testing
// encrypted vertical
this->p_impl_->BuildHistogram(ctx->CUDACtx(), matrix, feature_groups, gpair, ridx, histogram, rounding);

// Further histogram sync process - simulated
// only the last stage is needed under plugin system

// copy histogram data to host
std::vector<GradientPairInt64> host_histogram(histogram.size());
dh::CopyDeviceSpanToVector(&host_histogram, histogram);
// convert to regular vector
std::vector<std::int64_t> host_histogram_64(histogram.size() * 2);
for (auto i = 0; i < host_histogram.size(); i++) {
host_histogram_64[i * 2] = host_histogram[i].GetQuantisedGrad();
host_histogram_64[i * 2 + 1] = host_histogram[i].GetQuantisedHess();
}
// aggregate histograms in float
auto rc = collective::Allreduce(ctx, &host_histogram_64, collective::Op::kSum);
SafeColl(rc);
// convert back to GradientPairInt64
// only copy to Rank 0, clear other ranks to simulate the plugin scenario
for (auto i = 0; i < host_histogram.size(); i++) {
GradientPairInt64 hist_item(host_histogram_64[i * 2], host_histogram_64[i * 2 + 1]);
GradientPairInt64 hist_item_empty(0, 0);
if (collective::GetRank() != 0) {
hist_item = hist_item_empty;
} else {
host_histogram[i] = hist_item;
}
}
// copy the aggregated histogram back to GPU memory
// at this point, the histogram contains full information from all parties
dh::safe_cuda(cudaMemcpyAsync(histogram.data(), host_histogram.data(),
histogram.size() * sizeof(GradientPairPrecise),
cudaMemcpyHostToDevice));

}
}
} // namespace xgboost::tree
5 changes: 3 additions & 2 deletions src/tree/gpu_hist/histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,12 @@ class DeviceHistogramBuilder {

void Reset(Context const* ctx, FeatureGroupsAccessor const& feature_groups,
bool force_global_memory);
void BuildHistogram(CUDAContext const* ctx, EllpackDeviceAccessor const& matrix,
void BuildHistogram(Context const* ctx, EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const std::uint32_t> ridx,
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding);
common::Span<GradientPairInt64> histogram, GradientQuantiser rounding,
MetaInfo const& info);
};
} // namespace xgboost::tree
#endif // HISTOGRAM_CUH_
4 changes: 2 additions & 2 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ struct GPUHistMakerDevice {
void BuildHist(int nidx) {
auto d_node_hist = hist.GetNodeHistogram(nidx);
auto d_ridx = row_partitioner->GetRows(nidx);
this->histogram_.BuildHistogram(ctx_->CUDACtx(), page->GetDeviceAccessor(ctx_->Device()),
this->histogram_.BuildHistogram(ctx_, page->GetDeviceAccessor(ctx_->Device()),
feature_groups->DeviceAccessor(ctx_->Device()), gpair, d_ridx,
d_node_hist, *quantiser);
d_node_hist, *quantiser, info_);
}

// Attempt to do subtraction trick
Expand Down

0 comments on commit f2b876d

Please sign in to comment.