Skip to content

Commit

Permalink
Fix return type.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Apr 22, 2023
1 parent aa6befa commit 4d20b56
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 13 deletions.
1 change: 1 addition & 0 deletions python-package/xgboost/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ def make_position_biased_ltr(
(X_train, y_train, qid_train), (X_test, y_test, qid_test), max_rel=max_rel
)
clicks = simulate_clicks(data)
assert isinstance(clicks, ClickFold)
return clicks


Expand Down
20 changes: 11 additions & 9 deletions python-package/xgboost/testing/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
"""Utilities for data generation."""
import os
import zipfile
from collections import namedtuple
from dataclasses import dataclass
from typing import Any, Generator, List, NamedTuple, Tuple, Union
from typing import Any, Generator, List, NamedTuple, Optional, Tuple, Union
from urllib import request

import numpy as np
Expand Down Expand Up @@ -351,9 +350,8 @@ def get_mq2008(

@dataclass
class ClickFold:
"""A structure containing information about generated user-click data.
"""A structure containing information about generated user-click data."""

"""
X: sparse.csr_matrix
y: npt.NDArray[np.int32]
qid: npt.NDArray[np.int32]
Expand Down Expand Up @@ -513,7 +511,7 @@ def simulate_one_fold(


# pylint: disable=too-many-locals
def simulate_clicks(cv_data: RelDataCV) -> Tuple[ClickFold, ClickFold]:
def simulate_clicks(cv_data: RelDataCV) -> Tuple[ClickFold, Optional[ClickFold]]:
"""Simulate click data using position biased model (PBM)."""
X, y, qid = list(zip(cv_data.train, cv_data.test))

Expand Down Expand Up @@ -544,10 +542,14 @@ def simulate_clicks(cv_data: RelDataCV) -> Tuple[ClickFold, ClickFold]:
for i in range(2):
assert (scores_check_1[i] == scores[i]).all()

train, test = [
ClickFold(X_lst[i], y_lst[i], q_lst[i], s_lst[i], c_lst[i], p_lst[i])
for i in range(2)
]
if len(X_lst) == 1:
train = ClickFold(X_lst[0], y_lst[0], q_lst[0], s_lst[0], c_lst[0], p_lst[0])
test = None
else:
train, test = (
ClickFold(X_lst[i], y_lst[i], q_lst[i], s_lst[i], c_lst[i], p_lst[i])
for i in range(len(X_lst))
)
return train, test


Expand Down
8 changes: 4 additions & 4 deletions tests/python/test_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ def test_unbiased() -> None:
Xe = csr_matrix(Xe)
data = RelDataCV((X, y, q), (Xe, ye, qe), max_rel=5)

sim = simulate_clicks(data)
train = [pack[0] for pack in sim]
x, y, q, s, c, p = train
x, c, y, q = sort_ltr_samples(x, y, q, c, p)
train, _ = simulate_clicks(data)
x, c, y, q = sort_ltr_samples(
train.X, train.y, train.qid, train.click, train.pos
)
df: Optional[pd.DataFrame] = None

class Position(xgboost.callback.TrainingCallback):
Expand Down

0 comments on commit 4d20b56

Please sign in to comment.