From 0e2b75fe22d575e3089a08cc1d4c491bd3cd1554 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Wed, 8 Feb 2023 04:02:44 +0800 Subject: [PATCH 1/9] Reimplement lambdamart ndcg. * Simplify the implementation for both CPU and GPU. Fix JSON IO. Check labels. Put idx into cache. Optimize. File tag. Weights. Trivial tests. Compatibility. Lint. Fix swap. Device weight. tidy. Easier to read R failure. enum. Fix global configuration. Tidy. msvc omp. dask. Remove ndcg specific parameter. Drop label type for smaller PR. Fix rebase. Fixes. Don't mess with includes. Fixes. Format. Use omp util. Restore some old code. Revert. Port changes from the work on quantile loss. python binding. param. Cleanup. conditional parallel. types. Move doc. fix. need metric rewrite. rename ctx. extract. Work on metric. Metric Init estimation. extract tests, compute ties. cleanup. notes. extract optional weights. init. cleanup. old metric format. note. ndcg cache. nested. debug. fix. log2. Begin CUDA work. temp. Extract sort and latest cuda. truncation. dcg. dispatch. try different gain type. start looking into ub. note. consider writing a doc. check exp gain. Reimplement lambdamart ndcg. * Simplify the implementation for both CPU and GPU. Fix JSON IO. Check labels. Put idx into cache. Optimize. File tag. Weights. Trivial tests. Compatibility. Lint. Fix swap. Device weight. tidy. Easier to read R failure. enum. Fix global configuration. Tidy. msvc omp. dask. Remove ndcg specific parameter. Drop label type for smaller PR. Fix rebase. Fixes. Don't mess with includes. Fixes. Format. Use omp util. Restore some old code. Revert. Port changes from the work on quantile loss. python binding. param. Cleanup. conditional parallel. types. Move doc. fix. need metric rewrite. rename ctx. extract. Work on metric. Metric Init estimation. extract tests, compute ties. cleanup. notes. extract optional weights. init. cleanup. old metric format. note. ndcg cache. nested. debug. fix. log2. Begin CUDA work. temp. Extract sort and latest cuda. truncation. dcg. dispatch. try different gain type. start looking into ub. note. consider writing a doc. check exp gain. Start looking into unbiased. lambda. Extract the ndcg cache. header. cleanup namespace. small check. namespace. init with param. gain. extract. groups. Cleanup. disable. debug. remove. Revert "remove." This reverts commit ea025f9e8085f5624db8bffbee801dbcd60f3ff5. sigmoid. cleanup. metric name. check scores. note. check map. extract utilities. avoid inline. fix. header. extract more. note. note. note. start working on map. fix. continue map. map. matrix. Remove map. note. format. move check. cleanup. use cached discount, use double. cleanup. Add position to the Python interface. pass it into lambda. Full ratio. rank. comment. some work on GPU. compile. move cache initialization. descending. Fix arg sort. basic ndcg score. metric weight. config. extract. pass position again. Define a metric decorator. position. decorate metric.. return. note. irrelevant docs. fix weights. header. Share the bias. Use position check info. use cache for param. note. prepare to work on deterministic gpu. rounding. Extract op. cleanup. Use it. check label. ditch launchn. rounding. Move rounding into cache. fix check label. GPU fixes. Irrelevant doc. try to avoid inf. mad. Work on metric cache. Cleanup sort. use cache. cache others. revert. add test for metric. fixes. msg. note. remove reduce by key. comments. check position. stream. min. small cleanup. use atomic for now. fill. no inline. norm. remove op. start gpu. cleanup. use gpu for update. segmented reduce. revert. comments. comments. fix. comments. fix bounds. comments. cache. pointer. fixes. no spark. revert. Cleanup. cleanup. work on gain type. fix. notes. make metric name. remove. revert. revert. comment. revert. Move back into rank metric. Set name in objective. fix. Don't configure. note. merge tests. accept empty group. fixes. float. revert and fix. not mutable. prototype for cache. extract. convert to DMatrix. cache. Extract the cache. Port changes. fix & cleanup. cleanup. cleanup. Rename. restore. remove. header. revert. rename. rename. doc. cleanup. doc. cleanup. tests. tests. split up. jvm parameters. doc. Fix. Use cache in cox. Revert "Use cache in cox." This reverts commit e1cec376eab37c22d93180ea4ebecc828af9ca2e. Remove pairwise. iwyu. rename. Move. Merge. ranking utils. Fixes. rename. Comments. todos. Small cleanup. doc. Start working on demo. move some code here. rename. Update doc. Update doc. Work on demo. work on demo. demo. Demo. Specify the max rel degree. remove position. Fix. Work on demo. demo. Using only one fold. cache. demo. schema. comments. Lint. fix test. automake. macos. schema. test. schema. lint. fix tests. Implement MAP and pair sampling. revert sorting. Work on ranknet. remove. Don't upgrade cost if larger than. Extract GPU make pairs. error message. Remove. Cleanup some gpu tests. Move. Move NDCG test. fix weights. Move rest of the tests. Remove. Work on tests. fixes. Cleanup. header. cleanup. Update document. update document. fix build. cpplint. rename. Fixes and cleanup. Cleanup tests. lint. fix tests. debug macos non-openmp checks. macos. fix ndcg test. Ensure number of threads is smaller than the number of inputs. fix. Debug macos. fixes. Add weight normalization. Note on reproducible result. Don't normalize if it's binary. old ctk. Use old objective. Update doc. Convert pyspark tests. black. Fix rebase. Fix rebase. Start looking into CV. Hacky score function. extract parsing. Cleanup and tests. Lint & note. test check. Update document. Update tests & doc. Support custom metric as well. c++-17. cleanup old metrics. rename. Fixes. Fix cxx test. test cudf. start converting tests. pylint. fix data load. Cleanup the tests. Parameter tests. isort. Fix test. Specify src path for isort. 17 goodies. Fix rebase. Start working on ranking cache tests. Extract CPU impl. test debiasing. use index. ranking cache. comment. some work on debiasing. save the estimated bias. normalize by default. GPU norm. fix gpu unbiased. cleanup. cleanup. Remove workaround. Default to topk. Restore. Cleanup. Revert change in algorithm. norm. Move data generation process in testing for reuse. Move sort samples as well. cleanup. Generate data. lint. pylint. Fix. Fix spark test. avoid sampling with unbiased. Cleanup demo. Handle single group simulation. Numeric issues. More numeric issues. sigma. naming. Simple test. tests. brief description. Revert "brief description." This reverts commit 0b3817a683892e4fc66d2162f8434890d55cf09c. rebase. symbol. Rebase. disable normalization. Revert "disable normalization." This reverts commit ef3133d2b4a76714f3514808c6e2ae5937e6a8c2. unused variable. Apply suggestions from code review Co-authored-by: Philip Hyunsu Cho Use dataclass. Fix return type. doc. Minor fixes. Add test for custom gain. cleanup. wording. start working on precision. comments. initial work on precision. Cleanup GPU ranking metric. rigorous. work on test. adjust test. Tests. Work on binary classification support. cpu. mention it in document. callback. tests. --- demo/guide-python/learning_to_rank.py | 212 ++++++++++++++ doc/contrib/coding_guide.rst | 4 +- doc/parameter.rst | 1 + doc/tutorials/index.rst | 1 + doc/tutorials/learning_to_rank.rst | 199 +++++++++++++ include/xgboost/cache.h | 1 + include/xgboost/data.h | 7 +- include/xgboost/objective.h | 1 + .../spark/GpuXGBoostRegressorSuite.scala | 2 +- .../spark/params/LearningTaskParams.scala | 2 +- .../scala/spark/XGBoostGeneralSuite.scala | 6 +- .../scala/spark/XGBoostRegressorSuite.scala | 2 +- python-package/xgboost/_typing.py | 2 +- python-package/xgboost/sklearn.py | 10 +- python-package/xgboost/testing/__init__.py | 3 + python-package/xgboost/testing/data.py | 265 +++++++++++++++++- python-package/xgboost/testing/params.py | 14 + src/metric/elementwise_metric.cu | 10 +- tests/ci_build/lint_python.py | 2 + tests/cpp/common/test_ranking_utils.cc | 1 - tests/python/test_eval_metrics.py | 4 +- tests/python/test_ranking.py | 116 +++++++- tests/python/test_with_sklearn.py | 3 +- 23 files changed, 845 insertions(+), 23 deletions(-) create mode 100644 demo/guide-python/learning_to_rank.py create mode 100644 doc/tutorials/learning_to_rank.rst diff --git a/demo/guide-python/learning_to_rank.py b/demo/guide-python/learning_to_rank.py new file mode 100644 index 000000000000..37b7157f5029 --- /dev/null +++ b/demo/guide-python/learning_to_rank.py @@ -0,0 +1,212 @@ +""" +Getting started with learning to rank +===================================== + + .. versionadded:: 2.0.0 + +This is a demonstration of using XGBoost for learning to rank tasks using the +MSLR_10k_letor dataset. For more infomation about the dataset, please visit its +`description page `_. + +This is a two-part demo, the first one contains a basic example of using XGBoost to +train on relevance degree, and the second part simulates click data and enable the +position debiasing training. + +For an overview of learning to rank in XGBoost, please see +:doc:`Learning to Rank `. +""" +from __future__ import annotations + +import argparse +import json +import os +import pickle as pkl + +import numpy as np +import pandas as pd +from sklearn.datasets import load_svmlight_file + +import xgboost as xgb +from xgboost.testing.data import RelDataCV, simulate_clicks, sort_ltr_samples + + +def load_mlsr_10k(data_path: str, cache_path: str) -> RelDataCV: + """Load the MSLR10k dataset from data_path and cache a pickle object in cache_path. + + Returns + ------- + + A list of tuples [(X, y, qid), ...]. + + """ + root_path = os.path.expanduser(args.data) + cacheroot_path = os.path.expanduser(args.cache) + cache_path = os.path.join(cacheroot_path, "MSLR_10K_LETOR.pkl") + + # Use only the Fold1 for demo: + # Train, Valid, Test + # {S1,S2,S3}, S4, S5 + fold = 1 + + if not os.path.exists(cache_path): + fold_path = os.path.join(root_path, f"Fold{fold}") + train_path = os.path.join(fold_path, "train.txt") + valid_path = os.path.join(fold_path, "vali.txt") + test_path = os.path.join(fold_path, "test.txt") + X_train, y_train, qid_train = load_svmlight_file( + train_path, query_id=True, dtype=np.float32 + ) + y_train = y_train.astype(np.int32) + qid_train = qid_train.astype(np.int32) + + X_valid, y_valid, qid_valid = load_svmlight_file( + valid_path, query_id=True, dtype=np.float32 + ) + y_valid = y_valid.astype(np.int32) + qid_valid = qid_valid.astype(np.int32) + + X_test, y_test, qid_test = load_svmlight_file( + test_path, query_id=True, dtype=np.float32 + ) + y_test = y_test.astype(np.int32) + qid_test = qid_test.astype(np.int32) + + data = RelDataCV( + train=(X_train, y_train, qid_train), + test=(X_test, y_test, qid_test), + max_rel=4, + ) + + with open(cache_path, "wb") as fd: + pkl.dump(data, fd) + + with open(cache_path, "rb") as fd: + data = pkl.load(fd) + + return data + + +def ranking_demo(args: argparse.Namespace) -> None: + """Demonstration for learning to rank with relevance degree.""" + data = load_mlsr_10k(args.data, args.cache) + + # Sort data according to query index + X_train, y_train, qid_train = data.train + sorted_idx = np.argsort(qid_train) + X_train = X_train[sorted_idx] + y_train = y_train[sorted_idx] + qid_train = qid_train[sorted_idx] + + X_test, y_test, qid_test = data.test + sorted_idx = np.argsort(qid_test) + X_test = X_test[sorted_idx] + y_test = y_test[sorted_idx] + qid_test = qid_test[sorted_idx] + + ranker = xgb.XGBRanker( + tree_method="gpu_hist", + lambdarank_pair_method="topk", + lambdarank_num_pair_per_sample=13, + eval_metric=["ndcg@1", "ndcg@8"], + ) + ranker.fit( + X_train, + y_train, + qid=qid_train, + eval_set=[(X_test, y_test)], + eval_qid=[qid_test], + verbose=True, + ) + + +def click_data_demo(args: argparse.Namespace) -> None: + """Demonstration for learning to rank with click data.""" + data = load_mlsr_10k(args.data, args.cache) + train, test = simulate_clicks(data) + assert test is not None + + assert train.X.shape[0] == train.click.size + assert test.X.shape[0] == test.click.size + assert test.score.dtype == np.float32 + assert test.click.dtype == np.int32 + + X_train, clicks_train, y_train, qid_train = sort_ltr_samples( + train.X, + train.y, + train.qid, + train.click, + train.pos, + ) + X_test, clicks_test, y_test, qid_test = sort_ltr_samples( + test.X, + test.y, + test.qid, + test.click, + test.pos, + ) + + class ShowPosition(xgb.callback.TrainingCallback): + def after_iteration( + self, + model: xgb.Booster, + epoch: int, + evals_log: xgb.callback.TrainingCallback.EvalsLog, + ) -> bool: + config = json.loads(model.save_config()) + ti_plus = np.array(config["learner"]["objective"]["ti+"]) + tj_minus = np.array(config["learner"]["objective"]["tj-"]) + df = pd.DataFrame({"ti+": ti_plus, "tj-": tj_minus}) + print(df) + return False + + ranker = xgb.XGBRanker( + n_estimators=512, + tree_method="gpu_hist", + learning_rate=0.01, + reg_lambda=1.5, + subsample=0.8, + sampling_method="gradient_based", + # LTR specific parameters + objective="rank:ndcg", + # - Enable bias estimation + lambdarank_unbiased=True, + # - normalization (1 / (norm + 1)) + lambdarank_bias_norm=1, + # - Focus on the top 12 documents + lambdarank_num_pair_per_sample=12, + lambdarank_pair_method="topk", + ndcg_exp_gain=True, + eval_metric=["ndcg@1", "ndcg@3", "ndcg@5", "ndcg@10"], + callbacks=[ShowPosition()], + ) + ranker.fit( + X_train, + clicks_train, + qid=qid_train, + eval_set=[(X_test, y_test), (X_test, clicks_test)], + eval_qid=[qid_test, qid_test], + verbose=True, + ) + ranker.predict(X_test) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Demonstration of learning to rank using XGBoost." + ) + parser.add_argument( + "--data", + type=str, + help="Root directory of the MSLR-WEB10K data.", + required=True, + ) + parser.add_argument( + "--cache", + type=str, + help="Directory for caching processed data.", + required=True, + ) + args = parser.parse_args() + + ranking_demo(args) + click_data_demo(args) diff --git a/doc/contrib/coding_guide.rst b/doc/contrib/coding_guide.rst index a080c2a31541..f939a17b2a94 100644 --- a/doc/contrib/coding_guide.rst +++ b/doc/contrib/coding_guide.rst @@ -16,8 +16,10 @@ C++ Coding Guideline * Each line of text may contain up to 100 characters. * The use of C++ exceptions is allowed. -- Use C++11 features such as smart pointers, braced initializers, lambda functions, and ``std::thread``. +- Use C++17 features such as smart pointers, braced initializers, lambda functions, and ``std::thread``. - Use Doxygen to document all the interface code. +- We have some comments around symbols imported by headers, some of those are hinted by `include-what-you-use `_. It's not required. +- We use clang-tidy and clang-format. You can check their configuration in the root directory of the XGBoost source tree. - We have a series of automatic checks to ensure that all of our codebase complies with the Google style. Before submitting your pull request, you are encouraged to run the style checks on your machine. See :ref:`running_checks_locally`. *********************** diff --git a/doc/parameter.rst b/doc/parameter.rst index f6d3a06b671f..40ddb9247b83 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -425,6 +425,7 @@ Specify the learning task and the corresponding learning objective. The objectiv After XGBoost 1.6, both of the requirements and restrictions for using ``aucpr`` in classification problem are similar to ``auc``. For ranking task, only binary relevance label :math:`y \in [0, 1]` is supported. Different from ``map (mean average precision)``, ``aucpr`` calculates the *interpolated* area under precision recall curve using continuous interpolation. - ``pre``: Precision at :math:`k`. Supports only learning to rank task. + - ``ndcg``: `Normalized Discounted Cumulative Gain `_ - ``map``: `Mean Average Precision `_ diff --git a/doc/tutorials/index.rst b/doc/tutorials/index.rst index 310fd0170610..eb8c23726d56 100644 --- a/doc/tutorials/index.rst +++ b/doc/tutorials/index.rst @@ -21,6 +21,7 @@ See `Awesome XGBoost `_ for mo monotonic rf feature_interaction_constraint + learning_to_rank aft_survival_analysis c_api_tutorial input_format diff --git a/doc/tutorials/learning_to_rank.rst b/doc/tutorials/learning_to_rank.rst new file mode 100644 index 000000000000..72e2123bbec9 --- /dev/null +++ b/doc/tutorials/learning_to_rank.rst @@ -0,0 +1,199 @@ +################ +Learning to Rank +################ + +**Contents** + +.. contents:: + :local: + :backlinks: none + +******** +Overview +******** +Often in the context of information retrieval, learning-to-rank aims to train a model that arranges a set of query results into an ordered list `[1] <#references>`__. For surprivised learning-to-rank, the predictors are sample documents encoded as feature matrix, and the labels are relevance degree for each sample. Relevance degree can be multi-level (graded) or binary (relevant or not). The training samples are often grouped by their query index with each query group containing multiple query results. + +XGBoost implements learning to rank through a set of objective functions and performance metrics. The default objective is ``rank:ndcg`` based on the ``LambdaMART`` `[2] <#references>`__ algorithm, which in turn is an adaptation of the ``LambdaRank`` `[3] <#references>`__ framework to gradient boosting trees. For a history and a summary of the algorithm, see `[5] <#references>`__. The implementation in XGBoost features deterministic GPU computation, distributed training, position debiasing and two different pair construction strategies. + +************************************ +Training with the Pariwise Objective +************************************ +``LambdaMART`` is a pairwise ranking model, meaning that it compares the relevance degree for every pair of samples in a query group and calculate a proxy gradient for each pair. The default objective ``rank:ndcg`` is using the surrogate gradient derived from the ``ndcg`` metric. To train a XGBoost model, we need an additional sorted array called ``qid`` for specifying the query group of input samples. An example input would look like this: + ++-------+-----------+---------------+ +| QID | Label | Features | ++=======+===========+===============+ +| 1 | 0 | :math:`x_1` | ++-------+-----------+---------------+ +| 1 | 1 | :math:`x_2` | ++-------+-----------+---------------+ +| 1 | 0 | :math:`x_3` | ++-------+-----------+---------------+ +| 2 | 0 | :math:`x_4` | ++-------+-----------+---------------+ +| 2 | 1 | :math:`x_5` | ++-------+-----------+---------------+ +| 2 | 1 | :math:`x_6` | ++-------+-----------+---------------+ +| 2 | 1 | :math:`x_7` | ++-------+-----------+---------------+ + +Notice that the samples are sorted based on their query index in a non-decreasing order. In the above example, the first three samples belong to the first query and the next four samples belong to the second. For the sake of simplicity, we will use a synthetic binary learning-to-rank dataset in the following code snippets, with binary labels representing whether the result is relevant or not, and randomly assign the query group index to each sample. For an example that uses a real world dataset, please see :ref:`sphx_glr_python_examples_learning_to_rank.py`. + +.. code-block:: python + + from sklearn.datasets import make_classification + import numpy as np + + import xgboost as xgb + + # Make a synthetic ranking dataset for demonstration + X, y = make_classification(random_state=rng) + rng = np.random.default_rng(1994) + n_query_groups = 3 + qid = rng.integers(0, 3, size=X.shape[0]) + + # Sort the inputs based on query index + sorted_idx = np.argsort(qid) + X = X[sorted_idx, :] + y = y[sorted_idx] + +The simpliest way to train a ranking model is by using the scikit-learn estimator interface. Continuing the previous snippet, we can train a simple ranking model without tuning: + +.. code-block:: python + + ranker = xgb.XGBRanker(tree_method="hist", lambdarank_num_pair_per_sample=8, objective="rank:ndcg", lambdarank_pair_method="topk") + ranker.fit(X, y, qid=qid) + +Please note that, as of writing, there's no learning-to-rank interface in scikit-learn. As a result, the :py:class:`xgboost.XGBRanker` class does not fully conform the scikit-learn estimator guideline and can not be directly used with some of its utility functions. For instances, the ``auc_score`` and ``ndcg_score`` in scikit-learn don't consider query group information nor the pairwise loss. Most of the metrics are implemented as part of XGBoost, but to use scikit-learn utilities like :py:func:`sklearn.model_selection.cross_validation`, we need to make some adjustments in order to pass the ``qid`` as an additional parameter for :py:meth:`xgboost.XGBRanker.score`. Given a data frame ``X`` (either pandas or cuDF), add the column ``qid`` as follows: + +.. code-block:: python + + df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])) + df["qid"] = qid + ranker.fit(df, y) # No need to pass qid as a separate argument + + from sklearn.model_selection import StratifiedGroupKFold, cross_val_score + # Works with cv in scikit-learn, along with HPO utilities like GridSearchCV + kfold = StratifiedGroupKFold(shuffle=False) + cross_val_score(ranker, df, y, cv=kfold, groups=df.qid) + +The above snippets build a model using ``LambdaMART`` with the ``NDCG@8`` metric. The outputs of a ranker are relevance scores: + +.. code-block:: python + + scores = ranker.predict(X) + sorted_idx = np.argsort(scores)[::-1] + # Sort the relevance scores from most relevant to least relevant + scores = scores[sorted_idx] + + +************* +Position Bias +************* + +.. versionadded:: 2.0.0 + +.. note:: + + The feature is considered experimental. This is a heated research area, and your input is much appreciated! + +Obtaining real relevance degrees for query results is an expensive and strenuous, requiring human labelers to label all results one by one. When such labeling task is infeasible, we might want to train the learning-to-rank model on user click data instead, as it is relatively easy to collect. Another advantage of using click data directly is that it can reflect the most up-to-date user preferences `[1] <#references>`__. However, user clicks are often biased, as users tend to choose results that are displayed in higher positions. User clicks are also noisy, where users might accidentally click on irrelevant documents. To ameliorate these issues, XGBoost implements the ``Unbiased LambdaMART`` `[4] <#references>`__ algorithm to debias the position-dependent click data. The feature can be enabled by the ``lambdarank_unbiased`` parameter; see :ref:`ltr-param` for related options and :ref:`sphx_glr_python_examples_learning_to_rank.py` for a worked example with simulated user clicks. + +**** +Loss +**** + +XGBoost implements different ``LambdaMART`` objectives based on different metrics. We list them here as a reference. Other than those used as objective function, XGBoost also implements metrics like ``pre`` (for precision) for evaluation. See :doc:`parameters ` for available options and the following sections for how to choose these objectives based of the amount of effective pairs. + +* NDCG + +`Normalized Discounted Cumulative Gain` ``NDCG`` can be used with both binary relevance and multi-level relevance. If you are not sure about your data, this metric can be used as the default. The name for the objective is ``rank:ndcg``. + + +* MAP + +`Mean average precision` ``MAP`` is a binary measure. It can be used when the relevance label is 0 or 1. The name for the objective is ``rank:map``. + + +* Pairwise + +The `LambdaMART` algorithm scales the logistic loss with learning to rank metrics like ``NDCG`` in the hope of including ranking information into the loss function. The ``rank:pairwise`` loss is the original version of the pairwise loss, also known as the `RankNet loss` `[7] <#references>`__ or the `pairwise logistic loss`. Unlike the ``rank:map`` and the ``rank:ndcg``, no scaling is applied (:math:`|\Delta Z_{ij}| = 1`). + +Whether scaling with a LTR metric is actually more effective is still up for debate; `[8] <#references>`__ provides a theoretical foundation for general lambda loss functions and some insights into the framework. + +****************** +Constructing Pairs +****************** + +There are two implemented strategies for constructing document pairs for :math:`\lambda`-gradient calculation. The first one is the ``mean`` method, another one is the ``topk`` method. The preferred strategy can be specified by the ``lambdarank_pair_method`` parameter. + +For the ``mean`` strategy, XGBoost samples ``lambdarank_num_pair_per_sample`` pairs for each document in a query list. For example, given a list of 3 documents and ``lambdarank_num_pair_per_sample`` is set to 2, XGBoost will randomly sample 6 pairs, assuming the labels for these documents are different. On the other hand, if the pair method is set to ``topk``, XGBoost constructs about :math:`k \times |query|` number of pairs with :math:`|query|` pairs for each sample at the top :math:`k = lambdarank\_num\_pair` position. The number of pairs counted here is an approximation since we skip pairs that have the same label. + +********************* +Obtaining Good Result +********************* + +Learning to rank is a sophisticated task and an active research area. It's not trivial to train a model that generalizes well. There are multiple loss functions available in XGBoost along with a set of hyperparameters. This section contains some hints for how to choose hyperparameters as a starting point. One can further optimize the model by tuning these hyperparameters. + +The first question would be how to choose an objective that matches the task at hand. If your input data has multi-level relevance degrees, then either ``rank:ndcg`` or ``rank:pairwise`` should be used. However, when the input has binary labels, we have multiple options based on the target metric. `[6] <#references>`__ provides some guidelines on this topic and users are encouraged to see the analysis done in their work. The choice should be based on the number of `effective pairs`, which refers to the number of pairs that can generate non-zero gradient and contribute to training. `LambdaMART` with ``MRR`` has the least amount of effective pairs as the :math:`\lambda`-gradient is only non-zero when the pair contains a non-relevant document ranked higher than the top relevant document. As a result, it's not implemented in XGBoost. Since ``NDCG`` is a multi-level metric, it usually generate more effective pairs than ``MAP``. + +However, when there are sufficiently many effective pairs, it's shown in `[6] <#references>`__ that matching the target metric with the objective is of significance. When the target metric is ``MAP`` and you are using a large dataset that can provide a sufficient amount of effective pairs, ``rank:map`` can in theory yield higher ``MAP`` value than ``rank:ndcg``. + +The consideration of effective pairs also applies to the choice of pair method (``lambdarank_pair_method``) and the number of pairs for each sample (``lambdarank_num_pair_per_sample``). For example, the mean-``NDCG`` considers more pairs than ``NDCG@10``, so the former generates more effective pairs and provides more granularity than the latter. Also, using the ``mean`` strategy can help the model generalize with random sampling. However, one might want to focus the training on the top :math:`k` documents instead of using all pairs, to better fit their real-world application. + +When using the mean strategy for generating pairs, where the target metric (like ``NDCG``) is computed over the whole query list, users can specify how many pairs should be generated per each document, by setting the ``lambdarank_num_pair_per_sample``. XGBoost will randomly sample ``lambdarank_num_pair_per_sample`` pairs for each element in the query group (:math:`|pairs| = |query| \times num\_pairsample`). Often, setting it to 1 can produce reasonable results. In cases where performance is inadequate due to insufficient number of effective pairs being generated, set ``lambdarank_num_pair_per_sample`` to a higher value. As more document pairs are generated, more effective pairs will be generated as well. + +On the other hand, if you are prioritizing the top :math:`k` documents, the ``lambdarank_num_pair_per_sample`` should be set slightly higher than :math:`k` (with a few more documents) to obtain a good training result. + +**Summary** If you have large amount of training data: + +* Use the target-matching objective. +* Choose the ``topk`` strategy for generating document pairs (if it's appropriate for your application). + +On the other hand, if you have comparatively small amount of training data: + +* Select ``NDCG`` or the RankNet loss (``rank:pairwise``). +* Choose the ``mean`` strategy for generating document pairs, to obtain more effective pairs. + +For any method chosen, you can modify ``lambdarank_num_pair_per_sample`` to control the amount of pairs generated. + +******************** +Distributed Training +******************** +XGBoost implements distributed learning-to-rank with integration of multiple frameworks including Dask, Spark, and PySpark. The interface is similar to the single-node counterpart. Please refer to document of the respective XGBoost interface for details. Scattering a query group onto multiple workers is theoretically sound but can affect the model accuracy. For most of the use cases, the small discrepancy is not an issue, as the amount of training data is usually large when distributed training is used. As a result, users don't need to partition the data based on query groups. As long as each data partition is correctly sorted by query IDs, XGBoost can aggregate sample gradients accordingly. + +******************* +Reproducible Result +******************* + +Like any other tasks, XGBoost should generate reproducible results given the same hardware and software environments (and data partitions, if distributed interface is used). Even when the underlying environment has changed, the result should still be consistent. However, when the ``lambdarank_pair_method`` is set to ``mean``, XGBoost uses random sampling, and results may differ depending on the platform used. The random number generator used on Windows (Microsoft Visual C++) is different from the ones used on other platforms like Linux (GCC, Clang), so the output varies significantly between these platforms. + +********** +References +********** + +[1] Tie-Yan Liu. 2009. "`Learning to Rank for Information Retrieval`_". Found. Trends Inf. Retr. 3, 3 (March 2009), 225–331. + +[2] Christopher J. C. Burges, Robert Ragno, and Quoc Viet Le. 2006. "`Learning to rank with nonsmooth cost functions`_". In Proceedings of the 19th International Conference on Neural Information Processing Systems (NIPS'06). MIT Press, Cambridge, MA, USA, 193–200. + +[3] Wu, Q., Burges, C.J.C., Svore, K.M. et al. "`Adapting boosting for information retrieval measures`_". Inf Retrieval 13, 254–270 (2010). + +[4] Ziniu Hu, Yang Wang, Qu Peng, Hang Li. "`Unbiased LambdaMART: An Unbiased Pairwise Learning-to-Rank Algorithm`_". Proceedings of the 2019 World Wide Web Conference. + +[5] Burges, Chris J.C. "`From RankNet to LambdaRank to LambdaMART: An Overview`_". MSR-TR-2010-82 + +[6] Pinar Donmez, Krysta M. Svore, and Christopher J.C. Burges. 2009. "`On the local optimality of LambdaRank`_". In Proceedings of the 32nd international ACM SIGIR conference on Research and development in information retrieval (SIGIR '09). Association for Computing Machinery, New York, NY, USA, 460–467. + +[7] Chris Burges, Tal Shaked, Erin Renshaw, Ari Lazier, Matt Deeds, Nicole Hamilton, and Greg Hullender. 2005. "`Learning to rank using gradient descent`_". In Proceedings of the 22nd international conference on Machine learning (ICML '05). Association for Computing Machinery, New York, NY, USA, 89–96. + +[8] Xuanhui Wang and Cheng Li and Nadav Golbandi and Mike Bendersky and Marc Najork. 2018. "`The LambdaLoss Framework for Ranking Metric Optimization`_". Proceedings of The 27th ACM International Conference on Information and Knowledge Management (CIKM '18). + +.. _`Learning to Rank for Information Retrieval`: https://doi.org/10.1561/1500000016 +.. _`Learning to rank with nonsmooth cost functions`: https://dl.acm.org/doi/10.5555/2976456.2976481 +.. _`Adapting boosting for information retrieval measures`: https://doi.org/10.1007/s10791-009-9112-1 +.. _`Unbiased LambdaMART: An Unbiased Pairwise Learning-to-Rank Algorithm`: https://dl.acm.org/doi/10.1145/3308558.3313447 +.. _`From RankNet to LambdaRank to LambdaMART: An Overview`: https://www.microsoft.com/en-us/research/publication/from-ranknet-to-lambdarank-to-lambdamart-an-overview/ +.. _`On the local optimality of LambdaRank`: https://doi.org/10.1145/1571941.1572021 +.. _`Learning to rank using gradient descent`: https://doi.org/10.1145/1102351.1102363 +.. _`The LambdaLoss Framework for Ranking Metric Optimization`: https://dl.acm.org/doi/10.1145/3269206.3271784 diff --git a/include/xgboost/cache.h b/include/xgboost/cache.h index 32e1b21ac3f6..279fef0907e4 100644 --- a/include/xgboost/cache.h +++ b/include/xgboost/cache.h @@ -15,6 +15,7 @@ #include // for move #include // for vector + namespace xgboost { class DMatrix; /** diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 6305abff840e..937475a9a015 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -17,6 +17,8 @@ #include #include +#include // std::size_t +#include // std::uint64_t #include #include #include @@ -60,9 +62,8 @@ class MetaInfo { linalg::Tensor labels; /*! \brief data split mode */ DataSplitMode data_split_mode{DataSplitMode::kRow}; - /*! - * \brief the index of begin and end of a group - * needed when the learning task is ranking. + /** + * \brief the index of begin and end of a group, needed when the learning task is ranking. */ std::vector group_ptr_; // NOLINT /*! \brief weights of each instance, optional */ diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index a04d2e453df7..4fecf56884b2 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -11,6 +11,7 @@ #include #include #include +#include // for Json, Null #include #include diff --git a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala index 5342aa563621..b8dca5d7040e 100644 --- a/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark-gpu/src/test/scala/ml/dmlc/xgboost4j/scala/rapids/spark/GpuXGBoostRegressorSuite.scala @@ -220,7 +220,7 @@ class GpuXGBoostRegressorSuite extends GpuTestSuite { test("Ranking: train with Group") { withGpuSparkSession(enableCsvConf()) { spark => - val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "rank:pairwise", + val xgbParam = Map("eta" -> 0.1f, "max_depth" -> 2, "objective" -> "rank:ndcg", "num_round" -> 10, "num_workers" -> 1, "tree_method" -> "gpu_hist", "features_cols" -> featureNames, "label_col" -> labelName) val Array(trainingDf, testDf) = spark.read.option("header", "true").schema(schema) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index f63865fabc2d..6aec4d36ed6f 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -25,7 +25,7 @@ private[spark] trait LearningTaskParams extends Params { /** * Specify the learning task and the corresponding learning objective. * options: reg:squarederror, reg:squaredlogerror, reg:logistic, binary:logistic, binary:logitraw, - * count:poisson, multi:softmax, multi:softprob, rank:pairwise, reg:gamma. + * count:poisson, multi:softmax, multi:softprob, rank:ndcg, reg:gamma. * default: reg:squarederror */ final val objective = new Param[String](this, "objective", diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index c1e34224caca..d93b182e043e 100755 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -201,7 +201,7 @@ class XGBoostGeneralSuite extends AnyFunSuite with TmpFolderPerSuite with PerTes sc, buildTrainingRDD, List("eta" -> "1", "max_depth" -> "6", - "objective" -> "rank:pairwise", "num_round" -> 5, "num_workers" -> numWorkers, + "objective" -> "rank:ndcg", "num_round" -> 5, "num_workers" -> numWorkers, "custom_eval" -> null, "custom_obj" -> null, "use_external_memory" -> false, "missing" -> Float.NaN).toMap) @@ -268,7 +268,7 @@ class XGBoostGeneralSuite extends AnyFunSuite with TmpFolderPerSuite with PerTes val training = buildDataFrameWithGroup(Ranking.train, 5) val Array(train, eval1, eval2) = training.randomSplit(Array(0.6, 0.2, 0.2), 0) val paramMap1 = Map("eta" -> "1", "max_depth" -> "6", - "objective" -> "rank:pairwise", + "objective" -> "rank:ndcg", "num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group") val xgb1 = new XGBoostRegressor(paramMap1).setEvalSets(Map("eval1" -> eval1, "eval2" -> eval2)) val model1 = xgb1.fit(train) @@ -281,7 +281,7 @@ class XGBoostGeneralSuite extends AnyFunSuite with TmpFolderPerSuite with PerTes assert(model1.summary.trainObjectiveHistory !== model1.summary.validationObjectiveHistory(1)) val paramMap2 = Map("eta" -> "1", "max_depth" -> "6", - "objective" -> "rank:pairwise", + "objective" -> "rank:ndcg", "num_round" -> 5, "num_workers" -> numWorkers, "group_col" -> "group", "eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2)) val xgb2 = new XGBoostRegressor(paramMap2) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala index efcb38cf62e1..1bdea7a827bd 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressorSuite.scala @@ -121,7 +121,7 @@ class XGBoostRegressorSuite extends AnyFunSuite with PerTest with TmpFolderPerSu test("ranking: use group data") { val paramMap = Map("eta" -> "1", "max_depth" -> "6", "silent" -> "1", - "objective" -> "rank:pairwise", "num_workers" -> numWorkers, "num_round" -> 5, + "objective" -> "rank:ndcg", "num_workers" -> numWorkers, "num_round" -> 5, "group_col" -> "group", "tree_method" -> treeMethod) val trainingDF = buildDataFrameWithGroup(Ranking.train) diff --git a/python-package/xgboost/_typing.py b/python-package/xgboost/_typing.py index 774681031cec..39952aca9845 100644 --- a/python-package/xgboost/_typing.py +++ b/python-package/xgboost/_typing.py @@ -31,7 +31,7 @@ PathLike = Union[str, os.PathLike] CupyT = ArrayLike # maybe need a stub for cupy arrays NumpyOrCupy = Any -NumpyDType = Union[str, Type[np.number]] +NumpyDType = Union[str, Type[np.number]] # pylint: disable=invalid-name PandasDType = Any # real type is pandas.core.dtypes.base.ExtensionDtype FloatCompatible = Union[float, np.float32, np.float64] diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 9b5949cdb8b7..43d531a9d7bc 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1796,7 +1796,11 @@ def _get_qid( @xgboost_model_doc( - """Implementation of the Scikit-Learn API for XGBoost Ranking.""", + """Implementation of the Scikit-Learn API for XGBoost Ranking. + +See :doc:`Learning to Rank ` for an introducion. + + """, ["estimators", "model"], end_note=""" .. note:: @@ -1845,7 +1849,7 @@ def _get_qid( class XGBRanker(XGBModel, XGBRankerMixIn): # pylint: disable=missing-docstring,too-many-arguments,invalid-name @_deprecate_positional_args - def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any): + def __init__(self, *, objective: str = "rank:ndcg", **kwargs: Any): super().__init__(objective=objective, **kwargs) if callable(self.objective): raise ValueError("custom objective function not supported by XGBRanker") @@ -2029,7 +2033,7 @@ def fit( self._Booster = train( params, train_dmatrix, - self.get_num_boosting_rounds(), + num_boost_round=self.get_num_boosting_rounds(), early_stopping_rounds=early_stopping_rounds, evals=evals, evals_result=evals_result, diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 70e5361011b3..de44d8882959 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -36,11 +36,14 @@ from xgboost.core import ArrayLike from xgboost.sklearn import SklObjective from xgboost.testing.data import ( + ClickFold, + RelDataCV, get_california_housing, get_cancer, get_digits, get_sparse, memory, + simulate_clicks, ) hypothesis = pytest.importorskip("hypothesis") diff --git a/python-package/xgboost/testing/data.py b/python-package/xgboost/testing/data.py index 477d0cf3d6f0..b5f07d98161e 100644 --- a/python-package/xgboost/testing/data.py +++ b/python-package/xgboost/testing/data.py @@ -1,11 +1,14 @@ +# pylint: disable=invalid-name """Utilities for data generation.""" import os import zipfile -from typing import Any, Generator, List, Tuple, Union +from dataclasses import dataclass +from typing import Any, Generator, List, NamedTuple, Optional, Tuple, Union from urllib import request import numpy as np import pytest +from numpy import typing as npt from numpy.random import Generator as RNG from scipy import sparse @@ -340,3 +343,263 @@ def get_mq2008( y_valid, qid_valid, ) + + +RelData = Tuple[sparse.csr_matrix, npt.NDArray[np.int32], npt.NDArray[np.int32]] + + +@dataclass +class ClickFold: + """A structure containing information about generated user-click data.""" + + X: sparse.csr_matrix + y: npt.NDArray[np.int32] + qid: npt.NDArray[np.int32] + score: npt.NDArray[np.float32] + click: npt.NDArray[np.int32] + pos: npt.NDArray[np.int64] + + +class RelDataCV(NamedTuple): + """Simple data struct for holding a train-test split of a learning to rank dataset.""" + + train: RelData + test: RelData + max_rel: int + + def is_binary(self) -> bool: + """Whether the label consists of binary relevance degree.""" + return self.max_rel == 1 + + +class PBM: # pylint: disable=too-few-public-methods + """Simulate click data with position bias model. There are other models available in + `ULTRA `_ like the cascading model. + + References + ---------- + Unbiased LambdaMART: An Unbiased Pairwise Learning-to-Rank Algorithm + + """ + + def __init__(self, eta: float) -> None: + # click probability for each relevance degree. (from 0 to 4) + self.click_prob = np.array([0.1, 0.16, 0.28, 0.52, 1.0]) + exam_prob = np.array( + [0.68, 0.61, 0.48, 0.34, 0.28, 0.20, 0.11, 0.10, 0.08, 0.06] + ) + # Observation probability, encoding positional bias for each position + self.exam_prob = np.power(exam_prob, eta) + + def sample_clicks_for_query( + self, labels: npt.NDArray[np.int32], position: npt.NDArray[np.int64] + ) -> npt.NDArray[np.int32]: + """Sample clicks for one query based on input relevance degree and position. + + Parameters + ---------- + + labels : + relevance_degree + + """ + labels = np.array(labels, copy=True) + + click_prob = np.zeros(labels.shape) + # minimum + labels[labels < 0] = 0 + # maximum + labels[labels >= len(self.click_prob)] = -1 + click_prob = self.click_prob[labels] + + exam_prob = np.zeros(labels.shape) + assert position.size == labels.size + ranks = np.array(position, copy=True) + # maximum + ranks[ranks >= self.exam_prob.size] = -1 + exam_prob = self.exam_prob[ranks] + + rng = np.random.default_rng(1994) + prob = rng.random(size=labels.shape[0], dtype=np.float32) + + clicks: npt.NDArray[np.int32] = np.zeros(labels.shape, dtype=np.int32) + clicks[prob < exam_prob * click_prob] = 1 + return clicks + + +def rlencode(x: npt.NDArray[np.int32]) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]: + """Run length encoding using numpy, modified from: + https://gist.github.com/nvictus/66627b580c13068589957d6ab0919e66 + + """ + x = np.asarray(x) + n = x.size + starts = np.r_[0, np.flatnonzero(~np.isclose(x[1:], x[:-1], equal_nan=True)) + 1] + lengths = np.diff(np.r_[starts, n]) + values = x[starts] + indptr = np.append(starts, np.array([x.size])) + + return indptr, lengths, values + + +def init_rank_score( + X: sparse.csr_matrix, + y: npt.NDArray[np.int32], + qid: npt.NDArray[np.int32], + sample_rate: float = 0.1, +) -> npt.NDArray[np.float32]: + """We use XGBoost to generate the initial score instead of SVMRank for + simplicity. Sample rate is set to 0.1 by default so that we can test with small + datasets. + + """ + # random sample + rng = np.random.default_rng(1994) + n_samples = int(X.shape[0] * sample_rate) + index = np.arange(0, X.shape[0], dtype=np.uint64) + rng.shuffle(index) + index = index[:n_samples] + + X_train = X[index] + y_train = y[index] + qid_train = qid[index] + + # Sort training data based on query id, required by XGBoost. + sorted_idx = np.argsort(qid_train) + X_train = X_train[sorted_idx] + y_train = y_train[sorted_idx] + qid_train = qid_train[sorted_idx] + + ltr = xgboost.XGBRanker(objective="rank:ndcg", tree_method="hist") + ltr.fit(X_train, y_train, qid=qid_train) + + # Use the original order of the data. + scores = ltr.predict(X) + return scores + + +def simulate_one_fold( + fold: Tuple[sparse.csr_matrix, npt.NDArray[np.int32], npt.NDArray[np.int32]], + scores_fold: npt.NDArray[np.float32], +) -> ClickFold: + """Simulate clicks for one fold.""" + X_fold, y_fold, qid_fold = fold + assert qid_fold.dtype == np.int32 + + qids = np.unique(qid_fold) + + position = np.empty((y_fold.size,), dtype=np.int64) + clicks = np.empty((y_fold.size,), dtype=np.int32) + pbm = PBM(eta=1.0) + + # Avoid grouping by qid as we want to preserve the original data partition by + # the dataset authors. + for q in qids: + qid_mask = q == qid_fold + qid_mask = qid_mask.reshape(qid_mask.shape[0]) + query_scores = scores_fold[qid_mask] + # Initial rank list, scores sorted to decreasing order + query_position = np.argsort(query_scores)[::-1] + position[qid_mask] = query_position + # get labels + relevance_degrees = y_fold[qid_mask] + query_clicks = pbm.sample_clicks_for_query(relevance_degrees, query_position) + clicks[qid_mask] = query_clicks + + assert X_fold.shape[0] == qid_fold.shape[0], (X_fold.shape, qid_fold.shape) + assert X_fold.shape[0] == clicks.shape[0], (X_fold.shape, clicks.shape) + + return ClickFold(X_fold, y_fold, qid_fold, scores_fold, clicks, position) + + +# pylint: disable=too-many-locals +def simulate_clicks(cv_data: RelDataCV) -> Tuple[ClickFold, Optional[ClickFold]]: + """Simulate click data using position biased model (PBM).""" + X, y, qid = list(zip(cv_data.train, cv_data.test)) + + # ptr to train-test split + indptr = np.array([0] + [v.shape[0] for v in X]) + indptr = np.cumsum(indptr) + + assert len(indptr) == 2 + 1 # train, test + X_full = sparse.vstack(X) + y_full = np.concatenate(y) + qid_full = np.concatenate(qid) + + # Obtain initial relevance score for click simulation + scores_full = init_rank_score(X_full, y_full, qid_full) + # partition it back to (train, test) tuple + scores = [scores_full[indptr[i - 1] : indptr[i]] for i in range(1, indptr.size)] + + X_lst, y_lst, q_lst, s_lst, c_lst, p_lst = [], [], [], [], [], [] + for i in range(indptr.size - 1): + fold = simulate_one_fold((X[i], y[i], qid[i]), scores[i]) + X_lst.append(fold.X) + y_lst.append(fold.y) + q_lst.append(fold.qid) + s_lst.append(fold.score) + c_lst.append(fold.click) + p_lst.append(fold.pos) + + scores_check_1 = [s_lst[i] for i in range(indptr.size - 1)] + for i in range(2): + assert (scores_check_1[i] == scores[i]).all() + + if len(X_lst) == 1: + train = ClickFold(X_lst[0], y_lst[0], q_lst[0], s_lst[0], c_lst[0], p_lst[0]) + test = None + else: + train, test = ( + ClickFold(X_lst[i], y_lst[i], q_lst[i], s_lst[i], c_lst[i], p_lst[i]) + for i in range(len(X_lst)) + ) + return train, test + + +def sort_ltr_samples( + X: sparse.csr_matrix, + y: npt.NDArray[np.int32], + qid: npt.NDArray[np.int32], + clicks: npt.NDArray[np.int32], + pos: npt.NDArray[np.int64], +) -> Tuple[ + sparse.csr_matrix, + npt.NDArray[np.int32], + npt.NDArray[np.int32], + npt.NDArray[np.int32], +]: + """Sort data based on query index and position.""" + sorted_idx = np.argsort(qid) + X = X[sorted_idx] + clicks = clicks[sorted_idx] + qid = qid[sorted_idx] + pos = pos[sorted_idx] + + indptr, _, _ = rlencode(qid) + + for i in range(1, indptr.size): + beg = indptr[i - 1] + end = indptr[i] + + assert beg < end, (beg, end) + assert np.unique(qid[beg:end]).size == 1, (beg, end) + + query_pos = pos[beg:end] + assert query_pos.min() == 0, query_pos.min() + assert query_pos.max() >= query_pos.size - 1, ( + query_pos.max(), + query_pos.size, + i, + np.unique(qid[beg:end]), + ) + sorted_idx = np.argsort(query_pos) + + X[beg:end] = X[beg:end][sorted_idx] + clicks[beg:end] = clicks[beg:end][sorted_idx] + y[beg:end] = y[beg:end][sorted_idx] + # not necessary + qid[beg:end] = qid[beg:end][sorted_idx] + + data = X, clicks, y, qid + + return data diff --git a/python-package/xgboost/testing/params.py b/python-package/xgboost/testing/params.py index e6ba73e1f541..8dc91b6017be 100644 --- a/python-package/xgboost/testing/params.py +++ b/python-package/xgboost/testing/params.py @@ -67,3 +67,17 @@ "max_cat_threshold": strategies.integers(1, 128), } ) + +lambdarank_parameter_strategy = strategies.fixed_dictionaries( + { + "lambdarank_unbiased": strategies.sampled_from([True, False]), + "lambdarank_pair_method": strategies.sampled_from(["topk", "mean"]), + "lambdarank_num_pair_per_sample": strategies.integers(1, 8), + "lambdarank_bias_norm": strategies.floats(0.5, 2.0), + "objective": strategies.sampled_from( + ["rank:ndcg", "rank:map", "rank:pairwise"] + ), + } +).filter( + lambda x: not (x["lambdarank_unbiased"] and x["lambdarank_pair_method"] == "mean") +) diff --git a/src/metric/elementwise_metric.cu b/src/metric/elementwise_metric.cu index bd1b0b2d89d3..b6888610b586 100644 --- a/src/metric/elementwise_metric.cu +++ b/src/metric/elementwise_metric.cu @@ -483,9 +483,13 @@ class QuantileError : public MetricNoCache { const char* Name() const override { return "quantile"; } void LoadConfig(Json const& in) override { - auto const& name = get(in["name"]); - CHECK_EQ(name, "quantile"); - FromJson(in["quantile_loss_param"], ¶m_); + auto const& obj = get(in); + auto it = obj.find("quantile_loss_param"); + if (it != obj.cend()) { + FromJson(it->second, ¶m_); + auto const& name = get(in["name"]); + CHECK_EQ(name, "quantile"); + } } void SaveConfig(Json* p_out) const override { auto& out = *p_out; diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index a6ef0b8049de..90c52aad4ab2 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -36,6 +36,7 @@ class LintersPaths: "demo/guide-python/individual_trees.py", "demo/guide-python/quantile_regression.py", "demo/guide-python/multioutput_regression.py", + "demo/guide-python/learning_to_rank.py", # CI "tests/ci_build/lint_python.py", "tests/ci_build/test_r_package.py", @@ -76,6 +77,7 @@ class LintersPaths: "demo/guide-python/individual_trees.py", "demo/guide-python/quantile_regression.py", "demo/guide-python/multioutput_regression.py", + "demo/guide-python/learning_to_rank.py", # CI "tests/ci_build/lint_python.py", "tests/ci_build/test_r_package.py", diff --git a/tests/cpp/common/test_ranking_utils.cc b/tests/cpp/common/test_ranking_utils.cc index 919102278b98..171d8af2eb87 100644 --- a/tests/cpp/common/test_ranking_utils.cc +++ b/tests/cpp/common/test_ranking_utils.cc @@ -99,7 +99,6 @@ void TestRankingCache(Context const* ctx) { auto rank_idx = cache.SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan()); - for (std::size_t i = 0; i < rank_idx.size(); ++i) { ASSERT_EQ(rank_idx[i], rank_idx.size() - i - 1); } diff --git a/tests/python/test_eval_metrics.py b/tests/python/test_eval_metrics.py index 0328765f5cbb..147c87a27922 100644 --- a/tests/python/test_eval_metrics.py +++ b/tests/python/test_eval_metrics.py @@ -299,7 +299,9 @@ def test_pr_auc_multi(self): def run_pr_auc_ltr(self, tree_method): from sklearn.datasets import make_classification X, y = make_classification(128, 4, n_classes=2, random_state=1994) - ltr = xgb.XGBRanker(tree_method=tree_method, n_estimators=16) + ltr = xgb.XGBRanker( + tree_method=tree_method, n_estimators=16, objective="rank:pairwise" + ) groups = np.array([32, 32, 64]) ltr.fit( X, diff --git a/tests/python/test_ranking.py b/tests/python/test_ranking.py index 088b681fff5a..8bdeb070ffbe 100644 --- a/tests/python/test_ranking.py +++ b/tests/python/test_ranking.py @@ -1,12 +1,57 @@ import itertools +import json import os import shutil +from typing import Optional import numpy as np +import pytest +from hypothesis import given, note, settings from scipy.sparse import csr_matrix import xgboost from xgboost import testing as tm +from xgboost.testing.data import RelDataCV, simulate_clicks, sort_ltr_samples +from xgboost.testing.params import lambdarank_parameter_strategy + + +def test_ndcg_custom_gain(): + def ndcg_gain(y: np.ndarray) -> np.ndarray: + return np.exp2(y.astype(np.float64)) - 1.0 + + X, y, q, w = tm.make_ltr(n_samples=1024, n_features=4, n_query_groups=3, max_rel=3) + y_gain = ndcg_gain(y) + + byxgb = xgboost.XGBRanker(tree_method="hist", ndcg_exp_gain=True, n_estimators=10) + byxgb.fit( + X, + y, + qid=q, + sample_weight=w, + eval_set=[(X, y)], + eval_qid=(q,), + sample_weight_eval_set=(w,), + verbose=True, + ) + byxgb_json = json.loads(byxgb.get_booster().save_raw(raw_format="json")) + + bynp = xgboost.XGBRanker(tree_method="hist", ndcg_exp_gain=False, n_estimators=10) + bynp.fit( + X, + y_gain, + qid=q, + sample_weight=w, + eval_set=[(X, y_gain)], + eval_qid=(q,), + sample_weight_eval_set=(w,), + verbose=True, + ) + bynp_json = json.loads(bynp.get_booster().save_raw(raw_format="json")) + + # Remove the difference in parameter for comparison + byxgb_json["learner"]["objective"]["lambdarank_param"]["ndcg_exp_gain"] = "0" + assert byxgb.evals_result() == bynp.evals_result() + assert byxgb_json == bynp_json def test_ranking_with_unweighted_data(): @@ -73,8 +118,77 @@ def test_ranking_with_weighted_data(): assert all(p <= q for p, q in zip(is_sorted, is_sorted[1:])) -class TestRanking: +def test_error_msg() -> None: + X, y, qid, w = tm.make_ltr(10, 2, 2, 2) + ranker = xgboost.XGBRanker() + with pytest.raises(ValueError, match=r"equal to the number of query groups"): + ranker.fit(X, y, qid=qid, sample_weight=y) + + +@given(lambdarank_parameter_strategy) +@settings(deadline=None, print_blob=True) +def test_lambdarank_parameters(params): + if params["objective"] == "rank:map": + rel = 1 + else: + rel = 4 + X, y, q, w = tm.make_ltr(4096, 3, 13, rel) + ranker = xgboost.XGBRanker(tree_method="hist", n_estimators=64, **params) + ranker.fit(X, y, qid=q, sample_weight=w, eval_set=[(X, y)], eval_qid=[q]) + for k, v in ranker.evals_result()["validation_0"].items(): + note(v) + assert v[-1] >= v[0] + assert ranker.n_features_in_ == 3 + +@pytest.mark.skipif(**tm.no_pandas()) +@pytest.mark.skipif(**tm.no_sklearn()) +def test_unbiased() -> None: + import pandas as pd + from sklearn.model_selection import train_test_split + + X, y, q, w = tm.make_ltr(8192, 2, n_query_groups=6, max_rel=4) + X, Xe, y, ye, q, qe = train_test_split(X, y, q, test_size=0.2, random_state=3) + X = csr_matrix(X) + Xe = csr_matrix(Xe) + data = RelDataCV((X, y, q), (Xe, ye, qe), max_rel=4) + + train, _ = simulate_clicks(data) + x, c, y, q = sort_ltr_samples( + train.X, train.y, train.qid, train.click, train.pos + ) + df: Optional[pd.DataFrame] = None + + class Position(xgboost.callback.TrainingCallback): + def after_training(self, model) -> bool: + nonlocal df + config = json.loads(model.save_config()) + ti_plus = np.array(config["learner"]["objective"]["ti+"]) + tj_minus = np.array(config["learner"]["objective"]["tj-"]) + df = pd.DataFrame({"ti+": ti_plus, "tj-": tj_minus}) + return model + + ltr = xgboost.XGBRanker( + n_estimators=8, + tree_method="hist", + lambdarank_unbiased=True, + lambdarank_num_pair_per_sample=12, + lambdarank_pair_method="topk", + objective="rank:ndcg", + callbacks=[Position()], + boost_from_average=0, + ) + ltr.fit(x, c, qid=q, eval_set=[(x, c)], eval_qid=[q]) + + assert df is not None + # normalized + np.testing.assert_allclose(df["ti+"].iloc[0], 1.0) + np.testing.assert_allclose(df["tj-"].iloc[0], 1.0) + # less biased on low ranks. + assert df["ti+"].iloc[-1] < df["ti+"].iloc[0] + + +class TestRanking: @classmethod def setup_class(cls): """ diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index e0d3d680be68..d1915267b966 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -130,11 +130,11 @@ def test_ranking(): params = { "tree_method": "exact", + "objective": "rank:pairwise", "learning_rate": 0.1, "gamma": 1.0, "min_child_weight": 0.1, "max_depth": 6, - "eval_metric": "ndcg", "n_estimators": 4, } model = xgb.sklearn.XGBRanker(**params) @@ -163,7 +163,6 @@ def test_ranking(): "gamma": 1.0, "min_child_weight": 0.1, "max_depth": 6, - "eval_metric": "ndcg", } xgb_model_orig = xgb.train( params_orig, train_data, num_boost_round=4, evals=[(valid_data, "validation")] From c84bb082344a47e07ce6ffd526b71cec83a8aef6 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 2 Jun 2023 23:43:56 +0800 Subject: [PATCH 2/9] Cleanup. --- doc/parameter.rst | 1 - include/xgboost/cache.h | 1 - include/xgboost/data.h | 7 +++---- include/xgboost/objective.h | 1 - tests/cpp/common/test_ranking_utils.cc | 1 + 5 files changed, 4 insertions(+), 7 deletions(-) diff --git a/doc/parameter.rst b/doc/parameter.rst index 40ddb9247b83..f6d3a06b671f 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -425,7 +425,6 @@ Specify the learning task and the corresponding learning objective. The objectiv After XGBoost 1.6, both of the requirements and restrictions for using ``aucpr`` in classification problem are similar to ``auc``. For ranking task, only binary relevance label :math:`y \in [0, 1]` is supported. Different from ``map (mean average precision)``, ``aucpr`` calculates the *interpolated* area under precision recall curve using continuous interpolation. - ``pre``: Precision at :math:`k`. Supports only learning to rank task. - - ``ndcg``: `Normalized Discounted Cumulative Gain `_ - ``map``: `Mean Average Precision `_ diff --git a/include/xgboost/cache.h b/include/xgboost/cache.h index 279fef0907e4..32e1b21ac3f6 100644 --- a/include/xgboost/cache.h +++ b/include/xgboost/cache.h @@ -15,7 +15,6 @@ #include // for move #include // for vector - namespace xgboost { class DMatrix; /** diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 937475a9a015..6305abff840e 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -17,8 +17,6 @@ #include #include -#include // std::size_t -#include // std::uint64_t #include #include #include @@ -62,8 +60,9 @@ class MetaInfo { linalg::Tensor labels; /*! \brief data split mode */ DataSplitMode data_split_mode{DataSplitMode::kRow}; - /** - * \brief the index of begin and end of a group, needed when the learning task is ranking. + /*! + * \brief the index of begin and end of a group + * needed when the learning task is ranking. */ std::vector group_ptr_; // NOLINT /*! \brief weights of each instance, optional */ diff --git a/include/xgboost/objective.h b/include/xgboost/objective.h index 4fecf56884b2..a04d2e453df7 100644 --- a/include/xgboost/objective.h +++ b/include/xgboost/objective.h @@ -11,7 +11,6 @@ #include #include #include -#include // for Json, Null #include #include diff --git a/tests/cpp/common/test_ranking_utils.cc b/tests/cpp/common/test_ranking_utils.cc index 171d8af2eb87..919102278b98 100644 --- a/tests/cpp/common/test_ranking_utils.cc +++ b/tests/cpp/common/test_ranking_utils.cc @@ -99,6 +99,7 @@ void TestRankingCache(Context const* ctx) { auto rank_idx = cache.SortedIdx(ctx, ctx->IsCPU() ? predt.ConstHostSpan() : predt.ConstDeviceSpan()); + for (std::size_t i = 0; i < rank_idx.size(); ++i) { ASSERT_EQ(rank_idx[i], rank_idx.size() - i - 1); } From 116a63accc9ee0cef9cc77745ab0e078d9ad2851 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 3 Jun 2023 03:47:25 +0800 Subject: [PATCH 3/9] try to use view. --- python-package/xgboost/sklearn.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 43d531a9d7bc..8901a06ebb21 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1783,15 +1783,31 @@ def _get_qid( X: ArrayLike, qid: Optional[ArrayLike] ) -> Tuple[ArrayLike, Optional[ArrayLike]]: """Get the special qid column from X if exists.""" - if (_is_pandas_df(X) or _is_cudf_df(X)) and hasattr(X, "qid"): + has_qid = hasattr(X, "qid") + if (_is_pandas_df(X) or _is_cudf_df(X)) and has_qid: if qid is not None: raise ValueError( "Found both the special column `qid` in `X` and the `qid` from the" "`fit` method. Please remove one of them." ) + if _is_cudf_df(X) and has_qid: q_x = X.qid X = X.drop("qid", axis=1) return X, q_x + if _is_pandas_df(X) and has_qid: + import pandas as pd + + q_x = X.qid + series = [] + columns = X.columns.difference(["qid"]) + for c in columns: + if c == "qid": + continue + + s_view = X[c].view(X[c].dtype) + series.append(s_view) + X = pd.DataFrame(series, columns=columns) + return X, q_x return X, qid From 14bff7f68dcbff1676c95d07a157c05d7a26552d Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 3 Jun 2023 04:11:27 +0800 Subject: [PATCH 4/9] Revert "try to use view." This reverts commit 6030aa91029e66bd41871f12c41ba654fd82b4fc. --- python-package/xgboost/sklearn.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 8901a06ebb21..43d531a9d7bc 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -1783,31 +1783,15 @@ def _get_qid( X: ArrayLike, qid: Optional[ArrayLike] ) -> Tuple[ArrayLike, Optional[ArrayLike]]: """Get the special qid column from X if exists.""" - has_qid = hasattr(X, "qid") - if (_is_pandas_df(X) or _is_cudf_df(X)) and has_qid: + if (_is_pandas_df(X) or _is_cudf_df(X)) and hasattr(X, "qid"): if qid is not None: raise ValueError( "Found both the special column `qid` in `X` and the `qid` from the" "`fit` method. Please remove one of them." ) - if _is_cudf_df(X) and has_qid: q_x = X.qid X = X.drop("qid", axis=1) return X, q_x - if _is_pandas_df(X) and has_qid: - import pandas as pd - - q_x = X.qid - series = [] - columns = X.columns.difference(["qid"]) - for c in columns: - if c == "qid": - continue - - s_view = X[c].view(X[c].dtype) - series.append(s_view) - X = pd.DataFrame(series, columns=columns) - return X, q_x return X, qid From 4e195218c6687125c75ddc8bcfb003b34d7ee69b Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 3 Jun 2023 04:31:26 +0800 Subject: [PATCH 5/9] cleanup. --- python-package/xgboost/testing/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index de44d8882959..70e5361011b3 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -36,14 +36,11 @@ from xgboost.core import ArrayLike from xgboost.sklearn import SklObjective from xgboost.testing.data import ( - ClickFold, - RelDataCV, get_california_housing, get_cancer, get_digits, get_sparse, memory, - simulate_clicks, ) hypothesis = pytest.importorskip("hypothesis") From faf3a887f13596ef228a4ea00c5f491b2153d826 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sun, 4 Jun 2023 00:12:20 +0800 Subject: [PATCH 6/9] annotation. --- src/tree/updater_quantile_hist.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 68d74bea3b94..ae1c4b468617 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -419,6 +419,7 @@ class HistBuilder { CPUExpandEntry InitRoot(DMatrix *p_fmat, linalg::MatrixView gpair, RegTree *p_tree) { + monitor_->Start(__func__); CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0)); std::size_t page_id = 0; @@ -475,12 +476,14 @@ class HistBuilder { node = entries.front(); } + monitor_->Stop(__func__); return node; } void BuildHistogram(DMatrix *p_fmat, RegTree *p_tree, std::vector const &valid_candidates, linalg::MatrixView gpair) { + monitor_->Start(__func__); std::vector nodes_to_build(valid_candidates.size()); std::vector nodes_to_sub(valid_candidates.size()); @@ -508,6 +511,7 @@ class HistBuilder { nodes_to_sub, gpair.Values()); ++page_id; } + monitor_->Stop(__func__); } void UpdatePosition(DMatrix *p_fmat, RegTree const *p_tree, @@ -525,6 +529,7 @@ class HistBuilder { std::vector *p_out_position) { monitor_->Start(__func__); if (!task_->UpdateTreeLeaf()) { + monitor_->Stop(__func__); return; } for (auto const &part : partitioner_) { From fb9cf2037442b91d1176d8a550780d5b5347b475 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sun, 4 Jun 2023 00:34:57 +0800 Subject: [PATCH 7/9] remove the symmetric hessian. --- src/objective/lambdarank_obj.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/objective/lambdarank_obj.h b/src/objective/lambdarank_obj.h index c2222c028582..0771ba499145 100644 --- a/src/objective/lambdarank_obj.h +++ b/src/objective/lambdarank_obj.h @@ -123,7 +123,7 @@ LambdaGrad(linalg::VectorView labels, common::Span pre } auto lambda_ij = (sigmoid - 1.0) * delta_metric; - auto hessian_ij = std::max(sigmoid * (1.0 - sigmoid), Eps64()) * delta_metric * 2.0; + auto hessian_ij = std::max(sigmoid * (1.0 - sigmoid), Eps64()) * delta_metric; auto k = t_plus.Size(); assert(t_minus.Size() == k && "Invalid size of position bias"); From 9ac469191c4c9b1d637e04c32b28414bb902efc5 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 5 Jun 2023 08:28:54 +0800 Subject: [PATCH 8/9] footnote. --- doc/tutorials/learning_to_rank.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/tutorials/learning_to_rank.rst b/doc/tutorials/learning_to_rank.rst index 72e2123bbec9..b9883d236ac2 100644 --- a/doc/tutorials/learning_to_rank.rst +++ b/doc/tutorials/learning_to_rank.rst @@ -167,7 +167,9 @@ XGBoost implements distributed learning-to-rank with integration of multiple fra Reproducible Result ******************* -Like any other tasks, XGBoost should generate reproducible results given the same hardware and software environments (and data partitions, if distributed interface is used). Even when the underlying environment has changed, the result should still be consistent. However, when the ``lambdarank_pair_method`` is set to ``mean``, XGBoost uses random sampling, and results may differ depending on the platform used. The random number generator used on Windows (Microsoft Visual C++) is different from the ones used on other platforms like Linux (GCC, Clang), so the output varies significantly between these platforms. +Like any other tasks, XGBoost should generate reproducible results given the same hardware and software environments (and data partitions, if distributed interface is used). Even when the underlying environment has changed, the result should still be consistent. However, when the ``lambdarank_pair_method`` is set to ``mean``, XGBoost uses random sampling, and results may differ depending on the platform used. The random number generator used on Windows (Microsoft Visual C++) is different from the ones used on other platforms like Linux (GCC, Clang) [#f0]_, so the output varies significantly between these platforms. + +.. [#f0] `minstd_rand` implementation is different on MSVC. The implementations from GCC and Thrust produce the same output. ********** References From ffadd8099e7d1e5e5f3b2c45fb755eafc05c39ad Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 9 Jun 2023 09:52:31 +0800 Subject: [PATCH 9/9] Revert "remove the symmetric hessian." This reverts commit e7d74cedc30177ef4e431befe268880a67c87017. --- src/objective/lambdarank_obj.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/objective/lambdarank_obj.h b/src/objective/lambdarank_obj.h index 0771ba499145..c2222c028582 100644 --- a/src/objective/lambdarank_obj.h +++ b/src/objective/lambdarank_obj.h @@ -123,7 +123,7 @@ LambdaGrad(linalg::VectorView labels, common::Span pre } auto lambda_ij = (sigmoid - 1.0) * delta_metric; - auto hessian_ij = std::max(sigmoid * (1.0 - sigmoid), Eps64()) * delta_metric; + auto hessian_ij = std::max(sigmoid * (1.0 - sigmoid), Eps64()) * delta_metric * 2.0; auto k = t_plus.Size(); assert(t_minus.Size() == k && "Invalid size of position bias");