From 6d75fe95fb9ed66b7fde2bf5763d82895050c91c Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 8 Feb 2023 04:02:44 +0800 Subject: [PATCH] Reimplement lambdamart ndcg. * Simplify the implementation for both CPU and GPU. Fix JSON IO. Check labels. Put idx into cache. Optimize. File tag. Weights. Trivial tests. Compatibility. Lint. Fix swap. Device weight. tidy. Easier to read R failure. enum. Fix global configuration. Tidy. msvc omp. dask. Remove ndcg specific parameter. Drop label type for smaller PR. Fix rebase. Fixes. Don't mess with includes. Fixes. Format. Use omp util. Restore some old code. Revert. Port changes from the work on quantile loss. python binding. param. Cleanup. conditional parallel. types. Move doc. fix. need metric rewrite. rename ctx. extract. Work on metric. Metric Init estimation. extract tests, compute ties. cleanup. notes. extract optional weights. init. cleanup. old metric format. note. ndcg cache. nested. debug. fix. log2. Begin CUDA work. temp. Extract sort and latest cuda. truncation. dcg. dispatch. try different gain type. start looking into ub. note. consider writing a doc. check exp gain. Reimplement lambdamart ndcg. * Simplify the implementation for both CPU and GPU. Fix JSON IO. Check labels. Put idx into cache. Optimize. File tag. Weights. Trivial tests. Compatibility. Lint. Fix swap. Device weight. tidy. Easier to read R failure. enum. Fix global configuration. Tidy. msvc omp. dask. Remove ndcg specific parameter. Drop label type for smaller PR. Fix rebase. Fixes. Don't mess with includes. Fixes. Format. Use omp util. Restore some old code. Revert. Port changes from the work on quantile loss. python binding. param. Cleanup. conditional parallel. types. Move doc. fix. need metric rewrite. rename ctx. extract. Work on metric. Metric Init estimation. extract tests, compute ties. cleanup. notes. extract optional weights. init. cleanup. old metric format. note. ndcg cache. nested. debug. fix. log2. Begin CUDA work. temp. Extract sort and latest cuda. truncation. dcg. dispatch. try different gain type. start looking into ub. note. consider writing a doc. check exp gain. Start looking into unbiased. lambda. Extract the ndcg cache. header. cleanup namespace. small check. namespace. init with param. gain. extract. groups. Cleanup. disable. debug. remove. Revert "remove." This reverts commit ea025f9e8085f5624db8bffbee801dbcd60f3ff5. sigmoid. cleanup. metric name. check scores. note. check map. extract utilities. avoid inline. fix. header. extract more. note. note. note. start working on map. fix. continue map. map. matrix. Remove map. note. format. move check. cleanup. use cached discount, use double. cleanup. Add position to the Python interface. pass it into lambda. Full ratio. rank. comment. some work on GPU. compile. move cache initialization. descending. Fix arg sort. basic ndcg score. metric weight. config. extract. pass position again. Define a metric decorator. position. decorate metric.. return. note. irrelevant docs. fix weights. header. Share the bias. Use position check info. use cache for param. note. prepare to work on deterministic gpu. rounding. Extract op. cleanup. Use it. check label. ditch launchn. rounding. Move rounding into cache. fix check label. GPU fixes. Irrelevant doc. try to avoid inf. mad. Work on metric cache. Cleanup sort. use cache. cache others. revert. add test for metric. fixes. msg. note. remove reduce by key. comments. check position. stream. min. small cleanup. use atomic for now. fill. no inline. norm. remove op. start gpu. cleanup. use gpu for update. segmented reduce. revert. comments. comments. fix. comments. fix bounds. comments. cache. pointer. fixes. no spark. revert. Cleanup. cleanup. work on gain type. fix. notes. make metric name. remove. revert. revert. comment. revert. Move back into rank metric. Set name in objective. fix. Don't configure. note. merge tests. accept empty group. fixes. float. revert and fix. not mutable. prototype for cache. extract. convert to DMatrix. cache. Extract the cache. Port changes. fix & cleanup. cleanup. cleanup. Rename. restore. remove. header. revert. rename. rename. doc. cleanup. doc. cleanup. tests. tests. split up. jvm parameters. doc. Fix. Use cache in cox. Revert "Use cache in cox." This reverts commit e1cec376eab37c22d93180ea4ebecc828af9ca2e. Remove pairwise. iwyu. rename. Move. Merge. ranking utils. Fixes. rename. Comments. todos. Small cleanup. doc. Start working on demo. move some code here. rename. Update doc. Update doc. Work on demo. work on demo. demo. Demo. Specify the max rel degree. remove position. Fix. Work on demo. demo. Using only one fold. cache. demo. schema. comments. Lint. fix test. automake. macos. schema. test. schema. lint. fix tests. Implement MAP and pair sampling. revert sorting. Work on ranknet. remove. Don't upgrade cost if larger than. Extract GPU make pairs. error message. Remove. Cleanup some gpu tests. Move. Move NDCG test. fix weights. Move rest of the tests. Remove. Work on tests. fixes. Cleanup. header. cleanup. Update document. update document. fix build. cpplint. rename. Fixes and cleanup. Cleanup tests. lint. fix tests. debug macos non-openmp checks. macos. fix ndcg test. Ensure number of threads is smaller than the number of inputs. fix. Debug macos. fixes. Add weight normalization. Note on reproducible result. Don't normalize if it's binary. old ctk. Use old objective. Update doc. Convert pyspark tests. black. Fix rebase. Fix rebase. Start looking into CV. Hacky score function. extract parsing. Cleanup and tests. Lint & note. test check. Update document. Update tests & doc. Support custom metric as well. c++-17. cleanup old metrics. rename. Fixes. Fix cxx test. test cudf. start converting tests. pylint. fix data load. Cleanup the tests. Parameter tests. isort. Fix test. Specify src path for isort. 17 goodies. --- R-package/src/Makevars.in | 2 +- R-package/src/Makevars.win | 2 +- demo/guide-python/learning_to_rank.py | 462 +++++++++ demo/guide-python/quantile_regression.py | 2 + doc/contrib/coding_guide.rst | 4 +- doc/model.schema | 18 +- doc/parameter.rst | 52 +- doc/tutorials/dask.rst | 1 + doc/tutorials/index.rst | 1 + doc/tutorials/learning_to_rank.rst | 177 ++++ include/xgboost/cache.h | 21 + include/xgboost/data.h | 11 +- include/xgboost/objective.h | 1 + .../spark/GpuXGBoostRegressorSuite.scala | 2 +- .../spark/params/LearningTaskParams.scala | 2 +- .../scala/spark/XGBoostGeneralSuite.scala | 6 +- .../scala/spark/XGBoostRegressorSuite.scala | 2 +- python-package/xgboost/core.py | 51 +- python-package/xgboost/sklearn.py | 8 +- python-package/xgboost/testing/__init__.py | 9 +- python-package/xgboost/testing/params.py | 12 + src/common/algorithm.h | 66 +- src/common/error_msg.h | 26 + src/common/math.h | 32 +- src/common/numeric.h | 16 +- src/common/ranking_utils.cc | 150 ++- src/common/ranking_utils.cu | 213 ++++ src/common/ranking_utils.cuh | 40 + src/common/ranking_utils.h | 441 +++++++- src/common/threading_utils.cuh | 35 +- src/common/threading_utils.h | 10 +- src/data/data.cc | 47 +- src/learner.cc | 4 +- src/metric/elementwise_metric.cu | 10 +- src/metric/rank_metric.cc | 330 ++++-- src/metric/rank_metric.cu | 298 +++--- src/metric/rank_metric.h | 44 + src/objective/init_estimation.cc | 8 +- src/objective/init_estimation.h | 6 +- src/objective/lambdarank_obj.cc | 603 +++++++++++ src/objective/lambdarank_obj.cu | 597 +++++++++++ src/objective/lambdarank_obj.cuh | 172 ++++ src/objective/lambdarank_obj.h | 262 +++++ src/objective/objective.cc | 5 +- src/objective/rank_obj.cc | 17 - src/objective/rank_obj.cu | 961 ------------------ tests/ci_build/lint_python.py | 1 + tests/cpp/common/test_algorithm.cc | 61 +- tests/cpp/common/test_ranking_utils.cc | 65 +- tests/cpp/common/test_ranking_utils.cu | 62 ++ tests/cpp/common/test_ranking_utils.h | 114 +++ tests/cpp/metric/test_rank_metric.cc | 84 +- tests/cpp/objective/test_lambdarank_obj.cc | 273 +++++ tests/cpp/objective/test_lambdarank_obj.cu | 161 +++ tests/cpp/objective/test_lambdarank_obj.h | 46 + tests/cpp/objective/test_ranking_obj.cc | 128 --- tests/cpp/objective/test_ranking_obj_gpu.cu | 268 ----- tests/python-gpu/test_gpu_ranking.py | 306 +++--- tests/python/test_eval_metrics.py | 4 +- tests/python/test_ranking.py | 27 +- tests/python/test_with_sklearn.py | 43 +- .../test_with_spark/test_spark_local.py | 109 +- 62 files changed, 4932 insertions(+), 2059 deletions(-) create mode 100644 demo/guide-python/learning_to_rank.py create mode 100644 doc/tutorials/learning_to_rank.rst create mode 100644 src/common/error_msg.h create mode 100644 src/common/ranking_utils.cu create mode 100644 src/common/ranking_utils.cuh create mode 100644 src/metric/rank_metric.h create mode 100644 src/objective/lambdarank_obj.cc create mode 100644 src/objective/lambdarank_obj.cu create mode 100644 src/objective/lambdarank_obj.cuh create mode 100644 src/objective/lambdarank_obj.h delete mode 100644 src/objective/rank_obj.cc delete mode 100644 src/objective/rank_obj.cu create mode 100644 tests/cpp/common/test_ranking_utils.cu create mode 100644 tests/cpp/common/test_ranking_utils.h create mode 100644 tests/cpp/objective/test_lambdarank_obj.cc create mode 100644 tests/cpp/objective/test_lambdarank_obj.cu create mode 100644 tests/cpp/objective/test_lambdarank_obj.h delete mode 100644 tests/cpp/objective/test_ranking_obj.cc delete mode 100644 tests/cpp/objective/test_ranking_obj_gpu.cu diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index ed3f10571edf..33452e142ed1 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -32,7 +32,7 @@ OBJECTS= \ $(PKGROOT)/src/objective/objective.o \ $(PKGROOT)/src/objective/regression_obj.o \ $(PKGROOT)/src/objective/multiclass_obj.o \ - $(PKGROOT)/src/objective/rank_obj.o \ + $(PKGROOT)/src/objective/lambdarank_obj.o \ $(PKGROOT)/src/objective/hinge.o \ $(PKGROOT)/src/objective/aft_obj.o \ $(PKGROOT)/src/objective/adaptive.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 024ba1aa19b7..f176642643cc 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -32,7 +32,7 @@ OBJECTS= \ $(PKGROOT)/src/objective/objective.o \ $(PKGROOT)/src/objective/regression_obj.o \ $(PKGROOT)/src/objective/multiclass_obj.o \ - $(PKGROOT)/src/objective/rank_obj.o \ + $(PKGROOT)/src/objective/lambdarank_obj.o \ $(PKGROOT)/src/objective/hinge.o \ $(PKGROOT)/src/objective/aft_obj.o \ $(PKGROOT)/src/objective/adaptive.o \ diff --git a/demo/guide-python/learning_to_rank.py b/demo/guide-python/learning_to_rank.py new file mode 100644 index 000000000000..9dd46cb6c0d4 --- /dev/null +++ b/demo/guide-python/learning_to_rank.py @@ -0,0 +1,462 @@ +""" +Getting started with learning to rank +===================================== + + .. versionadded:: 2.0.0 + +This is a demonstration of using XGBoost for learning to rank tasks using the +MSLR_10k_letor dataset. For more infomation about the dataset, please visit its +`description page `_. + +This is a two-part demo, the first one contains a basic example of using XGBoost to +train on relevance degree, and the second part simulates click data and enable the +position debiasing training. + +For an overview of learning to rank in XGBoost, please see +:doc:`Learning to Rank `. +""" +from __future__ import annotations + +import argparse +import os +import pickle as pkl +from collections import namedtuple +from time import time +from typing import List, NamedTuple, Tuple, TypedDict + +import numpy as np +from numpy import typing as npt +from scipy import sparse +from scipy.sparse import csr_matrix +from sklearn.datasets import load_svmlight_file + +import xgboost as xgb + + +class PBM: + """Simulate click data with position bias model. There are other models available in + `ULTRA `_ like the cascading model. + + References + ---------- + Unbiased LambdaMART: An Unbiased Pairwise Learning-to-Rank Algorithm + + """ + + def __init__(self, eta: float) -> None: + # click probability for each relevance degree. (from 0 to 4) + self.click_prob = np.array([0.1, 0.16, 0.28, 0.52, 1.0]) + exam_prob = np.array( + [0.68, 0.61, 0.48, 0.34, 0.28, 0.20, 0.11, 0.10, 0.08, 0.06] + ) + self.exam_prob = np.power(exam_prob, eta) + + def sample_clicks_for_query( + self, labels: npt.NDArray[np.int32], position: npt.NDArray[np.int64] + ) -> npt.NDArray[np.int32]: + """Sample clicks for one query based on input relevance degree and position. + + Parameters + ---------- + + labels : + relevance_degree + + """ + labels = np.array(labels, copy=True) + + click_prob = np.zeros(labels.shape) + # minimum + labels[labels < 0] = 0 + # maximum + labels[labels >= len(self.click_prob)] = -1 + click_prob = self.click_prob[labels] + + exam_prob = np.zeros(labels.shape) + assert position.size == labels.size + ranks = np.array(position, copy=True) + # maximum + ranks[ranks >= self.exam_prob.size] = -1 + exam_prob = self.exam_prob[ranks] + + rng = np.random.default_rng(1994) + prob = rng.random(size=labels.shape[0], dtype=np.float32) + + clicks: npt.NDArray[np.int32] = np.zeros(labels.shape, dtype=np.int32) + clicks[prob < exam_prob * click_prob] = 1 + return clicks + + +# relevance degree data +RelData = Tuple[csr_matrix, npt.NDArray[np.int32], npt.NDArray[np.int32]] + + +class RelDataCV(NamedTuple): + train: RelData + valid: RelData + test: RelData + + +def load_mlsr_10k(data_path: str, cache_path: str) -> RelDataCV: + """Load the MSLR10k dataset from data_path and cache a pickle object in cache_path. + + Returns + ------- + + A list of tuples [(X, y, qid), ...]. + + """ + root_path = os.path.expanduser(args.data) + cacheroot_path = os.path.expanduser(args.cache) + cache_path = os.path.join(cacheroot_path, "MSLR_10K_LETOR.pkl") + + # Use only the Fold1 for demo: + # Train, Valid, Test + # {S1,S2,S3}, S4, S5 + fold = 1 + + if not os.path.exists(cache_path): + fold_path = os.path.join(root_path, f"Fold{fold}") + train_path = os.path.join(fold_path, "train.txt") + valid_path = os.path.join(fold_path, "vali.txt") + test_path = os.path.join(fold_path, "test.txt") + X_train, y_train, qid_train = load_svmlight_file( + train_path, query_id=True, dtype=np.float32 + ) + y_train = y_train.astype(np.int32) + qid_train = qid_train.astype(np.int32) + + X_valid, y_valid, qid_valid = load_svmlight_file( + valid_path, query_id=True, dtype=np.float32 + ) + y_valid = y_valid.astype(np.int32) + qid_valid = qid_valid.astype(np.int32) + + X_test, y_test, qid_test = load_svmlight_file( + test_path, query_id=True, dtype=np.float32 + ) + y_test = y_test.astype(np.int32) + qid_test = qid_test.astype(np.int32) + + data = RelDataCV( + train=(X_train, y_train, qid_train), + valid=(X_valid, y_valid, qid_valid), + test=(X_test, y_test, qid_test), + ) + + with open(cache_path, "wb") as fd: + pkl.dump(data, fd) + + with open(cache_path, "rb") as fd: + data = pkl.load(fd) + + return data + + +def ranking_demo(args: argparse.Namespace) -> None: + """Demonstration for learning to rank with relevance degree.""" + data = load_mlsr_10k(args.data, args.cache) + + X_train, y_train, qid_train = data.train + sorted_idx = np.argsort(qid_train) + X_train = X_train[sorted_idx] + y_train = y_train[sorted_idx] + qid_train = qid_train[sorted_idx] + + X_valid, y_valid, qid_valid = data.valid + sorted_idx = np.argsort(qid_valid) + X_valid = X_valid[sorted_idx] + y_valid = y_valid[sorted_idx] + qid_valid = qid_valid[sorted_idx] + + ranker = xgb.XGBRanker( + tree_method="gpu_hist", + lambdarank_pair_method="topk", + lambdarank_num_pair_per_sample=13, + eval_metric=["ndcg@1", "ndcg@8"], + ) + ranker.fit( + X_train, + y_train, + qid=qid_train, + eval_set=[(X_valid, y_valid)], + eval_qid=[qid_valid], + verbose=True, + ) + + +def rlencode(x: npt.NDArray[np.int32]) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]: + """Run length encoding using numpy, modified from: + https://gist.github.com/nvictus/66627b580c13068589957d6ab0919e66 + + """ + x = np.asarray(x) + n = x.size + starts = np.r_[0, np.flatnonzero(~np.isclose(x[1:], x[:-1], equal_nan=True)) + 1] + lengths = np.diff(np.r_[starts, n]) + values = x[starts] + indptr = np.append(starts, np.array([x.size])) + + return indptr, lengths, values + + +ClickFold = namedtuple("ClickFold", ("X", "y", "q", "s", "c", "p")) + + +def simulate_clicks(args: argparse.Namespace) -> ClickFold: + """Simulate click data using position biased model (PBM).""" + + def init_rank_score( + X: csr_matrix, + y: npt.NDArray[np.int32], + qid: npt.NDArray[np.int32], + sample_rate: float = 0.01, + ) -> npt.NDArray[np.float32]: + """We use XGBoost to generate the initial score instead of SVMRank for + simplicity. + + """ + # random sample + _rng = np.random.default_rng(1994) + n_samples = int(X.shape[0] * sample_rate) + index = np.arange(0, X.shape[0], dtype=np.uint64) + _rng.shuffle(index) + index = index[:n_samples] + + X_train = X[index] + y_train = y[index] + qid_train = qid[index] + + # Sort training data based on query id, required by XGBoost. + sorted_idx = np.argsort(qid_train) + X_train = X_train[sorted_idx] + y_train = y_train[sorted_idx] + qid_train = qid_train[sorted_idx] + + ltr = xgb.XGBRanker(objective="rank:ndcg", tree_method="gpu_hist") + ltr.fit(X_train, y_train, qid=qid_train) + + # Use the original order of the data. + scores = ltr.predict(X) + return scores + + def simulate_one_fold(fold, scores_fold: npt.NDArray[np.float32]) -> ClickFold: + """Simulate clicks for one fold.""" + X_fold, y_fold, qid_fold = fold + assert qid_fold.dtype == np.int32 + indptr, lengths, values = rlencode(qid_fold) + + qids = np.unique(qid_fold) + + position = np.empty((y_fold.size,), dtype=np.int64) + clicks = np.empty((y_fold.size,), dtype=np.int32) + pbm = PBM(eta=1.0) + + # Avoid grouping by qid as we want to preserve the original data partition by + # the dataset authors. + for q in qids: + qid_mask = q == qid_fold + query_scores = scores_fold[qid_mask] + # Initial rank list, scores sorted to decreasing order + query_position = np.argsort(query_scores)[::-1] + position[qid_mask] = query_position + # get labels + relevance_degrees = y_fold[qid_mask] + query_clicks = pbm.sample_clicks_for_query( + relevance_degrees, query_position + ) + clicks[qid_mask] = query_clicks + + assert X_fold.shape[0] == qid_fold.shape[0], (X_fold.shape, qid_fold.shape) + assert X_fold.shape[0] == clicks.shape[0], (X_fold.shape, clicks.shape) + + return ClickFold(X_fold, y_fold, qid_fold, scores_fold, clicks, position) + + cache_path = os.path.join( + os.path.expanduser(args.cache), "MSLR_10K_LETOR_Clicks.pkl" + ) + if os.path.exists(cache_path): + print("Found existing cache for clicks.") + with open(cache_path, "rb") as fdr: + new_folds = pkl.load(fdr) + return new_folds + + cv_data = load_mlsr_10k(args.data, args.cache) + X, y, qid = list(zip(cv_data.train, cv_data.valid, cv_data.test)) + + indptr = np.array([0] + [v.shape[0] for v in X]) + indptr = np.cumsum(indptr) + + assert len(indptr) == 3 + 1 # train, valid, test + X_full = sparse.vstack(X) + y_full = np.concatenate(y) + qid_full = np.concatenate(qid) + + # Skip the data cleaning here for demonstration purposes. + + # Obtain initial relevance score for click simulation + scores_full = init_rank_score(X_full, y_full, qid_full) + # partition it back to train,valid,test tuple + scores = [scores_full[indptr[i - 1] : indptr[i]] for i in range(1, indptr.size)] + + ( + X_full, + y_full, + qid_full, + scores_ret, + clicks_full, + position_full, + ) = simulate_one_fold((X_full, y_full, qid_full), scores_full) + + scores_check_1 = [ + scores_ret[indptr[i - 1] : indptr[i]] for i in range(1, indptr.size) + ] + for i in range(3): + assert (scores_check_1[i] == scores[i]).all() + + position = [position_full[indptr[i - 1] : indptr[i]] for i in range(1, indptr.size)] + clicks = [clicks_full[indptr[i - 1] : indptr[i]] for i in range(1, indptr.size)] + + with open(cache_path, "wb") as fdw: + data = ClickFold(X, y, qid, scores, clicks, position) + pkl.dump(data, fdw) + + return data + + +def sort_samples( + X: csr_matrix, + y: npt.NDArray[np.int32], + qid: npt.NDArray[np.int32], + clicks: npt.NDArray[np.int32], + pos: npt.NDArray[np.int64], + cache_path: str, +) -> Tuple[ + csr_matrix, npt.NDArray[np.int32], npt.NDArray[np.int32], npt.NDArray[np.int32] +]: + """Sort data based on query index and position.""" + if os.path.exists(cache_path): + print(f"Found existing cache: {cache_path}") + with open(cache_path, "rb") as fdr: + data = pkl.load(fdr) + return data + + s = time() + sorted_idx = np.argsort(qid) + X = X[sorted_idx] + clicks = clicks[sorted_idx] + qid = qid[sorted_idx] + pos = pos[sorted_idx] + + indptr, lengths, values = rlencode(qid) + + for i in range(1, indptr.size): + beg = indptr[i - 1] + end = indptr[i] + + assert beg < end, (beg, end) + assert np.unique(qid[beg:end]).size == 1, (beg, end) + + query_pos = pos[beg:end] + assert query_pos.min() == 0, query_pos.min() + assert query_pos.max() >= query_pos.size - 1, ( + query_pos.max(), + query_pos.size, + i, + np.unique(qid[beg:end]), + ) + sorted_idx = np.argsort(query_pos) + + X[beg:end] = X[beg:end][sorted_idx] + clicks[beg:end] = clicks[beg:end][sorted_idx] + y[beg:end] = y[beg:end][sorted_idx] + # not necessary + qid[beg:end] = qid[beg:end][sorted_idx] + + e = time() + print("Sort samples:", e - s) + data = X, clicks, y, qid + + with open(cache_path, "wb") as fdw: + pkl.dump(data, fdw) + + return data + + +def click_data_demo(args: argparse.Namespace) -> None: + """Demonstration for learning to rank with click data.""" + folds = simulate_clicks(args) + + train = [pack[0] for pack in folds] + valid = [pack[1] for pack in folds] + test = [pack[2] for pack in folds] + + X_train, y_train, qid_train, scores_train, clicks_train, position_train = train + assert X_train.shape[0] == clicks_train.size + X_valid, y_valid, qid_valid, scores_valid, clicks_valid, position_valid = valid + assert X_valid.shape[0] == clicks_valid.size + assert scores_valid.dtype == np.float32 + assert clicks_valid.dtype == np.int32 + cache_path = os.path.expanduser(args.cache) + + X_train, clicks_train, y_train, qid_train = sort_samples( + X_train, + y_train, + qid_train, + clicks_train, + position_train, + os.path.join(cache_path, "sorted.train.pkl"), + ) + X_valid, clicks_valid, y_valid, qid_valid = sort_samples( + X_valid, + y_valid, + qid_valid, + clicks_valid, + position_valid, + os.path.join(cache_path, "sorted.valid.pkl"), + ) + + ranker = xgb.XGBRanker( + n_estimators=512, + tree_method="hist", + boost_from_average=0, + grow_policy="lossguide", + learning_rate=0.1, + # LTR specific parameters + objective="rank:ndcg", + lambdarank_unbiased=True, + lambdarank_bias_norm=0.5, + lambdarank_num_pair_per_sample=8, + lambdarank_pair_method="topk", + ndcg_exp_gain=True, + eval_metric=["ndcg@1", "ndcg@8", "ndcg@32"], + ) + ranker.fit( + X_train, + clicks_train, + qid=qid_train, + eval_set=[(X_train, clicks_train), (X_valid, y_valid)], + eval_qid=[qid_train, qid_valid], + verbose=True, + ) + X_test = test[0] + ranker.predict(X_test) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Demonstration of learning to rank using XGBoost." + ) + parser.add_argument( + "--data", type=str, help="Root directory of the MSLR data.", required=True + ) + parser.add_argument( + "--cache", + type=str, + help="Directory for caching processed data.", + required=True, + ) + args = parser.parse_args() + + ranking_demo(args) + click_data_demo(args) diff --git a/demo/guide-python/quantile_regression.py b/demo/guide-python/quantile_regression.py index d92115bf08d7..e6dc0847f4a3 100644 --- a/demo/guide-python/quantile_regression.py +++ b/demo/guide-python/quantile_regression.py @@ -2,6 +2,8 @@ Quantile Regression =================== + .. versionadded:: 2.0.0 + The script is inspired by this awesome example in sklearn: https://scikit-learn.org/stable/auto_examples/ensemble/plot_gradient_boosting_quantile.html diff --git a/doc/contrib/coding_guide.rst b/doc/contrib/coding_guide.rst index a080c2a31541..57bf07de4d74 100644 --- a/doc/contrib/coding_guide.rst +++ b/doc/contrib/coding_guide.rst @@ -16,8 +16,10 @@ C++ Coding Guideline * Each line of text may contain up to 100 characters. * The use of C++ exceptions is allowed. -- Use C++11 features such as smart pointers, braced initializers, lambda functions, and ``std::thread``. +- Use C++14 features such as smart pointers, braced initializers, lambda functions, and ``std::thread``. - Use Doxygen to document all the interface code. +- We have some comments around symbols imported by headers, some of those are hinted by `include-what-you-use `_. It's not required. +- We use clang-tidy and clang-format. You can check their configuration in the root directory of the XGBoost source tree. - We have a series of automatic checks to ensure that all of our codebase complies with the Google style. Before submitting your pull request, you are encouraged to run the style checks on your machine. See :ref:`running_checks_locally`. *********************** diff --git a/doc/model.schema b/doc/model.schema index 07a871820b5a..9d4c74607342 100644 --- a/doc/model.schema +++ b/doc/model.schema @@ -238,6 +238,16 @@ "num_pairsample": { "type": "string" }, "fix_list_weight": { "type": "string" } } + }, + "lambdarank_param": { + "type": "object", + "properties": { + "lambdarank_num_pair_per_sample": { "type": "string" }, + "lambdarank_pair_method": { "type": "string" }, + "lambdarank_unbiased": {"type": "string" }, + "lambdarank_bias_norm": {"type": "string" }, + "ndcg_exp_gain": {"type": "string"} + } } }, "type": "object", @@ -496,22 +506,22 @@ "type": "object", "properties": { "name": { "const": "rank:pairwise" }, - "lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"} + "lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"} }, "required": [ "name", - "lambda_rank_param" + "lambdarank_param" ] }, { "type": "object", "properties": { "name": { "const": "rank:ndcg" }, - "lambda_rank_param": { "$ref": "#/definitions/lambda_rank_param"} + "lambda_rank_param": { "$ref": "#/definitions/lambdarank_param"} }, "required": [ "name", - "lambda_rank_param" + "lambdarank_param" ] }, { diff --git a/doc/parameter.rst b/doc/parameter.rst index 99d6f0585936..38a3bf37b70d 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -363,8 +363,7 @@ Specify the learning task and the corresponding learning objective. The objectiv - ``aft_loss_distribution``: Probability Density Function used by ``survival:aft`` objective and ``aft-nloglik`` metric. - ``multi:softmax``: set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes) - ``multi:softprob``: same as softmax, but output a vector of ``ndata * nclass``, which can be further reshaped to ``ndata * nclass`` matrix. The result contains predicted probability of each data point belonging to each class. - - ``rank:pairwise``: Use LambdaMART to perform pairwise ranking where the pairwise loss is minimized - - ``rank:ndcg``: Use LambdaMART to perform list-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) `_ is maximized + - ``rank:ndcg``: Use LambdaMART to perform list-wise ranking where `Normalized Discounted Cumulative Gain (NDCG) `_ is maximized. This objective supports position debiasing for click data. - ``rank:map``: Use LambdaMART to perform list-wise ranking where `Mean Average Precision (MAP) `_ is maximized - ``reg:gamma``: gamma regression with log-link. Output is a mean of gamma distribution. It might be useful, e.g., for modeling insurance claims severity, or for any outcome that might be `gamma-distributed `_. - ``reg:tweedie``: Tweedie regression with log-link. It might be useful, e.g., for modeling total loss in insurance, or for any outcome that might be `Tweedie-distributed `_. @@ -378,8 +377,9 @@ Specify the learning task and the corresponding learning objective. The objectiv * ``eval_metric`` [default according to objective] - - Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, mean average precision for ranking) - - User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous one + - Evaluation metrics for validation data, a default metric will be assigned according to objective (rmse for regression, and logloss for classification, `mean average precision` for ``rank:map``, etc.) + - User can add multiple evaluation metrics. Python users: remember to pass the metrics in as list of parameters pairs instead of map, so that latter ``eval_metric`` won't override previous ones + - The choices are listed below: - ``rmse``: `root mean square error `_ @@ -408,8 +408,17 @@ Specify the learning task and the corresponding learning objective. The objectiv - ``ndcg``: `Normalized Discounted Cumulative Gain `_ - ``map``: `Mean Average Precision `_ - - ``ndcg@n``, ``map@n``: 'n' can be assigned as an integer to cut off the top positions in the lists for evaluation. - - ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, NDCG and MAP will evaluate the score of a list without any positive samples as 1. By adding "-" in the evaluation metric XGBoost will evaluate these score as 0 to be consistent under some conditions. + + The `average precision` is defined as: + + .. math:: + + AP@l = \frac{1}{min{(l, N)}}\sum^l_{k=1}P@k \cdot I_{(k)} + + where :math:`I_{(k)}` is an indicator function that equals to :math:`1` when the document at :math:`k` is relevant and :math:`0` otherwise. The :math:`P@k` is the precision at :math:`k`, and :math:`N` is the total number of relevant documents. Lastly, the `mean average precision` is defined as the weighted average across all queries. + + - ``ndcg@n``, ``map@n``: :math:`n` can be assigned as an integer to cut off the top positions in the lists for evaluation. + - ``ndcg-``, ``map-``, ``ndcg@n-``, ``map@n-``: In XGBoost, the NDCG and MAP evaluate the score of a list without any positive samples as :math:`1`. By appending "-" to the evaluation metric name, we can ask XGBoost to evaluate these scores as :math:`0` to be consistent under some conditions. - ``poisson-nloglik``: negative log-likelihood for Poisson regression - ``gamma-nloglik``: negative log-likelihood for gamma regression - ``cox-nloglik``: negative partial log-likelihood for Cox proportional hazards regression @@ -447,6 +456,37 @@ Parameter for using Quantile Loss (``reg:quantileerror``) * ``quantile_alpha``: A scala or a list of targeted quantiles. + +.. _ltr-param: + +Parameters for learning to rank (``rank:ndcg``, ``rank:map``, ``rank:pairwise``) +================================================================================ + +These are parameters specific to learning to rank task. See :doc:`Learning to Rank ` for an in-depth explanation. + +* ``lambdarank_pair_method`` [default = ``mean``] + + How to construct pairs for pair-wise learning. + + - ``mean``: Sample ``lambdarank_num_pair_per_sample`` pairs for each document in the query list. + - ``topk``: Focus on top-``lambdarank_num_pair_per_sample`` documents. Construct :math:`|query|` pairs for each document at the top-``lambdarank_num_pair_per_sample`` ranked by the model. + +* ``lambdarank_num_pair_per_sample`` [range = :math:`[1, \infty]`] + + It specifies the number of pairs sampled for each document when pair method is ``mean``, or the truncation level for queries when the pair method is ``topk``. For example, to train with ``ndcg@6``, set ``lambdarank_num_pair_per_sample`` to :math:`6` and ``lambdarank_pair_method`` to ``topk``. + +* ``lambdarank_unbiased`` [default = ``false``] + + Specify whether do we need to debias input click data. + +* ``lambdarank_bias_norm`` [default = 2.0] + + :math:`L_p` normalization for position debiasing, default is :math:`L_2`. Only relevant when ``lambdarank_unbiased`` is set to true. + +* ``ndcg_exp_gain`` [default = ``true``] + + Whether we should use exponential gain function for ``NDCG``. There are two forms of gain function for ``NDCG``, one is using relevance value directly while the other is using :math:`2^{rel} - 1` to emphasize on retrieving relevant documents. When ``ndcg_exp_gain`` is true (the default), relevance degree cannot be greater than 31. + *********************** Command Line Parameters *********************** diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst index 87b2bf9968b9..fd5155f468cd 100644 --- a/doc/tutorials/dask.rst +++ b/doc/tutorials/dask.rst @@ -503,6 +503,7 @@ dask config is used: reg = dxgb.DaskXGBRegressor() +Please note that XGBoost requires a different port than dask. By default, on a unix-like system XGBoost uses the port 0 to find available ports, which may fail if a user is running in a docker environment where ports are restricted. ************ IPv6 Support diff --git a/doc/tutorials/index.rst b/doc/tutorials/index.rst index 310fd0170610..eb8c23726d56 100644 --- a/doc/tutorials/index.rst +++ b/doc/tutorials/index.rst @@ -21,6 +21,7 @@ See `Awesome XGBoost `_ for mo monotonic rf feature_interaction_constraint + learning_to_rank aft_survival_analysis c_api_tutorial input_format diff --git a/doc/tutorials/learning_to_rank.rst b/doc/tutorials/learning_to_rank.rst new file mode 100644 index 000000000000..5afd9314f055 --- /dev/null +++ b/doc/tutorials/learning_to_rank.rst @@ -0,0 +1,177 @@ +################ +Learning to Rank +################ + +**Contents** + +.. contents:: + :local: + :backlinks: none + +******** +Overview +******** +Often in the context of information retrieval, learning to rank aims to train a model that arranges a set of query results into an ordered list `[1] <#references>`__. For surprivised learning to rank, the predictors are sample documents encoded as feature matrix, and the labels are relevance degree for each sample. Relevance degree can be multi-level (graded) or binary (relevant or not). The training samples are often grouped by their query index with each query group containing multiple query results. + +XGBoost implements learning to rank through a set of objective functions and performane metrics. The default objective is ``rank:ndcg`` based on the ``LambdaMART`` `[2] <#references>`__ algorithm, which in turn is an adaptation of the ``LambdaRank`` `[3] <#references>`__ framework to gradient boosting trees. For a history and a summary of the algorithm, see `[5] <#references>`__. The implementation in XGBoost features deterministic GPU computation, distributed training, position debiasing and two different pair construction strategies. + +************************************ +Training with the Pariwise Objective +************************************ +``LambdaMART`` is a pairwise ranking model, meaning that it compares the relevance degree for every pair of samples in a query group and calculate a proxy gradient for each pair. The default objective ``rank:ndcg`` is using the surrogate gradient derived from the ``ndcg`` metric. To train a XGBoost model, we need an additional sorted array called ``qid`` for specifying the query group of input samples. An example input would look like this: + ++-------+-----------+---------------+ +| QID | Label | Features | ++=======+===========+===============+ +| 1 | 0 | :math:`x_1` | ++-------+-----------+---------------+ +| 1 | 1 | :math:`x_2` | ++-------+-----------+---------------+ +| 1 | 0 | :math:`x_3` | ++-------+-----------+---------------+ +| 2 | 0 | :math:`x_4` | ++-------+-----------+---------------+ +| 2 | 1 | :math:`x_5` | ++-------+-----------+---------------+ +| 2 | 1 | :math:`x_6` | ++-------+-----------+---------------+ +| 2 | 1 | :math:`x_7` | ++-------+-----------+---------------+ + +Notice that the samples are sorted based on their query index in an non-decreasing order. Here the first three samples belong to the first query and the next four samples belong to the second. For the sake of simplicity, we will use a pseudo binary learning to rank dataset in the following snippets, with binary labels representing whether the result is relevant or not, and randomly assign the query group index to each sample. For an example that uses a real world dataset, please see :ref:`sphx_glr_python_examples_learning_to_rank.py`. + +.. code-block:: python + + from sklearn.datasets import make_classification + import numpy as np + + import xgboost as xgb + + # Make a pseudo ranking dataset for demonstration + X, y = make_classification(random_state=rng) + rng = np.random.default_rng(1994) + n_query_groups = 3 + qid = rng.integers(0, 3, size=X.shape[0]) + + # Sort the inputs based on query index + sorted_idx = np.argsort(qid) + X = X[sorted_idx, :] + y = y[sorted_idx] + +The simpliest way to train a ranking model is by using the sklearn estimator interface. Continuing the previous snippet, we can train a simple ranking model without tuning: + +.. code-block:: python + + ranker = xgb.XGBRanker(tree_method="hist", lambdarank_num_pair_per_sample=8, objective="rank:ndcg", lambdarank_pair_method="topk") + ranker.fit(X, y, qid=qid) + +Please note that, as of writing, there's no learning to rank interface in sklearn. As a result, the :py:class:`xgboost.XGBRanker` does not fully conform the sklearn estimator guideline and can not be directly used with some of its utility functions. For instances, the ``auc_score`` and ``ndcg_score`` in sklearn don't consider group information nor the pairwise loss. Most of the metrics are implemented as part of XGBoost, but to use sklearn utilities like :py:func:`sklearn.model_selection.cross_validation`, we need to make some adjustments in order to pass the `qid` as an additional parameter for :py:meth:`xgboost.XGBRanker.score`. The `X` for :py:class:`xgboost.XGBRanker` may contain a special column called ``qid`` when it's a pandas dataframe or a cuDF dataframe: + +.. code-block:: python + + df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])) + df["qid"] = qid + ranker.fit(df, y) # No need to pass qid as a separate argument + + from sklearn.model_selection import StratifiedGroupKFold, cross_val_score + # Works with cv in sklearn, along with HPO utilities like grid search cv. + kfold = StratifiedGroupKFold(shuffle=False) + cross_val_score(ranker, df, y, cv=kfold, groups=df.qid) + +The above snippets build a model using ``LambdaMART`` with the ``NDCG@8`` metric. The outputs of a ranker are relevance scores: + +.. code-block:: python + + scores = ranker.predict(X) + sorted_idx = np.argsort(scores)[::-1] + # Sort the relevance scores from most relevant to least relevant + scores = scores[sorted_idx] + + +************* +Position Bias +************* +Real relevance degree for query result is difficult to obtain as it often requires human judegs to examine the content of query results. When such labeled data is absent, we might want to train the model on ground truth data like user clicks. Another upside of using click data directly is that it can relect the up-to-date relevance status `[1] <#references>`__. However, user clicks are often nosiy and biased as users tend to choose results displayed in higher position. To ameliorate this issue, XGBoost implements the ``Unbiased LambdaMART`` `[4] <#references>`__ algorithm to debias the position-dependent click data. The feature can be enabled by the ``lambdarank_unbiased`` parameter, see :ref:`ltr-param` for related options and :ref:`sphx_glr_python_examples_learning_to_rank.py` for a worked example with simulated user clicks. + +**** +Loss +**** + +XGBoost implements different ``LambdaMART`` objectives based on different metrics. We list them here as a reference. Other than those used as objective function, XGBoost also implements metrics like ``pre`` (for precision) for evaluation. See :doc:`parameters ` for available options and the following sections for how to choose these objectives based of the amount of effective pairs. + +* NDCG +`Normalized Discounted Cumulative Gain` ``NDCG`` can be used with both binary relevance and multi-level relevance. If you are not sure about your data, this metric can be used as the default. The name for the objective is ``rank:ndcg``. + + +* MAP +`Mean average precision` ``MAP`` is a binary measure. It can be used when the relevance label is 0 or 1. The name for the objective is ``rank:map``. + + +* Pairwise + +The `LambdaMART` algorithm scales the logistic loss with learning to rank metrics like ``NDCG`` in the hope of including ranking infomation into the loss function. The ``rank:pairwise`` loss is the orginal version of the pairwise loss, also known as the `RankNet loss` `[7] <#references>`__ or the `pairwise logistic loss`. Unlike the ``rank:map`` and the ``rank:ndcg``, no scaling is applied (:math:`|\Delta Z_{ij}| = 1`). + +Whether scaling with a LTR metric is actually more effective is still up for debate, `[8] <#references>`__ provides a theoretical foundation for general lambda loss functions and some insights into the framework. + +****************** +Constructing Pairs +****************** + +There are two implemented strategies for constructing document pairs for :math:`\lambda`-gradient calculation. The first one is the ``mean`` method, another one is the ``topk`` method. The preferred strategy can be specified by the ``lambdarank_pair_method`` parameter. + +For the ``mean`` strategy, XGBoost samples ``lambdarank_num_pair_per_sample`` pairs for each document in a query list. For example, given a list of 3 documents and ``lambdarank_num_pair_per_sample`` is set to 2, XGBoost will randomly sample 6 pairs assuming the labels for these documents are different. On the other hand, if the pair method is set to ``topk``, XGBoost constructs about :math:`k \times |query|` number of pairs with :math:`|query|` pairs for each sample at the top :math:`k = lambdarank\_num\_pair` position. The number of pairs counted here is an approximation since we skip pairs that have the sample label. + +********************* +Obtaining Good Result +********************* + +Learning to rank is a sophisticated task and a field of heated research. It's not trivial to train a model that generalizes well. There are multiple loss functions available in XGBoost along with a set of hyper-parameters. This section contains some hints for how to choose those parameters as a starting point. One can further optimize the model by tuning these parameters. + +The first question would be how to choose an objective that matches the task at hand. If your input data is multi-level relevance degree, then either ``rank:ndcg`` or ``rank:pairwise`` should be used. However, when the input is binary we have multiple options based on the target metric. `[6] <#references>`__ provides some guidelines on this topic and users are encouraged to see the analysis done in their work. The choice should be based on the number of `effective pairs`, which refers to the number of pairs that can generate non-zero gradient and contribute to training. `LambdaMART` with ``MRR`` has the least amount of effective pairs as the :math:`\lambda`-gradient is only non-zero when the pair contains a non-relevant document ranked higher than the top relevant document. As a result, it's not implemented in XGBoost. Since ``NDCG`` is a multi-level metric, it usually generate more effective pairs than ``MAP``. + +However, when there's a sufficient amount of effective pairs, it's shown in `[6] <#references>`__ that matching the target metric with the objective is of significance. When the target metric is ``MAP`` and you are using a large dataset that can provide a sufficient amount of effective pairs, ``rank:map`` can in theory yield higher ``MAP`` value than the ``rank:ndcg``. + +The choice of pair method (``lambdarank_pair_method``) and the number of pairs for each sample (``lambdarank_num_pair_per_sample``) is similar, as the mean-``NDCG`` considers more pairs than ``NDCG@10``, it can generate more effective pairs and provide more granularity. Also, using the ``mean`` strategy can help the model generalize with random sampling. However, one might want to focus the training on the top :math:`k` documents instead of using all pairs in practice, the tradeoff should be made based on the user's goal. + +When using mean value instead of targeting a specific position by calculating the target metric (like ``NDCG``) over the whole query list, user can specify how many pairs they want in each query by setting the ``lambdarank_num_pair_per_sample`` and XGBoost will randomly sample this amount of pairs for each element in the query group (:math:`|pairs| = |query| \times num\_pairsample`). Often time, setting it to 1 can produce reasonable result, with higher value producing more pairs (with the hope that a reasonable amount of them being effective). On the other hand, if you are prioritizing the top :math:`k` documents, the ``lambdarank_num_pair_per_sample`` should be set to slightly higher than :math:`k` (with a few more documents) to obtain a good training result. + +In summary, to start off the training, if you have a large dataset, consider using the target-matching objective, otherwise ``NDCG`` or the RankNet loss (``rank:pairwise``) might be preferred. With the same target metric, use the ``lambdarank_num_pair_per_sample`` to specify the top :math:`k` documents for training if your dataset is large, and use the mean value version otherwise. Lastly, ``lambdarank_num_pair_per_sample`` can be used to control the amount of pairs for both methods. + +******************** +Distributed Training +******************** +XGBoost implements distributed learning-to-rank with integration of multiple frameworks including dask, spark, and pyspark. The interface is similar to single node. Please refer to document of the respective XGBoost interface for details. Scattering a query group onto multiple workers is theoretically sound but can affect the model accuracy. For most of the use cases, the small discrepancy is not an issue since when distributed training is involved the dataset is usually large. As a result, users don't need to partition the data based on group information. Given the dataset is correctly sorted, XGBoost can aggregate sample gradients accordingly. + +******************* +Reproducbile Result +******************* + +Like any other tasks, XGBoost should generate reproducbile results given the same hardware and software environments, along with data partitions if distributed interface is used. Even when the underlying environment has changed, the result should still be consistent. However, when the ``lambdarank_pair_method`` is set to ``mean``, XGBoost uses sampling, and the random number generator used on Windows (MSVC) is different from the one used on other platforms like Linux (GCC, Clang), the output varies significantly between these platforms. + +********** +References +********** + +[1] Tie-Yan Liu. 2009. "`Learning to Rank for Information Retrieval`_". Found. Trends Inf. Retr. 3, 3 (March 2009), 225–331. + +[2] Christopher J. C. Burges, Robert Ragno, and Quoc Viet Le. 2006. "`Learning to rank with nonsmooth cost functions`_". In Proceedings of the 19th International Conference on Neural Information Processing Systems (NIPS'06). MIT Press, Cambridge, MA, USA, 193–200. + +[3] Wu, Q., Burges, C.J.C., Svore, K.M. et al. "`Adapting boosting for information retrieval measures`_". Inf Retrieval 13, 254–270 (2010). + +[4] Ziniu Hu, Yang Wang, Qu Peng, Hang Li. "`Unbiased LambdaMART: An Unbiased Pairwise Learning-to-Rank Algorithm`_". Proceedings of the 2019 World Wide Web Conference. + +[5] Burges, Chris J.C. "`From RankNet to LambdaRank to LambdaMART: An Overview`_". MSR-TR-2010-82 + +[6] Pinar Donmez, Krysta M. Svore, and Christopher J.C. Burges. 2009. "`On the local optimality of LambdaRank`_". In Proceedings of the 32nd international ACM SIGIR conference on Research and development in information retrieval (SIGIR '09). Association for Computing Machinery, New York, NY, USA, 460–467. + +[7] Chris Burges, Tal Shaked, Erin Renshaw, Ari Lazier, Matt Deeds, Nicole Hamilton, and Greg Hullender. 2005. "`Learning to rank using gradient descent`_". In Proceedings of the 22nd international conference on Machine learning (ICML '05). Association for Computing Machinery, New York, NY, USA, 89–96. + +[8] Xuanhui Wang and Cheng Li and Nadav Golbandi and Mike Bendersky and Marc Najork. 2018. "`The LambdaLoss Framework for Ranking Metric Optimization`_". Proceedings of The 27th ACM International Conference on Information and Knowledge Management (CIKM '18). + +.. _`Learning to Rank for Information Retrieval`: https://doi.org/10.1561/1500000016 +.. _`Learning to rank with nonsmooth cost functions`: https://dl.acm.org/doi/10.5555/2976456.2976481 +.. _`Adapting boosting for information retrieval measures`: https://doi.org/10.1007/s10791-009-9112-1 +.. _`Unbiased LambdaMART: An Unbiased Pairwise Learning-to-Rank Algorithm`: https://dl.acm.org/doi/10.1145/3308558.3313447 +.. _`From RankNet to LambdaRank to LambdaMART: An Overview`: https://www.microsoft.com/en-us/research/publication/from-ranknet-to-lambdarank-to-lambdamart-an-overview/ +.. _`On the local optimality of LambdaRank`: https://doi.org/10.1145/1571941.1572021 +.. _`Learning to rank using gradient descent`: https://doi.org/10.1145/1102351.1102363 diff --git a/include/xgboost/cache.h b/include/xgboost/cache.h index 781f45b1c474..66dec55f96a4 100644 --- a/include/xgboost/cache.h +++ b/include/xgboost/cache.h @@ -15,6 +15,7 @@ #include // for move #include // for vector + namespace xgboost { class DMatrix; /** @@ -149,6 +150,26 @@ class DMatrixCache { } return container_.at(key).value; } + /** + * \brief Re-initialize the item in cache. + * + * Since the shared_ptr is used to hold the item, any reference that lives outside of + * the cache can no-longer be reached from the cache. + * + * We use reset instead of erase to avoid walking through the whole cache for renewing + * a single item. (the cache is FIFO, needs to maintain the order). + */ + template + std::shared_ptr ResetItem(std::shared_ptr m, Args const&... args) { + std::lock_guard guard{lock_}; + CheckConsistent(); + auto key = Key{m.get(), std::this_thread::get_id()}; + auto it = container_.find(key); + CHECK(it != container_.cend()); + it->second = {m, std::make_shared(args...)}; + CheckConsistent(); + return it->second.value; + } /** * \brief Get a const reference to the underlying hash map. Clear expired caches before * returning. diff --git a/include/xgboost/data.h b/include/xgboost/data.h index ec78c588d5d9..d7abb9db39bd 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -1,5 +1,5 @@ -/*! - * Copyright (c) 2015-2022 by XGBoost Contributors +/** + * Copyright 2015-2023 by XGBoost Contributors * \file data.h * \brief The input data structure of xgboost. * \author Tianqi Chen @@ -17,6 +17,8 @@ #include #include +#include // std::size_t +#include // std::uint64_t #include #include #include @@ -60,9 +62,8 @@ class MetaInfo { linalg::Tensor labels; /*! \brief data split mode */ DataSplitMode data_split_mode{DataSplitMode::kRow}; - /*! - * \brief the index of begin and end of a group - * needed when the learning task is ranking. + /** + * \brief the index of begin and end of a group, needed when the learning task is ranking. */ std::vector group_ptr_; // NOLINT /*! \brief weights of each instance, optional */ diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index a04d2e453df7..4fecf56884b2 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -11,6 +11,7 @@ #include #include #include +#include // for Json, Null #include #include diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala index 5342aa563621..b8dca5d7040e 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala @@ -220,7 +220,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite { test("Ranking: train with Group") { withGpuSparkSession(enableCsvConf()) { spark => - val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "rank:pairwise", + val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "rank:ndcg", "num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist", "features_cols" -> featureNames, "label_col" -> labelName) val Array(trainingDf, testDf) = spark.read.option("header", "true").schema(schema) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index f63865fabc2d..6aec4d36ed6f 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -25,7 +25,7 @@ private[spark] trait LearningTaskParams extends Params { /** * Specify the learning task and the corresponding learning objective. * options: reg:squarederror, reg:squaredlogerror, reg:logistic, binary:logistic, binary:logitraw, - * count:poisson, multi:softmax, multi:softprob, rank:pairwise, reg:gamma. + * count:poisson, multi:softmax, multi:softprob, rank:ndcg, reg:gamma. * default: reg:squarederror */ final val objective = new Param[String](this, "objective", diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index 0bf8c2fbb426..3e6879d83148 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -201,7 +201,7 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest { sc, buildTrainingRDD, List("eta" -> "1", "max_depth" -> "6", - "objective" -> "rank:pairwise", "num_round" -> 5, "num_workers" -> numWorkers, + "objective" -> "rank:ndcg", "num_round" -> 5, "num_workers" -> numWorkers, "custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false, "missing" -> Float.NaN).toMap) @@ -268,7 +268,7 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest { val training = buildDataFrameWithGroup(Ranking.train, 5) val Array(train, eval1, eval2) = training.randomSplit(Array(0.6, 0.2, 0.2), 0) val paramMap1 = Map("eta" -> "1", "max_depth" -> "6", - "objective" -> "rank:pairwise", + "objective" -> "rank:ndcg", "num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group") val xgb1 = new XGBoostRegressor(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2)) val model1 = xgb1.fit(train) @@ -281,7 +281,7 @@ class XGBoostGeneralSuite extends FunSuite with TmpFolderPerSuite with PerTest { assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1)) val paramMap2 = Map("eta" -> "1", "max_depth" -> "6", - "objective" -> "rank:pairwise", + "objective" -> "rank:ndcg", "num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group", "eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2)) val xgb2 = new XGBoostRegressor(paramMap2) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala index 4e3d59b25e57..b8b6aeefa0e0 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -121,7 +121,7 @@ class XGBoostRegressorSuite extends FunSuite with PerTest with TmpFolderPerSuite test("ranking: use group data") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "rank:pairwise", "num_workers" -> numWorkers, "num_round" -> 5, + "objective" -> "rank:ndcg", "num_workers" -> numWorkers, "num_round" -> 5, "group_col" -> "group", "tree_method" -> treeMethod) val trainingDF = buildDataFrameWithGroup(Ranking.train) diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 5a0cfb3a2ece..4b1e8be1b5af 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -288,10 +288,10 @@ def _check_call(ret: int) -> None: def build_info() -> dict: - """Build information of XGBoost. The returned value format is not stable. Also, please - note that build time dependency is not the same as runtime dependency. For instance, - it's possible to build XGBoost with older CUDA version but run it with the lastest - one. + """Build information of XGBoost. The returned value format is not stable. Also, + please note that build time dependency is not the same as runtime dependency. For + instance, it's possible to build XGBoost with older CUDA version but run it with the + lastest one. .. versionadded:: 1.6.0 @@ -675,28 +675,28 @@ def __init__( data : Data source of DMatrix. See :ref:`py-data` for a list of supported input types. - label : array_like + label : Label of the training data. - weight : array_like + weight : Weight for each instance. - .. note:: For ranking task, weights are per-group. + .. note:: - In ranking task, one weight is assigned to each group (not each - data point). This is because we only care about the relative - ordering of data points within each group, so it doesn't make - sense to assign weights to individual data points. + For ranking task, weights are per-group. In ranking task, one weight + is assigned to each group (not each data point). This is because we + only care about the relative ordering of data points within each group, + so it doesn't make sense to assign weights to individual data points. - base_margin: array_like + base_margin : Base margin used for boosting from existing model. - missing : float, optional - Value in the input data which needs to be present as a missing - value. If None, defaults to np.nan. - silent : boolean, optional + missing : + Value in the input data which needs to be present as a missing value. If + None, defaults to np.nan. + silent : Whether print messages during construction - feature_names : list, optional + feature_names : Set names for features. - feature_types : FeatureTypes + feature_types : Set types for features. When `enable_categorical` is set to `True`, string "c" represents categorical data type while "q" represents numerical feature @@ -706,20 +706,20 @@ def __init__( `.cat.codes` method. This is useful when users want to specify categorical features without having to construct a dataframe as input. - nthread : integer, optional + nthread : Number of threads to use for loading data when parallelization is applicable. If -1, uses maximum threads available on the system. - group : array_like + group : Group size for all ranking group. - qid : array_like + qid : Query ID for data samples, used for ranking. - label_lower_bound : array_like + label_lower_bound : Lower bound for survival training. - label_upper_bound : array_like + label_upper_bound : Upper bound for survival training. - feature_weights : array_like, optional + feature_weights : Set feature weights for column sampling. - enable_categorical: boolean, optional + enable_categorical : .. versionadded:: 1.3.0 @@ -1760,6 +1760,7 @@ def save_config(self) -> str: string. .. versionadded:: 1.0.0 + """ json_string = ctypes.c_char_p() length = c_bst_ulong() diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 3204f5a2a61e..75433ca5b320 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1830,7 +1830,11 @@ def _get_qid( @xgboost_model_doc( - """Implementation of the Scikit-Learn API for XGBoost Ranking.""", + """Implementation of the Scikit-Learn API for XGBoost Ranking. + +See :doc:`Learning to Rank ` for an introducion. + + """, ["estimators", "model"], end_note=""" .. note:: @@ -1879,7 +1883,7 @@ def _get_qid( class XGBRanker(XGBModel, XGBRankerMixIn): # pylint: disable=missing-docstring,too-many-arguments,invalid-name @_deprecate_positional_args - def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any): + def __init__(self, *, objective: str = "rank:ndcg", **kwargs: Any): super().__init__(objective=objective, **kwargs) if callable(self.objective): raise ValueError("custom objective function not supported by XGBRanker") diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 3b33e87749f3..bb13b5523ed2 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -14,6 +14,7 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from io import StringIO +from pathlib import Path from platform import system from typing import ( Any, @@ -443,7 +444,7 @@ def get_mq2008( from sklearn.datasets import load_svmlight_files src = "https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip" - target = dpath + "/MQ2008.zip" + target = os.path.join(os.path.expanduser(dpath), "MQ2008.zip") if not os.path.exists(target): request.urlretrieve(url=src, filename=target) @@ -462,9 +463,9 @@ def get_mq2008( qid_valid, ) = load_svmlight_files( ( - dpath + "MQ2008/Fold1/train.txt", - dpath + "MQ2008/Fold1/test.txt", - dpath + "MQ2008/Fold1/vali.txt", + Path(dpath) / "MQ2008" / "Fold1" / "train.txt", + Path(dpath) / "MQ2008" / "Fold1" / "test.txt", + Path(dpath) / "MQ2008" / "Fold1" / "vali.txt", ), query_id=True, zero_based=False, diff --git a/python-package/xgboost/testing/params.py b/python-package/xgboost/testing/params.py index 3af3306da40e..fe730f74f51f 100644 --- a/python-package/xgboost/testing/params.py +++ b/python-package/xgboost/testing/params.py @@ -47,3 +47,15 @@ "max_cat_threshold": strategies.integers(1, 128), } ) + +lambdarank_parameter_strategy = strategies.fixed_dictionaries( + { + "lambdarank_unbiased": strategies.sampled_from([True, False]), + "lambdarank_pair_method": strategies.sampled_from(["topk", "mean"]), + "lambdarank_num_pair_per_sample": strategies.integers(1, 8), + "lambdarank_bias_norm": strategies.floats(0.5, 2.0), + "objective": strategies.sampled_from( + ["rank:ndcg", "rank:map", "rank:pairwise"] + ), + } +) diff --git a/src/common/algorithm.h b/src/common/algorithm.h index 739a84968d45..2f5fccedd592 100644 --- a/src/common/algorithm.h +++ b/src/common/algorithm.h @@ -3,14 +3,17 @@ */ #ifndef XGBOOST_COMMON_ALGORITHM_H_ #define XGBOOST_COMMON_ALGORITHM_H_ -#include // upper_bound, stable_sort, sort, max -#include // size_t -#include // less -#include // iterator_traits, distance -#include // vector +#include // for upper_bound, stable_sort, sort, max, all_of, none_of, min +#include // for size_t +#include // for less +#include // for iterator_traits, distance +#include // for is_same +#include // for vector -#include "numeric.h" // Iota -#include "xgboost/context.h" // Context +#include "common.h" // for DivRoundUp +#include "numeric.h" // for Iota +#include "threading_utils.h" // for MemStackAllocator, DefaultMaxThreads, ParallelFor +#include "xgboost/context.h" // for Context // clang with libstdc++ works as well #if defined(__GNUC__) && (__GNUC__ >= 4) && !defined(__sun) && !defined(sun) && \ @@ -83,6 +86,55 @@ std::vector ArgSort(Context const *ctx, Iter begin, Iter end, Comp comp = s StableSort(ctx, result.begin(), result.end(), op); return result; } + +namespace detail { +template +bool Logical(Context const *ctx, It first, It last, Op op) { + auto n = std::distance(first, last); + auto n_threads = + std::max(std::min(n, static_cast(ctx->Threads())), static_cast(1)); + common::MemStackAllocator tloc{ + static_cast(n_threads), false}; + CHECK_GE(n, 0); + CHECK_GE(ctx->Threads(), 1); + static_assert(std::is_same::value, ""); + auto const n_per_thread = common::DivRoundUp(n, ctx->Threads()); + common::ParallelFor(static_cast(n_threads), n_threads, [&](auto t) { + auto begin = t * n_per_thread; + auto end = std::min(begin + n_per_thread, n); + + auto first_tloc = first + begin; + auto last_tloc = first + end; + if (first_tloc >= last_tloc) { + tloc[t] = true; + return; + } + bool result = op(first_tloc, last_tloc); + tloc[t] = result; + }); + return std::all_of(tloc.cbegin(), tloc.cend(), [](auto v) { return v; }); +} +} // namespace detail + +/** + * \brief Parallel version of std::none_of + */ +template +bool NoneOf(Context const *ctx, It first, It last, Pred predicate) { + return detail::Logical(ctx, first, last, [&predicate](It first, It last) { + return std::none_of(first, last, predicate); + }); +} + +/** + * \brief Parallel version of std::all_of + */ +template +bool AllOf(Context const *ctx, It first, It last, Pred predicate) { + return detail::Logical(ctx, first, last, [&predicate](It first, It last) { + return std::all_of(first, last, predicate); + }); +} } // namespace common } // namespace xgboost diff --git a/src/common/error_msg.h b/src/common/error_msg.h new file mode 100644 index 000000000000..c6f1e3d0d41c --- /dev/null +++ b/src/common/error_msg.h @@ -0,0 +1,26 @@ +/** + * Copyright 2023 by XGBoost contributors + * + * \brief Common error message for various checks. + */ +#ifndef XGBOOST_COMMON_ERROR_MSG_H_ +#define XGBOOST_COMMON_ERROR_MSG_H_ + +#include "xgboost/string_view.h" // for StringView + +namespace xgboost { +namespace error { +constexpr StringView GroupWeight() { + return "Size of weight must equal to the number of query groups when ranking group is used."; +} + +constexpr StringView GroupSize() { + return "Invalid query group structure. The number of rows obtained from group doesn't equal to "; +} + +constexpr StringView LabelScoreSize() { + return "The size of label doesn't match the size of prediction."; +} +} // namespace error +} // namespace xgboost +#endif // XGBOOST_COMMON_ERROR_MSG_H_ diff --git a/src/common/math.h b/src/common/math.h index 71a494544be1..25befdbef218 100644 --- a/src/common/math.h +++ b/src/common/math.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2015 by Contributors +/** + * Copyright 2015-2023 by XGBoost Contributors * \file math.h * \brief additional math utils * \author Tianqi Chen @@ -7,16 +7,19 @@ #ifndef XGBOOST_COMMON_MATH_H_ #define XGBOOST_COMMON_MATH_H_ -#include +#include // for XGBOOST_DEVICE -#include -#include -#include -#include -#include +#include // for max +#include // for exp, abs, log, lgamma +#include // for numeric_limits +#include // for is_floating_point, conditional, is_signed, is_same, declval, enable_if +#include // for pair namespace xgboost { namespace common { + +template XGBOOST_DEVICE T Sqr(T const &w) { return w * w; } + /*! * \brief calculate the sigmoid of the input. * \param x input parameter @@ -30,9 +33,12 @@ XGBOOST_DEVICE inline float Sigmoid(float x) { return y; } -template -XGBOOST_DEVICE inline static T Sqr(T a) { return a * a; } - +XGBOOST_DEVICE inline double Sigmoid(double x) { + double constexpr kEps = 1e-16; // avoid 0 div + auto denom = std::exp(-x) + 1.0 + kEps; + auto y = 1.0 / denom; + return y; +} /*! * \brief Equality test for both integer and floating point. */ @@ -134,10 +140,6 @@ inline static bool CmpFirst(const std::pair &a, const std::pair &b) { return a.first > b.first; } -inline static bool CmpSecond(const std::pair &a, - const std::pair &b) { - return a.second > b.second; -} // Redefined here to workaround a VC bug that doesn't support overloading for integer // types. diff --git a/src/common/numeric.h b/src/common/numeric.h index 6a1c15fd08b4..687134c122d4 100644 --- a/src/common/numeric.h +++ b/src/common/numeric.h @@ -6,8 +6,10 @@ #include // OMPException -#include // std::max -#include // std::iterator_traits +#include // for std::max +#include // for size_t +#include // for int32_t +#include // for iterator_traits #include #include "common.h" // AssertGPUSupport @@ -111,11 +113,11 @@ inline double Reduce(Context const*, HostDeviceVector const&) { namespace cpu_impl { template V Reduce(Context const* ctx, It first, It second, V const& init) { - size_t n = std::distance(first, second); - common::MemStackAllocator result_tloc(ctx->Threads(), init); - common::ParallelFor(n, ctx->Threads(), - [&](auto i) { result_tloc[omp_get_thread_num()] += first[i]; }); - auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cbegin() + ctx->Threads(), init); + std::size_t n = std::distance(first, second); + auto n_threads = static_cast(std::min(n, static_cast(ctx->Threads()))); + common::MemStackAllocator result_tloc(n_threads, init); + common::ParallelFor(n, n_threads, [&](auto i) { result_tloc[omp_get_thread_num()] += first[i]; }); + auto result = std::accumulate(result_tloc.cbegin(), result_tloc.cbegin() + n_threads, init); return result; } } // namespace cpu_impl diff --git a/src/common/ranking_utils.cc b/src/common/ranking_utils.cc index f0b1c1a5ee77..cae564480a0d 100644 --- a/src/common/ranking_utils.cc +++ b/src/common/ranking_utils.cc @@ -3,15 +3,134 @@ */ #include "ranking_utils.h" -#include // std::uint32_t -#include // std::ostringstream -#include // std::string,std::sscanf +#include // for copy_n, max, min +#include // for size_t +#include // for sscanf +#include // for exception +#include // for greater +#include // for reverse_iterator +#include // for char_traits, string -#include "xgboost/string_view.h" // StringView +#include "algorithm.h" // for ArgSort, NoneOf, AllOf +#include "linalg_op.h" // for cbegin, cend +#include "optional_weight.h" // for MakeOptionalWeights +#include "threading_utils.h" // for ParallelFor +#include "xgboost/base.h" // for bst_group_t +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/linalg.h" // for All, TensorView, Range, Tensor, Vector +#include "xgboost/logging.h" // for Error, LogCheck_EQ, CHECK_EQ -namespace xgboost { -namespace ltr { -std::string MakeMetricName(StringView name, StringView param, std::uint32_t* topn, bool* minus) { +namespace xgboost::ltr { +void RankingCache::InitOnCPU(Context const* ctx, MetaInfo const& info) { + if (info.group_ptr_.empty()) { + group_ptr_.Resize(2, 0); + group_ptr_.HostVector()[1] = info.num_row_; + } else { + group_ptr_.HostVector() = info.group_ptr_; + } + + auto const& gptr = group_ptr_.ConstHostVector(); + for (std::size_t i = 1; i < gptr.size(); ++i) { + std::size_t n = gptr[i] - gptr[i - 1]; + max_group_size_ = std::max(max_group_size_, n); + } + + double sum_weights = 0; + auto n_groups = Groups(); + auto weight = common::MakeOptionalWeights(ctx, info.weights_); + for (bst_omp_uint k = 0; k < n_groups; ++k) { + sum_weights += weight[k]; + } + weight_norm_ = static_cast(n_groups) / sum_weights; + + auto const& h_label = info.labels.HostView().Slice(linalg::All(), 0); + is_binary_ = IsBinaryRel( + h_label, [ctx](auto beg, auto end, auto op) { return common::AllOf(ctx, beg, end, op); }); +} + +common::Span RankingCache::MakeRankOnCPU(Context const* ctx, + common::Span predt) { + auto gptr = this->DataGroupPtr(ctx); + auto rank = this->sorted_idx_cache_.HostSpan(); + CHECK_EQ(rank.size(), predt.size()); + + common::ParallelFor(this->Groups(), ctx->Threads(), [&](auto g) { + auto cnt = gptr[g + 1] - gptr[g]; + auto g_predt = predt.subspan(gptr[g], cnt); + auto g_rank = rank.subspan(gptr[g], cnt); + auto sorted_idx = common::ArgSort( + ctx, g_predt.data(), g_predt.data() + g_predt.size(), std::greater<>{}); + CHECK_EQ(g_rank.size(), sorted_idx.size()); + std::copy_n(sorted_idx.data(), sorted_idx.size(), g_rank.data()); + }); + + return rank; +} + +#if !defined(XGBOOST_USE_CUDA) +void RankingCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); } +common::Span RankingCache::MakeRankOnCUDA(Context const*, + common::Span) { + common::AssertGPUSupport(); + return {}; +} +#endif // !defined() + +void NDCGCache::InitOnCPU(Context const* ctx, MetaInfo const& info) { + auto const h_group_ptr = this->DataGroupPtr(ctx); + + discounts_.Resize(MaxGroupSize(), 0); + auto& h_discounts = discounts_.HostVector(); + for (std::size_t i = 0; i < MaxGroupSize(); ++i) { + h_discounts[i] = CalcDCGDiscount(i); + } + + auto n_groups = h_group_ptr.size() - 1; + auto h_labels = info.labels.HostView().Slice(linalg::All(), 0); + + CheckNDCGLabels(this->Param(), h_labels, + [&](auto beg, auto end, auto op) { return common::NoneOf(ctx, beg, end, op); }); + + inv_idcg_.Reshape(n_groups); + auto h_inv_idcg = inv_idcg_.HostView(); + std::size_t topk = this->Param().TopK(); + auto const exp_gain = this->Param().ndcg_exp_gain; + + common::ParallelFor(n_groups, ctx->Threads(), [&](auto g) { + auto g_labels = h_labels.Slice(linalg::Range(h_group_ptr[g], h_group_ptr[g + 1])); + auto sorted_idx = common::ArgSort(ctx, linalg::cbegin(g_labels), + linalg::cend(g_labels), std::greater<>{}); + + double idcg{0.0}; + for (std::size_t i = 0; i < std::min(g_labels.Size(), topk); ++i) { + if (exp_gain) { + idcg += h_discounts[i] * CalcDCGGain(g_labels(sorted_idx[i])); + } else { + idcg += h_discounts[i] * g_labels(sorted_idx[i]); + } + } + h_inv_idcg(g) = CalcInvIDCG(idcg); + }); +} + +#if !defined(XGBOOST_USE_CUDA) +void NDCGCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); } +#endif // !defined(XGBOOST_USE_CUDA) + +DMLC_REGISTER_PARAMETER(LambdaRankParam); + +void MAPCache::InitOnCPU(Context const* ctx, MetaInfo const& info) { + auto const& h_label = info.labels.HostView().Slice(linalg::All(), 0); + CheckMapLabels(h_label, + [ctx](auto beg, auto end, auto op) { return common::AllOf(ctx, beg, end, op); }); +} + +#if !defined(XGBOOST_USE_CUDA) +void MAPCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); } +#endif // !defined(XGBOOST_USE_CUDA) + +std::string ParseMetricName(StringView name, StringView param, position_t* topn, bool* minus) { std::string out_name; if (!param.empty()) { std::ostringstream os; @@ -30,5 +149,18 @@ std::string MakeMetricName(StringView name, StringView param, std::uint32_t* top } return out_name; } -} // namespace ltr -} // namespace xgboost + +std::string MakeMetricName(StringView name, position_t topn, bool minus) { + std::ostringstream ss; + if (topn == LambdaRankParam::NotSet()) { + ss << name; + } else { + ss << name << "@" << topn; + } + if (minus) { + ss << "-"; + } + std::string out_name = ss.str(); + return out_name; +} +} // namespace xgboost::ltr diff --git a/src/common/ranking_utils.cu b/src/common/ranking_utils.cu new file mode 100644 index 000000000000..5ee1ed36c367 --- /dev/null +++ b/src/common/ranking_utils.cu @@ -0,0 +1,213 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include // for make_counting_iterator +#include // for none_of, all_off +#include // for pair, make_pair +#include // for reduce +#include // for inclusive_scan + +#include // for size_t + +#include "algorithm.cuh" // for SegmentedArgSort +#include "cuda_context.cuh" // for CUDAContext +#include "device_helpers.cuh" // for MakeTransformIterator, LaunchN +#include "optional_weight.h" // for MakeOptionalWeights, OptionalWeights +#include "ranking_utils.cuh" // for ThreadsForMean +#include "ranking_utils.h" +#include "threading_utils.cuh" // for SegmentedTrapezoidThreads +#include "xgboost/base.h" // for XGBOOST_DEVICE +#include "xgboost/context.h" // for Context +#include "xgboost/linalg.h" // for VectorView +#include "xgboost/span.h" // for Span + +namespace xgboost::ltr { +namespace cuda_impl { +void CalcQueriesDCG(Context const* ctx, linalg::VectorView d_labels, + common::Span d_sorted_idx, bool exp_gain, + common::Span d_group_ptr, std::size_t k, + linalg::VectorView out_dcg) { + CHECK_EQ(d_group_ptr.size() - 1, out_dcg.Size()); + using IdxGroup = thrust::pair; + auto group_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ull), [=] XGBOOST_DEVICE(std::size_t idx) { + return thrust::make_pair(idx, dh::SegmentId(d_group_ptr, idx)); // NOLINT + }); + auto value_it = dh::MakeTransformIterator( + group_it, + [exp_gain, d_labels, d_group_ptr, k, + d_sorted_idx] XGBOOST_DEVICE(IdxGroup const& l) -> double { + auto g_begin = d_group_ptr[l.second]; + auto g_size = d_group_ptr[l.second + 1] - g_begin; + + auto idx_in_group = l.first - g_begin; + if (idx_in_group >= k) { + return 0.0; + } + double gain{0.0}; + auto g_sorted_idx = d_sorted_idx.subspan(g_begin, g_size); + auto g_labels = d_labels.Slice(linalg::Range(g_begin, g_begin + g_size)); + + if (exp_gain) { + gain = ltr::CalcDCGGain(g_labels(g_sorted_idx[idx_in_group])); + } else { + gain = g_labels(g_sorted_idx[idx_in_group]); + } + double discount = CalcDCGDiscount(idx_in_group); + return gain * discount; + }); + + CHECK(out_dcg.Contiguous()); + std::size_t bytes; + cub::DeviceSegmentedReduce::Sum(nullptr, bytes, value_it, out_dcg.Values().data(), + d_group_ptr.size() - 1, d_group_ptr.data(), + d_group_ptr.data() + 1, ctx->CUDACtx()->Stream()); + dh::TemporaryArray temp(bytes); + cub::DeviceSegmentedReduce::Sum(temp.data().get(), bytes, value_it, out_dcg.Values().data(), + d_group_ptr.size() - 1, d_group_ptr.data(), + d_group_ptr.data() + 1, ctx->CUDACtx()->Stream()); +} + +void CalcQueriesInvIDCG(Context const* ctx, linalg::VectorView d_labels, + common::Span d_group_ptr, + linalg::VectorView out_inv_IDCG, ltr::LambdaRankParam const& p) { + CHECK_GE(d_group_ptr.size(), 2ul); + size_t n_groups = d_group_ptr.size() - 1; + CHECK_EQ(out_inv_IDCG.Size(), n_groups); + dh::device_vector sorted_idx(d_labels.Size()); + auto d_sorted_idx = dh::ToSpan(sorted_idx); + common::SegmentedArgSort(ctx, d_labels.Values(), d_group_ptr, d_sorted_idx); + CalcQueriesDCG(ctx, d_labels, d_sorted_idx, p.ndcg_exp_gain, d_group_ptr, p.TopK(), out_inv_IDCG); + dh::LaunchN(out_inv_IDCG.Size(), ctx->CUDACtx()->Stream(), + [out_inv_IDCG] XGBOOST_DEVICE(size_t idx) mutable { + double idcg = out_inv_IDCG(idx); + out_inv_IDCG(idx) = CalcInvIDCG(idcg); + }); +} +} // namespace cuda_impl + +namespace { +struct CheckNDCGOp { + CUDAContext const* cuctx; + template + bool operator()(It beg, It end, Op op) { + return thrust::none_of(cuctx->CTP(), beg, end, op); + } +}; +struct CheckMAPOp { + CUDAContext const* cuctx; + template + bool operator()(It beg, It end, Op op) { + return thrust::all_of(cuctx->CTP(), beg, end, op); + } +}; + +struct ThreadGroupOp { + common::Span d_group_ptr; + std::size_t n_pairs; + + common::Span out_thread_group_ptr; + + XGBOOST_DEVICE void operator()(std::size_t i) { + out_thread_group_ptr[i + 1] = + cuda_impl::ThreadsForMean(d_group_ptr[i + 1] - d_group_ptr[i], n_pairs); + } +}; + +struct GroupSizeOp { + common::Span d_group_ptr; + + XGBOOST_DEVICE auto operator()(std::size_t i) -> std::size_t { + return d_group_ptr[i + 1] - d_group_ptr[i]; + } +}; + +struct WeightOp { + common::OptionalWeights d_weight; + XGBOOST_DEVICE auto operator()(std::size_t i) -> double { return d_weight[i]; } +}; +} // anonymous namespace + +void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { + CUDAContext const* cuctx = ctx->CUDACtx(); + + group_ptr_.SetDevice(ctx->gpu_id); + if (info.group_ptr_.empty()) { + group_ptr_.Resize(2, 0); + group_ptr_.HostVector()[1] = info.num_row_; + } else { + auto const& h_group_ptr = info.group_ptr_; + group_ptr_.Resize(h_group_ptr.size()); + auto d_group_ptr = group_ptr_.DeviceSpan(); + dh::safe_cuda(cudaMemcpyAsync(d_group_ptr.data(), h_group_ptr.data(), d_group_ptr.size_bytes(), + cudaMemcpyHostToDevice, cuctx->Stream())); + } + + auto d_group_ptr = DataGroupPtr(ctx); + std::size_t n_groups = Groups(); + + auto it = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), + GroupSizeOp{d_group_ptr}); + max_group_size_ = + thrust::reduce(cuctx->CTP(), it, it + n_groups, 0ul, thrust::maximum{}); + + threads_group_ptr_.SetDevice(ctx->gpu_id); + threads_group_ptr_.Resize(n_groups + 1, 0); + auto d_threads_group_ptr = threads_group_ptr_.DeviceSpan(); + if (param_.HasTruncation()) { + n_cuda_threads_ = + common::SegmentedTrapezoidThreads(d_group_ptr, d_threads_group_ptr, Param().NumPair()); + } else { + auto n_pairs = Param().NumPair(); + dh::LaunchN(n_groups, cuctx->Stream(), + ThreadGroupOp{d_group_ptr, n_pairs, d_threads_group_ptr}); + thrust::inclusive_scan(cuctx->CTP(), dh::tcbegin(d_threads_group_ptr), + dh::tcend(d_threads_group_ptr), dh::tbegin(d_threads_group_ptr)); + n_cuda_threads_ = info.num_row_ * param_.NumPair(); + } + + sorted_idx_cache_.SetDevice(ctx->gpu_id); + sorted_idx_cache_.Resize(info.labels.Size(), 0); + + auto weight = common::MakeOptionalWeights(ctx, info.weights_); + auto w_it = + dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), WeightOp{weight}); + weight_norm_ = static_cast(n_groups) / thrust::reduce(w_it, w_it + n_groups); + + auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + is_binary_ = IsBinaryRel(d_label, CheckMAPOp{ctx->CUDACtx()}); +} + +common::Span RankingCache::MakeRankOnCUDA(Context const* ctx, + common::Span predt) { + auto d_sorted_idx = sorted_idx_cache_.DeviceSpan(); + auto d_group_ptr = DataGroupPtr(ctx); + common::SegmentedArgSort(ctx, predt, d_group_ptr, d_sorted_idx); + return d_sorted_idx; +} + +void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { + CUDAContext const* cuctx = ctx->CUDACtx(); + auto labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + CheckNDCGLabels(this->Param(), labels, CheckNDCGOp{cuctx}); + + auto d_group_ptr = this->DataGroupPtr(ctx); + + std::size_t n_groups = d_group_ptr.size() - 1; + inv_idcg_ = linalg::Zeros(ctx, n_groups); + auto d_inv_idcg = inv_idcg_.View(ctx->gpu_id); + cuda_impl::CalcQueriesInvIDCG(ctx, labels, d_group_ptr, d_inv_idcg, this->Param()); + CHECK_GE(this->Param().NumPair(), 1ul); + + discounts_.SetDevice(ctx->gpu_id); + discounts_.Resize(MaxGroupSize()); + auto d_discount = discounts_.DeviceSpan(); + dh::LaunchN(MaxGroupSize(), cuctx->Stream(), + [=] XGBOOST_DEVICE(std::size_t i) { d_discount[i] = CalcDCGDiscount(i); }); +} + +void MAPCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) { + auto const d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + CheckMapLabels(d_label, CheckMAPOp{ctx->CUDACtx()}); +} +} // namespace xgboost::ltr diff --git a/src/common/ranking_utils.cuh b/src/common/ranking_utils.cuh new file mode 100644 index 000000000000..297f5157ecfb --- /dev/null +++ b/src/common/ranking_utils.cuh @@ -0,0 +1,40 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#ifndef XGBOOST_COMMON_RANKING_UTILS_CUH_ +#define XGBOOST_COMMON_RANKING_UTILS_CUH_ + +#include // for size_t + +#include "ranking_utils.h" // for LambdaRankParam +#include "xgboost/base.h" // for bst_group_t, XGBOOST_DEVICE +#include "xgboost/context.h" // for Context +#include "xgboost/linalg.h" // for VectorView +#include "xgboost/span.h" // for Span + +namespace xgboost { +namespace ltr { +namespace cuda_impl { +void CalcQueriesDCG(Context const *ctx, linalg::VectorView d_labels, + common::Span d_sorted_idx, bool exp_gain, + common::Span d_group_ptr, std::size_t k, + linalg::VectorView out_dcg); + +void CalcQueriesInvIDCG(Context const *ctx, linalg::VectorView d_labels, + common::Span d_group_ptr, + linalg::VectorView out_inv_IDCG, ltr::LambdaRankParam const &p); + +// Functions for creating number of threads for CUDA, and getting back the number of pairs +// from the number of threads. +XGBOOST_DEVICE __forceinline__ std::size_t ThreadsForMean(std::size_t group_size, + std::size_t n_pairs) { + return group_size * n_pairs; +} +XGBOOST_DEVICE __forceinline__ std::size_t PairsForGroup(std::size_t n_threads, + std::size_t group_size) { + return n_threads / group_size; +} +} // namespace cuda_impl +} // namespace ltr +} // namespace xgboost +#endif // XGBOOST_COMMON_RANKING_UTILS_CUH_ diff --git a/src/common/ranking_utils.h b/src/common/ranking_utils.h index 35ee36c2185d..b3259f3f154e 100644 --- a/src/common/ranking_utils.h +++ b/src/common/ranking_utils.h @@ -3,17 +3,434 @@ */ #ifndef XGBOOST_COMMON_RANKING_UTILS_H_ #define XGBOOST_COMMON_RANKING_UTILS_H_ +#include // for min +#include // for log2, fabs, floor +#include // for size_t +#include // for uint32_t, uint8_t, int32_t +#include // for numeric_limits +#include // for char_traits, string +#include // for vector -#include // std::size_t -#include // std::uint32_t -#include // std::string +#include "./math.h" // for CloseTo +#include "dmlc/parameter.h" // for FieldEntry, DMLC_DECLARE_FIELD +#include "error_msg.h" // for GroupWeight, GroupSize +#include "xgboost/base.h" // for XGBOOST_DEVICE, bst_group_t +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/linalg.h" // for Vector, VectorView, Tensor +#include "xgboost/logging.h" // for LogCheck_EQ, CHECK_EQ, CHECK +#include "xgboost/parameter.h" // for XGBoostParameter +#include "xgboost/span.h" // for Span +#include "xgboost/string_view.h" // for StringView -#include "xgboost/string_view.h" // StringView +namespace xgboost::ltr { +/** + * \brief Relevance degree + */ +using rel_degree_t = std::uint32_t; // NOLINT +/** + * \brief top-k position + */ +using position_t = std::uint32_t; // NOLINT + +/** + * \brief Maximum relevance degree for NDCG + */ +constexpr std::size_t MaxRel() { return sizeof(rel_degree_t) * 8 - 1; } +static_assert(MaxRel() == 31); + +XGBOOST_DEVICE inline double CalcDCGGain(rel_degree_t label) { + return static_cast((1u << label) - 1); +} + +XGBOOST_DEVICE inline double CalcDCGDiscount(std::size_t idx) { + return 1.0 / std::log2(static_cast(idx) + 2.0); +} + +XGBOOST_DEVICE inline double CalcInvIDCG(double idcg) { + auto inv_idcg = (idcg == 0.0 ? 0.0 : (1.0 / idcg)); // handle irrelevant document + return inv_idcg; +} + +enum class PairMethod : std::int32_t { + kTopK = 0, + kMean = 1, +}; +} // namespace xgboost::ltr + +DECLARE_FIELD_ENUM_CLASS(xgboost::ltr::PairMethod); + +namespace xgboost::ltr { +struct LambdaRankParam : public XGBoostParameter { + private: + static constexpr position_t DefaultK() { return 32; } + static constexpr position_t DefaultSamplePairs() { return 1; } // fixme: better dft + + protected: + // pairs + // should be accessed by getter for auto configuration. + // nolint so that we can keep the string name. + PairMethod lambdarank_pair_method; // NOLINT + std::size_t lambdarank_num_pair_per_sample; // NOLINT + + public: + static constexpr position_t NotSet() { return std::numeric_limits::max(); } + + // unbiased + bool lambdarank_unbiased; + double lambdarank_bias_norm; + // ndcg + bool ndcg_exp_gain; + + bool operator==(LambdaRankParam const& that) const { + return lambdarank_pair_method == that.lambdarank_pair_method && + lambdarank_num_pair_per_sample == that.lambdarank_num_pair_per_sample && + lambdarank_unbiased == that.lambdarank_unbiased && + lambdarank_bias_norm == that.lambdarank_bias_norm && ndcg_exp_gain == that.ndcg_exp_gain; + } + bool operator!=(LambdaRankParam const& that) const { return !(*this == that); } + + [[nodiscard]] double Regularizer() const { return 1.0 / (1.0 + this->lambdarank_bias_norm); } + + /** + * \brief Get number of pairs for each sample + */ + [[nodiscard]] position_t NumPair() const { + if (lambdarank_num_pair_per_sample == NotSet()) { + switch (lambdarank_pair_method) { + case PairMethod::kMean: + return DefaultSamplePairs(); + case PairMethod::kTopK: + return DefaultK(); + default: + LOG(FATAL) << "Unreachable."; + } + } else { + return lambdarank_num_pair_per_sample; + } + LOG(FATAL) << "Unreachable."; + return 0; + } + + [[nodiscard]] bool HasTruncation() const { return lambdarank_pair_method == PairMethod::kTopK; } + + // Used for evaluation metric and cache initialization, iterate through top-k or the whole list + [[nodiscard]] auto TopK() const { + if (HasTruncation()) { + return NumPair(); + } else { + return NotSet(); + } + } + + DMLC_DECLARE_PARAMETER(LambdaRankParam) { + DMLC_DECLARE_FIELD(lambdarank_pair_method) + .set_default(PairMethod::kMean) + .add_enum("mean", PairMethod::kMean) + .add_enum("topk", PairMethod::kTopK) + .describe("Method for constructing pairs."); + DMLC_DECLARE_FIELD(lambdarank_num_pair_per_sample) + .set_default(NotSet()) + .set_lower_bound(1) + .describe("Number of pairs for each sample in the list."); + DMLC_DECLARE_FIELD(lambdarank_unbiased) + .set_default(false) + .describe("Unbiased lambda mart. Use IPW to debias click position"); + DMLC_DECLARE_FIELD(lambdarank_bias_norm) + .set_default(2.0) + .set_lower_bound(0.0) + .describe("Lp regularization for unbiased lambdarank."); + DMLC_DECLARE_FIELD(ndcg_exp_gain) + .set_default(true) + .describe("When set to true, the label gain is 2^rel - 1, otherwise it's rel."); + } +}; + +/** + * \brief Common cached items for ranking tasks. + */ +class RankingCache { + private: + void InitOnCPU(Context const* ctx, MetaInfo const& info); + void InitOnCUDA(Context const* ctx, MetaInfo const& info); + // Cached parameter + LambdaRankParam param_; + // offset to data groups. + HostDeviceVector group_ptr_; + // store the sorted index of prediction. + HostDeviceVector sorted_idx_cache_; + // Maximum size of group + std::size_t max_group_size_{0}; + // Normalization for weight + double weight_norm_{1.0}; + // Whether label is binary + bool is_binary_{false}; + /** + * CUDA cache + */ + // offset to threads assigned to each group for gradient calculation + HostDeviceVector threads_group_ptr_; + // Sorted index of label for finding buckets. + HostDeviceVector y_sorted_idx_cache_; + // Cached labels sorted by the model + HostDeviceVector y_ranked_by_model_; + // store rounding factor for objective for each group + linalg::Vector roundings_; + // rounding factor for cost + HostDeviceVector cost_rounding_; + // temporary storage for creating rounding factors. Stored as byte to avoid having cuda + // data structure in here. + HostDeviceVector max_lambdas_; + // total number of cuda threads used for gradient calculation + std::size_t n_cuda_threads_{0}; + + // Create model rank list on GPU + common::Span MakeRankOnCUDA(Context const* ctx, + common::Span predt); + // Create model rank list on CPU + common::Span MakeRankOnCPU(Context const* ctx, + common::Span predt); -namespace xgboost { -namespace ltr { + protected: + [[nodiscard]] std::size_t MaxGroupSize() const { return max_group_size_; } + + public: + RankingCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p) : param_{p} { + CHECK(param_.GetInitialised()); + if (!info.group_ptr_.empty()) { + CHECK_EQ(info.group_ptr_.back(), info.labels.Size()) + << error::GroupSize() << "the size of label."; + } + if (ctx->IsCPU()) { + this->InitOnCPU(ctx, info); + } else { + this->InitOnCUDA(ctx, info); + } + if (!info.weights_.Empty()) { + CHECK_EQ(Groups(), info.weights_.Size()) << error::GroupWeight(); + } + } + [[nodiscard]] std::size_t MaxPositionSize() const { + // Use truncation level as bound. + if (param_.HasTruncation()) { + return param_.NumPair(); + } + // Hardcoded maximum size of positions to track. We don't need too many of them as the + // bias decreases exponentially. + return std::min(max_group_size_, static_cast(32)); + } + // Constructed as [1, n_samples] if group ptr is not supplied by the user + common::Span DataGroupPtr(Context const* ctx) const { + group_ptr_.SetDevice(ctx->gpu_id); + return ctx->IsCPU() ? group_ptr_.ConstHostSpan() : group_ptr_.ConstDeviceSpan(); + } + + [[nodiscard]] auto const& Param() const { return param_; } + [[nodiscard]] std::size_t Groups() const { return group_ptr_.Size() - 1; } + [[nodiscard]] double WeightNorm() const { return weight_norm_; } + [[nodiscard]] bool IsBinary() const { return is_binary_; } + + // Create a rank list by model prediction + common::Span SortedIdx(Context const* ctx, common::Span predt) { + if (sorted_idx_cache_.Empty()) { + sorted_idx_cache_.SetDevice(ctx->gpu_id); + sorted_idx_cache_.Resize(predt.size()); + } + if (ctx->IsCPU()) { + return this->MakeRankOnCPU(ctx, predt); + } else { + return this->MakeRankOnCUDA(ctx, predt); + } + } + // The function simply returns a uninitialized buffer as this is only used by the + // objective for creating pairs. + common::Span SortedIdxY(Context const* ctx, std::size_t n_samples) { + CHECK(ctx->IsCUDA()); + if (y_sorted_idx_cache_.Empty()) { + y_sorted_idx_cache_.SetDevice(ctx->gpu_id); + y_sorted_idx_cache_.Resize(n_samples); + } + return y_sorted_idx_cache_.DeviceSpan(); + } + common::Span RankedY(Context const* ctx, std::size_t n_samples) { + CHECK(ctx->IsCUDA()); + if (y_ranked_by_model_.Empty()) { + y_ranked_by_model_.SetDevice(ctx->gpu_id); + y_ranked_by_model_.Resize(n_samples); + } + return y_ranked_by_model_.DeviceSpan(); + } + + // CUDA cache getters, the cache is shared between metric and objective, some of these + // fields are lazy initialized to avoid unnecessary allocation. + [[nodiscard]] common::Span CUDAThreadsGroupPtr() const { + CHECK(!threads_group_ptr_.Empty()); + return threads_group_ptr_.ConstDeviceSpan(); + } + [[nodiscard]] std::size_t CUDAThreads() const { return n_cuda_threads_; } + + linalg::VectorView CUDARounding(Context const* ctx) { + if (roundings_.Size() == 0) { + roundings_.SetDevice(ctx->gpu_id); + roundings_.Reshape(Groups()); + } + return roundings_.View(ctx->gpu_id); + } + common::Span CUDACostRounding(Context const* ctx) { + if (cost_rounding_.Size() == 0) { + cost_rounding_.SetDevice(ctx->gpu_id); + cost_rounding_.Resize(1); + } + return cost_rounding_.DeviceSpan(); + } + template + common::Span MaxLambdas(Context const* ctx, std::size_t n) { + max_lambdas_.SetDevice(ctx->gpu_id); + std::size_t bytes = n * sizeof(Type); + if (bytes != max_lambdas_.Size()) { + max_lambdas_.Resize(bytes); + } + return common::Span{reinterpret_cast(max_lambdas_.DevicePointer()), n}; + } +}; + +class NDCGCache : public RankingCache { + // NDCG discount + HostDeviceVector discounts_; + // 1.0 / IDCG + linalg::Vector inv_idcg_; + /** + * CUDA cache + */ + // store the intermediate DCG calculation result for metric + linalg::Vector dcg_; + + public: + void InitOnCPU(Context const* ctx, MetaInfo const& info); + void InitOnCUDA(Context const* ctx, MetaInfo const& info); + + public: + NDCGCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p) + : RankingCache{ctx, info, p} { + if (ctx->IsCPU()) { + this->InitOnCPU(ctx, info); + } else { + this->InitOnCUDA(ctx, info); + } + } + + linalg::VectorView InvIDCG(Context const* ctx) const { + return inv_idcg_.View(ctx->gpu_id); + } + common::Span Discount(Context const* ctx) const { + return ctx->IsCPU() ? discounts_.ConstHostSpan() : discounts_.ConstDeviceSpan(); + } + linalg::VectorView Dcg(Context const* ctx) { + if (dcg_.Size() == 0) { + dcg_.SetDevice(ctx->gpu_id); + dcg_.Reshape(this->Groups()); + } + return dcg_.View(ctx->gpu_id); + } +}; + +/** + * \brief Validate label for NDCG + * + * \tparam NoneOf Implementation of std::none_of. Specified as a parameter to reuse the + * check for both CPU and GPU. + */ +template +void CheckNDCGLabels(ltr::LambdaRankParam const& p, linalg::VectorView labels, + NoneOf none_of) { + auto d_labels = labels.Values(); + if (p.ndcg_exp_gain) { + auto label_is_integer = + none_of(d_labels.data(), d_labels.data() + d_labels.size(), [] XGBOOST_DEVICE(float v) { + auto l = std::floor(v); + return std::fabs(l - v) > kRtEps || v < 0.0f; + }); + CHECK(label_is_integer) + << "When using relevance degree as target, label must be either 0 or positive integer."; + } + + if (p.ndcg_exp_gain) { + auto label_is_valid = none_of(d_labels.data(), d_labels.data() + d_labels.size(), + [] XGBOOST_DEVICE(ltr::rel_degree_t v) { return v > MaxRel(); }); + CHECK(label_is_valid) << "Relevance degress must be lesser than or equal to " << MaxRel() + << " when the exponential NDCG gain function is used. " + << "Set `ndcg_exp_gain` to false to use custom DCG gain."; + } +} + +template +bool IsBinaryRel(linalg::VectorView label, AllOf all_of) { + auto s_label = label.Values(); + return all_of(s_label.data(), s_label.data() + s_label.size(), [] XGBOOST_DEVICE(float y) { + return std::abs(y - 1.0f) < kRtEps || std::abs(y - 0.0f) < kRtEps; + }); +} /** - * \brief Construct name for ranking metric given parameters. + * \brief Validate label for MAP + * + * \tparam Implementation of std::all_of. Specified as a parameter to reuse the check for + * both CPU and GPU. + */ +template +void CheckMapLabels(linalg::VectorView label, AllOf all_of) { + auto s_label = label.Values(); + auto is_binary = IsBinaryRel(label, all_of); + CHECK(is_binary) << "MAP can only be used with binary labels."; +} + +class MAPCache : public RankingCache { + // Total number of relevant documents for each group + HostDeviceVector n_rel_; + // \sum l_k/k + HostDeviceVector acc_; + HostDeviceVector map_; + // Number of samples in this dataset. + std::size_t n_samples_{0}; + + void InitOnCPU(Context const* ctx, MetaInfo const& info); + void InitOnCUDA(Context const* ctx, MetaInfo const& info); + + public: + MAPCache(Context const* ctx, MetaInfo const& info, LambdaRankParam const& p) + : RankingCache{ctx, info, p}, n_samples_{static_cast(info.num_row_)} { + if (ctx->IsCPU()) { + this->InitOnCPU(ctx, info); + } else { + this->InitOnCUDA(ctx, info); + } + } + + common::Span NumRelevant(Context const* ctx) { + if (n_rel_.Empty()) { + n_rel_.SetDevice(ctx->gpu_id); + n_rel_.Resize(n_samples_); + } + return ctx->IsCPU() ? n_rel_.HostSpan() : n_rel_.DeviceSpan(); + } + common::Span Acc(Context const* ctx) { + if (acc_.Empty()) { + acc_.SetDevice(ctx->gpu_id); + acc_.Resize(n_samples_); + } + return ctx->IsCPU() ? acc_.HostSpan() : acc_.DeviceSpan(); + } + common::Span Map(Context const* ctx) { + if (map_.Empty()) { + map_.SetDevice(ctx->gpu_id); + map_.Resize(this->Groups()); + } + return ctx->IsCPU() ? map_.HostSpan() : map_.DeviceSpan(); + } +}; + +/** + * \brief Parse name for ranking metric given parameters. * * \param [in] name Null terminated string for metric name * \param [in] param Null terminated string for parameter like the `3-` in `ndcg@3-`. @@ -23,7 +440,11 @@ namespace ltr { * * \return The name of the metric. */ -std::string MakeMetricName(StringView name, StringView param, std::uint32_t* topn, bool* minus); -} // namespace ltr -} // namespace xgboost +std::string ParseMetricName(StringView name, StringView param, position_t* topn, bool* minus); + +/** + * \brief Parse name for ranking metric given parameters. + */ +std::string MakeMetricName(StringView name, position_t topn, bool minus); +} // namespace xgboost::ltr #endif // XGBOOST_COMMON_RANKING_UTILS_H_ diff --git a/src/common/threading_utils.cuh b/src/common/threading_utils.cuh index c21d312d2e03..db5fe82f94ac 100644 --- a/src/common/threading_utils.cuh +++ b/src/common/threading_utils.cuh @@ -43,36 +43,33 @@ XGBOOST_DEVICE inline std::size_t DiscreteTrapezoidArea(std::size_t n, std::size * with h <= n */ template -inline size_t -SegmentedTrapezoidThreads(xgboost::common::Span group_ptr, - xgboost::common::Span out_group_threads_ptr, - size_t h) { +std::size_t SegmentedTrapezoidThreads(xgboost::common::Span group_ptr, + xgboost::common::Span out_group_threads_ptr, + std::size_t h) { CHECK_GE(group_ptr.size(), 1); CHECK_EQ(group_ptr.size(), out_group_threads_ptr.size()); - dh::LaunchN( - group_ptr.size(), [=] XGBOOST_DEVICE(size_t idx) { - if (idx == 0) { - out_group_threads_ptr[0] = 0; - return; - } + dh::LaunchN(group_ptr.size(), [=] XGBOOST_DEVICE(std::size_t idx) { + if (idx == 0) { + out_group_threads_ptr[0] = 0; + return; + } - size_t cnt = static_cast(group_ptr[idx] - group_ptr[idx - 1]); - out_group_threads_ptr[idx] = DiscreteTrapezoidArea(cnt, h); - }); + std::size_t cnt = static_cast(group_ptr[idx] - group_ptr[idx - 1]); + out_group_threads_ptr[idx] = DiscreteTrapezoidArea(cnt, h); + }); dh::InclusiveSum(out_group_threads_ptr.data(), out_group_threads_ptr.data(), out_group_threads_ptr.size()); - size_t total = 0; - dh::safe_cuda(cudaMemcpy( - &total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1, - sizeof(total), cudaMemcpyDeviceToHost)); + std::size_t total = 0; + dh::safe_cuda(cudaMemcpy(&total, out_group_threads_ptr.data() + out_group_threads_ptr.size() - 1, + sizeof(total), cudaMemcpyDeviceToHost)); return total; } /** * Called inside kernel to obtain coordinate from trapezoid grid. */ -XGBOOST_DEVICE inline void UnravelTrapeziodIdx(size_t i_idx, size_t n, - size_t *out_i, size_t *out_j) { +XGBOOST_DEVICE inline void UnravelTrapeziodIdx(std::size_t i_idx, std::size_t n, std::size_t *out_i, + std::size_t *out_j) { auto &i = *out_i; auto &j = *out_j; double idx = static_cast(i_idx); diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index a52695e02590..d80008cc0809 100644 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -8,9 +8,11 @@ #include #include -#include // std::int32_t +#include // for int32_t +#include // for malloc, free #include -#include // std::is_signed +#include // for bad_alloc +#include // for is_signed #include #include "xgboost/logging.h" @@ -266,7 +268,7 @@ class MemStackAllocator { if (MaxStackSize >= required_size_) { ptr_ = stack_mem_; } else { - ptr_ = reinterpret_cast(malloc(required_size_ * sizeof(T))); + ptr_ = reinterpret_cast(std::malloc(required_size_ * sizeof(T))); } if (!ptr_) { throw std::bad_alloc{}; @@ -278,7 +280,7 @@ class MemStackAllocator { ~MemStackAllocator() { if (required_size_ > MaxStackSize) { - free(ptr_); + std::free(ptr_); } } T& operator[](size_t i) { return ptr_[i]; } diff --git a/src/data/data.cc b/src/data/data.cc index d24048a2ab23..238aaefd47ce 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -7,11 +7,13 @@ #include #include +#include #include #include "../collective/communicator-inl.h" #include "../common/algorithm.h" // StableSort #include "../common/api_entry.h" // XGBAPIThreadLocalEntry +#include "../common/error_msg.h" // for GroupWeight, GroupSize #include "../common/group_data.h" #include "../common/io.h" #include "../common/linalg_op.h" @@ -32,6 +34,7 @@ #include "xgboost/context.h" #include "xgboost/host_device_vector.h" #include "xgboost/learner.h" +#include "xgboost/linalg.h" // Vector #include "xgboost/logging.h" #include "xgboost/string_view.h" #include "xgboost/version_config.h" @@ -485,7 +488,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) { } // uint info if (key == "group") { - linalg::Tensor t; + linalg::Vector t; CopyTensorInfoImpl(ctx, arr, &t); auto const& h_groups = t.Data()->HostVector(); group_ptr_.clear(); @@ -510,6 +513,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) { data::ValidateQueryGroup(group_ptr_); return; } + // float info linalg::Tensor t; CopyTensorInfoImpl<1>(ctx, arr, &t); @@ -700,58 +704,63 @@ void MetaInfo::Extend(MetaInfo const& that, bool accumulate_rows, bool check_col } } +namespace { +template +void CheckDevice(std::int32_t device, HostDeviceVector const& v) { + CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device) + << "Data is resided on a different device than `gpu_id`. " + << "Device that data is on: " << v.DeviceIdx() << ", " + << "`gpu_id` for XGBoost: " << device; +} +template +void CheckDevice(std::int32_t device, linalg::Tensor const& v) { + CheckDevice(device, *v.Data()); +} +} // anonymous namespace + void MetaInfo::Validate(std::int32_t device) const { if (group_ptr_.size() != 0 && weights_.Size() != 0) { - CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) - << "Size of weights must equal to number of groups when ranking " - "group is used."; + CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) << error::GroupWeight(); return; } if (group_ptr_.size() != 0) { CHECK_EQ(group_ptr_.back(), num_row_) - << "Invalid group structure. Number of rows obtained from groups " - "doesn't equal to actual number of rows given by data."; + << error::GroupSize() << "the actual number of rows given by data."; } - auto check_device = [device](HostDeviceVector const& v) { - CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device) - << "Data is resided on a different device than `gpu_id`. " - << "Device that data is on: " << v.DeviceIdx() << ", " - << "`gpu_id` for XGBoost: " << device; - }; if (weights_.Size() != 0) { CHECK_EQ(weights_.Size(), num_row_) << "Size of weights must equal to number of rows."; - check_device(weights_); + CheckDevice(device, weights_); return; } if (labels.Size() != 0) { CHECK_EQ(labels.Shape(0), num_row_) << "Size of labels must equal to number of rows."; - check_device(*labels.Data()); + CheckDevice(device, labels); return; } if (labels_lower_bound_.Size() != 0) { CHECK_EQ(labels_lower_bound_.Size(), num_row_) << "Size of label_lower_bound must equal to number of rows."; - check_device(labels_lower_bound_); + CheckDevice(device, labels_lower_bound_); return; } if (feature_weights.Size() != 0) { CHECK_EQ(feature_weights.Size(), num_col_) << "Size of feature_weights must equal to number of columns."; - check_device(feature_weights); + CheckDevice(device, feature_weights); } if (labels_upper_bound_.Size() != 0) { CHECK_EQ(labels_upper_bound_.Size(), num_row_) << "Size of label_upper_bound must equal to number of rows."; - check_device(labels_upper_bound_); + CheckDevice(device, labels_upper_bound_); return; } CHECK_LE(num_nonzero_, num_col_ * num_row_); if (base_margin_.Size() != 0) { CHECK_EQ(base_margin_.Size() % num_row_, 0) << "Size of base margin must be a multiple of number of rows."; - check_device(*base_margin_.Data()); + CheckDevice(device, base_margin_); } } @@ -1025,6 +1034,8 @@ SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const { bool SparsePage::IsIndicesSorted(int32_t n_threads) const { auto& h_offset = this->offset.HostVector(); auto& h_data = this->data.HostVector(); + n_threads = std::max(std::min(static_cast(n_threads), this->Size()), + static_cast(1)); std::vector is_sorted_tloc(n_threads, 0); common::ParallelFor(this->Size(), n_threads, [&](auto i) { auto beg = h_offset[i]; diff --git a/src/learner.cc b/src/learner.cc index 0e47c694cc92..27f8973a6c4b 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -113,7 +113,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter } // Skip other legacy fields. - Json ToJson() const { + [[nodiscard]] Json ToJson() const { Json obj{Object{}}; char floats[NumericLimits::kToCharsSize]; auto ret = to_chars(floats, floats + NumericLimits::kToCharsSize, base_score); @@ -163,7 +163,7 @@ struct LearnerModelParamLegacy : public dmlc::Parameter from_chars(str.c_str(), str.c_str() + str.size(), base_score); } - LearnerModelParamLegacy ByteSwap() const { + [[nodiscard]] LearnerModelParamLegacy ByteSwap() const { LearnerModelParamLegacy x = *this; dmlc::ByteSwap(&x.base_score, sizeof(x.base_score), 1); dmlc::ByteSwap(&x.num_feature, sizeof(x.num_feature), 1); diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index e06be9730e8f..728d8c230dd3 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -485,9 +485,13 @@ class QuantileError : public MetricNoCache { const char* Name() const override { return "quantile"; } void LoadConfig(Json const& in) override { - auto const& name = get(in["name"]); - CHECK_EQ(name, "quantile"); - FromJson(in["quantile_loss_param"], ¶m_); + auto const& obj = get(in); + auto it = obj.find("quantile_loss_param"); + if (it != obj.cend()) { + FromJson(it->second, ¶m_); + auto const& name = get(in["name"]); + CHECK_EQ(name, "quantile"); + } } void SaveConfig(Json* p_out) const override { auto& out = *p_out; diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index d39c7302ad54..9cac452bee7d 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -20,23 +20,47 @@ // corresponding headers that brings in those function declaration can't be included with CUDA). // This precludes the CPU and GPU logic to coexist inside a .cu file +#include "rank_metric.h" + +#include #include -#include -#include -#include +#include // for min, stable_sort +#include // for array +#include // for abs, log, sqrt +#include // for size_t +#include // for uint32_t +#include // for greater +#include // for numeric_limits +#include // for make_shared, shared_ptr +#include // for accumulate +#include // for to_string +#include // for make_pair +#include // for vector #include "../collective/communicator-inl.h" -#include "../common/algorithm.h" // Sort -#include "../common/math.h" -#include "../common/ranking_utils.h" // MakeMetricName -#include "../common/threading_utils.h" -#include "metric_common.h" -#include "xgboost/host_device_vector.h" +#include "../collective/communicator.h" // for Operation +#include "../common/algorithm.h" // for ArgSort +#include "../common/linalg_op.h" // for cbegin,cend +#include "../common/math.h" // for CmpFirst +#include "../common/optional_weight.h" // for OptionalWeights +#include "../common/ranking_utils.h" // for ParseMetricName, rel_degree_t +#include "../common/ranking_utils.h" // for LambdaRankParam +#include "../common/threading_utils.h" // for ParallelFor +#include "metric_common.h" // for MetricNoCache, EvalRankConfig, GPUMetric +#include "xgboost/base.h" // for Args, bst_group_t +#include "xgboost/cache.h" // for DMatrixCache +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/json.h" // for Json, String, ToJson, FromJson, get +#include "xgboost/linalg.h" // for MakeTensorView, Vector +#include "xgboost/metric.h" // for Metric, XGBOOST_REGISTER_METRIC +#include "xgboost/span.h" // for Span namespace { -using PredIndPair = std::pair; +using PredIndPair = std::pair; using PredIndPairContainer = std::vector; /* @@ -84,7 +108,6 @@ class PerGroupWeightPolicy { return info.GetWeight(group_id); } }; - } // anonymous namespace namespace xgboost { @@ -234,7 +257,7 @@ struct EvalRank : public MetricNoCache, public EvalRankConfig { protected: explicit EvalRank(const char* name, const char* param) { - this->name = ltr::MakeMetricName(name, param, &topn, &minus); + this->name = ltr::ParseMetricName(name, param, &topn, &minus); } virtual double EvalGroup(PredIndPairContainer *recptr) const = 0; @@ -257,71 +280,6 @@ struct EvalPrecision : public EvalRank { } }; -/*! \brief NDCG: Normalized Discounted Cumulative Gain at N */ -struct EvalNDCG : public EvalRank { - private: - double CalcDCG(const PredIndPairContainer &rec) const { - double sumdcg = 0.0; - for (size_t i = 0; i < rec.size() && i < this->topn; ++i) { - const unsigned rel = rec[i].second; - if (rel != 0) { - sumdcg += ((1 << rel) - 1) / std::log2(i + 2.0); - } - } - return sumdcg; - } - - public: - explicit EvalNDCG(const char* name, const char* param) : EvalRank(name, param) {} - - double EvalGroup(PredIndPairContainer *recptr) const override { - PredIndPairContainer &rec(*recptr); - std::stable_sort(rec.begin(), rec.end(), common::CmpFirst); - double dcg = CalcDCG(rec); - std::stable_sort(rec.begin(), rec.end(), common::CmpSecond); - double idcg = CalcDCG(rec); - if (idcg == 0.0f) { - if (this->minus) { - return 0.0f; - } else { - return 1.0f; - } - } - return dcg/idcg; - } -}; - -/*! \brief Mean Average Precision at N, for both classification and rank */ -struct EvalMAP : public EvalRank { - public: - explicit EvalMAP(const char* name, const char* param) : EvalRank(name, param) {} - - double EvalGroup(PredIndPairContainer *recptr) const override { - PredIndPairContainer &rec(*recptr); - std::stable_sort(rec.begin(), rec.end(), common::CmpFirst); - unsigned nhits = 0; - double sumap = 0.0; - for (size_t i = 0; i < rec.size(); ++i) { - if (rec[i].second != 0) { - nhits += 1; - if (i < this->topn) { - sumap += static_cast(nhits) / (i + 1); - } - } - } - if (nhits != 0) { - sumap /= nhits; - return sumap; - } else { - if (this->minus) { - return 0.0; - } else { - return 1.0; - } - } - } -}; - /*! \brief Cox: Partial likelihood of the Cox proportional hazards model */ struct EvalCox : public MetricNoCache { public: @@ -377,16 +335,218 @@ XGBOOST_REGISTER_METRIC(Precision, "pre") .describe("precision@k for rank.") .set_body([](const char* param) { return new EvalPrecision("pre", param); }); -XGBOOST_REGISTER_METRIC(NDCG, "ndcg") -.describe("ndcg@k for rank.") -.set_body([](const char* param) { return new EvalNDCG("ndcg", param); }); - -XGBOOST_REGISTER_METRIC(MAP, "map") -.describe("map@k for rank.") -.set_body([](const char* param) { return new EvalMAP("map", param); }); - XGBOOST_REGISTER_METRIC(Cox, "cox-nloglik") .describe("Negative log partial likelihood of Cox proportional hazards model.") .set_body([](const char*) { return new EvalCox(); }); + +// ranking metrics that requires cache +template +class EvalRankWithCache : public Metric { + protected: + ltr::LambdaRankParam param_; + bool minus_{false}; + std::string name_; + + DMatrixCache cache_{DMatrixCache::DefaultSize()}; + + public: + EvalRankWithCache(StringView name, const char* param) { + auto constexpr kMax = ltr::LambdaRankParam::NotSet(); + std::uint32_t topn{kMax}; + this->name_ = ltr::ParseMetricName(name, param, &topn, &minus_); + if (topn != kMax) { + param_.UpdateAllowUnknown(Args{{"lambdarank_num_pair_per_sample", std::to_string(topn)}, + {"lambdarank_pair_method", "topk"}}); + } + param_.UpdateAllowUnknown(Args{}); + } + void Configure(Args const&) override { + // do not configure, otherwise the ndcg param will be forced into the same as the one in + // objective. + } + void LoadConfig(Json const& in) override { + if (IsA(in)) { + return; + } + auto const& obj = get(in); + auto it = obj.find("lambdarank_param"); + if (it != obj.cend()) { + FromJson(it->second, ¶m_); + } + } + + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["name"] = String{this->Name()}; + out["lambdarank_param"] = ToJson(param_); + } + + double Evaluate(HostDeviceVector const& preds, std::shared_ptr p_fmat) override { + auto const& info = p_fmat->Info(); + auto p_cache = cache_.CacheItem(p_fmat, ctx_, info, param_); + if (p_cache->Param() != param_) { + p_cache = cache_.ResetItem(p_fmat, ctx_, info, param_); + } + CHECK(p_cache->Param() == param_); + CHECK_EQ(preds.Size(), info.labels.Size()); + + return this->Eval(preds, info, p_cache); + } + + virtual double Eval(HostDeviceVector const& preds, MetaInfo const& info, + std::shared_ptr p_cache) = 0; +}; + +namespace { +double Finalize(double score, double sw) { + std::array dat{score, sw}; + collective::Allreduce(dat.data(), dat.size()); + if (sw > 0.0) { + score = score / sw; + } + + CHECK_LE(score, 1.0 + kRtEps) + << "Invalid output score, might be caused by invalid query group weight."; + score = std::min(1.0, score); + + return score; +} +} // namespace + +/** + * \brief Implement the NDCG score function for learning to rank. + * + * Ties are ignored, which can lead to different result with other implementations. + */ +class EvalNDCG : public EvalRankWithCache { + public: + using EvalRankWithCache::EvalRankWithCache; + const char* Name() const override { return name_.c_str(); } + + double Eval(HostDeviceVector const& preds, MetaInfo const& info, + std::shared_ptr p_cache) override { + if (ctx_->IsCUDA()) { + auto ndcg = cuda_impl::NDCGScore(ctx_, info, preds, minus_, p_cache); + return Finalize(ndcg.Residue(), ndcg.Weights()); + } + + // group local ndcg + auto group_ptr = p_cache->DataGroupPtr(ctx_); + bst_group_t n_groups = group_ptr.size() - 1; + auto ndcg_gloc = p_cache->Dcg(ctx_); + std::fill_n(ndcg_gloc.Values().data(), ndcg_gloc.Size(), 0.0); + + auto h_inv_idcg = p_cache->InvIDCG(ctx_); + auto p_discount = p_cache->Discount(ctx_).data(); + + auto h_label = info.labels.HostView(); + auto h_predt = + linalg::MakeTensorView(ctx_->IsCPU() ? preds.ConstHostSpan() : preds.ConstDeviceSpan(), + {preds.Size()}, ctx_->gpu_id); + auto weights = common::MakeOptionalWeights(ctx_, info.weights_); + + common::ParallelFor(n_groups, ctx_->Threads(), [&](auto g) { + auto g_predt = h_predt.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1])); + auto g_labels = h_label.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1]), 0); + auto sorted_idx = common::ArgSort(ctx_, linalg::cbegin(g_predt), + linalg::cend(g_predt), std::greater<>{}); + double ndcg{.0}; + double inv_idcg = h_inv_idcg(g); + if (inv_idcg <= 0.0) { + ndcg_gloc(g) = minus_ ? 0.0 : 1.0; + return; + } + std::size_t n{std::min(sorted_idx.size(), static_cast(param_.TopK()))}; + if (param_.ndcg_exp_gain) { + for (std::size_t i = 0; i < n; ++i) { + ndcg += p_discount[i] * ltr::CalcDCGGain(g_labels(sorted_idx[i])) * inv_idcg; + } + } else { + for (std::size_t i = 0; i < n; ++i) { + ndcg += p_discount[i] * g_labels(sorted_idx[i]) * inv_idcg; + } + } + ndcg_gloc(g) += ndcg * weights[g]; + }); + double sum_w{0}; + if (weights.Empty()) { + sum_w = n_groups; + } else { + sum_w = std::accumulate(weights.weights.cbegin(), weights.weights.cend(), 0.0); + } + auto ndcg = std::accumulate(linalg::cbegin(ndcg_gloc), linalg::cend(ndcg_gloc), 0.0); + return Finalize(ndcg, sum_w); + } +}; + +class EvalMAPScore : public EvalRankWithCache { + public: + using EvalRankWithCache::EvalRankWithCache; + const char* Name() const override { return name_.c_str(); } + + double Eval(HostDeviceVector const& predt, MetaInfo const& info, + std::shared_ptr p_cache) override { + if (ctx_->IsCUDA()) { + auto map = cuda_impl::MAPScore(ctx_, info, predt, minus_, p_cache); + return Finalize(map.Residue(), map.Weights()); + } + + auto gptr = p_cache->DataGroupPtr(ctx_); + auto h_label = info.labels.HostView().Slice(linalg::All(), 0); + auto h_predt = + linalg::MakeTensorView(ctx_->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan(), + {predt.Size()}, ctx_->gpu_id); + + auto map_gloc = p_cache->Map(ctx_); + std::fill_n(map_gloc.data(), map_gloc.size(), 0.0); + auto rank_idx = p_cache->SortedIdx(ctx_, predt.ConstHostSpan()); + + common::ParallelFor(p_cache->Groups(), ctx_->Threads(), [&](auto g) { + auto g_predt = h_predt.Slice(linalg::Range(gptr[g], gptr[g + 1])); + auto g_label = h_label.Slice(linalg::Range(gptr[g], gptr[g + 1])); + auto g_rank = rank_idx.subspan(gptr[g]); + + auto n = std::min(static_cast(param_.TopK()), g_label.Size()); + double n_hits{0.0}; + for (std::size_t i = 0; i < n; ++i) { + auto p = g_label(g_rank[i]); + n_hits += p; + map_gloc[g] += n_hits / static_cast((i + 1)) * p; + } + for (std::size_t i = n; i < g_label.Size(); ++i) { + n_hits += g_label(g_rank[i]); + } + if (n_hits > 0.0) { + map_gloc[g] /= std::min(n_hits, static_cast(param_.TopK())); + } else { + map_gloc[g] = minus_ ? 0.0 : 1.0; + } + }); + + auto sw = 0.0; + auto weight = common::MakeOptionalWeights(ctx_, info.weights_); + if (!weight.Empty()) { + CHECK_EQ(weight.weights.size(), p_cache->Groups()); + } + for (std::size_t i = 0; i < map_gloc.size(); ++i) { + map_gloc[i] = map_gloc[i] * weight[i]; + sw += weight[i]; + } + auto sum = std::accumulate(map_gloc.cbegin(), map_gloc.cend(), 0.0); + return Finalize(sum, sw); + } +}; + +XGBOOST_REGISTER_METRIC(EvalMAP, "map") + .describe("map@k for ranking.") + .set_body([](char const* param) { + return new EvalMAPScore{"map", param}; + }); + +XGBOOST_REGISTER_METRIC(EvalNDCG, "ndcg") + .describe("ndcg@k for ranking.") + .set_body([](char const* param) { + return new EvalNDCG{"ndcg", param}; + }); } // namespace metric } // namespace xgboost diff --git a/src/metric/rank_metric.cu b/src/metric/rank_metric.cu index 5f98db7a93cd..2d94a9d5c258 100644 --- a/src/metric/rank_metric.cu +++ b/src/metric/rank_metric.cu @@ -2,19 +2,27 @@ * Copyright 2020-2023 by XGBoost Contributors */ #include -#include // make_counting_iterator -#include // reduce -#include - -#include // std::size_t -#include // std::shared_ptr - -#include "../common/cuda_context.cuh" // CUDAContext +#include // for make_counting_iterator +#include // for reduce + +#include // for transform +#include // for size_t +#include // for shared_ptr +#include // for vector + +#include "../common/cuda_context.cuh" // for CUDAContext +#include "../common/device_helpers.cuh" // for MakeTransformIterator +#include "../common/optional_weight.h" // for MakeOptionalWeights +#include "../common/ranking_utils.cuh" // for CalcQueriesDCG, NDCGCache #include "metric_common.h" -#include "xgboost/base.h" // XGBOOST_DEVICE -#include "xgboost/context.h" // Context -#include "xgboost/data.h" // MetaInfo -#include "xgboost/host_device_vector.h" // HostDeviceVector +#include "rank_metric.h" +#include "xgboost/base.h" // for XGBOOST_DEVICE +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/linalg.h" // for MakeTensorView +#include "xgboost/logging.h" // for CHECK +#include "xgboost/metric.h" namespace xgboost { namespace metric { @@ -117,167 +125,129 @@ struct EvalPrecisionGpu { } }; -/*! \brief NDCG: Normalized Discounted Cumulative Gain at N */ -struct EvalNDCGGpu { - public: - static void ComputeDCG(const dh::SegmentSorter &pred_sorter, - const float *dlabels, - const EvalRankConfig &ecfg, - // The order in which labels have to be accessed. The order is determined - // by sorting the predictions or the labels for the entire dataset - const xgboost::common::Span &dlabels_sort_order, - dh::caching_device_vector *dcgptr) { - dh::caching_device_vector &dcgs(*dcgptr); - // Group info on device - const auto &dgroups = pred_sorter.GetGroupsSpan(); - const auto &dgroup_idx = pred_sorter.GetGroupSegmentsSpan(); - - // First, determine non zero labels in the dataset individually - auto DetermineNonTrivialLabelLambda = [=] __device__(uint32_t idx) { - return (static_cast(dlabels[dlabels_sort_order[idx]])); - }; // NOLINT - - // Find each group's DCG value - const auto nitems = pred_sorter.GetNumItems(); - auto *ddcgs = dcgs.data().get(); - - int device_id = -1; - dh::safe_cuda(cudaGetDevice(&device_id)); - - // For each group item compute the aggregated precision - dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) { - const auto group_idx = dgroup_idx[idx]; - const auto group_begin = dgroups[group_idx]; - const auto ridx = idx - group_begin; - auto label = DetermineNonTrivialLabelLambda(idx); - if (ridx < ecfg.topn && label) { - atomicAdd(&ddcgs[group_idx], ((1 << label) - 1) / std::log2(ridx + 2.0)); - } - }); - } - - static double EvalMetric(const dh::SegmentSorter &pred_sorter, - const float *dlabels, - const EvalRankConfig &ecfg) { - // Sort the labels and compute IDCG - dh::SegmentSorter segment_label_sorter; - segment_label_sorter.SortItems(dlabels, pred_sorter.GetNumItems(), - pred_sorter.GetGroupSegmentsSpan()); - - uint32_t ngroups = pred_sorter.GetNumGroups(); - - dh::caching_device_vector idcg(ngroups, 0); - ComputeDCG(pred_sorter, dlabels, ecfg, segment_label_sorter.GetOriginalPositionsSpan(), &idcg); - - // Compute the DCG values next - dh::caching_device_vector dcg(ngroups, 0); - ComputeDCG(pred_sorter, dlabels, ecfg, pred_sorter.GetOriginalPositionsSpan(), &dcg); - - double *ddcg = dcg.data().get(); - double *didcg = idcg.data().get(); +XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre") +.describe("precision@k for rank computed on GPU.") +.set_body([](const char* param) { return new EvalRankGpu("pre", param); }); - int device_id = -1; - dh::safe_cuda(cudaGetDevice(&device_id)); - // Compute the group's DCG and reduce it across all groups - dh::LaunchN(ngroups, nullptr, [=] __device__(uint32_t gidx) { - if (didcg[gidx] == 0.0f) { - ddcg[gidx] = (ecfg.minus) ? 0.0f : 1.0f; - } else { - ddcg[gidx] /= didcg[gidx]; - } - }); +namespace cuda_impl { +PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info, + HostDeviceVector const &predt, bool minus, + std::shared_ptr p_cache) { + CHECK(p_cache); - // Allocator to be used for managing space overhead while performing reductions - dh::XGBCachingDeviceAllocator alloc; - return thrust::reduce(thrust::cuda::par(alloc), dcg.begin(), dcg.end()); + auto const &p = p_cache->Param(); + auto d_weight = common::MakeOptionalWeights(ctx, info.weights_); + if (!d_weight.Empty()) { + CHECK_EQ(d_weight.weights.size(), p_cache->Groups()); } -}; - -/*! \brief Mean Average Precision at N, for both classification and rank */ -struct EvalMAPGpu { - public: - static double EvalMetric(const dh::SegmentSorter &pred_sorter, - const float *dlabels, - const EvalRankConfig &ecfg) { - // Group info on device - const auto &dgroups = pred_sorter.GetGroupsSpan(); - const auto ngroups = pred_sorter.GetNumGroups(); - const auto &dgroup_idx = pred_sorter.GetGroupSegmentsSpan(); - - // Original positions of the predictions after they have been sorted - const auto &dpreds_orig_pos = pred_sorter.GetOriginalPositionsSpan(); - - // First, determine non zero labels in the dataset individually - const auto nitems = pred_sorter.GetNumItems(); - dh::caching_device_vector hits(nitems, 0); - auto DetermineNonTrivialLabelLambda = [=] __device__(uint32_t idx) { - return (static_cast(dlabels[dpreds_orig_pos[idx]]) != 0) ? 1 : 0; - }; // NOLINT - - thrust::transform(thrust::make_counting_iterator(static_cast(0)), - thrust::make_counting_iterator(nitems), - hits.begin(), - DetermineNonTrivialLabelLambda); - - // Allocator to be used by sort for managing space overhead while performing prefix scans - dh::XGBCachingDeviceAllocator alloc; + auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + predt.SetDevice(ctx->gpu_id); + auto d_predt = linalg::MakeTensorView(predt.ConstDeviceSpan(), {predt.Size()}, ctx->gpu_id); - // Next, prefix scan the nontrivial labels that are segmented to accumulate them. - // This is required for computing the metric sum - // Data segmented into different groups... - thrust::inclusive_scan_by_key(thrust::cuda::par(alloc), - dh::tcbegin(dgroup_idx), dh::tcend(dgroup_idx), - hits.begin(), // Input value - hits.begin()); // In-place scan + auto d_group_ptr = p_cache->DataGroupPtr(ctx); + auto n_groups = info.group_ptr_.size() - 1; - // Find each group's metric sum - dh::caching_device_vector sumap(ngroups, 0); - auto *dsumap = sumap.data().get(); - const auto *dhits = hits.data().get(); + auto d_inv_idcg = p_cache->InvIDCG(ctx); + auto d_sorted_idx = p_cache->SortedIdx(ctx, d_predt.Values()); + auto d_out_dcg = p_cache->Dcg(ctx); - int device_id = -1; - dh::safe_cuda(cudaGetDevice(&device_id)); - // For each group item compute the aggregated precision - dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) { - if (DetermineNonTrivialLabelLambda(idx)) { - const auto group_idx = dgroup_idx[idx]; - const auto group_begin = dgroups[group_idx]; - const auto ridx = idx - group_begin; - if (ridx < ecfg.topn) { - atomicAdd(&dsumap[group_idx], - static_cast(dhits[idx]) / (ridx + 1)); - } - } - }); + ltr::cuda_impl::CalcQueriesDCG(ctx, d_label, d_sorted_idx, p.ndcg_exp_gain, d_group_ptr, p.TopK(), + d_out_dcg); - // Aggregate the group's item precisions - dh::LaunchN(ngroups, nullptr, [=] __device__(uint32_t gidx) { - auto nhits = dgroups[gidx + 1] ? dhits[dgroups[gidx + 1] - 1] : 0; - if (nhits != 0) { - dsumap[gidx] /= nhits; - } else { - if (ecfg.minus) { - dsumap[gidx] = 0; - } else { - dsumap[gidx] = 1; + auto it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { + if (d_inv_idcg(i) <= 0.0) { + return PackedReduceResult{minus ? 0.0 : 1.0, static_cast(d_weight[i])}; } - } - }); - - return thrust::reduce(thrust::cuda::par(alloc), sumap.begin(), sumap.end()); + return PackedReduceResult{d_out_dcg(i) * d_inv_idcg(i) * d_weight[i], + static_cast(d_weight[i])}; + }); + auto pair = thrust::reduce(ctx->CUDACtx()->CTP(), it, it + d_out_dcg.Size(), + PackedReduceResult{0.0, 0.0}); + return pair; +} + +PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info, + HostDeviceVector const &predt, bool minus, + std::shared_ptr p_cache) { + auto d_group_ptr = p_cache->DataGroupPtr(ctx); + auto n_groups = info.group_ptr_.size() - 1; + auto d_label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + + predt.SetDevice(ctx->gpu_id); + auto d_rank_idx = p_cache->SortedIdx(ctx, predt.ConstDeviceSpan()); + auto key_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + [=] XGBOOST_DEVICE(std::size_t i) { return dh::SegmentId(d_group_ptr, i); }); + + auto get_label = [=] XGBOOST_DEVICE(std::size_t i) { + auto g = key_it[i]; + auto g_begin = d_group_ptr[g]; + auto g_end = d_group_ptr[g + 1]; + i -= g_begin; + auto g_label = d_label.Slice(linalg::Range(g_begin, g_end)); + auto g_rank = d_rank_idx.subspan(g_begin, g_end - g_begin); + return g_label(g_rank[i]); + }; + auto it = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), get_label); + + auto cuctx = ctx->CUDACtx(); + auto n_rel = p_cache->NumRelevant(ctx); + thrust::inclusive_scan_by_key(cuctx->CTP(), key_it, key_it + d_label.Size(), it, n_rel.data()); + + double topk = p_cache->Param().TopK(); + auto map = p_cache->Map(ctx); + thrust::fill_n(cuctx->CTP(), map.data(), map.size(), 0.0); + { + auto val_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { + auto g = key_it[i]; + auto g_begin = d_group_ptr[g]; + auto g_end = d_group_ptr[g + 1]; + i -= g_begin; + if (i >= topk) { + return 0.0; + } + + auto g_label = d_label.Slice(linalg::Range(g_begin, g_end)); + auto g_rank = d_rank_idx.subspan(g_begin, g_end - g_begin); + auto label = g_label(g_rank[i]); + + auto g_n_rel = n_rel.subspan(g_begin, g_end - g_begin); + auto nhits = g_n_rel[i]; + return nhits / static_cast(i + 1) * label; + }); + + std::size_t bytes; + cub::DeviceSegmentedReduce::Sum(nullptr, bytes, val_it, map.data(), p_cache->Groups(), + d_group_ptr.data(), d_group_ptr.data() + 1, cuctx->Stream()); + dh::TemporaryArray temp(bytes); + cub::DeviceSegmentedReduce::Sum(temp.data().get(), bytes, val_it, map.data(), p_cache->Groups(), + d_group_ptr.data(), d_group_ptr.data() + 1, cuctx->Stream()); } -}; - -XGBOOST_REGISTER_GPU_METRIC(PrecisionGpu, "pre") -.describe("precision@k for rank computed on GPU.") -.set_body([](const char* param) { return new EvalRankGpu("pre", param); }); - -XGBOOST_REGISTER_GPU_METRIC(NDCGGpu, "ndcg") -.describe("ndcg@k for rank computed on GPU.") -.set_body([](const char* param) { return new EvalRankGpu("ndcg", param); }); -XGBOOST_REGISTER_GPU_METRIC(MAPGpu, "map") -.describe("map@k for rank computed on GPU.") -.set_body([](const char* param) { return new EvalRankGpu("map", param); }); + PackedReduceResult result{0.0, 0.0}; + { + auto d_weight = common::MakeOptionalWeights(ctx, info.weights_); + if (!d_weight.Empty()) { + CHECK_EQ(d_weight.weights.size(), p_cache->Groups()); + } + auto val_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t g) { + auto g_begin = d_group_ptr[g]; + auto g_end = d_group_ptr[g + 1]; + auto g_n_rel = n_rel.subspan(g_begin, g_end - g_begin); + if (!g_n_rel.empty() && g_n_rel.back() > 0.0) { + return PackedReduceResult{map[g] * d_weight[g] / std::min(g_n_rel.back(), topk), + static_cast(d_weight[g])}; + } + return PackedReduceResult{minus ? 0.0 : 1.0, static_cast(d_weight[g])}; + }); + result = + thrust::reduce(cuctx->CTP(), val_it, val_it + map.size(), PackedReduceResult{0.0, 0.0}); + } + return result; +} +} // namespace cuda_impl } // namespace metric } // namespace xgboost diff --git a/src/metric/rank_metric.h b/src/metric/rank_metric.h new file mode 100644 index 000000000000..27897c7a759b --- /dev/null +++ b/src/metric/rank_metric.h @@ -0,0 +1,44 @@ +#ifndef XGBOOST_METRIC_RANK_METRIC_H_ +#define XGBOOST_METRIC_RANK_METRIC_H_ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include // for shared_ptr + +#include "../common/common.h" // for AssertGPUSupport +#include "../common/ranking_utils.h" // for NDCGCache, MAPCache +#include "metric_common.h" // for PackedReduceResult +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector + +namespace xgboost { +namespace metric { +namespace cuda_impl { +PackedReduceResult NDCGScore(Context const *ctx, MetaInfo const &info, + HostDeviceVector const &predt, bool minus, + std::shared_ptr p_cache); + +PackedReduceResult MAPScore(Context const *ctx, MetaInfo const &info, + HostDeviceVector const &predt, bool minus, + std::shared_ptr p_cache); + +#if !defined(XGBOOST_USE_CUDA) +inline PackedReduceResult NDCGScore(Context const *, MetaInfo const &, + HostDeviceVector const &, bool, + std::shared_ptr) { + common::AssertGPUSupport(); + return {}; +} + +inline PackedReduceResult MAPScore(Context const *, MetaInfo const &, + HostDeviceVector const &, bool, + std::shared_ptr) { + common::AssertGPUSupport(); + return {}; +} +#endif +} // namespace cuda_impl +} // namespace metric +} // namespace xgboost +#endif // XGBOOST_METRIC_RANK_METRIC_H_ diff --git a/src/objective/init_estimation.cc b/src/objective/init_estimation.cc index 96fd5d65379c..0d0280437f04 100644 --- a/src/objective/init_estimation.cc +++ b/src/objective/init_estimation.cc @@ -14,8 +14,7 @@ #include "xgboost/linalg.h" // Tensor,Vector #include "xgboost/task.h" // ObjInfo -namespace xgboost { -namespace obj { +namespace xgboost::obj { void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector* base_score) const { if (this->Task().task == ObjInfo::kRegression) { CheckInitInputs(info); @@ -31,14 +30,13 @@ void FitIntercept::InitEstimation(MetaInfo const& info, linalg::Vector* b ObjFunction::Create(get(config["name"]), this->ctx_)}; new_obj->LoadConfig(config); new_obj->GetGradient(dummy_predt, info, 0, &gpair); + bst_target_t n_targets = this->Targets(info); linalg::Vector leaf_weight; tree::FitStump(this->ctx_, gpair, n_targets, &leaf_weight); - // workaround, we don't support multi-target due to binary model serialization for // base margin. common::Mean(this->ctx_, leaf_weight, base_score); this->PredTransform(base_score->Data()); } -} // namespace obj -} // namespace xgboost +} // namespace xgboost::obj diff --git a/src/objective/init_estimation.h b/src/objective/init_estimation.h index b0a91d8c3ec7..0ac5c52065dc 100644 --- a/src/objective/init_estimation.h +++ b/src/objective/init_estimation.h @@ -7,8 +7,7 @@ #include "xgboost/linalg.h" // Tensor #include "xgboost/objective.h" // ObjFunction -namespace xgboost { -namespace obj { +namespace xgboost::obj { class FitIntercept : public ObjFunction { void InitEstimation(MetaInfo const& info, linalg::Vector* base_score) const override; }; @@ -20,6 +19,5 @@ inline void CheckInitInputs(MetaInfo const& info) { << "Number of weights should be equal to number of data points."; } } -} // namespace obj -} // namespace xgboost +} // namespace xgboost::obj #endif // XGBOOST_OBJECTIVE_INIT_ESTIMATION_H_ diff --git a/src/objective/lambdarank_obj.cc b/src/objective/lambdarank_obj.cc new file mode 100644 index 000000000000..6320ac14f8f3 --- /dev/null +++ b/src/objective/lambdarank_obj.cc @@ -0,0 +1,603 @@ +/** + * Copyright 2023 by XGBoost contributors + */ +#include "lambdarank_obj.h" + +#include // for DMLC_REGISTRY_FILE_TAG + +#include // for min, fill_n, transform +#include // for pow, log2, isinf +#include // for size_t +#include // for int32_t +#include // for shared_ptr, make_shared, static_pointer_cast +#include // for string +#include // for make_tuple, apply +#include // for is_floating_point +#include // for swap +#include // for vector + +#include "../common/algorithm.h" // for ArgSort +#include "../common/common.h" // for AssertGPUSupport +#include "../common/error_msg.h" // for GroupWeight, LabelScoreSize +#include "../common/optional_weight.h" // for MakeOptionalWeights +#include "../common/ranking_utils.h" // for LambdaRankParam, NDCGCache +#include "../common/threading_utils.h" // for ParallelFor +#include "init_estimation.h" // for FitIntercept +#include "xgboost/base.h" // for GradientPair, bst_group_t +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/json.h" // for Json, get, ToJson, FromJson +#include "xgboost/linalg.h" // for Vector, VectorView, Range, Zeros, Constant +#include "xgboost/logging.h" // for CHECK_LE,CHECK_EQ,CHECK_GE +#include "xgboost/objective.h" // for ObjFunction +#include "xgboost/span.h" // for Span, operator!= +#include "xgboost/task.h" // for ObjInfo + +namespace xgboost::obj { +namespace { +void Normalize(double sum_lambda, common::Span g_gpair) { + if (sum_lambda > 0.0) { + double norm = std::log2(1.0 + sum_lambda) / sum_lambda; + std::transform(g_gpair.data(), g_gpair.data() + g_gpair.size(), g_gpair.data(), + [norm](GradientPair const& g) { return g * norm; }); + } +} + +void UpdateBiasLoss(double cost, // cross entropy loss + linalg::VectorView ti_plus, + linalg::VectorView tj_minus, + std::size_t rank_high, // position of the doc that ranked higher + std::size_t rank_low, // position of the doc that ranked lower + linalg::VectorView li, linalg::VectorView lj) { + auto k = ti_plus.Size() - 1; + // fixme: We should probably use all the positions. If we skip the update due to having + // high/low > k, we might be lossing out too many pairs, if we cap the position, then we + // might be accumulating too many tail bias into the last tracked position. + if (rank_high < k && rank_low < k) { + li(rank_high) += cost / (tj_minus(rank_low) + kRtEps); + lj(rank_low) += cost / (ti_plus(rank_high) + kRtEps); + } +} +} // anonymous namespace + +template +class LambdaRankObj : public FitIntercept { + MetaInfo const* p_info_{nullptr}; + + // Update position biased for unbiased click data + void UpdatePositionBias() { + auto n_groups = p_cache_->Groups(); + auto gptr = p_cache_->DataGroupPtr(ctx_); + + li_full_.SetDevice(ctx_->gpu_id); + lj_full_.SetDevice(ctx_->gpu_id); + li_.SetDevice(ctx_->gpu_id); + lj_.SetDevice(ctx_->gpu_id); + + if (ctx_->IsCPU()) { + auto ti_plus = ti_plus_.HostView(); + auto tj_minus = tj_minus_.HostView(); + auto li = li_.HostView(); + auto lj = lj_.HostView(); + + auto regularizer = param_.Regularizer(); + + auto li_full = li_full_.HostView(); + auto lj_full = lj_full_.HostView(); + + for (bst_group_t g{0}; g < n_groups; ++g) { + auto begin = gptr[g]; + auto end = gptr[g + 1]; + std::size_t group_size = end - begin; + auto n = std::min(group_size, p_cache_->MaxPositionSize()); + + auto g_li = li_full.Slice(linalg::Range(begin, end)); + auto g_lj = lj_full.Slice(linalg::Range(begin, end)); + + for (std::size_t i{0}; i < n; ++i) { + li(i) += g_li(i); + lj(i) += g_lj(i); + } + } + + for (std::size_t i = 0; i < ti_plus.Size(); ++i) { + ti_plus(i) = std::pow(li(i) / (li(0) + kRtEps), regularizer); + tj_minus(i) = std::pow(lj(i) / (lj(0) + kRtEps), regularizer); + assert(!std::isinf(ti_plus(i))); + assert(!std::isinf(tj_minus(i))); + } + } else { + cuda_impl::LambdaRankUpdatePositionBias(ctx_, li_full_.View(ctx_->gpu_id), + lj_full_.View(ctx_->gpu_id), &ti_plus_, &tj_minus_, + &li_, &lj_, p_cache_); + } + + li_full_.Data()->Fill(0.0); + lj_full_.Data()->Fill(0.0); + + li_.Data()->Fill(0.0); + lj_.Data()->Fill(0.0); + } + + protected: + // L / tj-* (eq. 30) + linalg::Vector li_; + // L / ti+* + linalg::Vector lj_; + // position bias ratio for relevant doc, ti+ (eq. 30) + linalg::Vector ti_plus_; + // position bias ratio for irrelevant doc, tj- (eq. 31) + linalg::Vector tj_minus_; + // li buffer for all samples + linalg::Vector li_full_; + // lj buffer for all samples + linalg::Vector lj_full_; + + ltr::LambdaRankParam param_; + // cache + std::shared_ptr p_cache_; + + [[nodiscard]] std::shared_ptr GetCache() const { + auto ptr = std::static_pointer_cast(p_cache_); + CHECK(ptr); + return ptr; + } + + // get group view for li/lj + linalg::VectorView GroupLoss(bst_group_t g, linalg::Vector* v) const { + auto gptr = p_cache_->DataGroupPtr(ctx_); + auto begin = gptr[g]; + auto end = gptr[g + 1]; + if (param_.lambdarank_unbiased) { + return v->HostView().Slice(linalg::Range(begin, end)); + } + return v->HostView(); + } + + // Calculate lambda gradient for each group on CPU. + template + void CalcLambdaForGroup(std::int32_t iter, common::Span g_predt, + linalg::VectorView g_label, float w, + common::Span g_rank, bst_group_t g, Delta delta, + common::Span g_gpair) { + std::fill_n(g_gpair.data(), g_gpair.size(), GradientPair{}); + auto p_gpair = g_gpair.data(); + + auto ti_plus = ti_plus_.HostView(); + auto tj_minus = tj_minus_.HostView(); + + auto li = GroupLoss(g, &li_full_); + auto lj = GroupLoss(g, &lj_full_); + + // Normalization, first used by LightGBM. + // https://github.com/microsoft/LightGBM/pull/2331#issuecomment-523259298 + double sum_lambda{0}; + + auto delta_op = [&](auto const&... args) { return delta(args..., g); }; + + auto loop = [&](std::size_t i, std::size_t j) { + // higher/lower on the ranked list + std::size_t rank_high = i, rank_low = j; + if (g_label(g_rank[rank_high]) == g_label(g_rank[rank_low])) { + return; + } + if (g_label(g_rank[rank_high]) < g_label(g_rank[rank_low])) { + std::swap(rank_high, rank_low); + } + + double cost; + auto pg = LambdaGrad(g_label, g_predt, g_rank, rank_high, rank_low, + delta_op, ti_plus, tj_minus, &cost); + auto ng = Repulse(pg); + + std::size_t idx_high = g_rank[rank_high]; + std::size_t idx_low = g_rank[rank_low]; + p_gpair[idx_high] += pg; + p_gpair[idx_low] += ng; + + if (unbiased) { + UpdateBiasLoss(cost, ti_plus, tj_minus, rank_high, rank_low, li, lj); + } + + sum_lambda += -2.0 * static_cast(pg.GetGrad()); + }; + + MakePairs(ctx_, iter, p_cache_, g, g_label, g_rank, loop); + + if (normalize) { + Normalize(sum_lambda, g_gpair); + } + + auto w_norm = p_cache_->WeightNorm(); + std::transform(g_gpair.begin(), g_gpair.end(), g_gpair.begin(), + [&](GradientPair const& gpair) { return gpair * w * w_norm; }); + } + + public: + void Configure(Args const& args) override { param_.UpdateAllowUnknown(args); } + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["name"] = String(Loss::Name()); + out["lambdarank_param"] = ToJson(param_); + } + void LoadConfig(Json const& in) override { + auto const& obj = get(in); + if (obj.find("lambdarank_param") != obj.cend()) { + FromJson(in["lambdarank_param"], ¶m_); + } + } + [[nodiscard]] ObjInfo Task() const override { return ObjInfo{ObjInfo::kRanking}; } + + [[nodiscard]] bst_target_t Targets(MetaInfo const& info) const override { + CHECK_LE(info.labels.Shape(1), 1) << "multi-output for LTR is not yet supported."; + return 1; + } + + [[nodiscard]] const char* RankEvalMetric(StringView metric) const { + static thread_local std::string name; + if (param_.HasTruncation()) { + name = ltr::MakeMetricName(metric, param_.NumPair(), false); + } else { + name = ltr::MakeMetricName(metric, param_.NotSet(), false); + } + return name.c_str(); + } + + void GetGradient(HostDeviceVector const& predt, MetaInfo const& info, std::int32_t iter, + HostDeviceVector* out_gpair) override { + CHECK_EQ(info.labels.Size(), predt.Size()) << error::LabelScoreSize(); + + // init/renew cache + if (!p_cache_ || p_info_ != &info || p_cache_->Param() != param_) { + p_cache_ = std::make_shared(ctx_, info, param_); + p_info_ = &info; + } + std::size_t n_groups = p_cache_->Groups(); + if (!info.weights_.Empty()) { + CHECK_EQ(info.weights_.Size(), n_groups) << error::GroupWeight(); + } + + if (ti_plus_.Size() == 0 && param_.lambdarank_unbiased) { + CHECK_EQ(iter, 0); + ti_plus_ = linalg::Constant(ctx_, 1.0, p_cache_->MaxPositionSize()); + tj_minus_ = linalg::Constant(ctx_, 1.0, p_cache_->MaxPositionSize()); + + li_ = linalg::Zeros(ctx_, p_cache_->MaxPositionSize()); + lj_ = linalg::Zeros(ctx_, p_cache_->MaxPositionSize()); + + li_full_ = linalg::Zeros(ctx_, info.num_row_); + lj_full_ = linalg::Zeros(ctx_, info.num_row_); + } + static_cast(this)->GetGradientImpl(iter, predt, info, out_gpair); + + if (param_.lambdarank_unbiased) { + this->UpdatePositionBias(); + } + } +}; + +class LambdaRankNDCG : public LambdaRankObj { + public: + template + void CalcLambdaForGroupNDCG(std::int32_t iter, common::Span g_predt, + linalg::VectorView g_label, float w, + common::Span g_rank, + common::Span g_gpair, + linalg::VectorView inv_IDCG, + common::Span discount, bst_group_t g) { + auto delta = [&](auto y_high, auto y_low, std::size_t rank_high, std::size_t rank_low, + bst_group_t g) { + static_assert(std::is_floating_point::value); + return DeltaNDCG(y_high, y_low, rank_high, rank_low, inv_IDCG(g), discount); + }; + + if (p_cache_->IsBinary()) { + this->CalcLambdaForGroup(iter, g_predt, g_label, w, g_rank, g, delta, + g_gpair); + } else { + this->CalcLambdaForGroup(iter, g_predt, g_label, w, g_rank, g, delta, + g_gpair); + } + } + + void GetGradientImpl(std::int32_t iter, const HostDeviceVector& predt, + const MetaInfo& info, HostDeviceVector* out_gpair) { + if (ctx_->IsCUDA()) { + cuda_impl::LambdaRankGetGradientNDCG( + ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id), + tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id), + out_gpair); + return; + } + + bst_group_t n_groups = p_cache_->Groups(); + auto gptr = p_cache_->DataGroupPtr(ctx_); + + out_gpair->Resize(info.num_row_); + auto h_gpair = out_gpair->HostSpan(); + auto h_predt = predt.ConstHostSpan(); + auto h_label = info.labels.HostView(); + auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_); + auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); }; + + auto dct = GetCache()->Discount(ctx_); + auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt); + auto inv_IDCG = GetCache()->InvIDCG(ctx_); + + common::ParallelFor(n_groups, ctx_->Threads(), [&](auto g) { + std::size_t cnt = gptr[g + 1] - gptr[g]; + auto w = h_weight[g]; + auto g_predt = h_predt.subspan(gptr[g], cnt); + auto g_gpair = h_gpair.subspan(gptr[g], cnt); + auto g_label = h_label.Slice(make_range(g), 0); + auto g_rank = rank_idx.subspan(gptr[g], cnt); + + auto args = + std::make_tuple(this, iter, g_predt, g_label, w, g_rank, g_gpair, inv_IDCG, dct, g); + + if (param_.lambdarank_unbiased) { + if (param_.ndcg_exp_gain) { + std::apply(&LambdaRankNDCG::CalcLambdaForGroupNDCG, args); + } else { + std::apply(&LambdaRankNDCG::CalcLambdaForGroupNDCG, args); + } + } else { + if (param_.ndcg_exp_gain) { + std::apply(&LambdaRankNDCG::CalcLambdaForGroupNDCG, args); + } else { + std::apply(&LambdaRankNDCG::CalcLambdaForGroupNDCG, args); + } + } + }); + } + + static char const* Name() { return "rank:ndcg"; } + [[nodiscard]] const char* DefaultEvalMetric() const override { + return this->RankEvalMetric("ndcg"); + } + [[nodiscard]] Json DefaultMetricConfig() const override { + Json config{Object{}}; + config["name"] = String{DefaultEvalMetric()}; + config["lambdarank_param"] = ToJson(param_); + return config; + } +}; + +namespace cuda_impl { +#if !defined(XGBOOST_USE_CUDA) +void LambdaRankGetGradientNDCG(Context const*, std::int32_t, HostDeviceVector const&, + const MetaInfo&, std::shared_ptr, + linalg::VectorView, // input bias ratio + linalg::VectorView, // input bias ratio + linalg::VectorView, linalg::VectorView, + HostDeviceVector*) { + common::AssertGPUSupport(); +} + +void LambdaRankUpdatePositionBias(Context const*, linalg::VectorView, + linalg::VectorView, linalg::Vector*, + linalg::Vector*, linalg::Vector*, + linalg::Vector*, std::shared_ptr) { + common::AssertGPUSupport(); +} +#endif // !defined(XGBOOST_USE_CUDA) +} // namespace cuda_impl + +namespace cpu_impl { +void MAPStat(Context const* ctx, linalg::VectorView label, + common::Span rank_idx, std::shared_ptr p_cache) { + auto h_n_rel = p_cache->NumRelevant(ctx); + auto gptr = p_cache->DataGroupPtr(ctx); + + CHECK_EQ(h_n_rel.size(), gptr.back()); + CHECK_EQ(h_n_rel.size(), label.Size()); + + auto h_acc = p_cache->Acc(ctx); + + common::ParallelFor(p_cache->Groups(), ctx->Threads(), [&](auto g) { + auto cnt = gptr[g + 1] - gptr[g]; + auto g_n_rel = h_n_rel.subspan(gptr[g], cnt); + auto g_rank = rank_idx.subspan(gptr[g], cnt); + auto g_label = label.Slice(linalg::Range(gptr[g], gptr[g + 1])); + + // The number of relevant documents at each position + g_n_rel[0] = g_label(g_rank[0]); + for (std::size_t k = 1; k < g_rank.size(); ++k) { + g_n_rel[k] = g_n_rel[k - 1] + g_label(g_rank[k]); + } + + // \sum l_k/k + auto g_acc = h_acc.subspan(gptr[g], cnt); + g_acc[0] = g_label(g_rank[0]) / 1.0; + + for (std::size_t k = 1; k < g_rank.size(); ++k) { + g_acc[k] = g_acc[k - 1] + (g_label(g_rank[k]) / static_cast(k + 1)); + } + }); +} +} // namespace cpu_impl + +class LambdaRankMAP : public LambdaRankObj { + public: + void GetGradientImpl(std::int32_t iter, const HostDeviceVector& predt, + const MetaInfo& info, HostDeviceVector* out_gpair) { + CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the MAP objective."; + if (ctx_->IsCUDA()) { + return cuda_impl::LambdaRankGetGradientMAP( + ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id), + tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id), + out_gpair); + } + + auto gptr = p_cache_->DataGroupPtr(ctx_).data(); + bst_group_t n_groups = p_cache_->Groups(); + + out_gpair->Resize(info.num_row_); + auto h_gpair = out_gpair->HostSpan(); + auto h_label = info.labels.HostView().Slice(linalg::All(), 0); + auto h_predt = predt.ConstHostSpan(); + auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt); + auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_); + + auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); }; + + cpu_impl::MAPStat(ctx_, h_label, rank_idx, GetCache()); + auto n_rel = GetCache()->NumRelevant(ctx_); + auto acc = GetCache()->Acc(ctx_); + bool is_binary = p_cache_->IsBinary(); + + auto delta_map = [&](auto y_high, auto y_low, std::size_t rank_high, std::size_t rank_low, + bst_group_t g) { + if (rank_high > rank_low) { + std::swap(rank_high, rank_low); + std::swap(y_high, y_low); + } + auto cnt = gptr[g + 1] - gptr[g]; + // In a hot loop + auto g_n_rel = common::Span{n_rel.data() + gptr[g], cnt}; + auto g_acc = common::Span{acc.data() + gptr[g], cnt}; + auto d = DeltaMAP(y_high, y_low, rank_high, rank_low, g_n_rel, g_acc); + return d; + }; + using D = decltype(delta_map); + + common::ParallelFor(n_groups, ctx_->Threads(), [&](auto g) { + auto cnt = gptr[g + 1] - gptr[g]; + auto w = h_weight[g]; + auto g_predt = h_predt.subspan(gptr[g], cnt); + auto g_gpair = h_gpair.subspan(gptr[g], cnt); + auto g_label = h_label.Slice(make_range(g)); + auto g_rank = rank_idx.subspan(gptr[g], cnt); + + auto args = std::make_tuple(this, iter, g_predt, g_label, w, g_rank, g, delta_map, g_gpair); + + if (param_.lambdarank_unbiased) { + if (is_binary) { + std::apply(&LambdaRankMAP::CalcLambdaForGroup, args); + } else { + std::apply(&LambdaRankMAP::CalcLambdaForGroup, args); + } + } else { + if (is_binary) { + std::apply(&LambdaRankMAP::CalcLambdaForGroup, args); + } else { + std::apply(&LambdaRankMAP::CalcLambdaForGroup, args); + } + } + }); + } + static char const* Name() { return "rank:map"; } + [[nodiscard]] const char* DefaultEvalMetric() const override { + return this->RankEvalMetric("map"); + } +}; + +#if !defined(XGBOOST_USE_CUDA) +namespace cuda_impl { +void MAPStat(Context const*, MetaInfo const&, common::Span, + std::shared_ptr) { + common::AssertGPUSupport(); +} + +void LambdaRankGetGradientMAP(Context const*, std::int32_t, HostDeviceVector const&, + const MetaInfo&, std::shared_ptr, + linalg::VectorView, // input bias ratio + linalg::VectorView, // input bias ratio + linalg::VectorView, linalg::VectorView, + HostDeviceVector*) { + common::AssertGPUSupport(); +} +} // namespace cuda_impl +#endif // !defined(XGBOOST_USE_CUDA) + +/** + * \brief The RankNet loss. + */ +class LambdaRankPairwise : public LambdaRankObj { + public: + void GetGradientImpl(std::int32_t iter, const HostDeviceVector& predt, + const MetaInfo& info, HostDeviceVector* out_gpair) { + CHECK(param_.ndcg_exp_gain) << "NDCG gain can not be set for the pairwise objective."; + if (ctx_->IsCUDA()) { + return cuda_impl::LambdaRankGetGradientPairwise( + ctx_, iter, predt, info, GetCache(), ti_plus_.View(ctx_->gpu_id), + tj_minus_.View(ctx_->gpu_id), li_full_.View(ctx_->gpu_id), lj_full_.View(ctx_->gpu_id), + out_gpair); + } + + auto gptr = p_cache_->DataGroupPtr(ctx_); + bst_group_t n_groups = p_cache_->Groups(); + + out_gpair->Resize(info.num_row_); + auto h_gpair = out_gpair->HostSpan(); + auto h_label = info.labels.HostView().Slice(linalg::All(), 0); + auto h_predt = predt.ConstHostSpan(); + auto h_weight = common::MakeOptionalWeights(ctx_, info.weights_); + + auto make_range = [&](bst_group_t g) { return linalg::Range(gptr[g], gptr[g + 1]); }; + auto rank_idx = p_cache_->SortedIdx(ctx_, h_predt); + + auto delta = [](auto...) { return 1.0; }; + using D = decltype(delta); + + common::ParallelFor(n_groups, ctx_->Threads(), [&](auto g) { + auto cnt = gptr[g + 1] - gptr[g]; + auto w = h_weight[g]; + auto g_predt = h_predt.subspan(gptr[g], cnt); + auto g_gpair = h_gpair.subspan(gptr[g], cnt); + auto g_label = h_label.Slice(make_range(g)); + auto g_rank = rank_idx.subspan(gptr[g], cnt); + + auto args = std::make_tuple(this, iter, g_predt, g_label, w, g_rank, g, delta, g_gpair); + + if (p_cache_->IsBinary()) { + if (param_.lambdarank_unbiased) { + std::apply(&LambdaRankPairwise::CalcLambdaForGroup, args); + } else { + std::apply(&LambdaRankPairwise::CalcLambdaForGroup, args); + } + } else { + if (param_.lambdarank_unbiased) { + std::apply(&LambdaRankPairwise::CalcLambdaForGroup, args); + } else { + std::apply(&LambdaRankPairwise::CalcLambdaForGroup, args); + } + } + }); + } + + static char const* Name() { return "rank:pairwise"; } + [[nodiscard]] const char* DefaultEvalMetric() const override { + return this->RankEvalMetric("ndcg"); + } +}; + +#if !defined(XGBOOST_USE_CUDA) +namespace cuda_impl { +void LambdaRankGetGradientPairwise(Context const*, std::int32_t, HostDeviceVector const&, + const MetaInfo&, std::shared_ptr, + linalg::VectorView, // input bias ratio + linalg::VectorView, // input bias ratio + linalg::VectorView, linalg::VectorView, + HostDeviceVector*) { + common::AssertGPUSupport(); +} +} // namespace cuda_impl +#endif // !defined(XGBOOST_USE_CUDA) + +XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, LambdaRankNDCG::Name()) + .describe("LambdaRank with NDCG loss as objective") + .set_body([]() { return new LambdaRankNDCG{}; }); + +XGBOOST_REGISTER_OBJECTIVE(LambdaRankPairwise, LambdaRankPairwise::Name()) + .describe("LambdaRank with RankNet loss as objective") + .set_body([]() { return new LambdaRankPairwise{}; }); + +XGBOOST_REGISTER_OBJECTIVE(LambdaRankMAP, LambdaRankMAP::Name()) + .describe("LambdaRank with MAP loss as objective.") + .set_body([]() { return new LambdaRankMAP{}; }); + +DMLC_REGISTRY_FILE_TAG(lambdarank_obj); + +} // namespace xgboost::obj diff --git a/src/objective/lambdarank_obj.cu b/src/objective/lambdarank_obj.cu new file mode 100644 index 000000000000..7b42a8786341 --- /dev/null +++ b/src/objective/lambdarank_obj.cu @@ -0,0 +1,597 @@ +/** + * Copyright 2015-2023 by XGBoost contributors + * + * \brief CUDA implementation of lambdarank. + */ +#include // for fill_n +#include // for for_each_n +#include // for make_counting_iterator +#include // for make_zip_iterator +#include // for make_tuple, tuple, tie, get + +#include // for min +#include // for assert +#include // for abs, log2, isinf +#include // for size_t +#include // for int32_t +#include // for shared_ptr +#include + +#include "../common/algorithm.cuh" // for SegmentedArgSort +#include "../common/cuda_context.cuh" // for CUDAContext +#include "../common/deterministic.cuh" // for CreateRoundingFactor, TruncateWithRounding +#include "../common/device_helpers.cuh" // for SegmentId, TemporaryArray, AtomicAddGpair +#include "../common/optional_weight.h" // for MakeOptionalWeights +#include "../common/ranking_utils.h" // for NDCGCache, LambdaRankParam, rel_degree_t +#include "lambdarank_obj.cuh" +#include "lambdarank_obj.h" +#include "xgboost/base.h" // for bst_group_t, XGBOOST_DEVICE, GradientPair +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/linalg.h" // for VectorView, Range, Vector +#include "xgboost/logging.h" +#include "xgboost/span.h" // for Span + +namespace xgboost::obj { +DMLC_REGISTRY_FILE_TAG(lambdarank_obj_cu); + +namespace cuda_impl { +namespace { +/** + * \brief Calculate minimum value of bias for floating point truncation. + */ +void MinBias(Context const* ctx, std::shared_ptr p_cache, + linalg::VectorView t_plus, linalg::VectorView tj_minus, + common::Span d_min) { + CHECK_EQ(d_min.size(), 2); + auto cuctx = ctx->CUDACtx(); + + auto k = t_plus.Size(); + auto const& p = p_cache->Param(); + CHECK_GT(k, 0); + CHECK_EQ(k, p_cache->MaxPositionSize()); + + auto key_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { return i * k; }); + auto val_it = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), + [=] XGBOOST_DEVICE(std::size_t i) { + if (i >= k) { + return std::abs(tj_minus(i - k)); + } + return std::abs(t_plus(i)); + }); + std::size_t bytes; + cub::DeviceSegmentedReduce::Min(nullptr, bytes, val_it, d_min.data(), 2, key_it, key_it + 1, + cuctx->Stream()); + dh::TemporaryArray temp(bytes); + cub::DeviceSegmentedReduce::Min(temp.data().get(), bytes, val_it, d_min.data(), 2, key_it, + key_it + 1, cuctx->Stream()); +} + +/** + * \brief Type for gradient statistic. (Gradient, cost for unbiased LTR, normalization factor) + */ +using GradCostNorm = thrust::tuple; + +/** + * \brief Obtain and update the gradient for one pair. + */ +template +struct GetGradOp { + MakePairsOp make_pair; + Delta delta; + + bool need_update; + + auto __device__ operator()(std::size_t idx) -> GradCostNorm { + auto const& args = make_pair.args; + auto g = dh::SegmentId(args.d_threads_group_ptr, idx); + + auto data_group_begin = static_cast(args.d_group_ptr[g]); + std::size_t n_data = args.d_group_ptr[g + 1] - data_group_begin; + // obtain group segment data. + auto g_label = args.labels.Slice(linalg::Range(data_group_begin, data_group_begin + n_data), 0); + auto g_predt = args.predts.subspan(data_group_begin, n_data); + auto g_gpair = args.gpairs.subspan(data_group_begin, n_data).data(); + auto g_sorted_idx = args.d_sorted_idx.subspan(data_group_begin, n_data); + + auto [i, j] = make_pair(idx, g); + + std::size_t rank_high = i, rank_low = j; + if (g_label(g_sorted_idx[i]) == g_label(g_sorted_idx[j])) { + return thrust::make_tuple(GradientPair{}, 0.0, 0.0); + } + if (g_label(g_sorted_idx[i]) < g_label(g_sorted_idx[j])) { + thrust::swap(rank_high, rank_low); + } + + double cost{0}; + + auto delta_op = [&](auto const&... args) { return delta(args..., g); }; + GradientPair pg = + LambdaGrad(g_label, g_predt, g_sorted_idx, rank_high, rank_low, + delta_op, args.ti_plus, args.tj_minus, &cost); + + if (need_update) { + // second run, update the gradient + std::size_t idx_high = g_sorted_idx[rank_high]; + std::size_t idx_low = g_sorted_idx[rank_low]; + auto ng = Repulse(pg); + + auto gr = args.d_roundings(g); + // positive gradient truncated + auto pgt = GradientPair{common::TruncateWithRounding(gr.GetGrad(), pg.GetGrad()), + common::TruncateWithRounding(gr.GetHess(), pg.GetHess())}; + // negative gradient truncated + auto ngt = GradientPair{common::TruncateWithRounding(gr.GetGrad(), ng.GetGrad()), + common::TruncateWithRounding(gr.GetHess(), ng.GetHess())}; + + dh::AtomicAddGpair(g_gpair + idx_high, pgt); + dh::AtomicAddGpair(g_gpair + idx_low, ngt); + } + + if (unbiased && need_update) { + // second run, update the cost + auto k = args.ti_plus.Size() - 1; + assert(args.tj_minus.Size() - 1 == k && "Invalid size of position bias"); + + auto g_li = args.li.Slice(linalg::Range(data_group_begin, data_group_begin + n_data)); + auto g_lj = args.lj.Slice(linalg::Range(data_group_begin, data_group_begin + n_data)); + + if (rank_high <= args.ti_plus.Size() && rank_low <= args.ti_plus.Size()) { + auto cost_high = cost / (args.ti_plus(rank_high) + kRtEps); + auto cost_low = cost / (args.tj_minus(rank_low) + kRtEps); + + atomicAdd(&g_li(rank_high), + common::TruncateWithRounding(args.d_cost_rounding[0], cost_low)); + atomicAdd(&g_lj(rank_low), + common::TruncateWithRounding(args.d_cost_rounding[0], cost_high)); + } + } + return thrust::make_tuple(GradientPair{std::abs(pg.GetGrad()), std::abs(pg.GetHess())}, + std::abs(cost), -2.0 * static_cast(pg.GetGrad())); + } +}; + +template +struct MakeGetGrad { + MakePairsOp make_pair; + Delta delta; + + [[nodiscard]] KernelInputs const& Args() const { return make_pair.args; } + + MakeGetGrad(KernelInputs args, Delta d) : make_pair{args}, delta{std::move(d)} {} + + GetGradOp operator()(bool need_update) { + return GetGradOp{make_pair, delta, need_update}; + } +}; + +/** + * \brief Calculate gradient for all pairs using update op created by make_get_grad. + * + * We need to run gradient calculation twice, the first time gathers infomation like + * maximum gradient, maximum cost, and the normalization term using reduction. The second + * time performs the actual update. + * + * Without normalization, we only need to run it once since we can manually calculate + * the bounds of gradient (NDCG \in [0, 1], delta_NDCG \in [0, 1], ti+/tj- are from the + * previous iteration so the bound can be calculated for current iteration). However, if + * normalization is used, the delta score is un-bounded and we need to obtain the sum + * gradient. As a tradeoff, we simply run the kernel twice, once as reduction, second + * one as for_each. + * + * Alternatively, we can bound the delta score by limiting the output of the model using + * sigmoid for binary output and some normalization for multi-level. But effect to the + * accuracy is not known yet, and it's only used by GPU. + * + * For performance, the segmented sort for sorted scores is the bottleneck and takes up + * about half of the time, while the reduction and for_each takes up the second half. + */ +template +void CalcGrad(Context const* ctx, MetaInfo const& info, std::shared_ptr p_cache, + MakeGetGrad make_get_grad) { + auto n_groups = p_cache->Groups(); + auto d_threads_group_ptr = p_cache->CUDAThreadsGroupPtr(); + auto d_group_ptr = p_cache->DataGroupPtr(ctx); + auto d_gpair = make_get_grad.Args().gpairs; + + /** + * First pass, gather info for normalization and rounding factor. + */ + auto val_it = dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), + make_get_grad(false)); + auto reduction_op = [] XGBOOST_DEVICE(GradCostNorm const& l, + GradCostNorm const& r) -> GradCostNorm { + // get maximum gradient for each group, along with cost and the normalization term + auto const& lg = thrust::get<0>(l); + auto const& rg = thrust::get<0>(r); + auto grad = std::max(lg.GetGrad(), rg.GetGrad()); + auto hess = std::max(lg.GetHess(), rg.GetHess()); + auto cost = std::max(thrust::get<1>(l), thrust::get<1>(r)); + double sum_lambda = thrust::get<2>(l) + thrust::get<2>(r); + return thrust::make_tuple(GradientPair{std::abs(grad), std::abs(hess)}, cost, sum_lambda); + }; + auto init = thrust::make_tuple(GradientPair{0.0f, 0.0f}, 0.0, 0.0); + common::Span d_max_lambdas = p_cache->MaxLambdas(ctx, n_groups); + CHECK_EQ(n_groups * sizeof(GradCostNorm), d_max_lambdas.size_bytes()); + + std::size_t bytes; + cub::DeviceSegmentedReduce::Reduce(nullptr, bytes, val_it, d_max_lambdas.data(), n_groups, + d_threads_group_ptr.data(), d_threads_group_ptr.data() + 1, + reduction_op, init, ctx->CUDACtx()->Stream()); + dh::TemporaryArray temp(bytes); + cub::DeviceSegmentedReduce::Reduce( + temp.data().get(), bytes, val_it, d_max_lambdas.data(), n_groups, d_threads_group_ptr.data(), + d_threads_group_ptr.data() + 1, reduction_op, init, ctx->CUDACtx()->Stream()); + + dh::TemporaryArray min_bias(2); + auto d_min_bias = dh::ToSpan(min_bias); + if (unbiased) { + MinBias(ctx, p_cache, make_get_grad.Args().ti_plus, make_get_grad.Args().tj_minus, d_min_bias); + } + /** + * Create rounding factors + */ + auto d_cost_rounding = p_cache->CUDACostRounding(ctx); + auto d_rounding = p_cache->CUDARounding(ctx); + dh::LaunchN(n_groups, ctx->CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t g) mutable { + auto group_size = d_group_ptr[g + 1] - d_group_ptr[g]; + auto const& max_grad = thrust::get<0>(d_max_lambdas[g]); + // float group size + auto fgs = static_cast(group_size); + auto grad = common::CreateRoundingFactor(fgs * max_grad.GetGrad(), group_size); + auto hess = common::CreateRoundingFactor(fgs * max_grad.GetHess(), group_size); + d_rounding(g) = GradientPair{grad, hess}; + + auto cost = thrust::get<1>(d_max_lambdas[g]); + if (unbiased) { + cost /= std::min(d_min_bias[0], d_min_bias[1]); + d_cost_rounding[0] = common::CreateRoundingFactor(fgs * cost, group_size); + } + }); + + /** + * Second pass, actual update to gradient and bias. + */ + thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), + p_cache->CUDAThreads(), make_get_grad(true)); + + /** + * Lastly, normalization and weight. + */ + auto d_weights = common::MakeOptionalWeights(ctx, info.weights_); + auto w_norm = p_cache->WeightNorm(); + thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_gpair.size(), + [=] XGBOOST_DEVICE(std::size_t i) { + auto g = dh::SegmentId(d_group_ptr, i); + auto sum_lambda = thrust::get<2>(d_max_lambdas[g]); + // Normalization + if (sum_lambda > 0.0) { + double norm = std::log2(1 + sum_lambda) / sum_lambda; + d_gpair[i] *= norm; + } + d_gpair[i] *= (d_weights[g] * w_norm); + }); +} + +/** + * \brief Handles boilerplate code like getting device span. + */ +template +void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector const& preds, + const MetaInfo& info, std::shared_ptr p_cache, Delta delta, + linalg::VectorView ti_plus, // input bias ratio + linalg::VectorView tj_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair) { + // boilerplate + std::int32_t device_id = ctx->gpu_id; + dh::safe_cuda(cudaSetDevice(device_id)); + auto n_groups = p_cache->Groups(); + + info.labels.SetDevice(device_id); + preds.SetDevice(device_id); + out_gpair->SetDevice(device_id); + out_gpair->Resize(preds.Size()); + + CHECK(p_cache); + + auto d_rounding = p_cache->CUDARounding(ctx); + auto d_cost_rounding = p_cache->CUDACostRounding(ctx); + + CHECK_NE(d_rounding.Size(), 0); + + auto label = info.labels.View(ctx->gpu_id); + auto predts = preds.ConstDeviceSpan(); + auto gpairs = out_gpair->DeviceSpan(); + thrust::fill_n(ctx->CUDACtx()->CTP(), gpairs.data(), gpairs.size(), GradientPair{0.0f, 0.0f}); + + auto const d_threads_group_ptr = p_cache->CUDAThreadsGroupPtr(); + auto const d_group_ptr = p_cache->DataGroupPtr(ctx); + auto const rank_idx = p_cache->SortedIdx(ctx, predts); + + auto const unbiased = p_cache->Param().lambdarank_unbiased; + + common::Span d_y_sorted_idx; + if (!p_cache->Param().HasTruncation()) { + d_y_sorted_idx = SortY(ctx, info, rank_idx, p_cache); + } + + KernelInputs args{ti_plus, tj_minus, li, lj, d_group_ptr, d_threads_group_ptr, + rank_idx, label, predts, gpairs, d_rounding, d_cost_rounding.data(), + d_y_sorted_idx, iter}; + + // dispatch based on unbiased and truncation + if (p_cache->IsBinary()) { + if (p_cache->Param().HasTruncation()) { + if (unbiased) { + CalcGrad(ctx, info, p_cache, MakeGetGrad{args, delta}); + } else { + CalcGrad(ctx, info, p_cache, MakeGetGrad{args, delta}); + } + } else { + if (unbiased) { + CalcGrad(ctx, info, p_cache, MakeGetGrad{args, delta}); + } else { + CalcGrad(ctx, info, p_cache, MakeGetGrad{args, delta}); + } + } + } else { + if (p_cache->Param().HasTruncation()) { + if (unbiased) { + CalcGrad(ctx, info, p_cache, MakeGetGrad{args, delta}); + } else { + CalcGrad(ctx, info, p_cache, MakeGetGrad{args, delta}); + } + } else { + if (unbiased) { + CalcGrad(ctx, info, p_cache, MakeGetGrad{args, delta}); + } else { + CalcGrad(ctx, info, p_cache, MakeGetGrad{args, delta}); + } + } + } +} +} // anonymous namespace + +common::Span SortY(Context const* ctx, MetaInfo const& info, + common::Span d_rank, + std::shared_ptr p_cache) { + auto const d_group_ptr = p_cache->DataGroupPtr(ctx); + auto label = info.labels.View(ctx->gpu_id); + // The buffer for ranked y is necessary as cub segmented sort accepts only pointer. + auto d_y_ranked = p_cache->RankedY(ctx, info.num_row_); + thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_y_ranked.size(), + [=] XGBOOST_DEVICE(std::size_t i) { + auto g = dh::SegmentId(d_group_ptr, i); + auto g_label = + label.Slice(linalg::Range(d_group_ptr[g], d_group_ptr[g + 1]), 0); + auto g_rank_idx = d_rank.subspan(d_group_ptr[g], g_label.Size()); + i -= d_group_ptr[g]; + auto g_y_ranked = d_y_ranked.subspan(d_group_ptr[g], g_label.Size()); + g_y_ranked[i] = g_label(g_rank_idx[i]); + }); + auto d_y_sorted_idx = p_cache->SortedIdxY(ctx, info.num_row_); + common::SegmentedArgSort(ctx, d_y_ranked, d_group_ptr, d_y_sorted_idx); + return d_y_sorted_idx; +} + +void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter, + const HostDeviceVector& preds, const MetaInfo& info, + std::shared_ptr p_cache, + linalg::VectorView ti_plus, // input bias ratio + linalg::VectorView tj_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair) { + // boilerplate + std::int32_t device_id = ctx->gpu_id; + dh::safe_cuda(cudaSetDevice(device_id)); + auto const d_inv_IDCG = p_cache->InvIDCG(ctx); + auto const discount = p_cache->Discount(ctx); + + info.labels.SetDevice(device_id); + preds.SetDevice(device_id); + + auto const exp_gain = p_cache->Param().ndcg_exp_gain; + auto delta_ndcg = [=] XGBOOST_DEVICE(float y_high, float y_low, std::size_t rank_high, + std::size_t rank_low, bst_group_t g) { + return exp_gain ? DeltaNDCG(y_high, y_low, rank_high, rank_low, d_inv_IDCG(g), discount) + : DeltaNDCG(y_high, y_low, rank_high, rank_low, d_inv_IDCG(g), discount); + }; + Launch(ctx, iter, preds, info, p_cache, delta_ndcg, ti_plus, tj_minus, li, lj, out_gpair); +} + +void MAPStat(Context const* ctx, MetaInfo const& info, common::Span d_rank_idx, + std::shared_ptr p_cache) { + common::Span out_n_rel = p_cache->NumRelevant(ctx); + common::Span out_acc = p_cache->Acc(ctx); + + CHECK_EQ(out_n_rel.size(), info.num_row_); + CHECK_EQ(out_acc.size(), info.num_row_); + + auto group_ptr = p_cache->DataGroupPtr(ctx); + auto key_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + [=] XGBOOST_DEVICE(std::size_t i) -> std::size_t { return dh::SegmentId(group_ptr, i); }); + auto label = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0); + auto const* cuctx = ctx->CUDACtx(); + + { + // calculate number of relevant documents + auto val_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> double { + auto g = dh::SegmentId(group_ptr, i); + auto g_label = label.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1])); + auto idx_in_group = i - group_ptr[g]; + auto g_sorted_idx = d_rank_idx.subspan(group_ptr[g], group_ptr[g + 1] - group_ptr[g]); + return static_cast(g_label(g_sorted_idx[idx_in_group])); + }); + thrust::inclusive_scan_by_key(cuctx->CTP(), key_it, key_it + info.num_row_, val_it, + out_n_rel.data()); + } + { + // \sum l_k/k + auto val_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) -> double { + auto g = dh::SegmentId(group_ptr, i); + auto g_label = label.Slice(linalg::Range(group_ptr[g], group_ptr[g + 1])); + auto g_sorted_idx = d_rank_idx.subspan(group_ptr[g], group_ptr[g + 1] - group_ptr[g]); + auto idx_in_group = i - group_ptr[g]; + double rank_in_group = idx_in_group + 1.0; + return static_cast(g_label(g_sorted_idx[idx_in_group])) / rank_in_group; + }); + thrust::inclusive_scan_by_key(cuctx->CTP(), key_it, key_it + info.num_row_, val_it, + out_acc.data()); + } +} + +void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter, + HostDeviceVector const& predt, const MetaInfo& info, + std::shared_ptr p_cache, + linalg::VectorView ti_plus, // input bias ratio + linalg::VectorView tj_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair) { + auto const* cuctx = ctx->CUDACtx(); + std::int32_t device_id = ctx->gpu_id; + dh::safe_cuda(cudaSetDevice(device_id)); + + info.labels.SetDevice(device_id); + predt.SetDevice(device_id); + + CHECK(p_cache); + + auto d_predt = predt.ConstDeviceSpan(); + auto const d_sorted_idx = p_cache->SortedIdx(ctx, d_predt); + + MAPStat(ctx, info, d_sorted_idx, p_cache); + auto d_n_rel = p_cache->NumRelevant(ctx); + auto d_acc = p_cache->Acc(ctx); + auto d_gptr = p_cache->DataGroupPtr(ctx).data(); + + auto delta_map = [=] XGBOOST_DEVICE(float y_high, float y_low, std::size_t rank_high, + std::size_t rank_low, bst_group_t g) { + if (rank_high > rank_low) { + thrust::swap(rank_high, rank_low); + thrust::swap(y_high, y_low); + } + auto cnt = d_gptr[g + 1] - d_gptr[g]; + auto g_n_rel = d_n_rel.subspan(d_gptr[g], cnt); + auto g_acc = d_acc.subspan(d_gptr[g], cnt); + auto d = DeltaMAP(y_high, y_low, rank_high, rank_low, g_n_rel, g_acc); + return d; + }; + + Launch(ctx, iter, predt, info, p_cache, delta_map, ti_plus, tj_minus, li, lj, out_gpair); +} + +void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter, + HostDeviceVector const& predt, const MetaInfo& info, + std::shared_ptr p_cache, + linalg::VectorView ti_plus, // input bias ratio + linalg::VectorView tj_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair) { + auto const* cuctx = ctx->CUDACtx(); + std::int32_t device_id = ctx->gpu_id; + dh::safe_cuda(cudaSetDevice(device_id)); + + info.labels.SetDevice(device_id); + predt.SetDevice(device_id); + + auto d_predt = predt.ConstDeviceSpan(); + auto const d_sorted_idx = p_cache->SortedIdx(ctx, d_predt); + + auto delta = [] XGBOOST_DEVICE(float, float, std::size_t, std::size_t, bst_group_t) { + return 1.0; + }; + + Launch(ctx, iter, predt, info, p_cache, delta, ti_plus, tj_minus, li, lj, out_gpair); +} + +namespace { +struct ReduceOp { + template + Tup XGBOOST_DEVICE operator()(Tup const& l, Tup const& r) { + return thrust::make_tuple(thrust::get<0>(l) + thrust::get<0>(r), + thrust::get<1>(l) + thrust::get<1>(r)); + } +}; +} // namespace + +void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView li_full, + linalg::VectorView lj_full, + linalg::Vector* p_ti_plus, + linalg::Vector* p_tj_minus, + linalg::Vector* p_li, // loss + linalg::Vector* p_lj, + std::shared_ptr p_cache) { + auto const d_group_ptr = p_cache->DataGroupPtr(ctx); + auto n_groups = d_group_ptr.size() - 1; + + auto ti_plus = p_ti_plus->View(ctx->gpu_id); + auto tj_minus = p_tj_minus->View(ctx->gpu_id); + + auto li = p_li->View(ctx->gpu_id); + auto lj = p_lj->View(ctx->gpu_id); + CHECK_EQ(li.Size(), ti_plus.Size()); + + auto const& param = p_cache->Param(); + auto regularizer = param.Regularizer(); + std::size_t k = p_cache->MaxPositionSize(); + + CHECK_EQ(li.Size(), k); + CHECK_EQ(lj.Size(), k); + // reduce li_full to li for each group. + auto make_iter = [&](linalg::VectorView l_full) { + auto l_it = [=] XGBOOST_DEVICE(std::size_t i) { + // group index + auto g = i % n_groups; + // rank is the position within a group, also the segment index + auto r = i / n_groups; + + auto begin = d_group_ptr[g]; + std::size_t group_size = d_group_ptr[g + 1] - begin; + auto n = std::min(group_size, k); + // r can be greater than n since we allocate threads based on truncation level + // instead of actual group size. + if (r >= n) { + return 0.0; + } + return l_full(r + begin); + }; + return l_it; + }; + auto li_it = + dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), make_iter(li_full)); + auto lj_it = + dh::MakeTransformIterator(thrust::make_counting_iterator(0ul), make_iter(lj_full)); + // k segments, each segment has size n_groups. + auto key_it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + [=] XGBOOST_DEVICE(std::size_t i) { return i * n_groups; }); + auto val_it = thrust::make_zip_iterator(thrust::make_tuple(li_it, lj_it)); + auto out_it = + thrust::make_zip_iterator(thrust::make_tuple(li.Values().data(), lj.Values().data())); + + auto init = thrust::make_tuple(0.0, 0.0); + std::size_t bytes; + cub::DeviceSegmentedReduce::Reduce(nullptr, bytes, val_it, out_it, k, key_it, key_it + 1, + ReduceOp{}, init, ctx->CUDACtx()->Stream()); + dh::TemporaryArray temp(bytes); + cub::DeviceSegmentedReduce::Reduce(temp.data().get(), bytes, val_it, out_it, k, key_it, + key_it + 1, ReduceOp{}, init, ctx->CUDACtx()->Stream()); + + thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), li.Size(), + [=] XGBOOST_DEVICE(std::size_t i) mutable { + ti_plus(i) = std::pow(li(i) / (li(0) + kRtEps), regularizer); + assert(!std::isinf(ti_plus(i))); + + tj_minus(i) = std::pow(lj(i) / (lj(0) + kRtEps), regularizer); + assert(!std::isinf(tj_minus(i))); + }); +} +} // namespace cuda_impl +} // namespace xgboost::obj diff --git a/src/objective/lambdarank_obj.cuh b/src/objective/lambdarank_obj.cuh new file mode 100644 index 000000000000..be9f479cea3b --- /dev/null +++ b/src/objective/lambdarank_obj.cuh @@ -0,0 +1,172 @@ +/** + * Copyright 2023 XGBoost contributors + */ +#ifndef XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_ +#define XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_ + +#include // for lower_bound, upper_bound +#include // for greater +#include // for make_counting_iterator +#include // for minstd_rand +#include // for uniform_int_distribution + +#include // for cassert +#include // for size_t +#include // for int32_t +#include // for make_tuple, tuple + +#include "../common/device_helpers.cuh" // for MakeTransformIterator +#include "../common/ranking_utils.cuh" // for PairsForGroup +#include "../common/ranking_utils.h" // for RankingCache +#include "../common/threading_utils.cuh" // for UnravelTrapeziodIdx +#include "xgboost/base.h" // for bst_group_t, GradientPair, XGBOOST_DEVICE +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/linalg.h" // for VectorView, Range, UnravelIndex +#include "xgboost/span.h" // for Span + +namespace xgboost::obj::cuda_impl { +/** + * \brief Find number of elements left to the label bucket + */ +template ::value_type> +XGBOOST_DEVICE __forceinline__ std::size_t CountNumItemsToTheLeftOf(It items, std::size_t n, T v) { + return thrust::lower_bound(thrust::seq, items, items + n, v, thrust::greater{}) - items; +} +/** + * \brief Find number of elements right to the label bucket + */ +template ::value_type> +XGBOOST_DEVICE __forceinline__ std::size_t CountNumItemsToTheRightOf(It items, std::size_t n, T v) { + return n - (thrust::upper_bound(thrust::seq, items, items + n, v, thrust::greater{}) - items); +} +/** + * \brief Sort labels according to rank list for making pairs. + */ +common::Span SortY(Context const *ctx, MetaInfo const &info, + common::Span d_rank, + std::shared_ptr p_cache); + +/** + * \brief Parameters needed for calculating gradient + */ +struct KernelInputs { + linalg::VectorView ti_plus; // input bias ratio + linalg::VectorView tj_minus; // input bias ratio + linalg::VectorView li; + linalg::VectorView lj; + + common::Span d_group_ptr; + common::Span d_threads_group_ptr; + common::Span d_sorted_idx; + + linalg::MatrixView labels; + common::Span predts; + common::Span gpairs; + + linalg::VectorView d_roundings; + double const *d_cost_rounding; + + common::Span d_y_sorted_idx; + + std::int32_t iter; +}; +/** + * \brief Functor for generating pairs + */ +template +struct MakePairsOp { + KernelInputs args; + /** + * \brief Make pair for the topk pair method. + */ + XGBOOST_DEVICE std::tuple WithTruncation(std::size_t idx, + bst_group_t g) const { + auto thread_group_begin = args.d_threads_group_ptr[g]; + auto idx_in_thread_group = idx - thread_group_begin; + + auto data_group_begin = static_cast(args.d_group_ptr[g]); + std::size_t n_data = args.d_group_ptr[g + 1] - data_group_begin; + // obtain group segment data. + auto g_label = args.labels.Slice(linalg::Range(data_group_begin, data_group_begin + n_data), 0); + auto g_sorted_idx = args.d_sorted_idx.subspan(data_group_begin, n_data); + + std::size_t i = 0, j = 0; + common::UnravelTrapeziodIdx(idx_in_thread_group, n_data, &i, &j); + + std::size_t rank_high = i, rank_low = j; + return std::make_tuple(rank_high, rank_low); + } + /** + * \brief Make pair for the mean pair method + */ + XGBOOST_DEVICE std::tuple WithSampling(std::size_t idx, + bst_group_t g) const { + std::size_t n_samples = args.labels.Size(); + assert(n_samples == args.predts.size()); + // Constructed from ranking cache. + std::size_t n_pairs = + ltr::cuda_impl::PairsForGroup(args.d_threads_group_ptr[g + 1] - args.d_threads_group_ptr[g], + args.d_group_ptr[g + 1] - args.d_group_ptr[g]); + + assert(n_pairs > 0); + auto [sample_idx, sample_pair_idx] = linalg::UnravelIndex(idx, {n_samples, n_pairs}); + + auto g_begin = static_cast(args.d_group_ptr[g]); + std::size_t n_data = args.d_group_ptr[g + 1] - g_begin; + + auto g_label = args.labels.Slice(linalg::Range(g_begin, g_begin + n_data)); + auto g_rank_idx = args.d_sorted_idx.subspan(args.d_group_ptr[g], n_data); + auto g_y_sorted_idx = args.d_y_sorted_idx.subspan(g_begin, n_data); + + std::size_t const i = sample_idx - g_begin; + assert(sample_pair_idx < n_samples); + assert(i <= sample_idx); + + auto g_sorted_label = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + [&](std::size_t i) { return g_label(g_rank_idx[g_y_sorted_idx[i]]); }); + + // Are the labels diverse enough? If they are all the same, then there is nothing to pick + // from another group - bail sooner + if (g_label.Size() == 0 || g_sorted_label[0] == g_sorted_label[n_data - 1]) { + auto z = static_cast(0ul); + return std::make_tuple(z, z); + } + + std::size_t n_lefts = CountNumItemsToTheLeftOf(g_sorted_label, i + 1, g_sorted_label[i]); + std::size_t n_rights = + CountNumItemsToTheRightOf(g_sorted_label + i, n_data - i, g_sorted_label[i]); + // The index pointing to the first element of the next bucket + std::size_t right_bound = n_data - n_rights; + + thrust::minstd_rand rng(args.iter); + auto pair_idx = i; + rng.discard(sample_pair_idx * n_data + g + pair_idx); // fixme + thrust::uniform_int_distribution dist(0, n_lefts + n_rights - 1); + auto ridx = dist(rng); + SPAN_CHECK(ridx < n_lefts + n_rights); + if (ridx >= n_lefts) { + ridx = ridx - n_lefts + right_bound; // fixme + } + + auto idx0 = g_y_sorted_idx[pair_idx]; + auto idx1 = g_y_sorted_idx[ridx]; + + return std::make_tuple(idx0, idx1); + } + /** + * \brief Generate a single pair. + * + * \param idx Pair index (CUDA thread index). + * \param g Query group index. + */ + XGBOOST_DEVICE auto operator()(std::size_t idx, bst_group_t g) const { + if (has_truncation) { + return this->WithTruncation(idx, g); + } else { + return this->WithSampling(idx, g); + } + } +}; +} // namespace xgboost::obj::cuda_impl +#endif // XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_CUH_ diff --git a/src/objective/lambdarank_obj.h b/src/objective/lambdarank_obj.h new file mode 100644 index 000000000000..54005492503c --- /dev/null +++ b/src/objective/lambdarank_obj.h @@ -0,0 +1,262 @@ +/** + * Copyright 2023 XGBoost contributors + */ +#ifndef XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_ +#define XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_ +#include // for min, max +#include // for assert +#include // for log, abs +#include // for size_t +#include // for greater +#include // for shared_ptr +#include // for minstd_rand, uniform_int_distribution +#include // for vector + +#include "../common/algorithm.h" // for ArgSort +#include "../common/math.h" // for Sigmoid +#include "../common/ranking_utils.h" // for CalcDCGGain +#include "../common/transform_iterator.h" // for MakeIndexTransformIter +#include "xgboost/base.h" // for GradientPair, XGBOOST_DEVICE, kRtEps +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/linalg.h" // for VectorView, Vector +#include "xgboost/logging.h" // for CHECK_EQ +#include "xgboost/span.h" // for Span + +namespace xgboost::obj { +template +XGBOOST_DEVICE double DeltaNDCG(float y_high, float y_low, std::size_t r_high, std::size_t r_low, + double inv_IDCG, common::Span discount) { + double gain_high = exp ? ltr::CalcDCGGain(y_high) : y_high; + double discount_high = discount[r_high]; + + double gain_low = exp ? ltr::CalcDCGGain(y_low) : y_low; + double discount_low = discount[r_low]; + + double original = gain_high * discount_high + gain_low * discount_low; + double changed = gain_low * discount_high + gain_high * discount_low; + + double delta_NDCG = (original - changed) * inv_IDCG; + assert(delta_NDCG >= -1.0); + assert(delta_NDCG <= 1.0); + return delta_NDCG; +} + +XGBOOST_DEVICE inline double DeltaMAP(float y_high, float y_low, std::size_t rank_high, + std::size_t rank_low, common::Span n_rel, + common::Span acc) { + double r_h = static_cast(rank_high) + 1.0; + double r_l = static_cast(rank_low) + 1.0; + double delta{0.0}; + double n_total_relevances = n_rel.back(); + assert(n_total_relevances > 0.0); + auto m = n_rel[rank_low]; + double n = n_rel[rank_high]; + + if (y_high < y_low) { + auto a = m / r_l - (n + 1.0) / r_h; + auto b = acc[rank_low - 1] - acc[rank_high]; + delta = (a - b) / n_total_relevances; + } else { + auto a = n / r_h - m / r_l; + auto b = acc[rank_low - 1] - acc[rank_high]; + delta = (a + b) / n_total_relevances; + } + return delta; +} + +template +XGBOOST_DEVICE GradientPair +LambdaGrad(linalg::VectorView labels, common::Span predts, + common::Span sorted_idx, + std::size_t rank_high, // cordiniate + std::size_t rank_low, // cordiniate + Delta delta, // delta score + linalg::VectorView t_plus, // input bias ratio + linalg::VectorView t_minus, // input bias ratio + double* p_cost) { + assert(sorted_idx.size() > 0 && "Empty sorted idx for a group."); + std::size_t idx_high = sorted_idx[rank_high]; + std::size_t idx_low = sorted_idx[rank_low]; + + if (labels(idx_high) == labels(idx_low)) { + *p_cost = 0; + return {0.0f, 0.0f}; + } + + auto best_score = predts[sorted_idx.front()]; + auto worst_score = predts[sorted_idx.back()]; + + auto y_high = labels(idx_high); + float s_high = predts[idx_high]; + auto y_low = labels(idx_low); + float s_low = predts[idx_low]; + + // Use double whenever possible as we are working on the exp space. + double delta_score = std::abs(s_high - s_low); + double sigmoid = common::Sigmoid(s_high - s_low); + // Change in metric score like \delta NDCG or \delta MAP + double delta_metric = std::abs(delta(y_high, y_low, rank_high, rank_low)); + + auto k = t_plus.Size() - 1; + assert(t_minus.Size() - 1 == k && "Invalid size of position bias"); + + if (normalize && best_score != worst_score) { + delta_metric /= (delta_score + kRtEps); + } + + if (unbiased) { + *p_cost = std::log(1.0 / (1.0 - sigmoid)) * delta_metric; + } + + constexpr double kEps = 1e-16; + auto lambda_ij = (sigmoid - 1.0) * delta_metric; + auto hessian_ij = std::max(sigmoid * (1.0 - sigmoid), kEps) * delta_metric * 2.0; + + if (unbiased) { + auto position_high = static_cast(std::min(rank_high, k)); + auto position_low = static_cast(std::min(rank_low, k)); + lambda_ij /= (t_minus(position_low) * t_plus(position_high) + kRtEps); + hessian_ij /= (t_minus(position_low) * t_plus(position_high) + kRtEps); + } + + auto pg = GradientPair{static_cast(lambda_ij), static_cast(hessian_ij)}; + return pg; +} + +XGBOOST_DEVICE inline GradientPair Repulse(GradientPair pg) { + auto ng = GradientPair{-pg.GetGrad(), pg.GetHess()}; + return ng; +} + +namespace cuda_impl { +void LambdaRankGetGradientNDCG(Context const* ctx, std::int32_t iter, + HostDeviceVector const& preds, MetaInfo const& info, + std::shared_ptr p_cache, + linalg::VectorView t_plus, // input bias ratio + linalg::VectorView t_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair); + +/** + * \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart. + */ +void MAPStat(Context const* ctx, MetaInfo const& info, common::Span d_rank_idx, + std::shared_ptr p_cache); + +void LambdaRankGetGradientMAP(Context const* ctx, std::int32_t iter, + HostDeviceVector const& predt, MetaInfo const& info, + std::shared_ptr p_cache, + linalg::VectorView t_plus, // input bias ratio + linalg::VectorView t_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair); + +void LambdaRankGetGradientPairwise(Context const* ctx, std::int32_t iter, + HostDeviceVector const& predt, const MetaInfo& info, + std::shared_ptr p_cache, + linalg::VectorView ti_plus, // input bias ratio + linalg::VectorView tj_minus, // input bias ratio + linalg::VectorView li, linalg::VectorView lj, + HostDeviceVector* out_gpair); + +void LambdaRankUpdatePositionBias(Context const* ctx, linalg::VectorView li_full, + linalg::VectorView lj_full, + linalg::Vector* p_ti_plus, + linalg::Vector* p_tj_minus, linalg::Vector* p_li, + linalg::Vector* p_lj, + std::shared_ptr p_cache); +} // namespace cuda_impl + +namespace cpu_impl { +/** + * \brief Generate statistic for MAP used for calculating \Delta Z in lambda mart. + * + * \param label Ground truth relevance label. + * \param rank_idx Sorted index of prediction. + * \param p_cache An initialized MAPCache. + */ +void MAPStat(Context const* ctx, linalg::VectorView label, + common::Span rank_idx, std::shared_ptr p_cache); +} // namespace cpu_impl + +/** + * \param Construct pairs on CPU + * + * \tparam Op Functor for upgrading a pair of gradients. + * + * \param ctx The global context. + * \param iter The boosting iteration. + * \param cache ltr cache. + * \param g The current query group + * \param g_label label The labels for the current query group + * \param g_rank Sorted index of model scores for the current query group. + * \param op A callable that accepts two index for a pair of documents. The index is for + * the ranked list (labels sorted according to model scores). + */ +template +void MakePairs(Context const* ctx, std::int32_t iter, + std::shared_ptr const cache, bst_group_t g, + linalg::VectorView g_label, common::Span g_rank, + Op op) { + auto group_ptr = cache->DataGroupPtr(ctx); + ltr::position_t cnt = group_ptr[g + 1] - group_ptr[g]; + + if (cache->Param().HasTruncation()) { + for (std::size_t i = 0; i < std::min(cnt, cache->Param().NumPair()); ++i) { + for (std::size_t j = i + 1; j < cnt; ++j) { + op(i, j); + } + } + } else { + CHECK_EQ(g_rank.size(), g_label.Size()); + std::minstd_rand rnd(iter); + rnd.discard(g); // fixme(jiamingy): honor the global seed + // sort label according to the rank list + auto it = common::MakeIndexTransformIter( + [&g_rank, &g_label](std::size_t idx) { return g_label(g_rank[idx]); }); + std::vector y_sorted_idx = + common::ArgSort(ctx, it, it + cnt, std::greater<>{}); + // permutation iterator to get the original label + auto rev_it = common::MakeIndexTransformIter( + [&](std::size_t idx) { return g_label(g_rank[y_sorted_idx[idx]]); }); + + for (std::size_t i = 0; i < cnt;) { + std::size_t j = i + 1; + // find the bucket boundary + while (j < cnt && rev_it[i] == rev_it[j]) { + ++j; + } + // Bucket [i,j), construct n_samples pairs for each sample inside the bucket with + // another sample outside the bucket. + // + // n elements left to the bucket, and n elements right to the bucket + std::size_t n_lefts = i, n_rights = static_cast(cnt - j); + if (n_lefts + n_rights == 0) { + i = j; + continue; + } + + auto n_samples = cache->Param().NumPair(); + // for each pair specifed by the user + while (n_samples--) { + // for each sample in the bucket + for (std::size_t pair_idx = i; pair_idx < j; ++pair_idx) { + std::size_t ridx = std::uniform_int_distribution( + static_cast(0), n_lefts + n_rights - 1)(rnd); + if (ridx >= n_lefts) { + ridx = ridx - i + j; // shift to the right of the bucket + } + // index that points to the rank list. + auto idx0 = y_sorted_idx[pair_idx]; + auto idx1 = y_sorted_idx[ridx]; + op(idx0, idx1); + } + } + i = j; + } + } +} +} // namespace xgboost::obj +#endif // XGBOOST_OBJECTIVE_LAMBDARANK_OBJ_H_ diff --git a/src/objective/objective.cc b/src/objective/objective.cc index d3b01d80bf27..85cd9803d4ef 100644 --- a/src/objective/objective.cc +++ b/src/objective/objective.cc @@ -47,13 +47,14 @@ DMLC_REGISTRY_LINK_TAG(regression_obj_gpu); DMLC_REGISTRY_LINK_TAG(quantile_obj_gpu); DMLC_REGISTRY_LINK_TAG(hinge_obj_gpu); DMLC_REGISTRY_LINK_TAG(multiclass_obj_gpu); -DMLC_REGISTRY_LINK_TAG(rank_obj_gpu); +DMLC_REGISTRY_LINK_TAG(lambdarank_obj); +DMLC_REGISTRY_LINK_TAG(lambdarank_obj_cu); #else DMLC_REGISTRY_LINK_TAG(regression_obj); DMLC_REGISTRY_LINK_TAG(quantile_obj); DMLC_REGISTRY_LINK_TAG(hinge_obj); DMLC_REGISTRY_LINK_TAG(multiclass_obj); -DMLC_REGISTRY_LINK_TAG(rank_obj); +DMLC_REGISTRY_LINK_TAG(lambdarank_obj); #endif // XGBOOST_USE_CUDA } // namespace obj } // namespace xgboost diff --git a/src/objective/rank_obj.cc b/src/objective/rank_obj.cc deleted file mode 100644 index 25cd9e643eff..000000000000 --- a/src/objective/rank_obj.cc +++ /dev/null @@ -1,17 +0,0 @@ -/*! - * Copyright 2019 XGBoost contributors - */ - -// Dummy file to keep the CUDA conditional compile trick. -#include -namespace xgboost { -namespace obj { - -DMLC_REGISTRY_FILE_TAG(rank_obj); - -} // namespace obj -} // namespace xgboost - -#ifndef XGBOOST_USE_CUDA -#include "rank_obj.cu" -#endif // XGBOOST_USE_CUDA diff --git a/src/objective/rank_obj.cu b/src/objective/rank_obj.cu deleted file mode 100644 index f1c8702102df..000000000000 --- a/src/objective/rank_obj.cu +++ /dev/null @@ -1,961 +0,0 @@ -/*! - * Copyright 2015-2022 XGBoost contributors - */ -#include -#include -#include -#include -#include -#include -#include - -#include "xgboost/json.h" -#include "xgboost/parameter.h" - -#include "../common/math.h" -#include "../common/random.h" - -#if defined(__CUDACC__) -#include -#include -#include -#include -#include - -#include - -#include "../common/device_helpers.cuh" -#endif - -namespace xgboost { -namespace obj { - -#if defined(XGBOOST_USE_CUDA) && !defined(GTEST_TEST) -DMLC_REGISTRY_FILE_TAG(rank_obj_gpu); -#endif // defined(XGBOOST_USE_CUDA) - -struct LambdaRankParam : public XGBoostParameter { - size_t num_pairsample; - float fix_list_weight; - // declare parameters - DMLC_DECLARE_PARAMETER(LambdaRankParam) { - DMLC_DECLARE_FIELD(num_pairsample).set_lower_bound(1).set_default(1) - .describe("Number of pair generated for each instance."); - DMLC_DECLARE_FIELD(fix_list_weight).set_lower_bound(0.0f).set_default(0.0f) - .describe("Normalize the weight of each list by this value," - " if equals 0, no effect will happen"); - } -}; - -#if defined(__CUDACC__) -// Helper functions - -template -XGBOOST_DEVICE __forceinline__ uint32_t -CountNumItemsToTheLeftOf(const T *__restrict__ items, uint32_t n, T v) { - return thrust::lower_bound(thrust::seq, items, items + n, v, - thrust::greater()) - - items; -} - -template -XGBOOST_DEVICE __forceinline__ uint32_t -CountNumItemsToTheRightOf(const T *__restrict__ items, uint32_t n, T v) { - return n - (thrust::upper_bound(thrust::seq, items, items + n, v, - thrust::greater()) - - items); -} -#endif - -/*! \brief helper information in a list */ -struct ListEntry { - /*! \brief the predict score we in the data */ - bst_float pred; - /*! \brief the actual label of the entry */ - bst_float label; - /*! \brief row index in the data matrix */ - unsigned rindex; - // constructor - ListEntry(bst_float pred, bst_float label, unsigned rindex) - : pred(pred), label(label), rindex(rindex) {} - // comparator by prediction - inline static bool CmpPred(const ListEntry &a, const ListEntry &b) { - return a.pred > b.pred; - } - // comparator by label - inline static bool CmpLabel(const ListEntry &a, const ListEntry &b) { - return a.label > b.label; - } -}; - -/*! \brief a pair in the lambda rank */ -struct LambdaPair { - /*! \brief positive index: this is a position in the list */ - unsigned pos_index; - /*! \brief negative index: this is a position in the list */ - unsigned neg_index; - /*! \brief weight to be filled in */ - bst_float weight; - // constructor - LambdaPair(unsigned pos_index, unsigned neg_index) - : pos_index(pos_index), neg_index(neg_index), weight(1.0f) {} - // constructor - LambdaPair(unsigned pos_index, unsigned neg_index, bst_float weight) - : pos_index(pos_index), neg_index(neg_index), weight(weight) {} -}; - -class PairwiseLambdaWeightComputer { - public: - /*! - * \brief get lambda weight for existing pairs - for pairwise objective - * \param list a list that is sorted by pred score - * \param io_pairs record of pairs, containing the pairs to fill in weights - */ - static void GetLambdaWeight(const std::vector&, - std::vector*) {} - - static char const* Name() { - return "rank:pairwise"; - } - -#if defined(__CUDACC__) - PairwiseLambdaWeightComputer(const bst_float*, - const bst_float*, - const dh::SegmentSorter&) {} - - class PairwiseLambdaWeightMultiplier { - public: - // Adjust the items weight by this value - __device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const { - return 1.0f; - } - }; - - inline const PairwiseLambdaWeightMultiplier GetWeightMultiplier() const { - return {}; - } -#endif -}; - -#if defined(__CUDACC__) -class BaseLambdaWeightMultiplier { - public: - BaseLambdaWeightMultiplier(const dh::SegmentSorter &segment_label_sorter, - const dh::SegmentSorter &segment_pred_sorter) - : dsorted_labels_(segment_label_sorter.GetItemsSpan()), - dorig_pos_(segment_label_sorter.GetOriginalPositionsSpan()), - dgroups_(segment_label_sorter.GetGroupsSpan()), - dindexable_sorted_preds_pos_(segment_pred_sorter.GetIndexableSortedPositionsSpan()) {} - - protected: - const common::Span dsorted_labels_; // Labels sorted within a group - const common::Span dorig_pos_; // Original indices of the labels - // before they are sorted - const common::Span dgroups_; // The group indices - // Where can a prediction for a label be found in the original array, when they are sorted - const common::Span dindexable_sorted_preds_pos_; -}; - -// While computing the weight that needs to be adjusted by this ranking objective, we need -// to figure out where positive and negative labels chosen earlier exists, if the group -// were to be sorted by its predictions. To accommodate this, we employ the following algorithm. -// For a given group, let's assume the following: -// labels: 1 5 9 2 4 8 0 7 6 3 -// predictions: 1 9 0 8 2 7 3 6 5 4 -// position: 0 1 2 3 4 5 6 7 8 9 -// -// After label sort: -// labels: 9 8 7 6 5 4 3 2 1 0 -// position: 2 5 7 8 1 4 9 3 0 6 -// -// After prediction sort: -// predictions: 9 8 7 6 5 4 3 2 1 0 -// position: 1 3 5 7 8 9 6 4 0 2 -// -// If a sorted label at position 'x' is chosen, then we need to find out where the prediction -// for this label 'x' exists, if the group were to be sorted by predictions. -// We first take the sorted prediction positions: -// position: 1 3 5 7 8 9 6 4 0 2 -// at indices: 0 1 2 3 4 5 6 7 8 9 -// -// We create a sorted prediction positional array, such that value at position 'x' gives -// us the position in the sorted prediction array where its related prediction lies. -// dindexable_sorted_preds_pos_: 8 0 9 1 7 2 6 3 4 5 -// at indices: 0 1 2 3 4 5 6 7 8 9 -// Basically, swap the previous 2 arrays, sort the indices and reorder positions -// for an O(1) lookup using the position where the sorted label exists. -// -// This type does that using the SegmentSorter -class IndexablePredictionSorter { - public: - IndexablePredictionSorter(const bst_float *dpreds, - const dh::SegmentSorter &segment_label_sorter) { - // Sort the predictions first - segment_pred_sorter_.SortItems(dpreds, segment_label_sorter.GetNumItems(), - segment_label_sorter.GetGroupSegmentsSpan()); - - // Create an index for the sorted prediction positions - segment_pred_sorter_.CreateIndexableSortedPositions(); - } - - inline const dh::SegmentSorter &GetPredictionSorter() const { - return segment_pred_sorter_; - } - - private: - dh::SegmentSorter segment_pred_sorter_; // For sorting the predictions -}; -#endif - -// beta version: NDCG lambda rank -class NDCGLambdaWeightComputer -#if defined(__CUDACC__) - : public IndexablePredictionSorter -#endif -{ - public: -#if defined(__CUDACC__) - // This function object computes the item's DCG value - class ComputeItemDCG : public thrust::unary_function { - public: - XGBOOST_DEVICE ComputeItemDCG(const common::Span &dsorted_labels, - const common::Span &dgroups, - const common::Span &gidxs) - : dsorted_labels_(dsorted_labels), - dgroups_(dgroups), - dgidxs_(gidxs) {} - - // Compute DCG for the item at 'idx' - __device__ __forceinline__ float operator()(uint32_t idx) const { - return ComputeItemDCGWeight(dsorted_labels_[idx], idx - dgroups_[dgidxs_[idx]]); - } - - private: - const common::Span dsorted_labels_; // Labels sorted within a group - const common::Span dgroups_; // The group indices - where each group - // begins and ends - const common::Span dgidxs_; // The group each items belongs to - }; - - // Type containing device pointers that can be cheaply copied on the kernel - class NDCGLambdaWeightMultiplier : public BaseLambdaWeightMultiplier { - public: - NDCGLambdaWeightMultiplier(const dh::SegmentSorter &segment_label_sorter, - const NDCGLambdaWeightComputer &lwc) - : BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()), - dgroup_dcgs_(lwc.GetGroupDcgsSpan()) {} - - // Adjust the items weight by this value - __device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const { - if (dgroup_dcgs_[gidx] == 0.0) return 0.0f; - - uint32_t group_begin = dgroups_[gidx]; - - auto pos_lab_orig_posn = dorig_pos_[pidx]; - auto neg_lab_orig_posn = dorig_pos_[nidx]; - KERNEL_CHECK(pos_lab_orig_posn != neg_lab_orig_posn); - - // Note: the label positive and negative indices are relative to the entire dataset. - // Hence, scale them back to an index within the group - auto pos_pred_pos = dindexable_sorted_preds_pos_[pos_lab_orig_posn] - group_begin; - auto neg_pred_pos = dindexable_sorted_preds_pos_[neg_lab_orig_posn] - group_begin; - return NDCGLambdaWeightComputer::ComputeDeltaWeight( - pos_pred_pos, neg_pred_pos, - static_cast(dsorted_labels_[pidx]), static_cast(dsorted_labels_[nidx]), - dgroup_dcgs_[gidx]); - } - - private: - const common::Span dgroup_dcgs_; // Group DCG values - }; - - NDCGLambdaWeightComputer(const bst_float *dpreds, - const bst_float*, - const dh::SegmentSorter &segment_label_sorter) - : IndexablePredictionSorter(dpreds, segment_label_sorter), - dgroup_dcg_(segment_label_sorter.GetNumGroups(), 0.0f), - weight_multiplier_(segment_label_sorter, *this) { - const auto &group_segments = segment_label_sorter.GetGroupSegmentsSpan(); - - // Allocator to be used for managing space overhead while performing transformed reductions - dh::XGBCachingDeviceAllocator alloc; - - // Compute each elements DCG values and reduce them across groups concurrently. - auto end_range = - thrust::reduce_by_key(thrust::cuda::par(alloc), - dh::tcbegin(group_segments), dh::tcend(group_segments), - thrust::make_transform_iterator( - // The indices need not be sequential within a group, as we care only - // about the sum of items DCG values within a group - dh::tcbegin(segment_label_sorter.GetOriginalPositionsSpan()), - ComputeItemDCG(segment_label_sorter.GetItemsSpan(), - segment_label_sorter.GetGroupsSpan(), - group_segments)), - thrust::make_discard_iterator(), // We don't care for the group indices - dgroup_dcg_.begin()); // Sum of the item's DCG values in the group - CHECK_EQ(static_cast(end_range.second - dgroup_dcg_.begin()), dgroup_dcg_.size()); - } - - inline const common::Span GetGroupDcgsSpan() const { - return { dgroup_dcg_.data().get(), dgroup_dcg_.size() }; - } - - inline const NDCGLambdaWeightMultiplier GetWeightMultiplier() const { - return weight_multiplier_; - } -#endif - - static void GetLambdaWeight(const std::vector &sorted_list, - std::vector *io_pairs) { - std::vector &pairs = *io_pairs; - float IDCG; // NOLINT - { - std::vector labels(sorted_list.size()); - for (size_t i = 0; i < sorted_list.size(); ++i) { - labels[i] = sorted_list[i].label; - } - std::stable_sort(labels.begin(), labels.end(), std::greater<>()); - IDCG = ComputeGroupDCGWeight(&labels[0], labels.size()); - } - if (IDCG == 0.0) { - for (auto & pair : pairs) { - pair.weight = 0.0f; - } - } else { - for (auto & pair : pairs) { - unsigned pos_idx = pair.pos_index; - unsigned neg_idx = pair.neg_index; - pair.weight *= ComputeDeltaWeight(pos_idx, neg_idx, - sorted_list[pos_idx].label, sorted_list[neg_idx].label, - IDCG); - } - } - } - - static char const* Name() { - return "rank:ndcg"; - } - - inline static bst_float ComputeGroupDCGWeight(const float *sorted_labels, uint32_t size) { - double sumdcg = 0.0; - for (uint32_t i = 0; i < size; ++i) { - sumdcg += ComputeItemDCGWeight(sorted_labels[i], i); - } - - return static_cast(sumdcg); - } - - private: - XGBOOST_DEVICE inline static bst_float ComputeItemDCGWeight(unsigned label, uint32_t idx) { - return (label != 0) ? (((1 << label) - 1) / std::log2(static_cast(idx + 2))) : 0; - } - - // Compute the weight adjustment for an item within a group: - // pos_pred_pos => Where does the positive label live, had the list been sorted by prediction - // neg_pred_pos => Where does the negative label live, had the list been sorted by prediction - // pos_label => positive label value from sorted label list - // neg_label => negative label value from sorted label list - XGBOOST_DEVICE inline static bst_float ComputeDeltaWeight(uint32_t pos_pred_pos, - uint32_t neg_pred_pos, - int pos_label, int neg_label, - float idcg) { - float pos_loginv = 1.0f / std::log2(pos_pred_pos + 2.0f); - float neg_loginv = 1.0f / std::log2(neg_pred_pos + 2.0f); - bst_float original = ((1 << pos_label) - 1) * pos_loginv + ((1 << neg_label) - 1) * neg_loginv; - float changed = ((1 << neg_label) - 1) * pos_loginv + ((1 << pos_label) - 1) * neg_loginv; - bst_float delta = (original - changed) * (1.0f / idcg); - if (delta < 0.0f) delta = - delta; - return delta; - } - -#if defined(__CUDACC__) - dh::caching_device_vector dgroup_dcg_; - // This computes the adjustment to the weight - const NDCGLambdaWeightMultiplier weight_multiplier_; -#endif -}; - -class MAPLambdaWeightComputer -#if defined(__CUDACC__) - : public IndexablePredictionSorter -#endif -{ - public: - struct MAPStats { - /*! \brief the accumulated precision */ - float ap_acc{0.0f}; - /*! - * \brief the accumulated precision, - * assuming a positive instance is missing - */ - float ap_acc_miss{0.0f}; - /*! - * \brief the accumulated precision, - * assuming that one more positive instance is inserted ahead - */ - float ap_acc_add{0.0f}; - /* \brief the accumulated positive instance count */ - float hits{0.0f}; - - XGBOOST_DEVICE MAPStats() {} // NOLINT - XGBOOST_DEVICE MAPStats(float ap_acc, float ap_acc_miss, float ap_acc_add, float hits) - : ap_acc(ap_acc), ap_acc_miss(ap_acc_miss), ap_acc_add(ap_acc_add), hits(hits) {} - - // For prefix scan - XGBOOST_DEVICE MAPStats operator +(const MAPStats &v1) const { - return {ap_acc + v1.ap_acc, ap_acc_miss + v1.ap_acc_miss, - ap_acc_add + v1.ap_acc_add, hits + v1.hits}; - } - - // For test purposes - compare for equality - XGBOOST_DEVICE bool operator ==(const MAPStats &rhs) const { - return ap_acc == rhs.ap_acc && ap_acc_miss == rhs.ap_acc_miss && - ap_acc_add == rhs.ap_acc_add && hits == rhs.hits; - } - }; - - private: - template - XGBOOST_DEVICE inline static void Swap(T &v0, T &v1) { -#if defined(__CUDACC__) - thrust::swap(v0, v1); -#else - std::swap(v0, v1); -#endif - } - - /*! - * \brief Obtain the delta MAP by trying to switch the positions of labels in pos_pred_pos or - * neg_pred_pos when sorted by predictions - * \param pos_pred_pos positive label's prediction value position when the groups prediction - * values are sorted - * \param neg_pred_pos negative label's prediction value position when the groups prediction - * values are sorted - * \param pos_label, neg_label the chosen positive and negative labels - * \param p_map_stats a vector containing the accumulated precisions for each position in a list - * \param map_stats_size size of the accumulated precisions vector - */ - XGBOOST_DEVICE inline static bst_float GetLambdaMAP( - int pos_pred_pos, int neg_pred_pos, - bst_float pos_label, bst_float neg_label, - const MAPStats *p_map_stats, uint32_t map_stats_size) { - if (pos_pred_pos == neg_pred_pos || p_map_stats[map_stats_size - 1].hits == 0) { - return 0.0f; - } - if (pos_pred_pos > neg_pred_pos) { - Swap(pos_pred_pos, neg_pred_pos); - Swap(pos_label, neg_label); - } - bst_float original = p_map_stats[neg_pred_pos].ap_acc; - if (pos_pred_pos != 0) original -= p_map_stats[pos_pred_pos - 1].ap_acc; - bst_float changed = 0; - bst_float label1 = pos_label > 0.0f ? 1.0f : 0.0f; - bst_float label2 = neg_label > 0.0f ? 1.0f : 0.0f; - if (label1 == label2) { - return 0.0; - } else if (label1 < label2) { - changed += p_map_stats[neg_pred_pos - 1].ap_acc_add - p_map_stats[pos_pred_pos].ap_acc_add; - changed += (p_map_stats[pos_pred_pos].hits + 1.0f) / (pos_pred_pos + 1); - } else { - changed += p_map_stats[neg_pred_pos - 1].ap_acc_miss - p_map_stats[pos_pred_pos].ap_acc_miss; - changed += p_map_stats[neg_pred_pos].hits / (neg_pred_pos + 1); - } - bst_float ans = (changed - original) / (p_map_stats[map_stats_size - 1].hits); - if (ans < 0) ans = -ans; - return ans; - } - - public: - /* - * \brief obtain preprocessing results for calculating delta MAP - * \param sorted_list the list containing entry information - * \param map_stats a vector containing the accumulated precisions for each position in a list - */ - inline static void GetMAPStats(const std::vector &sorted_list, - std::vector *p_map_acc) { - std::vector &map_acc = *p_map_acc; - map_acc.resize(sorted_list.size()); - bst_float hit = 0, acc1 = 0, acc2 = 0, acc3 = 0; - for (size_t i = 1; i <= sorted_list.size(); ++i) { - if (sorted_list[i - 1].label > 0.0f) { - hit++; - acc1 += hit / i; - acc2 += (hit - 1) / i; - acc3 += (hit + 1) / i; - } - map_acc[i - 1] = MAPStats(acc1, acc2, acc3, hit); - } - } - - static char const* Name() { - return "rank:map"; - } - - static void GetLambdaWeight(const std::vector &sorted_list, - std::vector *io_pairs) { - std::vector &pairs = *io_pairs; - std::vector map_stats; - GetMAPStats(sorted_list, &map_stats); - for (auto & pair : pairs) { - pair.weight *= - GetLambdaMAP(pair.pos_index, pair.neg_index, - sorted_list[pair.pos_index].label, sorted_list[pair.neg_index].label, - &map_stats[0], map_stats.size()); - } - } - -#if defined(__CUDACC__) - MAPLambdaWeightComputer(const bst_float *dpreds, - const bst_float *dlabels, - const dh::SegmentSorter &segment_label_sorter) - : IndexablePredictionSorter(dpreds, segment_label_sorter), - dmap_stats_(segment_label_sorter.GetNumItems(), MAPStats()), - weight_multiplier_(segment_label_sorter, *this) { - this->CreateMAPStats(dlabels, segment_label_sorter); - } - - void CreateMAPStats(const bst_float *dlabels, - const dh::SegmentSorter &segment_label_sorter) { - // For each group, go through the sorted prediction positions, and look up its corresponding - // label from the unsorted labels (from the original label list) - - // For each item in the group, compute its MAP stats. - // Interleave the computation of map stats amongst different groups. - - // First, determine postive labels in the dataset individually - auto nitems = segment_label_sorter.GetNumItems(); - dh::caching_device_vector dhits(nitems, 0); - // Original positions of the predictions after they have been sorted - const auto &pred_original_pos = this->GetPredictionSorter().GetOriginalPositionsSpan(); - // Unsorted labels - const float *unsorted_labels = dlabels; - auto DeterminePositiveLabelLambda = [=] __device__(uint32_t idx) { - return (unsorted_labels[pred_original_pos[idx]] > 0.0f) ? 1 : 0; - }; // NOLINT - - thrust::transform(thrust::make_counting_iterator(static_cast(0)), - thrust::make_counting_iterator(nitems), - dhits.begin(), - DeterminePositiveLabelLambda); - - // Allocator to be used by sort for managing space overhead while performing prefix scans - dh::XGBCachingDeviceAllocator alloc; - - // Next, prefix scan the positive labels that are segmented to accumulate them. - // This is required for computing the accumulated precisions - const auto &group_segments = segment_label_sorter.GetGroupSegmentsSpan(); - // Data segmented into different groups... - thrust::inclusive_scan_by_key(thrust::cuda::par(alloc), - dh::tcbegin(group_segments), dh::tcend(group_segments), - dhits.begin(), // Input value - dhits.begin()); // In-place scan - - // Compute accumulated precisions for each item, assuming positive and - // negative instances are missing. - // But first, compute individual item precisions - const auto *dhits_arr = dhits.data().get(); - // Group info on device - const auto &dgroups = segment_label_sorter.GetGroupsSpan(); - auto ComputeItemPrecisionLambda = [=] __device__(uint32_t idx) { - if (unsorted_labels[pred_original_pos[idx]] > 0.0f) { - auto idx_within_group = (idx - dgroups[group_segments[idx]]) + 1; - return MAPStats{static_cast(dhits_arr[idx]) / idx_within_group, - static_cast(dhits_arr[idx] - 1) / idx_within_group, - static_cast(dhits_arr[idx] + 1) / idx_within_group, - 1.0f}; - } - return MAPStats{}; - }; // NOLINT - - thrust::transform(thrust::make_counting_iterator(static_cast(0)), - thrust::make_counting_iterator(nitems), - this->dmap_stats_.begin(), - ComputeItemPrecisionLambda); - - // Lastly, compute the accumulated precisions for all the items segmented by groups. - // The precisions are accumulated within each group - thrust::inclusive_scan_by_key(thrust::cuda::par(alloc), - dh::tcbegin(group_segments), dh::tcend(group_segments), - this->dmap_stats_.begin(), // Input map stats - this->dmap_stats_.begin()); // In-place scan and output here - } - - inline const common::Span GetMapStatsSpan() const { - return { dmap_stats_.data().get(), dmap_stats_.size() }; - } - - // Type containing device pointers that can be cheaply copied on the kernel - class MAPLambdaWeightMultiplier : public BaseLambdaWeightMultiplier { - public: - MAPLambdaWeightMultiplier(const dh::SegmentSorter &segment_label_sorter, - const MAPLambdaWeightComputer &lwc) - : BaseLambdaWeightMultiplier(segment_label_sorter, lwc.GetPredictionSorter()), - dmap_stats_(lwc.GetMapStatsSpan()) {} - - // Adjust the items weight by this value - __device__ __forceinline__ bst_float GetWeight(uint32_t gidx, int pidx, int nidx) const { - uint32_t group_begin = dgroups_[gidx]; - uint32_t group_end = dgroups_[gidx + 1]; - - auto pos_lab_orig_posn = dorig_pos_[pidx]; - auto neg_lab_orig_posn = dorig_pos_[nidx]; - KERNEL_CHECK(pos_lab_orig_posn != neg_lab_orig_posn); - - // Note: the label positive and negative indices are relative to the entire dataset. - // Hence, scale them back to an index within the group - auto pos_pred_pos = dindexable_sorted_preds_pos_[pos_lab_orig_posn] - group_begin; - auto neg_pred_pos = dindexable_sorted_preds_pos_[neg_lab_orig_posn] - group_begin; - return MAPLambdaWeightComputer::GetLambdaMAP( - pos_pred_pos, neg_pred_pos, - dsorted_labels_[pidx], dsorted_labels_[nidx], - &dmap_stats_[group_begin], group_end - group_begin); - } - - private: - common::Span dmap_stats_; // Start address of the map stats for every sorted - // prediction value - }; - - inline const MAPLambdaWeightMultiplier GetWeightMultiplier() const { return weight_multiplier_; } - - private: - dh::caching_device_vector dmap_stats_; - // This computes the adjustment to the weight - const MAPLambdaWeightMultiplier weight_multiplier_; -#endif -}; - -#if defined(__CUDACC__) -class SortedLabelList : dh::SegmentSorter { - private: - const LambdaRankParam ¶m_; // Objective configuration - - public: - explicit SortedLabelList(const LambdaRankParam ¶m) - : param_(param) {} - - // Sort the labels that are grouped by 'groups' - void Sort(const HostDeviceVector &dlabels, const std::vector &groups) { - this->SortItems(dlabels.ConstDevicePointer(), dlabels.Size(), groups); - } - - // This kernel can only run *after* the kernel in sort is completed, as they - // use the default stream - template - void ComputeGradients(const bst_float *dpreds, // Unsorted predictions - const bst_float *dlabels, // Unsorted labels - const HostDeviceVector &weights, - int iter, - GradientPair *out_gpair, - float weight_normalization_factor) { - // Group info on device - const auto &dgroups = this->GetGroupsSpan(); - uint32_t ngroups = this->GetNumGroups() + 1; - - uint32_t total_items = this->GetNumItems(); - uint32_t niter = param_.num_pairsample * total_items; - - float fix_list_weight = param_.fix_list_weight; - - const auto &original_pos = this->GetOriginalPositionsSpan(); - - uint32_t num_weights = weights.Size(); - auto dweights = num_weights ? weights.ConstDevicePointer() : nullptr; - - const auto &sorted_labels = this->GetItemsSpan(); - - // This is used to adjust the weight of different elements based on the different ranking - // objective function policies - LambdaWeightComputerT weight_computer(dpreds, dlabels, *this); - auto wmultiplier = weight_computer.GetWeightMultiplier(); - - int device_id = -1; - dh::safe_cuda(cudaGetDevice(&device_id)); - // For each instance in the group, compute the gradient pair concurrently - dh::LaunchN(niter, nullptr, [=] __device__(uint32_t idx) { - // First, determine the group 'idx' belongs to - uint32_t item_idx = idx % total_items; - uint32_t group_idx = - thrust::upper_bound(thrust::seq, dgroups.begin(), - dgroups.begin() + ngroups, item_idx) - - dgroups.begin(); - // Span of this group within the larger labels/predictions sorted tuple - uint32_t group_begin = dgroups[group_idx - 1]; - uint32_t group_end = dgroups[group_idx]; - uint32_t total_group_items = group_end - group_begin; - - // Are the labels diverse enough? If they are all the same, then there is nothing to pick - // from another group - bail sooner - if (sorted_labels[group_begin] == sorted_labels[group_end - 1]) return; - - // Find the number of labels less than and greater than the current label - // at the sorted index position item_idx - uint32_t nleft = CountNumItemsToTheLeftOf( - sorted_labels.data() + group_begin, item_idx - group_begin + 1, sorted_labels[item_idx]); - uint32_t nright = CountNumItemsToTheRightOf( - sorted_labels.data() + item_idx, group_end - item_idx, sorted_labels[item_idx]); - - // Create a minstd_rand object to act as our source of randomness - thrust::minstd_rand rng((iter + 1) * 1111); - rng.discard(((idx / total_items) * total_group_items) + item_idx - group_begin); - // Create a uniform_int_distribution to produce a sample from outside of the - // present label group - thrust::uniform_int_distribution dist(0, nleft + nright - 1); - - int sample = dist(rng); - int pos_idx = -1; // Bigger label - int neg_idx = -1; // Smaller label - // Are we picking a sample to the left/right of the current group? - if (sample < nleft) { - // Go left - pos_idx = sample + group_begin; - neg_idx = item_idx; - } else { - pos_idx = item_idx; - uint32_t items_in_group = total_group_items - nleft - nright; - neg_idx = sample + items_in_group + group_begin; - } - - // Compute and assign the gradients now - const float eps = 1e-16f; - bst_float p = common::Sigmoid(dpreds[original_pos[pos_idx]] - dpreds[original_pos[neg_idx]]); - bst_float g = p - 1.0f; - bst_float h = thrust::max(p * (1.0f - p), eps); - - // Rescale each gradient and hessian so that the group has a weighted constant - float scale = __frcp_ru(niter / total_items); - if (fix_list_weight != 0.0f) { - scale *= fix_list_weight / total_group_items; - } - - float weight = num_weights ? dweights[group_idx - 1] : 1.0f; - weight *= weight_normalization_factor; - weight *= wmultiplier.GetWeight(group_idx - 1, pos_idx, neg_idx); - weight *= scale; - // Accumulate gradient and hessian in both positive and negative indices - const GradientPair in_pos_gpair(g * weight, 2.0f * weight * h); - dh::AtomicAddGpair(&out_gpair[original_pos[pos_idx]], in_pos_gpair); - - const GradientPair in_neg_gpair(-g * weight, 2.0f * weight * h); - dh::AtomicAddGpair(&out_gpair[original_pos[neg_idx]], in_neg_gpair); - }); - - // Wait until the computations done by the kernel is complete - dh::safe_cuda(cudaStreamSynchronize(nullptr)); - } -}; -#endif - -// objective for lambda rank -template -class LambdaRankObj : public ObjFunction { - public: - void Configure(Args const &args) override { param_.UpdateAllowUnknown(args); } - ObjInfo Task() const override { return ObjInfo::kRanking; } - - void GetGradient(const HostDeviceVector& preds, - const MetaInfo& info, - int iter, - HostDeviceVector* out_gpair) override { - CHECK_EQ(preds.Size(), info.labels.Size()) << "label size predict size not match"; - - // quick consistency when group is not available - std::vector tgptr(2, 0); tgptr[1] = static_cast(info.labels.Size()); - const std::vector &gptr = info.group_ptr_.size() == 0 ? tgptr : info.group_ptr_; - CHECK(gptr.size() != 0 && gptr.back() == info.labels.Size()) - << "group structure not consistent with #rows" << ", " - << "group ponter size: " << gptr.size() << ", " - << "labels size: " << info.labels.Size() << ", " - << "group pointer back: " << (gptr.size() == 0 ? 0 : gptr.back()); - -#if defined(__CUDACC__) - // Check if we have a GPU assignment; else, revert back to CPU - auto device = ctx_->gpu_id; - if (device >= 0) { - ComputeGradientsOnGPU(preds, info, iter, out_gpair, gptr); - } else { - // Revert back to CPU -#endif - ComputeGradientsOnCPU(preds, info, iter, out_gpair, gptr); -#if defined(__CUDACC__) - } -#endif - } - - const char* DefaultEvalMetric() const override { - return "map"; - } - - void SaveConfig(Json* p_out) const override { - auto& out = *p_out; - out["name"] = String(LambdaWeightComputerT::Name()); - out["lambda_rank_param"] = ToJson(param_); - } - - void LoadConfig(Json const& in) override { - FromJson(in["lambda_rank_param"], ¶m_); - } - - private: - bst_float ComputeWeightNormalizationFactor(const MetaInfo& info, - const std::vector &gptr) { - const auto ngroup = static_cast(gptr.size() - 1); - bst_float sum_weights = 0; - for (bst_omp_uint k = 0; k < ngroup; ++k) { - sum_weights += info.GetWeight(k); - } - return ngroup / sum_weights; - } - - void ComputeGradientsOnCPU(const HostDeviceVector& preds, - const MetaInfo& info, - int iter, - HostDeviceVector* out_gpair, - const std::vector &gptr) { - LOG(DEBUG) << "Computing " << LambdaWeightComputerT::Name() << " gradients on CPU."; - - bst_float weight_normalization_factor = ComputeWeightNormalizationFactor(info, gptr); - - const auto& preds_h = preds.HostVector(); - const auto& labels = info.labels.HostView(); - std::vector& gpair = out_gpair->HostVector(); - const auto ngroup = static_cast(gptr.size() - 1); - out_gpair->Resize(preds.Size()); - - dmlc::OMPException exc; -#pragma omp parallel num_threads(ctx_->Threads()) - { - exc.Run([&]() { - // parallel construct, declare random number generator here, so that each - // thread use its own random number generator, seed by thread id and current iteration - std::minstd_rand rnd((iter + 1) * 1111); - std::vector pairs; - std::vector lst; - std::vector< std::pair > rec; - - #pragma omp for schedule(static) - for (bst_omp_uint k = 0; k < ngroup; ++k) { - exc.Run([&]() { - lst.clear(); pairs.clear(); - for (unsigned j = gptr[k]; j < gptr[k+1]; ++j) { - lst.emplace_back(preds_h[j], labels(j), j); - gpair[j] = GradientPair(0.0f, 0.0f); - } - std::stable_sort(lst.begin(), lst.end(), ListEntry::CmpPred); - rec.resize(lst.size()); - for (unsigned i = 0; i < lst.size(); ++i) { - rec[i] = std::make_pair(lst[i].label, i); - } - std::stable_sort(rec.begin(), rec.end(), common::CmpFirst); - // enumerate buckets with same label - // for each item in the lst, grab another sample randomly - for (unsigned i = 0; i < rec.size(); ) { - unsigned j = i + 1; - while (j < rec.size() && rec[j].first == rec[i].first) ++j; - // bucket in [i,j), get a sample outside bucket - unsigned nleft = i, nright = static_cast(rec.size() - j); - if (nleft + nright != 0) { - int nsample = param_.num_pairsample; - while (nsample --) { - for (unsigned pid = i; pid < j; ++pid) { - unsigned ridx = - std::uniform_int_distribution(0, nleft + nright - 1)(rnd); - if (ridx < nleft) { - pairs.emplace_back(rec[ridx].second, rec[pid].second, - info.GetWeight(k) * weight_normalization_factor); - } else { - pairs.emplace_back(rec[pid].second, rec[ridx+j-i].second, - info.GetWeight(k) * weight_normalization_factor); - } - } - } - } - i = j; - } - // get lambda weight for the pairs - LambdaWeightComputerT::GetLambdaWeight(lst, &pairs); - // rescale each gradient and hessian so that the lst have constant weighted - float scale = 1.0f / param_.num_pairsample; - if (param_.fix_list_weight != 0.0f) { - scale *= param_.fix_list_weight / (gptr[k + 1] - gptr[k]); - } - for (auto & pair : pairs) { - const ListEntry &pos = lst[pair.pos_index]; - const ListEntry &neg = lst[pair.neg_index]; - const bst_float w = pair.weight * scale; - const float eps = 1e-16f; - bst_float p = common::Sigmoid(pos.pred - neg.pred); - bst_float g = p - 1.0f; - bst_float h = std::max(p * (1.0f - p), eps); - // accumulate gradient and hessian in both pid, and nid - gpair[pos.rindex] += GradientPair(g * w, 2.0f*w*h); - gpair[neg.rindex] += GradientPair(-g * w, 2.0f*w*h); - } - }); - } - }); - } - exc.Rethrow(); - } - -#if defined(__CUDACC__) - void ComputeGradientsOnGPU(const HostDeviceVector& preds, - const MetaInfo& info, - int iter, - HostDeviceVector* out_gpair, - const std::vector &gptr) { - LOG(DEBUG) << "Computing " << LambdaWeightComputerT::Name() << " gradients on GPU."; - - auto device = ctx_->gpu_id; - dh::safe_cuda(cudaSetDevice(device)); - - bst_float weight_normalization_factor = ComputeWeightNormalizationFactor(info, gptr); - - // Set the device ID and copy them to the device - out_gpair->SetDevice(device); - info.labels.SetDevice(device); - preds.SetDevice(device); - info.weights_.SetDevice(device); - - out_gpair->Resize(preds.Size()); - - auto d_preds = preds.ConstDevicePointer(); - auto d_gpair = out_gpair->DevicePointer(); - auto d_labels = info.labels.View(device); - - SortedLabelList slist(param_); - - // Sort the labels within the groups on the device - slist.Sort(*info.labels.Data(), gptr); - - // Initialize the gradients next - out_gpair->Fill(GradientPair(0.0f, 0.0f)); - - // Finally, compute the gradients - slist.ComputeGradients(d_preds, d_labels.Values().data(), info.weights_, - iter, d_gpair, weight_normalization_factor); - } -#endif - - LambdaRankParam param_; -}; - -#if !defined(GTEST_TEST) -// register the objective functions -DMLC_REGISTER_PARAMETER(LambdaRankParam); - -XGBOOST_REGISTER_OBJECTIVE(PairwiseRankObj, PairwiseLambdaWeightComputer::Name()) -.describe("Pairwise rank objective.") -.set_body([]() { return new LambdaRankObj(); }); - -XGBOOST_REGISTER_OBJECTIVE(LambdaRankNDCG, NDCGLambdaWeightComputer::Name()) -.describe("LambdaRank with NDCG as objective.") -.set_body([]() { return new LambdaRankObj(); }); - -XGBOOST_REGISTER_OBJECTIVE(LambdaRankObjMAP, MAPLambdaWeightComputer::Name()) -.describe("LambdaRank with MAP as objective.") -.set_body([]() { return new LambdaRankObj(); }); -#endif - -} // namespace obj -} // namespace xgboost diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index 8d601f3550bc..86859e390530 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -155,6 +155,7 @@ def main(args: argparse.Namespace) -> None: "demo/guide-python/spark_estimator_examples.py", "demo/guide-python/individual_trees.py", "demo/guide-python/quantile_regression.py", + "demo/guide-python/learning_to_rank.py", # CI "tests/ci_build/lint_python.py", "tests/ci_build/test_r_package.py", diff --git a/tests/cpp/common/test_algorithm.cc b/tests/cpp/common/test_algorithm.cc index 630460714e37..c7bcc19d4e5d 100644 --- a/tests/cpp/common/test_algorithm.cc +++ b/tests/cpp/common/test_algorithm.cc @@ -2,10 +2,11 @@ * Copyright 2020-2023 by XGBoost Contributors */ #include -#include // Context -#include +#include // for Context -#include // is_sorted +#include // for is_sorted +#include // for int32_t +#include // for vector #include "../../../src/common/algorithm.h" @@ -31,5 +32,59 @@ TEST(Algorithm, Sort) { StableSort(&ctx, inputs.begin(), inputs.end(), std::less<>{}); ASSERT_TRUE(std::is_sorted(inputs.cbegin(), inputs.cend())); } + +TEST(Algorithm, AllOf) { + Context ctx; + auto is_zero = [](auto v) { return v == 0; }; + + for (std::size_t n : {0, 3, 16, 128}) { + std::vector data(n, 0); + for (std::int32_t n_threads : {0, 1, 3, 7}) { + ctx.nthread = n_threads; + auto ret = AllOf(&ctx, data.cbegin(), data.cend(), is_zero); + ASSERT_TRUE(ret); + // same result as std for empty case. + ASSERT_TRUE(std::all_of(data.cbegin(), data.cend(), is_zero)); + } + + if (n == 0) { + continue; + } + + data[n / 2] = 1; + for (std::int32_t n_threads : {0, 1, 3, 7}) { + ctx.nthread = n_threads; + auto ret = AllOf(&ctx, data.cbegin(), data.cend(), is_zero); + ASSERT_FALSE(ret); + } + } +} + +TEST(Algorithm, NoneOf) { + Context ctx; + auto is_one = [](auto v) { return v == 1; }; + + for (std::size_t n : {0, 3, 16, 128}) { + std::vector data(n, 0); + for (std::int32_t n_threads : {0, 1, 3, 7}) { + ctx.nthread = n_threads; + auto ret = NoneOf(&ctx, data.cbegin(), data.cend(), is_one); + ASSERT_TRUE(ret); + // same result as std for empty case. + ASSERT_TRUE(std::none_of(data.cbegin(), data.cend(), is_one)); + } + + if (n == 0) { + continue; + } + + data[n / 2] = 1; + for (std::int32_t n_threads : {1, 3, 7}) { + ctx.nthread = n_threads; + auto ret = NoneOf(&ctx, data.cbegin(), data.cend(), is_one); + ASSERT_FALSE(ret); + } + } +} } // namespace common } // namespace xgboost diff --git a/tests/cpp/common/test_ranking_utils.cc b/tests/cpp/common/test_ranking_utils.cc index ea72edd9fdb7..bd35f2b27af8 100644 --- a/tests/cpp/common/test_ranking_utils.cc +++ b/tests/cpp/common/test_ranking_utils.cc @@ -1,38 +1,85 @@ /** * Copyright 2023 by XGBoost Contributors */ -#include +#include "test_ranking_utils.h" -#include // std::uint32_t +#include // for Test, AssertionResult, Message, TestPartR... +#include // for ASSERT_NEAR, ASSERT_T... +#include // for Args +#include // for Context +#include // for StringView -#include "../../../src/common/ranking_utils.h" +#include // for uint32_t +#include // for pair + +#include "../../../src/common/ranking_utils.h" // for LambdaRankParam, ParseMetricName, MakeMet... +#include "test_ranking_utils.h" namespace xgboost { namespace ltr { -TEST(RankingUtils, MakeMetricName) { +TEST(RankingUtils, LambdaRankParam) { + // make sure no memory is shared in dmlc parameter. + LambdaRankParam p0; + p0.UpdateAllowUnknown(Args{{"lambdarank_num_pair_per_sample", "3"}}); + ASSERT_EQ(p0.NumPair(), 3); + + LambdaRankParam p1; + p1.UpdateAllowUnknown(Args{{"lambdarank_num_pair_per_sample", "8"}}); + + ASSERT_EQ(p0.NumPair(), 3); + ASSERT_EQ(p1.NumPair(), 8); + + p0.UpdateAllowUnknown(Args{{"lambdarank_num_pair_per_sample", "17"}}); + ASSERT_EQ(p0.NumPair(), 17); + ASSERT_EQ(p1.NumPair(), 8); +} + +TEST(RankingUtils, ParseMetricName) { std::uint32_t topn{32}; bool minus{false}; - auto name = MakeMetricName("ndcg", "3-", &topn, &minus); + auto name = ParseMetricName("ndcg", "3-", &topn, &minus); ASSERT_EQ(name, "ndcg@3-"); ASSERT_EQ(topn, 3); ASSERT_TRUE(minus); - name = MakeMetricName("ndcg", "6", &topn, &minus); + name = ParseMetricName("ndcg", "6", &topn, &minus); ASSERT_EQ(topn, 6); ASSERT_TRUE(minus); // unchanged minus = false; - name = MakeMetricName("ndcg", "-", &topn, &minus); + name = ParseMetricName("ndcg", "-", &topn, &minus); ASSERT_EQ(topn, 6); // unchanged ASSERT_TRUE(minus); - name = MakeMetricName("ndcg", nullptr, &topn, &minus); + name = ParseMetricName("ndcg", nullptr, &topn, &minus); ASSERT_EQ(topn, 6); // unchanged ASSERT_TRUE(minus); // unchanged - name = MakeMetricName("ndcg", StringView{}, &topn, &minus); + name = ParseMetricName("ndcg", StringView{}, &topn, &minus); ASSERT_EQ(topn, 6); // unchanged ASSERT_TRUE(minus); // unchanged } + +TEST(RankingUtils, MakeMetricName) { + auto name = MakeMetricName("map", LambdaRankParam::NotSet(), true); + ASSERT_EQ(name, "map-"); + name = MakeMetricName("map", LambdaRankParam::NotSet(), false); + ASSERT_EQ(name, "map"); + name = MakeMetricName("map", 2, true); + ASSERT_EQ(name, "map@2-"); + name = MakeMetricName("map", 2, false); + ASSERT_EQ(name, "map@2"); +} + +TEST(NDCGCache, InitFromCPU) { + Context ctx; + TestNDCGCache(&ctx); +} + +TEST(MAPCache, InitFromCPU) { + Context ctx; + ctx.Init(Args{}); + TestMAPCache(&ctx); +} } // namespace ltr } // namespace xgboost diff --git a/tests/cpp/common/test_ranking_utils.cu b/tests/cpp/common/test_ranking_utils.cu new file mode 100644 index 000000000000..6281a82f3f2e --- /dev/null +++ b/tests/cpp/common/test_ranking_utils.cu @@ -0,0 +1,62 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include +#include // for Context +#include // for MakeTensorView, Vector +#include // for Args + +#include // for size_t + +#include "../../../src/common/algorithm.cuh" // for SegmentedSequence +#include "../../../src/common/cuda_context.cuh" // for CUDAContext +#include "../../../src/common/device_helpers.cuh" // for device_vector, LaunchN, ToSpan +#include "../../../src/common/ranking_utils.cuh" +#include "../../../src/common/ranking_utils.h" // for LambdaRankParam +#include "test_ranking_utils.h" + +namespace xgboost { +namespace ltr { +void TestCalcQueriesInvIDCG() { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); + std::size_t n_groups = 5, n_samples_per_group = 32; + + dh::device_vector scores(n_samples_per_group * n_groups); + dh::device_vector group_ptr(n_groups + 1); + auto d_group_ptr = dh::ToSpan(group_ptr); + dh::LaunchN(d_group_ptr.size(), ctx.CUDACtx()->Stream(), + [=] XGBOOST_DEVICE(std::size_t i) { d_group_ptr[i] = i * n_samples_per_group; }); + + auto d_scores = dh::ToSpan(scores); + common::SegmentedSequence(&ctx, d_group_ptr, d_scores); + + linalg::Vector inv_IDCG({n_groups}, ctx.gpu_id); + + ltr::LambdaRankParam p; + p.UpdateAllowUnknown(Args{{"ndcg_exp_gain", "false"}}); + + cuda_impl::CalcQueriesInvIDCG(&ctx, + linalg::MakeTensorView(d_scores, {d_scores.size()}, ctx.gpu_id), + dh::ToSpan(group_ptr), inv_IDCG.View(ctx.gpu_id), p); + for (std::size_t i = 0; i < n_groups; ++i) { + double inv_idcg = inv_IDCG(i); + ASSERT_NEAR(inv_idcg, 0.00551782, kRtEps); + } +} + +TEST(RankingUtils, CalcQueriesInvIDCG) { TestCalcQueriesInvIDCG(); } + +TEST(NDCGCache, InitFromGPU) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); + TestNDCGCache(&ctx); +} + +TEST(MAPCache, InitFromGPU) { + Context ctx; + ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}}); + TestMAPCache(&ctx); +} +} // namespace ltr +} // namespace xgboost diff --git a/tests/cpp/common/test_ranking_utils.h b/tests/cpp/common/test_ranking_utils.h new file mode 100644 index 000000000000..173b865cc459 --- /dev/null +++ b/tests/cpp/common/test_ranking_utils.h @@ -0,0 +1,114 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#pragma once +#include // for ASSERT_NEAR, ASSERT_T... +#include // for Args, bst_group_t +#include // for Context +#include // for MetaInfo, DMatrix +#include // for Error + +#include // for size_t +#include // for pair, move +#include // for vector + +#include "../../../src/common/numeric.h" // for Iota +#include "../../../src/common/ranking_utils.h" // for LambdaRankParam, NDCG... +#include "../helpers.h" // for EmptyDMatrix +#include "gtest/gtest.h" // for Message, TestPartResult +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/linalg.h" // for Tensor, VectorView +#include "xgboost/span.h" // for Span + +namespace xgboost { +namespace ltr { +inline void TestNDCGCache(Context const* ctx) { + auto p_fmat = EmptyDMatrix(); + MetaInfo& info = p_fmat->Info(); + LambdaRankParam param; + param.UpdateAllowUnknown(Args{}); + + { + // empty + NDCGCache cache{ctx, info, param}; + ASSERT_EQ(cache.DataGroupPtr(ctx).size(), 2); + } + + info.num_row_ = 3; + info.group_ptr_ = {static_cast(0), static_cast(info.num_row_)}; + + { + auto fail = [&]() { NDCGCache cache{ctx, info, param}; }; + // empty label + ASSERT_THROW(fail(), dmlc::Error); + info.labels = linalg::Matrix{{0.0f, 0.1f, 0.2f}, {3}, Context::kCpuId}; + // invalid label + ASSERT_THROW(fail(), dmlc::Error); + auto h_labels = info.labels.HostView(); + for (std::size_t i = 0; i < h_labels.Size(); ++i) { + h_labels(i) *= 10; + } + param.UpdateAllowUnknown(Args{{"ndcg_exp_gain", "false"}}); + NDCGCache cache{ctx, info, param}; + Context cpuctx; + auto inv_idcg = cache.InvIDCG(&cpuctx); + ASSERT_EQ(inv_idcg.Size(), 1); + ASSERT_NEAR(1.0 / inv_idcg(0), 2.63093, kRtEps); + } + + { + param.UpdateAllowUnknown(Args{{"lambdarank_unbiased", "false"}}); + + std::vector h_data(32); + + common::Iota(ctx, h_data.begin(), h_data.end(), 0.0f); + info.labels.Reshape(h_data.size()); + info.num_row_ = h_data.size(); + info.group_ptr_.back() = info.num_row_; + info.labels.Data()->HostVector() = std::move(h_data); + + { + NDCGCache cache{ctx, info, param}; + Context cpuctx; + auto inv_idcg = cache.InvIDCG(&cpuctx); + ASSERT_NEAR(inv_idcg(0), 0.00551782, kRtEps); + } + + param.UpdateAllowUnknown( + Args{{"lambdarank_num_pair_per_sample", "3"}, {"lambdarank_pair_method", "topk"}}); + { + NDCGCache cache{ctx, info, param}; + Context cpuctx; + auto inv_idcg = cache.InvIDCG(&cpuctx); + ASSERT_NEAR(inv_idcg(0), 0.01552123, kRtEps); + } + } +} + +inline void TestMAPCache(Context const* ctx) { + auto p_fmat = EmptyDMatrix(); + MetaInfo& info = p_fmat->Info(); + LambdaRankParam param; + param.UpdateAllowUnknown(Args{}); + + std::vector h_data(32); + + common::Iota(ctx, h_data.begin(), h_data.end(), 0.0f); + info.labels.Reshape(h_data.size()); + info.num_row_ = h_data.size(); + info.labels.Data()->HostVector() = std::move(h_data); + + auto fail = [&]() { std::make_shared(ctx, info, param); }; + // binary label + ASSERT_THROW(fail(), dmlc::Error); + + h_data = std::vector(32, 0.0f); + h_data[1] = 1.0f; + info.labels.Data()->HostVector() = h_data; + auto p_cache = std::make_shared(ctx, info, param); + + ASSERT_EQ(p_cache->Acc(ctx).size(), info.num_row_); + ASSERT_EQ(p_cache->NumRelevant(ctx).size(), info.num_row_); +} +} // namespace ltr +} // namespace xgboost diff --git a/tests/cpp/metric/test_rank_metric.cc b/tests/cpp/metric/test_rank_metric.cc index 1edbd9fc8d76..3e1028c48d7e 100644 --- a/tests/cpp/metric/test_rank_metric.cc +++ b/tests/cpp/metric/test_rank_metric.cc @@ -1,7 +1,20 @@ -// Copyright by Contributors -#include - -#include "../helpers.h" +/** + * Copyright 2016-2023 by XGBoost Contributors + */ +#include // for Test, EXPECT_NEAR, ASSERT_STREQ +#include // for Context +#include // for MetaInfo, DMatrix +#include // for Matrix +#include // for Metric + +#include // for max +#include // for unique_ptr +#include // for vector + +#include "../helpers.h" // for GetMetricEval, CreateEmptyGe... +#include "xgboost/base.h" // for bst_float, kRtEps +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/json.h" // for Json, String, Object #if !defined(__CUDACC__) TEST(Metric, AMS) { @@ -51,15 +64,17 @@ TEST(Metric, DeclareUnifiedTest(Precision)) { delete metric; } +namespace xgboost { +namespace metric { TEST(Metric, DeclareUnifiedTest(NDCG)) { - auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX); - xgboost::Metric * metric = xgboost::Metric::Create("ndcg", &ctx); + auto ctx = CreateEmptyGenericParam(GPUIDX); + Metric * metric = xgboost::Metric::Create("ndcg", &ctx); ASSERT_STREQ(metric->Name(), "ndcg"); EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {})); - EXPECT_NEAR(GetMetricEval(metric, + ASSERT_NEAR(GetMetricEval(metric, xgboost::HostDeviceVector{}, {}), 1, 1e-10); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); + ASSERT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, { 0, 0, 1, 1}), @@ -80,7 +95,7 @@ TEST(Metric, DeclareUnifiedTest(NDCG)) { EXPECT_NEAR(GetMetricEval(metric, xgboost::HostDeviceVector{}, {}), 0, 1e-10); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); + ASSERT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1.f, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, { 0, 0, 1, 1}), @@ -91,29 +106,30 @@ TEST(Metric, DeclareUnifiedTest(NDCG)) { EXPECT_NEAR(GetMetricEval(metric, xgboost::HostDeviceVector{}, {}), 0, 1e-10); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1.f, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, { 0, 0, 1, 1}), - 0.6509f, 0.001f); + 0.6509f, 0.001f); delete metric; metric = xgboost::Metric::Create("ndcg@2-", &ctx); ASSERT_STREQ(metric->Name(), "ndcg@2-"); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1.f, 1e-10); EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, { 0, 0, 1, 1}), - 0.3868f, 0.001f); + 1.f - 0.3868f, 1.f - 0.001f); delete metric; } TEST(Metric, DeclareUnifiedTest(MAP)) { auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX); - xgboost::Metric * metric = xgboost::Metric::Create("map", &ctx); + Metric * metric = xgboost::Metric::Create("map", &ctx); ASSERT_STREQ(metric->Name(), "map"); - EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, 1e-10); + EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}), 1, kRtEps); + EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, { 0, 0, 1, 1}), @@ -125,7 +141,7 @@ TEST(Metric, DeclareUnifiedTest(MAP)) { // Rank metric with group info EXPECT_NEAR(GetMetricEval(metric, {0.1f, 0.9f, 0.2f, 0.8f, 0.4f, 1.7f}, - {2, 7, 1, 0, 5, 0}, // Labels + {1, 1, 1, 0, 1, 0}, // Labels {}, // Weights {0, 2, 5, 6}), // Group info 0.8611f, 0.001f); @@ -154,3 +170,39 @@ TEST(Metric, DeclareUnifiedTest(MAP)) { 0.25f, 0.001f); delete metric; } + +TEST(Metric, DeclareUnifiedTest(NDCGExpGain)) { + Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX); + + auto p_fmat = xgboost::RandomDataGenerator{0, 0, 0}.GenerateDMatrix(); + MetaInfo& info = p_fmat->Info(); + info.labels = linalg::Matrix{{10.0f, 0.0f, 0.0f, 1.0f, 5.0f}, {5}, ctx.gpu_id}; + info.num_row_ = info.labels.Shape(0); + info.group_ptr_.resize(2); + info.group_ptr_[0] = 0; + info.group_ptr_[1] = info.num_row_; + HostDeviceVector predt{{0.1f, 0.2f, 0.3f, 4.0f, 70.0f}}; + + std::unique_ptr metric{Metric::Create("ndcg", &ctx)}; + Json config{Object{}}; + config["name"] = String{"ndcg"}; + config["lambdarank_param"] = Object{}; + config["lambdarank_param"]["ndcg_exp_gain"] = String{"true"}; + config["lambdarank_param"]["lambdarank_num_pair_per_sample"] = String{"32"}; + metric->LoadConfig(config); + + auto ndcg = metric->Evaluate(predt, p_fmat); + ASSERT_NEAR(ndcg, 0.409738f, kRtEps); + + config["lambdarank_param"]["ndcg_exp_gain"] = String{"false"}; + metric->LoadConfig(config); + + ndcg = metric->Evaluate(predt, p_fmat); + ASSERT_NEAR(ndcg, 0.695694f, kRtEps); + + predt.HostVector() = info.labels.Data()->HostVector(); + ndcg = metric->Evaluate(predt, p_fmat); + ASSERT_NEAR(ndcg, 1.0, kRtEps); +} +} // namespace metric +} // namespace xgboost diff --git a/tests/cpp/objective/test_lambdarank_obj.cc b/tests/cpp/objective/test_lambdarank_obj.cc new file mode 100644 index 000000000000..98343d0b4bfa --- /dev/null +++ b/tests/cpp/objective/test_lambdarank_obj.cc @@ -0,0 +1,273 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include "test_lambdarank_obj.h" + +#include // for Test, Message, TestPartResult, CmpHel... + +#include // for size_t +#include // for initializer_list +#include // for map +#include // for unique_ptr, shared_ptr, make_shared +#include // for iota +#include // for char_traits, basic_string, string +#include // for vector + +#include "../../../src/common/ranking_utils.h" // for LambdaRankParam +#include "../../../src/common/ranking_utils.h" // for NDCGCache, LambdaRankParam +#include "../helpers.h" // for CheckRankingObjFunction, CheckConfigReload +#include "xgboost/base.h" // for GradientPair, bst_group_t, Args +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo, DMatrix +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/linalg.h" // for Tensor, All, TensorView +#include "xgboost/objective.h" // for ObjFunction +#include "xgboost/span.h" // for Span + +namespace xgboost::obj { +TEST(LambdaRank, NDCGJsonIO) { + Context ctx; + TestNDCGJsonIO(&ctx); +} + +void TestNDCGGPair(Context const* ctx) { + std::unique_ptr obj{xgboost::ObjFunction::Create("rank:ndcg", ctx)}; + obj->Configure(Args{{"lambdarank_pair_method", "topk"}}); + CheckConfigReload(obj, "rank:ndcg"); + + // No gain in swapping 2 documents. + CheckRankingObjFunction(obj, {1, 1, 1, 1}, {1, 1, 1, 1}, {1.0f, 1.0f}, {0, 2, 4}, + {0.0f, -0.0f, 0.0f, 0.0f}, {0.0f, 0.0f, 0.0f, 0.0f}); + + HostDeviceVector predts{0, 1, 0, 1}; + MetaInfo info; + info.labels = linalg::Tensor{{0, 1, 0, 1}, {4, 1}, GPUIDX}; + info.group_ptr_ = {0, 2, 4}; + info.num_row_ = 4; + HostDeviceVector gpairs; + obj->GetGradient(predts, info, 0, &gpairs); + ASSERT_EQ(gpairs.Size(), predts.Size()); + + { + predts = {1, 0, 1, 0}; + HostDeviceVector gpairs; + obj->GetGradient(predts, info, 0, &gpairs); + for (size_t i = 0; i < gpairs.Size(); ++i) { + ASSERT_GT(gpairs.HostSpan()[i].GetHess(), 0); + } + ASSERT_LT(gpairs.HostSpan()[1].GetGrad(), 0); + ASSERT_LT(gpairs.HostSpan()[3].GetGrad(), 0); + + ASSERT_GT(gpairs.HostSpan()[0].GetGrad(), 0); + ASSERT_GT(gpairs.HostSpan()[2].GetGrad(), 0); + + info.weights_ = {2, 3}; + HostDeviceVector weighted_gpairs; + obj->GetGradient(predts, info, 0, &weighted_gpairs); + auto const& h_gpairs = gpairs.ConstHostSpan(); + auto const& h_weighted_gpairs = weighted_gpairs.ConstHostSpan(); + for (size_t i : {0ul, 1ul}) { + ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetGrad(), h_gpairs[i].GetGrad() * 2.0f); + ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetHess(), h_gpairs[i].GetHess() * 2.0f); + } + for (size_t i : {2ul, 3ul}) { + ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetGrad(), h_gpairs[i].GetGrad() * 3.0f); + ASSERT_FLOAT_EQ(h_weighted_gpairs[i].GetHess(), h_gpairs[i].GetHess() * 3.0f); + } + } + + ASSERT_NO_THROW(obj->DefaultEvalMetric()); +} + +TEST(LambdaRank, NDCGGPair) { + Context ctx; + TestNDCGGPair(&ctx); +} + +void InitMakePairTest(Context const* ctx, MetaInfo* out_info, HostDeviceVector* out_predt) { + out_predt->SetDevice(ctx->gpu_id); + MetaInfo& info = *out_info; + info.num_row_ = 128; + info.labels.ModifyInplace([&](HostDeviceVector* data, common::Span shape) { + shape[0] = info.num_row_; + shape[1] = 1; + auto& h_data = data->HostVector(); + h_data.resize(shape[0]); + for (std::size_t i = 0; i < h_data.size(); ++i) { + h_data[i] = i % 2; + } + }); + std::vector predt(info.num_row_); + std::iota(predt.rbegin(), predt.rend(), 0.0f); + out_predt->HostVector() = predt; +} + +TEST(LambdaRank, MakePair) { + Context ctx; + MetaInfo info; + HostDeviceVector predt; + + InitMakePairTest(&ctx, &info, &predt); + + ltr::LambdaRankParam param; + param.UpdateAllowUnknown(Args{{"lambdarank_pair_method", "topk"}}); + ASSERT_TRUE(param.HasTruncation()); + + std::shared_ptr p_cache = std::make_shared(&ctx, info, param); + auto const& h_predt = predt.ConstHostVector(); + { + auto rank_idx = p_cache->SortedIdx(&ctx, h_predt); + for (std::size_t i = 0; i < h_predt.size(); ++i) { + ASSERT_EQ(rank_idx[i], static_cast(*(h_predt.crbegin() + i))); + } + std::int32_t n_pairs{0}; + MakePairs(&ctx, 0, p_cache, 0, info.labels.HostView().Slice(linalg::All(), 0), rank_idx, + [&](auto i, auto j) { + ASSERT_GT(j, i); + ASSERT_LT(i, p_cache->Param().NumPair()); + ++n_pairs; + }); + ASSERT_EQ(n_pairs, 3568); + } + + auto const h_label = info.labels.HostView(); + + { + param.UpdateAllowUnknown(Args{{"lambdarank_pair_method", "mean"}}); + auto p_cache = std::make_shared(&ctx, info, param); + ASSERT_FALSE(param.HasTruncation()); + std::int32_t n_pairs = 0; + auto rank_idx = p_cache->SortedIdx(&ctx, h_predt); + MakePairs(&ctx, 0, p_cache, 0, info.labels.HostView().Slice(linalg::All(), 0), rank_idx, + [&](auto i, auto j) { + ++n_pairs; + // Not in the same bucket + ASSERT_NE(h_label(rank_idx[i]), h_label(rank_idx[j])); + }); + ASSERT_EQ(n_pairs, info.num_row_ * param.NumPair()); + } + + { + param.UpdateAllowUnknown(Args{{"lambdarank_num_pair_per_sample", "2"}}); + auto p_cache = std::make_shared(&ctx, info, param); + auto rank_idx = p_cache->SortedIdx(&ctx, h_predt); + std::int32_t n_pairs = 0; + MakePairs(&ctx, 0, p_cache, 0, info.labels.HostView().Slice(linalg::All(), 0), rank_idx, + [&](auto i, auto j) { + ++n_pairs; + // Not in the same bucket + ASSERT_NE(h_label(rank_idx[i]), h_label(rank_idx[j])); + }); + ASSERT_EQ(param.NumPair(), 2); + ASSERT_EQ(n_pairs, info.num_row_ * param.NumPair()); + } +} + +void TestMAPStat(Context const* ctx) { + auto p_fmat = EmptyDMatrix(); + MetaInfo& info = p_fmat->Info(); + ltr::LambdaRankParam param; + param.UpdateAllowUnknown(Args{}); + + { + std::vector h_data{1.0f, 1.0f, 0.0f, 1.0f, 1.0f, 1.0f}; + info.labels.Reshape(h_data.size(), 1); + info.labels.Data()->HostVector() = h_data; + info.num_row_ = h_data.size(); + + HostDeviceVector predt; + auto& h_predt = predt.HostVector(); + h_predt.resize(h_data.size()); + std::iota(h_predt.rbegin(), h_predt.rend(), 0.0f); + + auto p_cache = std::make_shared(ctx, info, param); + + predt.SetDevice(ctx->gpu_id); + auto rank_idx = + p_cache->SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan()); + + if (ctx->IsCPU()) { + obj::cpu_impl::MAPStat(ctx, info.labels.HostView().Slice(linalg::All(), 0), rank_idx, + p_cache); + } else { + obj::cuda_impl::MAPStat(ctx, info, rank_idx, p_cache); + } + + Context cpu_ctx; + auto n_rel = p_cache->NumRelevant(&cpu_ctx); + auto acc = p_cache->Acc(&cpu_ctx); + + ASSERT_EQ(n_rel[0], 1.0); + ASSERT_EQ(acc[0], 1.0); + + ASSERT_EQ(n_rel.back(), h_data.size() - 1.0); + ASSERT_NEAR(acc.back(), 1.95 + (1.0 / h_data.size()), kRtEps); + } + { + info.labels.Reshape(16); + auto& h_label = info.labels.Data()->HostVector(); + info.group_ptr_ = {0, 8, 16}; + info.num_row_ = info.labels.Shape(0); + + std::fill_n(h_label.begin(), 8, 1.0f); + std::fill_n(h_label.begin() + 8, 8, 0.0f); + HostDeviceVector predt; + auto& h_predt = predt.HostVector(); + h_predt.resize(h_label.size()); + std::iota(h_predt.rbegin(), h_predt.rbegin() + 8, 0.0f); + std::iota(h_predt.rbegin() + 8, h_predt.rend(), 0.0f); + + auto p_cache = std::make_shared(ctx, info, param); + + predt.SetDevice(ctx->gpu_id); + auto rank_idx = + p_cache->SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan()); + + if (ctx->IsCPU()) { + obj::cpu_impl::MAPStat(ctx, info.labels.HostView().Slice(linalg::All(), 0), rank_idx, + p_cache); + } else { + obj::cuda_impl::MAPStat(ctx, info, rank_idx, p_cache); + } + + Context cpu_ctx; + auto n_rel = p_cache->NumRelevant(&cpu_ctx); + ASSERT_EQ(n_rel[7], 8); // first group + ASSERT_EQ(n_rel.back(), 0); // second group + } +} + +TEST(LambdaRank, MAPStat) { + Context ctx; + TestMAPStat(&ctx); +} + +void TestMAPGPair(Context const* ctx) { + std::unique_ptr obj{xgboost::ObjFunction::Create("rank:map", ctx)}; + Args args; + obj->Configure(args); + + CheckConfigReload(obj, "rank:map"); + + CheckRankingObjFunction(obj, // obj + {0, 0.1f, 0, 0.1f}, // score + {0, 1, 0, 1}, // label + {2.0f, 2.0f}, // weight + {0, 2, 4}, // group + {0.4750208f, -0.4750208f, 0.4750208f, -0.4750208f}, // out grad + {0.4987521f, 0.4987521f, 0.4987521f, 0.4987521f}); + // disable the second query group with 0 weight + CheckRankingObjFunction(obj, // obj + {0, 0.1f, 0, 0.1f}, // score + {0, 1, 0, 1}, // label + {2.0f, 0.0f}, // weight + {0, 2, 4}, // group + {0.4750208f, -0.4750208f, 0.0f, 0.0f}, // out grad + {0.4987521f, 0.4987521f, 0.0f, 0.0f}); +} + +TEST(LambdaRank, MAPGPair) { + Context ctx; + TestMAPGPair(&ctx); +} +} // namespace xgboost::obj diff --git a/tests/cpp/objective/test_lambdarank_obj.cu b/tests/cpp/objective/test_lambdarank_obj.cu new file mode 100644 index 000000000000..ef82e1ff2a7c --- /dev/null +++ b/tests/cpp/objective/test_lambdarank_obj.cu @@ -0,0 +1,161 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#include +#include // for Context + +#include // for uint32_t +#include // for vector + +#include "../../../src/common/cuda_context.cuh" // for CUDAContext +#include "../../../src/objective/lambdarank_obj.cuh" +#include "test_lambdarank_obj.h" + +namespace xgboost::obj { +TEST(LambdaRank, GPUNDCGJsonIO) { + Context ctx; + ctx.gpu_id = 0; + TestNDCGJsonIO(&ctx); +} + +TEST(LambdaRank, GPUMAPStat) { + Context ctx; + ctx.gpu_id = 0; + TestMAPStat(&ctx); +} + +TEST(LambdaRank, GPUNDCGGPair) { + Context ctx; + ctx.gpu_id = 0; + TestNDCGGPair(&ctx); +} + +void TestGPUMakePair() { + Context ctx; + ctx.gpu_id = 0; + + MetaInfo info; + HostDeviceVector predt; + InitMakePairTest(&ctx, &info, &predt); + + ltr::LambdaRankParam param; + + auto make_args = [&](std::shared_ptr p_cache, auto rank_idx, + common::Span y_sorted_idx) { + linalg::Vector dummy; + auto d = dummy.View(ctx.gpu_id); + linalg::Vector dgpair; + auto dg = dgpair.View(ctx.gpu_id); + cuda_impl::KernelInputs args{d, + d, + d, + d, + p_cache->DataGroupPtr(&ctx), + p_cache->CUDAThreadsGroupPtr(), + rank_idx, + info.labels.View(ctx.gpu_id), + predt.ConstDeviceSpan(), + {}, + dg, + nullptr, + y_sorted_idx, + 0}; + return args; + }; + + { + param.UpdateAllowUnknown(Args{{"lambdarank_pair_method", "topk"}}); + auto p_cache = std::make_shared(&ctx, info, param); + auto rank_idx = p_cache->SortedIdx(&ctx, predt.ConstDeviceSpan()); + + ASSERT_EQ(p_cache->CUDAThreads(), 3568); + + auto args = make_args(p_cache, rank_idx, {}); + auto n_pairs = p_cache->Param().NumPair(); + auto make_pair = cuda_impl::MakePairsOp{args}; + + dh::LaunchN(p_cache->CUDAThreads(), ctx.CUDACtx()->Stream(), + [=] XGBOOST_DEVICE(std::size_t idx) { + auto [i, j] = make_pair(idx, 0); + SPAN_CHECK(j > i); + SPAN_CHECK(i < n_pairs); + }); + } + { + param.UpdateAllowUnknown(Args{{"lambdarank_pair_method", "mean"}}); + auto p_cache = std::make_shared(&ctx, info, param); + auto rank_idx = p_cache->SortedIdx(&ctx, predt.ConstDeviceSpan()); + auto y_sorted_idx = cuda_impl::SortY(&ctx, info, rank_idx, p_cache); + + ASSERT_FALSE(param.HasTruncation()); + ASSERT_EQ(p_cache->CUDAThreads(), info.num_row_ * param.NumPair()); + + auto args = make_args(p_cache, rank_idx, y_sorted_idx); + auto make_pair = cuda_impl::MakePairsOp{args}; + auto n_pairs = p_cache->Param().NumPair(); + + dh::LaunchN( + p_cache->CUDAThreads(), ctx.CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t idx) { + idx = 97; + auto [i, j] = make_pair(idx, 0); + // Not in the same bucket + SPAN_CHECK(make_pair.args.labels(rank_idx[i]) != make_pair.args.labels(rank_idx[j])); + }); + } + { + param.UpdateAllowUnknown(Args{{"lambdarank_num_pair_per_sample", "2"}}); + auto p_cache = std::make_shared(&ctx, info, param); + auto rank_idx = p_cache->SortedIdx(&ctx, predt.ConstDeviceSpan()); + auto y_sorted_idx = cuda_impl::SortY(&ctx, info, rank_idx, p_cache); + + auto args = make_args(p_cache, rank_idx, y_sorted_idx); + auto make_pair = cuda_impl::MakePairsOp{args}; + + dh::LaunchN( + p_cache->CUDAThreads(), ctx.CUDACtx()->Stream(), [=] XGBOOST_DEVICE(std::size_t idx) { + auto [i, j] = make_pair(idx, 0); + // Not in the same bucket + SPAN_CHECK(make_pair.args.labels(rank_idx[i]) != make_pair.args.labels(rank_idx[j])); + }); + ASSERT_EQ(param.NumPair(), 2); + ASSERT_EQ(p_cache->CUDAThreads(), info.num_row_ * param.NumPair()); + } +} + +TEST(LambdaRank, GPUMakePair) { TestGPUMakePair(); } + +template +void RankItemCountImpl(std::vector const &sorted_items, CountFunctor f, + std::uint32_t find_val, std::uint32_t exp_val) { + EXPECT_NE(std::find(sorted_items.begin(), sorted_items.end(), find_val), sorted_items.end()); + EXPECT_EQ(f(&sorted_items[0], sorted_items.size(), find_val), exp_val); +} + +TEST(LambdaRank, RankItemCountOnLeft) { + // Items sorted descendingly + std::vector sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0}; + auto wrapper = [](auto const &...args) { return cuda_impl::CountNumItemsToTheLeftOf(args...); }; + RankItemCountImpl(sorted_items, wrapper, 10, static_cast(0)); + RankItemCountImpl(sorted_items, wrapper, 6, static_cast(2)); + RankItemCountImpl(sorted_items, wrapper, 4, static_cast(3)); + RankItemCountImpl(sorted_items, wrapper, 1, static_cast(7)); + RankItemCountImpl(sorted_items, wrapper, 0, static_cast(12)); +} + +TEST(LambdaRank, RankItemCountOnRight) { + // Items sorted descendingly + std::vector sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0}; + auto wrapper = [](auto const &...args) { return cuda_impl::CountNumItemsToTheRightOf(args...); }; + RankItemCountImpl(sorted_items, wrapper, 10, static_cast(11)); + RankItemCountImpl(sorted_items, wrapper, 6, static_cast(10)); + RankItemCountImpl(sorted_items, wrapper, 4, static_cast(6)); + RankItemCountImpl(sorted_items, wrapper, 1, static_cast(1)); + RankItemCountImpl(sorted_items, wrapper, 0, static_cast(0)); +} + +TEST(LambdaRank, GPUMAPGPair) { + Context ctx; + ctx.gpu_id = 0; + TestMAPGPair(&ctx); +} +} // namespace xgboost::obj diff --git a/tests/cpp/objective/test_lambdarank_obj.h b/tests/cpp/objective/test_lambdarank_obj.h new file mode 100644 index 000000000000..2a6e5a3ba039 --- /dev/null +++ b/tests/cpp/objective/test_lambdarank_obj.h @@ -0,0 +1,46 @@ +/** + * Copyright 2023 by XGBoost Contributors + */ +#ifndef XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ +#define XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ +#include +#include // for MetaInfo +#include // for HostDeviceVector +#include // for All +#include // for ObjFunction + +#include // for shared_ptr, make_shared +#include // for iota +#include // for vector + +#include "../../../src/common/ranking_utils.h" // for LambdaRankParam, MAPCache +#include "../../../src/objective/lambdarank_obj.h" // for MAPStat +#include "../helpers.h" // for EmptyDMatrix + +namespace xgboost::obj { +void TestMAPStat(Context const* ctx); + +inline void TestNDCGJsonIO(Context const* ctx) { + std::unique_ptr obj{ObjFunction::Create("rank:ndcg", ctx)}; + + obj->Configure(Args{}); + Json j_obj{Object()}; + obj->SaveConfig(&j_obj); + + ASSERT_EQ(get(j_obj["name"]), "rank:ndcg"); + auto const& j_param = j_obj["lambdarank_param"]; + + ASSERT_EQ(get(j_param["ndcg_exp_gain"]), "1"); + ASSERT_EQ(get(j_param["lambdarank_num_pair_per_sample"]), + std::to_string(ltr::LambdaRankParam::NotSet())); +} + +void TestNDCGGPair(Context const* ctx); +void TestMAPGPair(Context const* ctx); + +/** + * \brief Initialize test data for make pair tests. + */ +void InitMakePairTest(Context const* ctx, MetaInfo* out_info, HostDeviceVector* out_predt); +} // namespace xgboost::obj +#endif // XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ diff --git a/tests/cpp/objective/test_ranking_obj.cc b/tests/cpp/objective/test_ranking_obj.cc deleted file mode 100644 index a007750e3d81..000000000000 --- a/tests/cpp/objective/test_ranking_obj.cc +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright by Contributors -#include -#include -#include - -#include "../helpers.h" - -namespace xgboost { - -TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPair)) { - std::vector> args; - xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX); - - std::unique_ptr obj{xgboost::ObjFunction::Create("rank:pairwise", &ctx)}; - obj->Configure(args); - CheckConfigReload(obj, "rank:pairwise"); - - // Test with setting sample weight to second query group - CheckRankingObjFunction(obj, - {0, 0.1f, 0, 0.1f}, - {0, 1, 0, 1}, - {2.0f, 0.0f}, - {0, 2, 4}, - {1.9f, -1.9f, 0.0f, 0.0f}, - {1.995f, 1.995f, 0.0f, 0.0f}); - - CheckRankingObjFunction(obj, - {0, 0.1f, 0, 0.1f}, - {0, 1, 0, 1}, - {1.0f, 1.0f}, - {0, 2, 4}, - {0.95f, -0.95f, 0.95f, -0.95f}, - {0.9975f, 0.9975f, 0.9975f, 0.9975f}); - - ASSERT_NO_THROW(obj->DefaultEvalMetric()); -} - -TEST(Objective, DeclareUnifiedTest(NDCG_JsonIO)) { - xgboost::Context ctx; - ctx.UpdateAllowUnknown(Args{}); - - std::unique_ptr obj{xgboost::ObjFunction::Create("rank:ndcg", &ctx)}; - - obj->Configure(Args{}); - Json j_obj {Object()}; - obj->SaveConfig(&j_obj); - - ASSERT_EQ(get(j_obj["name"]), "rank:ndcg");; - - auto const& j_param = j_obj["lambda_rank_param"]; - - ASSERT_EQ(get(j_param["num_pairsample"]), "1"); - ASSERT_EQ(get(j_param["fix_list_weight"]), "0"); -} - -TEST(Objective, DeclareUnifiedTest(PairwiseRankingGPairSameLabels)) { - std::vector> args; - xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX); - - std::unique_ptr obj{ObjFunction::Create("rank:pairwise", &ctx)}; - obj->Configure(args); - // No computation of gradient/hessian, as there is no diversity in labels - CheckRankingObjFunction(obj, - {0, 0.1f, 0, 0.1f}, - {1, 1, 1, 1}, - {2.0f, 0.0f}, - {0, 2, 4}, - {0.0f, 0.0f, 0.0f, 0.0f}, - {0.0f, 0.0f, 0.0f, 0.0f}); - - ASSERT_NO_THROW(obj->DefaultEvalMetric()); -} - -TEST(Objective, DeclareUnifiedTest(NDCGRankingGPair)) { - std::vector> args; - xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX); - - std::unique_ptr obj{xgboost::ObjFunction::Create("rank:ndcg", &ctx)}; - obj->Configure(args); - CheckConfigReload(obj, "rank:ndcg"); - - // Test with setting sample weight to second query group - CheckRankingObjFunction(obj, - {0, 0.1f, 0, 0.1f}, - {0, 1, 0, 1}, - {2.0f, 0.0f}, - {0, 2, 4}, - {0.7f, -0.7f, 0.0f, 0.0f}, - {0.74f, 0.74f, 0.0f, 0.0f}); - - CheckRankingObjFunction(obj, - {0, 0.1f, 0, 0.1f}, - {0, 1, 0, 1}, - {1.0f, 1.0f}, - {0, 2, 4}, - {0.35f, -0.35f, 0.35f, -0.35f}, - {0.368f, 0.368f, 0.368f, 0.368f}); - ASSERT_NO_THROW(obj->DefaultEvalMetric()); -} - -TEST(Objective, DeclareUnifiedTest(MAPRankingGPair)) { - std::vector> args; - xgboost::Context ctx = xgboost::CreateEmptyGenericParam(GPUIDX); - - std::unique_ptr obj{xgboost::ObjFunction::Create("rank:map", &ctx)}; - obj->Configure(args); - CheckConfigReload(obj, "rank:map"); - - // Test with setting sample weight to second query group - CheckRankingObjFunction(obj, - {0, 0.1f, 0, 0.1f}, - {0, 1, 0, 1}, - {2.0f, 0.0f}, - {0, 2, 4}, - {0.95f, -0.95f, 0.0f, 0.0f}, - {0.9975f, 0.9975f, 0.0f, 0.0f}); - - CheckRankingObjFunction(obj, - {0, 0.1f, 0, 0.1f}, - {0, 1, 0, 1}, - {1.0f, 1.0f}, - {0, 2, 4}, - {0.475f, -0.475f, 0.475f, -0.475f}, - {0.4988f, 0.4988f, 0.4988f, 0.4988f}); - ASSERT_NO_THROW(obj->DefaultEvalMetric()); -} - -} // namespace xgboost diff --git a/tests/cpp/objective/test_ranking_obj_gpu.cu b/tests/cpp/objective/test_ranking_obj_gpu.cu deleted file mode 100644 index 02286ab46b10..000000000000 --- a/tests/cpp/objective/test_ranking_obj_gpu.cu +++ /dev/null @@ -1,268 +0,0 @@ -/*! - * Copyright 2019-2021 by XGBoost Contributors - */ -#include - -#include "test_ranking_obj.cc" -#include "../../../src/objective/rank_obj.cu" - -namespace xgboost { - -template > -std::unique_ptr> -RankSegmentSorterTestImpl(const std::vector &group_indices, - const std::vector &hlabels, - const std::vector &expected_sorted_hlabels, - const std::vector &expected_orig_pos - ) { - std::unique_ptr> seg_sorter_ptr(new dh::SegmentSorter); - dh::SegmentSorter &seg_sorter(*seg_sorter_ptr); - - // Create a bunch of unsorted labels on the device and sort it via the segment sorter - dh::device_vector dlabels(hlabels); - seg_sorter.SortItems(dlabels.data().get(), dlabels.size(), group_indices, Comparator()); - - auto num_items = seg_sorter.GetItemsSpan().size(); - EXPECT_EQ(num_items, group_indices.back()); - EXPECT_EQ(seg_sorter.GetNumGroups(), group_indices.size() - 1); - - // Check the labels - dh::device_vector sorted_dlabels(num_items); - sorted_dlabels.assign(dh::tcbegin(seg_sorter.GetItemsSpan()), - dh::tcend(seg_sorter.GetItemsSpan())); - thrust::host_vector sorted_hlabels(sorted_dlabels); - EXPECT_EQ(expected_sorted_hlabels, sorted_hlabels); - - // Check the indices - dh::device_vector dorig_pos(num_items); - dorig_pos.assign(dh::tcbegin(seg_sorter.GetOriginalPositionsSpan()), - dh::tcend(seg_sorter.GetOriginalPositionsSpan())); - dh::device_vector horig_pos(dorig_pos); - EXPECT_EQ(expected_orig_pos, horig_pos); - - return seg_sorter_ptr; -} - -TEST(Objective, RankSegmentSorterTest) { - RankSegmentSorterTestImpl({0, 2, 4, 7, 10, 14, 18, 22, 26}, // Groups - {1, 1, // Labels - 1, 2, - 3, 2, 1, - 1, 2, 1, - 1, 3, 4, 2, - 1, 2, 1, 1, - 1, 2, 2, 3, - 3, 3, 1, 2}, - {1, 1, // Expected sorted labels - 2, 1, - 3, 2, 1, - 2, 1, 1, - 4, 3, 2, 1, - 2, 1, 1, 1, - 3, 2, 2, 1, - 3, 3, 2, 1}, - {0, 1, // Expected original positions - 3, 2, - 4, 5, 6, - 8, 7, 9, - 12, 11, 13, 10, - 15, 14, 16, 17, - 21, 19, 20, 18, - 22, 23, 25, 24}); -} - -TEST(Objective, RankSegmentSorterSingleGroupTest) { - RankSegmentSorterTestImpl({0, 7}, // Groups - {6, 1, 4, 3, 0, 5, 2}, // Labels - {6, 5, 4, 3, 2, 1, 0}, // Expected sorted labels - {0, 5, 2, 3, 6, 1, 4}); // Expected original positions -} - -TEST(Objective, RankSegmentSorterAscendingTest) { - RankSegmentSorterTestImpl>( - {0, 4, 7}, // Groups - {3, 1, 4, 2, // Labels - 6, 5, 7}, - {1, 2, 3, 4, // Expected sorted labels - 5, 6, 7}, - {1, 3, 0, 2, // Expected original positions - 5, 4, 6}); -} - -using CountFunctor = uint32_t (*)(const int *, uint32_t, int); -void RankItemCountImpl(const std::vector &sorted_items, CountFunctor f, - int find_val, uint32_t exp_val) { - EXPECT_NE(std::find(sorted_items.begin(), sorted_items.end(), find_val), sorted_items.end()); - EXPECT_EQ(f(&sorted_items[0], sorted_items.size(), find_val), exp_val); -} - -TEST(Objective, RankItemCountOnLeft) { - // Items sorted descendingly - std::vector sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0}; - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf, - 10, static_cast(0)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf, - 6, static_cast(2)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf, - 4, static_cast(3)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf, - 1, static_cast(7)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheLeftOf, - 0, static_cast(12)); -} - -TEST(Objective, RankItemCountOnRight) { - // Items sorted descendingly - std::vector sorted_items{10, 10, 6, 4, 4, 4, 4, 1, 1, 1, 1, 1, 0}; - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf, - 10, static_cast(11)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf, - 6, static_cast(10)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf, - 4, static_cast(6)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf, - 1, static_cast(1)); - RankItemCountImpl(sorted_items, &xgboost::obj::CountNumItemsToTheRightOf, - 0, static_cast(0)); -} - -TEST(Objective, NDCGLambdaWeightComputerTest) { - std::vector hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels - 7.8f, 5.01f, 6.96f, - 10.3f, 8.7f, 11.4f, 9.45f, 11.4f}; - dh::device_vector dlabels(hlabels); - - auto segment_label_sorter = RankSegmentSorterTestImpl( - {0, 4, 7, 12}, // Groups - hlabels, - {4.4f, 3.1f, 2.3f, 1.2f, // Expected sorted labels - 7.8f, 6.96f, 5.01f, - 11.4f, 11.4f, 10.3f, 9.45f, 8.7f}, - {3, 0, 2, 1, // Expected original positions - 4, 6, 5, - 9, 11, 7, 10, 8}); - - // Created segmented predictions for the labels from above - std::vector hpreds{-9.78f, 24.367f, 0.908f, -11.47f, - -1.03f, -2.79f, -3.1f, - 104.22f, 103.1f, -101.7f, 100.5f, 45.1f}; - dh::device_vector dpreds(hpreds); - - xgboost::obj::NDCGLambdaWeightComputer ndcg_lw_computer(dpreds.data().get(), - dlabels.data().get(), - *segment_label_sorter); - - // Where will the predictions move from its current position, if they were sorted - // descendingly? - auto dsorted_pred_pos = ndcg_lw_computer.GetPredictionSorter().GetIndexableSortedPositionsSpan(); - std::vector hsorted_pred_pos(segment_label_sorter->GetNumItems()); - dh::CopyDeviceSpanToVector(&hsorted_pred_pos, dsorted_pred_pos); - std::vector expected_sorted_pred_pos{2, 0, 1, 3, - 4, 5, 6, - 7, 8, 11, 9, 10}; - EXPECT_EQ(expected_sorted_pred_pos, hsorted_pred_pos); - - // Check group DCG values - std::vector hgroup_dcgs(segment_label_sorter->GetNumGroups()); - dh::CopyDeviceSpanToVector(&hgroup_dcgs, ndcg_lw_computer.GetGroupDcgsSpan()); - std::vector hgroups(segment_label_sorter->GetNumGroups() + 1); - dh::CopyDeviceSpanToVector(&hgroups, segment_label_sorter->GetGroupsSpan()); - EXPECT_EQ(hgroup_dcgs.size(), segment_label_sorter->GetNumGroups()); - std::vector hsorted_labels(segment_label_sorter->GetNumItems()); - dh::CopyDeviceSpanToVector(&hsorted_labels, segment_label_sorter->GetItemsSpan()); - for (size_t i = 0; i < hgroup_dcgs.size(); ++i) { - // Compute group DCG value on CPU and compare - auto gbegin = hgroups[i]; - auto gend = hgroups[i + 1]; - EXPECT_NEAR( - hgroup_dcgs[i], - xgboost::obj::NDCGLambdaWeightComputer::ComputeGroupDCGWeight(&hsorted_labels[gbegin], - gend - gbegin), - 0.01f); - } -} - -TEST(Objective, IndexableSortedItemsTest) { - std::vector hlabels = {3.1f, 1.2f, 2.3f, 4.4f, // Labels - 7.8f, 5.01f, 6.96f, - 10.3f, 8.7f, 11.4f, 9.45f, 11.4f}; - dh::device_vector dlabels(hlabels); - - auto segment_label_sorter = RankSegmentSorterTestImpl( - {0, 4, 7, 12}, // Groups - hlabels, - {4.4f, 3.1f, 2.3f, 1.2f, // Expected sorted labels - 7.8f, 6.96f, 5.01f, - 11.4f, 11.4f, 10.3f, 9.45f, 8.7f}, - {3, 0, 2, 1, // Expected original positions - 4, 6, 5, - 9, 11, 7, 10, 8}); - - segment_label_sorter->CreateIndexableSortedPositions(); - std::vector sorted_indices(segment_label_sorter->GetNumItems()); - dh::CopyDeviceSpanToVector(&sorted_indices, - segment_label_sorter->GetIndexableSortedPositionsSpan()); - std::vector expected_sorted_indices = { - 1, 3, 2, 0, - 4, 6, 5, - 9, 11, 7, 10, 8}; - EXPECT_EQ(expected_sorted_indices, sorted_indices); -} - -TEST(Objective, ComputeAndCompareMAPStatsTest) { - std::vector hlabels = {3.1f, 0.0f, 2.3f, 4.4f, // Labels - 0.0f, 5.01f, 0.0f, - 10.3f, 0.0f, 11.4f, 9.45f, 11.4f}; - dh::device_vector dlabels(hlabels); - - auto segment_label_sorter = RankSegmentSorterTestImpl( - {0, 4, 7, 12}, // Groups - hlabels, - {4.4f, 3.1f, 2.3f, 0.0f, // Expected sorted labels - 5.01f, 0.0f, 0.0f, - 11.4f, 11.4f, 10.3f, 9.45f, 0.0f}, - {3, 0, 2, 1, // Expected original positions - 5, 4, 6, - 9, 11, 7, 10, 8}); - - // Create MAP stats on the device first using the objective - std::vector hpreds{-9.78f, 24.367f, 0.908f, -11.47f, - -1.03f, -2.79f, -3.1f, - 104.22f, 103.1f, -101.7f, 100.5f, 45.1f}; - dh::device_vector dpreds(hpreds); - - xgboost::obj::MAPLambdaWeightComputer map_lw_computer(dpreds.data().get(), - dlabels.data().get(), - *segment_label_sorter); - - // Get the device MAP stats on host - std::vector dmap_stats( - segment_label_sorter->GetNumItems()); - dh::CopyDeviceSpanToVector(&dmap_stats, map_lw_computer.GetMapStatsSpan()); - - // Compute the MAP stats on host next to compare - std::vector hgroups(segment_label_sorter->GetNumGroups() + 1); - dh::CopyDeviceSpanToVector(&hgroups, segment_label_sorter->GetGroupsSpan()); - - for (size_t i = 0; i < hgroups.size() - 1; ++i) { - auto gbegin = hgroups[i]; - auto gend = hgroups[i + 1]; - std::vector lst_entry; - for (auto j = gbegin; j < gend; ++j) { - lst_entry.emplace_back(hpreds[j], hlabels[j], j); - } - std::stable_sort(lst_entry.begin(), lst_entry.end(), xgboost::obj::ListEntry::CmpPred); - - // Compute the MAP stats with this list and compare with the ones computed on the device - std::vector hmap_stats; - xgboost::obj::MAPLambdaWeightComputer::GetMAPStats(lst_entry, &hmap_stats); - for (auto j = gbegin; j < gend; ++j) { - EXPECT_EQ(dmap_stats[j].hits, hmap_stats[j - gbegin].hits); - EXPECT_NEAR(dmap_stats[j].ap_acc, hmap_stats[j - gbegin].ap_acc, 0.01f); - EXPECT_NEAR(dmap_stats[j].ap_acc_miss, hmap_stats[j - gbegin].ap_acc_miss, 0.01f); - EXPECT_NEAR(dmap_stats[j].ap_acc_add, hmap_stats[j - gbegin].ap_acc_add, 0.01f); - } - } -} - -} // namespace xgboost diff --git a/tests/python-gpu/test_gpu_ranking.py b/tests/python-gpu/test_gpu_ranking.py index d86c1aa142af..d3c661d5f91e 100644 --- a/tests/python-gpu/test_gpu_ranking.py +++ b/tests/python-gpu/test_gpu_ranking.py @@ -3,192 +3,132 @@ import shutil import urllib.request import zipfile +from typing import Dict import numpy as np +import pytest import xgboost from xgboost import testing as tm -pytestmark = tm.timeout(10) - - -class TestRanking: - @classmethod - def setup_class(cls): - """ - Download and setup the test fixtures - """ - from sklearn.datasets import load_svmlight_files - - # download the test data - cls.dpath = os.path.join(tm.demo_dir(__file__), "rank/") - src = 'https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.zip' - target = os.path.join(cls.dpath, "MQ2008.zip") - - if os.path.exists(cls.dpath) and os.path.exists(target): - print("Skipping dataset download...") - else: - urllib.request.urlretrieve(url=src, filename=target) - with zipfile.ZipFile(target, 'r') as f: - f.extractall(path=cls.dpath) - - (x_train, y_train, qid_train, x_test, y_test, qid_test, - x_valid, y_valid, qid_valid) = load_svmlight_files( - (cls.dpath + "MQ2008/Fold1/train.txt", - cls.dpath + "MQ2008/Fold1/test.txt", - cls.dpath + "MQ2008/Fold1/vali.txt"), - query_id=True, zero_based=False) - # instantiate the matrices - cls.dtrain = xgboost.DMatrix(x_train, y_train) - cls.dvalid = xgboost.DMatrix(x_valid, y_valid) - cls.dtest = xgboost.DMatrix(x_test, y_test) - # set the group counts from the query IDs - cls.dtrain.set_group([len(list(items)) - for _key, items in itertools.groupby(qid_train)]) - cls.dtest.set_group([len(list(items)) - for _key, items in itertools.groupby(qid_test)]) - cls.dvalid.set_group([len(list(items)) - for _key, items in itertools.groupby(qid_valid)]) - # save the query IDs for testing - cls.qid_train = qid_train - cls.qid_test = qid_test - cls.qid_valid = qid_valid - - def setup_weighted(x, y, groups): - # Setup weighted data - data = xgboost.DMatrix(x, y) - groups_segment = [len(list(items)) - for _key, items in itertools.groupby(groups)] - data.set_group(groups_segment) - n_groups = len(groups_segment) - weights = np.ones((n_groups,)) - data.set_weight(weights) - return data - - cls.dtrain_w = setup_weighted(x_train, y_train, qid_train) - cls.dtest_w = setup_weighted(x_test, y_test, qid_test) - cls.dvalid_w = setup_weighted(x_valid, y_valid, qid_valid) - - # model training parameters - cls.params = {'booster': 'gbtree', - 'tree_method': 'gpu_hist', - 'gpu_id': 0, - 'predictor': 'gpu_predictor'} - cls.cpu_params = {'booster': 'gbtree', - 'tree_method': 'hist', - 'gpu_id': -1, - 'predictor': 'cpu_predictor'} - - @classmethod - def teardown_class(cls): - """ - Cleanup test artifacts from download and unpacking - :return: - """ - os.remove(os.path.join(cls.dpath, "MQ2008.zip")) - shutil.rmtree(os.path.join(cls.dpath, "MQ2008")) - - @classmethod - def __test_training_with_rank_objective(cls, rank_objective, metric_name, tolerance=1e-02): - """ - Internal method that trains the dataset using the rank objective on GPU and CPU, evaluates - the metric and determines if the delta between the metric is within the tolerance level - :return: - """ - # specify validations set to watch performance - watchlist = [(cls.dtest, 'eval'), (cls.dtrain, 'train')] - - num_trees = 100 - check_metric_improvement_rounds = 10 - - evals_result = {} - cls.params['objective'] = rank_objective - cls.params['eval_metric'] = metric_name - bst = xgboost.train( - cls.params, cls.dtrain, num_boost_round=num_trees, - early_stopping_rounds=check_metric_improvement_rounds, - evals=watchlist, evals_result=evals_result) - gpu_map_metric = evals_result['train'][metric_name][-1] - - evals_result = {} - cls.cpu_params['objective'] = rank_objective - cls.cpu_params['eval_metric'] = metric_name - bstc = xgboost.train( - cls.cpu_params, cls.dtrain, num_boost_round=num_trees, - early_stopping_rounds=check_metric_improvement_rounds, - evals=watchlist, evals_result=evals_result) - cpu_map_metric = evals_result['train'][metric_name][-1] - - assert np.allclose(gpu_map_metric, cpu_map_metric, tolerance, - tolerance) - assert np.allclose(bst.best_score, bstc.best_score, tolerance, - tolerance) - - evals_result_weighted = {} - watchlist = [(cls.dtest_w, 'eval'), (cls.dtrain_w, 'train')] - bst_w = xgboost.train( - cls.params, cls.dtrain_w, num_boost_round=num_trees, - early_stopping_rounds=check_metric_improvement_rounds, - evals=watchlist, evals_result=evals_result_weighted) - weighted_metric = evals_result_weighted['train'][metric_name][-1] - # GPU Ranking is not deterministic due to `AtomicAddGpair`, - # remove tolerance once the issue is resolved. - # https://github.com/dmlc/xgboost/issues/5561 - assert np.allclose(bst_w.best_score, bst.best_score, - tolerance, tolerance) - assert np.allclose(weighted_metric, gpu_map_metric, - tolerance, tolerance) - - def test_training_rank_pairwise_map_metric(self): - """ - Train an XGBoost ranking model with pairwise objective function and compare map metric - """ - self.__test_training_with_rank_objective('rank:pairwise', 'map') - - def test_training_rank_pairwise_auc_metric(self): - """ - Train an XGBoost ranking model with pairwise objective function and compare auc metric - """ - self.__test_training_with_rank_objective('rank:pairwise', 'auc') - - def test_training_rank_pairwise_ndcg_metric(self): - """ - Train an XGBoost ranking model with pairwise objective function and compare ndcg metric - """ - self.__test_training_with_rank_objective('rank:pairwise', 'ndcg') - - def test_training_rank_ndcg_map(self): - """ - Train an XGBoost ranking model with ndcg objective function and compare map metric - """ - self.__test_training_with_rank_objective('rank:ndcg', 'map') - - def test_training_rank_ndcg_auc(self): - """ - Train an XGBoost ranking model with ndcg objective function and compare auc metric - """ - self.__test_training_with_rank_objective('rank:ndcg', 'auc') - - def test_training_rank_ndcg_ndcg(self): - """ - Train an XGBoost ranking model with ndcg objective function and compare ndcg metric - """ - self.__test_training_with_rank_objective('rank:ndcg', 'ndcg') - - def test_training_rank_map_map(self): - """ - Train an XGBoost ranking model with map objective function and compare map metric - """ - self.__test_training_with_rank_objective('rank:map', 'map') - - def test_training_rank_map_auc(self): - """ - Train an XGBoost ranking model with map objective function and compare auc metric - """ - self.__test_training_with_rank_objective('rank:map', 'auc') - - def test_training_rank_map_ndcg(self): - """ - Train an XGBoost ranking model with map objective function and compare ndcg metric - """ - self.__test_training_with_rank_objective('rank:map', 'ndcg') +pytestmark = tm.timeout(30) + + +def comp_training_with_rank_objective( + dtrain: xgboost.DMatrix, + dtest: xgboost.DMatrix, + rank_objective: str, + metric_name: str, + tolerance: float = 1e-02, +) -> None: + """Internal method that trains the dataset using the rank objective on GPU and CPU, + evaluates the metric and determines if the delta between the metric is within the + tolerance level. + + """ + # specify validations set to watch performance + watchlist = [(dtest, "eval"), (dtrain, "train")] + + params = { + "booster": "gbtree", + "tree_method": "gpu_hist", + "gpu_id": 0, + "predictor": "gpu_predictor", + } + + num_trees = 32 + check_metric_improvement_rounds = 4 + + evals_result: Dict[str, Dict] = {} + params["objective"] = rank_objective + params["eval_metric"] = metric_name + bst = xgboost.train( + params, + dtrain, + num_boost_round=num_trees, + early_stopping_rounds=check_metric_improvement_rounds, + evals=watchlist, + evals_result=evals_result, + ) + gpu_map_metric = evals_result["train"][metric_name][-1] + + evals_result = {} + + cpu_params = { + "booster": "gbtree", + "tree_method": "hist", + "gpu_id": -1, + "predictor": "cpu_predictor", + } + cpu_params["objective"] = rank_objective + cpu_params["eval_metric"] = metric_name + bstc = xgboost.train( + cpu_params, + dtrain, + num_boost_round=num_trees, + early_stopping_rounds=check_metric_improvement_rounds, + evals=watchlist, + evals_result=evals_result, + ) + cpu_map_metric = evals_result["train"][metric_name][-1] + + info = (rank_objective, metric_name) + assert np.allclose(gpu_map_metric, cpu_map_metric, tolerance, tolerance), info + assert np.allclose(bst.best_score, bstc.best_score, tolerance, tolerance), info + + evals_result_weighted: Dict[str, Dict] = {} + dtest.set_weight(np.ones((dtest.get_group().size,))) + dtrain.set_weight(np.ones((dtrain.get_group().size,))) + watchlist = [(dtest, "eval"), (dtrain, "train")] + bst_w = xgboost.train( + params, + dtrain, + num_boost_round=num_trees, + early_stopping_rounds=check_metric_improvement_rounds, + evals=watchlist, + evals_result=evals_result_weighted, + ) + weighted_metric = evals_result_weighted["train"][metric_name][-1] + + tolerance = 1e-5 + assert np.allclose(bst_w.best_score, bst.best_score, tolerance, tolerance) + assert np.allclose(weighted_metric, gpu_map_metric, tolerance, tolerance) + + +@pytest.mark.parametrize( + "objective,metric", + [ + ("rank:pairwise", "auc"), + ("rank:pairwise", "ndcg"), + ("rank:pairwise", "map"), + ("rank:ndcg", "auc"), + ("rank:ndcg", "ndcg"), + ("rank:ndcg", "map"), + ("rank:map", "auc"), + ("rank:map", "ndcg"), + ("rank:map", "map"), + ], +) +def test_with_mq2008(objective, metric) -> None: + ( + x_train, + y_train, + qid_train, + x_test, + y_test, + qid_test, + x_valid, + y_valid, + qid_valid, + ) = tm.get_mq2008(os.path.join(os.path.join(tm.demo_dir(__file__), "rank"))) + + if metric.find("map") != -1 or objective.find("map") != -1: + y_train[y_train <= 1] = 0 + y_train[y_train > 1] = 1 + y_test[y_test <= 1] = 0 + y_test[y_test > 1] = 1 + + dtrain = xgboost.DMatrix(x_train, y_train, qid=qid_train) + dtest = xgboost.DMatrix(x_test, y_test, qid=qid_test) + + comp_training_with_rank_objective(dtrain, dtest, objective, metric) diff --git a/tests/python/test_eval_metrics.py b/tests/python/test_eval_metrics.py index 3b7dc5b8e616..ebd0da144d8d 100644 --- a/tests/python/test_eval_metrics.py +++ b/tests/python/test_eval_metrics.py @@ -299,7 +299,9 @@ def test_pr_auc_multi(self): def run_pr_auc_ltr(self, tree_method): from sklearn.datasets import make_classification X, y = make_classification(128, 4, n_classes=2, random_state=1994) - ltr = xgb.XGBRanker(tree_method=tree_method, n_estimators=16) + ltr = xgb.XGBRanker( + tree_method=tree_method, n_estimators=16, objective="rank:pairwise" + ) groups = np.array([32, 32, 64]) ltr.fit( X, diff --git a/tests/python/test_ranking.py b/tests/python/test_ranking.py index 239271ec71bc..ccbdeff5dcd3 100644 --- a/tests/python/test_ranking.py +++ b/tests/python/test_ranking.py @@ -3,10 +3,13 @@ import shutil import numpy as np +import pytest +from hypothesis import given, note, settings from scipy.sparse import csr_matrix import xgboost from xgboost import testing as tm +from xgboost.testing.params import lambdarank_parameter_strategy def test_ranking_with_unweighted_data(): @@ -73,8 +76,30 @@ def test_ranking_with_weighted_data(): assert all(p <= q for p, q in zip(is_sorted, is_sorted[1:])) -class TestRanking: +def test_error_msg() -> None: + X, y, qid, w = tm.make_ltr(10, 2, 2, 2) + ranker = xgboost.XGBRanker() + with pytest.raises(ValueError, match=r"equal to the number of query groups"): + ranker.fit(X, y, qid=qid, sample_weight=y) + + +@given(lambdarank_parameter_strategy) +@settings(deadline=None, print_blob=True) +def test_lambdarank_parameters(params): + if params["objective"] == "rank:map": + rel = 2 + else: + rel = 5 + X, y, q, w = tm.make_ltr(4096, 3, 13, rel) + ranker = xgboost.XGBRanker(tree_method="hist", n_estimators=64, **params) + ranker.fit(X, y, qid=q, sample_weight=w, eval_set=[(X, y)], eval_qid=[q]) + for k, v in ranker.evals_result()["validation_0"].items(): + note(v) + assert v[-1] > v[0] + assert ranker.n_features_in_ == 3 + +class TestRanking: @classmethod def setup_class(cls): """ diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index baef690ee32e..62f5b8391a8b 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -128,12 +128,23 @@ def test_ranking(): x_test = np.random.rand(100, 10) - params = {'tree_method': 'exact', 'objective': 'rank:pairwise', - 'learning_rate': 0.1, 'gamma': 1.0, 'min_child_weight': 0.1, - 'max_depth': 6, 'n_estimators': 4} + params = { + "tree_method": "exact", + "objective": "rank:pairwise", + "learning_rate": 0.1, + "gamma": 1.0, + "min_child_weight": 0.1, + "max_depth": 6, + "n_estimators": 4, + } model = xgb.sklearn.XGBRanker(**params) - model.fit(x_train, y_train, group=train_group, - eval_set=[(x_valid, y_valid)], eval_group=[valid_group]) + model.fit( + x_train, + y_train, + group=train_group, + eval_set=[(x_valid, y_valid)], + eval_group=[valid_group], + ) assert model.evals_result() pred = model.predict(x_test) @@ -145,11 +156,17 @@ def test_ranking(): assert train_data.get_label().shape[0] == x_train.shape[0] valid_data.set_group(valid_group) - params_orig = {'tree_method': 'exact', 'objective': 'rank:pairwise', - 'eta': 0.1, 'gamma': 1.0, - 'min_child_weight': 0.1, 'max_depth': 6} - xgb_model_orig = xgb.train(params_orig, train_data, num_boost_round=4, - evals=[(valid_data, 'validation')]) + params_orig = { + "tree_method": "exact", + "objective": "rank:pairwise", + "eta": 0.1, + "gamma": 1.0, + "min_child_weight": 0.1, + "max_depth": 6, + } + xgb_model_orig = xgb.train( + params_orig, train_data, num_boost_round=4, evals=[(valid_data, "validation")] + ) pred_orig = xgb_model_orig.predict(test_data) np.testing.assert_almost_equal(pred, pred_orig) @@ -165,7 +182,11 @@ def test_ranking_metric() -> None: # sklearn compares the number of mis-classified docs, while the one in xgboost # compares the number of mis-classified pairs. ltr = xgb.XGBRanker( - eval_metric=roc_auc_score, n_estimators=10, tree_method="hist", max_depth=2 + eval_metric=roc_auc_score, + n_estimators=10, + tree_method="hist", + max_depth=2, + objective="rank:pairwise", ) ltr.fit( X, diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index a8c64713f949..33d60badce3f 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -1337,61 +1337,66 @@ def test_unsupported_params(self): SparkXGBClassifier(evals_result={}) -class XgboostRankerLocalTest(SparkTestCase): - def setUp(self): - self.session.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8") - self.ranker_df_train = self.session.createDataFrame( - [ - (Vectors.dense(1.0, 2.0, 3.0), 0, 0), - (Vectors.dense(4.0, 5.0, 6.0), 1, 0), - (Vectors.dense(9.0, 4.0, 8.0), 2, 0), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1), - (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1), - (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1), - ], - ["features", "label", "qid"], - ) - self.ranker_df_test = self.session.createDataFrame( - [ - (Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988), - (Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556), - (Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570), - (Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988), - (Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612), - (Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826), - ], - ["features", "qid", "expected_prediction"], - ) - self.ranker_df_train_1 = self.session.createDataFrame( - [ - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 9), - (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 9), - (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 9), - (Vectors.dense(1.0, 2.0, 3.0), 0, 8), - (Vectors.dense(4.0, 5.0, 6.0), 1, 8), - (Vectors.dense(9.0, 4.0, 8.0), 2, 8), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 7), - (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 7), - (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 7), - (Vectors.dense(1.0, 2.0, 3.0), 0, 6), - (Vectors.dense(4.0, 5.0, 6.0), 1, 6), - (Vectors.dense(9.0, 4.0, 8.0), 2, 6), - ] - * 4, - ["features", "label", "qid"], - ) +LTRData = namedtuple("LTRData", ("df_train", "df_test", "df_train_1")) - def test_ranker(self): - ranker = SparkXGBRanker(qid_col="qid") - assert ranker.getOrDefault(ranker.objective) == "rank:pairwise" - model = ranker.fit(self.ranker_df_train) - pred_result = model.transform(self.ranker_df_test).collect() +@pytest.fixture +def ltr_data(spark: SparkSession) -> Generator[LTRData, None, None]: + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8") + ranker_df_train = spark.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 0, 0), + (Vectors.dense(4.0, 5.0, 6.0), 1, 0), + (Vectors.dense(9.0, 4.0, 8.0), 2, 0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1), + (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1), + ], + ["features", "label", "qid"], + ) + ranker_df_test = spark.createDataFrame( + [ + (Vectors.dense(1.5, 2.0, 3.0), 0, -1.75218), + (Vectors.dense(4.5, 5.0, 6.0), 0, -0.34192949533462524), + (Vectors.dense(9.0, 4.5, 8.0), 0, 1.7251298427581787), + (Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.7521828413009644), + (Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -1.0988065004348755), + (Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 1.632674217224121), + ], + ["features", "qid", "expected_prediction"], + ) + ranker_df_train_1 = spark.createDataFrame( + [ + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 9), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 9), + (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 9), + (Vectors.dense(1.0, 2.0, 3.0), 0, 8), + (Vectors.dense(4.0, 5.0, 6.0), 1, 8), + (Vectors.dense(9.0, 4.0, 8.0), 2, 8), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 7), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 7), + (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 7), + (Vectors.dense(1.0, 2.0, 3.0), 0, 6), + (Vectors.dense(4.0, 5.0, 6.0), 1, 6), + (Vectors.dense(9.0, 4.0, 8.0), 2, 6), + ] + * 4, + ["features", "label", "qid"], + ) + yield LTRData(ranker_df_train, ranker_df_test, ranker_df_train_1) + + +class TestPySparkLocalLETOR: + def test_ranker(self, ltr_data: LTRData) -> None: + ranker = SparkXGBRanker(qid_col="qid", objective="rank:pairwise") + assert ranker.getOrDefault(ranker.objective) == "rank:pairwise" + model = ranker.fit(ltr_data.df_train) + pred_result = model.transform(ltr_data.df_test).collect() for row in pred_result: assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3) - def test_ranker_qid_sorted(self): + def test_ranker_qid_sorted(self, ltr_data: LTRData) -> None: ranker = SparkXGBRanker(qid_col="qid", num_workers=4) - assert ranker.getOrDefault(ranker.objective) == "rank:pairwise" - model = ranker.fit(self.ranker_df_train_1) - model.transform(self.ranker_df_test).collect() + assert ranker.getOrDefault(ranker.objective) == "rank:ndcg" + model = ranker.fit(ltr_data.df_train_1) + model.transform(ltr_data.df_test).collect()