From 4d20b56591841f481b3d699da3a04168bde875a8 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 22 Apr 2023 22:14:14 +0800 Subject: [PATCH] Fix return type. --- python-package/xgboost/testing/__init__.py | 1 + python-package/xgboost/testing/data.py | 20 +++++++++++--------- tests/python/test_ranking.py | 8 ++++---- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 5a23d859f5b2..6cccee4d3d0d 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -464,6 +464,7 @@ def make_position_biased_ltr( (X_train, y_train, qid_train), (X_test, y_test, qid_test), max_rel=max_rel ) clicks = simulate_clicks(data) + assert isinstance(clicks, ClickFold) return clicks diff --git a/python-package/xgboost/testing/data.py b/python-package/xgboost/testing/data.py index a9bb2674c2a0..1aca269f60b3 100644 --- a/python-package/xgboost/testing/data.py +++ b/python-package/xgboost/testing/data.py @@ -2,9 +2,8 @@ """Utilities for data generation.""" import os import zipfile -from collections import namedtuple from dataclasses import dataclass -from typing import Any, Generator, List, NamedTuple, Tuple, Union +from typing import Any, Generator, List, NamedTuple, Optional, Tuple, Union from urllib import request import numpy as np @@ -351,9 +350,8 @@ def get_mq2008( @dataclass class ClickFold: - """A structure containing information about generated user-click data. + """A structure containing information about generated user-click data.""" - """ X: sparse.csr_matrix y: npt.NDArray[np.int32] qid: npt.NDArray[np.int32] @@ -513,7 +511,7 @@ def simulate_one_fold( # pylint: disable=too-many-locals -def simulate_clicks(cv_data: RelDataCV) -> Tuple[ClickFold, ClickFold]: +def simulate_clicks(cv_data: RelDataCV) -> Tuple[ClickFold, Optional[ClickFold]]: """Simulate click data using position biased model (PBM).""" X, y, qid = list(zip(cv_data.train, cv_data.test)) @@ -544,10 +542,14 @@ def simulate_clicks(cv_data: RelDataCV) -> Tuple[ClickFold, ClickFold]: for i in range(2): assert (scores_check_1[i] == scores[i]).all() - train, test = [ - ClickFold(X_lst[i], y_lst[i], q_lst[i], s_lst[i], c_lst[i], p_lst[i]) - for i in range(2) - ] + if len(X_lst) == 1: + train = ClickFold(X_lst[0], y_lst[0], q_lst[0], s_lst[0], c_lst[0], p_lst[0]) + test = None + else: + train, test = ( + ClickFold(X_lst[i], y_lst[i], q_lst[i], s_lst[i], c_lst[i], p_lst[i]) + for i in range(len(X_lst)) + ) return train, test diff --git a/tests/python/test_ranking.py b/tests/python/test_ranking.py index f7816f98740c..7c30618af4de 100644 --- a/tests/python/test_ranking.py +++ b/tests/python/test_ranking.py @@ -114,10 +114,10 @@ def test_unbiased() -> None: Xe = csr_matrix(Xe) data = RelDataCV((X, y, q), (Xe, ye, qe), max_rel=5) - sim = simulate_clicks(data) - train = [pack[0] for pack in sim] - x, y, q, s, c, p = train - x, c, y, q = sort_ltr_samples(x, y, q, c, p) + train, _ = simulate_clicks(data) + x, c, y, q = sort_ltr_samples( + train.X, train.y, train.qid, train.click, train.pos + ) df: Optional[pd.DataFrame] = None class Position(xgboost.callback.TrainingCallback):