From 72e8331eabb0b93a0859e44d827d31a168a4ec9d Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 15 Mar 2023 03:26:17 +0800 Subject: [PATCH] Reimplement the NDCG metric. (#8906) - Add support for non-exp gain. - Cache the DMatrix object to avoid re-calculating the IDCG. - Make GPU implementation deterministic. (no atomic add) --- include/xgboost/cache.h | 20 +++ src/metric/rank_metric.cc | 240 +++++++++++++++++++++------ src/metric/rank_metric.cu | 158 +++++++----------- src/metric/rank_metric.h | 33 ++++ tests/cpp/metric/test_rank_metric.cc | 82 +++++++-- 5 files changed, 368 insertions(+), 165 deletions(-) create mode 100644 src/metric/rank_metric.h diff --git a/include/xgboost/cache.h b/include/xgboost/cache.h index 6195e730c153..32e1b21ac3f6 100644 --- a/include/xgboost/cache.h +++ b/include/xgboost/cache.h @@ -161,6 +161,26 @@ class DMatrixCache { } return container_.at(key).value; } + /** + * \brief Re-initialize the item in cache. + * + * Since the shared_ptr is used to hold the item, any reference that lives outside of + * the cache can no-longer be reached from the cache. + * + * We use reset instead of erase to avoid walking through the whole cache for renewing + * a single item. (the cache is FIFO, needs to maintain the order). + */ + template + std::shared_ptr ResetItem(std::shared_ptr m, Args const&... args) { + std::lock_guard guard{lock_}; + CheckConsistent(); + auto key = Key{m.get(), std::this_thread::get_id()}; + auto it = container_.find(key); + CHECK(it != container_.cend()); + it->second = {m, std::make_shared(args...)}; + CheckConsistent(); + return it->second.value; + } /** * \brief Get a const reference to the underlying hash map. Clear expired caches before * returning. diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index 69e6e24cd1e1..c2aa48cab853 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -20,23 +20,51 @@ // corresponding headers that brings in those function declaration can't be included with CUDA). // This precludes the CPU and GPU logic to coexist inside a .cu file -#include -#include +#include "rank_metric.h" -#include -#include +#include +#include -#include "../collective/communicator-inl.h" -#include "../common/algorithm.h" // Sort -#include "../common/math.h" -#include "../common/ranking_utils.h" // MakeMetricName -#include "../common/threading_utils.h" -#include "metric_common.h" -#include "xgboost/host_device_vector.h" +#include // for stable_sort, copy, fill_n, min, max +#include // for array +#include // for log, sqrt +#include // for size_t, std +#include // for uint32_t +#include // for less, greater +#include // for operator!=, _Rb_tree_const_iterator +#include // for allocator, unique_ptr, shared_ptr, __shared_... +#include // for accumulate +#include // for operator<<, basic_ostream, ostringstream +#include // for char_traits, operator<, basic_string, to_string +#include // for pair, make_pair +#include // for vector + +#include "../collective/communicator-inl.h" // for IsDistributed, Allreduce +#include "../collective/communicator.h" // for Operation +#include "../common/algorithm.h" // for ArgSort, Sort +#include "../common/linalg_op.h" // for cbegin, cend +#include "../common/math.h" // for CmpFirst +#include "../common/optional_weight.h" // for OptionalWeights, MakeOptionalWeights +#include "../common/ranking_utils.h" // for LambdaRankParam, NDCGCache, ParseMetricName +#include "../common/threading_utils.h" // for ParallelFor +#include "../common/transform_iterator.h" // for IndexTransformIter +#include "dmlc/common.h" // for OMPException +#include "metric_common.h" // for MetricNoCache, GPUMetric, PackedReduceResult +#include "xgboost/base.h" // for bst_float, bst_omp_uint, bst_group_t, Args +#include "xgboost/cache.h" // for DMatrixCache +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo, DMatrix +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/json.h" // for Json, FromJson, IsA, ToJson, get, Null, Object +#include "xgboost/linalg.h" // for Tensor, TensorView, Range, VectorView, MakeT... +#include "xgboost/logging.h" // for CHECK, ConsoleLogger, LOG_INFO, CHECK_EQ +#include "xgboost/metric.h" // for MetricReg, XGBOOST_REGISTER_METRIC, Metric +#include "xgboost/span.h" // for Span, operator!= +#include "xgboost/string_view.h" // for StringView namespace { -using PredIndPair = std::pair; +using PredIndPair = std::pair; using PredIndPairContainer = std::vector; /* @@ -87,8 +115,7 @@ class PerGroupWeightPolicy { } // anonymous namespace -namespace xgboost { -namespace metric { +namespace xgboost::metric { // tag the this file, used by force static link later. DMLC_REGISTRY_FILE_TAG(rank_metric); @@ -257,40 +284,6 @@ struct EvalPrecision : public EvalRank { } }; -/*! \brief NDCG: Normalized Discounted Cumulative Gain at N */ -struct EvalNDCG : public EvalRank { - private: - double CalcDCG(const PredIndPairContainer &rec) const { - double sumdcg = 0.0; - for (size_t i = 0; i < rec.size() && i < this->topn; ++i) { - const unsigned rel = rec[i].second; - if (rel != 0) { - sumdcg += ((1 << rel) - 1) / std::log2(i + 2.0); - } - } - return sumdcg; - } - - public: - explicit EvalNDCG(const char* name, const char* param) : EvalRank(name, param) {} - - double EvalGroup(PredIndPairContainer *recptr) const override { - PredIndPairContainer &rec(*recptr); - std::stable_sort(rec.begin(), rec.end(), common::CmpFirst); - double dcg = CalcDCG(rec); - std::stable_sort(rec.begin(), rec.end(), common::CmpSecond); - double idcg = CalcDCG(rec); - if (idcg == 0.0f) { - if (this->minus) { - return 0.0f; - } else { - return 1.0f; - } - } - return dcg/idcg; - } -}; - /*! \brief Mean Average Precision at N, for both classification and rank */ struct EvalMAP : public EvalRank { public: @@ -377,10 +370,6 @@ XGBOOST_REGISTER_METRIC(Precision, "pre") .describe("precision@k for rank.") .set_body([](const char* param) { return new EvalPrecision("pre", param); }); -XGBOOST_REGISTER_METRIC(NDCG, "ndcg") -.describe("ndcg@k for rank.") -.set_body([](const char* param) { return new EvalNDCG("ndcg", param); }); - XGBOOST_REGISTER_METRIC(MAP, "map") .describe("map@k for rank.") .set_body([](const char* param) { return new EvalMAP("map", param); }); @@ -388,5 +377,148 @@ XGBOOST_REGISTER_METRIC(MAP, "map") XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik") .describe("Negative log partial likelihood of Cox proportional hazards model.") .set_body([](const char*) { return new EvalCox(); }); -} // namespace metric -} // namespace xgboost + +// ranking metrics that requires cache +template +class EvalRankWithCache : public Metric { + protected: + ltr::LambdaRankParam param_; + bool minus_{false}; + std::string name_; + + DMatrixCache cache_{DMatrixCache::DefaultSize()}; + + public: + EvalRankWithCache(StringView name, const char* param) { + auto constexpr kMax = ltr::LambdaRankParam::NotSet(); + std::uint32_t topn{kMax}; + this->name_ = ltr::ParseMetricName(name, param, &topn, &minus_); + if (topn != kMax) { + param_.UpdateAllowUnknown(Args{{"lambdarank_num_pair_per_sample", std::to_string(topn)}, + {"lambdarank_pair_method", "topk"}}); + } + param_.UpdateAllowUnknown(Args{}); + } + void Configure(Args const&) override { + // do not configure, otherwise the ndcg param will be forced into the same as the one in + // objective. + } + void LoadConfig(Json const& in) override { + if (IsA(in)) { + return; + } + auto const& obj = get(in); + auto it = obj.find("lambdarank_param"); + if (it != obj.cend()) { + FromJson(it->second, ¶m_); + } + } + + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["name"] = String{this->Name()}; + out["lambdarank_param"] = ToJson(param_); + } + + double Evaluate(HostDeviceVector const& preds, std::shared_ptr p_fmat) override { + auto const& info = p_fmat->Info(); + auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_); + if (p_cache->Param() != param_) { + p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_); + } + CHECK(p_cache->Param() == param_); + CHECK_EQ(preds.Size(), info.labels.Size()); + + return this->Eval(preds, info, p_cache); + } + + virtual double Eval(HostDeviceVector const& preds, MetaInfo const& info, + std::shared_ptr p_cache) = 0; +}; + +namespace { +double Finalize(double score, double sw) { + std::array dat{score, sw}; + collective::Allreduce(dat.data(), dat.size()); + if (sw > 0.0) { + score = score / sw; + } + + CHECK_LE(score, 1.0 + kRtEps) + << "Invalid output score, might be caused by invalid query group weight."; + score = std::min(1.0, score); + + return score; +} +} // namespace + +/** + * \brief Implement the NDCG score function for learning to rank. + * + * Ties are ignored, which can lead to different result with other implementations. + */ +class EvalNDCG : public EvalRankWithCache { + public: + using EvalRankWithCache::EvalRankWithCache; + const char* Name() const override { return name_.c_str(); } + + double Eval(HostDeviceVector const& preds, MetaInfo const& info, + std::shared_ptr p_cache) override { + if (ctx_->IsCUDA()) { + auto ndcg = cuda_impl::NDCGScore(ctx_, info, preds, minus_, p_cache); + return Finalize(ndcg.Residue(), ndcg.Weights()); + } + + // group local ndcg + auto group_ptr = p_cache->DataGroupPtr(ctx_); + bst_group_t n_groups = group_ptr.size() - 1; + auto ndcg_gloc = p_cache->Dcg(ctx_); + std::fill_n(ndcg_gloc.Values().data(), ndcg_gloc.Size(), 0.0); + + auto h_inv_idcg = p_cache->InvIDCG(ctx_); + auto p_discount = p_cache->Discount(ctx_).data(); + + auto h_label = info.labels.HostView(); + auto h_predt = linalg::MakeTensorView(ctx_, &preds, preds.Size()); + auto weights = common::MakeOptionalWeights(ctx_, info.weights_); + + common::ParallelFor(n_groups, ctx_->Threads(), [&](auto g) { + auto g_predt = h_predt.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1])); + auto g_labels = h_label.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1]), 0); + auto sorted_idx = common::ArgSort(ctx_, linalg::cbegin(g_predt), + linalg::cend(g_predt), std::greater<>{}); + double ndcg{.0}; + double inv_idcg = h_inv_idcg(g); + if (inv_idcg <= 0.0) { + ndcg_gloc(g) = minus_ ? 0.0 : 1.0; + return; + } + std::size_t n{std::min(sorted_idx.size(), static_cast(param_.TopK()))}; + if (param_.ndcg_exp_gain) { + for (std::size_t i = 0; i < n; ++i) { + ndcg += p_discount[i] * ltr::CalcDCGGain(g_labels(sorted_idx[i])) * inv_idcg; + } + } else { + for (std::size_t i = 0; i < n; ++i) { + ndcg += p_discount[i] * g_labels(sorted_idx[i]) * inv_idcg; + } + } + ndcg_gloc(g) += ndcg * weights[g]; + }); + double sum_w{0}; + if (weights.Empty()) { + sum_w = n_groups; + } else { + sum_w = std::accumulate(weights.weights.cbegin(), weights.weights.cend(), 0.0); + } + auto ndcg = std::accumulate(linalg::cbegin(ndcg_gloc), linalg::cend(ndcg_gloc), 0.0); + return Finalize(ndcg, sum_w); + } +}; + +XGBOOST_REGISTER_METRIC(EvalNDCG, "ndcg") + .describe("ndcg@k for ranking.") + .set_body([](char const* param) { + return new EvalNDCG{"ndcg", param}; + }); +} // namespace xgboost::metric diff --git a/src/metric/rank_metric.cu b/src/metric/rank_metric.cu index 5f98db7a93cd..4ab422a96838 100644 --- a/src/metric/rank_metric.cu +++ b/src/metric/rank_metric.cu @@ -2,22 +2,29 @@ * Copyright 2020-2023 by XGBoost Contributors */ #include -#include // make_counting_iterator -#include // reduce -#include - -#include // std::size_t -#include // std::shared_ptr - -#include "../common/cuda_context.cuh" // CUDAContext +#include // for make_counting_iterator +#include // for reduce + +#include // for transform +#include // for size_t +#include // for shared_ptr +#include // for vector + +#include "../common/cuda_context.cuh" // for CUDAContext +#include "../common/device_helpers.cuh" // for MakeTransformIterator +#include "../common/optional_weight.h" // for MakeOptionalWeights +#include "../common/ranking_utils.cuh" // for CalcQueriesDCG, NDCGCache #include "metric_common.h" -#include "xgboost/base.h" // XGBOOST_DEVICE -#include "xgboost/context.h" // Context -#include "xgboost/data.h" // MetaInfo -#include "xgboost/host_device_vector.h" // HostDeviceVector - -namespace xgboost { -namespace metric { +#include "rank_metric.h" +#include "xgboost/base.h" // for XGBOOST_DEVICE +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/linalg.h" // for MakeTensorView +#include "xgboost/logging.h" // for CHECK +#include "xgboost/metric.h" + +namespace xgboost::metric { // tag the this file, used by force static link later. DMLC_REGISTRY_FILE_TAG(rank_metric_gpu); @@ -117,81 +124,6 @@ struct EvalPrecisionGpu { } }; -/*! \brief NDCG: Normalized Discounted Cumulative Gain at N */ -struct EvalNDCGGpu { - public: - static void ComputeDCG(const dh::SegmentSorter &pred_sorter, - const float *dlabels, - const EvalRankConfig &ecfg, - // The order in which labels have to be accessed. The order is determined - // by sorting the predictions or the labels for the entire dataset - const xgboost::common::Span &dlabels_sort_order, - dh::caching_device_vector *dcgptr) { - dh::caching_device_vector &dcgs(*dcgptr); - // Group info on device - const auto &dgroups = pred_sorter.GetGroupsSpan(); - const auto &dgroup_idx = pred_sorter.GetGroupSegmentsSpan(); - - // First, determine non zero labels in the dataset individually - auto DetermineNonTrivialLabelLambda = [=] __device__(uint32_t idx) { - return (static_cast(dlabels[dlabels_sort_order[idx]])); - }; // NOLINT - - // Find each group's DCG value - const auto nitems = pred_sorter.GetNumItems(); - auto *ddcgs = dcgs.data().get(); - - int device_id = -1; - dh::safe_cuda(cudaGetDevice(&device_id)); - - // For each group item compute the aggregated precision - dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) { - const auto group_idx = dgroup_idx[idx]; - const auto group_begin = dgroups[group_idx]; - const auto ridx = idx - group_begin; - auto label = DetermineNonTrivialLabelLambda(idx); - if (ridx < ecfg.topn && label) { - atomicAdd(&ddcgs[group_idx], ((1 << label) - 1) / std::log2(ridx + 2.0)); - } - }); - } - - static double EvalMetric(const dh::SegmentSorter &pred_sorter, - const float *dlabels, - const EvalRankConfig &ecfg) { - // Sort the labels and compute IDCG - dh::SegmentSorter segment_label_sorter; - segment_label_sorter.SortItems(dlabels, pred_sorter.GetNumItems(), - pred_sorter.GetGroupSegmentsSpan()); - - uint32_t ngroups = pred_sorter.GetNumGroups(); - - dh::caching_device_vector idcg(ngroups, 0); - ComputeDCG(pred_sorter, dlabels, ecfg, segment_label_sorter.GetOriginalPositionsSpan(), &idcg); - - // Compute the DCG values next - dh::caching_device_vector dcg(ngroups, 0); - ComputeDCG(pred_sorter, dlabels, ecfg, pred_sorter.GetOriginalPositionsSpan(), &dcg); - - double *ddcg = dcg.data().get(); - double *didcg = idcg.data().get(); - - int device_id = -1; - dh::safe_cuda(cudaGetDevice(&device_id)); - // Compute the group's DCG and reduce it across all groups - dh::LaunchN(ngroups, nullptr, [=] __device__(uint32_t gidx) { - if (didcg[gidx] == 0.0f) { - ddcg[gidx] = (ecfg.minus) ? 0.0f : 1.0f; - } else { - ddcg[gidx] /= didcg[gidx]; - } - }); - - // Allocator to be used for managing space overhead while performing reductions - dh::XGBCachingDeviceAllocator alloc; - return thrust::reduce(thrust::cuda::par(alloc), dcg.begin(), dcg.end()); - } -}; /*! \brief Mean Average Precision at N, for both classification and rank */ struct EvalMAPGpu { @@ -272,12 +204,46 @@ XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre") .describe("precision@k for rank computed on GPU.") .set_body([](const char* param) { return new EvalRankGpu("pre", param); }); -XGBOOST_REGISTER_GPU_METRIC(NDCGGpu, "ndcg") -.describe("ndcg@k for rank computed on GPU.") -.set_body([](const char* param) { return new EvalRankGpu("ndcg", param); }); - XGBOOST_REGISTER_GPU_METRIC(MAPGpu, "map") .describe("map@k for rank computed on GPU.") .set_body([](const char* param) { return new EvalRankGpu("map", param); }); -} // namespace metric -} // namespace xgboost + +namespace cuda_impl { +PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info, + HostDeviceVector const &predt, bool minus, + std::shared_ptr p_cache) { + CHECK(p_cache); + + auto const &p = p_cache->Param(); + auto d_weight = common::MakeOptionalWeights(ctx, info.weights_); + if (!d_weight.Empty()) { + CHECK_EQ(d_weight.weights.size(), p_cache->Groups()); + } + auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + predt.SetDevice(ctx->gpu_id); + auto d_predt = linalg::MakeTensorView(ctx, predt.ConstDeviceSpan(), predt.Size()); + + auto d_group_ptr = p_cache->DataGroupPtr(ctx); + auto n_groups = info.group_ptr_.size() - 1; + + auto d_inv_idcg = p_cache->InvIDCG(ctx); + auto d_sorted_idx = p_cache->SortedIdx(ctx, d_predt.Values()); + auto d_out_dcg = p_cache->Dcg(ctx); + + ltr::cuda_impl::CalcQueriesDCG(ctx, d_label, d_sorted_idx, p.ndcg_exp_gain, d_group_ptr, p.TopK(), + d_out_dcg); + + auto it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { + if (d_inv_idcg(i) <= 0.0) { + return PackedReduceResult{minus ? 0.0 : 1.0, static_cast(d_weight[i])}; + } + return PackedReduceResult{d_out_dcg(i) * d_inv_idcg(i) * d_weight[i], + static_cast(d_weight[i])}; + }); + auto pair = thrust::reduce(ctx->CUDACtx()->CTP(), it, it + d_out_dcg.Size(), + PackedReduceResult{0.0, 0.0}); + return pair; +} +} // namespace cuda_impl +} // namespace xgboost::metric diff --git a/src/metric/rank_metric.h b/src/metric/rank_metric.h new file mode 100644 index 000000000000..0be0d4ee8c94 --- /dev/null +++ b/src/metric/rank_metric.h @@ -0,0 +1,33 @@ +#ifndef XGBOOST_METRIC_RANK_METRIC_H_ +#define XGBOOST_METRIC_RANK_METRIC_H_ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include // for shared_ptr + +#include "../common/common.h" // for AssertGPUSupport +#include "../common/ranking_utils.h" // for NDCGCache +#include "metric_common.h" // for PackedReduceResult +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector + +namespace xgboost { +namespace metric { +namespace cuda_impl { +PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info, + HostDeviceVector const &predt, bool minus, + std::shared_ptr p_cache); + +#if !defined(XGBOOST_USE_CUDA) +inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &, + HostDeviceVector const &, bool, + std::shared_ptr) { + common::AssertGPUSupport(); + return {}; +} +#endif +} // namespace cuda_impl +} // namespace metric +} // namespace xgboost +#endif // XGBOOST_METRIC_RANK_METRIC_H_ diff --git a/tests/cpp/metric/test_rank_metric.cc b/tests/cpp/metric/test_rank_metric.cc index 1edbd9fc8d76..337ddbc8afa5 100644 --- a/tests/cpp/metric/test_rank_metric.cc +++ b/tests/cpp/metric/test_rank_metric.cc @@ -1,7 +1,20 @@ -// Copyright by Contributors -#include - -#include "../helpers.h" +/** + * Copyright 2016-2023 by XGBoost Contributors + */ +#include // for Test, EXPECT_NEAR, ASSERT_STREQ +#include // for Context +#include // for MetaInfo, DMatrix +#include // for Matrix +#include // for Metric + +#include // for max +#include // for unique_ptr +#include // for vector + +#include "../helpers.h" // for GetMetricEval, CreateEmptyGe... +#include "xgboost/base.h" // for bst_float, kRtEps +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/json.h" // for Json, String, Object #if !defined(__CUDACC__) TEST(Metric, AMS) { @@ -51,15 +64,17 @@ TEST(Metric, DeclareUnifiedTest(Precision)) { delete metric; } +namespace xgboost { +namespace metric { TEST(Metric, DeclareUnifiedTest(NDCG)) { - auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX); - xgboost::Metric * metric = xgboost::Metric::Create("ndcg", &ctx); + auto ctx = CreateEmptyGenericParam(GPUIDX); + Metric * metric = xgboost::Metric::Create("ndcg", &ctx); ASSERT_STREQ(metric->Name(), "ndcg"); EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {})); - EXPECT_NEAR(GetMetricEval(metric, + ASSERT_NEAR(GetMetricEval(metric, xgboost::HostDeviceVector{}, {}), 1, 1e-10); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); + ASSERT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, { 0, 0, 1, 1}), @@ -80,7 +95,7 @@ TEST(Metric, DeclareUnifiedTest(NDCG)) { EXPECT_NEAR(GetMetricEval(metric, xgboost::HostDeviceVector{}, {}), 0, 1e-10); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); + ASSERT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1.f, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, { 0, 0, 1, 1}), @@ -91,29 +106,30 @@ TEST(Metric, DeclareUnifiedTest(NDCG)) { EXPECT_NEAR(GetMetricEval(metric, xgboost::HostDeviceVector{}, {}), 0, 1e-10); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1.f, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, { 0, 0, 1, 1}), - 0.6509f, 0.001f); + 0.6509f, 0.001f); delete metric; metric = xgboost::Metric::Create("ndcg@2-", &ctx); ASSERT_STREQ(metric->Name(), "ndcg@2-"); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1.f, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, { 0, 0, 1, 1}), - 0.3868f, 0.001f); + 1.f - 0.3868f, 1.f - 0.001f); delete metric; } TEST(Metric, DeclareUnifiedTest(MAP)) { auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX); - xgboost::Metric * metric = xgboost::Metric::Create("map", &ctx); + Metric * metric = xgboost::Metric::Create("map", &ctx); ASSERT_STREQ(metric->Name(), "map"); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, kRtEps); + EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, { 0, 0, 1, 1}), @@ -154,3 +170,39 @@ TEST(Metric, DeclareUnifiedTest(MAP)) { 0.25f, 0.001f); delete metric; } + +TEST(Metric, DeclareUnifiedTest(NDCGExpGain)) { + Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX); + + auto p_fmat = xgboost::RandomDataGenerator{0, 0, 0}.GenerateDMatrix(); + MetaInfo& info = p_fmat->Info(); + info.labels = linalg::Matrix{{10.0f, 0.0f, 0.0f, 1.0f, 5.0f}, {5}, ctx.gpu_id}; + info.num_row_ = info.labels.Shape(0); + info.group_ptr_.resize(2); + info.group_ptr_[0] = 0; + info.group_ptr_[1] = info.num_row_; + HostDeviceVector predt{{0.1f, 0.2f, 0.3f, 4.0f, 70.0f}}; + + std::unique_ptr metric{Metric::Create("ndcg", &ctx)}; + Json config{Object{}}; + config["name"] = String{"ndcg"}; + config["lambdarank_param"] = Object{}; + config["lambdarank_param"]["ndcg_exp_gain"] = String{"true"}; + config["lambdarank_param"]["lambdarank_num_pair_per_sample"] = String{"32"}; + metric->LoadConfig(config); + + auto ndcg = metric->Evaluate(predt, p_fmat); + ASSERT_NEAR(ndcg, 0.409738f, kRtEps); + + config["lambdarank_param"]["ndcg_exp_gain"] = String{"false"}; + metric->LoadConfig(config); + + ndcg = metric->Evaluate(predt, p_fmat); + ASSERT_NEAR(ndcg, 0.695694f, kRtEps); + + predt.HostVector() = info.labels.Data()->HostVector(); + ndcg = metric->Evaluate(predt, p_fmat); + ASSERT_NEAR(ndcg, 1.0, kRtEps); +} +} // namespace metric +} // namespace xgboost