Skip to content

Commit

Permalink
Convert pyspark tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Feb 28, 2023
1 parent f13bde8 commit 40678c6
Showing 1 changed file with 41 additions and 34 deletions.
75 changes: 41 additions & 34 deletions tests/test_distributed/test_with_spark/test_spark_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,32 +1327,37 @@ def test_unsupported_params(self):
SparkXGBClassifier(evals_result={})


class XgboostRankerLocalTest(SparkTestCase):
def setUp(self):
self.session.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8")
self.ranker_df_train = self.session.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, 0),
(Vectors.dense(4.0, 5.0, 6.0), 1, 0),
(Vectors.dense(9.0, 4.0, 8.0), 2, 0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1),
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1),
],
["features", "label", "qid"],
)
self.ranker_df_test = self.session.createDataFrame(
[
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.87988),
(Vectors.dense(4.5, 5.0, 6.0), 0, 0.29556),
(Vectors.dense(9.0, 4.5, 8.0), 0, 2.36570),
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.87988),
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -0.30612),
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 2.44826),
],
["features", "qid", "expected_prediction"],
)
self.ranker_df_train_1 = self.session.createDataFrame(
LTRData = namedtuple(
"LTRData", ("df_train", "df_test", "df_train_1")
)


@pytest.fixture
def ltr_data(spark: SparkSession) -> Generator[LTRData, None, None]:
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "8")
ranker_df_train = spark.createDataFrame(
[
(Vectors.dense(1.0, 2.0, 3.0), 0, 0),
(Vectors.dense(4.0, 5.0, 6.0), 1, 0),
(Vectors.dense(9.0, 4.0, 8.0), 2, 0),
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 1),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 1),
(Vectors.sparse(3, {1: 8.0, 2: 9.5}), 2, 1),
],
["features", "label", "qid"],
)
ranker_df_test = spark.createDataFrame(
[
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.75218),
(Vectors.dense(4.5, 5.0, 6.0), 0, -0.34192949533462524),
(Vectors.dense(9.0, 4.5, 8.0), 0, 1.7251298427581787),
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.7521828413009644),
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -1.0988065004348755),
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 1.632674217224121),
],
["features", "qid", "expected_prediction"],
)
ranker_df_train_1 = spark.createDataFrame(
[
(Vectors.sparse(3, {1: 1.0, 2: 5.5}), 0, 9),
(Vectors.sparse(3, {1: 6.0, 2: 7.5}), 1, 9),
Expand All @@ -1369,19 +1374,21 @@ def setUp(self):
]
* 4,
["features", "label", "qid"],
)
)
yield LTRData(ranker_df_train, ranker_df_test, ranker_df_train_1)

def test_ranker(self):

class TestPySparkLocalLETOR:
def test_ranker(self, ltr_data: LTRData) -> None:
ranker = SparkXGBRanker(qid_col="qid", objective="rank:pairwise")
assert ranker.getOrDefault(ranker.objective) == "rank:pairwise"
model = ranker.fit(self.ranker_df_train)
pred_result = model.transform(self.ranker_df_test).collect()

model = ranker.fit(ltr_data.df_train)
pred_result = model.transform(ltr_data.df_test).collect()
for row in pred_result:
assert np.isclose(row.prediction, row.expected_prediction, rtol=1e-3)

def test_ranker_qid_sorted(self):
def test_ranker_qid_sorted(self, ltr_data: LTRData) -> None:
ranker = SparkXGBRanker(qid_col="qid", num_workers=4)
assert ranker.getOrDefault(ranker.objective) == "rank:ndcg"
model = ranker.fit(self.ranker_df_train_1)
model.transform(self.ranker_df_test).collect()
model = ranker.fit(ltr_data.df_train_1)
model.transform(ltr_data.df_test).collect()

0 comments on commit 40678c6

Please sign in to comment.