diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index 04f0a74a5308..a84459db921b 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -32,7 +32,6 @@ OBJECTS= \ $(PKGROOT)/src/objective/objective.o \ $(PKGROOT)/src/objective/regression_obj.o \ $(PKGROOT)/src/objective/multiclass_obj.o \ - $(PKGROOT)/src/objective/rank_obj.o \ $(PKGROOT)/src/objective/lambdarank_obj.o \ $(PKGROOT)/src/objective/hinge.o \ $(PKGROOT)/src/objective/aft_obj.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 969cb7ff42b1..25c577e3a184 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -32,7 +32,6 @@ OBJECTS= \ $(PKGROOT)/src/objective/objective.o \ $(PKGROOT)/src/objective/regression_obj.o \ $(PKGROOT)/src/objective/multiclass_obj.o \ - $(PKGROOT)/src/objective/rank_obj.o \ $(PKGROOT)/src/objective/lambdarank_obj.o \ $(PKGROOT)/src/objective/hinge.o \ $(PKGROOT)/src/objective/aft_obj.o \ diff --git a/doc/model.schema b/doc/model.schema index b9e2da3058db..103d9d9e4221 100644 --- a/doc/model.schema +++ b/doc/model.schema @@ -219,6 +219,16 @@ "num_pairsample": { "type": "string" }, "fix_list_weight": { "type": "string" } } + }, + "lambdarank_param": { + "type": "object", + "properties": { + "lambdarank_num_pair_per_sample": { "type": "string" }, + "lambdarank_pair_method": { "type": "string" }, + "lambdarank_unbiased": {"type": "string" }, + "lambdarank_bias_norm": {"type": "string" }, + "ndcg_exp_gain": {"type": "string"} + } } }, "type": "object", @@ -477,22 +487,22 @@ "type": "object", "properties": { "name": { "const": "rank:pairwise" }, - "lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"} + "lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"} }, "required": [ "name", - "lambda_rank_param" + "lambdarank_param" ] }, { "type": "object", "properties": { "name": { "const": "rank:ndcg" }, - "lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"} + "lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"} }, "required": [ "name", - "lambda_rank_param" + "lambdarank_param" ] }, { diff --git a/doc/parameter.rst b/doc/parameter.rst index c070e7018201..8c7cadcdc3b2 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -233,7 +233,7 @@ Parameters for Tree Booster .. note:: This parameter is working-in-progress. - The strategy used for training multi-target models, including multi-target regression - and multi-class classification. See :doc:`/tutorials/multioutput` for more information. + and multi-class classification. See :doc:`/tutorials/multioutput` for more information. - ``one_output_per_tree``: One model for each target. - ``multi_output_tree``: Use multi-target trees. @@ -380,9 +380,9 @@ Specify the learning task and the corresponding learning objective. The objectiv See :doc:`/tutorials/aft_survival_analysis` for details. - ``multi:softmax``: set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes) - ``multi:softprob``: same as softmax, but output a vector of ``ndata * nclass``, which can be further reshaped to ``ndata * nclass`` matrix. The result contains predicted probability of each data point belonging to each class. - - ``rank:pairwise``: Use LambdaMART to perform pairwise ranking where the pairwise loss is minimized - - ``rank:ndcg``: Use LambdaMART to perform list-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) `_ is maximized - - ``rank:map``: Use LambdaMART to perform list-wise ranking where `Mean Average Precision (MAP) `_ is maximized + - ``rank:ndcg``: Use LambdaMART to perform pair-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) `_ is maximized. This objective supports position debiasing for click data. + - ``rank:map``: Use LambdaMART to perform pair-wise ranking where `Mean Average Precision (MAP) `_ is maximized + - ``rank:pairwise``: Use LambdaRank to perform pair-wise ranking using the `ranknet` objective. - ``reg:gamma``: gamma regression with log-link. Output is a mean of gamma distribution. It might be useful, e.g., for modeling insurance claims severity, or for any outcome that might be `gamma-distributed `_. - ``reg:tweedie``: Tweedie regression with log-link. It might be useful, e.g., for modeling total loss in insurance, or for any outcome that might be `Tweedie-distributed `_. @@ -395,8 +395,9 @@ Specify the learning task and the corresponding learning objective. The objectiv * ``eval_metric`` [default according to objective] - - Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, mean average precision for ranking) - - User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous one + - Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, `mean average precision` for ``rank:map``, etc.) + - User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous ones + - The choices are listed below: - ``rmse``: `root mean square error `_ @@ -480,6 +481,36 @@ Parameter for using AFT Survival Loss (``survival:aft``) and Negative Log Likeli * ``aft_loss_distribution``: Probability Density Function, ``normal``, ``logistic``, or ``extreme``. +.. _ltr-param: + +Parameters for learning to rank (``rank:ndcg``, ``rank:map``, ``rank:pairwise``) +================================================================================ + +These are parameters specific to learning to rank task. See :doc:`Learning to Rank ` for an in-depth explanation. + +* ``lambdarank_pair_method`` [default = ``mean``] + + How to construct pairs for pair-wise learning. + + - ``mean``: Sample ``lambdarank_num_pair_per_sample`` pairs for each document in the query list. + - ``topk``: Focus on top-``lambdarank_num_pair_per_sample`` documents. Construct :math:`|query|` pairs for each document at the top-``lambdarank_num_pair_per_sample`` ranked by the model. + +* ``lambdarank_num_pair_per_sample`` [range = :math:`[1, \infty]`] + + It specifies the number of pairs sampled for each document when pair method is ``mean``, or the truncation level for queries when the pair method is ``topk``. For example, to train with ``ndcg@6``, set ``lambdarank_num_pair_per_sample`` to :math:`6` and ``lambdarank_pair_method`` to ``topk``. + +* ``lambdarank_unbiased`` [default = ``false``] + + Specify whether do we need to debias input click data. + +* ``lambdarank_bias_norm`` [default = 2.0] + + :math:`L_p` normalization for position debiasing, default is :math:`L_2`. Only relevant when ``lambdarank_unbiased`` is set to true. + +* ``ndcg_exp_gain`` [default = ``true``] + + Whether we should use exponential gain function for ``NDCG``. There are two forms of gain function for ``NDCG``, one is using relevance value directly while the other is using :math:`2^{rel} - 1` to emphasize on retrieving relevant documents. When ``ndcg_exp_gain`` is true (the default), relevance degree cannot be greater than 31. + *********************** Command Line Parameters *********************** diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 20a4c681e142..5566e0b2d401 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -431,8 +431,11 @@ def make_ltr( """Make a dataset for testing LTR.""" rng = np.random.default_rng(1994) X = rng.normal(0, 1.0, size=n_samples * n_features).reshape(n_samples, n_features) - y = rng.integers(0, max_rel, size=n_samples) - qid = rng.integers(0, n_query_groups, size=n_samples) + y = np.sum(X, axis=1) + y -= y.min() + y = np.round(y / y.max() * max_rel).astype(np.int32) + + qid = rng.integers(0, n_query_groups, size=n_samples, dtype=np.int32) w = rng.normal(0, 1.0, size=n_query_groups) w -= np.min(w) w /= np.max(w) diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index 4f272e939acd..c4549458d28f 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -493,7 +493,6 @@ class EvalMAPScore : public EvalRankWithCache { auto rank_idx = p_cache->SortedIdx(ctx_, predt.ConstHostSpan()); common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) { - auto g_predt = h_predt.Slice(linalg::Range(gptr[g], gptr[g + 1])); auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1])); auto g_rank = rank_idx.subspan(gptr[g]); diff --git a/src/objective/lambdarank_obj.cc b/src/objective/lambdarank_obj.cc index 30957f81a4f7..d0ff5bda5bde 100644 --- a/src/objective/lambdarank_obj.cc +++ b/src/objective/lambdarank_obj.cc @@ -69,6 +69,7 @@ void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView label, + common::Span rank_idx, std::shared_ptr p_cache) { + auto h_n_rel = p_cache->NumRelevant(ctx); + auto gptr = p_cache->DataGroupPtr(ctx); + + CHECK_EQ(h_n_rel.size(), gptr.back()); + CHECK_EQ(h_n_rel.size(), label.Size()); + + auto h_acc = p_cache->Acc(ctx); + + common::ParallelFor(p_cache->Groups(), ctx->Threads(), [&](auto g) { + auto cnt = gptr[g + 1] - gptr[g]; + auto g_n_rel = h_n_rel.subspan(gptr[g], cnt); + auto g_rank = rank_idx.subspan(gptr[g], cnt); + auto g_label = label.Slice(linalg::Range(gptr[g], gptr[g + 1])); + + // The number of relevant documents at each position + g_n_rel[0] = g_label(g_rank[0]); + for (std::size_t k = 1; k < g_rank.size(); ++k) { + g_n_rel[k] = g_n_rel[k - 1] + g_label(g_rank[k]); + } + + // \sum l_k/k + auto g_acc = h_acc.subspan(gptr[g], cnt); + g_acc[0] = g_label(g_rank[0]) / 1.0; + + for (std::size_t k = 1; k < g_rank.size(); ++k) { + g_acc[k] = g_acc[k - 1] + (g_label(g_rank[k]) / static_cast(k + 1)); + } + }); +} +} // namespace cpu_impl + +class LambdaRankMAP : public LambdaRankObj { + public: + void GetGradientImpl(std::int32_t iter, const HostDeviceVector& predt, + const MetaInfo& info, HostDeviceVector* out_gpair) { + CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the MAP objective."; + if (ctx_->IsCUDA()) { + return cuda_impl::LambdaRankGetGradientMAP( + ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id), + tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id), + out_gpair); + } + + auto gptr = p_cache_->DataGroupPtr(ctx_).data(); + bst_group_t n_groups = p_cache_->Groups(); + + out_gpair->Resize(info.num_row_); + auto h_gpair = out_gpair->HostSpan(); + auto h_label = info.labels.HostView().Slice(linalg::All(), 0); + auto h_predt = predt.ConstHostSpan(); + auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt); + auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_); + + auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); }; + + cpu_impl::MAPStat(ctx_, h_label, rank_idx, GetCache()); + auto n_rel = GetCache()->NumRelevant(ctx_); + auto acc = GetCache()->Acc(ctx_); + + auto delta_map = [&](auto y_high, auto y_low, std::size_t rank_high, std::size_t rank_low, + bst_group_t g) { + if (rank_high > rank_low) { + std::swap(rank_high, rank_low); + std::swap(y_high, y_low); + } + auto cnt = gptr[g + 1] - gptr[g]; + // In a hot loop + auto g_n_rel = common::Span{n_rel.data() + gptr[g], cnt}; + auto g_acc = common::Span{acc.data() + gptr[g], cnt}; + auto d = DeltaMAP(y_high, y_low, rank_high, rank_low, g_n_rel, g_acc); + return d; + }; + using D = decltype(delta_map); + + common::ParallelFor(n_groups, ctx_->Threads(), [&](auto g) { + auto cnt = gptr[g + 1] - gptr[g]; + auto w = h_weight[g]; + auto g_predt = h_predt.subspan(gptr[g], cnt); + auto g_gpair = h_gpair.subspan(gptr[g], cnt); + auto g_label = h_label.Slice(make_range(g)); + auto g_rank = rank_idx.subspan(gptr[g], cnt); + + auto args = std::make_tuple(this, iter, g_predt, g_label, w, g_rank, g, delta_map, g_gpair); + + if (param_.lambdarank_unbiased) { + std::apply(&LambdaRankMAP::CalcLambdaForGroup, args); + } else { + std::apply(&LambdaRankMAP::CalcLambdaForGroup, args); + } + }); + } + static char const* Name() { return "rank:map"; } + [[nodiscard]] const char* DefaultEvalMetric() const override { + return this->RankEvalMetric("map"); + } +}; + +#if !defined(XGBOOST_USE_CUDA) +namespace cuda_impl { +void MAPStat(Context const*, MetaInfo const&, common::Span, + std::shared_ptr) { + common::AssertGPUSupport(); +} + +void LambdaRankGetGradientMAP(Context const*, std::int32_t, HostDeviceVector const&, + const MetaInfo&, std::shared_ptr, + linalg::VectorView, // input bias ratio + linalg::VectorView, // input bias ratio + linalg::VectorView, linalg::VectorView, + HostDeviceVector*) { + common::AssertGPUSupport(); +} +} // namespace cuda_impl +#endif // !defined(XGBOOST_USE_CUDA) + +/** + * \brief The RankNet loss. + */ +class LambdaRankPairwise : public LambdaRankObj { + public: + void GetGradientImpl(std::int32_t iter, const HostDeviceVector& predt, + const MetaInfo& info, HostDeviceVector* out_gpair) { + CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the pairwise objective."; + if (ctx_->IsCUDA()) { + return cuda_impl::LambdaRankGetGradientPairwise( + ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id), + tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id), + out_gpair); + } + + auto gptr = p_cache_->DataGroupPtr(ctx_); + bst_group_t n_groups = p_cache_->Groups(); + + out_gpair->Resize(info.num_row_); + auto h_gpair = out_gpair->HostSpan(); + auto h_label = info.labels.HostView().Slice(linalg::All(), 0); + auto h_predt = predt.ConstHostSpan(); + auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_); + + auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); }; + auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt); + + auto delta = [](auto...) { return 1.0; }; + using D = decltype(delta); + + common::ParallelFor(n_groups, ctx_->Threads(), [&](auto g) { + auto cnt = gptr[g + 1] - gptr[g]; + auto w = h_weight[g]; + auto g_predt = h_predt.subspan(gptr[g], cnt); + auto g_gpair = h_gpair.subspan(gptr[g], cnt); + auto g_label = h_label.Slice(make_range(g)); + auto g_rank = rank_idx.subspan(gptr[g], cnt); + + auto args = std::make_tuple(this, iter, g_predt, g_label, w, g_rank, g, delta, g_gpair); + if (param_.lambdarank_unbiased) { + std::apply(&LambdaRankPairwise::CalcLambdaForGroup, args); + } else { + std::apply(&LambdaRankPairwise::CalcLambdaForGroup, args); + } + }); + } + + static char const* Name() { return "rank:pairwise"; } + [[nodiscard]] const char* DefaultEvalMetric() const override { + return this->RankEvalMetric("ndcg"); + } +}; + +#if !defined(XGBOOST_USE_CUDA) +namespace cuda_impl { +void LambdaRankGetGradientPairwise(Context const*, std::int32_t, HostDeviceVector const&, + const MetaInfo&, std::shared_ptr, + linalg::VectorView, // input bias ratio + linalg::VectorView, // input bias ratio + linalg::VectorView, linalg::VectorView, + HostDeviceVector*) { + common::AssertGPUSupport(); +} +} // namespace cuda_impl +#endif // !defined(XGBOOST_USE_CUDA) + XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, LambdaRankNDCG::Name()) .describe("LambdaRank with NDCG loss as objective") .set_body([]() { return new LambdaRankNDCG{}; }); +XGBOOST_REGISTER_OBJECTIVE(LambdaRankPairwise, LambdaRankPairwise::Name()) + .describe("LambdaRank with RankNet loss as objective") + .set_body([]() { return new LambdaRankPairwise{}; }); + +XGBOOST_REGISTER_OBJECTIVE(LambdaRankMAP, LambdaRankMAP::Name()) + .describe("LambdaRank with MAP loss as objective.") + .set_body([]() { return new LambdaRankMAP{}; }); + DMLC_REGISTRY_FILE_TAG(lambdarank_obj); } // namespace xgboost::obj diff --git a/src/objective/lambdarank_obj.cu b/src/objective/lambdarank_obj.cu index 27b5872a8a49..110e4ae87914 100644 --- a/src/objective/lambdarank_obj.cu +++ b/src/objective/lambdarank_obj.cu @@ -390,6 +390,112 @@ void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter, Launch(ctx, iter, preds, info, p_cache, delta_ndcg, ti_plus, tj_minus, li, lj, out_gpair); } +void MAPStat(Context const* ctx, MetaInfo const& info, common::Span d_rank_idx, + std::shared_ptr p_cache) { + common::Span out_n_rel = p_cache->NumRelevant(ctx); + common::Span out_acc = p_cache->Acc(ctx); + + CHECK_EQ(out_n_rel.size(), info.num_row_); + CHECK_EQ(out_acc.size(), info.num_row_); + + auto group_ptr = p_cache->DataGroupPtr(ctx); + auto key_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + [=] XGBOOST_DEVICE(std::size_t i) -> std::size_t { return dh::SegmentId(group_ptr, i); }); + auto label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + auto const* cuctx = ctx->CUDACtx(); + + { + // calculate number of relevant documents + auto val_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> double { + auto g = dh::SegmentId(group_ptr, i); + auto g_label = label.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1])); + auto idx_in_group = i - group_ptr[g]; + auto g_sorted_idx = d_rank_idx.subspan(group_ptr[g], group_ptr[g + 1] - group_ptr[g]); + return static_cast(g_label(g_sorted_idx[idx_in_group])); + }); + thrust::inclusive_scan_by_key(cuctx->CTP(), key_it, key_it + info.num_row_, val_it, + out_n_rel.data()); + } + { + // \sum l_k/k + auto val_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> double { + auto g = dh::SegmentId(group_ptr, i); + auto g_label = label.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1])); + auto g_sorted_idx = d_rank_idx.subspan(group_ptr[g], group_ptr[g + 1] - group_ptr[g]); + auto idx_in_group = i - group_ptr[g]; + double rank_in_group = idx_in_group + 1.0; + return static_cast(g_label(g_sorted_idx[idx_in_group])) / rank_in_group; + }); + thrust::inclusive_scan_by_key(cuctx->CTP(), key_it, key_it + info.num_row_, val_it, + out_acc.data()); + } +} + +void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter, + HostDeviceVector const& predt, const MetaInfo& info, + std::shared_ptr p_cache, + linalg::VectorView ti_plus, // input bias ratio + linalg::VectorView tj_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair) { + std::int32_t device_id = ctx->gpu_id; + dh::safe_cuda(cudaSetDevice(device_id)); + + info.labels.SetDevice(device_id); + predt.SetDevice(device_id); + + CHECK(p_cache); + + auto d_predt = predt.ConstDeviceSpan(); + auto const d_sorted_idx = p_cache->SortedIdx(ctx, d_predt); + + MAPStat(ctx, info, d_sorted_idx, p_cache); + auto d_n_rel = p_cache->NumRelevant(ctx); + auto d_acc = p_cache->Acc(ctx); + auto d_gptr = p_cache->DataGroupPtr(ctx).data(); + + auto delta_map = [=] XGBOOST_DEVICE(float y_high, float y_low, std::size_t rank_high, + std::size_t rank_low, bst_group_t g) { + if (rank_high > rank_low) { + thrust::swap(rank_high, rank_low); + thrust::swap(y_high, y_low); + } + auto cnt = d_gptr[g + 1] - d_gptr[g]; + auto g_n_rel = d_n_rel.subspan(d_gptr[g], cnt); + auto g_acc = d_acc.subspan(d_gptr[g], cnt); + auto d = DeltaMAP(y_high, y_low, rank_high, rank_low, g_n_rel, g_acc); + return d; + }; + + Launch(ctx, iter, predt, info, p_cache, delta_map, ti_plus, tj_minus, li, lj, out_gpair); +} + +void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter, + HostDeviceVector const& predt, const MetaInfo& info, + std::shared_ptr p_cache, + linalg::VectorView ti_plus, // input bias ratio + linalg::VectorView tj_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair) { + std::int32_t device_id = ctx->gpu_id; + dh::safe_cuda(cudaSetDevice(device_id)); + + info.labels.SetDevice(device_id); + predt.SetDevice(device_id); + + auto d_predt = predt.ConstDeviceSpan(); + auto const d_sorted_idx = p_cache->SortedIdx(ctx, d_predt); + + auto delta = [] XGBOOST_DEVICE(float, float, std::size_t, std::size_t, bst_group_t) { + return 1.0; + }; + + Launch(ctx, iter, predt, info, p_cache, delta, ti_plus, tj_minus, li, lj, out_gpair); +} + namespace { struct ReduceOp { template diff --git a/src/objective/lambdarank_obj.h b/src/objective/lambdarank_obj.h index 0eb06e27cdc4..c2222c028582 100644 --- a/src/objective/lambdarank_obj.h +++ b/src/objective/lambdarank_obj.h @@ -156,6 +156,27 @@ void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter, linalg::VectorView li, linalg::VectorView lj, HostDeviceVector* out_gpair); +/** + * \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart. + */ +void MAPStat(Context const* ctx, MetaInfo const& info, common::Span d_rank_idx, + std::shared_ptr p_cache); + +void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter, + HostDeviceVector const& predt, MetaInfo const& info, + std::shared_ptr p_cache, + linalg::VectorView t_plus, // input bias ratio + linalg::VectorView t_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair); + +void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter, + HostDeviceVector const& predt, const MetaInfo& info, + std::shared_ptr p_cache, + linalg::VectorView ti_plus, // input bias ratio + linalg::VectorView tj_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair); void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView li_full, linalg::VectorView lj_full, @@ -165,6 +186,18 @@ void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView p_cache); } // namespace cuda_impl +namespace cpu_impl { +/** + * \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart. + * + * \param label Ground truth relevance label. + * \param rank_idx Sorted index of prediction. + * \param p_cache An initialized MAPCache. + */ +void MAPStat(Context const* ctx, linalg::VectorView label, + common::Span rank_idx, std::shared_ptr p_cache); +} // namespace cpu_impl + /** * \param Construct pairs on CPU * diff --git a/src/objective/objective.cc b/src/objective/objective.cc index 7d2c37811d1a..85cd9803d4ef 100644 --- a/src/objective/objective.cc +++ b/src/objective/objective.cc @@ -47,7 +47,6 @@ DMLC_REGISTRY_LINK_TAG(regression_obj_gpu); DMLC_REGISTRY_LINK_TAG(quantile_obj_gpu); DMLC_REGISTRY_LINK_TAG(hinge_obj_gpu); DMLC_REGISTRY_LINK_TAG(multiclass_obj_gpu); -DMLC_REGISTRY_LINK_TAG(rank_obj_gpu); DMLC_REGISTRY_LINK_TAG(lambdarank_obj); DMLC_REGISTRY_LINK_TAG(lambdarank_obj_cu); #else @@ -55,7 +54,6 @@ DMLC_REGISTRY_LINK_TAG(regression_obj); DMLC_REGISTRY_LINK_TAG(quantile_obj); DMLC_REGISTRY_LINK_TAG(hinge_obj); DMLC_REGISTRY_LINK_TAG(multiclass_obj); -DMLC_REGISTRY_LINK_TAG(rank_obj); DMLC_REGISTRY_LINK_TAG(lambdarank_obj); #endif // XGBOOST_USE_CUDA } // namespace obj diff --git a/src/objective/rank_obj.cc b/src/objective/rank_obj.cc deleted file mode 100644 index 25cd9e643eff..000000000000 --- a/src/objective/rank_obj.cc +++ /dev/null @@ -1,17 +0,0 @@ -/*! - * Copyright 2019 XGBoost contributors - */ - -// Dummy file to keep the CUDA conditional compile trick. -#include -namespace xgboost { -namespace obj { - -DMLC_REGISTRY_FILE_TAG(rank_obj); - -} // namespace obj -} // namespace xgboost - -#ifndef XGBOOST_USE_CUDA -#include "rank_obj.cu" -#endif // XGBOOST_USE_CUDA diff --git a/src/objective/rank_obj.cu b/src/objective/rank_obj.cu deleted file mode 100644 index 23613d93d9d3..000000000000 --- a/src/objective/rank_obj.cu +++ /dev/null @@ -1,789 +0,0 @@ -/*! - * Copyright 2015-2022 XGBoost contributors - */ -#include -#include -#include -#include -#include -#include -#include - -#include "xgboost/json.h" -#include "xgboost/parameter.h" - -#include "../common/math.h" -#include "../common/random.h" - -#if defined(__CUDACC__) -#include -#include -#include -#include -#include - -#include - -#include "../common/device_helpers.cuh" -#endif - -namespace xgboost { -namespace obj { - -#if defined(XGBOOST_USE_CUDA) && !defined(GTEST_TEST) -DMLC_REGISTRY_FILE_TAG(rank_obj_gpu); -#endif // defined(XGBOOST_USE_CUDA) - -struct LambdaRankParam : public XGBoostParameter { - size_t num_pairsample; - float fix_list_weight; - // declare parameters - DMLC_DECLARE_PARAMETER(LambdaRankParam) { - DMLC_DECLARE_FIELD(num_pairsample).set_lower_bound(1).set_default(1) - .describe("Number of pair generated for each instance."); - DMLC_DECLARE_FIELD(fix_list_weight).set_lower_bound(0.0f).set_default(0.0f) - .describe("Normalize the weight of each list by this value," - " if equals 0, no effect will happen"); - } -}; - -#if defined(__CUDACC__) -// Helper functions - -template -XGBOOST_DEVICE __forceinline__ uint32_t -CountNumItemsToTheLeftOf(const T *__restrict__ items, uint32_t n, T v) { - return thrust::lower_bound(thrust::seq, items, items + n, v, - thrust::greater()) - - items; -} - -template -XGBOOST_DEVICE __forceinline__ uint32_t -CountNumItemsToTheRightOf(const T *__restrict__ items, uint32_t n, T v) { - return n - (thrust::upper_bound(thrust::seq, items, items + n, v, - thrust::greater()) - - items); -} -#endif - -/*! \brief helper information in a list */ -struct ListEntry { - /*! \brief the predict score we in the data */ - bst_float pred; - /*! \brief the actual label of the entry */ - bst_float label; - /*! \brief row index in the data matrix */ - unsigned rindex; - // constructor - ListEntry(bst_float pred, bst_float label, unsigned rindex) - : pred(pred), label(label), rindex(rindex) {} - // comparator by prediction - inline static bool CmpPred(const ListEntry &a, const ListEntry &b) { - return a.pred > b.pred; - } - // comparator by label - inline static bool CmpLabel(const ListEntry &a, const ListEntry &b) { - return a.label > b.label; - } -}; - -/*! \brief a pair in the lambda rank */ -struct LambdaPair { - /*! \brief positive index: this is a position in the list */ - unsigned pos_index; - /*! \brief negative index: this is a position in the list */ - unsigned neg_index; - /*! \brief weight to be filled in */ - bst_float weight; - // constructor - LambdaPair(unsigned pos_index, unsigned neg_index) - : pos_index(pos_index), neg_index(neg_index), weight(1.0f) {} - // constructor - LambdaPair(unsigned pos_index, unsigned neg_index, bst_float weight) - : pos_index(pos_index), neg_index(neg_index), weight(weight) {} -}; - -class PairwiseLambdaWeightComputer { - public: - /*! - * \brief get lambda weight for existing pairs - for pairwise objective - * \param list a list that is sorted by pred score - * \param io_pairs record of pairs, containing the pairs to fill in weights - */ - static void GetLambdaWeight(const std::vector&, - std::vector*) {} - - static char const* Name() { - return "rank:pairwise"; - } - -#if defined(__CUDACC__) - PairwiseLambdaWeightComputer(const bst_float*, - const bst_float*, - const dh::SegmentSorter&) {} - - class PairwiseLambdaWeightMultiplier { - public: - // Adjust the items weight by this value - __device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const { - return 1.0f; - } - }; - - inline const PairwiseLambdaWeightMultiplier GetWeightMultiplier() const { - return {}; - } -#endif -}; - -#if defined(__CUDACC__) -class BaseLambdaWeightMultiplier { - public: - BaseLambdaWeightMultiplier(const dh::SegmentSorter &segment_label_sorter, - const dh::SegmentSorter &segment_pred_sorter) - : dsorted_labels_(segment_label_sorter.GetItemsSpan()), - dorig_pos_(segment_label_sorter.GetOriginalPositionsSpan()), - dgroups_(segment_label_sorter.GetGroupsSpan()), - dindexable_sorted_preds_pos_(segment_pred_sorter.GetIndexableSortedPositionsSpan()) {} - - protected: - const common::Span dsorted_labels_; // Labels sorted within a group - const common::Span dorig_pos_; // Original indices of the labels - // before they are sorted - const common::Span dgroups_; // The group indices - // Where can a prediction for a label be found in the original array, when they are sorted - const common::Span dindexable_sorted_preds_pos_; -}; - -// While computing the weight that needs to be adjusted by this ranking objective, we need -// to figure out where positive and negative labels chosen earlier exists, if the group -// were to be sorted by its predictions. To accommodate this, we employ the following algorithm. -// For a given group, let's assume the following: -// labels: 1 5 9 2 4 8 0 7 6 3 -// predictions: 1 9 0 8 2 7 3 6 5 4 -// position: 0 1 2 3 4 5 6 7 8 9 -// -// After label sort: -// labels: 9 8 7 6 5 4 3 2 1 0 -// position: 2 5 7 8 1 4 9 3 0 6 -// -// After prediction sort: -// predictions: 9 8 7 6 5 4 3 2 1 0 -// position: 1 3 5 7 8 9 6 4 0 2 -// -// If a sorted label at position 'x' is chosen, then we need to find out where the prediction -// for this label 'x' exists, if the group were to be sorted by predictions. -// We first take the sorted prediction positions: -// position: 1 3 5 7 8 9 6 4 0 2 -// at indices: 0 1 2 3 4 5 6 7 8 9 -// -// We create a sorted prediction positional array, such that value at position 'x' gives -// us the position in the sorted prediction array where its related prediction lies. -// dindexable_sorted_preds_pos_: 8 0 9 1 7 2 6 3 4 5 -// at indices: 0 1 2 3 4 5 6 7 8 9 -// Basically, swap the previous 2 arrays, sort the indices and reorder positions -// for an O(1) lookup using the position where the sorted label exists. -// -// This type does that using the SegmentSorter -class IndexablePredictionSorter { - public: - IndexablePredictionSorter(const bst_float *dpreds, - const dh::SegmentSorter &segment_label_sorter) { - // Sort the predictions first - segment_pred_sorter_.SortItems(dpreds, segment_label_sorter.GetNumItems(), - segment_label_sorter.GetGroupSegmentsSpan()); - - // Create an index for the sorted prediction positions - segment_pred_sorter_.CreateIndexableSortedPositions(); - } - - inline const dh::SegmentSorter &GetPredictionSorter() const { - return segment_pred_sorter_; - } - - private: - dh::SegmentSorter segment_pred_sorter_; // For sorting the predictions -}; -#endif - -class MAPLambdaWeightComputer -#if defined(__CUDACC__) - : public IndexablePredictionSorter -#endif -{ - public: - struct MAPStats { - /*! \brief the accumulated precision */ - float ap_acc{0.0f}; - /*! - * \brief the accumulated precision, - * assuming a positive instance is missing - */ - float ap_acc_miss{0.0f}; - /*! - * \brief the accumulated precision, - * assuming that one more positive instance is inserted ahead - */ - float ap_acc_add{0.0f}; - /* \brief the accumulated positive instance count */ - float hits{0.0f}; - - XGBOOST_DEVICE MAPStats() {} // NOLINT - XGBOOST_DEVICE MAPStats(float ap_acc, float ap_acc_miss, float ap_acc_add, float hits) - : ap_acc(ap_acc), ap_acc_miss(ap_acc_miss), ap_acc_add(ap_acc_add), hits(hits) {} - - // For prefix scan - XGBOOST_DEVICE MAPStats operator +(const MAPStats &v1) const { - return {ap_acc + v1.ap_acc, ap_acc_miss + v1.ap_acc_miss, - ap_acc_add + v1.ap_acc_add, hits + v1.hits}; - } - - // For test purposes - compare for equality - XGBOOST_DEVICE bool operator ==(const MAPStats &rhs) const { - return ap_acc == rhs.ap_acc && ap_acc_miss == rhs.ap_acc_miss && - ap_acc_add == rhs.ap_acc_add && hits == rhs.hits; - } - }; - - private: - template - XGBOOST_DEVICE inline static void Swap(T &v0, T &v1) { -#if defined(__CUDACC__) - thrust::swap(v0, v1); -#else - std::swap(v0, v1); -#endif - } - - /*! - * \brief Obtain the delta MAP by trying to switch the positions of labels in pos_pred_pos or - * neg_pred_pos when sorted by predictions - * \param pos_pred_pos positive label's prediction value position when the groups prediction - * values are sorted - * \param neg_pred_pos negative label's prediction value position when the groups prediction - * values are sorted - * \param pos_label, neg_label the chosen positive and negative labels - * \param p_map_stats a vector containing the accumulated precisions for each position in a list - * \param map_stats_size size of the accumulated precisions vector - */ - XGBOOST_DEVICE inline static bst_float GetLambdaMAP( - int pos_pred_pos, int neg_pred_pos, - bst_float pos_label, bst_float neg_label, - const MAPStats *p_map_stats, uint32_t map_stats_size) { - if (pos_pred_pos == neg_pred_pos || p_map_stats[map_stats_size - 1].hits == 0) { - return 0.0f; - } - if (pos_pred_pos > neg_pred_pos) { - Swap(pos_pred_pos, neg_pred_pos); - Swap(pos_label, neg_label); - } - bst_float original = p_map_stats[neg_pred_pos].ap_acc; - if (pos_pred_pos != 0) original -= p_map_stats[pos_pred_pos - 1].ap_acc; - bst_float changed = 0; - bst_float label1 = pos_label > 0.0f ? 1.0f : 0.0f; - bst_float label2 = neg_label > 0.0f ? 1.0f : 0.0f; - if (label1 == label2) { - return 0.0; - } else if (label1 < label2) { - changed += p_map_stats[neg_pred_pos - 1].ap_acc_add - p_map_stats[pos_pred_pos].ap_acc_add; - changed += (p_map_stats[pos_pred_pos].hits + 1.0f) / (pos_pred_pos + 1); - } else { - changed += p_map_stats[neg_pred_pos - 1].ap_acc_miss - p_map_stats[pos_pred_pos].ap_acc_miss; - changed += p_map_stats[neg_pred_pos].hits / (neg_pred_pos + 1); - } - bst_float ans = (changed - original) / (p_map_stats[map_stats_size - 1].hits); - if (ans < 0) ans = -ans; - return ans; - } - - public: - /* - * \brief obtain preprocessing results for calculating delta MAP - * \param sorted_list the list containing entry information - * \param map_stats a vector containing the accumulated precisions for each position in a list - */ - inline static void GetMAPStats(const std::vector &sorted_list, - std::vector *p_map_acc) { - std::vector &map_acc = *p_map_acc; - map_acc.resize(sorted_list.size()); - bst_float hit = 0, acc1 = 0, acc2 = 0, acc3 = 0; - for (size_t i = 1; i <= sorted_list.size(); ++i) { - if (sorted_list[i - 1].label > 0.0f) { - hit++; - acc1 += hit / i; - acc2 += (hit - 1) / i; - acc3 += (hit + 1) / i; - } - map_acc[i - 1] = MAPStats(acc1, acc2, acc3, hit); - } - } - - static char const* Name() { - return "rank:map"; - } - - static void GetLambdaWeight(const std::vector &sorted_list, - std::vector *io_pairs) { - std::vector &pairs = *io_pairs; - std::vector map_stats; - GetMAPStats(sorted_list, &map_stats); - for (auto & pair : pairs) { - pair.weight *= - GetLambdaMAP(pair.pos_index, pair.neg_index, - sorted_list[pair.pos_index].label, sorted_list[pair.neg_index].label, - &map_stats[0], map_stats.size()); - } - } - -#if defined(__CUDACC__) - MAPLambdaWeightComputer(const bst_float *dpreds, - const bst_float *dlabels, - const dh::SegmentSorter &segment_label_sorter) - : IndexablePredictionSorter(dpreds, segment_label_sorter), - dmap_stats_(segment_label_sorter.GetNumItems(), MAPStats()), - weight_multiplier_(segment_label_sorter, *this) { - this->CreateMAPStats(dlabels, segment_label_sorter); - } - - void CreateMAPStats(const bst_float *dlabels, - const dh::SegmentSorter &segment_label_sorter) { - // For each group, go through the sorted prediction positions, and look up its corresponding - // label from the unsorted labels (from the original label list) - - // For each item in the group, compute its MAP stats. - // Interleave the computation of map stats amongst different groups. - - // First, determine postive labels in the dataset individually - auto nitems = segment_label_sorter.GetNumItems(); - dh::caching_device_vector dhits(nitems, 0); - // Original positions of the predictions after they have been sorted - const auto &pred_original_pos = this->GetPredictionSorter().GetOriginalPositionsSpan(); - // Unsorted labels - const float *unsorted_labels = dlabels; - auto DeterminePositiveLabelLambda = [=] __device__(uint32_t idx) { - return (unsorted_labels[pred_original_pos[idx]] > 0.0f) ? 1 : 0; - }; // NOLINT - - thrust::transform(thrust::make_counting_iterator(static_cast(0)), - thrust::make_counting_iterator(nitems), - dhits.begin(), - DeterminePositiveLabelLambda); - - // Allocator to be used by sort for managing space overhead while performing prefix scans - dh::XGBCachingDeviceAllocator alloc; - - // Next, prefix scan the positive labels that are segmented to accumulate them. - // This is required for computing the accumulated precisions - const auto &group_segments = segment_label_sorter.GetGroupSegmentsSpan(); - // Data segmented into different groups... - thrust::inclusive_scan_by_key(thrust::cuda::par(alloc), - dh::tcbegin(group_segments), dh::tcend(group_segments), - dhits.begin(), // Input value - dhits.begin()); // In-place scan - - // Compute accumulated precisions for each item, assuming positive and - // negative instances are missing. - // But first, compute individual item precisions - const auto *dhits_arr = dhits.data().get(); - // Group info on device - const auto &dgroups = segment_label_sorter.GetGroupsSpan(); - auto ComputeItemPrecisionLambda = [=] __device__(uint32_t idx) { - if (unsorted_labels[pred_original_pos[idx]] > 0.0f) { - auto idx_within_group = (idx - dgroups[group_segments[idx]]) + 1; - return MAPStats{static_cast(dhits_arr[idx]) / idx_within_group, - static_cast(dhits_arr[idx] - 1) / idx_within_group, - static_cast(dhits_arr[idx] + 1) / idx_within_group, - 1.0f}; - } - return MAPStats{}; - }; // NOLINT - - thrust::transform(thrust::make_counting_iterator(static_cast(0)), - thrust::make_counting_iterator(nitems), - this->dmap_stats_.begin(), - ComputeItemPrecisionLambda); - - // Lastly, compute the accumulated precisions for all the items segmented by groups. - // The precisions are accumulated within each group - thrust::inclusive_scan_by_key(thrust::cuda::par(alloc), - dh::tcbegin(group_segments), dh::tcend(group_segments), - this->dmap_stats_.begin(), // Input map stats - this->dmap_stats_.begin()); // In-place scan and output here - } - - inline const common::Span GetMapStatsSpan() const { - return { dmap_stats_.data().get(), dmap_stats_.size() }; - } - - // Type containing device pointers that can be cheaply copied on the kernel - class MAPLambdaWeightMultiplier : public BaseLambdaWeightMultiplier { - public: - MAPLambdaWeightMultiplier(const dh::SegmentSorter &segment_label_sorter, - const MAPLambdaWeightComputer &lwc) - : BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()), - dmap_stats_(lwc.GetMapStatsSpan()) {} - - // Adjust the items weight by this value - __device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const { - uint32_t group_begin = dgroups_[gidx]; - uint32_t group_end = dgroups_[gidx + 1]; - - auto pos_lab_orig_posn = dorig_pos_[pidx]; - auto neg_lab_orig_posn = dorig_pos_[nidx]; - KERNEL_CHECK(pos_lab_orig_posn != neg_lab_orig_posn); - - // Note: the label positive and negative indices are relative to the entire dataset. - // Hence, scale them back to an index within the group - auto pos_pred_pos = dindexable_sorted_preds_pos_[pos_lab_orig_posn] - group_begin; - auto neg_pred_pos = dindexable_sorted_preds_pos_[neg_lab_orig_posn] - group_begin; - return MAPLambdaWeightComputer::GetLambdaMAP( - pos_pred_pos, neg_pred_pos, - dsorted_labels_[pidx], dsorted_labels_[nidx], - &dmap_stats_[group_begin], group_end - group_begin); - } - - private: - common::Span dmap_stats_; // Start address of the map stats for every sorted - // prediction value - }; - - inline const MAPLambdaWeightMultiplier GetWeightMultiplier() const { return weight_multiplier_; } - - private: - dh::caching_device_vector dmap_stats_; - // This computes the adjustment to the weight - const MAPLambdaWeightMultiplier weight_multiplier_; -#endif -}; - -#if defined(__CUDACC__) -class SortedLabelList : dh::SegmentSorter { - private: - const LambdaRankParam ¶m_; // Objective configuration - - public: - explicit SortedLabelList(const LambdaRankParam ¶m) - : param_(param) {} - - // Sort the labels that are grouped by 'groups' - void Sort(const HostDeviceVector &dlabels, const std::vector &groups) { - this->SortItems(dlabels.ConstDevicePointer(), dlabels.Size(), groups); - } - - // This kernel can only run *after* the kernel in sort is completed, as they - // use the default stream - template - void ComputeGradients(const bst_float *dpreds, // Unsorted predictions - const bst_float *dlabels, // Unsorted labels - const HostDeviceVector &weights, - int iter, - GradientPair *out_gpair, - float weight_normalization_factor) { - // Group info on device - const auto &dgroups = this->GetGroupsSpan(); - uint32_t ngroups = this->GetNumGroups() + 1; - - uint32_t total_items = this->GetNumItems(); - uint32_t niter = param_.num_pairsample * total_items; - - float fix_list_weight = param_.fix_list_weight; - - const auto &original_pos = this->GetOriginalPositionsSpan(); - - uint32_t num_weights = weights.Size(); - auto dweights = num_weights ? weights.ConstDevicePointer() : nullptr; - - const auto &sorted_labels = this->GetItemsSpan(); - - // This is used to adjust the weight of different elements based on the different ranking - // objective function policies - LambdaWeightComputerT weight_computer(dpreds, dlabels, *this); - auto wmultiplier = weight_computer.GetWeightMultiplier(); - - int device_id = -1; - dh::safe_cuda(cudaGetDevice(&device_id)); - // For each instance in the group, compute the gradient pair concurrently - dh::LaunchN(niter, nullptr, [=] __device__(uint32_t idx) { - // First, determine the group 'idx' belongs to - uint32_t item_idx = idx % total_items; - uint32_t group_idx = - thrust::upper_bound(thrust::seq, dgroups.begin(), - dgroups.begin() + ngroups, item_idx) - - dgroups.begin(); - // Span of this group within the larger labels/predictions sorted tuple - uint32_t group_begin = dgroups[group_idx - 1]; - uint32_t group_end = dgroups[group_idx]; - uint32_t total_group_items = group_end - group_begin; - - // Are the labels diverse enough? If they are all the same, then there is nothing to pick - // from another group - bail sooner - if (sorted_labels[group_begin] == sorted_labels[group_end - 1]) return; - - // Find the number of labels less than and greater than the current label - // at the sorted index position item_idx - uint32_t nleft = CountNumItemsToTheLeftOf( - sorted_labels.data() + group_begin, item_idx - group_begin + 1, sorted_labels[item_idx]); - uint32_t nright = CountNumItemsToTheRightOf( - sorted_labels.data() + item_idx, group_end - item_idx, sorted_labels[item_idx]); - - // Create a minstd_rand object to act as our source of randomness - thrust::minstd_rand rng((iter + 1) * 1111); - rng.discard(((idx / total_items) * total_group_items) + item_idx - group_begin); - // Create a uniform_int_distribution to produce a sample from outside of the - // present label group - thrust::uniform_int_distribution dist(0, nleft + nright - 1); - - int sample = dist(rng); - int pos_idx = -1; // Bigger label - int neg_idx = -1; // Smaller label - // Are we picking a sample to the left/right of the current group? - if (sample < nleft) { - // Go left - pos_idx = sample + group_begin; - neg_idx = item_idx; - } else { - pos_idx = item_idx; - uint32_t items_in_group = total_group_items - nleft - nright; - neg_idx = sample + items_in_group + group_begin; - } - - // Compute and assign the gradients now - const float eps = 1e-16f; - bst_float p = common::Sigmoid(dpreds[original_pos[pos_idx]] - dpreds[original_pos[neg_idx]]); - bst_float g = p - 1.0f; - bst_float h = thrust::max(p * (1.0f - p), eps); - - // Rescale each gradient and hessian so that the group has a weighted constant - float scale = __frcp_ru(niter / total_items); - if (fix_list_weight != 0.0f) { - scale *= fix_list_weight / total_group_items; - } - - float weight = num_weights ? dweights[group_idx - 1] : 1.0f; - weight *= weight_normalization_factor; - weight *= wmultiplier.GetWeight(group_idx - 1, pos_idx, neg_idx); - weight *= scale; - // Accumulate gradient and hessian in both positive and negative indices - const GradientPair in_pos_gpair(g * weight, 2.0f * weight * h); - dh::AtomicAddGpair(&out_gpair[original_pos[pos_idx]], in_pos_gpair); - - const GradientPair in_neg_gpair(-g * weight, 2.0f * weight * h); - dh::AtomicAddGpair(&out_gpair[original_pos[neg_idx]], in_neg_gpair); - }); - - // Wait until the computations done by the kernel is complete - dh::safe_cuda(cudaStreamSynchronize(nullptr)); - } -}; -#endif - -// objective for lambda rank -template -class LambdaRankObj : public ObjFunction { - public: - void Configure(Args const &args) override { param_.UpdateAllowUnknown(args); } - ObjInfo Task() const override { return ObjInfo::kRanking; } - - void GetGradient(const HostDeviceVector& preds, - const MetaInfo& info, - int iter, - HostDeviceVector* out_gpair) override { - CHECK_EQ(preds.Size(), info.labels.Size()) << "label size predict size not match"; - - // quick consistency when group is not available - std::vector tgptr(2, 0); tgptr[1] = static_cast(info.labels.Size()); - const std::vector &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_; - CHECK(gptr.size() != 0 && gptr.back() == info.labels.Size()) - << "group structure not consistent with #rows" << ", " - << "group ponter size: " << gptr.size() << ", " - << "labels size: " << info.labels.Size() << ", " - << "group pointer back: " << (gptr.size() == 0 ? 0 : gptr.back()); - -#if defined(__CUDACC__) - // Check if we have a GPU assignment; else, revert back to CPU - auto device = ctx_->gpu_id; - if (device >= 0) { - ComputeGradientsOnGPU(preds, info, iter, out_gpair, gptr); - } else { - // Revert back to CPU -#endif - ComputeGradientsOnCPU(preds, info, iter, out_gpair, gptr); -#if defined(__CUDACC__) - } -#endif - } - - const char* DefaultEvalMetric() const override { - return "map"; - } - - void SaveConfig(Json* p_out) const override { - auto& out = *p_out; - out["name"] = String(LambdaWeightComputerT::Name()); - out["lambda_rank_param"] = ToJson(param_); - } - - void LoadConfig(Json const& in) override { - FromJson(in["lambda_rank_param"], ¶m_); - } - - private: - bst_float ComputeWeightNormalizationFactor(const MetaInfo& info, - const std::vector &gptr) { - const auto ngroup = static_cast(gptr.size() - 1); - bst_float sum_weights = 0; - for (bst_omp_uint k = 0; k < ngroup; ++k) { - sum_weights += info.GetWeight(k); - } - return ngroup / sum_weights; - } - - void ComputeGradientsOnCPU(const HostDeviceVector& preds, - const MetaInfo& info, - int iter, - HostDeviceVector* out_gpair, - const std::vector &gptr) { - LOG(DEBUG) << "Computing " << LambdaWeightComputerT::Name() << " gradients on CPU."; - - bst_float weight_normalization_factor = ComputeWeightNormalizationFactor(info, gptr); - - const auto& preds_h = preds.HostVector(); - const auto& labels = info.labels.HostView(); - std::vector& gpair = out_gpair->HostVector(); - const auto ngroup = static_cast(gptr.size() - 1); - out_gpair->Resize(preds.Size()); - - dmlc::OMPException exc; -#pragma omp parallel num_threads(ctx_->Threads()) - { - exc.Run([&]() { - // parallel construct, declare random number generator here, so that each - // thread use its own random number generator, seed by thread id and current iteration - std::minstd_rand rnd((iter + 1) * 1111); - std::vector pairs; - std::vector lst; - std::vector< std::pair > rec; - - #pragma omp for schedule(static) - for (bst_omp_uint k = 0; k < ngroup; ++k) { - exc.Run([&]() { - lst.clear(); pairs.clear(); - for (unsigned j = gptr[k]; j < gptr[k+1]; ++j) { - lst.emplace_back(preds_h[j], labels(j), j); - gpair[j] = GradientPair(0.0f, 0.0f); - } - std::stable_sort(lst.begin(), lst.end(), ListEntry::CmpPred); - rec.resize(lst.size()); - for (unsigned i = 0; i < lst.size(); ++i) { - rec[i] = std::make_pair(lst[i].label, i); - } - std::stable_sort(rec.begin(), rec.end(), common::CmpFirst); - // enumerate buckets with same label - // for each item in the lst, grab another sample randomly - for (unsigned i = 0; i < rec.size(); ) { - unsigned j = i + 1; - while (j < rec.size() && rec[j].first == rec[i].first) ++j; - // bucket in [i,j), get a sample outside bucket - unsigned nleft = i, nright = static_cast(rec.size() - j); - if (nleft + nright != 0) { - int nsample = param_.num_pairsample; - while (nsample --) { - for (unsigned pid = i; pid < j; ++pid) { - unsigned ridx = - std::uniform_int_distribution(0, nleft + nright - 1)(rnd); - if (ridx < nleft) { - pairs.emplace_back(rec[ridx].second, rec[pid].second, - info.GetWeight(k) * weight_normalization_factor); - } else { - pairs.emplace_back(rec[pid].second, rec[ridx+j-i].second, - info.GetWeight(k) * weight_normalization_factor); - } - } - } - } - i = j; - } - // get lambda weight for the pairs - LambdaWeightComputerT::GetLambdaWeight(lst, &pairs); - // rescale each gradient and hessian so that the lst have constant weighted - float scale = 1.0f / param_.num_pairsample; - if (param_.fix_list_weight != 0.0f) { - scale *= param_.fix_list_weight / (gptr[k + 1] - gptr[k]); - } - for (auto & pair : pairs) { - const ListEntry &pos = lst[pair.pos_index]; - const ListEntry &neg = lst[pair.neg_index]; - const bst_float w = pair.weight * scale; - const float eps = 1e-16f; - bst_float p = common::Sigmoid(pos.pred - neg.pred); - bst_float g = p - 1.0f; - bst_float h = std::max(p * (1.0f - p), eps); - // accumulate gradient and hessian in both pid, and nid - gpair[pos.rindex] += GradientPair(g * w, 2.0f*w*h); - gpair[neg.rindex] += GradientPair(-g * w, 2.0f*w*h); - } - }); - } - }); - } - exc.Rethrow(); - } - -#if defined(__CUDACC__) - void ComputeGradientsOnGPU(const HostDeviceVector& preds, - const MetaInfo& info, - int iter, - HostDeviceVector* out_gpair, - const std::vector &gptr) { - LOG(DEBUG) << "Computing " << LambdaWeightComputerT::Name() << " gradients on GPU."; - - auto device = ctx_->gpu_id; - dh::safe_cuda(cudaSetDevice(device)); - - bst_float weight_normalization_factor = ComputeWeightNormalizationFactor(info, gptr); - - // Set the device ID and copy them to the device - out_gpair->SetDevice(device); - info.labels.SetDevice(device); - preds.SetDevice(device); - info.weights_.SetDevice(device); - - out_gpair->Resize(preds.Size()); - - auto d_preds = preds.ConstDevicePointer(); - auto d_gpair = out_gpair->DevicePointer(); - auto d_labels = info.labels.View(device); - - SortedLabelList slist(param_); - - // Sort the labels within the groups on the device - slist.Sort(*info.labels.Data(), gptr); - - // Initialize the gradients next - out_gpair->Fill(GradientPair(0.0f, 0.0f)); - - // Finally, compute the gradients - slist.ComputeGradients(d_preds, d_labels.Values().data(), info.weights_, - iter, d_gpair, weight_normalization_factor); - } -#endif - - LambdaRankParam param_; -}; - -#if !defined(GTEST_TEST) -// register the objective functions -DMLC_REGISTER_PARAMETER(LambdaRankParam); - -XGBOOST_REGISTER_OBJECTIVE(PairwiseRankObj, PairwiseLambdaWeightComputer::Name()) -.describe("Pairwise rank objective.") -.set_body([]() { return new LambdaRankObj(); }); - -XGBOOST_REGISTER_OBJECTIVE(LambdaRankObjMAP, MAPLambdaWeightComputer::Name()) -.describe("LambdaRank with MAP as objective.") -.set_body([]() { return new LambdaRankObj(); }); -#endif - -} // namespace obj -} // namespace xgboost diff --git a/tests/cpp/objective/test_lambdarank_obj.cc b/tests/cpp/objective/test_lambdarank_obj.cc index d02a55c1b7b8..c808e97f0c75 100644 --- a/tests/cpp/objective/test_lambdarank_obj.cc +++ b/tests/cpp/objective/test_lambdarank_obj.cc @@ -223,4 +223,125 @@ TEST(LambdaRank, MakePair) { ASSERT_EQ(n_pairs, info.num_row_ * param.NumPair()); } } + +void TestMAPStat(Context const* ctx) { + auto p_fmat = EmptyDMatrix(); + MetaInfo& info = p_fmat->Info(); + ltr::LambdaRankParam param; + param.UpdateAllowUnknown(Args{}); + + { + std::vector h_data{1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 1.0f}; + info.labels.Reshape(h_data.size(), 1); + info.labels.Data()->HostVector() = h_data; + info.num_row_ = h_data.size(); + + HostDeviceVector predt; + auto& h_predt = predt.HostVector(); + h_predt.resize(h_data.size()); + std::iota(h_predt.rbegin(), h_predt.rend(), 0.0f); + + auto p_cache = std::make_shared(ctx, info, param); + + predt.SetDevice(ctx->gpu_id); + auto rank_idx = + p_cache->SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan()); + + if (ctx->IsCPU()) { + obj::cpu_impl::MAPStat(ctx, info.labels.HostView().Slice(linalg::All(), 0), rank_idx, + p_cache); + } else { + obj::cuda_impl::MAPStat(ctx, info, rank_idx, p_cache); + } + + Context cpu_ctx; + auto n_rel = p_cache->NumRelevant(&cpu_ctx); + auto acc = p_cache->Acc(&cpu_ctx); + + ASSERT_EQ(n_rel[0], 1.0); + ASSERT_EQ(acc[0], 1.0); + + ASSERT_EQ(n_rel.back(), h_data.size() - 1.0); + ASSERT_NEAR(acc.back(), 1.95 + (1.0 / h_data.size()), kRtEps); + } + { + info.labels.Reshape(16); + auto& h_label = info.labels.Data()->HostVector(); + info.group_ptr_ = {0, 8, 16}; + info.num_row_ = info.labels.Shape(0); + + std::fill_n(h_label.begin(), 8, 1.0f); + std::fill_n(h_label.begin() + 8, 8, 0.0f); + HostDeviceVector predt; + auto& h_predt = predt.HostVector(); + h_predt.resize(h_label.size()); + std::iota(h_predt.rbegin(), h_predt.rbegin() + 8, 0.0f); + std::iota(h_predt.rbegin() + 8, h_predt.rend(), 0.0f); + + auto p_cache = std::make_shared(ctx, info, param); + + predt.SetDevice(ctx->gpu_id); + auto rank_idx = + p_cache->SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan()); + + if (ctx->IsCPU()) { + obj::cpu_impl::MAPStat(ctx, info.labels.HostView().Slice(linalg::All(), 0), rank_idx, + p_cache); + } else { + obj::cuda_impl::MAPStat(ctx, info, rank_idx, p_cache); + } + + Context cpu_ctx; + auto n_rel = p_cache->NumRelevant(&cpu_ctx); + ASSERT_EQ(n_rel[7], 8); // first group + ASSERT_EQ(n_rel.back(), 0); // second group + } +} + +TEST(LambdaRank, MAPStat) { + Context ctx; + TestMAPStat(&ctx); +} + +void TestMAPGPair(Context const* ctx) { + std::unique_ptr obj{xgboost::ObjFunction::Create("rank:map", ctx)}; + Args args; + obj->Configure(args); + + CheckConfigReload(obj, "rank:map"); + + CheckRankingObjFunction(obj, // obj + {0, 0.1f, 0, 0.1f}, // score + {0, 1, 0, 1}, // label + {2.0f, 2.0f}, // weight + {0, 2, 4}, // group + {1.2054923f, -1.2054923f, 1.2054923f, -1.2054923f}, // out grad + {1.2657166f, 1.2657166f, 1.2657166f, 1.2657166f}); + // disable the second query group with 0 weight + CheckRankingObjFunction(obj, // obj + {0, 0.1f, 0, 0.1f}, // score + {0, 1, 0, 1}, // label + {2.0f, 0.0f}, // weight + {0, 2, 4}, // group + {1.2054923f, -1.2054923f, .0f, .0f}, // out grad + {1.2657166f, 1.2657166f, .0f, .0f}); +} + +TEST(LambdaRank, MAPGPair) { + Context ctx; + TestMAPGPair(&ctx); +} + +void TestPairWiseGPair(Context const* ctx) { + std::unique_ptr obj{xgboost::ObjFunction::Create("rank:pairwise", ctx)}; + Args args; + obj->Configure(args); + + args.emplace_back("lambdarank_unbiased", "true"); +} + +TEST(LambdaRank, Pairwise) { + Context ctx; + TestPairWiseGPair(&ctx); +} } // namespace xgboost::obj diff --git a/tests/cpp/objective/test_lambdarank_obj.cu b/tests/cpp/objective/test_lambdarank_obj.cu index 01d020dda1cd..d0f448993487 100644 --- a/tests/cpp/objective/test_lambdarank_obj.cu +++ b/tests/cpp/objective/test_lambdarank_obj.cu @@ -18,6 +18,12 @@ TEST(LambdaRank, GPUNDCGJsonIO) { TestNDCGJsonIO(&ctx); } +TEST(LambdaRank, GPUMAPStat) { + Context ctx; + ctx.gpu_id = 0; + TestMAPStat(&ctx); +} + TEST(LambdaRank, GPUNDCGGPair) { Context ctx; ctx.gpu_id = 0; @@ -153,4 +159,10 @@ TEST(LambdaRank, RankItemCountOnRight) { RankItemCountImpl(sorted_items, wrapper, 1, static_cast(1)); RankItemCountImpl(sorted_items, wrapper, 0, static_cast(0)); } + +TEST(LambdaRank, GPUMAPGPair) { + Context ctx; + ctx.gpu_id = 0; + TestMAPGPair(&ctx); +} } // namespace xgboost::obj diff --git a/tests/cpp/objective/test_lambdarank_obj.h b/tests/cpp/objective/test_lambdarank_obj.h index aebe3ad54f3e..9539f1a3003e 100644 --- a/tests/cpp/objective/test_lambdarank_obj.h +++ b/tests/cpp/objective/test_lambdarank_obj.h @@ -18,6 +18,8 @@ #include "../helpers.h" // for EmptyDMatrix namespace xgboost::obj { +void TestMAPStat(Context const* ctx); + inline void TestNDCGJsonIO(Context const* ctx) { std::unique_ptr obj{ObjFunction::Create("rank:ndcg", ctx)}; @@ -37,6 +39,8 @@ void TestNDCGGPair(Context const* ctx); void TestUnbiasedNDCG(Context const* ctx); +void TestMAPGPair(Context const* ctx); + /** * \brief Initialize test data for make pair tests. */ diff --git a/tests/cpp/objective/test_ranking_obj.cc b/tests/cpp/objective/test_ranking_obj.cc deleted file mode 100644 index 2072f530e8da..000000000000 --- a/tests/cpp/objective/test_ranking_obj.cc +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright by Contributors -#include -#include -#include - -#include "../helpers.h" - -namespace xgboost { - -TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPair)) { - std::vector> args; - xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX); - - std::unique_ptr obj{xgboost::ObjFunction::Create("rank:pairwise", &ctx)}; - obj->Configure(args); - CheckConfigReload(obj, "rank:pairwise"); - - // Test with setting sample weight to second query group - CheckRankingObjFunction(obj, - {0, 0.1f, 0, 0.1f}, - {0, 1, 0, 1}, - {2.0f, 0.0f}, - {0, 2, 4}, - {1.9f, -1.9f, 0.0f, 0.0f}, - {1.995f, 1.995f, 0.0f, 0.0f}); - - CheckRankingObjFunction(obj, - {0, 0.1f, 0, 0.1f}, - {0, 1, 0, 1}, - {1.0f, 1.0f}, - {0, 2, 4}, - {0.95f, -0.95f, 0.95f, -0.95f}, - {0.9975f, 0.9975f, 0.9975f, 0.9975f}); - - ASSERT_NO_THROW(obj->DefaultEvalMetric()); -} - -TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPairSameLabels)) { - std::vector> args; - xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX); - - std::unique_ptr obj{ObjFunction::Create("rank:pairwise", &ctx)}; - obj->Configure(args); - // No computation of gradient/hessian, as there is no diversity in labels - CheckRankingObjFunction(obj, - {0, 0.1f, 0, 0.1f}, - {1, 1, 1, 1}, - {2.0f, 0.0f}, - {0, 2, 4}, - {0.0f, 0.0f, 0.0f, 0.0f}, - {0.0f, 0.0f, 0.0f, 0.0f}); - - ASSERT_NO_THROW(obj->DefaultEvalMetric()); -} - -TEST(Objective, DeclareUnifiedTest(MAPRankingGPair)) { - std::vector> args; - xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX); - - std::unique_ptr obj{xgboost::ObjFunction::Create("rank:map", &ctx)}; - obj->Configure(args); - CheckConfigReload(obj, "rank:map"); - - // Test with setting sample weight to second query group - CheckRankingObjFunction(obj, - {0, 0.1f, 0, 0.1f}, - {0, 1, 0, 1}, - {2.0f, 0.0f}, - {0, 2, 4}, - {0.95f, -0.95f, 0.0f, 0.0f}, - {0.9975f, 0.9975f, 0.0f, 0.0f}); - - CheckRankingObjFunction(obj, - {0, 0.1f, 0, 0.1f}, - {0, 1, 0, 1}, - {1.0f, 1.0f}, - {0, 2, 4}, - {0.475f, -0.475f, 0.475f, -0.475f}, - {0.4988f, 0.4988f, 0.4988f, 0.4988f}); - ASSERT_NO_THROW(obj->DefaultEvalMetric()); -} - -} // namespace xgboost diff --git a/tests/cpp/objective/test_ranking_obj_gpu.cu b/tests/cpp/objective/test_ranking_obj_gpu.cu deleted file mode 100644 index cd40b49284f6..000000000000 --- a/tests/cpp/objective/test_ranking_obj_gpu.cu +++ /dev/null @@ -1,175 +0,0 @@ -/*! - * Copyright 2019-2021 by XGBoost Contributors - */ -#include - -#include "test_ranking_obj.cc" -#include "../../../src/objective/rank_obj.cu" - -namespace xgboost { - -template > -std::unique_ptr> -RankSegmentSorterTestImpl(const std::vector &group_indices, - const std::vector &hlabels, - const std::vector &expected_sorted_hlabels, - const std::vector &expected_orig_pos - ) { - std::unique_ptr> seg_sorter_ptr(new dh::SegmentSorter); - dh::SegmentSorter &seg_sorter(*seg_sorter_ptr); - - // Create a bunch of unsorted labels on the device and sort it via the segment sorter - dh::device_vector dlabels(hlabels); - seg_sorter.SortItems(dlabels.data().get(), dlabels.size(), group_indices, Comparator()); - - auto num_items = seg_sorter.GetItemsSpan().size(); - EXPECT_EQ(num_items, group_indices.back()); - EXPECT_EQ(seg_sorter.GetNumGroups(), group_indices.size() - 1); - - // Check the labels - dh::device_vector sorted_dlabels(num_items); - sorted_dlabels.assign(dh::tcbegin(seg_sorter.GetItemsSpan()), - dh::tcend(seg_sorter.GetItemsSpan())); - thrust::host_vector sorted_hlabels(sorted_dlabels); - EXPECT_EQ(expected_sorted_hlabels, sorted_hlabels); - - // Check the indices - dh::device_vector dorig_pos(num_items); - dorig_pos.assign(dh::tcbegin(seg_sorter.GetOriginalPositionsSpan()), - dh::tcend(seg_sorter.GetOriginalPositionsSpan())); - dh::device_vector horig_pos(dorig_pos); - EXPECT_EQ(expected_orig_pos, horig_pos); - - return seg_sorter_ptr; -} - -TEST(Objective, RankSegmentSorterTest) { - RankSegmentSorterTestImpl({0, 2, 4, 7, 10, 14, 18, 22, 26}, // Groups - {1, 1, // Labels - 1, 2, - 3, 2, 1, - 1, 2, 1, - 1, 3, 4, 2, - 1, 2, 1, 1, - 1, 2, 2, 3, - 3, 3, 1, 2}, - {1, 1, // Expected sorted labels - 2, 1, - 3, 2, 1, - 2, 1, 1, - 4, 3, 2, 1, - 2, 1, 1, 1, - 3, 2, 2, 1, - 3, 3, 2, 1}, - {0, 1, // Expected original positions - 3, 2, - 4, 5, 6, - 8, 7, 9, - 12, 11, 13, 10, - 15, 14, 16, 17, - 21, 19, 20, 18, - 22, 23, 25, 24}); -} - -TEST(Objective, RankSegmentSorterSingleGroupTest) { - RankSegmentSorterTestImpl({0, 7}, // Groups - {6, 1, 4, 3, 0, 5, 2}, // Labels - {6, 5, 4, 3, 2, 1, 0}, // Expected sorted labels - {0, 5, 2, 3, 6, 1, 4}); // Expected original positions -} - -TEST(Objective, RankSegmentSorterAscendingTest) { - RankSegmentSorterTestImpl>( - {0, 4, 7}, // Groups - {3, 1, 4, 2, // Labels - 6, 5, 7}, - {1, 2, 3, 4, // Expected sorted labels - 5, 6, 7}, - {1, 3, 0, 2, // Expected original positions - 5, 4, 6}); -} - -TEST(Objective, IndexableSortedItemsTest) { - std::vector hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels - 7.8f, 5.01f, 6.96f, - 10.3f, 8.7f, 11.4f, 9.45f, 11.4f}; - dh::device_vector dlabels(hlabels); - - auto segment_label_sorter = RankSegmentSorterTestImpl( - {0, 4, 7, 12}, // Groups - hlabels, - {4.4f, 3.1f, 2.3f, 1.2f, // Expected sorted labels - 7.8f, 6.96f, 5.01f, - 11.4f, 11.4f, 10.3f, 9.45f, 8.7f}, - {3, 0, 2, 1, // Expected original positions - 4, 6, 5, - 9, 11, 7, 10, 8}); - - segment_label_sorter->CreateIndexableSortedPositions(); - std::vector sorted_indices(segment_label_sorter->GetNumItems()); - dh::CopyDeviceSpanToVector(&sorted_indices, - segment_label_sorter->GetIndexableSortedPositionsSpan()); - std::vector expected_sorted_indices = { - 1, 3, 2, 0, - 4, 6, 5, - 9, 11, 7, 10, 8}; - EXPECT_EQ(expected_sorted_indices, sorted_indices); -} - -TEST(Objective, ComputeAndCompareMAPStatsTest) { - std::vector hlabels = {3.1f, 0.0f, 2.3f, 4.4f, // Labels - 0.0f, 5.01f, 0.0f, - 10.3f, 0.0f, 11.4f, 9.45f, 11.4f}; - dh::device_vector dlabels(hlabels); - - auto segment_label_sorter = RankSegmentSorterTestImpl( - {0, 4, 7, 12}, // Groups - hlabels, - {4.4f, 3.1f, 2.3f, 0.0f, // Expected sorted labels - 5.01f, 0.0f, 0.0f, - 11.4f, 11.4f, 10.3f, 9.45f, 0.0f}, - {3, 0, 2, 1, // Expected original positions - 5, 4, 6, - 9, 11, 7, 10, 8}); - - // Create MAP stats on the device first using the objective - std::vector hpreds{-9.78f, 24.367f, 0.908f, -11.47f, - -1.03f, -2.79f, -3.1f, - 104.22f, 103.1f, -101.7f, 100.5f, 45.1f}; - dh::device_vector dpreds(hpreds); - - xgboost::obj::MAPLambdaWeightComputer map_lw_computer(dpreds.data().get(), - dlabels.data().get(), - *segment_label_sorter); - - // Get the device MAP stats on host - std::vector dmap_stats( - segment_label_sorter->GetNumItems()); - dh::CopyDeviceSpanToVector(&dmap_stats, map_lw_computer.GetMapStatsSpan()); - - // Compute the MAP stats on host next to compare - std::vector hgroups(segment_label_sorter->GetNumGroups() + 1); - dh::CopyDeviceSpanToVector(&hgroups, segment_label_sorter->GetGroupsSpan()); - - for (size_t i = 0; i < hgroups.size() - 1; ++i) { - auto gbegin = hgroups[i]; - auto gend = hgroups[i + 1]; - std::vector lst_entry; - for (auto j = gbegin; j < gend; ++j) { - lst_entry.emplace_back(hpreds[j], hlabels[j], j); - } - std::stable_sort(lst_entry.begin(), lst_entry.end(), xgboost::obj::ListEntry::CmpPred); - - // Compute the MAP stats with this list and compare with the ones computed on the device - std::vector hmap_stats; - xgboost::obj::MAPLambdaWeightComputer::GetMAPStats(lst_entry, &hmap_stats); - for (auto j = gbegin; j < gend; ++j) { - EXPECT_EQ(dmap_stats[j].hits, hmap_stats[j - gbegin].hits); - EXPECT_NEAR(dmap_stats[j].ap_acc, hmap_stats[j - gbegin].ap_acc, 0.01f); - EXPECT_NEAR(dmap_stats[j].ap_acc_miss, hmap_stats[j - gbegin].ap_acc_miss, 0.01f); - EXPECT_NEAR(dmap_stats[j].ap_acc_add, hmap_stats[j - gbegin].ap_acc_add, 0.01f); - } - } -} - -} // namespace xgboost diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 67620e6ddf48..e0d3d680be68 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -176,7 +176,7 @@ def test_ranking(): def test_ranking_metric() -> None: from sklearn.metrics import roc_auc_score - X, y, qid, w = tm.make_ltr(512, 4, 3, 2) + X, y, qid, w = tm.make_ltr(512, 4, 3, 1) # use auc for test as ndcg_score in sklearn works only on label gain instead of exp # gain. # note that the auc in sklearn is different from the one in XGBoost. The one in diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index a5e0f028a060..6d88323ac49f 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -1343,61 +1343,94 @@ def test_unsupported_params(self): SparkXGBClassifier(evals_result={}) -class XgboostRankerLocalTest(SparkTestCase): - def setUp(self): - self.session.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8") - self.ranker_df_train = self.session.createDataFrame( - [ - (Vectors.dense(1.0, 2.0, 3.0), 0, 0), - (Vectors.dense(4.0, 5.0, 6.0), 1, 0), - (Vectors.dense(9.0, 4.0, 8.0), 2, 0), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1), - (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1), - (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1), - ], - ["features", "label", "qid"], - ) - self.ranker_df_test = self.session.createDataFrame( - [ - (Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988), - (Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556), - (Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570), - (Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988), - (Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612), - (Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826), - ], - ["features", "qid", "expected_prediction"], - ) - self.ranker_df_train_1 = self.session.createDataFrame( - [ - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 9), - (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 9), - (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 9), - (Vectors.dense(1.0, 2.0, 3.0), 0, 8), - (Vectors.dense(4.0, 5.0, 6.0), 1, 8), - (Vectors.dense(9.0, 4.0, 8.0), 2, 8), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 7), - (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 7), - (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 7), - (Vectors.dense(1.0, 2.0, 3.0), 0, 6), - (Vectors.dense(4.0, 5.0, 6.0), 1, 6), - (Vectors.dense(9.0, 4.0, 8.0), 2, 6), - ] - * 4, - ["features", "label", "qid"], - ) +LTRData = namedtuple("LTRData", ("df_train", "df_test", "df_train_1")) - def test_ranker(self): - ranker = SparkXGBRanker(qid_col="qid") - assert ranker.getOrDefault(ranker.objective) == "rank:pairwise" - model = ranker.fit(self.ranker_df_train) - pred_result = model.transform(self.ranker_df_test).collect() +@pytest.fixture +def ltr_data(spark: SparkSession) -> Generator[LTRData, None, None]: + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8") + ranker_df_train = spark.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 0, 0), + (Vectors.dense(4.0, 5.0, 6.0), 1, 0), + (Vectors.dense(9.0, 4.0, 8.0), 2, 0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1), + (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1), + ], + ["features", "label", "qid"], + ) + X_train = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [9.0, 4.0, 8.0], + [np.NaN, 1.0, 5.5], + [np.NaN, 6.0, 7.5], + [np.NaN, 8.0, 9.5], + ] + ) + qid_train = np.array([0, 0, 0, 1, 1, 1]) + y_train = np.array([0, 1, 2, 0, 1, 2]) + + X_test = np.array( + [ + [1.5, 2.0, 3.0], + [4.5, 5.0, 6.0], + [9.0, 4.5, 8.0], + [np.NaN, 1.0, 6.0], + [np.NaN, 6.0, 7.0], + [np.NaN, 8.0, 10.5], + ] + ) + + ltr = xgb.XGBRanker(tree_method="approx", objective="rank:pairwise") + ltr.fit(X_train, y_train, qid=qid_train) + predt = ltr.predict(X_test) + + ranker_df_test = spark.createDataFrame( + [ + (Vectors.dense(1.5, 2.0, 3.0), 0, float(predt[0])), + (Vectors.dense(4.5, 5.0, 6.0), 0, float(predt[1])), + (Vectors.dense(9.0, 4.5, 8.0), 0, float(predt[2])), + (Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, float(predt[3])), + (Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, float(predt[4])), + (Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, float(predt[5])), + ], + ["features", "qid", "expected_prediction"], + ) + ranker_df_train_1 = spark.createDataFrame( + [ + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 9), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 9), + (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 9), + (Vectors.dense(1.0, 2.0, 3.0), 0, 8), + (Vectors.dense(4.0, 5.0, 6.0), 1, 8), + (Vectors.dense(9.0, 4.0, 8.0), 2, 8), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 7), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 7), + (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 7), + (Vectors.dense(1.0, 2.0, 3.0), 0, 6), + (Vectors.dense(4.0, 5.0, 6.0), 1, 6), + (Vectors.dense(9.0, 4.0, 8.0), 2, 6), + ] + * 4, + ["features", "label", "qid"], + ) + yield LTRData(ranker_df_train, ranker_df_test, ranker_df_train_1) + + +class TestPySparkLocalLETOR: + def test_ranker(self, ltr_data: LTRData) -> None: + ranker = SparkXGBRanker(qid_col="qid", objective="rank:pairwise") + assert ranker.getOrDefault(ranker.objective) == "rank:pairwise" + model = ranker.fit(ltr_data.df_train) + pred_result = model.transform(ltr_data.df_test).collect() for row in pred_result: assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3) - def test_ranker_qid_sorted(self): - ranker = SparkXGBRanker(qid_col="qid", num_workers=4) - assert ranker.getOrDefault(ranker.objective) == "rank:pairwise" - model = ranker.fit(self.ranker_df_train_1) - model.transform(self.ranker_df_test).collect() + def test_ranker_qid_sorted(self, ltr_data: LTRData) -> None: + ranker = SparkXGBRanker(qid_col="qid", num_workers=4, objective="rank:ndcg") + assert ranker.getOrDefault(ranker.objective) == "rank:ndcg" + model = ranker.fit(ltr_data.df_train_1) + model.transform(ltr_data.df_test).collect()