From 39d3a2d98f6366c39f40bc946be82cabd8a8c25c Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 2 Mar 2023 04:31:06 +0800 Subject: [PATCH] Update. --- tests/python/test_with_sklearn.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index a026c55cc3c4..3cbeb2e10c7d 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -184,9 +184,9 @@ def test_ranking_metric() -> None: def test_ranking_qid_df(): import pandas as pd import scipy.sparse - from sklearn.model_selection import cross_val_score + from sklearn.model_selection import StratifiedGroupKFold, cross_val_score - X, y, q, w = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=3, max_rel=3) + X, y, q, w = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=8, max_rel=3) # pack qid into x using dataframe df = pd.DataFrame(X) @@ -207,6 +207,11 @@ def test_ranking_qid_df(): s1 = ranker.score(df, y) assert np.isclose(s, s1) + # Works with standard sklearn cv + kfold = StratifiedGroupKFold(shuffle=False) + results = cross_val_score(ranker, df, y, cv=kfold, groups=df.qid) + assert len(results) == 5 + # Works with sparse data X_csr = scipy.sparse.csr_matrix(X) df = pd.DataFrame.sparse.from_spmatrix( @@ -218,10 +223,6 @@ def test_ranking_qid_df(): s2 = ranker.score(df, y) assert np.isclose(s2, s) - # Works with standard sklearn cv - results = cross_val_score(ranker, df, y) - assert len(results) == 5 - with pytest.raises(ValueError, match="Either `group` or `qid`."): ranker.fit(df, y, eval_set=[(X, y)])