Skip to content

Commit

Permalink
fix code linting and test scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Aug 1, 2024
1 parent 26aaded commit aa5b51b
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 26 deletions.
6 changes: 4 additions & 2 deletions src/tree/gpu_hist/evaluate_splits.cu
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,8 @@ void GPUHistEvaluator::EvaluateSplits(Context const *ctx, const std::vector<bst_
dh::TemporaryArray<DeviceSplitCandidate> splits_out_storage(d_inputs.size());
auto out_splits = dh::ToSpan(splits_out_storage);

bool is_passive_party = is_column_split_ && collective::IsEncrypted() && collective::GetRank() != 0;
bool is_passive_party = is_column_split_ && collective::IsEncrypted()
&& collective::GetRank() != 0;
bool is_active_party = !is_passive_party;
// Under secure vertical setting, only the active party is able to evaluate the split
// based on global histogram. Other parties will receive the final best split information
Expand All @@ -421,7 +422,8 @@ void GPUHistEvaluator::EvaluateSplits(Context const *ctx, const std::vector<bst_
// With regular column-wise data split, we gather the split candidates from
// all the workers and find the global best candidates.
auto const world_size = collective::GetWorldSize();
dh::TemporaryArray<DeviceSplitCandidate> all_candidate_storage(out_splits.size() * world_size);
dh::TemporaryArray<DeviceSplitCandidate> all_candidate_storage(
out_splits.size() * world_size);
auto all_candidates = dh::ToSpan(all_candidate_storage);
auto current_rank =
all_candidates.subspan(collective::GetRank() * out_splits.size(), out_splits.size());
Expand Down
21 changes: 13 additions & 8 deletions src/tree/gpu_hist/histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -388,23 +388,25 @@ void DeviceHistogramBuilder::BuildHistogram(Context const* ctx,
common::Span<const std::uint32_t> ridx,
common::Span<GradientPairInt64> histogram,
GradientQuantiser rounding, MetaInfo const& info) {

auto IsSecureVertical = !info.IsRowSplit() && collective::IsDistributed() && collective::IsEncrypted();
auto IsSecureVertical = !info.IsRowSplit() && collective::IsDistributed()
&& collective::IsEncrypted();
if (!IsSecureVertical) {
// Regular training, build histogram locally
this->p_impl_->BuildHistogram(ctx->CUDACtx(), matrix, feature_groups, gpair, ridx, histogram, rounding);
this->p_impl_->BuildHistogram(ctx->CUDACtx(), matrix, feature_groups,
gpair, ridx, histogram, rounding);
} else {
// Encrypted vertical, build histogram using federated plugin
auto const &comm = collective::GlobalCommGroup()->Ctx(ctx, DeviceOrd::CPU());
auto const &fed = dynamic_cast<collective::FederatedComm const &>(comm);
auto plugin = fed.EncryptionPlugin();

// Transmit matrix to plugin
if(!is_aggr_context_initialized_){
if (!is_aggr_context_initialized_) {
// Get cutptrs
std::vector<uint32_t> h_cuts_ptr(matrix.feature_segments.size());
dh::CopyDeviceSpanToVector(&h_cuts_ptr, matrix.feature_segments);
common::Span<std::uint32_t const> cutptrs = common::Span<std::uint32_t const>(h_cuts_ptr.data(), h_cuts_ptr.size());
common::Span<std::uint32_t const> cutptrs =
common::Span<std::uint32_t const>(h_cuts_ptr.data(), h_cuts_ptr.size());

// Get bin_idx matrix
auto kRows = matrix.n_rows;
Expand All @@ -414,7 +416,8 @@ void DeviceHistogramBuilder::BuildHistogram(Context const* ctx,
thrust::device_vector<bst_float> matrix_d(kRows * kCols);
dh::LaunchN(kRows * kCols, ReadMatrixFunction(matrix, kCols, matrix_d.data().get()));
thrust::copy(matrix_d.begin(), matrix_d.end(), h_bin_idx.begin());
common::Span<std::int32_t const> bin_idx = common::Span<std::int32_t const>(h_bin_idx.data(), h_bin_idx.size());
common::Span<std::int32_t const> bin_idx =
common::Span<std::int32_t const>(h_bin_idx.data(), h_bin_idx.size());

// Initialize plugin context
plugin->Reset(h_cuts_ptr, h_bin_idx);
Expand Down Expand Up @@ -443,12 +446,14 @@ void DeviceHistogramBuilder::BuildHistogram(Context const* ctx,
HostDeviceVector<std::int8_t> hist_entries;
std::vector<std::int64_t> recv_segments;
collective::SafeColl(
collective::AllgatherV(ctx, linalg::MakeVec(hist_data), &recv_segments, &hist_entries));
collective::AllgatherV(ctx, linalg::MakeVec(hist_data),
&recv_segments, &hist_entries));

// Call the plugin here to get the resulting histogram. Histogram from all workers are
// gathered to the label owner.
common::Span<double> hist_aggr =
plugin->SyncEncryptedHistVert(common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));
plugin->SyncEncryptedHistVert(
common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));

// Post process the AllGathered data
auto world_size = collective::GetWorldSize();
Expand Down
28 changes: 14 additions & 14 deletions tests/cpp/tree/gpu_hist/test_histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo());
DeviceHistogramBuilder builder;
builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(&ctx, page->GetDeviceAccessor(ctx.Device()),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
d_histogram, quantiser);
d_histogram, quantiser, MetaInfo());

std::vector<GradientPairInt64> histogram_h(num_bins);
dh::safe_cuda(cudaMemcpy(histogram_h.data(), d_histogram.data(),
Expand All @@ -95,9 +95,9 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
auto quantiser = GradientQuantiser(&ctx, gpair.DeviceSpan(), MetaInfo());
DeviceHistogramBuilder builder;
builder.Reset(&ctx, feature_groups.DeviceAccessor(ctx.Device()), force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(&ctx, page->GetDeviceAccessor(ctx.Device()),
feature_groups.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
d_new_histogram, quantiser);
d_new_histogram, quantiser, MetaInfo());

std::vector<GradientPairInt64> new_histogram_h(num_bins);
dh::safe_cuda(cudaMemcpy(new_histogram_h.data(), d_new_histogram.data(),
Expand All @@ -119,9 +119,9 @@ void TestDeterministicHistogram(bool is_dense, int shm_size, bool force_global)
dh::device_vector<GradientPairInt64> baseline(num_bins);
DeviceHistogramBuilder builder;
builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), force_global);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(&ctx, page->GetDeviceAccessor(ctx.Device()),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(baseline), quantiser);
dh::ToSpan(baseline), quantiser, MetaInfo());

std::vector<GradientPairInt64> baseline_h(num_bins);
dh::safe_cuda(cudaMemcpy(baseline_h.data(), baseline.data().get(),
Expand Down Expand Up @@ -185,9 +185,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
FeatureGroups single_group(page->Cuts());
DeviceHistogramBuilder builder;
builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), false);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(&ctx, page->GetDeviceAccessor(ctx.Device()),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(cat_hist), quantiser);
dh::ToSpan(cat_hist), quantiser, MetaInfo());
}

/**
Expand All @@ -201,9 +201,9 @@ void TestGPUHistogramCategorical(size_t num_categories) {
FeatureGroups single_group(page->Cuts());
DeviceHistogramBuilder builder;
builder.Reset(&ctx, single_group.DeviceAccessor(ctx.Device()), false);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(&ctx, page->GetDeviceAccessor(ctx.Device()),
single_group.DeviceAccessor(ctx.Device()), gpair.DeviceSpan(), ridx,
dh::ToSpan(encode_hist), quantiser);
dh::ToSpan(encode_hist), quantiser, MetaInfo());
}

std::vector<GradientPairInt64> h_cat_hist(cat_hist.size());
Expand Down Expand Up @@ -350,9 +350,9 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
auto d_histogram = dh::ToSpan(multi_hist);
DeviceHistogramBuilder builder;
builder.Reset(&ctx, fg->DeviceAccessor(ctx.Device()), force_global);
builder.BuildHistogram(ctx.CUDACtx(), impl->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(&ctx, impl->GetDeviceAccessor(ctx.Device()),
fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx,
d_histogram, quantiser);
d_histogram, quantiser, MetaInfo());
++k;
}
ASSERT_EQ(k, n_batches);
Expand All @@ -373,9 +373,9 @@ class HistogramExternalMemoryTest : public ::testing::TestWithParam<std::tuple<f
auto d_histogram = dh::ToSpan(single_hist);
DeviceHistogramBuilder builder;
builder.Reset(&ctx, fg->DeviceAccessor(ctx.Device()), force_global);
builder.BuildHistogram(ctx.CUDACtx(), page.GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(&ctx, page.GetDeviceAccessor(ctx.Device()),
fg->DeviceAccessor(ctx.Device()), gpair.ConstDeviceSpan(), ridx,
d_histogram, quantiser);
d_histogram, quantiser, MetaInfo());
}

std::vector<GradientPairInt64> h_single(single_hist.size());
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/tree/test_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ void TestBuildHist(bool use_shared_memory_histograms) {
DeviceHistogramBuilder builder;
builder.Reset(&ctx, maker.feature_groups->DeviceAccessor(ctx.Device()),
!use_shared_memory_histograms);
builder.BuildHistogram(ctx.CUDACtx(), page->GetDeviceAccessor(ctx.Device()),
builder.BuildHistogram(&ctx, page->GetDeviceAccessor(ctx.Device()),
maker.feature_groups->DeviceAccessor(ctx.Device()), gpair.DeviceSpan(),
maker.row_partitioner->GetRows(0), maker.hist.GetNodeHistogram(0),
*maker.quantiser);
*maker.quantiser, MetaInfo());

DeviceHistogramStorage<>& d_hist = maker.hist;

Expand Down

0 comments on commit aa5b51b

Please sign in to comment.