From 3005f213f9deee34d40702e7cc7c1cd5f72b814b Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 30 May 2023 08:49:09 +0800 Subject: [PATCH] Cleanup GPU ranking metric. --- src/common/device_helpers.cuh | 170 ---------------------------------- src/metric/metric.cc | 24 +---- src/metric/metric_common.h | 56 ++--------- src/metric/rank_metric.cc | 84 +---------------- src/metric/rank_metric.cu | 101 -------------------- 5 files changed, 12 insertions(+), 423 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 4aadfb0c083b..db38b2222e4c 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -825,176 +825,6 @@ XGBOOST_DEVICE auto tcrend(xgboost::common::Span const &span) { // NOLINT return tcrbegin(span) + span.size(); } -// This type sorts an array which is divided into multiple groups. The sorting is influenced -// by the function object 'Comparator' -template -class SegmentSorter { - private: - // Items sorted within the group - caching_device_vector ditems_; - - // Original position of the items before they are sorted descending within their groups - caching_device_vector doriginal_pos_; - - // Segments within the original list that delineates the different groups - caching_device_vector group_segments_; - - // Need this on the device as it is used in the kernels - caching_device_vector dgroups_; // Group information on device - - // Where did the item that was originally present at position 'x' move to after they are sorted - caching_device_vector dindexable_sorted_pos_; - - // Initialize everything but the segments - void Init(uint32_t num_elems) { - ditems_.resize(num_elems); - - doriginal_pos_.resize(num_elems); - thrust::sequence(doriginal_pos_.begin(), doriginal_pos_.end()); - } - - // Initialize all with group info - void Init(const std::vector &groups) { - uint32_t num_elems = groups.back(); - this->Init(num_elems); - this->CreateGroupSegments(groups); - } - - public: - // This needs to be public due to device lambda - void CreateGroupSegments(const std::vector &groups) { - uint32_t num_elems = groups.back(); - group_segments_.resize(num_elems, 0); - - dgroups_ = groups; - - if (GetNumGroups() == 1) return; // There are no segments; hence, no need to compute them - - // Define the segments by assigning a group ID to each element - const uint32_t *dgroups = dgroups_.data().get(); - uint32_t ngroups = dgroups_.size(); - auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) { - return thrust::upper_bound(thrust::seq, dgroups, dgroups + ngroups, idx) - - dgroups - 1; - }; // NOLINT - - thrust::transform(thrust::make_counting_iterator(static_cast(0)), - thrust::make_counting_iterator(num_elems), - group_segments_.begin(), - ComputeGroupIDLambda); - } - - // Accessors that returns device pointer - inline uint32_t GetNumItems() const { return ditems_.size(); } - inline const xgboost::common::Span GetItemsSpan() const { - return { ditems_.data().get(), ditems_.size() }; - } - - inline const xgboost::common::Span GetOriginalPositionsSpan() const { - return { doriginal_pos_.data().get(), doriginal_pos_.size() }; - } - - inline const xgboost::common::Span GetGroupSegmentsSpan() const { - return { group_segments_.data().get(), group_segments_.size() }; - } - - inline uint32_t GetNumGroups() const { return dgroups_.size() - 1; } - inline const xgboost::common::Span GetGroupsSpan() const { - return { dgroups_.data().get(), dgroups_.size() }; - } - - inline const xgboost::common::Span GetIndexableSortedPositionsSpan() const { - return { dindexable_sorted_pos_.data().get(), dindexable_sorted_pos_.size() }; - } - - // Sort an array that is divided into multiple groups. The array is sorted within each group. - // This version provides the group information that is on the host. - // The array is sorted based on an adaptable binary predicate. By default a stateless predicate - // is used. - template > - void SortItems(const T *ditems, uint32_t item_size, const std::vector &groups, - const Comparator &comp = Comparator()) { - this->Init(groups); - this->SortItems(ditems, item_size, this->GetGroupSegmentsSpan(), comp); - } - - // Sort an array that is divided into multiple groups. The array is sorted within each group. - // This version provides the group information that is on the device. - // The array is sorted based on an adaptable binary predicate. By default a stateless predicate - // is used. - template > - void SortItems(const T *ditems, uint32_t item_size, - const xgboost::common::Span &group_segments, - const Comparator &comp = Comparator()) { - this->Init(item_size); - - // Sort the items that are grouped. We would like to avoid using predicates to perform the sort, - // as thrust resorts to using a merge sort as opposed to a much much faster radix sort - // when comparators are used. Hence, the following algorithm is used. This is done so that - // we can grab the appropriate related values from the original list later, after the - // items are sorted. - // - // Here is the internal representation: - // dgroups_: [ 0, 3, 5, 8, 10 ] - // group_segments_: 0 0 0 | 1 1 | 2 2 2 | 3 3 - // doriginal_pos_: 0 1 2 | 3 4 | 5 6 7 | 8 9 - // ditems_: 1 0 1 | 2 1 | 1 3 3 | 4 4 (from original items) - // - // Sort the items first and make a note of the original positions in doriginal_pos_ - // based on the sort - // ditems_: 4 4 3 3 2 1 1 1 1 0 - // doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 - // NOTE: This consumes space, but is much faster than some of the other approaches - sorting - // in kernel, sorting using predicates etc. - - ditems_.assign(thrust::device_ptr(ditems), - thrust::device_ptr(ditems) + item_size); - - // Allocator to be used by sort for managing space overhead while sorting - dh::XGBCachingDeviceAllocator alloc; - - thrust::stable_sort_by_key(thrust::cuda::par(alloc), - ditems_.begin(), ditems_.end(), - doriginal_pos_.begin(), comp); - - if (GetNumGroups() == 1) return; // The entire array is sorted, as it isn't segmented - - // Next, gather the segments based on the doriginal_pos_. This is to reflect the - // holisitic item sort order on the segments - // group_segments_c_: 3 3 2 2 1 0 0 1 2 0 - // doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 (stays the same) - caching_device_vector group_segments_c(item_size); - thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(), - dh::tcbegin(group_segments), group_segments_c.begin()); - - // Now, sort the group segments so that you may bring the items within the group together, - // in the process also noting the relative changes to the doriginal_pos_ while that happens - // group_segments_c_: 0 0 0 1 1 2 2 2 3 3 - // doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 - thrust::stable_sort_by_key(thrust::cuda::par(alloc), - group_segments_c.begin(), group_segments_c.end(), - doriginal_pos_.begin(), thrust::less()); - - // Finally, gather the original items based on doriginal_pos_ to sort the input and - // to store them in ditems_ - // doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 (stays the same) - // ditems_: 1 1 0 2 1 3 3 1 4 4 (from unsorted items - ditems) - thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(), - thrust::device_ptr(ditems), ditems_.begin()); - } - - // Determine where an item that was originally present at position 'x' has been relocated to - // after a sort. Creation of such an index has to be explicitly requested after a sort - void CreateIndexableSortedPositions() { - dindexable_sorted_pos_.resize(GetNumItems()); - thrust::scatter(thrust::make_counting_iterator(static_cast(0)), - thrust::make_counting_iterator(GetNumItems()), // Rearrange indices... - // ...based on this map - dh::tcbegin(GetOriginalPositionsSpan()), - dindexable_sorted_pos_.begin()); // Write results into this - } -}; - // Atomic add function for gradients template XGBOOST_DEV_INLINE void AtomicAddGpair(OutputGradientT* dest, diff --git a/src/metric/metric.cc b/src/metric/metric.cc index ebb5798272d3..d7e2683ecc02 100644 --- a/src/metric/metric.cc +++ b/src/metric/metric.cc @@ -52,32 +52,13 @@ Metric::Create(const std::string& name, Context const* ctx) { metric->ctx_ = ctx; return metric; } - -GPUMetric* GPUMetric::CreateGPUMetric(const std::string& name, Context const* ctx) { - auto metric = CreateMetricImpl(name); - if (metric == nullptr) { - LOG(WARNING) << "Cannot find a GPU metric builder for metric " << name - << ". Resorting to the CPU builder"; - return nullptr; - } - - // Narrowing reference only for the compiler to allow assignment to a base class member. - // As such, using this narrowed reference to refer to derived members will be an illegal op. - // This is moot, as this type is stateless. - auto casted = static_cast(metric); - CHECK(casted); - casted->ctx_ = ctx; - return casted; -} } // namespace xgboost namespace dmlc { DMLC_REGISTRY_ENABLE(::xgboost::MetricReg); -DMLC_REGISTRY_ENABLE(::xgboost::MetricGPUReg); } -namespace xgboost { -namespace metric { +namespace xgboost::metric { // List of files that will be force linked in static links. DMLC_REGISTRY_LINK_TAG(auc); DMLC_REGISTRY_LINK_TAG(elementwise_metric); @@ -88,5 +69,4 @@ DMLC_REGISTRY_LINK_TAG(rank_metric); DMLC_REGISTRY_LINK_TAG(auc_gpu); DMLC_REGISTRY_LINK_TAG(rank_metric_gpu); #endif -} // namespace metric -} // namespace xgboost +} // namespace xgboost::metric diff --git a/src/metric/metric_common.h b/src/metric/metric_common.h index a6fad715849b..1b148ab0f47c 100644 --- a/src/metric/metric_common.h +++ b/src/metric/metric_common.h @@ -23,53 +23,14 @@ class MetricNoCache : public Metric { double Evaluate(HostDeviceVector const &predts, std::shared_ptr p_fmat) final { double result{0.0}; - auto const& info = p_fmat->Info(); - collective::ApplyWithLabels(info, &result, sizeof(double), [&] { - result = this->Eval(predts, info); - }); + auto const &info = p_fmat->Info(); + collective::ApplyWithLabels(info, &result, sizeof(double), + [&] { result = this->Eval(predts, info); }); return result; } }; -// This creates a GPU metric instance dynamically and adds it to the GPU metric registry, if not -// present already. This is created when there is a device ordinal present and if xgboost -// is compiled with CUDA support -struct GPUMetric : public MetricNoCache { - static GPUMetric *CreateGPUMetric(const std::string &name, Context const *tparam); -}; - -/*! - * \brief Internal registry entries for GPU Metric factory functions. - * The additional parameter const char* param gives the value after @, can be null. - * For example, metric map@3, then: param == "3". - */ -struct MetricGPUReg - : public dmlc::FunctionRegEntryBase > { -}; - -/*! - * \brief Macro to register metric computed on GPU. - * - * \code - * // example of registering a objective ndcg@k - * XGBOOST_REGISTER_GPU_METRIC(NDCG_GPU, "ndcg") - * .describe("NDCG metric computer on GPU.") - * .set_body([](const char* param) { - * int at_k = atoi(param); - * return new NDCG(at_k); - * }); - * \endcode - */ - -// Note: Metric names registered in the GPU registry should follow this convention: -// - GPU metric types should be registered with the same name as the non GPU metric types -#define XGBOOST_REGISTER_GPU_METRIC(UniqueId, Name) \ - ::xgboost::MetricGPUReg& __make_ ## MetricGPUReg ## _ ## UniqueId ## __ = \ - ::dmlc::Registry< ::xgboost::MetricGPUReg>::Get()->__REGISTER__(Name) - namespace metric { - // Ranking config to be used on device and host struct EvalRankConfig { public: @@ -81,8 +42,8 @@ struct EvalRankConfig { }; class PackedReduceResult { - double residue_sum_ { 0 }; - double weights_sum_ { 0 }; + double residue_sum_{0}; + double weights_sum_{0}; public: XGBOOST_DEVICE PackedReduceResult() {} // NOLINT @@ -91,16 +52,15 @@ class PackedReduceResult { XGBOOST_DEVICE PackedReduceResult operator+(PackedReduceResult const &other) const { - return PackedReduceResult{residue_sum_ + other.residue_sum_, - weights_sum_ + other.weights_sum_}; + return PackedReduceResult{residue_sum_ + other.residue_sum_, weights_sum_ + other.weights_sum_}; } PackedReduceResult &operator+=(PackedReduceResult const &other) { this->residue_sum_ += other.residue_sum_; this->weights_sum_ += other.weights_sum_; return *this; } - double Residue() const { return residue_sum_; } - double Weights() const { return weights_sum_; } + [[nodiscard]] double Residue() const { return residue_sum_; } + [[nodiscard]] double Weights() const { return weights_sum_; } }; } // namespace metric diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index 03110e457659..e8a27934ec29 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -1,25 +1,6 @@ /** * Copyright 2020-2023 by XGBoost contributors */ -// When device ordinal is present, we would want to build the metrics on the GPU. It is *not* -// possible for a valid device ordinal to be present for non GPU builds. However, it is possible -// for an invalid device ordinal to be specified in GPU builds - to train/predict and/or compute -// the metrics on CPU. To accommodate these scenarios, the following is done for the metrics -// accelerated on the GPU. -// - An internal GPU registry holds all the GPU metric types (defined in the .cu file) -// - An instance of the appropriate GPU metric type is created when a device ordinal is present -// - If the creation is successful, the metric computation is done on the device -// - else, it falls back on the CPU -// - The GPU metric types are *only* registered when xgboost is built for GPUs -// -// This is done for 2 reasons: -// - Clear separation of CPU and GPU logic -// - Sorting datasets containing large number of rows is (much) faster when parallel sort -// semantics is used on the CPU. The __gnu_parallel/concurrency primitives needed to perform -// this cannot be used when the translation unit is compiled using the 'nvcc' compiler (as the -// corresponding headers that brings in those function declaration can't be included with CUDA). -// This precludes the CPU and GPU logic to coexist inside a .cu file - #include "rank_metric.h" #include @@ -57,55 +38,8 @@ #include "xgboost/string_view.h" // for StringView namespace { - using PredIndPair = std::pair; using PredIndPairContainer = std::vector; - -/* - * Adapter to access instance weights. - * - * - For ranking task, weights are per-group - * - For binary classification task, weights are per-instance - * - * WeightPolicy::GetWeightOfInstance() : - * get weight associated with an individual instance, using index into - * `info.weights` - * WeightPolicy::GetWeightOfSortedRecord() : - * get weight associated with an individual instance, using index into - * sorted records `rec` (in ascending order of predicted labels). `rec` is - * of type PredIndPairContainer - */ - -class PerInstanceWeightPolicy { - public: - inline static xgboost::bst_float - GetWeightOfInstance(const xgboost::MetaInfo& info, - unsigned instance_id, unsigned) { - return info.GetWeight(instance_id); - } - inline static xgboost::bst_float - GetWeightOfSortedRecord(const xgboost::MetaInfo& info, - const PredIndPairContainer& rec, - unsigned record_id, unsigned) { - return info.GetWeight(rec[record_id].second); - } -}; - -class PerGroupWeightPolicy { - public: - inline static xgboost::bst_float - GetWeightOfInstance(const xgboost::MetaInfo& info, - unsigned, unsigned group_id) { - return info.GetWeight(group_id); - } - - inline static xgboost::bst_float - GetWeightOfSortedRecord(const xgboost::MetaInfo& info, - const PredIndPairContainer&, - unsigned, unsigned group_id) { - return info.GetWeight(group_id); - } -}; } // anonymous namespace namespace xgboost::metric { @@ -177,10 +111,6 @@ struct EvalAMS : public MetricNoCache { /*! \brief Evaluate rank list */ struct EvalRank : public MetricNoCache, public EvalRankConfig { - private: - // This is used to compute the ranking metrics on the GPU - for training jobs that run on the GPU. - std::unique_ptr rank_gpu_; - public: double Eval(const HostDeviceVector& preds, const MetaInfo& info) override { CHECK_EQ(preds.Size(), info.labels.Size()) @@ -199,20 +129,10 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig { // sum statistics double sum_metric = 0.0f; - // Check and see if we have the GPU metric registered in the internal registry - if (ctx_->gpu_id >= 0) { - if (!rank_gpu_) { - rank_gpu_.reset(GPUMetric::CreateGPUMetric(this->Name(), ctx_)); - } - if (rank_gpu_) { - sum_metric = rank_gpu_->Eval(preds, info); - } - } - CHECK(ctx_); std::vector sum_tloc(ctx_->Threads(), 0.0); - if (!rank_gpu_ || ctx_->gpu_id < 0) { + { const auto& labels = info.labels.View(Context::kCpuId); const auto &h_preds = preds.ConstHostVector(); @@ -295,7 +215,7 @@ struct EvalCox : public MetricNoCache { return out/num_events; // normalize by the number of events } - const char* Name() const override { + [[nodiscard]] const char* Name() const override { return "cox-nloglik"; } }; diff --git a/src/metric/rank_metric.cu b/src/metric/rank_metric.cu index 6fe7ba908f3d..9a724eb3b973 100644 --- a/src/metric/rank_metric.cu +++ b/src/metric/rank_metric.cu @@ -28,107 +28,6 @@ namespace xgboost::metric { // tag the this file, used by force static link later. DMLC_REGISTRY_FILE_TAG(rank_metric_gpu); -/*! \brief Evaluate rank list on GPU */ -template -struct EvalRankGpu : public GPUMetric, public EvalRankConfig { - public: - double Eval(const HostDeviceVector &preds, const MetaInfo &info) override { - // Sanity check is done by the caller - std::vector tgptr(2, 0); - tgptr[1] = static_cast(preds.Size()); - const std::vector &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_; - - const auto ngroups = static_cast(gptr.size() - 1); - - auto device = ctx_->gpu_id; - dh::safe_cuda(cudaSetDevice(device)); - - info.labels.SetDevice(device); - preds.SetDevice(device); - - auto dpreds = preds.ConstDevicePointer(); - auto dlabels = info.labels.View(device); - - // Sort all the predictions - dh::SegmentSorter segment_pred_sorter; - segment_pred_sorter.SortItems(dpreds, preds.Size(), gptr); - - // Compute individual group metric and sum them up - return EvalMetricT::EvalMetric(segment_pred_sorter, dlabels.Values().data(), *this); - } - - [[nodiscard]] const char* Name() const override { - return name.c_str(); - } - - explicit EvalRankGpu(const char* name, const char* param) { - using namespace std; // NOLINT(*) - if (param != nullptr) { - std::ostringstream os; - if (sscanf(param, "%u[-]?", &this->topn) == 1) { - os << name << '@' << param; - this->name = os.str(); - } else { - os << name << param; - this->name = os.str(); - } - if (param[strlen(param) - 1] == '-') { - this->minus = true; - } - } else { - this->name = name; - } - } -}; - -/*! \brief Precision at N, for both classification and rank */ -struct EvalPrecisionGpu { - public: - static double EvalMetric(const dh::SegmentSorter &pred_sorter, - const float *dlabels, - const EvalRankConfig &ecfg) { - // Group info on device - const auto &dgroups = pred_sorter.GetGroupsSpan(); - const auto ngroups = pred_sorter.GetNumGroups(); - const auto &dgroup_idx = pred_sorter.GetGroupSegmentsSpan(); - - // Original positions of the predictions after they have been sorted - const auto &dpreds_orig_pos = pred_sorter.GetOriginalPositionsSpan(); - - // First, determine non zero labels in the dataset individually - auto DetermineNonTrivialLabelLambda = [=] __device__(uint32_t idx) { - return (static_cast(dlabels[dpreds_orig_pos[idx]]) != 0) ? 1 : 0; - }; // NOLINT - - // Find each group's metric sum - dh::caching_device_vector hits(ngroups, 0); - const auto nitems = pred_sorter.GetNumItems(); - auto *dhits = hits.data().get(); - - int device_id = -1; - dh::safe_cuda(cudaGetDevice(&device_id)); - // For each group item compute the aggregated precision - dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) { - const auto group_idx = dgroup_idx[idx]; - const auto group_begin = dgroups[group_idx]; - const auto ridx = idx - group_begin; - if (ridx < ecfg.topn && DetermineNonTrivialLabelLambda(idx)) { - atomicAdd(&dhits[group_idx], 1); - } - }); - - // Allocator to be used for managing space overhead while performing reductions - dh::XGBCachingDeviceAllocator alloc; - return static_cast(thrust::reduce(thrust::cuda::par(alloc), - hits.begin(), hits.end())) / ecfg.topn; - } -}; - - -XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre") -.describe("precision@k for rank computed on GPU.") -.set_body([](const char* param) { return new EvalRankGpu("pre", param); }); - namespace cuda_impl { PackedReduceResult PreScore(Context const *ctx, MetaInfo const &info, HostDeviceVector const &predt, bool minus,