diff --git a/doc/parameter.rst b/doc/parameter.rst index 99d6f0585936..ac566af749f9 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -408,8 +408,17 @@ Specify the learning task and the corresponding learning objective. The objectiv - ``ndcg``: `Normalized Discounted Cumulative Gain `_ - ``map``: `Mean Average Precision `_ - - ``ndcg@n``, ``map@n``: 'n' can be assigned as an integer to cut off the top positions in the lists for evaluation. - - ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, NDCG and MAP will evaluate the score of a list without any positive samples as 1. By adding "-" in the evaluation metric XGBoost will evaluate these score as 0 to be consistent under some conditions. + + The `average precision` is defined as: + + .. math:: + + AP@l = \frac{1}{min{(l, N)}}\sum^l_{k=1}P@k \cdot I_{(k)} + + where :math:`I_{(k)}` is an indicator function that equals to :math:`1` when the document at :math:`k` is relevant and :math:`0` otherwise. The :math:`P@k` is the precision at :math:`k`, and :math:`N` is the total number of relevant documents. Lastly, the `mean average precision` is defined as the weighted average across all queries. + + - ``ndcg@n``, ``map@n``: :math:`n` can be assigned as an integer to cut off the top positions in the lists for evaluation. + - ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, the NDCG and MAP evaluate the score of a list without any positive samples as :math:`1`. By appending "-" to the evaluation metric name, we can ask XGBoost to evaluate these scores as :math:`0` to be consistent under some conditions. - ``poisson-nloglik``: negative log-likelihood for Poisson regression - ``gamma-nloglik``: negative log-likelihood for gamma regression - ``cox-nloglik``: negative partial log-likelihood for Cox proportional hazards regression diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 3b33e87749f3..bb13b5523ed2 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -14,6 +14,7 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from io import StringIO +from pathlib import Path from platform import system from typing import ( Any, @@ -443,7 +444,7 @@ def get_mq2008( from sklearn.datasets import load_svmlight_files src = "https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip" - target = dpath + "/MQ2008.zip" + target = os.path.join(os.path.expanduser(dpath), "MQ2008.zip") if not os.path.exists(target): request.urlretrieve(url=src, filename=target) @@ -462,9 +463,9 @@ def get_mq2008( qid_valid, ) = load_svmlight_files( ( - dpath + "MQ2008/Fold1/train.txt", - dpath + "MQ2008/Fold1/test.txt", - dpath + "MQ2008/Fold1/vali.txt", + Path(dpath) / "MQ2008" / "Fold1" / "train.txt", + Path(dpath) / "MQ2008" / "Fold1" / "test.txt", + Path(dpath) / "MQ2008" / "Fold1" / "vali.txt", ), query_id=True, zero_based=False, diff --git a/python-package/xgboost/testing/ranking.py b/python-package/xgboost/testing/ranking.py index fe4fc8404567..7c75012c2bc6 100644 --- a/python-package/xgboost/testing/ranking.py +++ b/python-package/xgboost/testing/ranking.py @@ -48,7 +48,12 @@ def run_ranking_qid_df(impl: ModuleType, tree_method: str) -> None: def neg_mse(*args: Any, **kwargs: Any) -> float: return -float(mean_squared_error(*args, **kwargs)) - ranker = xgb.XGBRanker(n_estimators=3, eval_metric=neg_mse, tree_method=tree_method) + ranker = xgb.XGBRanker( + n_estimators=3, + eval_metric=neg_mse, + tree_method=tree_method, + disable_default_eval_metric=True, + ) ranker.fit(df, y, eval_set=[(valid_df, y)]) score = ranker.score(valid_df, y) assert np.isclose(score, ranker.evals_result()["validation_0"]["neg_mse"][-1]) diff --git a/src/common/error_msg.h b/src/common/error_msg.h index 484595316e26..3dbb7f52c150 100644 --- a/src/common/error_msg.h +++ b/src/common/error_msg.h @@ -22,7 +22,7 @@ constexpr StringView LabelScoreSize() { } constexpr StringView InfInData() { - return "Input data contains `inf` while `missing` is not set to `inf`"; + return "Input data contains `inf` or a value too large, while `missing` is not set to `inf`"; } } // namespace xgboost::error #endif // XGBOOST_COMMON_ERROR_MSG_H_ diff --git a/src/common/numeric.h b/src/common/numeric.h index 6a1c15fd08b4..2da85502ad17 100644 --- a/src/common/numeric.h +++ b/src/common/numeric.h @@ -1,13 +1,15 @@ -/*! - * Copyright 2022, XGBoost contributors. +/** + * Copyright 2022-2023 by XGBoost contributors. */ #ifndef XGBOOST_COMMON_NUMERIC_H_ #define XGBOOST_COMMON_NUMERIC_H_ #include // OMPException -#include // std::max -#include // std::iterator_traits +#include // for std::max +#include // for size_t +#include // for int32_t +#include // for iterator_traits #include #include "common.h" // AssertGPUSupport @@ -15,8 +17,7 @@ #include "xgboost/context.h" // Context #include "xgboost/host_device_vector.h" // HostDeviceVector -namespace xgboost { -namespace common { +namespace xgboost::common { /** * \brief Run length encode on CPU, input must be sorted. @@ -111,11 +112,11 @@ inline double Reduce(Context const*, HostDeviceVector const&) { namespace cpu_impl { template V Reduce(Context const* ctx, It first, It second, V const& init) { - size_t n = std::distance(first, second); - common::MemStackAllocator result_tloc(ctx->Threads(), init); - common::ParallelFor(n, ctx->Threads(), - [&](auto i) { result_tloc[omp_get_thread_num()] += first[i]; }); - auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cbegin() + ctx->Threads(), init); + std::size_t n = std::distance(first, second); + auto n_threads = static_cast(std::min(n, static_cast(ctx->Threads()))); + common::MemStackAllocator result_tloc(n_threads, init); + common::ParallelFor(n, n_threads, [&](auto i) { result_tloc[omp_get_thread_num()] += first[i]; }); + auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cbegin() + n_threads, init); return result; } } // namespace cpu_impl @@ -144,7 +145,6 @@ void Iota(Context const* ctx, It first, It last, }); } } -} // namespace common -} // namespace xgboost +} // namespace xgboost::common #endif // XGBOOST_COMMON_NUMERIC_H_ diff --git a/src/common/ranking_utils.cc b/src/common/ranking_utils.cc index c8069784b301..d831b551c7d0 100644 --- a/src/common/ranking_utils.cc +++ b/src/common/ranking_utils.cc @@ -114,6 +114,15 @@ void NDCGCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUS DMLC_REGISTER_PARAMETER(LambdaRankParam); +void MAPCache::InitOnCPU(Context const*, MetaInfo const& info) { + auto const& h_label = info.labels.HostView().Slice(linalg::All(), 0); + CheckMapLabels(h_label, [](auto beg, auto end, auto op) { return std::all_of(beg, end, op); }); +} + +#if !defined(XGBOOST_USE_CUDA) +void MAPCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); } +#endif // !defined(XGBOOST_USE_CUDA) + std::string ParseMetricName(StringView name, StringView param, position_t* topn, bool* minus) { std::string out_name; if (!param.empty()) { diff --git a/src/common/ranking_utils.cu b/src/common/ranking_utils.cu index ce9cda4e24e5..8fbf89818cf6 100644 --- a/src/common/ranking_utils.cu +++ b/src/common/ranking_utils.cu @@ -204,4 +204,9 @@ void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { dh::LaunchN(MaxGroupSize(), cuctx->Stream(), [=] XGBOOST_DEVICE(std::size_t i) { d_discount[i] = CalcDCGDiscount(i); }); } + +void MAPCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { + auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + CheckMapLabels(d_label, CheckMAPOp{ctx->CUDACtx()}); +} } // namespace xgboost::ltr diff --git a/src/common/ranking_utils.h b/src/common/ranking_utils.h index 88283fba286e..727f918f26ed 100644 --- a/src/common/ranking_utils.h +++ b/src/common/ranking_utils.h @@ -358,6 +358,71 @@ void CheckNDCGLabels(ltr::LambdaRankParam const& p, linalg::VectorView +bool IsBinaryRel(linalg::VectorView label, AllOf all_of) { + auto s_label = label.Values(); + return all_of(s_label.data(), s_label.data() + s_label.size(), [] XGBOOST_DEVICE(float y) { + return std::abs(y - 1.0f) < kRtEps || std::abs(y - 0.0f) < kRtEps; + }); +} +/** + * \brief Validate label for MAP + * + * \tparam Implementation of std::all_of. Specified as a parameter to reuse the check for + * both CPU and GPU. + */ +template +void CheckMapLabels(linalg::VectorView label, AllOf all_of) { + auto s_label = label.Values(); + auto is_binary = IsBinaryRel(label, all_of); + CHECK(is_binary) << "MAP can only be used with binary labels."; +} + +class MAPCache : public RankingCache { + // Total number of relevant documents for each group + HostDeviceVector n_rel_; + // \sum l_k/k + HostDeviceVector acc_; + HostDeviceVector map_; + // Number of samples in this dataset. + std::size_t n_samples_{0}; + + void InitOnCPU(Context const* ctx, MetaInfo const& info); + void InitOnCUDA(Context const* ctx, MetaInfo const& info); + + public: + MAPCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p) + : RankingCache{ctx, info, p}, n_samples_{static_cast(info.num_row_)} { + if (ctx->IsCPU()) { + this->InitOnCPU(ctx, info); + } else { + this->InitOnCUDA(ctx, info); + } + } + + common::Span NumRelevant(Context const* ctx) { + if (n_rel_.Empty()) { + n_rel_.SetDevice(ctx->gpu_id); + n_rel_.Resize(n_samples_); + } + return ctx->IsCPU() ? n_rel_.HostSpan() : n_rel_.DeviceSpan(); + } + common::Span Acc(Context const* ctx) { + if (acc_.Empty()) { + acc_.SetDevice(ctx->gpu_id); + acc_.Resize(n_samples_); + } + return ctx->IsCPU() ? acc_.HostSpan() : acc_.DeviceSpan(); + } + common::Span Map(Context const* ctx) { + if (map_.Empty()) { + map_.SetDevice(ctx->gpu_id); + map_.Resize(this->Groups()); + } + return ctx->IsCPU() ? map_.HostSpan() : map_.DeviceSpan(); + } +}; + /** * \brief Parse name for ranking metric given parameters. * diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index a52695e02590..d80008cc0809 100644 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -8,9 +8,11 @@ #include #include -#include // std::int32_t +#include // for int32_t +#include // for malloc, free #include -#include // std::is_signed +#include // for bad_alloc +#include // for is_signed #include #include "xgboost/logging.h" @@ -266,7 +268,7 @@ class MemStackAllocator { if (MaxStackSize >= required_size_) { ptr_ = stack_mem_; } else { - ptr_ = reinterpret_cast(malloc(required_size_ * sizeof(T))); + ptr_ = reinterpret_cast(std::malloc(required_size_ * sizeof(T))); } if (!ptr_) { throw std::bad_alloc{}; @@ -278,7 +280,7 @@ class MemStackAllocator { ~MemStackAllocator() { if (required_size_ > MaxStackSize) { - free(ptr_); + std::free(ptr_); } } T& operator[](size_t i) { return ptr_[i]; } diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index c2aa48cab853..3a1416b0ff44 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -284,37 +284,6 @@ struct EvalPrecision : public EvalRank { } }; -/*! \brief Mean Average Precision at N, for both classification and rank */ -struct EvalMAP : public EvalRank { - public: - explicit EvalMAP(const char* name, const char* param) : EvalRank(name, param) {} - - double EvalGroup(PredIndPairContainer *recptr) const override { - PredIndPairContainer &rec(*recptr); - std::stable_sort(rec.begin(), rec.end(), common::CmpFirst); - unsigned nhits = 0; - double sumap = 0.0; - for (size_t i = 0; i < rec.size(); ++i) { - if (rec[i].second != 0) { - nhits += 1; - if (i < this->topn) { - sumap += static_cast(nhits) / (i + 1); - } - } - } - if (nhits != 0) { - sumap /= nhits; - return sumap; - } else { - if (this->minus) { - return 0.0; - } else { - return 1.0; - } - } - } -}; - /*! \brief Cox: Partial likelihood of the Cox proportional hazards model */ struct EvalCox : public MetricNoCache { public: @@ -370,10 +339,6 @@ XGBOOST_REGISTER_METRIC(Precision, "pre") .describe("precision@k for rank.") .set_body([](const char* param) { return new EvalPrecision("pre", param); }); -XGBOOST_REGISTER_METRIC(MAP, "map") -.describe("map@k for rank.") -.set_body([](const char* param) { return new EvalMAP("map", param); }); - XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik") .describe("Negative log partial likelihood of Cox proportional hazards model.") .set_body([](const char*) { return new EvalCox(); }); @@ -516,6 +481,68 @@ class EvalNDCG : public EvalRankWithCache { } }; +class EvalMAPScore : public EvalRankWithCache { + public: + using EvalRankWithCache::EvalRankWithCache; + const char* Name() const override { return name_.c_str(); } + + double Eval(HostDeviceVector const& predt, MetaInfo const& info, + std::shared_ptr p_cache) override { + if (ctx_->IsCUDA()) { + auto map = cuda_impl::MAPScore(ctx_, info, predt, minus_, p_cache); + return Finalize(map.Residue(), map.Weights()); + } + + auto gptr = p_cache->DataGroupPtr(ctx_); + auto h_label = info.labels.HostView().Slice(linalg::All(), 0); + auto h_predt = linalg::MakeTensorView(ctx_, &predt, predt.Size()); + + auto map_gloc = p_cache->Map(ctx_); + std::fill_n(map_gloc.data(), map_gloc.size(), 0.0); + auto rank_idx = p_cache->SortedIdx(ctx_, predt.ConstHostSpan()); + + common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) { + auto g_predt = h_predt.Slice(linalg::Range(gptr[g], gptr[g + 1])); + auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1])); + auto g_rank = rank_idx.subspan(gptr[g]); + + auto n = std::min(static_cast(param_.TopK()), g_label.Size()); + double n_hits{0.0}; + for (std::size_t i = 0; i < n; ++i) { + auto p = g_label(g_rank[i]); + n_hits += p; + map_gloc[g] += n_hits / static_cast((i + 1)) * p; + } + for (std::size_t i = n; i < g_label.Size(); ++i) { + n_hits += g_label(g_rank[i]); + } + if (n_hits > 0.0) { + map_gloc[g] /= std::min(n_hits, static_cast(param_.TopK())); + } else { + map_gloc[g] = minus_ ? 0.0 : 1.0; + } + }); + + auto sw = 0.0; + auto weight = common::MakeOptionalWeights(ctx_, info.weights_); + if (!weight.Empty()) { + CHECK_EQ(weight.weights.size(), p_cache->Groups()); + } + for (std::size_t i = 0; i < map_gloc.size(); ++i) { + map_gloc[i] = map_gloc[i] * weight[i]; + sw += weight[i]; + } + auto sum = std::accumulate(map_gloc.cbegin(), map_gloc.cend(), 0.0); + return Finalize(sum, sw); + } +}; + +XGBOOST_REGISTER_METRIC(EvalMAP, "map") + .describe("map@k for ranking.") + .set_body([](char const* param) { + return new EvalMAPScore{"map", param}; + }); + XGBOOST_REGISTER_METRIC(EvalNDCG, "ndcg") .describe("ndcg@k for ranking.") .set_body([](char const* param) { diff --git a/src/metric/rank_metric.cu b/src/metric/rank_metric.cu index 4ab422a96838..00116ebdb2ad 100644 --- a/src/metric/rank_metric.cu +++ b/src/metric/rank_metric.cu @@ -125,89 +125,10 @@ struct EvalPrecisionGpu { }; -/*! \brief Mean Average Precision at N, for both classification and rank */ -struct EvalMAPGpu { - public: - static double EvalMetric(const dh::SegmentSorter &pred_sorter, - const float *dlabels, - const EvalRankConfig &ecfg) { - // Group info on device - const auto &dgroups = pred_sorter.GetGroupsSpan(); - const auto ngroups = pred_sorter.GetNumGroups(); - const auto &dgroup_idx = pred_sorter.GetGroupSegmentsSpan(); - - // Original positions of the predictions after they have been sorted - const auto &dpreds_orig_pos = pred_sorter.GetOriginalPositionsSpan(); - - // First, determine non zero labels in the dataset individually - const auto nitems = pred_sorter.GetNumItems(); - dh::caching_device_vector hits(nitems, 0); - auto DetermineNonTrivialLabelLambda = [=] __device__(uint32_t idx) { - return (static_cast(dlabels[dpreds_orig_pos[idx]]) != 0) ? 1 : 0; - }; // NOLINT - - thrust::transform(thrust::make_counting_iterator(static_cast(0)), - thrust::make_counting_iterator(nitems), - hits.begin(), - DetermineNonTrivialLabelLambda); - - // Allocator to be used by sort for managing space overhead while performing prefix scans - dh::XGBCachingDeviceAllocator alloc; - - // Next, prefix scan the nontrivial labels that are segmented to accumulate them. - // This is required for computing the metric sum - // Data segmented into different groups... - thrust::inclusive_scan_by_key(thrust::cuda::par(alloc), - dh::tcbegin(dgroup_idx), dh::tcend(dgroup_idx), - hits.begin(), // Input value - hits.begin()); // In-place scan - - // Find each group's metric sum - dh::caching_device_vector sumap(ngroups, 0); - auto *dsumap = sumap.data().get(); - const auto *dhits = hits.data().get(); - - int device_id = -1; - dh::safe_cuda(cudaGetDevice(&device_id)); - // For each group item compute the aggregated precision - dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) { - if (DetermineNonTrivialLabelLambda(idx)) { - const auto group_idx = dgroup_idx[idx]; - const auto group_begin = dgroups[group_idx]; - const auto ridx = idx - group_begin; - if (ridx < ecfg.topn) { - atomicAdd(&dsumap[group_idx], - static_cast(dhits[idx]) / (ridx + 1)); - } - } - }); - - // Aggregate the group's item precisions - dh::LaunchN(ngroups, nullptr, [=] __device__(uint32_t gidx) { - auto nhits = dgroups[gidx + 1] ? dhits[dgroups[gidx + 1] - 1] : 0; - if (nhits != 0) { - dsumap[gidx] /= nhits; - } else { - if (ecfg.minus) { - dsumap[gidx] = 0; - } else { - dsumap[gidx] = 1; - } - } - }); - - return thrust::reduce(thrust::cuda::par(alloc), sumap.begin(), sumap.end()); - } -}; - XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre") .describe("precision@k for rank computed on GPU.") .set_body([](const char* param) { return new EvalRankGpu("pre", param); }); -XGBOOST_REGISTER_GPU_METRIC(MAPGpu, "map") -.describe("map@k for rank computed on GPU.") -.set_body([](const char* param) { return new EvalRankGpu("map", param); }); - namespace cuda_impl { PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info, HostDeviceVector const &predt, bool minus, @@ -245,5 +166,87 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info, PackedReduceResult{0.0, 0.0}); return pair; } + +PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info, + HostDeviceVector const &predt, bool minus, + std::shared_ptr p_cache) { + auto d_group_ptr = p_cache->DataGroupPtr(ctx); + auto n_groups = info.group_ptr_.size() - 1; + auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + + predt.SetDevice(ctx->gpu_id); + auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan()); + auto key_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + [=] XGBOOST_DEVICE(std::size_t i) { return dh::SegmentId(d_group_ptr, i); }); + + auto get_label = [=] XGBOOST_DEVICE(std::size_t i) { + auto g = key_it[i]; + auto g_begin = d_group_ptr[g]; + auto g_end = d_group_ptr[g + 1]; + i -= g_begin; + auto g_label = d_label.Slice(linalg::Range(g_begin, g_end)); + auto g_rank = d_rank_idx.subspan(g_begin, g_end - g_begin); + return g_label(g_rank[i]); + }; + auto it = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), get_label); + + auto cuctx = ctx->CUDACtx(); + auto n_rel = p_cache->NumRelevant(ctx); + thrust::inclusive_scan_by_key(cuctx->CTP(), key_it, key_it + d_label.Size(), it, n_rel.data()); + + double topk = p_cache->Param().TopK(); + auto map = p_cache->Map(ctx); + thrust::fill_n(cuctx->CTP(), map.data(), map.size(), 0.0); + { + auto val_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { + auto g = key_it[i]; + auto g_begin = d_group_ptr[g]; + auto g_end = d_group_ptr[g + 1]; + i -= g_begin; + if (i >= topk) { + return 0.0; + } + + auto g_label = d_label.Slice(linalg::Range(g_begin, g_end)); + auto g_rank = d_rank_idx.subspan(g_begin, g_end - g_begin); + auto label = g_label(g_rank[i]); + + auto g_n_rel = n_rel.subspan(g_begin, g_end - g_begin); + auto nhits = g_n_rel[i]; + return nhits / static_cast(i + 1) * label; + }); + + std::size_t bytes; + cub::DeviceSegmentedReduce::Sum(nullptr, bytes, val_it, map.data(), p_cache->Groups(), + d_group_ptr.data(), d_group_ptr.data() + 1, cuctx->Stream()); + dh::TemporaryArray temp(bytes); + cub::DeviceSegmentedReduce::Sum(temp.data().get(), bytes, val_it, map.data(), p_cache->Groups(), + d_group_ptr.data(), d_group_ptr.data() + 1, cuctx->Stream()); + } + + PackedReduceResult result{0.0, 0.0}; + { + auto d_weight = common::MakeOptionalWeights(ctx, info.weights_); + if (!d_weight.Empty()) { + CHECK_EQ(d_weight.weights.size(), p_cache->Groups()); + } + auto val_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t g) { + auto g_begin = d_group_ptr[g]; + auto g_end = d_group_ptr[g + 1]; + auto g_n_rel = n_rel.subspan(g_begin, g_end - g_begin); + if (!g_n_rel.empty() && g_n_rel.back() > 0.0) { + return PackedReduceResult{map[g] * d_weight[g] / std::min(g_n_rel.back(), topk), + static_cast(d_weight[g])}; + } + return PackedReduceResult{minus ? 0.0 : 1.0, static_cast(d_weight[g])}; + }); + result = + thrust::reduce(cuctx->CTP(), val_it, val_it + map.size(), PackedReduceResult{0.0, 0.0}); + } + return result; +} } // namespace cuda_impl } // namespace xgboost::metric diff --git a/src/metric/rank_metric.h b/src/metric/rank_metric.h index 0be0d4ee8c94..b3b121973ef8 100644 --- a/src/metric/rank_metric.h +++ b/src/metric/rank_metric.h @@ -6,7 +6,7 @@ #include // for shared_ptr #include "../common/common.h" // for AssertGPUSupport -#include "../common/ranking_utils.h" // for NDCGCache +#include "../common/ranking_utils.h" // for NDCGCache, MAPCache #include "metric_common.h" // for PackedReduceResult #include "xgboost/context.h" // for Context #include "xgboost/data.h" // for MetaInfo @@ -19,6 +19,10 @@ PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info, HostDeviceVector const &predt, bool minus, std::shared_ptr p_cache); +PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info, + HostDeviceVector const &predt, bool minus, + std::shared_ptr p_cache); + #if !defined(XGBOOST_USE_CUDA) inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &, HostDeviceVector const &, bool, @@ -26,6 +30,13 @@ inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &, common::AssertGPUSupport(); return {}; } + +inline PackedReduceResult MAPScore(Context const *, MetaInfo const &, + HostDeviceVector const &, bool, + std::shared_ptr) { + common::AssertGPUSupport(); + return {}; +} #endif } // namespace cuda_impl } // namespace metric diff --git a/tests/cpp/common/test_ranking_utils.cc b/tests/cpp/common/test_ranking_utils.cc index 9240db0d4814..919102278b98 100644 --- a/tests/cpp/common/test_ranking_utils.cc +++ b/tests/cpp/common/test_ranking_utils.cc @@ -177,4 +177,36 @@ TEST(NDCGCache, InitFromCPU) { Context ctx; TestNDCGCache(&ctx); } + +void TestMAPCache(Context const* ctx) { + auto p_fmat = EmptyDMatrix(); + MetaInfo& info = p_fmat->Info(); + LambdaRankParam param; + param.UpdateAllowUnknown(Args{}); + + std::vector h_data(32); + + common::Iota(ctx, h_data.begin(), h_data.end(), 0.0f); + info.labels.Reshape(h_data.size()); + info.num_row_ = h_data.size(); + info.labels.Data()->HostVector() = std::move(h_data); + + auto fail = [&]() { std::make_shared(ctx, info, param); }; + // binary label + ASSERT_THROW(fail(), dmlc::Error); + + h_data = std::vector(32, 0.0f); + h_data[1] = 1.0f; + info.labels.Data()->HostVector() = h_data; + auto p_cache = std::make_shared(ctx, info, param); + + ASSERT_EQ(p_cache->Acc(ctx).size(), info.num_row_); + ASSERT_EQ(p_cache->NumRelevant(ctx).size(), info.num_row_); +} + +TEST(MAPCache, InitFromCPU) { + Context ctx; + ctx.Init(Args{}); + TestMAPCache(&ctx); +} } // namespace xgboost::ltr diff --git a/tests/cpp/common/test_ranking_utils.cu b/tests/cpp/common/test_ranking_utils.cu index 5fda42c724be..db0ff3b66908 100644 --- a/tests/cpp/common/test_ranking_utils.cu +++ b/tests/cpp/common/test_ranking_utils.cu @@ -95,4 +95,10 @@ TEST(NDCGCache, InitFromGPU) { ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); TestNDCGCache(&ctx); } + +TEST(MAPCache, InitFromGPU) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); + TestMAPCache(&ctx); +} } // namespace xgboost::ltr diff --git a/tests/cpp/common/test_ranking_utils.h b/tests/cpp/common/test_ranking_utils.h index ede687ff4edb..8ff92df9a649 100644 --- a/tests/cpp/common/test_ranking_utils.h +++ b/tests/cpp/common/test_ranking_utils.h @@ -6,4 +6,6 @@ namespace xgboost::ltr { void TestNDCGCache(Context const* ctx); + +void TestMAPCache(Context const* ctx); } // namespace xgboost::ltr diff --git a/tests/cpp/metric/test_rank_metric.cc b/tests/cpp/metric/test_rank_metric.cc index 337ddbc8afa5..3e1028c48d7e 100644 --- a/tests/cpp/metric/test_rank_metric.cc +++ b/tests/cpp/metric/test_rank_metric.cc @@ -141,7 +141,7 @@ TEST(Metric, DeclareUnifiedTest(MAP)) { // Rank metric with group info EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.2f, 0.8f, 0.4f, 1.7f}, - {2, 7, 1, 0, 5, 0}, // Labels + {1, 1, 1, 0, 1, 0}, // Labels {}, // Weights {0, 2, 5, 6}), // Group info 0.8611f, 0.001f); diff --git a/tests/python-gpu/test_gpu_ranking.py b/tests/python-gpu/test_gpu_ranking.py index d86c1aa142af..b8be5dda169a 100644 --- a/tests/python-gpu/test_gpu_ranking.py +++ b/tests/python-gpu/test_gpu_ranking.py @@ -1,194 +1,130 @@ -import itertools import os -import shutil -import urllib.request -import zipfile +from typing import Dict import numpy as np +import pytest import xgboost from xgboost import testing as tm -pytestmark = tm.timeout(10) - - -class TestRanking: - @classmethod - def setup_class(cls): - """ - Download and setup the test fixtures - """ - from sklearn.datasets import load_svmlight_files - - # download the test data - cls.dpath = os.path.join(tm.demo_dir(__file__), "rank/") - src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip' - target = os.path.join(cls.dpath, "MQ2008.zip") - - if os.path.exists(cls.dpath) and os.path.exists(target): - print("Skipping dataset download...") - else: - urllib.request.urlretrieve(url=src, filename=target) - with zipfile.ZipFile(target, 'r') as f: - f.extractall(path=cls.dpath) - - (x_train, y_train, qid_train, x_test, y_test, qid_test, - x_valid, y_valid, qid_valid) = load_svmlight_files( - (cls.dpath + "MQ2008/Fold1/train.txt", - cls.dpath + "MQ2008/Fold1/test.txt", - cls.dpath + "MQ2008/Fold1/vali.txt"), - query_id=True, zero_based=False) - # instantiate the matrices - cls.dtrain = xgboost.DMatrix(x_train, y_train) - cls.dvalid = xgboost.DMatrix(x_valid, y_valid) - cls.dtest = xgboost.DMatrix(x_test, y_test) - # set the group counts from the query IDs - cls.dtrain.set_group([len(list(items)) - for _key, items in itertools.groupby(qid_train)]) - cls.dtest.set_group([len(list(items)) - for _key, items in itertools.groupby(qid_test)]) - cls.dvalid.set_group([len(list(items)) - for _key, items in itertools.groupby(qid_valid)]) - # save the query IDs for testing - cls.qid_train = qid_train - cls.qid_test = qid_test - cls.qid_valid = qid_valid - - def setup_weighted(x, y, groups): - # Setup weighted data - data = xgboost.DMatrix(x, y) - groups_segment = [len(list(items)) - for _key, items in itertools.groupby(groups)] - data.set_group(groups_segment) - n_groups = len(groups_segment) - weights = np.ones((n_groups,)) - data.set_weight(weights) - return data - - cls.dtrain_w = setup_weighted(x_train, y_train, qid_train) - cls.dtest_w = setup_weighted(x_test, y_test, qid_test) - cls.dvalid_w = setup_weighted(x_valid, y_valid, qid_valid) - - # model training parameters - cls.params = {'booster': 'gbtree', - 'tree_method': 'gpu_hist', - 'gpu_id': 0, - 'predictor': 'gpu_predictor'} - cls.cpu_params = {'booster': 'gbtree', - 'tree_method': 'hist', - 'gpu_id': -1, - 'predictor': 'cpu_predictor'} - - @classmethod - def teardown_class(cls): - """ - Cleanup test artifacts from download and unpacking - :return: - """ - os.remove(os.path.join(cls.dpath, "MQ2008.zip")) - shutil.rmtree(os.path.join(cls.dpath, "MQ2008")) - - @classmethod - def __test_training_with_rank_objective(cls, rank_objective, metric_name, tolerance=1e-02): - """ - Internal method that trains the dataset using the rank objective on GPU and CPU, evaluates - the metric and determines if the delta between the metric is within the tolerance level - :return: - """ - # specify validations set to watch performance - watchlist = [(cls.dtest, 'eval'), (cls.dtrain, 'train')] - - num_trees = 100 - check_metric_improvement_rounds = 10 - - evals_result = {} - cls.params['objective'] = rank_objective - cls.params['eval_metric'] = metric_name - bst = xgboost.train( - cls.params, cls.dtrain, num_boost_round=num_trees, - early_stopping_rounds=check_metric_improvement_rounds, - evals=watchlist, evals_result=evals_result) - gpu_map_metric = evals_result['train'][metric_name][-1] - - evals_result = {} - cls.cpu_params['objective'] = rank_objective - cls.cpu_params['eval_metric'] = metric_name - bstc = xgboost.train( - cls.cpu_params, cls.dtrain, num_boost_round=num_trees, - early_stopping_rounds=check_metric_improvement_rounds, - evals=watchlist, evals_result=evals_result) - cpu_map_metric = evals_result['train'][metric_name][-1] - - assert np.allclose(gpu_map_metric, cpu_map_metric, tolerance, - tolerance) - assert np.allclose(bst.best_score, bstc.best_score, tolerance, - tolerance) - - evals_result_weighted = {} - watchlist = [(cls.dtest_w, 'eval'), (cls.dtrain_w, 'train')] - bst_w = xgboost.train( - cls.params, cls.dtrain_w, num_boost_round=num_trees, - early_stopping_rounds=check_metric_improvement_rounds, - evals=watchlist, evals_result=evals_result_weighted) - weighted_metric = evals_result_weighted['train'][metric_name][-1] - # GPU Ranking is not deterministic due to `AtomicAddGpair`, - # remove tolerance once the issue is resolved. - # https://github.com/dmlc/xgboost/issues/5561 - assert np.allclose(bst_w.best_score, bst.best_score, - tolerance, tolerance) - assert np.allclose(weighted_metric, gpu_map_metric, - tolerance, tolerance) - - def test_training_rank_pairwise_map_metric(self): - """ - Train an XGBoost ranking model with pairwise objective function and compare map metric - """ - self.__test_training_with_rank_objective('rank:pairwise', 'map') - - def test_training_rank_pairwise_auc_metric(self): - """ - Train an XGBoost ranking model with pairwise objective function and compare auc metric - """ - self.__test_training_with_rank_objective('rank:pairwise', 'auc') - - def test_training_rank_pairwise_ndcg_metric(self): - """ - Train an XGBoost ranking model with pairwise objective function and compare ndcg metric - """ - self.__test_training_with_rank_objective('rank:pairwise', 'ndcg') - - def test_training_rank_ndcg_map(self): - """ - Train an XGBoost ranking model with ndcg objective function and compare map metric - """ - self.__test_training_with_rank_objective('rank:ndcg', 'map') - - def test_training_rank_ndcg_auc(self): - """ - Train an XGBoost ranking model with ndcg objective function and compare auc metric - """ - self.__test_training_with_rank_objective('rank:ndcg', 'auc') - - def test_training_rank_ndcg_ndcg(self): - """ - Train an XGBoost ranking model with ndcg objective function and compare ndcg metric - """ - self.__test_training_with_rank_objective('rank:ndcg', 'ndcg') - - def test_training_rank_map_map(self): - """ - Train an XGBoost ranking model with map objective function and compare map metric - """ - self.__test_training_with_rank_objective('rank:map', 'map') - - def test_training_rank_map_auc(self): - """ - Train an XGBoost ranking model with map objective function and compare auc metric - """ - self.__test_training_with_rank_objective('rank:map', 'auc') - - def test_training_rank_map_ndcg(self): - """ - Train an XGBoost ranking model with map objective function and compare ndcg metric - """ - self.__test_training_with_rank_objective('rank:map', 'ndcg') +pytestmark = tm.timeout(30) + + +def comp_training_with_rank_objective( + dtrain: xgboost.DMatrix, + dtest: xgboost.DMatrix, + rank_objective: str, + metric_name: str, + tolerance: float = 1e-02, +) -> None: + """Internal method that trains the dataset using the rank objective on GPU and CPU, + evaluates the metric and determines if the delta between the metric is within the + tolerance level. + + """ + # specify validations set to watch performance + watchlist = [(dtest, "eval"), (dtrain, "train")] + + params = { + "booster": "gbtree", + "tree_method": "gpu_hist", + "gpu_id": 0, + "predictor": "gpu_predictor", + } + + num_trees = 100 + check_metric_improvement_rounds = 10 + + evals_result: Dict[str, Dict] = {} + params["objective"] = rank_objective + params["eval_metric"] = metric_name + bst = xgboost.train( + params, + dtrain, + num_boost_round=num_trees, + early_stopping_rounds=check_metric_improvement_rounds, + evals=watchlist, + evals_result=evals_result, + ) + gpu_scores = evals_result["train"][metric_name][-1] + + evals_result = {} + + cpu_params = { + "booster": "gbtree", + "tree_method": "hist", + "gpu_id": -1, + "predictor": "cpu_predictor", + } + cpu_params["objective"] = rank_objective + cpu_params["eval_metric"] = metric_name + bstc = xgboost.train( + cpu_params, + dtrain, + num_boost_round=num_trees, + early_stopping_rounds=check_metric_improvement_rounds, + evals=watchlist, + evals_result=evals_result, + ) + cpu_scores = evals_result["train"][metric_name][-1] + + info = (rank_objective, metric_name) + assert np.allclose(gpu_scores, cpu_scores, tolerance, tolerance), info + assert np.allclose(bst.best_score, bstc.best_score, tolerance, tolerance), info + + evals_result_weighted: Dict[str, Dict] = {} + dtest.set_weight(np.ones((dtest.get_group().size,))) + dtrain.set_weight(np.ones((dtrain.get_group().size,))) + watchlist = [(dtest, "eval"), (dtrain, "train")] + bst_w = xgboost.train( + params, + dtrain, + num_boost_round=num_trees, + early_stopping_rounds=check_metric_improvement_rounds, + evals=watchlist, + evals_result=evals_result_weighted, + ) + weighted_metric = evals_result_weighted["train"][metric_name][-1] + + tolerance = 1e-5 + assert np.allclose(bst_w.best_score, bst.best_score, tolerance, tolerance) + assert np.allclose(weighted_metric, gpu_scores, tolerance, tolerance) + + +@pytest.mark.parametrize( + "objective,metric", + [ + ("rank:pairwise", "auc"), + ("rank:pairwise", "ndcg"), + ("rank:pairwise", "map"), + ("rank:ndcg", "auc"), + ("rank:ndcg", "ndcg"), + ("rank:ndcg", "map"), + ("rank:map", "auc"), + ("rank:map", "ndcg"), + ("rank:map", "map"), + ], +) +def test_with_mq2008(objective, metric) -> None: + ( + x_train, + y_train, + qid_train, + x_test, + y_test, + qid_test, + x_valid, + y_valid, + qid_valid, + ) = tm.get_mq2008(os.path.join(os.path.join(tm.demo_dir(__file__), "rank"))) + + if metric.find("map") != -1 or objective.find("map") != -1: + y_train[y_train <= 1] = 0.0 + y_train[y_train > 1] = 1.0 + y_test[y_test <= 1] = 0.0 + y_test[y_test > 1] = 1.0 + + dtrain = xgboost.DMatrix(x_train, y_train, qid=qid_train) + dtest = xgboost.DMatrix(x_test, y_test, qid=qid_test) + + comp_training_with_rank_objective(dtrain, dtest, objective, metric) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index baef690ee32e..c34b7d2d1edd 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -128,12 +128,23 @@ def test_ranking(): x_test = np.random.rand(100, 10) - params = {'tree_method': 'exact', 'objective': 'rank:pairwise', - 'learning_rate': 0.1, 'gamma': 1.0, 'min_child_weight': 0.1, - 'max_depth': 6, 'n_estimators': 4} + params = { + "tree_method": "exact", + "learning_rate": 0.1, + "gamma": 1.0, + "min_child_weight": 0.1, + "max_depth": 6, + "eval_metric": "ndcg", + "n_estimators": 4, + } model = xgb.sklearn.XGBRanker(**params) - model.fit(x_train, y_train, group=train_group, - eval_set=[(x_valid, y_valid)], eval_group=[valid_group]) + model.fit( + x_train, + y_train, + group=train_group, + eval_set=[(x_valid, y_valid)], + eval_group=[valid_group], + ) assert model.evals_result() pred = model.predict(x_test) @@ -145,11 +156,18 @@ def test_ranking(): assert train_data.get_label().shape[0] == x_train.shape[0] valid_data.set_group(valid_group) - params_orig = {'tree_method': 'exact', 'objective': 'rank:pairwise', - 'eta': 0.1, 'gamma': 1.0, - 'min_child_weight': 0.1, 'max_depth': 6} - xgb_model_orig = xgb.train(params_orig, train_data, num_boost_round=4, - evals=[(valid_data, 'validation')]) + params_orig = { + "tree_method": "exact", + "objective": "rank:pairwise", + "eta": 0.1, + "gamma": 1.0, + "min_child_weight": 0.1, + "max_depth": 6, + "eval_metric": "ndcg", + } + xgb_model_orig = xgb.train( + params_orig, train_data, num_boost_round=4, evals=[(valid_data, "validation")] + ) pred_orig = xgb_model_orig.predict(test_data) np.testing.assert_almost_equal(pred, pred_orig) @@ -165,7 +183,11 @@ def test_ranking_metric() -> None: # sklearn compares the number of mis-classified docs, while the one in xgboost # compares the number of mis-classified pairs. ltr = xgb.XGBRanker( - eval_metric=roc_auc_score, n_estimators=10, tree_method="hist", max_depth=2 + eval_metric=roc_auc_score, + n_estimators=10, + tree_method="hist", + max_depth=2, + objective="rank:pairwise", ) ltr.fit( X,