diff --git a/tests/test_distributed/test_with_spark/test_spark_local.py b/tests/test_distributed/test_with_spark/test_spark_local.py index 01f3bd12e626..3664f06e4706 100644 --- a/tests/test_distributed/test_with_spark/test_spark_local.py +++ b/tests/test_distributed/test_with_spark/test_spark_local.py @@ -1327,32 +1327,37 @@ def test_unsupported_params(self): SparkXGBClassifier(evals_result={}) -class XgboostRankerLocalTest(SparkTestCase): - def setUp(self): - self.session.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8") - self.ranker_df_train = self.session.createDataFrame( - [ - (Vectors.dense(1.0, 2.0, 3.0), 0, 0), - (Vectors.dense(4.0, 5.0, 6.0), 1, 0), - (Vectors.dense(9.0, 4.0, 8.0), 2, 0), - (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1), - (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1), - (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1), - ], - ["features", "label", "qid"], - ) - self.ranker_df_test = self.session.createDataFrame( - [ - (Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988), - (Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556), - (Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570), - (Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988), - (Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612), - (Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826), - ], - ["features", "qid", "expected_prediction"], - ) - self.ranker_df_train_1 = self.session.createDataFrame( +LTRData = namedtuple( + "LTRData", ("df_train", "df_test", "df_train_1") +) + + +@pytest.fixture +def ltr_data(spark: SparkSession) -> Generator[LTRData, None, None]: + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8") + ranker_df_train = spark.createDataFrame( + [ + (Vectors.dense(1.0, 2.0, 3.0), 0, 0), + (Vectors.dense(4.0, 5.0, 6.0), 1, 0), + (Vectors.dense(9.0, 4.0, 8.0), 2, 0), + (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1), + (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1), + (Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1), + ], + ["features", "label", "qid"], + ) + ranker_df_test = spark.createDataFrame( + [ + (Vectors.dense(1.5, 2.0, 3.0), 0, -1.75218), + (Vectors.dense(4.5, 5.0, 6.0), 0, -0.34192949533462524), + (Vectors.dense(9.0, 4.5, 8.0), 0, 1.7251298427581787), + (Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.7521828413009644), + (Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -1.0988065004348755), + (Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 1.632674217224121), + ], + ["features", "qid", "expected_prediction"], + ) + ranker_df_train_1 = spark.createDataFrame( [ (Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 9), (Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 9), @@ -1369,19 +1374,21 @@ def setUp(self): ] * 4, ["features", "label", "qid"], - ) + ) + yield LTRData(ranker_df_train, ranker_df_test, ranker_df_train_1) - def test_ranker(self): + +class TestPySparkLocalLETOR: + def test_ranker(self, ltr_data: LTRData) -> None: ranker = SparkXGBRanker(qid_col="qid", objective="rank:pairwise") assert ranker.getOrDefault(ranker.objective) == "rank:pairwise" - model = ranker.fit(self.ranker_df_train) - pred_result = model.transform(self.ranker_df_test).collect() - + model = ranker.fit(ltr_data.df_train) + pred_result = model.transform(ltr_data.df_test).collect() for row in pred_result: assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3) - def test_ranker_qid_sorted(self): + def test_ranker_qid_sorted(self, ltr_data: LTRData) -> None: ranker = SparkXGBRanker(qid_col="qid", num_workers=4) assert ranker.getOrDefault(ranker.objective) == "rank:ndcg" - model = ranker.fit(self.ranker_df_train_1) - model.transform(self.ranker_df_test).collect() + model = ranker.fit(ltr_data.df_train_1) + model.transform(ltr_data.df_test).collect()