From 5b039bc968d82a96ec52322115687bfbdc530148 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 2 Jun 2023 23:45:59 +0800 Subject: [PATCH] Support linalg data structures in check device. --- src/data/data.cc | 50 +++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/src/data/data.cc b/src/data/data.cc index f9886b2f0ca3..00cff8ab0929 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -7,14 +7,15 @@ #include #include +#include #include #include "../collective/communicator-inl.h" #include "../collective/communicator.h" -#include "../common/common.h" #include "../common/algorithm.h" // for StableSort #include "../common/api_entry.h" // for XGBAPIThreadLocalEntry -#include "../common/error_msg.h" // for InfInData +#include "../common/common.h" +#include "../common/error_msg.h" // for InfInData, GroupWeight, GroupSize #include "../common/group_data.h" #include "../common/io.h" #include "../common/linalg_op.h" @@ -35,6 +36,7 @@ #include "xgboost/context.h" #include "xgboost/host_device_vector.h" #include "xgboost/learner.h" +#include "xgboost/linalg.h" // Vector #include "xgboost/logging.h" #include "xgboost/string_view.h" #include "xgboost/version_config.h" @@ -491,7 +493,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) { } // uint info if (key == "group") { - linalg::Tensor t; + linalg::Vector t; CopyTensorInfoImpl(ctx, arr, &t); auto const& h_groups = t.Data()->HostVector(); group_ptr_.clear(); @@ -516,6 +518,7 @@ void MetaInfo::SetInfoFromHost(Context const& ctx, StringView key, Json arr) { data::ValidateQueryGroup(group_ptr_); return; } + // float info linalg::Tensor t; CopyTensorInfoImpl<1>(ctx, arr, &t); @@ -717,58 +720,63 @@ void MetaInfo::SynchronizeNumberOfColumns() { } } +namespace { +template +void CheckDevice(std::int32_t device, HostDeviceVector const& v) { + CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device) + << "Data is resided on a different device than `gpu_id`. " + << "Device that data is on: " << v.DeviceIdx() << ", " + << "`gpu_id` for XGBoost: " << device; +} +template +void CheckDevice(std::int32_t device, linalg::Tensor const& v) { + CheckDevice(device, *v.Data()); +} +} // anonymous namespace + void MetaInfo::Validate(std::int32_t device) const { if (group_ptr_.size() != 0 && weights_.Size() != 0) { - CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) - << "Size of weights must equal to number of groups when ranking " - "group is used."; + CHECK_EQ(group_ptr_.size(), weights_.Size() + 1) << error::GroupWeight(); return; } if (group_ptr_.size() != 0) { CHECK_EQ(group_ptr_.back(), num_row_) - << "Invalid group structure. Number of rows obtained from groups " - "doesn't equal to actual number of rows given by data."; + << error::GroupSize() << "the actual number of rows given by data."; } - auto check_device = [device](HostDeviceVector const& v) { - CHECK(v.DeviceIdx() == Context::kCpuId || device == Context::kCpuId || v.DeviceIdx() == device) - << "Data is resided on a different device than `gpu_id`. " - << "Device that data is on: " << v.DeviceIdx() << ", " - << "`gpu_id` for XGBoost: " << device; - }; if (weights_.Size() != 0) { CHECK_EQ(weights_.Size(), num_row_) << "Size of weights must equal to number of rows."; - check_device(weights_); + CheckDevice(device, weights_); return; } if (labels.Size() != 0) { CHECK_EQ(labels.Shape(0), num_row_) << "Size of labels must equal to number of rows."; - check_device(*labels.Data()); + CheckDevice(device, labels); return; } if (labels_lower_bound_.Size() != 0) { CHECK_EQ(labels_lower_bound_.Size(), num_row_) << "Size of label_lower_bound must equal to number of rows."; - check_device(labels_lower_bound_); + CheckDevice(device, labels_lower_bound_); return; } if (feature_weights.Size() != 0) { CHECK_EQ(feature_weights.Size(), num_col_) << "Size of feature_weights must equal to number of columns."; - check_device(feature_weights); + CheckDevice(device, feature_weights); } if (labels_upper_bound_.Size() != 0) { CHECK_EQ(labels_upper_bound_.Size(), num_row_) << "Size of label_upper_bound must equal to number of rows."; - check_device(labels_upper_bound_); + CheckDevice(device, labels_upper_bound_); return; } CHECK_LE(num_nonzero_, num_col_ * num_row_); if (base_margin_.Size() != 0) { CHECK_EQ(base_margin_.Size() % num_row_, 0) << "Size of base margin must be a multiple of number of rows."; - check_device(*base_margin_.Data()); + CheckDevice(device, base_margin_); } } @@ -1028,6 +1036,8 @@ SparsePage SparsePage::GetTranspose(int num_columns, int32_t n_threads) const { bool SparsePage::IsIndicesSorted(int32_t n_threads) const { auto& h_offset = this->offset.HostVector(); auto& h_data = this->data.HostVector(); + n_threads = std::max(std::min(static_cast(n_threads), this->Size()), + static_cast(1)); std::vector is_sorted_tloc(n_threads, 0); common::ParallelFor(this->Size(), n_threads, [&](auto i) { auto beg = h_offset[i];