Skip to content

Commit

Permalink
Cleanup GPU ranking metric.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 30, 2023
1 parent c3a15c2 commit 3005f21
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 423 deletions.
170 changes: 0 additions & 170 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -825,176 +825,6 @@ XGBOOST_DEVICE auto tcrend(xgboost::common::Span<T> const &span) { // NOLINT
return tcrbegin(span) + span.size();
}

// This type sorts an array which is divided into multiple groups. The sorting is influenced
// by the function object 'Comparator'
template <typename T>
class SegmentSorter {
private:
// Items sorted within the group
caching_device_vector<T> ditems_;

// Original position of the items before they are sorted descending within their groups
caching_device_vector<uint32_t> doriginal_pos_;

// Segments within the original list that delineates the different groups
caching_device_vector<uint32_t> group_segments_;

// Need this on the device as it is used in the kernels
caching_device_vector<uint32_t> dgroups_; // Group information on device

// Where did the item that was originally present at position 'x' move to after they are sorted
caching_device_vector<uint32_t> dindexable_sorted_pos_;

// Initialize everything but the segments
void Init(uint32_t num_elems) {
ditems_.resize(num_elems);

doriginal_pos_.resize(num_elems);
thrust::sequence(doriginal_pos_.begin(), doriginal_pos_.end());
}

// Initialize all with group info
void Init(const std::vector<uint32_t> &groups) {
uint32_t num_elems = groups.back();
this->Init(num_elems);
this->CreateGroupSegments(groups);
}

public:
// This needs to be public due to device lambda
void CreateGroupSegments(const std::vector<uint32_t> &groups) {
uint32_t num_elems = groups.back();
group_segments_.resize(num_elems, 0);

dgroups_ = groups;

if (GetNumGroups() == 1) return; // There are no segments; hence, no need to compute them

// Define the segments by assigning a group ID to each element
const uint32_t *dgroups = dgroups_.data().get();
uint32_t ngroups = dgroups_.size();
auto ComputeGroupIDLambda = [=] __device__(uint32_t idx) {
return thrust::upper_bound(thrust::seq, dgroups, dgroups + ngroups, idx) -
dgroups - 1;
}; // NOLINT

thrust::transform(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
thrust::make_counting_iterator(num_elems),
group_segments_.begin(),
ComputeGroupIDLambda);
}

// Accessors that returns device pointer
inline uint32_t GetNumItems() const { return ditems_.size(); }
inline const xgboost::common::Span<const T> GetItemsSpan() const {
return { ditems_.data().get(), ditems_.size() };
}

inline const xgboost::common::Span<const uint32_t> GetOriginalPositionsSpan() const {
return { doriginal_pos_.data().get(), doriginal_pos_.size() };
}

inline const xgboost::common::Span<const uint32_t> GetGroupSegmentsSpan() const {
return { group_segments_.data().get(), group_segments_.size() };
}

inline uint32_t GetNumGroups() const { return dgroups_.size() - 1; }
inline const xgboost::common::Span<const uint32_t> GetGroupsSpan() const {
return { dgroups_.data().get(), dgroups_.size() };
}

inline const xgboost::common::Span<const uint32_t> GetIndexableSortedPositionsSpan() const {
return { dindexable_sorted_pos_.data().get(), dindexable_sorted_pos_.size() };
}

// Sort an array that is divided into multiple groups. The array is sorted within each group.
// This version provides the group information that is on the host.
// The array is sorted based on an adaptable binary predicate. By default a stateless predicate
// is used.
template <typename Comparator = thrust::greater<T>>
void SortItems(const T *ditems, uint32_t item_size, const std::vector<uint32_t> &groups,
const Comparator &comp = Comparator()) {
this->Init(groups);
this->SortItems(ditems, item_size, this->GetGroupSegmentsSpan(), comp);
}

// Sort an array that is divided into multiple groups. The array is sorted within each group.
// This version provides the group information that is on the device.
// The array is sorted based on an adaptable binary predicate. By default a stateless predicate
// is used.
template <typename Comparator = thrust::greater<T>>
void SortItems(const T *ditems, uint32_t item_size,
const xgboost::common::Span<const uint32_t> &group_segments,
const Comparator &comp = Comparator()) {
this->Init(item_size);

// Sort the items that are grouped. We would like to avoid using predicates to perform the sort,
// as thrust resorts to using a merge sort as opposed to a much much faster radix sort
// when comparators are used. Hence, the following algorithm is used. This is done so that
// we can grab the appropriate related values from the original list later, after the
// items are sorted.
//
// Here is the internal representation:
// dgroups_: [ 0, 3, 5, 8, 10 ]
// group_segments_: 0 0 0 | 1 1 | 2 2 2 | 3 3
// doriginal_pos_: 0 1 2 | 3 4 | 5 6 7 | 8 9
// ditems_: 1 0 1 | 2 1 | 1 3 3 | 4 4 (from original items)
//
// Sort the items first and make a note of the original positions in doriginal_pos_
// based on the sort
// ditems_: 4 4 3 3 2 1 1 1 1 0
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1
// NOTE: This consumes space, but is much faster than some of the other approaches - sorting
// in kernel, sorting using predicates etc.

ditems_.assign(thrust::device_ptr<const T>(ditems),
thrust::device_ptr<const T>(ditems) + item_size);

// Allocator to be used by sort for managing space overhead while sorting
dh::XGBCachingDeviceAllocator<char> alloc;

thrust::stable_sort_by_key(thrust::cuda::par(alloc),
ditems_.begin(), ditems_.end(),
doriginal_pos_.begin(), comp);

if (GetNumGroups() == 1) return; // The entire array is sorted, as it isn't segmented

// Next, gather the segments based on the doriginal_pos_. This is to reflect the
// holisitic item sort order on the segments
// group_segments_c_: 3 3 2 2 1 0 0 1 2 0
// doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 (stays the same)
caching_device_vector<uint32_t> group_segments_c(item_size);
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
dh::tcbegin(group_segments), group_segments_c.begin());

// Now, sort the group segments so that you may bring the items within the group together,
// in the process also noting the relative changes to the doriginal_pos_ while that happens
// group_segments_c_: 0 0 0 1 1 2 2 2 3 3
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9
thrust::stable_sort_by_key(thrust::cuda::par(alloc),
group_segments_c.begin(), group_segments_c.end(),
doriginal_pos_.begin(), thrust::less<uint32_t>());

// Finally, gather the original items based on doriginal_pos_ to sort the input and
// to store them in ditems_
// doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 (stays the same)
// ditems_: 1 1 0 2 1 3 3 1 4 4 (from unsorted items - ditems)
thrust::gather(doriginal_pos_.begin(), doriginal_pos_.end(),
thrust::device_ptr<const T>(ditems), ditems_.begin());
}

// Determine where an item that was originally present at position 'x' has been relocated to
// after a sort. Creation of such an index has to be explicitly requested after a sort
void CreateIndexableSortedPositions() {
dindexable_sorted_pos_.resize(GetNumItems());
thrust::scatter(thrust::make_counting_iterator(static_cast<uint32_t>(0)),
thrust::make_counting_iterator(GetNumItems()), // Rearrange indices...
// ...based on this map
dh::tcbegin(GetOriginalPositionsSpan()),
dindexable_sorted_pos_.begin()); // Write results into this
}
};

// Atomic add function for gradients
template <typename OutputGradientT, typename InputGradientT>
XGBOOST_DEV_INLINE void AtomicAddGpair(OutputGradientT* dest,
Expand Down
24 changes: 2 additions & 22 deletions src/metric/metric.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,32 +52,13 @@ Metric::Create(const std::string& name, Context const* ctx) {
metric->ctx_ = ctx;
return metric;
}

GPUMetric* GPUMetric::CreateGPUMetric(const std::string& name, Context const* ctx) {
auto metric = CreateMetricImpl<MetricGPUReg>(name);
if (metric == nullptr) {
LOG(WARNING) << "Cannot find a GPU metric builder for metric " << name
<< ". Resorting to the CPU builder";
return nullptr;
}

// Narrowing reference only for the compiler to allow assignment to a base class member.
// As such, using this narrowed reference to refer to derived members will be an illegal op.
// This is moot, as this type is stateless.
auto casted = static_cast<GPUMetric*>(metric);
CHECK(casted);
casted->ctx_ = ctx;
return casted;
}
} // namespace xgboost

namespace dmlc {
DMLC_REGISTRY_ENABLE(::xgboost::MetricReg);
DMLC_REGISTRY_ENABLE(::xgboost::MetricGPUReg);
}

namespace xgboost {
namespace metric {
namespace xgboost::metric {
// List of files that will be force linked in static links.
DMLC_REGISTRY_LINK_TAG(auc);
DMLC_REGISTRY_LINK_TAG(elementwise_metric);
Expand All @@ -88,5 +69,4 @@ DMLC_REGISTRY_LINK_TAG(rank_metric);
DMLC_REGISTRY_LINK_TAG(auc_gpu);
DMLC_REGISTRY_LINK_TAG(rank_metric_gpu);
#endif
} // namespace metric
} // namespace xgboost
} // namespace xgboost::metric
56 changes: 8 additions & 48 deletions src/metric/metric_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,53 +23,14 @@ class MetricNoCache : public Metric {

double Evaluate(HostDeviceVector<float> const &predts, std::shared_ptr<DMatrix> p_fmat) final {
double result{0.0};
auto const& info = p_fmat->Info();
collective::ApplyWithLabels(info, &result, sizeof(double), [&] {
result = this->Eval(predts, info);
});
auto const &info = p_fmat->Info();
collective::ApplyWithLabels(info, &result, sizeof(double),
[&] { result = this->Eval(predts, info); });
return result;
}
};

// This creates a GPU metric instance dynamically and adds it to the GPU metric registry, if not
// present already. This is created when there is a device ordinal present and if xgboost
// is compiled with CUDA support
struct GPUMetric : public MetricNoCache {
static GPUMetric *CreateGPUMetric(const std::string &name, Context const *tparam);
};

/*!
* \brief Internal registry entries for GPU Metric factory functions.
* The additional parameter const char* param gives the value after @, can be null.
* For example, metric map@3, then: param == "3".
*/
struct MetricGPUReg
: public dmlc::FunctionRegEntryBase<MetricGPUReg,
std::function<Metric * (const char*)> > {
};

/*!
* \brief Macro to register metric computed on GPU.
*
* \code
* // example of registering a objective ndcg@k
* XGBOOST_REGISTER_GPU_METRIC(NDCG_GPU, "ndcg")
* .describe("NDCG metric computer on GPU.")
* .set_body([](const char* param) {
* int at_k = atoi(param);
* return new NDCG(at_k);
* });
* \endcode
*/

// Note: Metric names registered in the GPU registry should follow this convention:
// - GPU metric types should be registered with the same name as the non GPU metric types
#define XGBOOST_REGISTER_GPU_METRIC(UniqueId, Name) \
::xgboost::MetricGPUReg& __make_ ## MetricGPUReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::MetricGPUReg>::Get()->__REGISTER__(Name)

namespace metric {

// Ranking config to be used on device and host
struct EvalRankConfig {
public:
Expand All @@ -81,8 +42,8 @@ struct EvalRankConfig {
};

class PackedReduceResult {
double residue_sum_ { 0 };
double weights_sum_ { 0 };
double residue_sum_{0};
double weights_sum_{0};

public:
XGBOOST_DEVICE PackedReduceResult() {} // NOLINT
Expand All @@ -91,16 +52,15 @@ class PackedReduceResult {

XGBOOST_DEVICE
PackedReduceResult operator+(PackedReduceResult const &other) const {
return PackedReduceResult{residue_sum_ + other.residue_sum_,
weights_sum_ + other.weights_sum_};
return PackedReduceResult{residue_sum_ + other.residue_sum_, weights_sum_ + other.weights_sum_};
}
PackedReduceResult &operator+=(PackedReduceResult const &other) {
this->residue_sum_ += other.residue_sum_;
this->weights_sum_ += other.weights_sum_;
return *this;
}
double Residue() const { return residue_sum_; }
double Weights() const { return weights_sum_; }
[[nodiscard]] double Residue() const { return residue_sum_; }
[[nodiscard]] double Weights() const { return weights_sum_; }
};

} // namespace metric
Expand Down
Loading

0 comments on commit 3005f21

Please sign in to comment.