Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support linalg data structures in check device. #9243

Merged
merged 1 commit into from
Jun 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
#include <dmlc/registry.h>

#include <array>
#include <cstddef>
#include <cstring>

#include "../collective/communicator-inl.h"
#include "../collective/communicator.h"
#include "../common/common.h"
#include "../common/algorithm.h" // for StableSort
#include "../common/api_entry.h" // for XGBAPIThreadLocalEntry
#include "../common/error_msg.h" // for InfInData
#include "../common/common.h"
#include "../common/error_msg.h" // for InfInData, GroupWeight, GroupSize
#include "../common/group_data.h"
#include "../common/io.h"
#include "../common/linalg_op.h"
Expand All @@ -35,6 +36,7 @@
#include "xgboost/context.h"
#include "xgboost/host_device_vector.h"
#include "xgboost/learner.h"
#include "xgboost/linalg.h" // Vector
#include "xgboost/logging.h"
#include "xgboost/string_view.h"
#include "xgboost/version_config.h"
Expand Down Expand Up @@ -491,7 +493,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
}
// uint info
if (key == "group") {
linalg::Tensor<bst_group_t, 1> t;
linalg::Vector<bst_group_t> t;
CopyTensorInfoImpl(ctx, arr, &t);
auto const& h_groups = t.Data()->HostVector();
group_ptr_.clear();
Expand All @@ -516,6 +518,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) {
data::ValidateQueryGroup(group_ptr_);
return;
}

// float info
linalg::Tensor<float, 1> t;
CopyTensorInfoImpl<1>(ctx, arr, &t);
Expand Down Expand Up @@ -717,58 +720,63 @@ void MetaInfo::SynchronizeNumberOfColumns() {
}
}

namespace {
template <typename T>
void CheckDevice(std::int32_t device, HostDeviceVector<T> const& v) {
CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device)
<< "Data is resided on a different device than `gpu_id`. "
<< "Device that data is on: " << v.DeviceIdx() << ", "
<< "`gpu_id` for XGBoost: " << device;
}
template <typename T, std::int32_t D>
void CheckDevice(std::int32_t device, linalg::Tensor<T, D> const& v) {
CheckDevice(device, *v.Data());
}
} // anonymous namespace

void MetaInfo::Validate(std::int32_t device) const {
if (group_ptr_.size() != 0 && weights_.Size() != 0) {
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1)
<< "Size of weights must equal to number of groups when ranking "
"group is used.";
CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) << error::GroupWeight();
return;
}
if (group_ptr_.size() != 0) {
CHECK_EQ(group_ptr_.back(), num_row_)
<< "Invalid group structure. Number of rows obtained from groups "
"doesn't equal to actual number of rows given by data.";
<< error::GroupSize() << "the actual number of rows given by data.";
}
auto check_device = [device](HostDeviceVector<float> const& v) {
CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device)
<< "Data is resided on a different device than `gpu_id`. "
<< "Device that data is on: " << v.DeviceIdx() << ", "
<< "`gpu_id` for XGBoost: " << device;
};

if (weights_.Size() != 0) {
CHECK_EQ(weights_.Size(), num_row_)
<< "Size of weights must equal to number of rows.";
check_device(weights_);
CheckDevice(device, weights_);
return;
}
if (labels.Size() != 0) {
CHECK_EQ(labels.Shape(0), num_row_) << "Size of labels must equal to number of rows.";
check_device(*labels.Data());
CheckDevice(device, labels);
return;
}
if (labels_lower_bound_.Size() != 0) {
CHECK_EQ(labels_lower_bound_.Size(), num_row_)
<< "Size of label_lower_bound must equal to number of rows.";
check_device(labels_lower_bound_);
CheckDevice(device, labels_lower_bound_);
return;
}
if (feature_weights.Size() != 0) {
CHECK_EQ(feature_weights.Size(), num_col_)
<< "Size of feature_weights must equal to number of columns.";
check_device(feature_weights);
CheckDevice(device, feature_weights);
}
if (labels_upper_bound_.Size() != 0) {
CHECK_EQ(labels_upper_bound_.Size(), num_row_)
<< "Size of label_upper_bound must equal to number of rows.";
check_device(labels_upper_bound_);
CheckDevice(device, labels_upper_bound_);
return;
}
CHECK_LE(num_nonzero_, num_col_ * num_row_);
if (base_margin_.Size() != 0) {
CHECK_EQ(base_margin_.Size() % num_row_, 0)
<< "Size of base margin must be a multiple of number of rows.";
check_device(*base_margin_.Data());
CheckDevice(device, base_margin_);
}
}

Expand Down Expand Up @@ -1028,6 +1036,8 @@ SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const {
bool SparsePage::IsIndicesSorted(int32_t n_threads) const {
auto& h_offset = this->offset.HostVector();
auto& h_data = this->data.HostVector();
n_threads = std::max(std::min(static_cast<std::size_t>(n_threads), this->Size()),
static_cast<std::size_t>(1));
std::vector<int32_t> is_sorted_tloc(n_threads, 0);
common::ParallelFor(this->Size(), n_threads, [&](auto i) {
auto beg = h_offset[i];
Expand Down