Skip to content

Commit

Permalink
Require context in aggregators. (dmlc#10075)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Feb 27, 2024
1 parent 761845f commit 5ac2332
Show file tree
Hide file tree
Showing 23 changed files with 190 additions and 144 deletions.
2 changes: 1 addition & 1 deletion .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ AllowShortEnumsOnASingleLine: true
AllowShortBlocksOnASingleLine: Never
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: All
AllowShortLambdasOnASingleLine: All
AllowShortLambdasOnASingleLine: Inline
AllowShortIfStatementsOnASingleLine: WithoutElse
AllowShortLoopsOnASingleLine: true
AlwaysBreakAfterDefinitionReturnType: None
Expand Down
12 changes: 10 additions & 2 deletions include/xgboost/collective/result.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once

#include <xgboost/logging.h>

#include <memory> // for unique_ptr
#include <sstream> // for stringstream
#include <stack> // for stack
Expand Down Expand Up @@ -160,10 +162,16 @@ struct Result {

// We don't have monad, a simple helper would do.
template <typename Fn>
Result operator<<(Result&& r, Fn&& fn) {
[[nodiscard]] Result operator<<(Result&& r, Fn&& fn) {
if (!r.OK()) {
return std::forward<Result>(r);
}
return fn();
}

inline void SafeColl(Result const& rc) {
if (!rc.OK()) {
LOG(FATAL) << rc.Report();
}
}
} // namespace xgboost::collective
46 changes: 25 additions & 21 deletions src/collective/aggregator.h
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
/**
* Copyright 2023 by XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*
* Higher level functions built on top the Communicator API, taking care of behavioral differences
* between row-split vs column-split distributed training, and horizontal vs vertical federated
* learning.
*/
#pragma once
#include <xgboost/data.h>

#include <limits>
#include <string>
#include <utility>
#include <vector>

#include "communicator-inl.h"
#include "xgboost/collective/result.h" // for Result
#include "xgboost/data.h" // for MetaINfo

namespace xgboost {
namespace collective {
namespace xgboost::collective {

/**
* @brief Apply the given function where the labels are.
Expand All @@ -31,15 +30,16 @@ namespace collective {
* @param size The size of the buffer.
* @param function The function used to calculate the results.
*/
template <typename Function>
void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function) {
template <typename FN>
void ApplyWithLabels(Context const*, MetaInfo const& info, void* buffer, std::size_t size,
FN&& function) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::string message;
if (collective::GetRank() == 0) {
try {
std::forward<Function>(function)();
std::forward<FN>(function)();
} catch (dmlc::Error& e) {
message = e.what();
}
Expand All @@ -52,7 +52,7 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&&
LOG(FATAL) << &message[0];
}
} else {
std::forward<Function>(function)();
std::forward<FN>(function)();
}
}

Expand All @@ -70,7 +70,8 @@ void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&&
* @param function The function used to calculate the results.
*/
template <typename T, typename Function>
void ApplyWithLabels(MetaInfo const& info, HostDeviceVector<T>* result, Function&& function) {
void ApplyWithLabels(Context const*, MetaInfo const& info, HostDeviceVector<T>* result,
Function&& function) {
if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
Expand Down Expand Up @@ -114,7 +115,9 @@ void ApplyWithLabels(MetaInfo const& info, HostDeviceVector<T>* result, Function
* @return The global max of the input.
*/
template <typename T>
T GlobalMax(MetaInfo const& info, T value) {
std::enable_if_t<std::is_trivially_copy_assignable_v<T>, T> GlobalMax(Context const*,
MetaInfo const& info,
T value) {
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kMax>(&value, 1);
}
Expand All @@ -132,16 +135,18 @@ T GlobalMax(MetaInfo const& info, T value) {
* @param values Pointer to the inputs to sum.
* @param size Number of values to sum.
*/
template <typename T>
void GlobalSum(MetaInfo const& info, T* values, size_t size) {
template <typename T, std::int32_t kDim>
[[nodiscard]] Result GlobalSum(Context const*, MetaInfo const& info,
linalg::TensorView<T, kDim> values) {
if (info.IsRowSplit()) {
collective::Allreduce<collective::Operation::kSum>(values, size);
collective::Allreduce<collective::Operation::kSum>(values.Values().data(), values.Size());
}
return Success();
}

template <typename Container>
void GlobalSum(MetaInfo const& info, Container* values) {
GlobalSum(info, values->data(), values->size());
[[nodiscard]] Result GlobalSum(Context const* ctx, MetaInfo const& info, Container* values) {
return GlobalSum(ctx, info, values->data(), values->size());
}

/**
Expand All @@ -157,16 +162,15 @@ void GlobalSum(MetaInfo const& info, Container* values) {
* @return The global ratio of the two inputs.
*/
template <typename T>
T GlobalRatio(MetaInfo const& info, T dividend, T divisor) {
T GlobalRatio(Context const* ctx, MetaInfo const& info, T dividend, T divisor) {
std::array<T, 2> results{dividend, divisor};
GlobalSum(info, &results);
auto rc = GlobalSum(ctx, info, linalg::MakeVec(results.data(), results.size()));
collective::SafeColl(rc);
std::tie(dividend, divisor) = std::tuple_cat(results);
if (divisor <= 0) {
return std::numeric_limits<T>::quiet_NaN();
} else {
return dividend / divisor;
}
}

} // namespace collective
} // namespace xgboost
} // namespace xgboost::collective
34 changes: 21 additions & 13 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2020-2022 by XGBoost Contributors
/**
* Copyright 2020-2024, XGBoost Contributors
*/
#include "quantile.h"

Expand Down Expand Up @@ -145,7 +145,7 @@ struct QuantileAllreduce {

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::GatherSketchInfo(
Context const *, MetaInfo const &info,
Context const *ctx, MetaInfo const &info,
std::vector<typename WQSketch::SummaryContainer> const &reduced,
std::vector<size_t> *p_worker_segments, std::vector<bst_row_t> *p_sketches_scan,
std::vector<typename WQSketch::Entry> *p_global_sketches) {
Expand All @@ -171,7 +171,9 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
std::partial_sum(sketch_size.cbegin(), sketch_size.cend(), sketches_scan.begin() + beg_scan + 1);

// Gather all column pointers
collective::GlobalSum(info, sketches_scan.data(), sketches_scan.size());
auto rc =
collective::GlobalSum(ctx, info, linalg::MakeVec(sketches_scan.data(), sketches_scan.size()));
collective::SafeColl(rc);
for (int32_t i = 0; i < world; ++i) {
size_t back = (i + 1) * (n_columns + 1) - 1;
auto n_entries = sketches_scan.at(back);
Expand Down Expand Up @@ -199,14 +201,15 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(

static_assert(sizeof(typename WQSketch::Entry) / 4 == sizeof(float),
"Unexpected size of sketch entry.");
collective::GlobalSum(
info,
reinterpret_cast<float *>(global_sketches.data()),
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float));
rc = collective::GlobalSum(
ctx, info,
linalg::MakeVec(reinterpret_cast<float *>(global_sketches.data()),
global_sketches.size() * sizeof(typename WQSketch::Entry) / sizeof(float)));
collective::SafeColl(rc);
}

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const*, MetaInfo const& info) {
void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const* ctx, MetaInfo const& info) {
auto world_size = collective::GetWorldSize();
auto rank = collective::GetRank();
if (world_size == 1 || info.IsColumnSplit()) {
Expand All @@ -226,7 +229,8 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const*, MetaInfo
std::vector<size_t> global_feat_ptrs(feature_ptr.size() * world_size, 0);
size_t feat_begin = rank * feature_ptr.size(); // pointer to current worker
std::copy(feature_ptr.begin(), feature_ptr.end(), global_feat_ptrs.begin() + feat_begin);
collective::GlobalSum(info, global_feat_ptrs.data(), global_feat_ptrs.size());
auto rc = collective::GlobalSum(
ctx, info, linalg::MakeVec(global_feat_ptrs.data(), global_feat_ptrs.size()));

// move all categories into a flatten vector to prepare for allreduce
size_t total = feature_ptr.back();
Expand All @@ -239,7 +243,8 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const*, MetaInfo
// indptr for indexing workers
std::vector<size_t> global_worker_ptr(world_size + 1, 0);
global_worker_ptr[rank + 1] = total; // shift 1 to right for constructing the indptr
collective::GlobalSum(info, global_worker_ptr.data(), global_worker_ptr.size());
rc = collective::GlobalSum(ctx, info,
linalg::MakeVec(global_worker_ptr.data(), global_worker_ptr.size()));
std::partial_sum(global_worker_ptr.cbegin(), global_worker_ptr.cend(), global_worker_ptr.begin());
// total number of categories in all workers with all features
auto gtotal = global_worker_ptr.back();
Expand All @@ -251,7 +256,8 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories(Context const*, MetaInfo
CHECK_EQ(rank_size, total);
std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin);
// gather values from all workers.
collective::GlobalSum(info, global_categories.data(), global_categories.size());
rc = collective::GlobalSum(ctx, info,
linalg::MakeVec(global_categories.data(), global_categories.size()));
QuantileAllreduce<float> allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs,
categories_.size()};
ParallelFor(categories_.size(), n_threads_, [&](auto fidx) {
Expand Down Expand Up @@ -293,7 +299,9 @@ void SketchContainerImpl<WQSketch>::AllReduce(

// Prune the intermediate num cuts for synchronization.
std::vector<bst_row_t> global_column_size(columns_size_);
collective::GlobalSum(info, &global_column_size);
auto rc = collective::GlobalSum(
ctx, info, linalg::MakeVec(global_column_size.data(), global_column_size.size()));
collective::SafeColl(rc);

ParallelFor(sketches_.size(), n_threads_, [&](size_t i) {
int32_t intermediate_num_cuts = static_cast<int32_t>(
Expand Down
6 changes: 3 additions & 3 deletions src/learner.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2014-2023 by XGBoost Contributors
* Copyright 2014-2024, XGBoost Contributors
* \file learner.cc
* \brief Implementation of learning algorithm.
* \author Tianqi Chen
Expand Down Expand Up @@ -846,7 +846,7 @@ class LearnerConfiguration : public Learner {

void InitEstimation(MetaInfo const& info, linalg::Tensor<float, 1>* base_score) {
base_score->Reshape(1);
collective::ApplyWithLabels(info, base_score->Data(),
collective::ApplyWithLabels(this->Ctx(), info, base_score->Data(),
[&] { UsePtr(obj_)->InitEstimation(info, base_score); });
}
};
Expand Down Expand Up @@ -1472,7 +1472,7 @@ class LearnerImpl : public LearnerIO {
void GetGradient(HostDeviceVector<bst_float> const& preds, MetaInfo const& info,
std::int32_t iter, linalg::Matrix<GradientPair>* out_gpair) {
out_gpair->Reshape(info.num_row_, this->learner_model_param_.OutputLength());
collective::ApplyWithLabels(info, out_gpair->Data(),
collective::ApplyWithLabels(&ctx_, info, out_gpair->Data(),
[&] { obj_->GetGradient(preds, info, iter, out_gpair); });
}

Expand Down
10 changes: 6 additions & 4 deletions src/metric/auc.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2021-2023 by XGBoost Contributors
* Copyright 2021-2024, XGBoost Contributors
*/
#include "auc.h"

Expand Down Expand Up @@ -112,7 +112,9 @@ double MultiClassOVR(Context const *ctx, common::Span<float const> predts, MetaI

// we have 2 averages going in here, first is among workers, second is among
// classes. allreduce sums up fp/tp auc for each class.
collective::GlobalSum(info, &results.Values());
auto rc = collective::GlobalSum(ctx, info, results);
collective::SafeColl(rc);

double auc_sum{0};
double tp_sum{0};
for (size_t c = 0; c < n_classes; ++c) {
Expand Down Expand Up @@ -286,7 +288,7 @@ class EvalAUC : public MetricNoCache {
InvalidGroupAUC();
}

auc = collective::GlobalRatio(info, auc, static_cast<double>(valid_groups));
auc = collective::GlobalRatio(ctx_, info, auc, static_cast<double>(valid_groups));
if (!std::isnan(auc)) {
CHECK_LE(auc, 1) << "Total AUC across groups: " << auc * valid_groups
<< ", valid groups: " << valid_groups;
Expand All @@ -307,7 +309,7 @@ class EvalAUC : public MetricNoCache {
std::tie(fp, tp, auc) =
static_cast<Curve *>(this)->EvalBinary(preds, info);
}
auc = collective::GlobalRatio(info, auc, fp * tp);
auc = collective::GlobalRatio(ctx_, info, auc, fp * tp);
if (!std::isnan(auc)) {
CHECK_LE(auc, 1.0);
}
Expand Down
Loading

0 comments on commit 5ac2332

Please sign in to comment.