From 8be6095ece5bbb2059b70287e17a650d84aca39f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 13 Mar 2023 22:16:31 +0800 Subject: [PATCH] Implement NDCG cache. (#8893) --- src/common/ranking_utils.cc | 98 +++++++++- src/common/ranking_utils.cu | 207 +++++++++++++++++++++ src/common/ranking_utils.cuh | 40 +++++ src/common/ranking_utils.h | 238 ++++++++++++++++++++++++- tests/cpp/common/test_ranking_utils.cc | 119 ++++++++++++- tests/cpp/common/test_ranking_utils.cu | 98 ++++++++++ tests/cpp/common/test_ranking_utils.h | 9 + 7 files changed, 798 insertions(+), 11 deletions(-) create mode 100644 src/common/ranking_utils.cu create mode 100644 src/common/ranking_utils.cuh create mode 100644 tests/cpp/common/test_ranking_utils.cu create mode 100644 tests/cpp/common/test_ranking_utils.h diff --git a/src/common/ranking_utils.cc b/src/common/ranking_utils.cc index 8fad9a2060b0..c8069784b301 100644 --- a/src/common/ranking_utils.cc +++ b/src/common/ranking_utils.cc @@ -6,9 +6,7 @@ #include // for copy_n, max, min, none_of, all_of #include // for size_t #include // for sscanf -#include // for exception #include // for greater -#include // for reverse_iterator #include // for char_traits, string #include "algorithm.h" // for ArgSort @@ -18,10 +16,102 @@ #include "xgboost/base.h" // for bst_group_t #include "xgboost/context.h" // for Context #include "xgboost/data.h" // for MetaInfo -#include "xgboost/linalg.h" // for All, TensorView, Range, Tensor, Vector -#include "xgboost/logging.h" // for Error, LogCheck_EQ, CHECK_EQ +#include "xgboost/linalg.h" // for All, TensorView, Range +#include "xgboost/logging.h" // for CHECK_EQ namespace xgboost::ltr { +void RankingCache::InitOnCPU(Context const* ctx, MetaInfo const& info) { + if (info.group_ptr_.empty()) { + group_ptr_.Resize(2, 0); + group_ptr_.HostVector()[1] = info.num_row_; + } else { + group_ptr_.HostVector() = info.group_ptr_; + } + + auto const& gptr = group_ptr_.ConstHostVector(); + for (std::size_t i = 1; i < gptr.size(); ++i) { + std::size_t n = gptr[i] - gptr[i - 1]; + max_group_size_ = std::max(max_group_size_, n); + } + + double sum_weights = 0; + auto n_groups = Groups(); + auto weight = common::MakeOptionalWeights(ctx, info.weights_); + for (bst_omp_uint k = 0; k < n_groups; ++k) { + sum_weights += weight[k]; + } + weight_norm_ = static_cast(n_groups) / sum_weights; +} + +common::Span RankingCache::MakeRankOnCPU(Context const* ctx, + common::Span predt) { + auto gptr = this->DataGroupPtr(ctx); + auto rank = this->sorted_idx_cache_.HostSpan(); + CHECK_EQ(rank.size(), predt.size()); + + common::ParallelFor(this->Groups(), ctx->Threads(), [&](auto g) { + auto cnt = gptr[g + 1] - gptr[g]; + auto g_predt = predt.subspan(gptr[g], cnt); + auto g_rank = rank.subspan(gptr[g], cnt); + auto sorted_idx = common::ArgSort( + ctx, g_predt.data(), g_predt.data() + g_predt.size(), std::greater<>{}); + CHECK_EQ(g_rank.size(), sorted_idx.size()); + std::copy_n(sorted_idx.data(), sorted_idx.size(), g_rank.data()); + }); + + return rank; +} + +#if !defined(XGBOOST_USE_CUDA) +void RankingCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); } +common::Span RankingCache::MakeRankOnCUDA(Context const*, + common::Span) { + common::AssertGPUSupport(); + return {}; +} +#endif // !defined() + +void NDCGCache::InitOnCPU(Context const* ctx, MetaInfo const& info) { + auto const h_group_ptr = this->DataGroupPtr(ctx); + + discounts_.Resize(MaxGroupSize(), 0); + auto& h_discounts = discounts_.HostVector(); + for (std::size_t i = 0; i < MaxGroupSize(); ++i) { + h_discounts[i] = CalcDCGDiscount(i); + } + + auto n_groups = h_group_ptr.size() - 1; + auto h_labels = info.labels.HostView().Slice(linalg::All(), 0); + + CheckNDCGLabels(this->Param(), h_labels, + [](auto beg, auto end, auto op) { return std::none_of(beg, end, op); }); + + inv_idcg_.Reshape(n_groups); + auto h_inv_idcg = inv_idcg_.HostView(); + std::size_t topk = this->Param().TopK(); + auto const exp_gain = this->Param().ndcg_exp_gain; + + common::ParallelFor(n_groups, ctx->Threads(), [&](auto g) { + auto g_labels = h_labels.Slice(linalg::Range(h_group_ptr[g], h_group_ptr[g + 1])); + auto sorted_idx = common::ArgSort(ctx, linalg::cbegin(g_labels), + linalg::cend(g_labels), std::greater<>{}); + + double idcg{0.0}; + for (std::size_t i = 0; i < std::min(g_labels.Size(), topk); ++i) { + if (exp_gain) { + idcg += h_discounts[i] * CalcDCGGain(g_labels(sorted_idx[i])); + } else { + idcg += h_discounts[i] * g_labels(sorted_idx[i]); + } + } + h_inv_idcg(g) = CalcInvIDCG(idcg); + }); +} + +#if !defined(XGBOOST_USE_CUDA) +void NDCGCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); } +#endif // !defined(XGBOOST_USE_CUDA) + DMLC_REGISTER_PARAMETER(LambdaRankParam); std::string ParseMetricName(StringView name, StringView param, position_t* topn, bool* minus) { diff --git a/src/common/ranking_utils.cu b/src/common/ranking_utils.cu new file mode 100644 index 000000000000..ce9cda4e24e5 --- /dev/null +++ b/src/common/ranking_utils.cu @@ -0,0 +1,207 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include // for maximum +#include // for make_counting_iterator +#include // for none_of, all_of +#include // for pair, make_pair +#include // for reduce +#include // for inclusive_scan + +#include // for size_t + +#include "algorithm.cuh" // for SegmentedArgSort +#include "cuda_context.cuh" // for CUDAContext +#include "device_helpers.cuh" // for MakeTransformIterator, LaunchN +#include "optional_weight.h" // for MakeOptionalWeights, OptionalWeights +#include "ranking_utils.cuh" // for ThreadsForMean +#include "ranking_utils.h" +#include "threading_utils.cuh" // for SegmentedTrapezoidThreads +#include "xgboost/base.h" // for XGBOOST_DEVICE, bst_group_t +#include "xgboost/context.h" // for Context +#include "xgboost/linalg.h" // for VectorView, All, Range +#include "xgboost/logging.h" // for CHECK +#include "xgboost/span.h" // for Span + +namespace xgboost::ltr { +namespace cuda_impl { +void CalcQueriesDCG(Context const* ctx, linalg::VectorView d_labels, + common::Span d_sorted_idx, bool exp_gain, + common::Span d_group_ptr, std::size_t k, + linalg::VectorView out_dcg) { + CHECK_EQ(d_group_ptr.size() - 1, out_dcg.Size()); + using IdxGroup = thrust::pair; + auto group_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ull), [=] XGBOOST_DEVICE(std::size_t idx) { + return thrust::make_pair(idx, dh::SegmentId(d_group_ptr, idx)); // NOLINT + }); + auto value_it = dh::MakeTransformIterator( + group_it, + [exp_gain, d_labels, d_group_ptr, k, + d_sorted_idx] XGBOOST_DEVICE(IdxGroup const& l) -> double { + auto g_begin = d_group_ptr[l.second]; + auto g_size = d_group_ptr[l.second + 1] - g_begin; + + auto idx_in_group = l.first - g_begin; + if (idx_in_group >= k) { + return 0.0; + } + double gain{0.0}; + auto g_sorted_idx = d_sorted_idx.subspan(g_begin, g_size); + auto g_labels = d_labels.Slice(linalg::Range(g_begin, g_begin + g_size)); + + if (exp_gain) { + gain = ltr::CalcDCGGain(g_labels(g_sorted_idx[idx_in_group])); + } else { + gain = g_labels(g_sorted_idx[idx_in_group]); + } + double discount = CalcDCGDiscount(idx_in_group); + return gain * discount; + }); + + CHECK(out_dcg.Contiguous()); + std::size_t bytes; + cub::DeviceSegmentedReduce::Sum(nullptr, bytes, value_it, out_dcg.Values().data(), + d_group_ptr.size() - 1, d_group_ptr.data(), + d_group_ptr.data() + 1, ctx->CUDACtx()->Stream()); + dh::TemporaryArray temp(bytes); + cub::DeviceSegmentedReduce::Sum(temp.data().get(), bytes, value_it, out_dcg.Values().data(), + d_group_ptr.size() - 1, d_group_ptr.data(), + d_group_ptr.data() + 1, ctx->CUDACtx()->Stream()); +} + +void CalcQueriesInvIDCG(Context const* ctx, linalg::VectorView d_labels, + common::Span d_group_ptr, + linalg::VectorView out_inv_IDCG, ltr::LambdaRankParam const& p) { + CHECK_GE(d_group_ptr.size(), 2ul); + size_t n_groups = d_group_ptr.size() - 1; + CHECK_EQ(out_inv_IDCG.Size(), n_groups); + dh::device_vector sorted_idx(d_labels.Size()); + auto d_sorted_idx = dh::ToSpan(sorted_idx); + common::SegmentedArgSort(ctx, d_labels.Values(), d_group_ptr, d_sorted_idx); + CalcQueriesDCG(ctx, d_labels, d_sorted_idx, p.ndcg_exp_gain, d_group_ptr, p.TopK(), out_inv_IDCG); + dh::LaunchN(out_inv_IDCG.Size(), ctx->CUDACtx()->Stream(), + [out_inv_IDCG] XGBOOST_DEVICE(size_t idx) mutable { + double idcg = out_inv_IDCG(idx); + out_inv_IDCG(idx) = CalcInvIDCG(idcg); + }); +} +} // namespace cuda_impl + +namespace { +struct CheckNDCGOp { + CUDAContext const* cuctx; + template + bool operator()(It beg, It end, Op op) { + return thrust::none_of(cuctx->CTP(), beg, end, op); + } +}; +struct CheckMAPOp { + CUDAContext const* cuctx; + template + bool operator()(It beg, It end, Op op) { + return thrust::all_of(cuctx->CTP(), beg, end, op); + } +}; + +struct ThreadGroupOp { + common::Span d_group_ptr; + std::size_t n_pairs; + + common::Span out_thread_group_ptr; + + XGBOOST_DEVICE void operator()(std::size_t i) { + out_thread_group_ptr[i + 1] = + cuda_impl::ThreadsForMean(d_group_ptr[i + 1] - d_group_ptr[i], n_pairs); + } +}; + +struct GroupSizeOp { + common::Span d_group_ptr; + + XGBOOST_DEVICE auto operator()(std::size_t i) -> std::size_t { + return d_group_ptr[i + 1] - d_group_ptr[i]; + } +}; + +struct WeightOp { + common::OptionalWeights d_weight; + XGBOOST_DEVICE auto operator()(std::size_t i) -> double { return d_weight[i]; } +}; +} // anonymous namespace + +void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { + CUDAContext const* cuctx = ctx->CUDACtx(); + + group_ptr_.SetDevice(ctx->gpu_id); + if (info.group_ptr_.empty()) { + group_ptr_.Resize(2, 0); + group_ptr_.HostVector()[1] = info.num_row_; + } else { + auto const& h_group_ptr = info.group_ptr_; + group_ptr_.Resize(h_group_ptr.size()); + auto d_group_ptr = group_ptr_.DeviceSpan(); + dh::safe_cuda(cudaMemcpyAsync(d_group_ptr.data(), h_group_ptr.data(), d_group_ptr.size_bytes(), + cudaMemcpyHostToDevice, cuctx->Stream())); + } + + auto d_group_ptr = DataGroupPtr(ctx); + std::size_t n_groups = Groups(); + + auto it = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), + GroupSizeOp{d_group_ptr}); + max_group_size_ = + thrust::reduce(cuctx->CTP(), it, it + n_groups, 0ul, thrust::maximum{}); + + threads_group_ptr_.SetDevice(ctx->gpu_id); + threads_group_ptr_.Resize(n_groups + 1, 0); + auto d_threads_group_ptr = threads_group_ptr_.DeviceSpan(); + if (param_.HasTruncation()) { + n_cuda_threads_ = + common::SegmentedTrapezoidThreads(d_group_ptr, d_threads_group_ptr, Param().NumPair()); + } else { + auto n_pairs = Param().NumPair(); + dh::LaunchN(n_groups, cuctx->Stream(), + ThreadGroupOp{d_group_ptr, n_pairs, d_threads_group_ptr}); + thrust::inclusive_scan(cuctx->CTP(), dh::tcbegin(d_threads_group_ptr), + dh::tcend(d_threads_group_ptr), dh::tbegin(d_threads_group_ptr)); + n_cuda_threads_ = info.num_row_ * param_.NumPair(); + } + + sorted_idx_cache_.SetDevice(ctx->gpu_id); + sorted_idx_cache_.Resize(info.labels.Size(), 0); + + auto weight = common::MakeOptionalWeights(ctx, info.weights_); + auto w_it = + dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), WeightOp{weight}); + weight_norm_ = static_cast(n_groups) / thrust::reduce(w_it, w_it + n_groups); +} + +common::Span RankingCache::MakeRankOnCUDA(Context const* ctx, + common::Span predt) { + auto d_sorted_idx = sorted_idx_cache_.DeviceSpan(); + auto d_group_ptr = DataGroupPtr(ctx); + common::SegmentedArgSort(ctx, predt, d_group_ptr, d_sorted_idx); + return d_sorted_idx; +} + +void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { + CUDAContext const* cuctx = ctx->CUDACtx(); + auto labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + CheckNDCGLabels(this->Param(), labels, CheckNDCGOp{cuctx}); + + auto d_group_ptr = this->DataGroupPtr(ctx); + + std::size_t n_groups = d_group_ptr.size() - 1; + inv_idcg_ = linalg::Zeros(ctx, n_groups); + auto d_inv_idcg = inv_idcg_.View(ctx->gpu_id); + cuda_impl::CalcQueriesInvIDCG(ctx, labels, d_group_ptr, d_inv_idcg, this->Param()); + CHECK_GE(this->Param().NumPair(), 1ul); + + discounts_.SetDevice(ctx->gpu_id); + discounts_.Resize(MaxGroupSize()); + auto d_discount = discounts_.DeviceSpan(); + dh::LaunchN(MaxGroupSize(), cuctx->Stream(), + [=] XGBOOST_DEVICE(std::size_t i) { d_discount[i] = CalcDCGDiscount(i); }); +} +} // namespace xgboost::ltr diff --git a/src/common/ranking_utils.cuh b/src/common/ranking_utils.cuh new file mode 100644 index 000000000000..297f5157ecfb --- /dev/null +++ b/src/common/ranking_utils.cuh @@ -0,0 +1,40 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#ifndef XGBOOST_COMMON_RANKING_UTILS_CUH_ +#define XGBOOST_COMMON_RANKING_UTILS_CUH_ + +#include // for size_t + +#include "ranking_utils.h" // for LambdaRankParam +#include "xgboost/base.h" // for bst_group_t, XGBOOST_DEVICE +#include "xgboost/context.h" // for Context +#include "xgboost/linalg.h" // for VectorView +#include "xgboost/span.h" // for Span + +namespace xgboost { +namespace ltr { +namespace cuda_impl { +void CalcQueriesDCG(Context const *ctx, linalg::VectorView d_labels, + common::Span d_sorted_idx, bool exp_gain, + common::Span d_group_ptr, std::size_t k, + linalg::VectorView out_dcg); + +void CalcQueriesInvIDCG(Context const *ctx, linalg::VectorView d_labels, + common::Span d_group_ptr, + linalg::VectorView out_inv_IDCG, ltr::LambdaRankParam const &p); + +// Functions for creating number of threads for CUDA, and getting back the number of pairs +// from the number of threads. +XGBOOST_DEVICE __forceinline__ std::size_t ThreadsForMean(std::size_t group_size, + std::size_t n_pairs) { + return group_size * n_pairs; +} +XGBOOST_DEVICE __forceinline__ std::size_t PairsForGroup(std::size_t n_threads, + std::size_t group_size) { + return n_threads / group_size; +} +} // namespace cuda_impl +} // namespace ltr +} // namespace xgboost +#endif // XGBOOST_COMMON_RANKING_UTILS_CUH_ diff --git a/src/common/ranking_utils.h b/src/common/ranking_utils.h index 631de4d70324..88283fba286e 100644 --- a/src/common/ranking_utils.h +++ b/src/common/ranking_utils.h @@ -11,7 +11,6 @@ #include // for char_traits, string #include // for vector -#include "./math.h" // for CloseTo #include "dmlc/parameter.h" // for FieldEntry, DMLC_DECLARE_FIELD #include "error_msg.h" // for GroupWeight, GroupSize #include "xgboost/base.h" // for XGBOOST_DEVICE, bst_group_t @@ -19,7 +18,7 @@ #include "xgboost/data.h" // for MetaInfo #include "xgboost/host_device_vector.h" // for HostDeviceVector #include "xgboost/linalg.h" // for Vector, VectorView, Tensor -#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK +#include "xgboost/logging.h" // for CHECK_EQ, CHECK #include "xgboost/parameter.h" // for XGBoostParameter #include "xgboost/span.h" // for Span #include "xgboost/string_view.h" // for StringView @@ -34,6 +33,25 @@ using rel_degree_t = std::uint32_t; // NOLINT */ using position_t = std::uint32_t; // NOLINT +/** + * \brief Maximum relevance degree for NDCG + */ +constexpr std::size_t MaxRel() { return sizeof(rel_degree_t) * 8 - 1; } +static_assert(MaxRel() == 31); + +XGBOOST_DEVICE inline double CalcDCGGain(rel_degree_t label) { + return static_cast((1u << label) - 1); +} + +XGBOOST_DEVICE inline double CalcDCGDiscount(std::size_t idx) { + return 1.0 / std::log2(static_cast(idx) + 2.0); +} + +XGBOOST_DEVICE inline double CalcInvIDCG(double idcg) { + auto inv_idcg = (idcg == 0.0 ? 0.0 : (1.0 / idcg)); // handle irrelevant document + return inv_idcg; +} + enum class PairMethod : std::int32_t { kTopK = 0, kMean = 1, @@ -115,7 +133,7 @@ struct LambdaRankParam : public XGBoostParameter { .describe("Number of pairs for each sample in the list."); DMLC_DECLARE_FIELD(lambdarank_unbiased) .set_default(false) - .describe("Unbiased lambda mart. Use IPW to debias click position"); + .describe("Unbiased lambda mart. Use extended IPW to debias click position"); DMLC_DECLARE_FIELD(lambdarank_bias_norm) .set_default(2.0) .set_lower_bound(0.0) @@ -126,6 +144,220 @@ struct LambdaRankParam : public XGBoostParameter { } }; +/** + * \brief Common cached items for ranking tasks. + */ +class RankingCache { + private: + void InitOnCPU(Context const* ctx, MetaInfo const& info); + void InitOnCUDA(Context const* ctx, MetaInfo const& info); + // Cached parameter + LambdaRankParam param_; + // offset to data groups. + HostDeviceVector group_ptr_; + // store the sorted index of prediction. + HostDeviceVector sorted_idx_cache_; + // Maximum size of group + std::size_t max_group_size_{0}; + // Normalization for weight + double weight_norm_{1.0}; + /** + * CUDA cache + */ + // offset to threads assigned to each group for gradient calculation + HostDeviceVector threads_group_ptr_; + // Sorted index of label for finding buckets. + HostDeviceVector y_sorted_idx_cache_; + // Cached labels sorted by the model + HostDeviceVector y_ranked_by_model_; + // store rounding factor for objective for each group + linalg::Vector roundings_; + // rounding factor for cost + HostDeviceVector cost_rounding_; + // temporary storage for creating rounding factors. Stored as byte to avoid having cuda + // data structure in here. + HostDeviceVector max_lambdas_; + // total number of cuda threads used for gradient calculation + std::size_t n_cuda_threads_{0}; + + // Create model rank list on GPU + common::Span MakeRankOnCUDA(Context const* ctx, + common::Span predt); + // Create model rank list on CPU + common::Span MakeRankOnCPU(Context const* ctx, + common::Span predt); + + protected: + [[nodiscard]] std::size_t MaxGroupSize() const { return max_group_size_; } + + public: + RankingCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p) : param_{p} { + CHECK(param_.GetInitialised()); + if (!info.group_ptr_.empty()) { + CHECK_EQ(info.group_ptr_.back(), info.labels.Size()) + << error::GroupSize() << "the size of label."; + } + if (ctx->IsCPU()) { + this->InitOnCPU(ctx, info); + } else { + this->InitOnCUDA(ctx, info); + } + if (!info.weights_.Empty()) { + CHECK_EQ(Groups(), info.weights_.Size()) << error::GroupWeight(); + } + } + [[nodiscard]] std::size_t MaxPositionSize() const { + // Use truncation level as bound. + if (param_.HasTruncation()) { + return param_.NumPair(); + } + // Hardcoded maximum size of positions to track. We don't need too many of them as the + // bias decreases exponentially. + return std::min(max_group_size_, static_cast(32)); + } + // Constructed as [1, n_samples] if group ptr is not supplied by the user + common::Span DataGroupPtr(Context const* ctx) const { + group_ptr_.SetDevice(ctx->gpu_id); + return ctx->IsCPU() ? group_ptr_.ConstHostSpan() : group_ptr_.ConstDeviceSpan(); + } + + [[nodiscard]] auto const& Param() const { return param_; } + [[nodiscard]] std::size_t Groups() const { return group_ptr_.Size() - 1; } + [[nodiscard]] double WeightNorm() const { return weight_norm_; } + + // Create a rank list by model prediction + common::Span SortedIdx(Context const* ctx, common::Span predt) { + if (sorted_idx_cache_.Empty()) { + sorted_idx_cache_.SetDevice(ctx->gpu_id); + sorted_idx_cache_.Resize(predt.size()); + } + if (ctx->IsCPU()) { + return this->MakeRankOnCPU(ctx, predt); + } else { + return this->MakeRankOnCUDA(ctx, predt); + } + } + // The function simply returns a uninitialized buffer as this is only used by the + // objective for creating pairs. + common::Span SortedIdxY(Context const* ctx, std::size_t n_samples) { + CHECK(ctx->IsCUDA()); + if (y_sorted_idx_cache_.Empty()) { + y_sorted_idx_cache_.SetDevice(ctx->gpu_id); + y_sorted_idx_cache_.Resize(n_samples); + } + return y_sorted_idx_cache_.DeviceSpan(); + } + common::Span RankedY(Context const* ctx, std::size_t n_samples) { + CHECK(ctx->IsCUDA()); + if (y_ranked_by_model_.Empty()) { + y_ranked_by_model_.SetDevice(ctx->gpu_id); + y_ranked_by_model_.Resize(n_samples); + } + return y_ranked_by_model_.DeviceSpan(); + } + + // CUDA cache getters, the cache is shared between metric and objective, some of these + // fields are lazy initialized to avoid unnecessary allocation. + [[nodiscard]] common::Span CUDAThreadsGroupPtr() const { + CHECK(!threads_group_ptr_.Empty()); + return threads_group_ptr_.ConstDeviceSpan(); + } + [[nodiscard]] std::size_t CUDAThreads() const { return n_cuda_threads_; } + + linalg::VectorView CUDARounding(Context const* ctx) { + if (roundings_.Size() == 0) { + roundings_.SetDevice(ctx->gpu_id); + roundings_.Reshape(Groups()); + } + return roundings_.View(ctx->gpu_id); + } + common::Span CUDACostRounding(Context const* ctx) { + if (cost_rounding_.Size() == 0) { + cost_rounding_.SetDevice(ctx->gpu_id); + cost_rounding_.Resize(1); + } + return cost_rounding_.DeviceSpan(); + } + template + common::Span MaxLambdas(Context const* ctx, std::size_t n) { + max_lambdas_.SetDevice(ctx->gpu_id); + std::size_t bytes = n * sizeof(Type); + if (bytes != max_lambdas_.Size()) { + max_lambdas_.Resize(bytes); + } + return common::Span{reinterpret_cast(max_lambdas_.DevicePointer()), n}; + } +}; + +class NDCGCache : public RankingCache { + // NDCG discount + HostDeviceVector discounts_; + // 1.0 / IDCG + linalg::Vector inv_idcg_; + /** + * CUDA cache + */ + // store the intermediate DCG calculation result for metric + linalg::Vector dcg_; + + public: + void InitOnCPU(Context const* ctx, MetaInfo const& info); + void InitOnCUDA(Context const* ctx, MetaInfo const& info); + + public: + NDCGCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p) + : RankingCache{ctx, info, p} { + if (ctx->IsCPU()) { + this->InitOnCPU(ctx, info); + } else { + this->InitOnCUDA(ctx, info); + } + } + + linalg::VectorView InvIDCG(Context const* ctx) const { + return inv_idcg_.View(ctx->gpu_id); + } + common::Span Discount(Context const* ctx) const { + return ctx->IsCPU() ? discounts_.ConstHostSpan() : discounts_.ConstDeviceSpan(); + } + linalg::VectorView Dcg(Context const* ctx) { + if (dcg_.Size() == 0) { + dcg_.SetDevice(ctx->gpu_id); + dcg_.Reshape(this->Groups()); + } + return dcg_.View(ctx->gpu_id); + } +}; + +/** + * \brief Validate label for NDCG + * + * \tparam NoneOf Implementation of std::none_of. Specified as a parameter to reuse the + * check for both CPU and GPU. + */ +template +void CheckNDCGLabels(ltr::LambdaRankParam const& p, linalg::VectorView labels, + NoneOf none_of) { + auto d_labels = labels.Values(); + if (p.ndcg_exp_gain) { + auto label_is_integer = + none_of(d_labels.data(), d_labels.data() + d_labels.size(), [] XGBOOST_DEVICE(float v) { + auto l = std::floor(v); + return std::fabs(l - v) > kRtEps || v < 0.0f; + }); + CHECK(label_is_integer) + << "When using relevance degree as target, label must be either 0 or positive integer."; + } + + if (p.ndcg_exp_gain) { + auto label_is_valid = none_of(d_labels.data(), d_labels.data() + d_labels.size(), + [] XGBOOST_DEVICE(ltr::rel_degree_t v) { return v > MaxRel(); }); + CHECK(label_is_valid) << "Relevance degress must be lesser than or equal to " << MaxRel() + << " when the exponential NDCG gain function is used. " + << "Set `ndcg_exp_gain` to false to use custom DCG gain."; + } +} + /** * \brief Parse name for ranking metric given parameters. * diff --git a/tests/cpp/common/test_ranking_utils.cc b/tests/cpp/common/test_ranking_utils.cc index c73cffed7e27..9240db0d4814 100644 --- a/tests/cpp/common/test_ranking_utils.cc +++ b/tests/cpp/common/test_ranking_utils.cc @@ -1,16 +1,25 @@ /** * Copyright 2023 by XGBoost Contributors */ -#include // for Test, AssertionResult, Message, TestPartR... -#include // for ASSERT_NEAR, ASSERT_T... -#include // for Args +#include "test_ranking_utils.h" + +#include +#include // for Args, bst_group_t, kRtEps #include // for Context +#include // for MetaInfo, DMatrix +#include // for HostDeviceVector +#include // for Error #include // for StringView +#include // for size_t #include // for uint32_t -#include // for pair +#include // for iota +#include // for move +#include // for vector +#include "../../../src/common/numeric.h" // for Iota #include "../../../src/common/ranking_utils.h" // for LambdaRankParam, ParseMetricName, MakeMet... +#include "../helpers.h" // for EmptyDMatrix namespace xgboost::ltr { TEST(RankingUtils, LambdaRankParam) { @@ -66,4 +75,106 @@ TEST(RankingUtils, MakeMetricName) { name = MakeMetricName("map", 2, false); ASSERT_EQ(name, "map@2"); } + +void TestRankingCache(Context const* ctx) { + auto p_fmat = EmptyDMatrix(); + MetaInfo& info = p_fmat->Info(); + + info.num_row_ = 16; + info.labels.Reshape(info.num_row_); + auto& h_label = info.labels.Data()->HostVector(); + for (std::size_t i = 0; i < h_label.size(); ++i) { + h_label[i] = i % 2; + } + + LambdaRankParam param; + param.UpdateAllowUnknown(Args{}); + + RankingCache cache{ctx, info, param}; + + HostDeviceVector predt(info.num_row_, 0); + auto& h_predt = predt.HostVector(); + std::iota(h_predt.begin(), h_predt.end(), 0.0f); + predt.SetDevice(ctx->gpu_id); + + auto rank_idx = + cache.SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan()); + + for (std::size_t i = 0; i < rank_idx.size(); ++i) { + ASSERT_EQ(rank_idx[i], rank_idx.size() - i - 1); + } +} + +TEST(RankingCache, InitFromCPU) { + Context ctx; + TestRankingCache(&ctx); +} + +void TestNDCGCache(Context const* ctx) { + auto p_fmat = EmptyDMatrix(); + MetaInfo& info = p_fmat->Info(); + LambdaRankParam param; + param.UpdateAllowUnknown(Args{}); + + { + // empty + NDCGCache cache{ctx, info, param}; + ASSERT_EQ(cache.DataGroupPtr(ctx).size(), 2); + } + + info.num_row_ = 3; + info.group_ptr_ = {static_cast(0), static_cast(info.num_row_)}; + + { + auto fail = [&]() { NDCGCache cache{ctx, info, param}; }; + // empty label + ASSERT_THROW(fail(), dmlc::Error); + info.labels = linalg::Matrix{{0.0f, 0.1f, 0.2f}, {3}, Context::kCpuId}; + // invalid label + ASSERT_THROW(fail(), dmlc::Error); + auto h_labels = info.labels.HostView(); + for (std::size_t i = 0; i < h_labels.Size(); ++i) { + h_labels(i) *= 10; + } + param.UpdateAllowUnknown(Args{{"ndcg_exp_gain", "false"}}); + NDCGCache cache{ctx, info, param}; + Context cpuctx; + auto inv_idcg = cache.InvIDCG(&cpuctx); + ASSERT_EQ(inv_idcg.Size(), 1); + ASSERT_NEAR(1.0 / inv_idcg(0), 2.63093, kRtEps); + } + + { + param.UpdateAllowUnknown(Args{{"lambdarank_unbiased", "false"}}); + + std::vector h_data(32); + + common::Iota(ctx, h_data.begin(), h_data.end(), 0.0f); + info.labels.Reshape(h_data.size()); + info.num_row_ = h_data.size(); + info.group_ptr_.back() = info.num_row_; + info.labels.Data()->HostVector() = std::move(h_data); + + { + NDCGCache cache{ctx, info, param}; + Context cpuctx; + auto inv_idcg = cache.InvIDCG(&cpuctx); + ASSERT_NEAR(inv_idcg(0), 0.00551782, kRtEps); + } + + param.UpdateAllowUnknown( + Args{{"lambdarank_num_pair_per_sample", "3"}, {"lambdarank_pair_method", "topk"}}); + { + NDCGCache cache{ctx, info, param}; + Context cpuctx; + auto inv_idcg = cache.InvIDCG(&cpuctx); + ASSERT_NEAR(inv_idcg(0), 0.01552123, kRtEps); + } + } +} + +TEST(NDCGCache, InitFromCPU) { + Context ctx; + TestNDCGCache(&ctx); +} } // namespace xgboost::ltr diff --git a/tests/cpp/common/test_ranking_utils.cu b/tests/cpp/common/test_ranking_utils.cu new file mode 100644 index 000000000000..5fda42c724be --- /dev/null +++ b/tests/cpp/common/test_ranking_utils.cu @@ -0,0 +1,98 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include +#include // for Args, XGBOOST_DEVICE, bst_group_t, kRtEps +#include // for Context +#include // for MakeTensorView, Vector + +#include // for size_t +#include // for shared_ptr +#include // for iota +#include // for vector + +#include "../../../src/common/algorithm.cuh" // for SegmentedSequence +#include "../../../src/common/cuda_context.cuh" // for CUDAContext +#include "../../../src/common/device_helpers.cuh" // for device_vector, ToSpan +#include "../../../src/common/ranking_utils.cuh" // for CalcQueriesInvIDCG +#include "../../../src/common/ranking_utils.h" // for LambdaRankParam, RankingCache +#include "../helpers.h" // for EmptyDMatrix +#include "test_ranking_utils.h" // for TestNDCGCache +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector + +namespace xgboost::ltr { +void TestCalcQueriesInvIDCG() { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); + std::size_t n_groups = 5, n_samples_per_group = 32; + + dh::device_vector scores(n_samples_per_group * n_groups); + dh::device_vector group_ptr(n_groups + 1); + auto d_group_ptr = dh::ToSpan(group_ptr); + dh::LaunchN(d_group_ptr.size(), ctx.CUDACtx()->Stream(), + [=] XGBOOST_DEVICE(std::size_t i) { d_group_ptr[i] = i * n_samples_per_group; }); + + auto d_scores = dh::ToSpan(scores); + common::SegmentedSequence(&ctx, d_group_ptr, d_scores); + + linalg::Vector inv_IDCG({n_groups}, ctx.gpu_id); + + ltr::LambdaRankParam p; + p.UpdateAllowUnknown(Args{{"ndcg_exp_gain", "false"}}); + + cuda_impl::CalcQueriesInvIDCG(&ctx, linalg::MakeTensorView(&ctx, d_scores, d_scores.size()), + dh::ToSpan(group_ptr), inv_IDCG.View(ctx.gpu_id), p); + for (std::size_t i = 0; i < n_groups; ++i) { + double inv_idcg = inv_IDCG(i); + ASSERT_NEAR(inv_idcg, 0.00551782, kRtEps); + } +} + +TEST(RankingUtils, CalcQueriesInvIDCG) { TestCalcQueriesInvIDCG(); } + +namespace { +void TestRankingCache(Context const* ctx) { + auto p_fmat = EmptyDMatrix(); + MetaInfo& info = p_fmat->Info(); + + info.num_row_ = 16; + info.labels.Reshape(info.num_row_); + auto& h_label = info.labels.Data()->HostVector(); + for (std::size_t i = 0; i < h_label.size(); ++i) { + h_label[i] = i % 2; + } + + LambdaRankParam param; + param.UpdateAllowUnknown(Args{}); + + RankingCache cache{ctx, info, param}; + + HostDeviceVector predt(info.num_row_, 0); + auto& h_predt = predt.HostVector(); + std::iota(h_predt.begin(), h_predt.end(), 0.0f); + predt.SetDevice(ctx->gpu_id); + + auto rank_idx = + cache.SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan()); + + std::vector h_rank_idx(rank_idx.size()); + dh::CopyDeviceSpanToVector(&h_rank_idx, rank_idx); + for (std::size_t i = 0; i < rank_idx.size(); ++i) { + ASSERT_EQ(h_rank_idx[i], h_rank_idx.size() - i - 1); + } +} +} // namespace + +TEST(RankingCache, InitFromGPU) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); + TestRankingCache(&ctx); +} + +TEST(NDCGCache, InitFromGPU) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); + TestNDCGCache(&ctx); +} +} // namespace xgboost::ltr diff --git a/tests/cpp/common/test_ranking_utils.h b/tests/cpp/common/test_ranking_utils.h new file mode 100644 index 000000000000..ede687ff4edb --- /dev/null +++ b/tests/cpp/common/test_ranking_utils.h @@ -0,0 +1,9 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#pragma once +#include // for Context + +namespace xgboost::ltr { +void TestNDCGCache(Context const* ctx); +} // namespace xgboost::ltr