Skip to content

Commit

Permalink
tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 30, 2023
1 parent 36b2723 commit 0da99bd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
1 change: 0 additions & 1 deletion src/metric/rank_metric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,6 @@ class EvalPrecision : public EvalRankWithCache<ltr::MAPCache> {
n_hits += g_label(g_rank[i]) * weight[g];
}
}

pre[g] = n_hits / n;
});

Expand Down
24 changes: 13 additions & 11 deletions tests/cpp/metric/test_rank_metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,26 @@ namespace xgboost::metric {

inline void VerifyPrecision(DataSplitMode data_split_mode = DataSplitMode::kRow) {
auto ctx = xgboost::CreateEmptyGenericParam(GPUIDX);
xgboost::Metric* metric = xgboost::Metric::Create("pre", &ctx);
std::unique_ptr<xgboost::Metric> metric{Metric::Create("pre", &ctx)};
ASSERT_STREQ(metric->Name(), "pre");
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}, {}, {}, data_split_mode), 0.5, 1e-7);
EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1}, {0, 1}, {}, {}, data_split_mode), 0.5, 1e-7);
EXPECT_NEAR(
GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}, {}, {}, data_split_mode), 0.5,
1e-7);
GetMetricEval(metric.get(), {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}, {}, {}, data_split_mode),
0.5, 1e-7);

delete metric;
metric = xgboost::Metric::Create("pre@2", &ctx);
metric.reset(xgboost::Metric::Create("pre@2", &ctx));
ASSERT_STREQ(metric->Name(), "pre@2");
EXPECT_NEAR(GetMetricEval(metric, {0, 1}, {0, 1}, {}, {}, data_split_mode), 0.5f, 1e-7);
EXPECT_NEAR(GetMetricEval(metric.get(), {0, 1}, {0, 1}, {}, {}, data_split_mode), 0.5f, 1e-7);
EXPECT_NEAR(
GetMetricEval(metric, {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}, {}, {}, data_split_mode), 0.5f,
0.001f);
GetMetricEval(metric.get(), {0.1f, 0.9f, 0.1f, 0.9f}, {0, 0, 1, 1}, {}, {}, data_split_mode),
0.5f, 0.001f);

EXPECT_ANY_THROW(GetMetricEval(metric, {0, 1}, {}, {}, {}, data_split_mode));
EXPECT_ANY_THROW(GetMetricEval(metric.get(), {0, 1}, {}, {}, {}, data_split_mode));

delete metric;
metric.reset(xgboost::Metric::Create("pre@4", &ctx));
EXPECT_NEAR(GetMetricEval(metric.get(), {0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f},
{0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 1.0f}, {}, {}, data_split_mode),
0.5f, 1e-7);
}

inline void VerifyNDCG(DataSplitMode data_split_mode = DataSplitMode::kRow) {
Expand Down

0 comments on commit 0da99bd

Please sign in to comment.