diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in
index ed3f10571edf..33452e142ed1 100644
--- a/R-package/src/Makevars.in
+++ b/R-package/src/Makevars.in
@@ -32,7 +32,7 @@ OBJECTS= \
$(PKGROOT)/src/objective/objective.o \
$(PKGROOT)/src/objective/regression_obj.o \
$(PKGROOT)/src/objective/multiclass_obj.o \
- $(PKGROOT)/src/objective/rank_obj.o \
+ $(PKGROOT)/src/objective/lambdarank_obj.o \
$(PKGROOT)/src/objective/hinge.o \
$(PKGROOT)/src/objective/aft_obj.o \
$(PKGROOT)/src/objective/adaptive.o \
diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win
index 024ba1aa19b7..f176642643cc 100644
--- a/R-package/src/Makevars.win
+++ b/R-package/src/Makevars.win
@@ -32,7 +32,7 @@ OBJECTS= \
$(PKGROOT)/src/objective/objective.o \
$(PKGROOT)/src/objective/regression_obj.o \
$(PKGROOT)/src/objective/multiclass_obj.o \
- $(PKGROOT)/src/objective/rank_obj.o \
+ $(PKGROOT)/src/objective/lambdarank_obj.o \
$(PKGROOT)/src/objective/hinge.o \
$(PKGROOT)/src/objective/aft_obj.o \
$(PKGROOT)/src/objective/adaptive.o \
diff --git a/demo/guide-python/learning_to_rank.py b/demo/guide-python/learning_to_rank.py
new file mode 100644
index 000000000000..9dd46cb6c0d4
--- /dev/null
+++ b/demo/guide-python/learning_to_rank.py
@@ -0,0 +1,462 @@
+"""
+Getting started with learning to rank
+=====================================
+
+ .. versionadded:: 2.0.0
+
+This is a demonstration of using XGBoost for learning to rank tasks using the
+MSLR_10k_letor dataset. For more infomation about the dataset, please visit its
+`description page `_.
+
+This is a two-part demo, the first one contains a basic example of using XGBoost to
+train on relevance degree, and the second part simulates click data and enable the
+position debiasing training.
+
+For an overview of learning to rank in XGBoost, please see
+:doc:`Learning to Rank `.
+"""
+from __future__ import annotations
+
+import argparse
+import os
+import pickle as pkl
+from collections import namedtuple
+from time import time
+from typing import List, NamedTuple, Tuple, TypedDict
+
+import numpy as np
+from numpy import typing as npt
+from scipy import sparse
+from scipy.sparse import csr_matrix
+from sklearn.datasets import load_svmlight_file
+
+import xgboost as xgb
+
+
+class PBM:
+ """Simulate click data with position bias model. There are other models available in
+ `ULTRA `_ like the cascading model.
+
+ References
+ ----------
+ Unbiased LambdaMART: An Unbiased Pairwise Learning-to-Rank Algorithm
+
+ """
+
+ def __init__(self, eta: float) -> None:
+ # click probability for each relevance degree. (from 0 to 4)
+ self.click_prob = np.array([0.1, 0.16, 0.28, 0.52, 1.0])
+ exam_prob = np.array(
+ [0.68, 0.61, 0.48, 0.34, 0.28, 0.20, 0.11, 0.10, 0.08, 0.06]
+ )
+ self.exam_prob = np.power(exam_prob, eta)
+
+ def sample_clicks_for_query(
+ self, labels: npt.NDArray[np.int32], position: npt.NDArray[np.int64]
+ ) -> npt.NDArray[np.int32]:
+ """Sample clicks for one query based on input relevance degree and position.
+
+ Parameters
+ ----------
+
+ labels :
+ relevance_degree
+
+ """
+ labels = np.array(labels, copy=True)
+
+ click_prob = np.zeros(labels.shape)
+ # minimum
+ labels[labels < 0] = 0
+ # maximum
+ labels[labels >= len(self.click_prob)] = -1
+ click_prob = self.click_prob[labels]
+
+ exam_prob = np.zeros(labels.shape)
+ assert position.size == labels.size
+ ranks = np.array(position, copy=True)
+ # maximum
+ ranks[ranks >= self.exam_prob.size] = -1
+ exam_prob = self.exam_prob[ranks]
+
+ rng = np.random.default_rng(1994)
+ prob = rng.random(size=labels.shape[0], dtype=np.float32)
+
+ clicks: npt.NDArray[np.int32] = np.zeros(labels.shape, dtype=np.int32)
+ clicks[prob < exam_prob * click_prob] = 1
+ return clicks
+
+
+# relevance degree data
+RelData = Tuple[csr_matrix, npt.NDArray[np.int32], npt.NDArray[np.int32]]
+
+
+class RelDataCV(NamedTuple):
+ train: RelData
+ valid: RelData
+ test: RelData
+
+
+def load_mlsr_10k(data_path: str, cache_path: str) -> RelDataCV:
+ """Load the MSLR10k dataset from data_path and cache a pickle object in cache_path.
+
+ Returns
+ -------
+
+ A list of tuples [(X, y, qid), ...].
+
+ """
+ root_path = os.path.expanduser(args.data)
+ cacheroot_path = os.path.expanduser(args.cache)
+ cache_path = os.path.join(cacheroot_path, "MSLR_10K_LETOR.pkl")
+
+ # Use only the Fold1 for demo:
+ # Train, Valid, Test
+ # {S1,S2,S3}, S4, S5
+ fold = 1
+
+ if not os.path.exists(cache_path):
+ fold_path = os.path.join(root_path, f"Fold{fold}")
+ train_path = os.path.join(fold_path, "train.txt")
+ valid_path = os.path.join(fold_path, "vali.txt")
+ test_path = os.path.join(fold_path, "test.txt")
+ X_train, y_train, qid_train = load_svmlight_file(
+ train_path, query_id=True, dtype=np.float32
+ )
+ y_train = y_train.astype(np.int32)
+ qid_train = qid_train.astype(np.int32)
+
+ X_valid, y_valid, qid_valid = load_svmlight_file(
+ valid_path, query_id=True, dtype=np.float32
+ )
+ y_valid = y_valid.astype(np.int32)
+ qid_valid = qid_valid.astype(np.int32)
+
+ X_test, y_test, qid_test = load_svmlight_file(
+ test_path, query_id=True, dtype=np.float32
+ )
+ y_test = y_test.astype(np.int32)
+ qid_test = qid_test.astype(np.int32)
+
+ data = RelDataCV(
+ train=(X_train, y_train, qid_train),
+ valid=(X_valid, y_valid, qid_valid),
+ test=(X_test, y_test, qid_test),
+ )
+
+ with open(cache_path, "wb") as fd:
+ pkl.dump(data, fd)
+
+ with open(cache_path, "rb") as fd:
+ data = pkl.load(fd)
+
+ return data
+
+
+def ranking_demo(args: argparse.Namespace) -> None:
+ """Demonstration for learning to rank with relevance degree."""
+ data = load_mlsr_10k(args.data, args.cache)
+
+ X_train, y_train, qid_train = data.train
+ sorted_idx = np.argsort(qid_train)
+ X_train = X_train[sorted_idx]
+ y_train = y_train[sorted_idx]
+ qid_train = qid_train[sorted_idx]
+
+ X_valid, y_valid, qid_valid = data.valid
+ sorted_idx = np.argsort(qid_valid)
+ X_valid = X_valid[sorted_idx]
+ y_valid = y_valid[sorted_idx]
+ qid_valid = qid_valid[sorted_idx]
+
+ ranker = xgb.XGBRanker(
+ tree_method="gpu_hist",
+ lambdarank_pair_method="topk",
+ lambdarank_num_pair_per_sample=13,
+ eval_metric=["ndcg@1", "ndcg@8"],
+ )
+ ranker.fit(
+ X_train,
+ y_train,
+ qid=qid_train,
+ eval_set=[(X_valid, y_valid)],
+ eval_qid=[qid_valid],
+ verbose=True,
+ )
+
+
+def rlencode(x: npt.NDArray[np.int32]) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
+ """Run length encoding using numpy, modified from:
+ https://gist.github.com/nvictus/66627b580c13068589957d6ab0919e66
+
+ """
+ x = np.asarray(x)
+ n = x.size
+ starts = np.r_[0, np.flatnonzero(~np.isclose(x[1:], x[:-1], equal_nan=True)) + 1]
+ lengths = np.diff(np.r_[starts, n])
+ values = x[starts]
+ indptr = np.append(starts, np.array([x.size]))
+
+ return indptr, lengths, values
+
+
+ClickFold = namedtuple("ClickFold", ("X", "y", "q", "s", "c", "p"))
+
+
+def simulate_clicks(args: argparse.Namespace) -> ClickFold:
+ """Simulate click data using position biased model (PBM)."""
+
+ def init_rank_score(
+ X: csr_matrix,
+ y: npt.NDArray[np.int32],
+ qid: npt.NDArray[np.int32],
+ sample_rate: float = 0.01,
+ ) -> npt.NDArray[np.float32]:
+ """We use XGBoost to generate the initial score instead of SVMRank for
+ simplicity.
+
+ """
+ # random sample
+ _rng = np.random.default_rng(1994)
+ n_samples = int(X.shape[0] * sample_rate)
+ index = np.arange(0, X.shape[0], dtype=np.uint64)
+ _rng.shuffle(index)
+ index = index[:n_samples]
+
+ X_train = X[index]
+ y_train = y[index]
+ qid_train = qid[index]
+
+ # Sort training data based on query id, required by XGBoost.
+ sorted_idx = np.argsort(qid_train)
+ X_train = X_train[sorted_idx]
+ y_train = y_train[sorted_idx]
+ qid_train = qid_train[sorted_idx]
+
+ ltr = xgb.XGBRanker(objective="rank:ndcg", tree_method="gpu_hist")
+ ltr.fit(X_train, y_train, qid=qid_train)
+
+ # Use the original order of the data.
+ scores = ltr.predict(X)
+ return scores
+
+ def simulate_one_fold(fold, scores_fold: npt.NDArray[np.float32]) -> ClickFold:
+ """Simulate clicks for one fold."""
+ X_fold, y_fold, qid_fold = fold
+ assert qid_fold.dtype == np.int32
+ indptr, lengths, values = rlencode(qid_fold)
+
+ qids = np.unique(qid_fold)
+
+ position = np.empty((y_fold.size,), dtype=np.int64)
+ clicks = np.empty((y_fold.size,), dtype=np.int32)
+ pbm = PBM(eta=1.0)
+
+ # Avoid grouping by qid as we want to preserve the original data partition by
+ # the dataset authors.
+ for q in qids:
+ qid_mask = q == qid_fold
+ query_scores = scores_fold[qid_mask]
+ # Initial rank list, scores sorted to decreasing order
+ query_position = np.argsort(query_scores)[::-1]
+ position[qid_mask] = query_position
+ # get labels
+ relevance_degrees = y_fold[qid_mask]
+ query_clicks = pbm.sample_clicks_for_query(
+ relevance_degrees, query_position
+ )
+ clicks[qid_mask] = query_clicks
+
+ assert X_fold.shape[0] == qid_fold.shape[0], (X_fold.shape, qid_fold.shape)
+ assert X_fold.shape[0] == clicks.shape[0], (X_fold.shape, clicks.shape)
+
+ return ClickFold(X_fold, y_fold, qid_fold, scores_fold, clicks, position)
+
+ cache_path = os.path.join(
+ os.path.expanduser(args.cache), "MSLR_10K_LETOR_Clicks.pkl"
+ )
+ if os.path.exists(cache_path):
+ print("Found existing cache for clicks.")
+ with open(cache_path, "rb") as fdr:
+ new_folds = pkl.load(fdr)
+ return new_folds
+
+ cv_data = load_mlsr_10k(args.data, args.cache)
+ X, y, qid = list(zip(cv_data.train, cv_data.valid, cv_data.test))
+
+ indptr = np.array([0] + [v.shape[0] for v in X])
+ indptr = np.cumsum(indptr)
+
+ assert len(indptr) == 3 + 1 # train, valid, test
+ X_full = sparse.vstack(X)
+ y_full = np.concatenate(y)
+ qid_full = np.concatenate(qid)
+
+ # Skip the data cleaning here for demonstration purposes.
+
+ # Obtain initial relevance score for click simulation
+ scores_full = init_rank_score(X_full, y_full, qid_full)
+ # partition it back to train,valid,test tuple
+ scores = [scores_full[indptr[i - 1] : indptr[i]] for i in range(1, indptr.size)]
+
+ (
+ X_full,
+ y_full,
+ qid_full,
+ scores_ret,
+ clicks_full,
+ position_full,
+ ) = simulate_one_fold((X_full, y_full, qid_full), scores_full)
+
+ scores_check_1 = [
+ scores_ret[indptr[i - 1] : indptr[i]] for i in range(1, indptr.size)
+ ]
+ for i in range(3):
+ assert (scores_check_1[i] == scores[i]).all()
+
+ position = [position_full[indptr[i - 1] : indptr[i]] for i in range(1, indptr.size)]
+ clicks = [clicks_full[indptr[i - 1] : indptr[i]] for i in range(1, indptr.size)]
+
+ with open(cache_path, "wb") as fdw:
+ data = ClickFold(X, y, qid, scores, clicks, position)
+ pkl.dump(data, fdw)
+
+ return data
+
+
+def sort_samples(
+ X: csr_matrix,
+ y: npt.NDArray[np.int32],
+ qid: npt.NDArray[np.int32],
+ clicks: npt.NDArray[np.int32],
+ pos: npt.NDArray[np.int64],
+ cache_path: str,
+) -> Tuple[
+ csr_matrix, npt.NDArray[np.int32], npt.NDArray[np.int32], npt.NDArray[np.int32]
+]:
+ """Sort data based on query index and position."""
+ if os.path.exists(cache_path):
+ print(f"Found existing cache: {cache_path}")
+ with open(cache_path, "rb") as fdr:
+ data = pkl.load(fdr)
+ return data
+
+ s = time()
+ sorted_idx = np.argsort(qid)
+ X = X[sorted_idx]
+ clicks = clicks[sorted_idx]
+ qid = qid[sorted_idx]
+ pos = pos[sorted_idx]
+
+ indptr, lengths, values = rlencode(qid)
+
+ for i in range(1, indptr.size):
+ beg = indptr[i - 1]
+ end = indptr[i]
+
+ assert beg < end, (beg, end)
+ assert np.unique(qid[beg:end]).size == 1, (beg, end)
+
+ query_pos = pos[beg:end]
+ assert query_pos.min() == 0, query_pos.min()
+ assert query_pos.max() >= query_pos.size - 1, (
+ query_pos.max(),
+ query_pos.size,
+ i,
+ np.unique(qid[beg:end]),
+ )
+ sorted_idx = np.argsort(query_pos)
+
+ X[beg:end] = X[beg:end][sorted_idx]
+ clicks[beg:end] = clicks[beg:end][sorted_idx]
+ y[beg:end] = y[beg:end][sorted_idx]
+ # not necessary
+ qid[beg:end] = qid[beg:end][sorted_idx]
+
+ e = time()
+ print("Sort samples:", e - s)
+ data = X, clicks, y, qid
+
+ with open(cache_path, "wb") as fdw:
+ pkl.dump(data, fdw)
+
+ return data
+
+
+def click_data_demo(args: argparse.Namespace) -> None:
+ """Demonstration for learning to rank with click data."""
+ folds = simulate_clicks(args)
+
+ train = [pack[0] for pack in folds]
+ valid = [pack[1] for pack in folds]
+ test = [pack[2] for pack in folds]
+
+ X_train, y_train, qid_train, scores_train, clicks_train, position_train = train
+ assert X_train.shape[0] == clicks_train.size
+ X_valid, y_valid, qid_valid, scores_valid, clicks_valid, position_valid = valid
+ assert X_valid.shape[0] == clicks_valid.size
+ assert scores_valid.dtype == np.float32
+ assert clicks_valid.dtype == np.int32
+ cache_path = os.path.expanduser(args.cache)
+
+ X_train, clicks_train, y_train, qid_train = sort_samples(
+ X_train,
+ y_train,
+ qid_train,
+ clicks_train,
+ position_train,
+ os.path.join(cache_path, "sorted.train.pkl"),
+ )
+ X_valid, clicks_valid, y_valid, qid_valid = sort_samples(
+ X_valid,
+ y_valid,
+ qid_valid,
+ clicks_valid,
+ position_valid,
+ os.path.join(cache_path, "sorted.valid.pkl"),
+ )
+
+ ranker = xgb.XGBRanker(
+ n_estimators=512,
+ tree_method="hist",
+ boost_from_average=0,
+ grow_policy="lossguide",
+ learning_rate=0.1,
+ # LTR specific parameters
+ objective="rank:ndcg",
+ lambdarank_unbiased=True,
+ lambdarank_bias_norm=0.5,
+ lambdarank_num_pair_per_sample=8,
+ lambdarank_pair_method="topk",
+ ndcg_exp_gain=True,
+ eval_metric=["ndcg@1", "ndcg@8", "ndcg@32"],
+ )
+ ranker.fit(
+ X_train,
+ clicks_train,
+ qid=qid_train,
+ eval_set=[(X_train, clicks_train), (X_valid, y_valid)],
+ eval_qid=[qid_train, qid_valid],
+ verbose=True,
+ )
+ X_test = test[0]
+ ranker.predict(X_test)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Demonstration of learning to rank using XGBoost."
+ )
+ parser.add_argument(
+ "--data", type=str, help="Root directory of the MSLR data.", required=True
+ )
+ parser.add_argument(
+ "--cache",
+ type=str,
+ help="Directory for caching processed data.",
+ required=True,
+ )
+ args = parser.parse_args()
+
+ ranking_demo(args)
+ click_data_demo(args)
diff --git a/demo/guide-python/quantile_regression.py b/demo/guide-python/quantile_regression.py
index d92115bf08d7..e6dc0847f4a3 100644
--- a/demo/guide-python/quantile_regression.py
+++ b/demo/guide-python/quantile_regression.py
@@ -2,6 +2,8 @@
Quantile Regression
===================
+ .. versionadded:: 2.0.0
+
The script is inspired by this awesome example in sklearn:
https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_quantile.html
diff --git a/doc/contrib/coding_guide.rst b/doc/contrib/coding_guide.rst
index a080c2a31541..57bf07de4d74 100644
--- a/doc/contrib/coding_guide.rst
+++ b/doc/contrib/coding_guide.rst
@@ -16,8 +16,10 @@ C++ Coding Guideline
* Each line of text may contain up to 100 characters.
* The use of C++ exceptions is allowed.
-- Use C++11 features such as smart pointers, braced initializers, lambda functions, and ``std::thread``.
+- Use C++14 features such as smart pointers, braced initializers, lambda functions, and ``std::thread``.
- Use Doxygen to document all the interface code.
+- We have some comments around symbols imported by headers, some of those are hinted by `include-what-you-use `_. It's not required.
+- We use clang-tidy and clang-format. You can check their configuration in the root directory of the XGBoost source tree.
- We have a series of automatic checks to ensure that all of our codebase complies with the Google style. Before submitting your pull request, you are encouraged to run the style checks on your machine. See :ref:`running_checks_locally`.
***********************
diff --git a/doc/model.schema b/doc/model.schema
index 07a871820b5a..9d4c74607342 100644
--- a/doc/model.schema
+++ b/doc/model.schema
@@ -238,6 +238,16 @@
"num_pairsample": { "type": "string" },
"fix_list_weight": { "type": "string" }
}
+ },
+ "lambdarank_param": {
+ "type": "object",
+ "properties": {
+ "lambdarank_num_pair_per_sample": { "type": "string" },
+ "lambdarank_pair_method": { "type": "string" },
+ "lambdarank_unbiased": {"type": "string" },
+ "lambdarank_bias_norm": {"type": "string" },
+ "ndcg_exp_gain": {"type": "string"}
+ }
}
},
"type": "object",
@@ -496,22 +506,22 @@
"type": "object",
"properties": {
"name": { "const": "rank:pairwise" },
- "lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
+ "lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
},
"required": [
"name",
- "lambda_rank_param"
+ "lambdarank_param"
]
},
{
"type": "object",
"properties": {
"name": { "const": "rank:ndcg" },
- "lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"}
+ "lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"}
},
"required": [
"name",
- "lambda_rank_param"
+ "lambdarank_param"
]
},
{
diff --git a/doc/parameter.rst b/doc/parameter.rst
index 99d6f0585936..38a3bf37b70d 100644
--- a/doc/parameter.rst
+++ b/doc/parameter.rst
@@ -363,8 +363,7 @@ Specify the learning task and the corresponding learning objective. The objectiv
- ``aft_loss_distribution``: Probability Density Function used by ``survival:aft`` objective and ``aft-nloglik`` metric.
- ``multi:softmax``: set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes)
- ``multi:softprob``: same as softmax, but output a vector of ``ndata * nclass``, which can be further reshaped to ``ndata * nclass`` matrix. The result contains predicted probability of each data point belonging to each class.
- - ``rank:pairwise``: Use LambdaMART to perform pairwise ranking where the pairwise loss is minimized
- - ``rank:ndcg``: Use LambdaMART to perform list-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) `_ is maximized
+ - ``rank:ndcg``: Use LambdaMART to perform list-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) `_ is maximized. This objective supports position debiasing for click data.
- ``rank:map``: Use LambdaMART to perform list-wise ranking where `Mean Average Precision (MAP) `_ is maximized
- ``reg:gamma``: gamma regression with log-link. Output is a mean of gamma distribution. It might be useful, e.g., for modeling insurance claims severity, or for any outcome that might be `gamma-distributed `_.
- ``reg:tweedie``: Tweedie regression with log-link. It might be useful, e.g., for modeling total loss in insurance, or for any outcome that might be `Tweedie-distributed `_.
@@ -378,8 +377,9 @@ Specify the learning task and the corresponding learning objective. The objectiv
* ``eval_metric`` [default according to objective]
- - Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, mean average precision for ranking)
- - User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous one
+ - Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, `mean average precision` for ``rank:map``, etc.)
+ - User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous ones
+
- The choices are listed below:
- ``rmse``: `root mean square error `_
@@ -408,8 +408,17 @@ Specify the learning task and the corresponding learning objective. The objectiv
- ``ndcg``: `Normalized Discounted Cumulative Gain `_
- ``map``: `Mean Average Precision `_
- - ``ndcg@n``, ``map@n``: 'n' can be assigned as an integer to cut off the top positions in the lists for evaluation.
- - ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, NDCG and MAP will evaluate the score of a list without any positive samples as 1. By adding "-" in the evaluation metric XGBoost will evaluate these score as 0 to be consistent under some conditions.
+
+ The `average precision` is defined as:
+
+ .. math::
+
+ AP@l = \frac{1}{min{(l, N)}}\sum^l_{k=1}P@k \cdot I_{(k)}
+
+ where :math:`I_{(k)}` is an indicator function that equals to :math:`1` when the document at :math:`k` is relevant and :math:`0` otherwise. The :math:`P@k` is the precision at :math:`k`, and :math:`N` is the total number of relevant documents. Lastly, the `mean average precision` is defined as the weighted average across all queries.
+
+ - ``ndcg@n``, ``map@n``: :math:`n` can be assigned as an integer to cut off the top positions in the lists for evaluation.
+ - ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, the NDCG and MAP evaluate the score of a list without any positive samples as :math:`1`. By appending "-" to the evaluation metric name, we can ask XGBoost to evaluate these scores as :math:`0` to be consistent under some conditions.
- ``poisson-nloglik``: negative log-likelihood for Poisson regression
- ``gamma-nloglik``: negative log-likelihood for gamma regression
- ``cox-nloglik``: negative partial log-likelihood for Cox proportional hazards regression
@@ -447,6 +456,37 @@ Parameter for using Quantile Loss (``reg:quantileerror``)
* ``quantile_alpha``: A scala or a list of targeted quantiles.
+
+.. _ltr-param:
+
+Parameters for learning to rank (``rank:ndcg``, ``rank:map``, ``rank:pairwise``)
+================================================================================
+
+These are parameters specific to learning to rank task. See :doc:`Learning to Rank ` for an in-depth explanation.
+
+* ``lambdarank_pair_method`` [default = ``mean``]
+
+ How to construct pairs for pair-wise learning.
+
+ - ``mean``: Sample ``lambdarank_num_pair_per_sample`` pairs for each document in the query list.
+ - ``topk``: Focus on top-``lambdarank_num_pair_per_sample`` documents. Construct :math:`|query|` pairs for each document at the top-``lambdarank_num_pair_per_sample`` ranked by the model.
+
+* ``lambdarank_num_pair_per_sample`` [range = :math:`[1, \infty]`]
+
+ It specifies the number of pairs sampled for each document when pair method is ``mean``, or the truncation level for queries when the pair method is ``topk``. For example, to train with ``ndcg@6``, set ``lambdarank_num_pair_per_sample`` to :math:`6` and ``lambdarank_pair_method`` to ``topk``.
+
+* ``lambdarank_unbiased`` [default = ``false``]
+
+ Specify whether do we need to debias input click data.
+
+* ``lambdarank_bias_norm`` [default = 2.0]
+
+ :math:`L_p` normalization for position debiasing, default is :math:`L_2`. Only relevant when ``lambdarank_unbiased`` is set to true.
+
+* ``ndcg_exp_gain`` [default = ``true``]
+
+ Whether we should use exponential gain function for ``NDCG``. There are two forms of gain function for ``NDCG``, one is using relevance value directly while the other is using :math:`2^{rel} - 1` to emphasize on retrieving relevant documents. When ``ndcg_exp_gain`` is true (the default), relevance degree cannot be greater than 31.
+
***********************
Command Line Parameters
***********************
diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst
index 87b2bf9968b9..fd5155f468cd 100644
--- a/doc/tutorials/dask.rst
+++ b/doc/tutorials/dask.rst
@@ -503,6 +503,7 @@ dask config is used:
reg = dxgb.DaskXGBRegressor()
+Please note that XGBoost requires a different port than dask. By default, on a unix-like system XGBoost uses the port 0 to find available ports, which may fail if a user is running in a docker environment where ports are restricted.
************
IPv6 Support
diff --git a/doc/tutorials/index.rst b/doc/tutorials/index.rst
index 310fd0170610..eb8c23726d56 100644
--- a/doc/tutorials/index.rst
+++ b/doc/tutorials/index.rst
@@ -21,6 +21,7 @@ See `Awesome XGBoost `_ for mo
monotonic
rf
feature_interaction_constraint
+ learning_to_rank
aft_survival_analysis
c_api_tutorial
input_format
diff --git a/doc/tutorials/learning_to_rank.rst b/doc/tutorials/learning_to_rank.rst
new file mode 100644
index 000000000000..5afd9314f055
--- /dev/null
+++ b/doc/tutorials/learning_to_rank.rst
@@ -0,0 +1,177 @@
+################
+Learning to Rank
+################
+
+**Contents**
+
+.. contents::
+ :local:
+ :backlinks: none
+
+********
+Overview
+********
+Often in the context of information retrieval, learning to rank aims to train a model that arranges a set of query results into an ordered list `[1] <#references>`__. For surprivised learning to rank, the predictors are sample documents encoded as feature matrix, and the labels are relevance degree for each sample. Relevance degree can be multi-level (graded) or binary (relevant or not). The training samples are often grouped by their query index with each query group containing multiple query results.
+
+XGBoost implements learning to rank through a set of objective functions and performane metrics. The default objective is ``rank:ndcg`` based on the ``LambdaMART`` `[2] <#references>`__ algorithm, which in turn is an adaptation of the ``LambdaRank`` `[3] <#references>`__ framework to gradient boosting trees. For a history and a summary of the algorithm, see `[5] <#references>`__. The implementation in XGBoost features deterministic GPU computation, distributed training, position debiasing and two different pair construction strategies.
+
+************************************
+Training with the Pariwise Objective
+************************************
+``LambdaMART`` is a pairwise ranking model, meaning that it compares the relevance degree for every pair of samples in a query group and calculate a proxy gradient for each pair. The default objective ``rank:ndcg`` is using the surrogate gradient derived from the ``ndcg`` metric. To train a XGBoost model, we need an additional sorted array called ``qid`` for specifying the query group of input samples. An example input would look like this:
+
++-------+-----------+---------------+
+| QID | Label | Features |
++=======+===========+===============+
+| 1 | 0 | :math:`x_1` |
++-------+-----------+---------------+
+| 1 | 1 | :math:`x_2` |
++-------+-----------+---------------+
+| 1 | 0 | :math:`x_3` |
++-------+-----------+---------------+
+| 2 | 0 | :math:`x_4` |
++-------+-----------+---------------+
+| 2 | 1 | :math:`x_5` |
++-------+-----------+---------------+
+| 2 | 1 | :math:`x_6` |
++-------+-----------+---------------+
+| 2 | 1 | :math:`x_7` |
++-------+-----------+---------------+
+
+Notice that the samples are sorted based on their query index in an non-decreasing order. Here the first three samples belong to the first query and the next four samples belong to the second. For the sake of simplicity, we will use a pseudo binary learning to rank dataset in the following snippets, with binary labels representing whether the result is relevant or not, and randomly assign the query group index to each sample. For an example that uses a real world dataset, please see :ref:`sphx_glr_python_examples_learning_to_rank.py`.
+
+.. code-block:: python
+
+ from sklearn.datasets import make_classification
+ import numpy as np
+
+ import xgboost as xgb
+
+ # Make a pseudo ranking dataset for demonstration
+ X, y = make_classification(random_state=rng)
+ rng = np.random.default_rng(1994)
+ n_query_groups = 3
+ qid = rng.integers(0, 3, size=X.shape[0])
+
+ # Sort the inputs based on query index
+ sorted_idx = np.argsort(qid)
+ X = X[sorted_idx, :]
+ y = y[sorted_idx]
+
+The simpliest way to train a ranking model is by using the sklearn estimator interface. Continuing the previous snippet, we can train a simple ranking model without tuning:
+
+.. code-block:: python
+
+ ranker = xgb.XGBRanker(tree_method="hist", lambdarank_num_pair_per_sample=8, objective="rank:ndcg", lambdarank_pair_method="topk")
+ ranker.fit(X, y, qid=qid)
+
+Please note that, as of writing, there's no learning to rank interface in sklearn. As a result, the :py:class:`xgboost.XGBRanker` does not fully conform the sklearn estimator guideline and can not be directly used with some of its utility functions. For instances, the ``auc_score`` and ``ndcg_score`` in sklearn don't consider group information nor the pairwise loss. Most of the metrics are implemented as part of XGBoost, but to use sklearn utilities like :py:func:`sklearn.model_selection.cross_validation`, we need to make some adjustments in order to pass the `qid` as an additional parameter for :py:meth:`xgboost.XGBRanker.score`. The `X` for :py:class:`xgboost.XGBRanker` may contain a special column called ``qid`` when it's a pandas dataframe or a cuDF dataframe:
+
+.. code-block:: python
+
+ df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1]))
+ df["qid"] = qid
+ ranker.fit(df, y) # No need to pass qid as a separate argument
+
+ from sklearn.model_selection import StratifiedGroupKFold, cross_val_score
+ # Works with cv in sklearn, along with HPO utilities like grid search cv.
+ kfold = StratifiedGroupKFold(shuffle=False)
+ cross_val_score(ranker, df, y, cv=kfold, groups=df.qid)
+
+The above snippets build a model using ``LambdaMART`` with the ``NDCG@8`` metric. The outputs of a ranker are relevance scores:
+
+.. code-block:: python
+
+ scores = ranker.predict(X)
+ sorted_idx = np.argsort(scores)[::-1]
+ # Sort the relevance scores from most relevant to least relevant
+ scores = scores[sorted_idx]
+
+
+*************
+Position Bias
+*************
+Real relevance degree for query result is difficult to obtain as it often requires human judegs to examine the content of query results. When such labeled data is absent, we might want to train the model on ground truth data like user clicks. Another upside of using click data directly is that it can relect the up-to-date relevance status `[1] <#references>`__. However, user clicks are often nosiy and biased as users tend to choose results displayed in higher position. To ameliorate this issue, XGBoost implements the ``Unbiased LambdaMART`` `[4] <#references>`__ algorithm to debias the position-dependent click data. The feature can be enabled by the ``lambdarank_unbiased`` parameter, see :ref:`ltr-param` for related options and :ref:`sphx_glr_python_examples_learning_to_rank.py` for a worked example with simulated user clicks.
+
+****
+Loss
+****
+
+XGBoost implements different ``LambdaMART`` objectives based on different metrics. We list them here as a reference. Other than those used as objective function, XGBoost also implements metrics like ``pre`` (for precision) for evaluation. See :doc:`parameters ` for available options and the following sections for how to choose these objectives based of the amount of effective pairs.
+
+* NDCG
+`Normalized Discounted Cumulative Gain` ``NDCG`` can be used with both binary relevance and multi-level relevance. If you are not sure about your data, this metric can be used as the default. The name for the objective is ``rank:ndcg``.
+
+
+* MAP
+`Mean average precision` ``MAP`` is a binary measure. It can be used when the relevance label is 0 or 1. The name for the objective is ``rank:map``.
+
+
+* Pairwise
+
+The `LambdaMART` algorithm scales the logistic loss with learning to rank metrics like ``NDCG`` in the hope of including ranking infomation into the loss function. The ``rank:pairwise`` loss is the orginal version of the pairwise loss, also known as the `RankNet loss` `[7] <#references>`__ or the `pairwise logistic loss`. Unlike the ``rank:map`` and the ``rank:ndcg``, no scaling is applied (:math:`|\Delta Z_{ij}| = 1`).
+
+Whether scaling with a LTR metric is actually more effective is still up for debate, `[8] <#references>`__ provides a theoretical foundation for general lambda loss functions and some insights into the framework.
+
+******************
+Constructing Pairs
+******************
+
+There are two implemented strategies for constructing document pairs for :math:`\lambda`-gradient calculation. The first one is the ``mean`` method, another one is the ``topk`` method. The preferred strategy can be specified by the ``lambdarank_pair_method`` parameter.
+
+For the ``mean`` strategy, XGBoost samples ``lambdarank_num_pair_per_sample`` pairs for each document in a query list. For example, given a list of 3 documents and ``lambdarank_num_pair_per_sample`` is set to 2, XGBoost will randomly sample 6 pairs assuming the labels for these documents are different. On the other hand, if the pair method is set to ``topk``, XGBoost constructs about :math:`k \times |query|` number of pairs with :math:`|query|` pairs for each sample at the top :math:`k = lambdarank\_num\_pair` position. The number of pairs counted here is an approximation since we skip pairs that have the sample label.
+
+*********************
+Obtaining Good Result
+*********************
+
+Learning to rank is a sophisticated task and a field of heated research. It's not trivial to train a model that generalizes well. There are multiple loss functions available in XGBoost along with a set of hyper-parameters. This section contains some hints for how to choose those parameters as a starting point. One can further optimize the model by tuning these parameters.
+
+The first question would be how to choose an objective that matches the task at hand. If your input data is multi-level relevance degree, then either ``rank:ndcg`` or ``rank:pairwise`` should be used. However, when the input is binary we have multiple options based on the target metric. `[6] <#references>`__ provides some guidelines on this topic and users are encouraged to see the analysis done in their work. The choice should be based on the number of `effective pairs`, which refers to the number of pairs that can generate non-zero gradient and contribute to training. `LambdaMART` with ``MRR`` has the least amount of effective pairs as the :math:`\lambda`-gradient is only non-zero when the pair contains a non-relevant document ranked higher than the top relevant document. As a result, it's not implemented in XGBoost. Since ``NDCG`` is a multi-level metric, it usually generate more effective pairs than ``MAP``.
+
+However, when there's a sufficient amount of effective pairs, it's shown in `[6] <#references>`__ that matching the target metric with the objective is of significance. When the target metric is ``MAP`` and you are using a large dataset that can provide a sufficient amount of effective pairs, ``rank:map`` can in theory yield higher ``MAP`` value than the ``rank:ndcg``.
+
+The choice of pair method (``lambdarank_pair_method``) and the number of pairs for each sample (``lambdarank_num_pair_per_sample``) is similar, as the mean-``NDCG`` considers more pairs than ``NDCG@10``, it can generate more effective pairs and provide more granularity. Also, using the ``mean`` strategy can help the model generalize with random sampling. However, one might want to focus the training on the top :math:`k` documents instead of using all pairs in practice, the tradeoff should be made based on the user's goal.
+
+When using mean value instead of targeting a specific position by calculating the target metric (like ``NDCG``) over the whole query list, user can specify how many pairs they want in each query by setting the ``lambdarank_num_pair_per_sample`` and XGBoost will randomly sample this amount of pairs for each element in the query group (:math:`|pairs| = |query| \times num\_pairsample`). Often time, setting it to 1 can produce reasonable result, with higher value producing more pairs (with the hope that a reasonable amount of them being effective). On the other hand, if you are prioritizing the top :math:`k` documents, the ``lambdarank_num_pair_per_sample`` should be set to slightly higher than :math:`k` (with a few more documents) to obtain a good training result.
+
+In summary, to start off the training, if you have a large dataset, consider using the target-matching objective, otherwise ``NDCG`` or the RankNet loss (``rank:pairwise``) might be preferred. With the same target metric, use the ``lambdarank_num_pair_per_sample`` to specify the top :math:`k` documents for training if your dataset is large, and use the mean value version otherwise. Lastly, ``lambdarank_num_pair_per_sample`` can be used to control the amount of pairs for both methods.
+
+********************
+Distributed Training
+********************
+XGBoost implements distributed learning-to-rank with integration of multiple frameworks including dask, spark, and pyspark. The interface is similar to single node. Please refer to document of the respective XGBoost interface for details. Scattering a query group onto multiple workers is theoretically sound but can affect the model accuracy. For most of the use cases, the small discrepancy is not an issue since when distributed training is involved the dataset is usually large. As a result, users don't need to partition the data based on group information. Given the dataset is correctly sorted, XGBoost can aggregate sample gradients accordingly.
+
+*******************
+Reproducbile Result
+*******************
+
+Like any other tasks, XGBoost should generate reproducbile results given the same hardware and software environments, along with data partitions if distributed interface is used. Even when the underlying environment has changed, the result should still be consistent. However, when the ``lambdarank_pair_method`` is set to ``mean``, XGBoost uses sampling, and the random number generator used on Windows (MSVC) is different from the one used on other platforms like Linux (GCC, Clang), the output varies significantly between these platforms.
+
+**********
+References
+**********
+
+[1] Tie-Yan Liu. 2009. "`Learning to Rank for Information Retrieval`_". Found. Trends Inf. Retr. 3, 3 (March 2009), 225–331.
+
+[2] Christopher J. C. Burges, Robert Ragno, and Quoc Viet Le. 2006. "`Learning to rank with nonsmooth cost functions`_". In Proceedings of the 19th International Conference on Neural Information Processing Systems (NIPS'06). MIT Press, Cambridge, MA, USA, 193–200.
+
+[3] Wu, Q., Burges, C.J.C., Svore, K.M. et al. "`Adapting boosting for information retrieval measures`_". Inf Retrieval 13, 254–270 (2010).
+
+[4] Ziniu Hu, Yang Wang, Qu Peng, Hang Li. "`Unbiased LambdaMART: An Unbiased Pairwise Learning-to-Rank Algorithm`_". Proceedings of the 2019 World Wide Web Conference.
+
+[5] Burges, Chris J.C. "`From RankNet to LambdaRank to LambdaMART: An Overview`_". MSR-TR-2010-82
+
+[6] Pinar Donmez, Krysta M. Svore, and Christopher J.C. Burges. 2009. "`On the local optimality of LambdaRank`_". In Proceedings of the 32nd international ACM SIGIR conference on Research and development in information retrieval (SIGIR '09). Association for Computing Machinery, New York, NY, USA, 460–467.
+
+[7] Chris Burges, Tal Shaked, Erin Renshaw, Ari Lazier, Matt Deeds, Nicole Hamilton, and Greg Hullender. 2005. "`Learning to rank using gradient descent`_". In Proceedings of the 22nd international conference on Machine learning (ICML '05). Association for Computing Machinery, New York, NY, USA, 89–96.
+
+[8] Xuanhui Wang and Cheng Li and Nadav Golbandi and Mike Bendersky and Marc Najork. 2018. "`The LambdaLoss Framework for Ranking Metric Optimization`_". Proceedings of The 27th ACM International Conference on Information and Knowledge Management (CIKM '18).
+
+.. _`Learning to Rank for Information Retrieval`: https://doi.org/10.1561/1500000016
+.. _`Learning to rank with nonsmooth cost functions`: https://dl.acm.org/doi/10.5555/2976456.2976481
+.. _`Adapting boosting for information retrieval measures`: https://doi.org/10.1007/s10791-009-9112-1
+.. _`Unbiased LambdaMART: An Unbiased Pairwise Learning-to-Rank Algorithm`: https://dl.acm.org/doi/10.1145/3308558.3313447
+.. _`From RankNet to LambdaRank to LambdaMART: An Overview`: https://www.microsoft.com/en-us/research/publication/from-ranknet-to-lambdarank-to-lambdamart-an-overview/
+.. _`On the local optimality of LambdaRank`: https://doi.org/10.1145/1571941.1572021
+.. _`Learning to rank using gradient descent`: https://doi.org/10.1145/1102351.1102363
diff --git a/include/xgboost/cache.h b/include/xgboost/cache.h
index 781f45b1c474..66dec55f96a4 100644
--- a/include/xgboost/cache.h
+++ b/include/xgboost/cache.h
@@ -15,6 +15,7 @@
#include // for move
#include // for vector
+
namespace xgboost {
class DMatrix;
/**
@@ -149,6 +150,26 @@ class DMatrixCache {
}
return container_.at(key).value;
}
+ /**
+ * \brief Re-initialize the item in cache.
+ *
+ * Since the shared_ptr is used to hold the item, any reference that lives outside of
+ * the cache can no-longer be reached from the cache.
+ *
+ * We use reset instead of erase to avoid walking through the whole cache for renewing
+ * a single item. (the cache is FIFO, needs to maintain the order).
+ */
+ template
+ std::shared_ptr ResetItem(std::shared_ptr m, Args const&... args) {
+ std::lock_guard guard{lock_};
+ CheckConsistent();
+ auto key = Key{m.get(), std::this_thread::get_id()};
+ auto it = container_.find(key);
+ CHECK(it != container_.cend());
+ it->second = {m, std::make_shared(args...)};
+ CheckConsistent();
+ return it->second.value;
+ }
/**
* \brief Get a const reference to the underlying hash map. Clear expired caches before
* returning.
diff --git a/include/xgboost/data.h b/include/xgboost/data.h
index ec78c588d5d9..d7abb9db39bd 100644
--- a/include/xgboost/data.h
+++ b/include/xgboost/data.h
@@ -1,5 +1,5 @@
-/*!
- * Copyright (c) 2015-2022 by XGBoost Contributors
+/**
+ * Copyright 2015-2023 by XGBoost Contributors
* \file data.h
* \brief The input data structure of xgboost.
* \author Tianqi Chen
@@ -17,6 +17,8 @@
#include
#include
+#include // std::size_t
+#include // std::uint64_t
#include
#include
#include
@@ -60,9 +62,8 @@ class MetaInfo {
linalg::Tensor labels;
/*! \brief data split mode */
DataSplitMode data_split_mode{DataSplitMode::kRow};
- /*!
- * \brief the index of begin and end of a group
- * needed when the learning task is ranking.
+ /**
+ * \brief the index of begin and end of a group, needed when the learning task is ranking.
*/
std::vector group_ptr_; // NOLINT
/*! \brief weights of each instance, optional */
diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h
index a04d2e453df7..4fecf56884b2 100644
--- a/include/xgboost/objective.h
+++ b/include/xgboost/objective.h
@@ -11,6 +11,7 @@
#include
#include
#include
+#include // for Json, Null
#include
#include
diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala
index 5342aa563621..b8dca5d7040e 100644
--- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala
+++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala
@@ -220,7 +220,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite {
test("Ranking: train with Group") {
withGpuSparkSession(enableCsvConf()) { spark =>
- val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "rank:pairwise",
+ val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "rank:ndcg",
"num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist",
"features_cols" -> featureNames, "label_col" -> labelName)
val Array(trainingDf, testDf) = spark.read.option("header", "true").schema(schema)
diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala
index f63865fabc2d..6aec4d36ed6f 100644
--- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala
+++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala
@@ -25,7 +25,7 @@ private[spark] trait LearningTaskParams extends Params {
/**
* Specify the learning task and the corresponding learning objective.
* options: reg:squarederror, reg:squaredlogerror, reg:logistic, binary:logistic, binary:logitraw,
- * count:poisson, multi:softmax, multi:softprob, rank:pairwise, reg:gamma.
+ * count:poisson, multi:softmax, multi:softprob, rank:ndcg, reg:gamma.
* default: reg:squarederror
*/
final val objective = new Param[String](this, "objective",
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
index 0bf8c2fbb426..3e6879d83148 100755
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala
@@ -201,7 +201,7 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
sc,
buildTrainingRDD,
List("eta" -> "1", "max_depth" -> "6",
- "objective" -> "rank:pairwise", "num_round" -> 5, "num_workers" -> numWorkers,
+ "objective" -> "rank:ndcg", "num_round" -> 5, "num_workers" -> numWorkers,
"custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false,
"missing" -> Float.NaN).toMap)
@@ -268,7 +268,7 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
val training = buildDataFrameWithGroup(Ranking.train, 5)
val Array(train, eval1, eval2) = training.randomSplit(Array(0.6, 0.2, 0.2), 0)
val paramMap1 = Map("eta" -> "1", "max_depth" -> "6",
- "objective" -> "rank:pairwise",
+ "objective" -> "rank:ndcg",
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group")
val xgb1 = new XGBoostRegressor(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2))
val model1 = xgb1.fit(train)
@@ -281,7 +281,7 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest {
assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1))
val paramMap2 = Map("eta" -> "1", "max_depth" -> "6",
- "objective" -> "rank:pairwise",
+ "objective" -> "rank:ndcg",
"num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group",
"eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2))
val xgb2 = new XGBoostRegressor(paramMap2)
diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala
index 4e3d59b25e57..b8b6aeefa0e0 100644
--- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala
+++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala
@@ -121,7 +121,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest with TmpFolderPerSuite
test("ranking: use group data") {
val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1",
- "objective" -> "rank:pairwise", "num_workers" -> numWorkers, "num_round" -> 5,
+ "objective" -> "rank:ndcg", "num_workers" -> numWorkers, "num_round" -> 5,
"group_col" -> "group", "tree_method" -> treeMethod)
val trainingDF = buildDataFrameWithGroup(Ranking.train)
diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py
index 5a0cfb3a2ece..4b1e8be1b5af 100644
--- a/python-package/xgboost/core.py
+++ b/python-package/xgboost/core.py
@@ -288,10 +288,10 @@ def _check_call(ret: int) -> None:
def build_info() -> dict:
- """Build information of XGBoost. The returned value format is not stable. Also, please
- note that build time dependency is not the same as runtime dependency. For instance,
- it's possible to build XGBoost with older CUDA version but run it with the lastest
- one.
+ """Build information of XGBoost. The returned value format is not stable. Also,
+ please note that build time dependency is not the same as runtime dependency. For
+ instance, it's possible to build XGBoost with older CUDA version but run it with the
+ lastest one.
.. versionadded:: 1.6.0
@@ -675,28 +675,28 @@ def __init__(
data :
Data source of DMatrix. See :ref:`py-data` for a list of supported input
types.
- label : array_like
+ label :
Label of the training data.
- weight : array_like
+ weight :
Weight for each instance.
- .. note:: For ranking task, weights are per-group.
+ .. note::
- In ranking task, one weight is assigned to each group (not each
- data point). This is because we only care about the relative
- ordering of data points within each group, so it doesn't make
- sense to assign weights to individual data points.
+ For ranking task, weights are per-group. In ranking task, one weight
+ is assigned to each group (not each data point). This is because we
+ only care about the relative ordering of data points within each group,
+ so it doesn't make sense to assign weights to individual data points.
- base_margin: array_like
+ base_margin :
Base margin used for boosting from existing model.
- missing : float, optional
- Value in the input data which needs to be present as a missing
- value. If None, defaults to np.nan.
- silent : boolean, optional
+ missing :
+ Value in the input data which needs to be present as a missing value. If
+ None, defaults to np.nan.
+ silent :
Whether print messages during construction
- feature_names : list, optional
+ feature_names :
Set names for features.
- feature_types : FeatureTypes
+ feature_types :
Set types for features. When `enable_categorical` is set to `True`, string
"c" represents categorical data type while "q" represents numerical feature
@@ -706,20 +706,20 @@ def __init__(
`.cat.codes` method. This is useful when users want to specify categorical
features without having to construct a dataframe as input.
- nthread : integer, optional
+ nthread :
Number of threads to use for loading data when parallelization is
applicable. If -1, uses maximum threads available on the system.
- group : array_like
+ group :
Group size for all ranking group.
- qid : array_like
+ qid :
Query ID for data samples, used for ranking.
- label_lower_bound : array_like
+ label_lower_bound :
Lower bound for survival training.
- label_upper_bound : array_like
+ label_upper_bound :
Upper bound for survival training.
- feature_weights : array_like, optional
+ feature_weights :
Set feature weights for column sampling.
- enable_categorical: boolean, optional
+ enable_categorical :
.. versionadded:: 1.3.0
@@ -1760,6 +1760,7 @@ def save_config(self) -> str:
string.
.. versionadded:: 1.0.0
+
"""
json_string = ctypes.c_char_p()
length = c_bst_ulong()
diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py
index 3204f5a2a61e..75433ca5b320 100644
--- a/python-package/xgboost/sklearn.py
+++ b/python-package/xgboost/sklearn.py
@@ -1830,7 +1830,11 @@ def _get_qid(
@xgboost_model_doc(
- """Implementation of the Scikit-Learn API for XGBoost Ranking.""",
+ """Implementation of the Scikit-Learn API for XGBoost Ranking.
+
+See :doc:`Learning to Rank ` for an introducion.
+
+ """,
["estimators", "model"],
end_note="""
.. note::
@@ -1879,7 +1883,7 @@ def _get_qid(
class XGBRanker(XGBModel, XGBRankerMixIn):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
@_deprecate_positional_args
- def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any):
+ def __init__(self, *, objective: str = "rank:ndcg", **kwargs: Any):
super().__init__(objective=objective, **kwargs)
if callable(self.objective):
raise ValueError("custom objective function not supported by XGBRanker")
diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py
index 3b33e87749f3..bb13b5523ed2 100644
--- a/python-package/xgboost/testing/__init__.py
+++ b/python-package/xgboost/testing/__init__.py
@@ -14,6 +14,7 @@
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from io import StringIO
+from pathlib import Path
from platform import system
from typing import (
Any,
@@ -443,7 +444,7 @@ def get_mq2008(
from sklearn.datasets import load_svmlight_files
src = "https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip"
- target = dpath + "/MQ2008.zip"
+ target = os.path.join(os.path.expanduser(dpath), "MQ2008.zip")
if not os.path.exists(target):
request.urlretrieve(url=src, filename=target)
@@ -462,9 +463,9 @@ def get_mq2008(
qid_valid,
) = load_svmlight_files(
(
- dpath + "MQ2008/Fold1/train.txt",
- dpath + "MQ2008/Fold1/test.txt",
- dpath + "MQ2008/Fold1/vali.txt",
+ Path(dpath) / "MQ2008" / "Fold1" / "train.txt",
+ Path(dpath) / "MQ2008" / "Fold1" / "test.txt",
+ Path(dpath) / "MQ2008" / "Fold1" / "vali.txt",
),
query_id=True,
zero_based=False,
diff --git a/python-package/xgboost/testing/params.py b/python-package/xgboost/testing/params.py
index 3af3306da40e..fe730f74f51f 100644
--- a/python-package/xgboost/testing/params.py
+++ b/python-package/xgboost/testing/params.py
@@ -47,3 +47,15 @@
"max_cat_threshold": strategies.integers(1, 128),
}
)
+
+lambdarank_parameter_strategy = strategies.fixed_dictionaries(
+ {
+ "lambdarank_unbiased": strategies.sampled_from([True, False]),
+ "lambdarank_pair_method": strategies.sampled_from(["topk", "mean"]),
+ "lambdarank_num_pair_per_sample": strategies.integers(1, 8),
+ "lambdarank_bias_norm": strategies.floats(0.5, 2.0),
+ "objective": strategies.sampled_from(
+ ["rank:ndcg", "rank:map", "rank:pairwise"]
+ ),
+ }
+)
diff --git a/src/common/algorithm.h b/src/common/algorithm.h
index 739a84968d45..2f5fccedd592 100644
--- a/src/common/algorithm.h
+++ b/src/common/algorithm.h
@@ -3,14 +3,17 @@
*/
#ifndef XGBOOST_COMMON_ALGORITHM_H_
#define XGBOOST_COMMON_ALGORITHM_H_
-#include // upper_bound, stable_sort, sort, max
-#include // size_t
-#include // less
-#include // iterator_traits, distance
-#include // vector
+#include // for upper_bound, stable_sort, sort, max, all_of, none_of, min
+#include // for size_t
+#include // for less
+#include // for iterator_traits, distance
+#include // for is_same
+#include // for vector
-#include "numeric.h" // Iota
-#include "xgboost/context.h" // Context
+#include "common.h" // for DivRoundUp
+#include "numeric.h" // for Iota
+#include "threading_utils.h" // for MemStackAllocator, DefaultMaxThreads, ParallelFor
+#include "xgboost/context.h" // for Context
// clang with libstdc++ works as well
#if defined(__GNUC__) && (__GNUC__ >= 4) && !defined(__sun) && !defined(sun) && \
@@ -83,6 +86,55 @@ std::vector ArgSort(Context const *ctx, Iter begin, Iter end, Comp comp = s
StableSort(ctx, result.begin(), result.end(), op);
return result;
}
+
+namespace detail {
+template
+bool Logical(Context const *ctx, It first, It last, Op op) {
+ auto n = std::distance(first, last);
+ auto n_threads =
+ std::max(std::min(n, static_cast(ctx->Threads())), static_cast(1));
+ common::MemStackAllocator tloc{
+ static_cast(n_threads), false};
+ CHECK_GE(n, 0);
+ CHECK_GE(ctx->Threads(), 1);
+ static_assert(std::is_same::value, "");
+ auto const n_per_thread = common::DivRoundUp(n, ctx->Threads());
+ common::ParallelFor(static_cast(n_threads), n_threads, [&](auto t) {
+ auto begin = t * n_per_thread;
+ auto end = std::min(begin + n_per_thread, n);
+
+ auto first_tloc = first + begin;
+ auto last_tloc = first + end;
+ if (first_tloc >= last_tloc) {
+ tloc[t] = true;
+ return;
+ }
+ bool result = op(first_tloc, last_tloc);
+ tloc[t] = result;
+ });
+ return std::all_of(tloc.cbegin(), tloc.cend(), [](auto v) { return v; });
+}
+} // namespace detail
+
+/**
+ * \brief Parallel version of std::none_of
+ */
+template
+bool NoneOf(Context const *ctx, It first, It last, Pred predicate) {
+ return detail::Logical(ctx, first, last, [&predicate](It first, It last) {
+ return std::none_of(first, last, predicate);
+ });
+}
+
+/**
+ * \brief Parallel version of std::all_of
+ */
+template
+bool AllOf(Context const *ctx, It first, It last, Pred predicate) {
+ return detail::Logical(ctx, first, last, [&predicate](It first, It last) {
+ return std::all_of(first, last, predicate);
+ });
+}
} // namespace common
} // namespace xgboost
diff --git a/src/common/error_msg.h b/src/common/error_msg.h
new file mode 100644
index 000000000000..c6f1e3d0d41c
--- /dev/null
+++ b/src/common/error_msg.h
@@ -0,0 +1,26 @@
+/**
+ * Copyright 2023 by XGBoost contributors
+ *
+ * \brief Common error message for various checks.
+ */
+#ifndef XGBOOST_COMMON_ERROR_MSG_H_
+#define XGBOOST_COMMON_ERROR_MSG_H_
+
+#include "xgboost/string_view.h" // for StringView
+
+namespace xgboost {
+namespace error {
+constexpr StringView GroupWeight() {
+ return "Size of weight must equal to the number of query groups when ranking group is used.";
+}
+
+constexpr StringView GroupSize() {
+ return "Invalid query group structure. The number of rows obtained from group doesn't equal to ";
+}
+
+constexpr StringView LabelScoreSize() {
+ return "The size of label doesn't match the size of prediction.";
+}
+} // namespace error
+} // namespace xgboost
+#endif // XGBOOST_COMMON_ERROR_MSG_H_
diff --git a/src/common/math.h b/src/common/math.h
index 71a494544be1..25befdbef218 100644
--- a/src/common/math.h
+++ b/src/common/math.h
@@ -1,5 +1,5 @@
-/*!
- * Copyright 2015 by Contributors
+/**
+ * Copyright 2015-2023 by XGBoost Contributors
* \file math.h
* \brief additional math utils
* \author Tianqi Chen
@@ -7,16 +7,19 @@
#ifndef XGBOOST_COMMON_MATH_H_
#define XGBOOST_COMMON_MATH_H_
-#include
+#include // for XGBOOST_DEVICE
-#include
-#include
-#include
-#include
-#include
+#include // for max
+#include // for exp, abs, log, lgamma
+#include // for numeric_limits
+#include // for is_floating_point, conditional, is_signed, is_same, declval, enable_if
+#include // for pair
namespace xgboost {
namespace common {
+
+template XGBOOST_DEVICE T Sqr(T const &w) { return w * w; }
+
/*!
* \brief calculate the sigmoid of the input.
* \param x input parameter
@@ -30,9 +33,12 @@ XGBOOST_DEVICE inline float Sigmoid(float x) {
return y;
}
-template
-XGBOOST_DEVICE inline static T Sqr(T a) { return a * a; }
-
+XGBOOST_DEVICE inline double Sigmoid(double x) {
+ double constexpr kEps = 1e-16; // avoid 0 div
+ auto denom = std::exp(-x) + 1.0 + kEps;
+ auto y = 1.0 / denom;
+ return y;
+}
/*!
* \brief Equality test for both integer and floating point.
*/
@@ -134,10 +140,6 @@ inline static bool CmpFirst(const std::pair &a,
const std::pair &b) {
return a.first > b.first;
}
-inline static bool CmpSecond(const std::pair &a,
- const std::pair &b) {
- return a.second > b.second;
-}
// Redefined here to workaround a VC bug that doesn't support overloading for integer
// types.
diff --git a/src/common/numeric.h b/src/common/numeric.h
index 6a1c15fd08b4..687134c122d4 100644
--- a/src/common/numeric.h
+++ b/src/common/numeric.h
@@ -6,8 +6,10 @@
#include // OMPException
-#include // std::max
-#include // std::iterator_traits
+#include // for std::max
+#include // for size_t
+#include // for int32_t
+#include // for iterator_traits
#include
#include "common.h" // AssertGPUSupport
@@ -111,11 +113,11 @@ inline double Reduce(Context const*, HostDeviceVector const&) {
namespace cpu_impl {
template
V Reduce(Context const* ctx, It first, It second, V const& init) {
- size_t n = std::distance(first, second);
- common::MemStackAllocator result_tloc(ctx->Threads(), init);
- common::ParallelFor(n, ctx->Threads(),
- [&](auto i) { result_tloc[omp_get_thread_num()] += first[i]; });
- auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cbegin() + ctx->Threads(), init);
+ std::size_t n = std::distance(first, second);
+ auto n_threads = static_cast(std::min(n, static_cast(ctx->Threads())));
+ common::MemStackAllocator result_tloc(n_threads, init);
+ common::ParallelFor(n, n_threads, [&](auto i) { result_tloc[omp_get_thread_num()] += first[i]; });
+ auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cbegin() + n_threads, init);
return result;
}
} // namespace cpu_impl
diff --git a/src/common/ranking_utils.cc b/src/common/ranking_utils.cc
index f0b1c1a5ee77..cae564480a0d 100644
--- a/src/common/ranking_utils.cc
+++ b/src/common/ranking_utils.cc
@@ -3,15 +3,134 @@
*/
#include "ranking_utils.h"
-#include // std::uint32_t
-#include // std::ostringstream
-#include // std::string,std::sscanf
+#include // for copy_n, max, min
+#include // for size_t
+#include // for sscanf
+#include // for exception
+#include // for greater
+#include // for reverse_iterator
+#include // for char_traits, string
-#include "xgboost/string_view.h" // StringView
+#include "algorithm.h" // for ArgSort, NoneOf, AllOf
+#include "linalg_op.h" // for cbegin, cend
+#include "optional_weight.h" // for MakeOptionalWeights
+#include "threading_utils.h" // for ParallelFor
+#include "xgboost/base.h" // for bst_group_t
+#include "xgboost/context.h" // for Context
+#include "xgboost/data.h" // for MetaInfo
+#include "xgboost/linalg.h" // for All, TensorView, Range, Tensor, Vector
+#include "xgboost/logging.h" // for Error, LogCheck_EQ, CHECK_EQ
-namespace xgboost {
-namespace ltr {
-std::string MakeMetricName(StringView name, StringView param, std::uint32_t* topn, bool* minus) {
+namespace xgboost::ltr {
+void RankingCache::InitOnCPU(Context const* ctx, MetaInfo const& info) {
+ if (info.group_ptr_.empty()) {
+ group_ptr_.Resize(2, 0);
+ group_ptr_.HostVector()[1] = info.num_row_;
+ } else {
+ group_ptr_.HostVector() = info.group_ptr_;
+ }
+
+ auto const& gptr = group_ptr_.ConstHostVector();
+ for (std::size_t i = 1; i < gptr.size(); ++i) {
+ std::size_t n = gptr[i] - gptr[i - 1];
+ max_group_size_ = std::max(max_group_size_, n);
+ }
+
+ double sum_weights = 0;
+ auto n_groups = Groups();
+ auto weight = common::MakeOptionalWeights(ctx, info.weights_);
+ for (bst_omp_uint k = 0; k < n_groups; ++k) {
+ sum_weights += weight[k];
+ }
+ weight_norm_ = static_cast(n_groups) / sum_weights;
+
+ auto const& h_label = info.labels.HostView().Slice(linalg::All(), 0);
+ is_binary_ = IsBinaryRel(
+ h_label, [ctx](auto beg, auto end, auto op) { return common::AllOf(ctx, beg, end, op); });
+}
+
+common::Span RankingCache::MakeRankOnCPU(Context const* ctx,
+ common::Span predt) {
+ auto gptr = this->DataGroupPtr(ctx);
+ auto rank = this->sorted_idx_cache_.HostSpan();
+ CHECK_EQ(rank.size(), predt.size());
+
+ common::ParallelFor(this->Groups(), ctx->Threads(), [&](auto g) {
+ auto cnt = gptr[g + 1] - gptr[g];
+ auto g_predt = predt.subspan(gptr[g], cnt);
+ auto g_rank = rank.subspan(gptr[g], cnt);
+ auto sorted_idx = common::ArgSort(
+ ctx, g_predt.data(), g_predt.data() + g_predt.size(), std::greater<>{});
+ CHECK_EQ(g_rank.size(), sorted_idx.size());
+ std::copy_n(sorted_idx.data(), sorted_idx.size(), g_rank.data());
+ });
+
+ return rank;
+}
+
+#if !defined(XGBOOST_USE_CUDA)
+void RankingCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); }
+common::Span RankingCache::MakeRankOnCUDA(Context const*,
+ common::Span) {
+ common::AssertGPUSupport();
+ return {};
+}
+#endif // !defined()
+
+void NDCGCache::InitOnCPU(Context const* ctx, MetaInfo const& info) {
+ auto const h_group_ptr = this->DataGroupPtr(ctx);
+
+ discounts_.Resize(MaxGroupSize(), 0);
+ auto& h_discounts = discounts_.HostVector();
+ for (std::size_t i = 0; i < MaxGroupSize(); ++i) {
+ h_discounts[i] = CalcDCGDiscount(i);
+ }
+
+ auto n_groups = h_group_ptr.size() - 1;
+ auto h_labels = info.labels.HostView().Slice(linalg::All(), 0);
+
+ CheckNDCGLabels(this->Param(), h_labels,
+ [&](auto beg, auto end, auto op) { return common::NoneOf(ctx, beg, end, op); });
+
+ inv_idcg_.Reshape(n_groups);
+ auto h_inv_idcg = inv_idcg_.HostView();
+ std::size_t topk = this->Param().TopK();
+ auto const exp_gain = this->Param().ndcg_exp_gain;
+
+ common::ParallelFor(n_groups, ctx->Threads(), [&](auto g) {
+ auto g_labels = h_labels.Slice(linalg::Range(h_group_ptr[g], h_group_ptr[g + 1]));
+ auto sorted_idx = common::ArgSort(ctx, linalg::cbegin(g_labels),
+ linalg::cend(g_labels), std::greater<>{});
+
+ double idcg{0.0};
+ for (std::size_t i = 0; i < std::min(g_labels.Size(), topk); ++i) {
+ if (exp_gain) {
+ idcg += h_discounts[i] * CalcDCGGain(g_labels(sorted_idx[i]));
+ } else {
+ idcg += h_discounts[i] * g_labels(sorted_idx[i]);
+ }
+ }
+ h_inv_idcg(g) = CalcInvIDCG(idcg);
+ });
+}
+
+#if !defined(XGBOOST_USE_CUDA)
+void NDCGCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); }
+#endif // !defined(XGBOOST_USE_CUDA)
+
+DMLC_REGISTER_PARAMETER(LambdaRankParam);
+
+void MAPCache::InitOnCPU(Context const* ctx, MetaInfo const& info) {
+ auto const& h_label = info.labels.HostView().Slice(linalg::All(), 0);
+ CheckMapLabels(h_label,
+ [ctx](auto beg, auto end, auto op) { return common::AllOf(ctx, beg, end, op); });
+}
+
+#if !defined(XGBOOST_USE_CUDA)
+void MAPCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); }
+#endif // !defined(XGBOOST_USE_CUDA)
+
+std::string ParseMetricName(StringView name, StringView param, position_t* topn, bool* minus) {
std::string out_name;
if (!param.empty()) {
std::ostringstream os;
@@ -30,5 +149,18 @@ std::string MakeMetricName(StringView name, StringView param, std::uint32_t* top
}
return out_name;
}
-} // namespace ltr
-} // namespace xgboost
+
+std::string MakeMetricName(StringView name, position_t topn, bool minus) {
+ std::ostringstream ss;
+ if (topn == LambdaRankParam::NotSet()) {
+ ss << name;
+ } else {
+ ss << name << "@" << topn;
+ }
+ if (minus) {
+ ss << "-";
+ }
+ std::string out_name = ss.str();
+ return out_name;
+}
+} // namespace xgboost::ltr
diff --git a/src/common/ranking_utils.cu b/src/common/ranking_utils.cu
new file mode 100644
index 000000000000..5ee1ed36c367
--- /dev/null
+++ b/src/common/ranking_utils.cu
@@ -0,0 +1,213 @@
+/**
+ * Copyright 2023 by XGBoost Contributors
+ */
+#include // for make_counting_iterator
+#include // for none_of, all_off
+#include // for pair, make_pair
+#include // for reduce
+#include // for inclusive_scan
+
+#include // for size_t
+
+#include "algorithm.cuh" // for SegmentedArgSort
+#include "cuda_context.cuh" // for CUDAContext
+#include "device_helpers.cuh" // for MakeTransformIterator, LaunchN
+#include "optional_weight.h" // for MakeOptionalWeights, OptionalWeights
+#include "ranking_utils.cuh" // for ThreadsForMean
+#include "ranking_utils.h"
+#include "threading_utils.cuh" // for SegmentedTrapezoidThreads
+#include "xgboost/base.h" // for XGBOOST_DEVICE
+#include "xgboost/context.h" // for Context
+#include "xgboost/linalg.h" // for VectorView
+#include "xgboost/span.h" // for Span
+
+namespace xgboost::ltr {
+namespace cuda_impl {
+void CalcQueriesDCG(Context const* ctx, linalg::VectorView d_labels,
+ common::Span d_sorted_idx, bool exp_gain,
+ common::Span d_group_ptr, std::size_t k,
+ linalg::VectorView out_dcg) {
+ CHECK_EQ(d_group_ptr.size() - 1, out_dcg.Size());
+ using IdxGroup = thrust::pair;
+ auto group_it = dh::MakeTransformIterator(
+ thrust::make_counting_iterator(0ull), [=] XGBOOST_DEVICE(std::size_t idx) {
+ return thrust::make_pair(idx, dh::SegmentId(d_group_ptr, idx)); // NOLINT
+ });
+ auto value_it = dh::MakeTransformIterator(
+ group_it,
+ [exp_gain, d_labels, d_group_ptr, k,
+ d_sorted_idx] XGBOOST_DEVICE(IdxGroup const& l) -> double {
+ auto g_begin = d_group_ptr[l.second];
+ auto g_size = d_group_ptr[l.second + 1] - g_begin;
+
+ auto idx_in_group = l.first - g_begin;
+ if (idx_in_group >= k) {
+ return 0.0;
+ }
+ double gain{0.0};
+ auto g_sorted_idx = d_sorted_idx.subspan(g_begin, g_size);
+ auto g_labels = d_labels.Slice(linalg::Range(g_begin, g_begin + g_size));
+
+ if (exp_gain) {
+ gain = ltr::CalcDCGGain(g_labels(g_sorted_idx[idx_in_group]));
+ } else {
+ gain = g_labels(g_sorted_idx[idx_in_group]);
+ }
+ double discount = CalcDCGDiscount(idx_in_group);
+ return gain * discount;
+ });
+
+ CHECK(out_dcg.Contiguous());
+ std::size_t bytes;
+ cub::DeviceSegmentedReduce::Sum(nullptr, bytes, value_it, out_dcg.Values().data(),
+ d_group_ptr.size() - 1, d_group_ptr.data(),
+ d_group_ptr.data() + 1, ctx->CUDACtx()->Stream());
+ dh::TemporaryArray temp(bytes);
+ cub::DeviceSegmentedReduce::Sum(temp.data().get(), bytes, value_it, out_dcg.Values().data(),
+ d_group_ptr.size() - 1, d_group_ptr.data(),
+ d_group_ptr.data() + 1, ctx->CUDACtx()->Stream());
+}
+
+void CalcQueriesInvIDCG(Context const* ctx, linalg::VectorView d_labels,
+ common::Span d_group_ptr,
+ linalg::VectorView out_inv_IDCG, ltr::LambdaRankParam const& p) {
+ CHECK_GE(d_group_ptr.size(), 2ul);
+ size_t n_groups = d_group_ptr.size() - 1;
+ CHECK_EQ(out_inv_IDCG.Size(), n_groups);
+ dh::device_vector sorted_idx(d_labels.Size());
+ auto d_sorted_idx = dh::ToSpan(sorted_idx);
+ common::SegmentedArgSort(ctx, d_labels.Values(), d_group_ptr, d_sorted_idx);
+ CalcQueriesDCG(ctx, d_labels, d_sorted_idx, p.ndcg_exp_gain, d_group_ptr, p.TopK(), out_inv_IDCG);
+ dh::LaunchN(out_inv_IDCG.Size(), ctx->CUDACtx()->Stream(),
+ [out_inv_IDCG] XGBOOST_DEVICE(size_t idx) mutable {
+ double idcg = out_inv_IDCG(idx);
+ out_inv_IDCG(idx) = CalcInvIDCG(idcg);
+ });
+}
+} // namespace cuda_impl
+
+namespace {
+struct CheckNDCGOp {
+ CUDAContext const* cuctx;
+ template
+ bool operator()(It beg, It end, Op op) {
+ return thrust::none_of(cuctx->CTP(), beg, end, op);
+ }
+};
+struct CheckMAPOp {
+ CUDAContext const* cuctx;
+ template
+ bool operator()(It beg, It end, Op op) {
+ return thrust::all_of(cuctx->CTP(), beg, end, op);
+ }
+};
+
+struct ThreadGroupOp {
+ common::Span d_group_ptr;
+ std::size_t n_pairs;
+
+ common::Span out_thread_group_ptr;
+
+ XGBOOST_DEVICE void operator()(std::size_t i) {
+ out_thread_group_ptr[i + 1] =
+ cuda_impl::ThreadsForMean(d_group_ptr[i + 1] - d_group_ptr[i], n_pairs);
+ }
+};
+
+struct GroupSizeOp {
+ common::Span d_group_ptr;
+
+ XGBOOST_DEVICE auto operator()(std::size_t i) -> std::size_t {
+ return d_group_ptr[i + 1] - d_group_ptr[i];
+ }
+};
+
+struct WeightOp {
+ common::OptionalWeights d_weight;
+ XGBOOST_DEVICE auto operator()(std::size_t i) -> double { return d_weight[i]; }
+};
+} // anonymous namespace
+
+void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
+ CUDAContext const* cuctx = ctx->CUDACtx();
+
+ group_ptr_.SetDevice(ctx->gpu_id);
+ if (info.group_ptr_.empty()) {
+ group_ptr_.Resize(2, 0);
+ group_ptr_.HostVector()[1] = info.num_row_;
+ } else {
+ auto const& h_group_ptr = info.group_ptr_;
+ group_ptr_.Resize(h_group_ptr.size());
+ auto d_group_ptr = group_ptr_.DeviceSpan();
+ dh::safe_cuda(cudaMemcpyAsync(d_group_ptr.data(), h_group_ptr.data(), d_group_ptr.size_bytes(),
+ cudaMemcpyHostToDevice, cuctx->Stream()));
+ }
+
+ auto d_group_ptr = DataGroupPtr(ctx);
+ std::size_t n_groups = Groups();
+
+ auto it = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul),
+ GroupSizeOp{d_group_ptr});
+ max_group_size_ =
+ thrust::reduce(cuctx->CTP(), it, it + n_groups, 0ul, thrust::maximum{});
+
+ threads_group_ptr_.SetDevice(ctx->gpu_id);
+ threads_group_ptr_.Resize(n_groups + 1, 0);
+ auto d_threads_group_ptr = threads_group_ptr_.DeviceSpan();
+ if (param_.HasTruncation()) {
+ n_cuda_threads_ =
+ common::SegmentedTrapezoidThreads(d_group_ptr, d_threads_group_ptr, Param().NumPair());
+ } else {
+ auto n_pairs = Param().NumPair();
+ dh::LaunchN(n_groups, cuctx->Stream(),
+ ThreadGroupOp{d_group_ptr, n_pairs, d_threads_group_ptr});
+ thrust::inclusive_scan(cuctx->CTP(), dh::tcbegin(d_threads_group_ptr),
+ dh::tcend(d_threads_group_ptr), dh::tbegin(d_threads_group_ptr));
+ n_cuda_threads_ = info.num_row_ * param_.NumPair();
+ }
+
+ sorted_idx_cache_.SetDevice(ctx->gpu_id);
+ sorted_idx_cache_.Resize(info.labels.Size(), 0);
+
+ auto weight = common::MakeOptionalWeights(ctx, info.weights_);
+ auto w_it =
+ dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), WeightOp{weight});
+ weight_norm_ = static_cast(n_groups) / thrust::reduce(w_it, w_it + n_groups);
+
+ auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
+ is_binary_ = IsBinaryRel(d_label, CheckMAPOp{ctx->CUDACtx()});
+}
+
+common::Span RankingCache::MakeRankOnCUDA(Context const* ctx,
+ common::Span predt) {
+ auto d_sorted_idx = sorted_idx_cache_.DeviceSpan();
+ auto d_group_ptr = DataGroupPtr(ctx);
+ common::SegmentedArgSort(ctx, predt, d_group_ptr, d_sorted_idx);
+ return d_sorted_idx;
+}
+
+void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
+ CUDAContext const* cuctx = ctx->CUDACtx();
+ auto labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
+ CheckNDCGLabels(this->Param(), labels, CheckNDCGOp{cuctx});
+
+ auto d_group_ptr = this->DataGroupPtr(ctx);
+
+ std::size_t n_groups = d_group_ptr.size() - 1;
+ inv_idcg_ = linalg::Zeros(ctx, n_groups);
+ auto d_inv_idcg = inv_idcg_.View(ctx->gpu_id);
+ cuda_impl::CalcQueriesInvIDCG(ctx, labels, d_group_ptr, d_inv_idcg, this->Param());
+ CHECK_GE(this->Param().NumPair(), 1ul);
+
+ discounts_.SetDevice(ctx->gpu_id);
+ discounts_.Resize(MaxGroupSize());
+ auto d_discount = discounts_.DeviceSpan();
+ dh::LaunchN(MaxGroupSize(), cuctx->Stream(),
+ [=] XGBOOST_DEVICE(std::size_t i) { d_discount[i] = CalcDCGDiscount(i); });
+}
+
+void MAPCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
+ auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
+ CheckMapLabels(d_label, CheckMAPOp{ctx->CUDACtx()});
+}
+} // namespace xgboost::ltr
diff --git a/src/common/ranking_utils.cuh b/src/common/ranking_utils.cuh
new file mode 100644
index 000000000000..297f5157ecfb
--- /dev/null
+++ b/src/common/ranking_utils.cuh
@@ -0,0 +1,40 @@
+/**
+ * Copyright 2023 by XGBoost Contributors
+ */
+#ifndef XGBOOST_COMMON_RANKING_UTILS_CUH_
+#define XGBOOST_COMMON_RANKING_UTILS_CUH_
+
+#include // for size_t
+
+#include "ranking_utils.h" // for LambdaRankParam
+#include "xgboost/base.h" // for bst_group_t, XGBOOST_DEVICE
+#include "xgboost/context.h" // for Context
+#include "xgboost/linalg.h" // for VectorView
+#include "xgboost/span.h" // for Span
+
+namespace xgboost {
+namespace ltr {
+namespace cuda_impl {
+void CalcQueriesDCG(Context const *ctx, linalg::VectorView d_labels,
+ common::Span d_sorted_idx, bool exp_gain,
+ common::Span d_group_ptr, std::size_t k,
+ linalg::VectorView out_dcg);
+
+void CalcQueriesInvIDCG(Context const *ctx, linalg::VectorView d_labels,
+ common::Span d_group_ptr,
+ linalg::VectorView out_inv_IDCG, ltr::LambdaRankParam const &p);
+
+// Functions for creating number of threads for CUDA, and getting back the number of pairs
+// from the number of threads.
+XGBOOST_DEVICE __forceinline__ std::size_t ThreadsForMean(std::size_t group_size,
+ std::size_t n_pairs) {
+ return group_size * n_pairs;
+}
+XGBOOST_DEVICE __forceinline__ std::size_t PairsForGroup(std::size_t n_threads,
+ std::size_t group_size) {
+ return n_threads / group_size;
+}
+} // namespace cuda_impl
+} // namespace ltr
+} // namespace xgboost
+#endif // XGBOOST_COMMON_RANKING_UTILS_CUH_
diff --git a/src/common/ranking_utils.h b/src/common/ranking_utils.h
index 35ee36c2185d..b3259f3f154e 100644
--- a/src/common/ranking_utils.h
+++ b/src/common/ranking_utils.h
@@ -3,17 +3,434 @@
*/
#ifndef XGBOOST_COMMON_RANKING_UTILS_H_
#define XGBOOST_COMMON_RANKING_UTILS_H_
+#include // for min
+#include // for log2, fabs, floor
+#include // for size_t
+#include // for uint32_t, uint8_t, int32_t
+#include // for numeric_limits
+#include // for char_traits, string
+#include // for vector
-#include // std::size_t
-#include // std::uint32_t
-#include // std::string
+#include "./math.h" // for CloseTo
+#include "dmlc/parameter.h" // for FieldEntry, DMLC_DECLARE_FIELD
+#include "error_msg.h" // for GroupWeight, GroupSize
+#include "xgboost/base.h" // for XGBOOST_DEVICE, bst_group_t
+#include "xgboost/context.h" // for Context
+#include "xgboost/data.h" // for MetaInfo
+#include "xgboost/host_device_vector.h" // for HostDeviceVector
+#include "xgboost/linalg.h" // for Vector, VectorView, Tensor
+#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK
+#include "xgboost/parameter.h" // for XGBoostParameter
+#include "xgboost/span.h" // for Span
+#include "xgboost/string_view.h" // for StringView
-#include "xgboost/string_view.h" // StringView
+namespace xgboost::ltr {
+/**
+ * \brief Relevance degree
+ */
+using rel_degree_t = std::uint32_t; // NOLINT
+/**
+ * \brief top-k position
+ */
+using position_t = std::uint32_t; // NOLINT
+
+/**
+ * \brief Maximum relevance degree for NDCG
+ */
+constexpr std::size_t MaxRel() { return sizeof(rel_degree_t) * 8 - 1; }
+static_assert(MaxRel() == 31);
+
+XGBOOST_DEVICE inline double CalcDCGGain(rel_degree_t label) {
+ return static_cast((1u << label) - 1);
+}
+
+XGBOOST_DEVICE inline double CalcDCGDiscount(std::size_t idx) {
+ return 1.0 / std::log2(static_cast(idx) + 2.0);
+}
+
+XGBOOST_DEVICE inline double CalcInvIDCG(double idcg) {
+ auto inv_idcg = (idcg == 0.0 ? 0.0 : (1.0 / idcg)); // handle irrelevant document
+ return inv_idcg;
+}
+
+enum class PairMethod : std::int32_t {
+ kTopK = 0,
+ kMean = 1,
+};
+} // namespace xgboost::ltr
+
+DECLARE_FIELD_ENUM_CLASS(xgboost::ltr::PairMethod);
+
+namespace xgboost::ltr {
+struct LambdaRankParam : public XGBoostParameter {
+ private:
+ static constexpr position_t DefaultK() { return 32; }
+ static constexpr position_t DefaultSamplePairs() { return 1; } // fixme: better dft
+
+ protected:
+ // pairs
+ // should be accessed by getter for auto configuration.
+ // nolint so that we can keep the string name.
+ PairMethod lambdarank_pair_method; // NOLINT
+ std::size_t lambdarank_num_pair_per_sample; // NOLINT
+
+ public:
+ static constexpr position_t NotSet() { return std::numeric_limits::max(); }
+
+ // unbiased
+ bool lambdarank_unbiased;
+ double lambdarank_bias_norm;
+ // ndcg
+ bool ndcg_exp_gain;
+
+ bool operator==(LambdaRankParam const& that) const {
+ return lambdarank_pair_method == that.lambdarank_pair_method &&
+ lambdarank_num_pair_per_sample == that.lambdarank_num_pair_per_sample &&
+ lambdarank_unbiased == that.lambdarank_unbiased &&
+ lambdarank_bias_norm == that.lambdarank_bias_norm && ndcg_exp_gain == that.ndcg_exp_gain;
+ }
+ bool operator!=(LambdaRankParam const& that) const { return !(*this == that); }
+
+ [[nodiscard]] double Regularizer() const { return 1.0 / (1.0 + this->lambdarank_bias_norm); }
+
+ /**
+ * \brief Get number of pairs for each sample
+ */
+ [[nodiscard]] position_t NumPair() const {
+ if (lambdarank_num_pair_per_sample == NotSet()) {
+ switch (lambdarank_pair_method) {
+ case PairMethod::kMean:
+ return DefaultSamplePairs();
+ case PairMethod::kTopK:
+ return DefaultK();
+ default:
+ LOG(FATAL) << "Unreachable.";
+ }
+ } else {
+ return lambdarank_num_pair_per_sample;
+ }
+ LOG(FATAL) << "Unreachable.";
+ return 0;
+ }
+
+ [[nodiscard]] bool HasTruncation() const { return lambdarank_pair_method == PairMethod::kTopK; }
+
+ // Used for evaluation metric and cache initialization, iterate through top-k or the whole list
+ [[nodiscard]] auto TopK() const {
+ if (HasTruncation()) {
+ return NumPair();
+ } else {
+ return NotSet();
+ }
+ }
+
+ DMLC_DECLARE_PARAMETER(LambdaRankParam) {
+ DMLC_DECLARE_FIELD(lambdarank_pair_method)
+ .set_default(PairMethod::kMean)
+ .add_enum("mean", PairMethod::kMean)
+ .add_enum("topk", PairMethod::kTopK)
+ .describe("Method for constructing pairs.");
+ DMLC_DECLARE_FIELD(lambdarank_num_pair_per_sample)
+ .set_default(NotSet())
+ .set_lower_bound(1)
+ .describe("Number of pairs for each sample in the list.");
+ DMLC_DECLARE_FIELD(lambdarank_unbiased)
+ .set_default(false)
+ .describe("Unbiased lambda mart. Use IPW to debias click position");
+ DMLC_DECLARE_FIELD(lambdarank_bias_norm)
+ .set_default(2.0)
+ .set_lower_bound(0.0)
+ .describe("Lp regularization for unbiased lambdarank.");
+ DMLC_DECLARE_FIELD(ndcg_exp_gain)
+ .set_default(true)
+ .describe("When set to true, the label gain is 2^rel - 1, otherwise it's rel.");
+ }
+};
+
+/**
+ * \brief Common cached items for ranking tasks.
+ */
+class RankingCache {
+ private:
+ void InitOnCPU(Context const* ctx, MetaInfo const& info);
+ void InitOnCUDA(Context const* ctx, MetaInfo const& info);
+ // Cached parameter
+ LambdaRankParam param_;
+ // offset to data groups.
+ HostDeviceVector group_ptr_;
+ // store the sorted index of prediction.
+ HostDeviceVector sorted_idx_cache_;
+ // Maximum size of group
+ std::size_t max_group_size_{0};
+ // Normalization for weight
+ double weight_norm_{1.0};
+ // Whether label is binary
+ bool is_binary_{false};
+ /**
+ * CUDA cache
+ */
+ // offset to threads assigned to each group for gradient calculation
+ HostDeviceVector threads_group_ptr_;
+ // Sorted index of label for finding buckets.
+ HostDeviceVector y_sorted_idx_cache_;
+ // Cached labels sorted by the model
+ HostDeviceVector y_ranked_by_model_;
+ // store rounding factor for objective for each group
+ linalg::Vector roundings_;
+ // rounding factor for cost
+ HostDeviceVector cost_rounding_;
+ // temporary storage for creating rounding factors. Stored as byte to avoid having cuda
+ // data structure in here.
+ HostDeviceVector max_lambdas_;
+ // total number of cuda threads used for gradient calculation
+ std::size_t n_cuda_threads_{0};
+
+ // Create model rank list on GPU
+ common::Span MakeRankOnCUDA(Context const* ctx,
+ common::Span predt);
+ // Create model rank list on CPU
+ common::Span MakeRankOnCPU(Context const* ctx,
+ common::Span predt);
-namespace xgboost {
-namespace ltr {
+ protected:
+ [[nodiscard]] std::size_t MaxGroupSize() const { return max_group_size_; }
+
+ public:
+ RankingCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p) : param_{p} {
+ CHECK(param_.GetInitialised());
+ if (!info.group_ptr_.empty()) {
+ CHECK_EQ(info.group_ptr_.back(), info.labels.Size())
+ << error::GroupSize() << "the size of label.";
+ }
+ if (ctx->IsCPU()) {
+ this->InitOnCPU(ctx, info);
+ } else {
+ this->InitOnCUDA(ctx, info);
+ }
+ if (!info.weights_.Empty()) {
+ CHECK_EQ(Groups(), info.weights_.Size()) << error::GroupWeight();
+ }
+ }
+ [[nodiscard]] std::size_t MaxPositionSize() const {
+ // Use truncation level as bound.
+ if (param_.HasTruncation()) {
+ return param_.NumPair();
+ }
+ // Hardcoded maximum size of positions to track. We don't need too many of them as the
+ // bias decreases exponentially.
+ return std::min(max_group_size_, static_cast(32));
+ }
+ // Constructed as [1, n_samples] if group ptr is not supplied by the user
+ common::Span DataGroupPtr(Context const* ctx) const {
+ group_ptr_.SetDevice(ctx->gpu_id);
+ return ctx->IsCPU() ? group_ptr_.ConstHostSpan() : group_ptr_.ConstDeviceSpan();
+ }
+
+ [[nodiscard]] auto const& Param() const { return param_; }
+ [[nodiscard]] std::size_t Groups() const { return group_ptr_.Size() - 1; }
+ [[nodiscard]] double WeightNorm() const { return weight_norm_; }
+ [[nodiscard]] bool IsBinary() const { return is_binary_; }
+
+ // Create a rank list by model prediction
+ common::Span SortedIdx(Context const* ctx, common::Span predt) {
+ if (sorted_idx_cache_.Empty()) {
+ sorted_idx_cache_.SetDevice(ctx->gpu_id);
+ sorted_idx_cache_.Resize(predt.size());
+ }
+ if (ctx->IsCPU()) {
+ return this->MakeRankOnCPU(ctx, predt);
+ } else {
+ return this->MakeRankOnCUDA(ctx, predt);
+ }
+ }
+ // The function simply returns a uninitialized buffer as this is only used by the
+ // objective for creating pairs.
+ common::Span SortedIdxY(Context const* ctx, std::size_t n_samples) {
+ CHECK(ctx->IsCUDA());
+ if (y_sorted_idx_cache_.Empty()) {
+ y_sorted_idx_cache_.SetDevice(ctx->gpu_id);
+ y_sorted_idx_cache_.Resize(n_samples);
+ }
+ return y_sorted_idx_cache_.DeviceSpan();
+ }
+ common::Span RankedY(Context const* ctx, std::size_t n_samples) {
+ CHECK(ctx->IsCUDA());
+ if (y_ranked_by_model_.Empty()) {
+ y_ranked_by_model_.SetDevice(ctx->gpu_id);
+ y_ranked_by_model_.Resize(n_samples);
+ }
+ return y_ranked_by_model_.DeviceSpan();
+ }
+
+ // CUDA cache getters, the cache is shared between metric and objective, some of these
+ // fields are lazy initialized to avoid unnecessary allocation.
+ [[nodiscard]] common::Span CUDAThreadsGroupPtr() const {
+ CHECK(!threads_group_ptr_.Empty());
+ return threads_group_ptr_.ConstDeviceSpan();
+ }
+ [[nodiscard]] std::size_t CUDAThreads() const { return n_cuda_threads_; }
+
+ linalg::VectorView CUDARounding(Context const* ctx) {
+ if (roundings_.Size() == 0) {
+ roundings_.SetDevice(ctx->gpu_id);
+ roundings_.Reshape(Groups());
+ }
+ return roundings_.View(ctx->gpu_id);
+ }
+ common::Span CUDACostRounding(Context const* ctx) {
+ if (cost_rounding_.Size() == 0) {
+ cost_rounding_.SetDevice(ctx->gpu_id);
+ cost_rounding_.Resize(1);
+ }
+ return cost_rounding_.DeviceSpan();
+ }
+ template
+ common::Span MaxLambdas(Context const* ctx, std::size_t n) {
+ max_lambdas_.SetDevice(ctx->gpu_id);
+ std::size_t bytes = n * sizeof(Type);
+ if (bytes != max_lambdas_.Size()) {
+ max_lambdas_.Resize(bytes);
+ }
+ return common::Span{reinterpret_cast(max_lambdas_.DevicePointer()), n};
+ }
+};
+
+class NDCGCache : public RankingCache {
+ // NDCG discount
+ HostDeviceVector discounts_;
+ // 1.0 / IDCG
+ linalg::Vector inv_idcg_;
+ /**
+ * CUDA cache
+ */
+ // store the intermediate DCG calculation result for metric
+ linalg::Vector dcg_;
+
+ public:
+ void InitOnCPU(Context const* ctx, MetaInfo const& info);
+ void InitOnCUDA(Context const* ctx, MetaInfo const& info);
+
+ public:
+ NDCGCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p)
+ : RankingCache{ctx, info, p} {
+ if (ctx->IsCPU()) {
+ this->InitOnCPU(ctx, info);
+ } else {
+ this->InitOnCUDA(ctx, info);
+ }
+ }
+
+ linalg::VectorView InvIDCG(Context const* ctx) const {
+ return inv_idcg_.View(ctx->gpu_id);
+ }
+ common::Span Discount(Context const* ctx) const {
+ return ctx->IsCPU() ? discounts_.ConstHostSpan() : discounts_.ConstDeviceSpan();
+ }
+ linalg::VectorView Dcg(Context const* ctx) {
+ if (dcg_.Size() == 0) {
+ dcg_.SetDevice(ctx->gpu_id);
+ dcg_.Reshape(this->Groups());
+ }
+ return dcg_.View(ctx->gpu_id);
+ }
+};
+
+/**
+ * \brief Validate label for NDCG
+ *
+ * \tparam NoneOf Implementation of std::none_of. Specified as a parameter to reuse the
+ * check for both CPU and GPU.
+ */
+template
+void CheckNDCGLabels(ltr::LambdaRankParam const& p, linalg::VectorView labels,
+ NoneOf none_of) {
+ auto d_labels = labels.Values();
+ if (p.ndcg_exp_gain) {
+ auto label_is_integer =
+ none_of(d_labels.data(), d_labels.data() + d_labels.size(), [] XGBOOST_DEVICE(float v) {
+ auto l = std::floor(v);
+ return std::fabs(l - v) > kRtEps || v < 0.0f;
+ });
+ CHECK(label_is_integer)
+ << "When using relevance degree as target, label must be either 0 or positive integer.";
+ }
+
+ if (p.ndcg_exp_gain) {
+ auto label_is_valid = none_of(d_labels.data(), d_labels.data() + d_labels.size(),
+ [] XGBOOST_DEVICE(ltr::rel_degree_t v) { return v > MaxRel(); });
+ CHECK(label_is_valid) << "Relevance degress must be lesser than or equal to " << MaxRel()
+ << " when the exponential NDCG gain function is used. "
+ << "Set `ndcg_exp_gain` to false to use custom DCG gain.";
+ }
+}
+
+template
+bool IsBinaryRel(linalg::VectorView label, AllOf all_of) {
+ auto s_label = label.Values();
+ return all_of(s_label.data(), s_label.data() + s_label.size(), [] XGBOOST_DEVICE(float y) {
+ return std::abs(y - 1.0f) < kRtEps || std::abs(y - 0.0f) < kRtEps;
+ });
+}
/**
- * \brief Construct name for ranking metric given parameters.
+ * \brief Validate label for MAP
+ *
+ * \tparam Implementation of std::all_of. Specified as a parameter to reuse the check for
+ * both CPU and GPU.
+ */
+template
+void CheckMapLabels(linalg::VectorView label, AllOf all_of) {
+ auto s_label = label.Values();
+ auto is_binary = IsBinaryRel(label, all_of);
+ CHECK(is_binary) << "MAP can only be used with binary labels.";
+}
+
+class MAPCache : public RankingCache {
+ // Total number of relevant documents for each group
+ HostDeviceVector n_rel_;
+ // \sum l_k/k
+ HostDeviceVector acc_;
+ HostDeviceVector map_;
+ // Number of samples in this dataset.
+ std::size_t n_samples_{0};
+
+ void InitOnCPU(Context const* ctx, MetaInfo const& info);
+ void InitOnCUDA(Context const* ctx, MetaInfo const& info);
+
+ public:
+ MAPCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p)
+ : RankingCache{ctx, info, p}, n_samples_{static_cast(info.num_row_)} {
+ if (ctx->IsCPU()) {
+ this->InitOnCPU(ctx, info);
+ } else {
+ this->InitOnCUDA(ctx, info);
+ }
+ }
+
+ common::Span NumRelevant(Context const* ctx) {
+ if (n_rel_.Empty()) {
+ n_rel_.SetDevice(ctx->gpu_id);
+ n_rel_.Resize(n_samples_);
+ }
+ return ctx->IsCPU() ? n_rel_.HostSpan() : n_rel_.DeviceSpan();
+ }
+ common::Span Acc(Context const* ctx) {
+ if (acc_.Empty()) {
+ acc_.SetDevice(ctx->gpu_id);
+ acc_.Resize(n_samples_);
+ }
+ return ctx->IsCPU() ? acc_.HostSpan() : acc_.DeviceSpan();
+ }
+ common::Span Map(Context const* ctx) {
+ if (map_.Empty()) {
+ map_.SetDevice(ctx->gpu_id);
+ map_.Resize(this->Groups());
+ }
+ return ctx->IsCPU() ? map_.HostSpan() : map_.DeviceSpan();
+ }
+};
+
+/**
+ * \brief Parse name for ranking metric given parameters.
*
* \param [in] name Null terminated string for metric name
* \param [in] param Null terminated string for parameter like the `3-` in `ndcg@3-`.
@@ -23,7 +440,11 @@ namespace ltr {
*
* \return The name of the metric.
*/
-std::string MakeMetricName(StringView name, StringView param, std::uint32_t* topn, bool* minus);
-} // namespace ltr
-} // namespace xgboost
+std::string ParseMetricName(StringView name, StringView param, position_t* topn, bool* minus);
+
+/**
+ * \brief Parse name for ranking metric given parameters.
+ */
+std::string MakeMetricName(StringView name, position_t topn, bool minus);
+} // namespace xgboost::ltr
#endif // XGBOOST_COMMON_RANKING_UTILS_H_
diff --git a/src/common/threading_utils.cuh b/src/common/threading_utils.cuh
index c21d312d2e03..db5fe82f94ac 100644
--- a/src/common/threading_utils.cuh
+++ b/src/common/threading_utils.cuh
@@ -43,36 +43,33 @@ XGBOOST_DEVICE inline std::size_t DiscreteTrapezoidArea(std::size_t n, std::size
* with h <= n
*/
template
-inline size_t
-SegmentedTrapezoidThreads(xgboost::common::Span group_ptr,
- xgboost::common::Span out_group_threads_ptr,
- size_t h) {
+std::size_t SegmentedTrapezoidThreads(xgboost::common::Span group_ptr,
+ xgboost::common::Span out_group_threads_ptr,
+ std::size_t h) {
CHECK_GE(group_ptr.size(), 1);
CHECK_EQ(group_ptr.size(), out_group_threads_ptr.size());
- dh::LaunchN(
- group_ptr.size(), [=] XGBOOST_DEVICE(size_t idx) {
- if (idx == 0) {
- out_group_threads_ptr[0] = 0;
- return;
- }
+ dh::LaunchN(group_ptr.size(), [=] XGBOOST_DEVICE(std::size_t idx) {
+ if (idx == 0) {
+ out_group_threads_ptr[0] = 0;
+ return;
+ }
- size_t cnt = static_cast(group_ptr[idx] - group_ptr[idx - 1]);
- out_group_threads_ptr[idx] = DiscreteTrapezoidArea(cnt, h);
- });
+ std::size_t cnt = static_cast(group_ptr[idx] - group_ptr[idx - 1]);
+ out_group_threads_ptr[idx] = DiscreteTrapezoidArea(cnt, h);
+ });
dh::InclusiveSum(out_group_threads_ptr.data(), out_group_threads_ptr.data(),
out_group_threads_ptr.size());
- size_t total = 0;
- dh::safe_cuda(cudaMemcpy(
- &total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1,
- sizeof(total), cudaMemcpyDeviceToHost));
+ std::size_t total = 0;
+ dh::safe_cuda(cudaMemcpy(&total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1,
+ sizeof(total), cudaMemcpyDeviceToHost));
return total;
}
/**
* Called inside kernel to obtain coordinate from trapezoid grid.
*/
-XGBOOST_DEVICE inline void UnravelTrapeziodIdx(size_t i_idx, size_t n,
- size_t *out_i, size_t *out_j) {
+XGBOOST_DEVICE inline void UnravelTrapeziodIdx(std::size_t i_idx, std::size_t n, std::size_t *out_i,
+ std::size_t *out_j) {
auto &i = *out_i;
auto &j = *out_j;
double idx = static_cast(i_idx);
diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h
index a52695e02590..d80008cc0809 100644
--- a/src/common/threading_utils.h
+++ b/src/common/threading_utils.h
@@ -8,9 +8,11 @@
#include
#include
-#include // std::int32_t
+#include // for int32_t
+#include // for malloc, free
#include
-#include // std::is_signed
+#include // for bad_alloc
+#include // for is_signed
#include
#include "xgboost/logging.h"
@@ -266,7 +268,7 @@ class MemStackAllocator {
if (MaxStackSize >= required_size_) {
ptr_ = stack_mem_;
} else {
- ptr_ = reinterpret_cast(malloc(required_size_ * sizeof(T)));
+ ptr_ = reinterpret_cast(std::malloc(required_size_ * sizeof(T)));
}
if (!ptr_) {
throw std::bad_alloc{};
@@ -278,7 +280,7 @@ class MemStackAllocator {
~MemStackAllocator() {
if (required_size_ > MaxStackSize) {
- free(ptr_);
+ std::free(ptr_);
}
}
T& operator[](size_t i) { return ptr_[i]; }
diff --git a/src/data/data.cc b/src/data/data.cc
index d24048a2ab23..238aaefd47ce 100644
--- a/src/data/data.cc
+++ b/src/data/data.cc
@@ -7,11 +7,13 @@
#include
#include
+#include
#include
#include "../collective/communicator-inl.h"
#include "../common/algorithm.h" // StableSort
#include "../common/api_entry.h" // XGBAPIThreadLocalEntry
+#include "../common/error_msg.h" // for GroupWeight, GroupSize
#include "../common/group_data.h"
#include "../common/io.h"
#include "../common/linalg_op.h"
@@ -32,6 +34,7 @@
#include "xgboost/context.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/learner.h"
+#include "xgboost/linalg.h" // Vector
#include "xgboost/logging.h"
#include "xgboost/string_view.h"
#include "xgboost/version_config.h"
@@ -485,7 +488,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
}
// uint info
if (key == "group") {
- linalg::Tensor t;
+ linalg::Vector t;
CopyTensorInfoImpl(ctx, arr, &t);
auto const& h_groups = t.Data()->HostVector();
group_ptr_.clear();
@@ -510,6 +513,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
data::ValidateQueryGroup(group_ptr_);
return;
}
+
// float info
linalg::Tensor t;
CopyTensorInfoImpl<1>(ctx, arr, &t);
@@ -700,58 +704,63 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col
}
}
+namespace {
+template
+void CheckDevice(std::int32_t device, HostDeviceVector const& v) {
+ CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device)
+ << "Data is resided on a different device than `gpu_id`. "
+ << "Device that data is on: " << v.DeviceIdx() << ", "
+ << "`gpu_id` for XGBoost: " << device;
+}
+template
+void CheckDevice(std::int32_t device, linalg::Tensor const& v) {
+ CheckDevice(device, *v.Data());
+}
+} // anonymous namespace
+
void MetaInfo::Validate(std::int32_t device) const {
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
- CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
- << "Size of weights must equal to number of groups when ranking "
- "group is used.";
+ CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) << error::GroupWeight();
return;
}
if (group_ptr_.size() != 0) {
CHECK_EQ(group_ptr_.back(), num_row_)
- << "Invalid group structure. Number of rows obtained from groups "
- "doesn't equal to actual number of rows given by data.";
+ << error::GroupSize() << "the actual number of rows given by data.";
}
- auto check_device = [device](HostDeviceVector const& v) {
- CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device)
- << "Data is resided on a different device than `gpu_id`. "
- << "Device that data is on: " << v.DeviceIdx() << ", "
- << "`gpu_id` for XGBoost: " << device;
- };
if (weights_.Size() != 0) {
CHECK_EQ(weights_.Size(), num_row_)
<< "Size of weights must equal to number of rows.";
- check_device(weights_);
+ CheckDevice(device, weights_);
return;
}
if (labels.Size() != 0) {
CHECK_EQ(labels.Shape(0), num_row_) << "Size of labels must equal to number of rows.";
- check_device(*labels.Data());
+ CheckDevice(device, labels);
return;
}
if (labels_lower_bound_.Size() != 0) {
CHECK_EQ(labels_lower_bound_.Size(), num_row_)
<< "Size of label_lower_bound must equal to number of rows.";
- check_device(labels_lower_bound_);
+ CheckDevice(device, labels_lower_bound_);
return;
}
if (feature_weights.Size() != 0) {
CHECK_EQ(feature_weights.Size(), num_col_)
<< "Size of feature_weights must equal to number of columns.";
- check_device(feature_weights);
+ CheckDevice(device, feature_weights);
}
if (labels_upper_bound_.Size() != 0) {
CHECK_EQ(labels_upper_bound_.Size(), num_row_)
<< "Size of label_upper_bound must equal to number of rows.";
- check_device(labels_upper_bound_);
+ CheckDevice(device, labels_upper_bound_);
return;
}
CHECK_LE(num_nonzero_, num_col_ * num_row_);
if (base_margin_.Size() != 0) {
CHECK_EQ(base_margin_.Size() % num_row_, 0)
<< "Size of base margin must be a multiple of number of rows.";
- check_device(*base_margin_.Data());
+ CheckDevice(device, base_margin_);
}
}
@@ -1025,6 +1034,8 @@ SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
bool SparsePage::IsIndicesSorted(int32_t n_threads) const {
auto& h_offset = this->offset.HostVector();
auto& h_data = this->data.HostVector();
+ n_threads = std::max(std::min(static_cast(n_threads), this->Size()),
+ static_cast(1));
std::vector is_sorted_tloc(n_threads, 0);
common::ParallelFor(this->Size(), n_threads, [&](auto i) {
auto beg = h_offset[i];
diff --git a/src/learner.cc b/src/learner.cc
index 0e47c694cc92..27f8973a6c4b 100644
--- a/src/learner.cc
+++ b/src/learner.cc
@@ -113,7 +113,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter
}
// Skip other legacy fields.
- Json ToJson() const {
+ [[nodiscard]] Json ToJson() const {
Json obj{Object{}};
char floats[NumericLimits::kToCharsSize];
auto ret = to_chars(floats, floats + NumericLimits::kToCharsSize, base_score);
@@ -163,7 +163,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter
from_chars(str.c_str(), str.c_str() + str.size(), base_score);
}
- LearnerModelParamLegacy ByteSwap() const {
+ [[nodiscard]] LearnerModelParamLegacy ByteSwap() const {
LearnerModelParamLegacy x = *this;
dmlc::ByteSwap(&x.base_score, sizeof(x.base_score), 1);
dmlc::ByteSwap(&x.num_feature, sizeof(x.num_feature), 1);
diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu
index e06be9730e8f..728d8c230dd3 100644
--- a/src/metric/elementwise_metric.cu
+++ b/src/metric/elementwise_metric.cu
@@ -485,9 +485,13 @@ class QuantileError : public MetricNoCache {
const char* Name() const override { return "quantile"; }
void LoadConfig(Json const& in) override {
- auto const& name = get(in["name"]);
- CHECK_EQ(name, "quantile");
- FromJson(in["quantile_loss_param"], ¶m_);
+ auto const& obj = get