Skip to content

Commit

Permalink
Implement NDCG cache. (#8893)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Mar 13, 2023
1 parent 9bade72 commit 8be6095
Show file tree
Hide file tree
Showing 7 changed files with 798 additions and 11 deletions.
98 changes: 94 additions & 4 deletions src/common/ranking_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
#include <algorithm> // for copy_n, max, min, none_of, all_of
#include <cstddef> // for size_t
#include <cstdio> // for sscanf
#include <exception> // for exception
#include <functional> // for greater
#include <iterator> // for reverse_iterator
#include <string> // for char_traits, string

#include "algorithm.h" // for ArgSort
Expand All @@ -18,10 +16,102 @@
#include "xgboost/base.h" // for bst_group_t
#include "xgboost/context.h" // for Context
#include "xgboost/data.h" // for MetaInfo
#include "xgboost/linalg.h" // for All, TensorView, Range, Tensor, Vector
#include "xgboost/logging.h" // for Error, LogCheck_EQ, CHECK_EQ
#include "xgboost/linalg.h" // for All, TensorView, Range
#include "xgboost/logging.h" // for CHECK_EQ

namespace xgboost::ltr {
void RankingCache::InitOnCPU(Context const* ctx, MetaInfo const& info) {
if (info.group_ptr_.empty()) {
group_ptr_.Resize(2, 0);
group_ptr_.HostVector()[1] = info.num_row_;
} else {
group_ptr_.HostVector() = info.group_ptr_;
}

auto const& gptr = group_ptr_.ConstHostVector();
for (std::size_t i = 1; i < gptr.size(); ++i) {
std::size_t n = gptr[i] - gptr[i - 1];
max_group_size_ = std::max(max_group_size_, n);
}

double sum_weights = 0;
auto n_groups = Groups();
auto weight = common::MakeOptionalWeights(ctx, info.weights_);
for (bst_omp_uint k = 0; k < n_groups; ++k) {
sum_weights += weight[k];
}
weight_norm_ = static_cast<double>(n_groups) / sum_weights;
}

common::Span<std::size_t const> RankingCache::MakeRankOnCPU(Context const* ctx,
common::Span<float const> predt) {
auto gptr = this->DataGroupPtr(ctx);
auto rank = this->sorted_idx_cache_.HostSpan();
CHECK_EQ(rank.size(), predt.size());

common::ParallelFor(this->Groups(), ctx->Threads(), [&](auto g) {
auto cnt = gptr[g + 1] - gptr[g];
auto g_predt = predt.subspan(gptr[g], cnt);
auto g_rank = rank.subspan(gptr[g], cnt);
auto sorted_idx = common::ArgSort<std::size_t>(
ctx, g_predt.data(), g_predt.data() + g_predt.size(), std::greater<>{});
CHECK_EQ(g_rank.size(), sorted_idx.size());
std::copy_n(sorted_idx.data(), sorted_idx.size(), g_rank.data());
});

return rank;
}

#if !defined(XGBOOST_USE_CUDA)
void RankingCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); }
common::Span<std::size_t const> RankingCache::MakeRankOnCUDA(Context const*,
common::Span<float const>) {
common::AssertGPUSupport();
return {};
}
#endif // !defined()

void NDCGCache::InitOnCPU(Context const* ctx, MetaInfo const& info) {
auto const h_group_ptr = this->DataGroupPtr(ctx);

discounts_.Resize(MaxGroupSize(), 0);
auto& h_discounts = discounts_.HostVector();
for (std::size_t i = 0; i < MaxGroupSize(); ++i) {
h_discounts[i] = CalcDCGDiscount(i);
}

auto n_groups = h_group_ptr.size() - 1;
auto h_labels = info.labels.HostView().Slice(linalg::All(), 0);

CheckNDCGLabels(this->Param(), h_labels,
[](auto beg, auto end, auto op) { return std::none_of(beg, end, op); });

inv_idcg_.Reshape(n_groups);
auto h_inv_idcg = inv_idcg_.HostView();
std::size_t topk = this->Param().TopK();
auto const exp_gain = this->Param().ndcg_exp_gain;

common::ParallelFor(n_groups, ctx->Threads(), [&](auto g) {
auto g_labels = h_labels.Slice(linalg::Range(h_group_ptr[g], h_group_ptr[g + 1]));
auto sorted_idx = common::ArgSort<std::size_t>(ctx, linalg::cbegin(g_labels),
linalg::cend(g_labels), std::greater<>{});

double idcg{0.0};
for (std::size_t i = 0; i < std::min(g_labels.Size(), topk); ++i) {
if (exp_gain) {
idcg += h_discounts[i] * CalcDCGGain(g_labels(sorted_idx[i]));
} else {
idcg += h_discounts[i] * g_labels(sorted_idx[i]);
}
}
h_inv_idcg(g) = CalcInvIDCG(idcg);
});
}

#if !defined(XGBOOST_USE_CUDA)
void NDCGCache::InitOnCUDA(Context const*, MetaInfo const&) { common::AssertGPUSupport(); }
#endif // !defined(XGBOOST_USE_CUDA)

DMLC_REGISTER_PARAMETER(LambdaRankParam);

std::string ParseMetricName(StringView name, StringView param, position_t* topn, bool* minus) {
Expand Down
207 changes: 207 additions & 0 deletions src/common/ranking_utils.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
/**
* Copyright 2023 by XGBoost Contributors
*/
#include <thrust/functional.h> // for maximum
#include <thrust/iterator/counting_iterator.h> // for make_counting_iterator
#include <thrust/logical.h> // for none_of, all_of
#include <thrust/pair.h> // for pair, make_pair
#include <thrust/reduce.h> // for reduce
#include <thrust/scan.h> // for inclusive_scan

#include <cstddef> // for size_t

#include "algorithm.cuh" // for SegmentedArgSort
#include "cuda_context.cuh" // for CUDAContext
#include "device_helpers.cuh" // for MakeTransformIterator, LaunchN
#include "optional_weight.h" // for MakeOptionalWeights, OptionalWeights
#include "ranking_utils.cuh" // for ThreadsForMean
#include "ranking_utils.h"
#include "threading_utils.cuh" // for SegmentedTrapezoidThreads
#include "xgboost/base.h" // for XGBOOST_DEVICE, bst_group_t
#include "xgboost/context.h" // for Context
#include "xgboost/linalg.h" // for VectorView, All, Range
#include "xgboost/logging.h" // for CHECK
#include "xgboost/span.h" // for Span

namespace xgboost::ltr {
namespace cuda_impl {
void CalcQueriesDCG(Context const* ctx, linalg::VectorView<float const> d_labels,
common::Span<std::size_t const> d_sorted_idx, bool exp_gain,
common::Span<bst_group_t const> d_group_ptr, std::size_t k,
linalg::VectorView<double> out_dcg) {
CHECK_EQ(d_group_ptr.size() - 1, out_dcg.Size());
using IdxGroup = thrust::pair<std::size_t, std::size_t>;
auto group_it = dh::MakeTransformIterator<IdxGroup>(
thrust::make_counting_iterator(0ull), [=] XGBOOST_DEVICE(std::size_t idx) {
return thrust::make_pair(idx, dh::SegmentId(d_group_ptr, idx)); // NOLINT
});
auto value_it = dh::MakeTransformIterator<double>(
group_it,
[exp_gain, d_labels, d_group_ptr, k,
d_sorted_idx] XGBOOST_DEVICE(IdxGroup const& l) -> double {
auto g_begin = d_group_ptr[l.second];
auto g_size = d_group_ptr[l.second + 1] - g_begin;

auto idx_in_group = l.first - g_begin;
if (idx_in_group >= k) {
return 0.0;
}
double gain{0.0};
auto g_sorted_idx = d_sorted_idx.subspan(g_begin, g_size);
auto g_labels = d_labels.Slice(linalg::Range(g_begin, g_begin + g_size));

if (exp_gain) {
gain = ltr::CalcDCGGain(g_labels(g_sorted_idx[idx_in_group]));
} else {
gain = g_labels(g_sorted_idx[idx_in_group]);
}
double discount = CalcDCGDiscount(idx_in_group);
return gain * discount;
});

CHECK(out_dcg.Contiguous());
std::size_t bytes;
cub::DeviceSegmentedReduce::Sum(nullptr, bytes, value_it, out_dcg.Values().data(),
d_group_ptr.size() - 1, d_group_ptr.data(),
d_group_ptr.data() + 1, ctx->CUDACtx()->Stream());
dh::TemporaryArray<char> temp(bytes);
cub::DeviceSegmentedReduce::Sum(temp.data().get(), bytes, value_it, out_dcg.Values().data(),
d_group_ptr.size() - 1, d_group_ptr.data(),
d_group_ptr.data() + 1, ctx->CUDACtx()->Stream());
}

void CalcQueriesInvIDCG(Context const* ctx, linalg::VectorView<float const> d_labels,
common::Span<bst_group_t const> d_group_ptr,
linalg::VectorView<double> out_inv_IDCG, ltr::LambdaRankParam const& p) {
CHECK_GE(d_group_ptr.size(), 2ul);
size_t n_groups = d_group_ptr.size() - 1;
CHECK_EQ(out_inv_IDCG.Size(), n_groups);
dh::device_vector<std::size_t> sorted_idx(d_labels.Size());
auto d_sorted_idx = dh::ToSpan(sorted_idx);
common::SegmentedArgSort<false, true>(ctx, d_labels.Values(), d_group_ptr, d_sorted_idx);
CalcQueriesDCG(ctx, d_labels, d_sorted_idx, p.ndcg_exp_gain, d_group_ptr, p.TopK(), out_inv_IDCG);
dh::LaunchN(out_inv_IDCG.Size(), ctx->CUDACtx()->Stream(),
[out_inv_IDCG] XGBOOST_DEVICE(size_t idx) mutable {
double idcg = out_inv_IDCG(idx);
out_inv_IDCG(idx) = CalcInvIDCG(idcg);
});
}
} // namespace cuda_impl

namespace {
struct CheckNDCGOp {
CUDAContext const* cuctx;
template <typename It, typename Op>
bool operator()(It beg, It end, Op op) {
return thrust::none_of(cuctx->CTP(), beg, end, op);
}
};
struct CheckMAPOp {
CUDAContext const* cuctx;
template <typename It, typename Op>
bool operator()(It beg, It end, Op op) {
return thrust::all_of(cuctx->CTP(), beg, end, op);
}
};

struct ThreadGroupOp {
common::Span<bst_group_t const> d_group_ptr;
std::size_t n_pairs;

common::Span<std::size_t> out_thread_group_ptr;

XGBOOST_DEVICE void operator()(std::size_t i) {
out_thread_group_ptr[i + 1] =
cuda_impl::ThreadsForMean(d_group_ptr[i + 1] - d_group_ptr[i], n_pairs);
}
};

struct GroupSizeOp {
common::Span<bst_group_t const> d_group_ptr;

XGBOOST_DEVICE auto operator()(std::size_t i) -> std::size_t {
return d_group_ptr[i + 1] - d_group_ptr[i];
}
};

struct WeightOp {
common::OptionalWeights d_weight;
XGBOOST_DEVICE auto operator()(std::size_t i) -> double { return d_weight[i]; }
};
} // anonymous namespace

void RankingCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
CUDAContext const* cuctx = ctx->CUDACtx();

group_ptr_.SetDevice(ctx->gpu_id);
if (info.group_ptr_.empty()) {
group_ptr_.Resize(2, 0);
group_ptr_.HostVector()[1] = info.num_row_;
} else {
auto const& h_group_ptr = info.group_ptr_;
group_ptr_.Resize(h_group_ptr.size());
auto d_group_ptr = group_ptr_.DeviceSpan();
dh::safe_cuda(cudaMemcpyAsync(d_group_ptr.data(), h_group_ptr.data(), d_group_ptr.size_bytes(),
cudaMemcpyHostToDevice, cuctx->Stream()));
}

auto d_group_ptr = DataGroupPtr(ctx);
std::size_t n_groups = Groups();

auto it = dh::MakeTransformIterator<std::size_t>(thrust::make_counting_iterator(0ul),
GroupSizeOp{d_group_ptr});
max_group_size_ =
thrust::reduce(cuctx->CTP(), it, it + n_groups, 0ul, thrust::maximum<std::size_t>{});

threads_group_ptr_.SetDevice(ctx->gpu_id);
threads_group_ptr_.Resize(n_groups + 1, 0);
auto d_threads_group_ptr = threads_group_ptr_.DeviceSpan();
if (param_.HasTruncation()) {
n_cuda_threads_ =
common::SegmentedTrapezoidThreads(d_group_ptr, d_threads_group_ptr, Param().NumPair());
} else {
auto n_pairs = Param().NumPair();
dh::LaunchN(n_groups, cuctx->Stream(),
ThreadGroupOp{d_group_ptr, n_pairs, d_threads_group_ptr});
thrust::inclusive_scan(cuctx->CTP(), dh::tcbegin(d_threads_group_ptr),
dh::tcend(d_threads_group_ptr), dh::tbegin(d_threads_group_ptr));
n_cuda_threads_ = info.num_row_ * param_.NumPair();
}

sorted_idx_cache_.SetDevice(ctx->gpu_id);
sorted_idx_cache_.Resize(info.labels.Size(), 0);

auto weight = common::MakeOptionalWeights(ctx, info.weights_);
auto w_it =
dh::MakeTransformIterator<double>(thrust::make_counting_iterator(0ul), WeightOp{weight});
weight_norm_ = static_cast<double>(n_groups) / thrust::reduce(w_it, w_it + n_groups);
}

common::Span<std::size_t const> RankingCache::MakeRankOnCUDA(Context const* ctx,
common::Span<float const> predt) {
auto d_sorted_idx = sorted_idx_cache_.DeviceSpan();
auto d_group_ptr = DataGroupPtr(ctx);
common::SegmentedArgSort<false, true>(ctx, predt, d_group_ptr, d_sorted_idx);
return d_sorted_idx;
}

void NDCGCache::InitOnCUDA(Context const* ctx, MetaInfo const& info) {
CUDAContext const* cuctx = ctx->CUDACtx();
auto labels = info.labels.View(ctx->gpu_id).Slice(linalg::All(), 0);
CheckNDCGLabels(this->Param(), labels, CheckNDCGOp{cuctx});

auto d_group_ptr = this->DataGroupPtr(ctx);

std::size_t n_groups = d_group_ptr.size() - 1;
inv_idcg_ = linalg::Zeros<double>(ctx, n_groups);
auto d_inv_idcg = inv_idcg_.View(ctx->gpu_id);
cuda_impl::CalcQueriesInvIDCG(ctx, labels, d_group_ptr, d_inv_idcg, this->Param());
CHECK_GE(this->Param().NumPair(), 1ul);

discounts_.SetDevice(ctx->gpu_id);
discounts_.Resize(MaxGroupSize());
auto d_discount = discounts_.DeviceSpan();
dh::LaunchN(MaxGroupSize(), cuctx->Stream(),
[=] XGBOOST_DEVICE(std::size_t i) { d_discount[i] = CalcDCGDiscount(i); });
}
} // namespace xgboost::ltr
40 changes: 40 additions & 0 deletions src/common/ranking_utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/**
* Copyright 2023 by XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_RANKING_UTILS_CUH_
#define XGBOOST_COMMON_RANKING_UTILS_CUH_

#include <cstddef> // for size_t

#include "ranking_utils.h" // for LambdaRankParam
#include "xgboost/base.h" // for bst_group_t, XGBOOST_DEVICE
#include "xgboost/context.h" // for Context
#include "xgboost/linalg.h" // for VectorView
#include "xgboost/span.h" // for Span

namespace xgboost {
namespace ltr {
namespace cuda_impl {
void CalcQueriesDCG(Context const *ctx, linalg::VectorView<float const> d_labels,
common::Span<std::size_t const> d_sorted_idx, bool exp_gain,
common::Span<bst_group_t const> d_group_ptr, std::size_t k,
linalg::VectorView<double> out_dcg);

void CalcQueriesInvIDCG(Context const *ctx, linalg::VectorView<float const> d_labels,
common::Span<bst_group_t const> d_group_ptr,
linalg::VectorView<double> out_inv_IDCG, ltr::LambdaRankParam const &p);

// Functions for creating number of threads for CUDA, and getting back the number of pairs
// from the number of threads.
XGBOOST_DEVICE __forceinline__ std::size_t ThreadsForMean(std::size_t group_size,
std::size_t n_pairs) {
return group_size * n_pairs;
}
XGBOOST_DEVICE __forceinline__ std::size_t PairsForGroup(std::size_t n_threads,
std::size_t group_size) {
return n_threads / group_size;
}
} // namespace cuda_impl
} // namespace ltr
} // namespace xgboost
#endif // XGBOOST_COMMON_RANKING_UTILS_CUH_
Loading

0 comments on commit 8be6095

Please sign in to comment.