Skip to content

Commit

Permalink
Use dataclass.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Apr 22, 2023
1 parent 02a29b1 commit aa6befa
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 32 deletions.
35 changes: 15 additions & 20 deletions demo/guide-python/learning_to_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,31 +121,26 @@ def ranking_demo(args: argparse.Namespace) -> None:
def click_data_demo(args: argparse.Namespace) -> None:
"""Demonstration for learning to rank with click data."""
data = load_mlsr_10k(args.data, args.cache)
folds = simulate_clicks(data)
train, test = simulate_clicks(data)

train = [pack[0] for pack in folds]
test = [pack[1] for pack in folds]

X_train, y_train, qid_train, scores_train, clicks_train, position_train = train
assert X_train.shape[0] == clicks_train.size
X_test, y_test, qid_test, scores_test, clicks_test, position_test = test
assert X_test.shape[0] == clicks_test.size
assert scores_test.dtype == np.float32
assert clicks_test.dtype == np.int32
assert train.X.shape[0] == train.click.size
assert test.X.shape[0] == test.click.size
assert test.score.dtype == np.float32
assert test.click.dtype == np.int32

X_train, clicks_train, y_train, qid_train = sort_ltr_samples(
X_train,
y_train,
qid_train,
clicks_train,
position_train,
train.X,
train.y,
train.qid,
train.click,
train.pos,
)
X_test, clicks_test, y_test, qid_test = sort_ltr_samples(
X_test,
y_test,
qid_test,
clicks_test,
position_test,
test.X,
test.y,
test.qid,
test.click,
test.pos,
)

class ShowPosition(xgb.callback.TrainingCallback):
Expand Down
40 changes: 28 additions & 12 deletions python-package/xgboost/testing/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import zipfile
from collections import namedtuple
from dataclasses import dataclass
from typing import Any, Generator, List, NamedTuple, Tuple, Union
from urllib import request

Expand Down Expand Up @@ -345,10 +346,22 @@ def get_mq2008(
)


ClickFold = namedtuple("ClickFold", ("X", "y", "q", "s", "c", "p"))
RelData = Tuple[sparse.csr_matrix, npt.NDArray[np.int32], npt.NDArray[np.int32]]


@dataclass
class ClickFold:
"""A structure containing information about generated user-click data.
"""
X: sparse.csr_matrix
y: npt.NDArray[np.int32]
qid: npt.NDArray[np.int32]
score: npt.NDArray[np.float32]
click: npt.NDArray[np.int32]
pos: npt.NDArray[np.int64]


class RelDataCV(NamedTuple):
"""Simple data struct for train-test split."""

Expand Down Expand Up @@ -499,7 +512,8 @@ def simulate_one_fold(
return ClickFold(X_fold, y_fold, qid_fold, scores_fold, clicks, position)


def simulate_clicks(cv_data: RelDataCV) -> ClickFold: # pylint: disable=too-many-locals
# pylint: disable=too-many-locals
def simulate_clicks(cv_data: RelDataCV) -> Tuple[ClickFold, ClickFold]:
"""Simulate click data using position biased model (PBM)."""
X, y, qid = list(zip(cv_data.train, cv_data.test))

Expand All @@ -518,21 +532,23 @@ def simulate_clicks(cv_data: RelDataCV) -> ClickFold: # pylint: disable=too-man

X_lst, y_lst, q_lst, s_lst, c_lst, p_lst = [], [], [], [], [], []
for i in range(indptr.size - 1):
x_, y_, q_, s_, c_, p_ = simulate_one_fold((X[i], y[i], qid[i]), scores[i])
X_lst.append(x_)
y_lst.append(y_)
q_lst.append(q_)
s_lst.append(s_)
c_lst.append(c_)
p_lst.append(p_)
fold = simulate_one_fold((X[i], y[i], qid[i]), scores[i])
X_lst.append(fold.X)
y_lst.append(fold.y)
q_lst.append(fold.qid)
s_lst.append(fold.score)
c_lst.append(fold.click)
p_lst.append(fold.pos)

scores_check_1 = [s_lst[i] for i in range(indptr.size - 1)]
for i in range(2):
assert (scores_check_1[i] == scores[i]).all()

data = ClickFold(X_lst, y_lst, q_lst, s_lst, c_lst, p_lst)

return data
train, test = [
ClickFold(X_lst[i], y_lst[i], q_lst[i], s_lst[i], c_lst[i], p_lst[i])
for i in range(2)
]
return train, test


def sort_ltr_samples(
Expand Down

0 comments on commit aa6befa

Please sign in to comment.