Skip to content

Commit

Permalink
only rank 0 need histogram sync result
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 authored Aug 1, 2024
1 parent 7480ed3 commit 4587b2e
Showing 1 changed file with 26 additions and 23 deletions.
49 changes: 26 additions & 23 deletions src/tree/gpu_hist/histogram.cu
Original file line number Diff line number Diff line change
Expand Up @@ -451,36 +451,39 @@ void DeviceHistogramBuilder::BuildHistogram(Context const* ctx,
&recv_segments, &hist_entries));

// Call the plugin here to get the resulting histogram. Histogram from all workers are
// gathered to the label owner.
// gathered to the label owner.
common::Span<double> hist_aggr =
plugin->SyncEncryptedHistVert(
common::RestoreType<std::uint8_t>(hist_entries.HostSpan()));

// Post process the AllGathered data
auto world_size = collective::GetWorldSize();
std::vector<GradientPairInt64> host_histogram(histogram.size());
for (auto i = 0; i < histogram.size(); i++) {
double grad = 0.0;
double hess = 0.0;
for (auto rank = 0; rank < world_size; ++rank) {
auto idx = rank * histogram.size() + i;
grad += hist_aggr[idx * 2];
hess += hist_aggr[idx * 2 + 1];
// This is only needed by Rank 0
if (collective::GetRank() == 0) {
auto world_size = collective::GetWorldSize();
std::vector<GradientPairInt64> host_histogram(histogram.size());
for (auto i = 0; i < histogram.size(); i++) {
double grad = 0.0;
double hess = 0.0;
for (auto rank = 0; rank < world_size; ++rank) {
auto idx = rank * histogram.size() + i;
grad += hist_aggr[idx * 2];
hess += hist_aggr[idx * 2 + 1];
}
GradientPairPrecise hist_item(grad, hess);
GradientPairPrecise hist_item_empty(0.0, 0.0);
if (collective::GetRank() != 0) {
host_histogram[i] = rounding.ToFixedPoint(hist_item_empty);
} else {
host_histogram[i] = rounding.ToFixedPoint(hist_item);
}
}
GradientPairPrecise hist_item(grad, hess);
GradientPairPrecise hist_item_empty(0.0, 0.0);
if (collective::GetRank() != 0) {
host_histogram[i] = rounding.ToFixedPoint(hist_item_empty);
} else {
host_histogram[i] = rounding.ToFixedPoint(hist_item);
}
}

// copy the aggregated histogram back to GPU memory
// at this point, the histogram contains full information from all parties
dh::safe_cuda(cudaMemcpyAsync(histogram.data(), host_histogram.data(),
histogram.size() * sizeof(GradientPairInt64),
cudaMemcpyHostToDevice));
// copy the aggregated histogram back to GPU memory
// at this point, the histogram contains full information from all parties
dh::safe_cuda(cudaMemcpyAsync(histogram.data(), host_histogram.data(),
histogram.size() * sizeof(GradientPairInt64),
cudaMemcpyHostToDevice));
}
#else
LOG(FATAL) << error::NoFederated();
#endif
Expand Down

0 comments on commit 4587b2e

Please sign in to comment.