diff --git a/demo/guide-python/learning_to_rank.py b/demo/guide-python/learning_to_rank.py
new file mode 100644
index 000000000000..37b7157f5029
--- /dev/null
+++ b/demo/guide-python/learning_to_rank.py
@@ -0,0 +1,212 @@
+"""
+Getting started with learning to rank
+=====================================
+
+ .. versionadded:: 2.0.0
+
+This is a demonstration of using XGBoost for learning to rank tasks using the
+MSLR_10k_letor dataset. For more infomation about the dataset, please visit its
+`description page `_.
+
+This is a two-part demo, the first one contains a basic example of using XGBoost to
+train on relevance degree, and the second part simulates click data and enable the
+position debiasing training.
+
+For an overview of learning to rank in XGBoost, please see
+:doc:`Learning to Rank `.
+"""
+from __future__ import annotations
+
+import argparse
+import json
+import os
+import pickle as pkl
+
+import numpy as np
+import pandas as pd
+from sklearn.datasets import load_svmlight_file
+
+import xgboost as xgb
+from xgboost.testing.data import RelDataCV, simulate_clicks, sort_ltr_samples
+
+
+def load_mlsr_10k(data_path: str, cache_path: str) -> RelDataCV:
+ """Load the MSLR10k dataset from data_path and cache a pickle object in cache_path.
+
+ Returns
+ -------
+
+ A list of tuples [(X, y, qid), ...].
+
+ """
+ root_path = os.path.expanduser(args.data)
+ cacheroot_path = os.path.expanduser(args.cache)
+ cache_path = os.path.join(cacheroot_path, "MSLR_10K_LETOR.pkl")
+
+ # Use only the Fold1 for demo:
+ # Train, Valid, Test
+ # {S1,S2,S3}, S4, S5
+ fold = 1
+
+ if not os.path.exists(cache_path):
+ fold_path = os.path.join(root_path, f"Fold{fold}")
+ train_path = os.path.join(fold_path, "train.txt")
+ valid_path = os.path.join(fold_path, "vali.txt")
+ test_path = os.path.join(fold_path, "test.txt")
+ X_train, y_train, qid_train = load_svmlight_file(
+ train_path, query_id=True, dtype=np.float32
+ )
+ y_train = y_train.astype(np.int32)
+ qid_train = qid_train.astype(np.int32)
+
+ X_valid, y_valid, qid_valid = load_svmlight_file(
+ valid_path, query_id=True, dtype=np.float32
+ )
+ y_valid = y_valid.astype(np.int32)
+ qid_valid = qid_valid.astype(np.int32)
+
+ X_test, y_test, qid_test = load_svmlight_file(
+ test_path, query_id=True, dtype=np.float32
+ )
+ y_test = y_test.astype(np.int32)
+ qid_test = qid_test.astype(np.int32)
+
+ data = RelDataCV(
+ train=(X_train, y_train, qid_train),
+ test=(X_test, y_test, qid_test),
+ max_rel=4,
+ )
+
+ with open(cache_path, "wb") as fd:
+ pkl.dump(data, fd)
+
+ with open(cache_path, "rb") as fd:
+ data = pkl.load(fd)
+
+ return data
+
+
+def ranking_demo(args: argparse.Namespace) -> None:
+ """Demonstration for learning to rank with relevance degree."""
+ data = load_mlsr_10k(args.data, args.cache)
+
+ # Sort data according to query index
+ X_train, y_train, qid_train = data.train
+ sorted_idx = np.argsort(qid_train)
+ X_train = X_train[sorted_idx]
+ y_train = y_train[sorted_idx]
+ qid_train = qid_train[sorted_idx]
+
+ X_test, y_test, qid_test = data.test
+ sorted_idx = np.argsort(qid_test)
+ X_test = X_test[sorted_idx]
+ y_test = y_test[sorted_idx]
+ qid_test = qid_test[sorted_idx]
+
+ ranker = xgb.XGBRanker(
+ tree_method="gpu_hist",
+ lambdarank_pair_method="topk",
+ lambdarank_num_pair_per_sample=13,
+ eval_metric=["ndcg@1", "ndcg@8"],
+ )
+ ranker.fit(
+ X_train,
+ y_train,
+ qid=qid_train,
+ eval_set=[(X_test, y_test)],
+ eval_qid=[qid_test],
+ verbose=True,
+ )
+
+
+def click_data_demo(args: argparse.Namespace) -> None:
+ """Demonstration for learning to rank with click data."""
+ data = load_mlsr_10k(args.data, args.cache)
+ train, test = simulate_clicks(data)
+ assert test is not None
+
+ assert train.X.shape[0] == train.click.size
+ assert test.X.shape[0] == test.click.size
+ assert test.score.dtype == np.float32
+ assert test.click.dtype == np.int32
+
+ X_train, clicks_train, y_train, qid_train = sort_ltr_samples(
+ train.X,
+ train.y,
+ train.qid,
+ train.click,
+ train.pos,
+ )
+ X_test, clicks_test, y_test, qid_test = sort_ltr_samples(
+ test.X,
+ test.y,
+ test.qid,
+ test.click,
+ test.pos,
+ )
+
+ class ShowPosition(xgb.callback.TrainingCallback):
+ def after_iteration(
+ self,
+ model: xgb.Booster,
+ epoch: int,
+ evals_log: xgb.callback.TrainingCallback.EvalsLog,
+ ) -> bool:
+ config = json.loads(model.save_config())
+ ti_plus = np.array(config["learner"]["objective"]["ti+"])
+ tj_minus = np.array(config["learner"]["objective"]["tj-"])
+ df = pd.DataFrame({"ti+": ti_plus, "tj-": tj_minus})
+ print(df)
+ return False
+
+ ranker = xgb.XGBRanker(
+ n_estimators=512,
+ tree_method="gpu_hist",
+ learning_rate=0.01,
+ reg_lambda=1.5,
+ subsample=0.8,
+ sampling_method="gradient_based",
+ # LTR specific parameters
+ objective="rank:ndcg",
+ # - Enable bias estimation
+ lambdarank_unbiased=True,
+ # - normalization (1 / (norm + 1))
+ lambdarank_bias_norm=1,
+ # - Focus on the top 12 documents
+ lambdarank_num_pair_per_sample=12,
+ lambdarank_pair_method="topk",
+ ndcg_exp_gain=True,
+ eval_metric=["ndcg@1", "ndcg@3", "ndcg@5", "ndcg@10"],
+ callbacks=[ShowPosition()],
+ )
+ ranker.fit(
+ X_train,
+ clicks_train,
+ qid=qid_train,
+ eval_set=[(X_test, y_test), (X_test, clicks_test)],
+ eval_qid=[qid_test, qid_test],
+ verbose=True,
+ )
+ ranker.predict(X_test)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Demonstration of learning to rank using XGBoost."
+ )
+ parser.add_argument(
+ "--data",
+ type=str,
+ help="Root directory of the MSLR-WEB10K data.",
+ required=True,
+ )
+ parser.add_argument(
+ "--cache",
+ type=str,
+ help="Directory for caching processed data.",
+ required=True,
+ )
+ args = parser.parse_args()
+
+ ranking_demo(args)
+ click_data_demo(args)
diff --git a/doc/contrib/coding_guide.rst b/doc/contrib/coding_guide.rst
index a080c2a31541..f939a17b2a94 100644
--- a/doc/contrib/coding_guide.rst
+++ b/doc/contrib/coding_guide.rst
@@ -16,8 +16,10 @@ C++ Coding Guideline
* Each line of text may contain up to 100 characters.
* The use of C++ exceptions is allowed.
-- Use C++11 features such as smart pointers, braced initializers, lambda functions, and ``std::thread``.
+- Use C++17 features such as smart pointers, braced initializers, lambda functions, and ``std::thread``.
- Use Doxygen to document all the interface code.
+- We have some comments around symbols imported by headers, some of those are hinted by `include-what-you-use `_. It's not required.
+- We use clang-tidy and clang-format. You can check their configuration in the root directory of the XGBoost source tree.
- We have a series of automatic checks to ensure that all of our codebase complies with the Google style. Before submitting your pull request, you are encouraged to run the style checks on your machine. See :ref:`running_checks_locally`.
***********************
diff --git a/doc/tutorials/index.rst b/doc/tutorials/index.rst
index 310fd0170610..eb8c23726d56 100644
--- a/doc/tutorials/index.rst
+++ b/doc/tutorials/index.rst
@@ -21,6 +21,7 @@ See `Awesome XGBoost `_ for mo
monotonic
rf
feature_interaction_constraint
+ learning_to_rank
aft_survival_analysis
c_api_tutorial
input_format
diff --git a/doc/tutorials/learning_to_rank.rst b/doc/tutorials/learning_to_rank.rst
new file mode 100644
index 000000000000..b9883d236ac2
--- /dev/null
+++ b/doc/tutorials/learning_to_rank.rst
@@ -0,0 +1,201 @@
+################
+Learning to Rank
+################
+
+**Contents**
+
+.. contents::
+ :local:
+ :backlinks: none
+
+********
+Overview
+********
+Often in the context of information retrieval, learning-to-rank aims to train a model that arranges a set of query results into an ordered list `[1] <#references>`__. For surprivised learning-to-rank, the predictors are sample documents encoded as feature matrix, and the labels are relevance degree for each sample. Relevance degree can be multi-level (graded) or binary (relevant or not). The training samples are often grouped by their query index with each query group containing multiple query results.
+
+XGBoost implements learning to rank through a set of objective functions and performance metrics. The default objective is ``rank:ndcg`` based on the ``LambdaMART`` `[2] <#references>`__ algorithm, which in turn is an adaptation of the ``LambdaRank`` `[3] <#references>`__ framework to gradient boosting trees. For a history and a summary of the algorithm, see `[5] <#references>`__. The implementation in XGBoost features deterministic GPU computation, distributed training, position debiasing and two different pair construction strategies.
+
+************************************
+Training with the Pariwise Objective
+************************************
+``LambdaMART`` is a pairwise ranking model, meaning that it compares the relevance degree for every pair of samples in a query group and calculate a proxy gradient for each pair. The default objective ``rank:ndcg`` is using the surrogate gradient derived from the ``ndcg`` metric. To train a XGBoost model, we need an additional sorted array called ``qid`` for specifying the query group of input samples. An example input would look like this:
+
++-------+-----------+---------------+
+| QID | Label | Features |
++=======+===========+===============+
+| 1 | 0 | :math:`x_1` |
++-------+-----------+---------------+
+| 1 | 1 | :math:`x_2` |
++-------+-----------+---------------+
+| 1 | 0 | :math:`x_3` |
++-------+-----------+---------------+
+| 2 | 0 | :math:`x_4` |
++-------+-----------+---------------+
+| 2 | 1 | :math:`x_5` |
++-------+-----------+---------------+
+| 2 | 1 | :math:`x_6` |
++-------+-----------+---------------+
+| 2 | 1 | :math:`x_7` |
++-------+-----------+---------------+
+
+Notice that the samples are sorted based on their query index in a non-decreasing order. In the above example, the first three samples belong to the first query and the next four samples belong to the second. For the sake of simplicity, we will use a synthetic binary learning-to-rank dataset in the following code snippets, with binary labels representing whether the result is relevant or not, and randomly assign the query group index to each sample. For an example that uses a real world dataset, please see :ref:`sphx_glr_python_examples_learning_to_rank.py`.
+
+.. code-block:: python
+
+ from sklearn.datasets import make_classification
+ import numpy as np
+
+ import xgboost as xgb
+
+ # Make a synthetic ranking dataset for demonstration
+ X, y = make_classification(random_state=rng)
+ rng = np.random.default_rng(1994)
+ n_query_groups = 3
+ qid = rng.integers(0, 3, size=X.shape[0])
+
+ # Sort the inputs based on query index
+ sorted_idx = np.argsort(qid)
+ X = X[sorted_idx, :]
+ y = y[sorted_idx]
+
+The simpliest way to train a ranking model is by using the scikit-learn estimator interface. Continuing the previous snippet, we can train a simple ranking model without tuning:
+
+.. code-block:: python
+
+ ranker = xgb.XGBRanker(tree_method="hist", lambdarank_num_pair_per_sample=8, objective="rank:ndcg", lambdarank_pair_method="topk")
+ ranker.fit(X, y, qid=qid)
+
+Please note that, as of writing, there's no learning-to-rank interface in scikit-learn. As a result, the :py:class:`xgboost.XGBRanker` class does not fully conform the scikit-learn estimator guideline and can not be directly used with some of its utility functions. For instances, the ``auc_score`` and ``ndcg_score`` in scikit-learn don't consider query group information nor the pairwise loss. Most of the metrics are implemented as part of XGBoost, but to use scikit-learn utilities like :py:func:`sklearn.model_selection.cross_validation`, we need to make some adjustments in order to pass the ``qid`` as an additional parameter for :py:meth:`xgboost.XGBRanker.score`. Given a data frame ``X`` (either pandas or cuDF), add the column ``qid`` as follows:
+
+.. code-block:: python
+
+ df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1]))
+ df["qid"] = qid
+ ranker.fit(df, y) # No need to pass qid as a separate argument
+
+ from sklearn.model_selection import StratifiedGroupKFold, cross_val_score
+ # Works with cv in scikit-learn, along with HPO utilities like GridSearchCV
+ kfold = StratifiedGroupKFold(shuffle=False)
+ cross_val_score(ranker, df, y, cv=kfold, groups=df.qid)
+
+The above snippets build a model using ``LambdaMART`` with the ``NDCG@8`` metric. The outputs of a ranker are relevance scores:
+
+.. code-block:: python
+
+ scores = ranker.predict(X)
+ sorted_idx = np.argsort(scores)[::-1]
+ # Sort the relevance scores from most relevant to least relevant
+ scores = scores[sorted_idx]
+
+
+*************
+Position Bias
+*************
+
+.. versionadded:: 2.0.0
+
+.. note::
+
+ The feature is considered experimental. This is a heated research area, and your input is much appreciated!
+
+Obtaining real relevance degrees for query results is an expensive and strenuous, requiring human labelers to label all results one by one. When such labeling task is infeasible, we might want to train the learning-to-rank model on user click data instead, as it is relatively easy to collect. Another advantage of using click data directly is that it can reflect the most up-to-date user preferences `[1] <#references>`__. However, user clicks are often biased, as users tend to choose results that are displayed in higher positions. User clicks are also noisy, where users might accidentally click on irrelevant documents. To ameliorate these issues, XGBoost implements the ``Unbiased LambdaMART`` `[4] <#references>`__ algorithm to debias the position-dependent click data. The feature can be enabled by the ``lambdarank_unbiased`` parameter; see :ref:`ltr-param` for related options and :ref:`sphx_glr_python_examples_learning_to_rank.py` for a worked example with simulated user clicks.
+
+****
+Loss
+****
+
+XGBoost implements different ``LambdaMART`` objectives based on different metrics. We list them here as a reference. Other than those used as objective function, XGBoost also implements metrics like ``pre`` (for precision) for evaluation. See :doc:`parameters ` for available options and the following sections for how to choose these objectives based of the amount of effective pairs.
+
+* NDCG
+
+`Normalized Discounted Cumulative Gain` ``NDCG`` can be used with both binary relevance and multi-level relevance. If you are not sure about your data, this metric can be used as the default. The name for the objective is ``rank:ndcg``.
+
+
+* MAP
+
+`Mean average precision` ``MAP`` is a binary measure. It can be used when the relevance label is 0 or 1. The name for the objective is ``rank:map``.
+
+
+* Pairwise
+
+The `LambdaMART` algorithm scales the logistic loss with learning to rank metrics like ``NDCG`` in the hope of including ranking information into the loss function. The ``rank:pairwise`` loss is the original version of the pairwise loss, also known as the `RankNet loss` `[7] <#references>`__ or the `pairwise logistic loss`. Unlike the ``rank:map`` and the ``rank:ndcg``, no scaling is applied (:math:`|\Delta Z_{ij}| = 1`).
+
+Whether scaling with a LTR metric is actually more effective is still up for debate; `[8] <#references>`__ provides a theoretical foundation for general lambda loss functions and some insights into the framework.
+
+******************
+Constructing Pairs
+******************
+
+There are two implemented strategies for constructing document pairs for :math:`\lambda`-gradient calculation. The first one is the ``mean`` method, another one is the ``topk`` method. The preferred strategy can be specified by the ``lambdarank_pair_method`` parameter.
+
+For the ``mean`` strategy, XGBoost samples ``lambdarank_num_pair_per_sample`` pairs for each document in a query list. For example, given a list of 3 documents and ``lambdarank_num_pair_per_sample`` is set to 2, XGBoost will randomly sample 6 pairs, assuming the labels for these documents are different. On the other hand, if the pair method is set to ``topk``, XGBoost constructs about :math:`k \times |query|` number of pairs with :math:`|query|` pairs for each sample at the top :math:`k = lambdarank\_num\_pair` position. The number of pairs counted here is an approximation since we skip pairs that have the same label.
+
+*********************
+Obtaining Good Result
+*********************
+
+Learning to rank is a sophisticated task and an active research area. It's not trivial to train a model that generalizes well. There are multiple loss functions available in XGBoost along with a set of hyperparameters. This section contains some hints for how to choose hyperparameters as a starting point. One can further optimize the model by tuning these hyperparameters.
+
+The first question would be how to choose an objective that matches the task at hand. If your input data has multi-level relevance degrees, then either ``rank:ndcg`` or ``rank:pairwise`` should be used. However, when the input has binary labels, we have multiple options based on the target metric. `[6] <#references>`__ provides some guidelines on this topic and users are encouraged to see the analysis done in their work. The choice should be based on the number of `effective pairs`, which refers to the number of pairs that can generate non-zero gradient and contribute to training. `LambdaMART` with ``MRR`` has the least amount of effective pairs as the :math:`\lambda`-gradient is only non-zero when the pair contains a non-relevant document ranked higher than the top relevant document. As a result, it's not implemented in XGBoost. Since ``NDCG`` is a multi-level metric, it usually generate more effective pairs than ``MAP``.
+
+However, when there are sufficiently many effective pairs, it's shown in `[6] <#references>`__ that matching the target metric with the objective is of significance. When the target metric is ``MAP`` and you are using a large dataset that can provide a sufficient amount of effective pairs, ``rank:map`` can in theory yield higher ``MAP`` value than ``rank:ndcg``.
+
+The consideration of effective pairs also applies to the choice of pair method (``lambdarank_pair_method``) and the number of pairs for each sample (``lambdarank_num_pair_per_sample``). For example, the mean-``NDCG`` considers more pairs than ``NDCG@10``, so the former generates more effective pairs and provides more granularity than the latter. Also, using the ``mean`` strategy can help the model generalize with random sampling. However, one might want to focus the training on the top :math:`k` documents instead of using all pairs, to better fit their real-world application.
+
+When using the mean strategy for generating pairs, where the target metric (like ``NDCG``) is computed over the whole query list, users can specify how many pairs should be generated per each document, by setting the ``lambdarank_num_pair_per_sample``. XGBoost will randomly sample ``lambdarank_num_pair_per_sample`` pairs for each element in the query group (:math:`|pairs| = |query| \times num\_pairsample`). Often, setting it to 1 can produce reasonable results. In cases where performance is inadequate due to insufficient number of effective pairs being generated, set ``lambdarank_num_pair_per_sample`` to a higher value. As more document pairs are generated, more effective pairs will be generated as well.
+
+On the other hand, if you are prioritizing the top :math:`k` documents, the ``lambdarank_num_pair_per_sample`` should be set slightly higher than :math:`k` (with a few more documents) to obtain a good training result.
+
+**Summary** If you have large amount of training data:
+
+* Use the target-matching objective.
+* Choose the ``topk`` strategy for generating document pairs (if it's appropriate for your application).
+
+On the other hand, if you have comparatively small amount of training data:
+
+* Select ``NDCG`` or the RankNet loss (``rank:pairwise``).
+* Choose the ``mean`` strategy for generating document pairs, to obtain more effective pairs.
+
+For any method chosen, you can modify ``lambdarank_num_pair_per_sample`` to control the amount of pairs generated.
+
+********************
+Distributed Training
+********************
+XGBoost implements distributed learning-to-rank with integration of multiple frameworks including Dask, Spark, and PySpark. The interface is similar to the single-node counterpart. Please refer to document of the respective XGBoost interface for details. Scattering a query group onto multiple workers is theoretically sound but can affect the model accuracy. For most of the use cases, the small discrepancy is not an issue, as the amount of training data is usually large when distributed training is used. As a result, users don't need to partition the data based on query groups. As long as each data partition is correctly sorted by query IDs, XGBoost can aggregate sample gradients accordingly.
+
+*******************
+Reproducible Result
+*******************
+
+Like any other tasks, XGBoost should generate reproducible results given the same hardware and software environments (and data partitions, if distributed interface is used). Even when the underlying environment has changed, the result should still be consistent. However, when the ``lambdarank_pair_method`` is set to ``mean``, XGBoost uses random sampling, and results may differ depending on the platform used. The random number generator used on Windows (Microsoft Visual C++) is different from the ones used on other platforms like Linux (GCC, Clang) [#f0]_, so the output varies significantly between these platforms.
+
+.. [#f0] `minstd_rand` implementation is different on MSVC. The implementations from GCC and Thrust produce the same output.
+
+**********
+References
+**********
+
+[1] Tie-Yan Liu. 2009. "`Learning to Rank for Information Retrieval`_". Found. Trends Inf. Retr. 3, 3 (March 2009), 225–331.
+
+[2] Christopher J. C. Burges, Robert Ragno, and Quoc Viet Le. 2006. "`Learning to rank with nonsmooth cost functions`_". In Proceedings of the 19th International Conference on Neural Information Processing Systems (NIPS'06). MIT Press, Cambridge, MA, USA, 193–200.
+
+[3] Wu, Q., Burges, C.J.C., Svore, K.M. et al. "`Adapting boosting for information retrieval measures`_". Inf Retrieval 13, 254–270 (2010).
+
+[4] Ziniu Hu, Yang Wang, Qu Peng, Hang Li. "`Unbiased LambdaMART: An Unbiased Pairwise Learning-to-Rank Algorithm`_". Proceedings of the 2019 World Wide Web Conference.
+
+[5] Burges, Chris J.C. "`From RankNet to LambdaRank to LambdaMART: An Overview`_". MSR-TR-2010-82
+
+[6] Pinar Donmez, Krysta M. Svore, and Christopher J.C. Burges. 2009. "`On the local optimality of LambdaRank`_". In Proceedings of the 32nd international ACM SIGIR conference on Research and development in information retrieval (SIGIR '09). Association for Computing Machinery, New York, NY, USA, 460–467.
+
+[7] Chris Burges, Tal Shaked, Erin Renshaw, Ari Lazier, Matt Deeds, Nicole Hamilton, and Greg Hullender. 2005. "`Learning to rank using gradient descent`_". In Proceedings of the 22nd international conference on Machine learning (ICML '05). Association for Computing Machinery, New York, NY, USA, 89–96.
+
+[8] Xuanhui Wang and Cheng Li and Nadav Golbandi and Mike Bendersky and Marc Najork. 2018. "`The LambdaLoss Framework for Ranking Metric Optimization`_". Proceedings of The 27th ACM International Conference on Information and Knowledge Management (CIKM '18).
+
+.. _`Learning to Rank for Information Retrieval`: https://doi.org/10.1561/1500000016
+.. _`Learning to rank with nonsmooth cost functions`: https://dl.acm.org/doi/10.5555/2976456.2976481
+.. _`Adapting boosting for information retrieval measures`: https://doi.org/10.1007/s10791-009-9112-1
+.. _`Unbiased LambdaMART: An Unbiased Pairwise Learning-to-Rank Algorithm`: https://dl.acm.org/doi/10.1145/3308558.3313447
+.. _`From RankNet to LambdaRank to LambdaMART: An Overview`: https://www.microsoft.com/en-us/research/publication/from-ranknet-to-lambdarank-to-lambdamart-an-overview/
+.. _`On the local optimality of LambdaRank`: https://doi.org/10.1145/1571941.1572021
+.. _`Learning to rank using gradient descent`: https://doi.org/10.1145/1102351.1102363
+.. _`The LambdaLoss Framework for Ranking Metric Optimization`: https://dl.acm.org/doi/10.1145/3269206.3271784
diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala
index 5342aa563621..b8dca5d7040e 100644
--- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala
+++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala
@@ -220,7 +220,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
test("Ranking: train with Group") {
withGpuSparkSession(enableCsvConf()) { spark =>
- val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "rank:pairwise",
+ val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "rank:ndcg",
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
"features_cols" -> featureNames, "label_col" -> labelName)
val Array(trainingDf, testDf) = spark.read.option("header", "true").schema(schema)
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala
index f63865fabc2d..6aec4d36ed6f 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala
@@ -25,7 +25,7 @@ private[spark] trait LearningTaskParams extends Params {
/**
* Specify the learning task and the corresponding learning objective.
* options: reg:squarederror, reg:squaredlogerror, reg:logistic, binary:logistic, binary:logitraw,
- * count:poisson, multi:softmax, multi:softprob, rank:pairwise, reg:gamma.
+ * count:poisson, multi:softmax, multi:softprob, rank:ndcg, reg:gamma.
* default: reg:squarederror
*/
final val objective = new Param[String](this, "objective",
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
index c1e34224caca..d93b182e043e 100755
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
@@ -201,7 +201,7 @@ class XGBoostGeneralSuite extends AnyFunSuite with TmpFolderPerSuite with PerTes
sc,
buildTrainingRDD,
List("eta" -> "1", "max_depth" -> "6",
- "objective" -> "rank:pairwise", "num_round" -> 5, "num_workers" -> numWorkers,
+ "objective" -> "rank:ndcg", "num_round" -> 5, "num_workers" -> numWorkers,
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
"missing" -> Float.NaN).toMap)
@@ -268,7 +268,7 @@ class XGBoostGeneralSuite extends AnyFunSuite with TmpFolderPerSuite with PerTes
val training = buildDataFrameWithGroup(Ranking.train, 5)
val Array(train, eval1, eval2) = training.randomSplit(Array(0.6, 0.2, 0.2), 0)
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6",
- "objective" -> "rank:pairwise",
+ "objective" -> "rank:ndcg",
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group")
val xgb1 = new XGBoostRegressor(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
val model1 = xgb1.fit(train)
@@ -281,7 +281,7 @@ class XGBoostGeneralSuite extends AnyFunSuite with TmpFolderPerSuite with PerTes
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1))
val paramMap2 = Map("eta" -> "1", "max_depth" -> "6",
- "objective" -> "rank:pairwise",
+ "objective" -> "rank:ndcg",
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group",
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
val xgb2 = new XGBoostRegressor(paramMap2)
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala
index efcb38cf62e1..1bdea7a827bd 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala
@@ -121,7 +121,7 @@ class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu
test("ranking: use group data") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "rank:pairwise", "num_workers" -> numWorkers, "num_round" -> 5,
+ "objective" -> "rank:ndcg", "num_workers" -> numWorkers, "num_round" -> 5,
"group_col" -> "group", "tree_method" -> treeMethod)
val trainingDF = buildDataFrameWithGroup(Ranking.train)
diff --git a/python-package/xgboost/_typing.py b/python-package/xgboost/_typing.py
index 774681031cec..39952aca9845 100644
--- a/python-package/xgboost/_typing.py
+++ b/python-package/xgboost/_typing.py
@@ -31,7 +31,7 @@
PathLike = Union[str, os.PathLike]
CupyT = ArrayLike # maybe need a stub for cupy arrays
NumpyOrCupy = Any
-NumpyDType = Union[str, Type[np.number]]
+NumpyDType = Union[str, Type[np.number]] # pylint: disable=invalid-name
PandasDType = Any # real type is pandas.core.dtypes.base.ExtensionDtype
FloatCompatible = Union[float, np.float32, np.float64]
diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py
index 9b5949cdb8b7..43d531a9d7bc 100644
--- a/python-package/xgboost/sklearn.py
+++ b/python-package/xgboost/sklearn.py
@@ -1796,7 +1796,11 @@ def _get_qid(
@xgboost_model_doc(
- """Implementation of the Scikit-Learn API for XGBoost Ranking.""",
+ """Implementation of the Scikit-Learn API for XGBoost Ranking.
+
+See :doc:`Learning to Rank ` for an introducion.
+
+ """,
["estimators", "model"],
end_note="""
.. note::
@@ -1845,7 +1849,7 @@ def _get_qid(
class XGBRanker(XGBModel, XGBRankerMixIn):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
@_deprecate_positional_args
- def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any):
+ def __init__(self, *, objective: str = "rank:ndcg", **kwargs: Any):
super().__init__(objective=objective, **kwargs)
if callable(self.objective):
raise ValueError("custom objective function not supported by XGBRanker")
@@ -2029,7 +2033,7 @@ def fit(
self._Booster = train(
params,
train_dmatrix,
- self.get_num_boosting_rounds(),
+ num_boost_round=self.get_num_boosting_rounds(),
early_stopping_rounds=early_stopping_rounds,
evals=evals,
evals_result=evals_result,
diff --git a/python-package/xgboost/testing/data.py b/python-package/xgboost/testing/data.py
index 477d0cf3d6f0..b5f07d98161e 100644
--- a/python-package/xgboost/testing/data.py
+++ b/python-package/xgboost/testing/data.py
@@ -1,11 +1,14 @@
+# pylint: disable=invalid-name
"""Utilities for data generation."""
import os
import zipfile
-from typing import Any, Generator, List, Tuple, Union
+from dataclasses import dataclass
+from typing import Any, Generator, List, NamedTuple, Optional, Tuple, Union
from urllib import request
import numpy as np
import pytest
+from numpy import typing as npt
from numpy.random import Generator as RNG
from scipy import sparse
@@ -340,3 +343,263 @@ def get_mq2008(
y_valid,
qid_valid,
)
+
+
+RelData = Tuple[sparse.csr_matrix, npt.NDArray[np.int32], npt.NDArray[np.int32]]
+
+
+@dataclass
+class ClickFold:
+ """A structure containing information about generated user-click data."""
+
+ X: sparse.csr_matrix
+ y: npt.NDArray[np.int32]
+ qid: npt.NDArray[np.int32]
+ score: npt.NDArray[np.float32]
+ click: npt.NDArray[np.int32]
+ pos: npt.NDArray[np.int64]
+
+
+class RelDataCV(NamedTuple):
+ """Simple data struct for holding a train-test split of a learning to rank dataset."""
+
+ train: RelData
+ test: RelData
+ max_rel: int
+
+ def is_binary(self) -> bool:
+ """Whether the label consists of binary relevance degree."""
+ return self.max_rel == 1
+
+
+class PBM: # pylint: disable=too-few-public-methods
+ """Simulate click data with position bias model. There are other models available in
+ `ULTRA `_ like the cascading model.
+
+ References
+ ----------
+ Unbiased LambdaMART: An Unbiased Pairwise Learning-to-Rank Algorithm
+
+ """
+
+ def __init__(self, eta: float) -> None:
+ # click probability for each relevance degree. (from 0 to 4)
+ self.click_prob = np.array([0.1, 0.16, 0.28, 0.52, 1.0])
+ exam_prob = np.array(
+ [0.68, 0.61, 0.48, 0.34, 0.28, 0.20, 0.11, 0.10, 0.08, 0.06]
+ )
+ # Observation probability, encoding positional bias for each position
+ self.exam_prob = np.power(exam_prob, eta)
+
+ def sample_clicks_for_query(
+ self, labels: npt.NDArray[np.int32], position: npt.NDArray[np.int64]
+ ) -> npt.NDArray[np.int32]:
+ """Sample clicks for one query based on input relevance degree and position.
+
+ Parameters
+ ----------
+
+ labels :
+ relevance_degree
+
+ """
+ labels = np.array(labels, copy=True)
+
+ click_prob = np.zeros(labels.shape)
+ # minimum
+ labels[labels < 0] = 0
+ # maximum
+ labels[labels >= len(self.click_prob)] = -1
+ click_prob = self.click_prob[labels]
+
+ exam_prob = np.zeros(labels.shape)
+ assert position.size == labels.size
+ ranks = np.array(position, copy=True)
+ # maximum
+ ranks[ranks >= self.exam_prob.size] = -1
+ exam_prob = self.exam_prob[ranks]
+
+ rng = np.random.default_rng(1994)
+ prob = rng.random(size=labels.shape[0], dtype=np.float32)
+
+ clicks: npt.NDArray[np.int32] = np.zeros(labels.shape, dtype=np.int32)
+ clicks[prob < exam_prob * click_prob] = 1
+ return clicks
+
+
+def rlencode(x: npt.NDArray[np.int32]) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
+ """Run length encoding using numpy, modified from:
+ https://gist.github.com/nvictus/66627b580c13068589957d6ab0919e66
+
+ """
+ x = np.asarray(x)
+ n = x.size
+ starts = np.r_[0, np.flatnonzero(~np.isclose(x[1:], x[:-1], equal_nan=True)) + 1]
+ lengths = np.diff(np.r_[starts, n])
+ values = x[starts]
+ indptr = np.append(starts, np.array([x.size]))
+
+ return indptr, lengths, values
+
+
+def init_rank_score(
+ X: sparse.csr_matrix,
+ y: npt.NDArray[np.int32],
+ qid: npt.NDArray[np.int32],
+ sample_rate: float = 0.1,
+) -> npt.NDArray[np.float32]:
+ """We use XGBoost to generate the initial score instead of SVMRank for
+ simplicity. Sample rate is set to 0.1 by default so that we can test with small
+ datasets.
+
+ """
+ # random sample
+ rng = np.random.default_rng(1994)
+ n_samples = int(X.shape[0] * sample_rate)
+ index = np.arange(0, X.shape[0], dtype=np.uint64)
+ rng.shuffle(index)
+ index = index[:n_samples]
+
+ X_train = X[index]
+ y_train = y[index]
+ qid_train = qid[index]
+
+ # Sort training data based on query id, required by XGBoost.
+ sorted_idx = np.argsort(qid_train)
+ X_train = X_train[sorted_idx]
+ y_train = y_train[sorted_idx]
+ qid_train = qid_train[sorted_idx]
+
+ ltr = xgboost.XGBRanker(objective="rank:ndcg", tree_method="hist")
+ ltr.fit(X_train, y_train, qid=qid_train)
+
+ # Use the original order of the data.
+ scores = ltr.predict(X)
+ return scores
+
+
+def simulate_one_fold(
+ fold: Tuple[sparse.csr_matrix, npt.NDArray[np.int32], npt.NDArray[np.int32]],
+ scores_fold: npt.NDArray[np.float32],
+) -> ClickFold:
+ """Simulate clicks for one fold."""
+ X_fold, y_fold, qid_fold = fold
+ assert qid_fold.dtype == np.int32
+
+ qids = np.unique(qid_fold)
+
+ position = np.empty((y_fold.size,), dtype=np.int64)
+ clicks = np.empty((y_fold.size,), dtype=np.int32)
+ pbm = PBM(eta=1.0)
+
+ # Avoid grouping by qid as we want to preserve the original data partition by
+ # the dataset authors.
+ for q in qids:
+ qid_mask = q == qid_fold
+ qid_mask = qid_mask.reshape(qid_mask.shape[0])
+ query_scores = scores_fold[qid_mask]
+ # Initial rank list, scores sorted to decreasing order
+ query_position = np.argsort(query_scores)[::-1]
+ position[qid_mask] = query_position
+ # get labels
+ relevance_degrees = y_fold[qid_mask]
+ query_clicks = pbm.sample_clicks_for_query(relevance_degrees, query_position)
+ clicks[qid_mask] = query_clicks
+
+ assert X_fold.shape[0] == qid_fold.shape[0], (X_fold.shape, qid_fold.shape)
+ assert X_fold.shape[0] == clicks.shape[0], (X_fold.shape, clicks.shape)
+
+ return ClickFold(X_fold, y_fold, qid_fold, scores_fold, clicks, position)
+
+
+# pylint: disable=too-many-locals
+def simulate_clicks(cv_data: RelDataCV) -> Tuple[ClickFold, Optional[ClickFold]]:
+ """Simulate click data using position biased model (PBM)."""
+ X, y, qid = list(zip(cv_data.train, cv_data.test))
+
+ # ptr to train-test split
+ indptr = np.array([0] + [v.shape[0] for v in X])
+ indptr = np.cumsum(indptr)
+
+ assert len(indptr) == 2 + 1 # train, test
+ X_full = sparse.vstack(X)
+ y_full = np.concatenate(y)
+ qid_full = np.concatenate(qid)
+
+ # Obtain initial relevance score for click simulation
+ scores_full = init_rank_score(X_full, y_full, qid_full)
+ # partition it back to (train, test) tuple
+ scores = [scores_full[indptr[i - 1] : indptr[i]] for i in range(1, indptr.size)]
+
+ X_lst, y_lst, q_lst, s_lst, c_lst, p_lst = [], [], [], [], [], []
+ for i in range(indptr.size - 1):
+ fold = simulate_one_fold((X[i], y[i], qid[i]), scores[i])
+ X_lst.append(fold.X)
+ y_lst.append(fold.y)
+ q_lst.append(fold.qid)
+ s_lst.append(fold.score)
+ c_lst.append(fold.click)
+ p_lst.append(fold.pos)
+
+ scores_check_1 = [s_lst[i] for i in range(indptr.size - 1)]
+ for i in range(2):
+ assert (scores_check_1[i] == scores[i]).all()
+
+ if len(X_lst) == 1:
+ train = ClickFold(X_lst[0], y_lst[0], q_lst[0], s_lst[0], c_lst[0], p_lst[0])
+ test = None
+ else:
+ train, test = (
+ ClickFold(X_lst[i], y_lst[i], q_lst[i], s_lst[i], c_lst[i], p_lst[i])
+ for i in range(len(X_lst))
+ )
+ return train, test
+
+
+def sort_ltr_samples(
+ X: sparse.csr_matrix,
+ y: npt.NDArray[np.int32],
+ qid: npt.NDArray[np.int32],
+ clicks: npt.NDArray[np.int32],
+ pos: npt.NDArray[np.int64],
+) -> Tuple[
+ sparse.csr_matrix,
+ npt.NDArray[np.int32],
+ npt.NDArray[np.int32],
+ npt.NDArray[np.int32],
+]:
+ """Sort data based on query index and position."""
+ sorted_idx = np.argsort(qid)
+ X = X[sorted_idx]
+ clicks = clicks[sorted_idx]
+ qid = qid[sorted_idx]
+ pos = pos[sorted_idx]
+
+ indptr, _, _ = rlencode(qid)
+
+ for i in range(1, indptr.size):
+ beg = indptr[i - 1]
+ end = indptr[i]
+
+ assert beg < end, (beg, end)
+ assert np.unique(qid[beg:end]).size == 1, (beg, end)
+
+ query_pos = pos[beg:end]
+ assert query_pos.min() == 0, query_pos.min()
+ assert query_pos.max() >= query_pos.size - 1, (
+ query_pos.max(),
+ query_pos.size,
+ i,
+ np.unique(qid[beg:end]),
+ )
+ sorted_idx = np.argsort(query_pos)
+
+ X[beg:end] = X[beg:end][sorted_idx]
+ clicks[beg:end] = clicks[beg:end][sorted_idx]
+ y[beg:end] = y[beg:end][sorted_idx]
+ # not necessary
+ qid[beg:end] = qid[beg:end][sorted_idx]
+
+ data = X, clicks, y, qid
+
+ return data
diff --git a/python-package/xgboost/testing/params.py b/python-package/xgboost/testing/params.py
index e6ba73e1f541..8dc91b6017be 100644
--- a/python-package/xgboost/testing/params.py
+++ b/python-package/xgboost/testing/params.py
@@ -67,3 +67,17 @@
"max_cat_threshold": strategies.integers(1, 128),
}
)
+
+lambdarank_parameter_strategy = strategies.fixed_dictionaries(
+ {
+ "lambdarank_unbiased": strategies.sampled_from([True, False]),
+ "lambdarank_pair_method": strategies.sampled_from(["topk", "mean"]),
+ "lambdarank_num_pair_per_sample": strategies.integers(1, 8),
+ "lambdarank_bias_norm": strategies.floats(0.5, 2.0),
+ "objective": strategies.sampled_from(
+ ["rank:ndcg", "rank:map", "rank:pairwise"]
+ ),
+ }
+).filter(
+ lambda x: not (x["lambdarank_unbiased"] and x["lambdarank_pair_method"] == "mean")
+)
diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu
index bd1b0b2d89d3..b6888610b586 100644
--- a/src/metric/elementwise_metric.cu
+++ b/src/metric/elementwise_metric.cu
@@ -483,9 +483,13 @@ class QuantileError : public MetricNoCache {
const char* Name() const override { return "quantile"; }
void LoadConfig(Json const& in) override {
- auto const& name = get(in["name"]);
- CHECK_EQ(name, "quantile");
- FromJson(in["quantile_loss_param"], ¶m_);
+ auto const& obj = get