From 7eba285a1e60197dd5fe8f61f8b17f7fc746afa2 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 7 Mar 2023 00:22:08 +0800 Subject: [PATCH] Support sklearn cross validation for ranker. (#8859) * Support sklearn cross validation for ranker. - Add a convention for X to include a special `qid` column. sklearn utilities consider only `X`, `y` and `sample_weight` for supervised learning algorithms, but we need an additional qid array for ranking. It's important to be able to support the cross validation function in sklearn since all other tuning functions like grid search are based on cross validation. --- python-package/xgboost/callback.py | 14 +- python-package/xgboost/collective.py | 2 +- python-package/xgboost/core.py | 10 ++ python-package/xgboost/rabit.py | 2 +- python-package/xgboost/sklearn.py | 159 +++++++++++++++++----- python-package/xgboost/testing/ranking.py | 72 ++++++++++ tests/python-gpu/test_gpu_with_sklearn.py | 8 ++ tests/python/test_with_sklearn.py | 8 ++ 8 files changed, 232 insertions(+), 43 deletions(-) create mode 100644 python-package/xgboost/testing/ranking.py diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index 76350d839dd1..5be6a058ac8e 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -23,7 +23,13 @@ import numpy from . import collective -from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees +from .core import ( + Booster, + DMatrix, + XGBoostError, + _get_booster_layer_trees, + _parse_eval_str, +) __all__ = [ "TrainingCallback", @@ -250,11 +256,7 @@ def after_iteration( for _, name in evals: assert name.find("-") == -1, "Dataset name should not contain `-`" score: str = model.eval_set(evals, epoch, self.metric, self._output_margin) - splited = score.split()[1:] # into datasets - # split up `test-error:0.1234` - metric_score_str = [tuple(s.split(":")) for s in splited] - # convert to float - metric_score = [(n, float(s)) for n, s in metric_score_str] + metric_score = _parse_eval_str(score) self._update_history(metric_score, epoch) ret = any(c.after_iteration(model, epoch, self.history) for c in self.callbacks) return ret diff --git a/python-package/xgboost/collective.py b/python-package/xgboost/collective.py index 7c586cba71d3..4c67ccbfcad7 100644 --- a/python-package/xgboost/collective.py +++ b/python-package/xgboost/collective.py @@ -231,7 +231,7 @@ def allreduce(data: np.ndarray, op: Op) -> np.ndarray: # pylint:disable=invalid if buf.base is data.base: buf = buf.copy() if buf.dtype not in DTYPE_ENUM__: - raise Exception(f"data type {buf.dtype} not supported") + raise TypeError(f"data type {buf.dtype} not supported") _check_call( _LIB.XGCommunicatorAllreduce( buf.ctypes.data_as(ctypes.c_void_p), diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index a186dc3963dc..5a0cfb3a2ece 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -111,6 +111,16 @@ def make_jcargs(**kwargs: Any) -> bytes: return from_pystr_to_cstr(json.dumps(kwargs)) +def _parse_eval_str(result: str) -> List[Tuple[str, float]]: + """Parse an eval result string from the booster.""" + splited = result.split()[1:] + # split up `test-error:0.1234` + metric_score_str = [tuple(s.split(":")) for s in splited] + # convert to float + metric_score = [(n, float(s)) for n, s in metric_score_str] + return metric_score + + IterRange = TypeVar("IterRange", Optional[Tuple[int, int]], Tuple[int, int]) diff --git a/python-package/xgboost/rabit.py b/python-package/xgboost/rabit.py index 0b8f143ecd35..132d721787b1 100644 --- a/python-package/xgboost/rabit.py +++ b/python-package/xgboost/rabit.py @@ -136,7 +136,7 @@ def allreduce( # pylint:disable=invalid-name """ if prepare_fun is None: return collective.allreduce(data, collective.Op(op)) - raise Exception("preprocessing function is no longer supported") + raise ValueError("preprocessing function is no longer supported") def version_number() -> int: diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 69bcac38d01a..3204f5a2a61e 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -43,8 +43,9 @@ XGBoostError, _convert_ntree_limit, _deprecate_positional_args, + _parse_eval_str, ) -from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array +from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array, _is_pandas_df from .training import train @@ -1812,32 +1813,43 @@ def fit( return self +def _get_qid( + X: ArrayLike, qid: Optional[ArrayLike] +) -> Tuple[ArrayLike, Optional[ArrayLike]]: + """Get the special qid column from X if exists.""" + if (_is_pandas_df(X) or _is_cudf_df(X)) and hasattr(X, "qid"): + if qid is not None: + raise ValueError( + "Found both the special column `qid` in `X` and the `qid` from the" + "`fit` method. Please remove one of them." + ) + q_x = X.qid + X = X.drop("qid", axis=1) + return X, q_x + return X, qid + + @xgboost_model_doc( - "Implementation of the Scikit-Learn API for XGBoost Ranking.", + """Implementation of the Scikit-Learn API for XGBoost Ranking.""", ["estimators", "model"], end_note=""" - .. note:: - - The default objective for XGBRanker is "rank:pairwise" - .. note:: A custom objective function is currently not supported by XGBRanker. - Likewise, a custom metric function is not supported either. .. note:: - Query group information is required for ranking tasks by either using the - `group` parameter or `qid` parameter in `fit` method. This information is - not required in 'predict' method and multiple groups can be predicted on - a single call to `predict`. + Query group information is only required for ranking training but not + prediction. Multiple groups can be predicted on a single call to + :py:meth:`predict`. When fitting the model with the `group` parameter, your data need to be sorted - by query group first. `group` must be an array that contains the size of each + by the query group first. `group` is an array that contains the size of each query group. - When fitting the model with the `qid` parameter, your data does not need - sorting. `qid` must be an array that contains the group of each training - sample. + + Similarly, when fitting the model with the `qid` parameter, the data should be + sorted according to query index and `qid` is an array that contains the query + index for each training sample. For example, if your original data look like: @@ -1859,9 +1871,10 @@ def fit( | 2 | 1 | x_7 | +-------+-----------+---------------+ - then `fit` method can be called with either `group` array as ``[3, 4]`` - or with `qid` as ``[`1, 1, 1, 2, 2, 2, 2]``, that is the qid column. -""", + then :py:meth:`fit` method can be called with either `group` array as ``[3, 4]`` + or with `qid` as ``[1, 1, 1, 2, 2, 2, 2]``, that is the qid column. Also, the + `qid` can be a special column of input `X` instead of a separated parameter, see + :py:meth:`fit` for more info.""", ) class XGBRanker(XGBModel, XGBRankerMixIn): # pylint: disable=missing-docstring,too-many-arguments,invalid-name @@ -1873,6 +1886,16 @@ def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any): if "rank:" not in objective: raise ValueError("please use XGBRanker for ranking task") + def _create_ltr_dmatrix( + self, ref: Optional[DMatrix], data: ArrayLike, qid: ArrayLike, **kwargs: Any + ) -> DMatrix: + data, qid = _get_qid(data, qid) + + if kwargs.get("group", None) is None and qid is None: + raise ValueError("Either `group` or `qid` is required for ranking task") + + return super()._create_dmatrix(ref=ref, data=data, qid=qid, **kwargs) + @_deprecate_positional_args def fit( self, @@ -1907,6 +1930,23 @@ def fit( X : Feature matrix. See :ref:`py-data` for a list of supported types. + When this is a :py:class:`pandas.DataFrame` or a :py:class:`cudf.DataFrame`, + it may contain a special column called ``qid`` for specifying the query + index. Using a special column is the same as using the `qid` parameter, + except for being compatible with sklearn utility functions like + :py:func:`sklearn.model_selection.cross_validation`. The same convention + applies to the :py:meth:`XGBRanker.score` and :py:meth:`XGBRanker.predict`. + + +-----+----------------+----------------+ + | qid | feat_0 | feat_1 | + +-----+----------------+----------------+ + | 0 | :math:`x_{00}` | :math:`x_{01}` | + +-----+----------------+----------------+ + | 1 | :math:`x_{10}` | :math:`x_{11}` | + +-----+----------------+----------------+ + | 1 | :math:`x_{20}` | :math:`x_{21}` | + +-----+----------------+----------------+ + When the ``tree_method`` is set to ``hist`` or ``gpu_hist``, internally, the :py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix` for conserving memory. However, this has performance implications when the @@ -1916,12 +1956,12 @@ def fit( y : Labels group : - Size of each query group of training data. Should have as many elements as the - query groups in the training data. If this is set to None, then user must - provide qid. + Size of each query group of training data. Should have as many elements as + the query groups in the training data. If this is set to None, then user + must provide qid. qid : Query ID for each training sample. Should have the size of n_samples. If - this is set to None, then user must provide group. + this is set to None, then user must provide group or a special column in X. sample_weight : Query group weights @@ -1929,8 +1969,9 @@ def fit( In ranking task, one weight is assigned to each query group/id (not each data point). This is because we only care about the relative ordering of - data points within each group, so it doesn't make sense to assign weights - to individual data points. + data points within each group, so it doesn't make sense to assign + weights to individual data points. + base_margin : Global bias for each instance. eval_set : @@ -1942,7 +1983,8 @@ def fit( query groups in the ``i``-th pair in **eval_set**. eval_qid : A list in which ``eval_qid[i]`` is the array containing query ID of ``i``-th - pair in **eval_set**. + pair in **eval_set**. The special column convention in `X` applies to + validation datasets as well. eval_metric : str, list of str, optional .. deprecated:: 1.6.0 @@ -1985,16 +2027,7 @@ def fit( Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead. """ - # check if group information is provided with config_context(verbosity=self.verbosity): - if group is None and qid is None: - raise ValueError("group or qid is required for ranking task") - - if eval_set is not None: - if eval_group is None and eval_qid is None: - raise ValueError( - "eval_group or eval_qid is required if eval_set is not None" - ) train_dmatrix, evals = _wrap_evaluation_matrices( missing=self.missing, X=X, @@ -2009,7 +2042,7 @@ def fit( base_margin_eval_set=base_margin_eval_set, eval_group=eval_group, eval_qid=eval_qid, - create_dmatrix=self._create_dmatrix, + create_dmatrix=self._create_ltr_dmatrix, enable_categorical=self.enable_categorical, feature_types=self.feature_types, ) @@ -2044,3 +2077,59 @@ def fit( self._set_evaluation_result(evals_result) return self + + def predict( + self, + X: ArrayLike, + output_margin: bool = False, + ntree_limit: Optional[int] = None, + validate_features: bool = True, + base_margin: Optional[ArrayLike] = None, + iteration_range: Optional[Tuple[int, int]] = None, + ) -> ArrayLike: + X, _ = _get_qid(X, None) + return super().predict( + X, + output_margin, + ntree_limit, + validate_features, + base_margin, + iteration_range, + ) + + def apply( + self, + X: ArrayLike, + ntree_limit: int = 0, + iteration_range: Optional[Tuple[int, int]] = None, + ) -> ArrayLike: + X, _ = _get_qid(X, None) + return super().apply(X, ntree_limit, iteration_range) + + def score(self, X: ArrayLike, y: ArrayLike) -> float: + """Evaluate score for data using the last evaluation metric. + + Parameters + ---------- + X : pd.DataFrame|cudf.DataFrame + Feature matrix. A DataFrame with a special `qid` column. + + y : + Labels + + Returns + ------- + score : + The result of the first evaluation metric for the ranker. + + """ + X, qid = _get_qid(X, None) + Xyq = DMatrix(X, y, qid=qid) + if callable(self.eval_metric): + metric = ltr_metric_decorator(self.eval_metric, self.n_jobs) + result_str = self.get_booster().eval_set([(Xyq, "eval")], feval=metric) + else: + result_str = self.get_booster().eval(Xyq) + + metric_score = _parse_eval_str(result_str) + return metric_score[-1][1] diff --git a/python-package/xgboost/testing/ranking.py b/python-package/xgboost/testing/ranking.py new file mode 100644 index 000000000000..fe4fc8404567 --- /dev/null +++ b/python-package/xgboost/testing/ranking.py @@ -0,0 +1,72 @@ +# pylint: disable=too-many-locals +"""Tests for learning to rank.""" +from types import ModuleType +from typing import Any + +import numpy as np +import pytest + +import xgboost as xgb +from xgboost import testing as tm + + +def run_ranking_qid_df(impl: ModuleType, tree_method: str) -> None: + """Test ranking with qid packed into X.""" + import scipy.sparse + from sklearn.metrics import mean_squared_error + from sklearn.model_selection import StratifiedGroupKFold, cross_val_score + + X, y, q, _ = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=8, max_rel=3) + + # pack qid into x using dataframe + df = impl.DataFrame(X) + df["qid"] = q + ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg", tree_method=tree_method) + ranker.fit(df, y) + s = ranker.score(df, y) + assert s > 0.7 + + # works with validation datasets as well + valid_df = df.copy() + valid_df.iloc[0, 0] = 3.0 + ranker.fit(df, y, eval_set=[(valid_df, y)]) + + # same as passing qid directly + ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg", tree_method=tree_method) + ranker.fit(X, y, qid=q) + s1 = ranker.score(df, y) + assert np.isclose(s, s1) + + # Works with standard sklearn cv + if tree_method != "gpu_hist": + # we need cuML for this. + kfold = StratifiedGroupKFold(shuffle=False) + results = cross_val_score(ranker, df, y, cv=kfold, groups=df.qid) + assert len(results) == 5 + + # Works with custom metric + def neg_mse(*args: Any, **kwargs: Any) -> float: + return -float(mean_squared_error(*args, **kwargs)) + + ranker = xgb.XGBRanker(n_estimators=3, eval_metric=neg_mse, tree_method=tree_method) + ranker.fit(df, y, eval_set=[(valid_df, y)]) + score = ranker.score(valid_df, y) + assert np.isclose(score, ranker.evals_result()["validation_0"]["neg_mse"][-1]) + + # Works with sparse data + if tree_method != "gpu_hist": + # no sparse with cuDF + X_csr = scipy.sparse.csr_matrix(X) + df = impl.DataFrame.sparse.from_spmatrix( + X_csr, columns=[str(i) for i in range(X.shape[1])] + ) + df["qid"] = q + ranker = xgb.XGBRanker( + n_estimators=3, eval_metric="ndcg", tree_method=tree_method + ) + ranker.fit(df, y) + s2 = ranker.score(df, y) + assert np.isclose(s2, s) + + with pytest.raises(ValueError, match="Either `group` or `qid`."): + ranker.fit(df, y, eval_set=[(X, y)]) diff --git a/tests/python-gpu/test_gpu_with_sklearn.py b/tests/python-gpu/test_gpu_with_sklearn.py index 8ecb4bdc77cc..c9d3ab4ebff7 100644 --- a/tests/python-gpu/test_gpu_with_sklearn.py +++ b/tests/python-gpu/test_gpu_with_sklearn.py @@ -8,6 +8,7 @@ import xgboost as xgb from xgboost import testing as tm +from xgboost.testing.ranking import run_ranking_qid_df sys.path.append("tests/python") import test_with_sklearn as twskl # noqa @@ -153,3 +154,10 @@ def test_classififer(): y *= 10 with pytest.raises(ValueError, match=r"Invalid classes.*"): clf.fit(X, y) + + +@pytest.mark.skipif(**tm.no_pandas()) +def test_ranking_qid_df(): + import cudf + + run_ranking_qid_df(cudf, "gpu_hist") diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index bc7a3e94e437..baef690ee32e 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -11,6 +11,7 @@ import xgboost as xgb from xgboost import testing as tm +from xgboost.testing.ranking import run_ranking_qid_df from xgboost.testing.shared import get_feature_weights, validate_data_initialization from xgboost.testing.updater import get_basescore @@ -180,6 +181,13 @@ def test_ranking_metric() -> None: assert results["validation_0"]["roc_auc_score"][-1] > 0.6 +@pytest.mark.skipif(**tm.no_pandas()) +def test_ranking_qid_df(): + import pandas as pd + + run_ranking_qid_df(pd, "hist") + + def test_stacking_regression(): from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor, StackingRegressor