From aa6befaa71dbf8755337252cd6acaffafbb06539 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 22 Apr 2023 19:26:49 +0800 Subject: [PATCH] Use dataclass. --- demo/guide-python/learning_to_rank.py | 35 ++++++++++------------ python-package/xgboost/testing/data.py | 40 ++++++++++++++++++-------- 2 files changed, 43 insertions(+), 32 deletions(-) diff --git a/demo/guide-python/learning_to_rank.py b/demo/guide-python/learning_to_rank.py index 61de1e9e79cd..f4e78e5e10fd 100644 --- a/demo/guide-python/learning_to_rank.py +++ b/demo/guide-python/learning_to_rank.py @@ -121,31 +121,26 @@ def ranking_demo(args: argparse.Namespace) -> None: def click_data_demo(args: argparse.Namespace) -> None: """Demonstration for learning to rank with click data.""" data = load_mlsr_10k(args.data, args.cache) - folds = simulate_clicks(data) + train, test = simulate_clicks(data) - train = [pack[0] for pack in folds] - test = [pack[1] for pack in folds] - - X_train, y_train, qid_train, scores_train, clicks_train, position_train = train - assert X_train.shape[0] == clicks_train.size - X_test, y_test, qid_test, scores_test, clicks_test, position_test = test - assert X_test.shape[0] == clicks_test.size - assert scores_test.dtype == np.float32 - assert clicks_test.dtype == np.int32 + assert train.X.shape[0] == train.click.size + assert test.X.shape[0] == test.click.size + assert test.score.dtype == np.float32 + assert test.click.dtype == np.int32 X_train, clicks_train, y_train, qid_train = sort_ltr_samples( - X_train, - y_train, - qid_train, - clicks_train, - position_train, + train.X, + train.y, + train.qid, + train.click, + train.pos, ) X_test, clicks_test, y_test, qid_test = sort_ltr_samples( - X_test, - y_test, - qid_test, - clicks_test, - position_test, + test.X, + test.y, + test.qid, + test.click, + test.pos, ) class ShowPosition(xgb.callback.TrainingCallback): diff --git a/python-package/xgboost/testing/data.py b/python-package/xgboost/testing/data.py index 7da03caa8d7b..a9bb2674c2a0 100644 --- a/python-package/xgboost/testing/data.py +++ b/python-package/xgboost/testing/data.py @@ -3,6 +3,7 @@ import os import zipfile from collections import namedtuple +from dataclasses import dataclass from typing import Any, Generator, List, NamedTuple, Tuple, Union from urllib import request @@ -345,10 +346,22 @@ def get_mq2008( ) -ClickFold = namedtuple("ClickFold", ("X", "y", "q", "s", "c", "p")) RelData = Tuple[sparse.csr_matrix, npt.NDArray[np.int32], npt.NDArray[np.int32]] +@dataclass +class ClickFold: + """A structure containing information about generated user-click data. + + """ + X: sparse.csr_matrix + y: npt.NDArray[np.int32] + qid: npt.NDArray[np.int32] + score: npt.NDArray[np.float32] + click: npt.NDArray[np.int32] + pos: npt.NDArray[np.int64] + + class RelDataCV(NamedTuple): """Simple data struct for train-test split.""" @@ -499,7 +512,8 @@ def simulate_one_fold( return ClickFold(X_fold, y_fold, qid_fold, scores_fold, clicks, position) -def simulate_clicks(cv_data: RelDataCV) -> ClickFold: # pylint: disable=too-many-locals +# pylint: disable=too-many-locals +def simulate_clicks(cv_data: RelDataCV) -> Tuple[ClickFold, ClickFold]: """Simulate click data using position biased model (PBM).""" X, y, qid = list(zip(cv_data.train, cv_data.test)) @@ -518,21 +532,23 @@ def simulate_clicks(cv_data: RelDataCV) -> ClickFold: # pylint: disable=too-man X_lst, y_lst, q_lst, s_lst, c_lst, p_lst = [], [], [], [], [], [] for i in range(indptr.size - 1): - x_, y_, q_, s_, c_, p_ = simulate_one_fold((X[i], y[i], qid[i]), scores[i]) - X_lst.append(x_) - y_lst.append(y_) - q_lst.append(q_) - s_lst.append(s_) - c_lst.append(c_) - p_lst.append(p_) + fold = simulate_one_fold((X[i], y[i], qid[i]), scores[i]) + X_lst.append(fold.X) + y_lst.append(fold.y) + q_lst.append(fold.qid) + s_lst.append(fold.score) + c_lst.append(fold.click) + p_lst.append(fold.pos) scores_check_1 = [s_lst[i] for i in range(indptr.size - 1)] for i in range(2): assert (scores_check_1[i] == scores[i]).all() - data = ClickFold(X_lst, y_lst, q_lst, s_lst, c_lst, p_lst) - - return data + train, test = [ + ClickFold(X_lst[i], y_lst[i], q_lst[i], s_lst[i], c_lst[i], p_lst[i]) + for i in range(2) + ] + return train, test def sort_ltr_samples(