Skip to content

Commit

Permalink
Use static dim check for the array interface handler. (#11069)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Dec 7, 2024
1 parent 6d55fac commit 4b2001e
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 69 deletions.
4 changes: 2 additions & 2 deletions jvm-packages/xgboost4j/src/native/xgboost4j-gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void CopyInterface(std::vector<xgboost::ArrayInterface<1>> &interface_arr,
Json{Boolean{false}}};

out["data"] = Array(std::move(j_data));
out["shape"] = Array(std::vector<Json>{Json(Integer(interface.Shape(0)))});
out["shape"] = Array(std::vector<Json>{Json(Integer(interface.Shape<0>()))});

if (interface.valid.Data()) {
CopyColumnMask(interface, columns, kind, c, &mask, &out, stream);
Expand All @@ -113,7 +113,7 @@ void CopyMetaInfo(Json *p_interface, dh::device_vector<T> *out, cudaStream_t str
CHECK_EQ(get<Array const>(j_interface).size(), 1);
auto object = get<Object>(get<Array>(j_interface)[0]);
ArrayInterface<1> interface(object);
out->resize(interface.Shape(0));
out->resize(interface.Shape<0>());
size_t element_size = interface.ElementSize();
size_t size = element_size * interface.n;
dh::safe_cuda(cudaMemcpyAsync(RawPtr(*out), interface.data, size,
Expand Down
10 changes: 5 additions & 5 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1520,20 +1520,20 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetQuanti

ArrayInterface<1> indptr{StringView{str_indptr}};
ArrayInterface<1> data{StringView{str_data}};
CHECK_GE(indptr.Shape(0), 2);
CHECK_GE(indptr.Shape<0>(), 2);

// Cut ptr
auto j_indptr_array = jenv->NewLongArray(indptr.Shape(0));
auto j_indptr_array = jenv->NewLongArray(indptr.Shape<0>());
CHECK_EQ(indptr.type, ArrayInterfaceHandler::Type::kU8);
CHECK_LT(indptr(indptr.Shape(0) - 1),
CHECK_LT(indptr(indptr.Shape<0>() - 1),
static_cast<std::uint64_t>(std::numeric_limits<std::int64_t>::max()));
static_assert(sizeof(jlong) == sizeof(std::uint64_t));
jenv->SetLongArrayRegion(j_indptr_array, 0, indptr.Shape(0),
jenv->SetLongArrayRegion(j_indptr_array, 0, indptr.Shape<0>(),
static_cast<jlong const *>(indptr.data));
jenv->SetObjectArrayElement(j_indptr, 0, j_indptr_array);

// Cut values
auto n_cuts = indptr(indptr.Shape(0) - 1);
auto n_cuts = indptr(indptr.Shape<0>() - 1);
jfloatArray jcuts_array = jenv->NewFloatArray(n_cuts);
CHECK_EQ(data.type, ArrayInterfaceHandler::Type::kF4);
jenv->SetFloatArrayRegion(jcuts_array, 0, n_cuts, static_cast<float const *>(data.data));
Expand Down
8 changes: 4 additions & 4 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1098,18 +1098,18 @@ XGB_DLL int XGBoosterTrainOneIter(BoosterHandle handle, DMatrixHandle dtrain, in
ArrayInterface<2, false> i_grad{StringView{grad}};
ArrayInterface<2, false> i_hess{StringView{hess}};
StringView msg{"Mismatched shape between the gradient and hessian."};
CHECK_EQ(i_grad.Shape(0), i_hess.Shape(0)) << msg;
CHECK_EQ(i_grad.Shape(1), i_hess.Shape(1)) << msg;
CHECK_EQ(i_grad.Shape<0>(), i_hess.Shape<0>()) << msg;
CHECK_EQ(i_grad.Shape<1>(), i_hess.Shape<1>()) << msg;
linalg::Matrix<GradientPair> gpair;
auto grad_is_cuda = ArrayInterfaceHandler::IsCudaPtr(i_grad.data);
auto hess_is_cuda = ArrayInterfaceHandler::IsCudaPtr(i_hess.data);
CHECK_EQ(i_grad.Shape(0), p_fmat->Info().num_row_)
CHECK_EQ(i_grad.Shape<0>(), p_fmat->Info().num_row_)
<< "Mismatched size between the gradient and training data.";
CHECK_EQ(grad_is_cuda, hess_is_cuda) << "gradient and hessian should be on the same device.";
auto *learner = static_cast<Learner *>(handle);
auto ctx = learner->Ctx();
if (!grad_is_cuda) {
gpair.Reshape(i_grad.Shape(0), i_grad.Shape(1));
gpair.Reshape(i_grad.Shape<0>(), i_grad.Shape<1>());
auto h_gpair = gpair.HostView();
DispatchDType(i_grad, DeviceOrd::CPU(), [&](auto &&t_grad) {
DispatchDType(i_hess, DeviceOrd::CPU(), [&](auto &&t_hess) {
Expand Down
4 changes: 2 additions & 2 deletions src/c_api/c_api.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2019-2023, XGBoost Contributors
* Copyright 2019-2024, XGBoost Contributors
*/
#include <thrust/transform.h> // for transform

Expand Down Expand Up @@ -78,7 +78,7 @@ void CopyGradientFromCUDAArrays(Context const *ctx, ArrayInterface<2, false> con
CHECK_EQ(grad_dev, hess_dev) << "gradient and hessian should be on the same device.";
auto &gpair = *out_gpair;
gpair.SetDevice(DeviceOrd::CUDA(grad_dev));
gpair.Reshape(grad.Shape(0), grad.Shape(1));
gpair.Reshape(grad.Shape<0>(), grad.Shape<1>());
auto d_gpair = gpair.View(DeviceOrd::CUDA(grad_dev));
auto cuctx = ctx->CUDACtx();

Expand Down
48 changes: 24 additions & 24 deletions src/data/adapter.h
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
/**
* Copyright 2019-2023, XGBoost Contributors
* Copyright 2019-2024, XGBoost Contributors
* \file adapter.h
*/
#ifndef XGBOOST_DATA_ADAPTER_H_
#define XGBOOST_DATA_ADAPTER_H_
#include <dmlc/data.h>

#include <algorithm>
#include <cstddef> // for size_t
#include <functional>
#include <limits>
#include <map>
#include <memory>
#include <string>
#include <utility> // std::move
#include <vector>

#include "../common/error_msg.h" // for MaxFeatureSize
#include <algorithm> // for transform, all_of
#include <cmath> // for isfinite
#include <cstddef> // for size_t
#include <cstdint> // for uint8_t
#include <iterator> // for back_inserter
#include <limits> // for numeric_limits
#include <memory> // for unique_ptr, make_unique
#include <string> // for string
#include <utility> // for move
#include <vector> // for vector

#include "../common/math.h"
#include "array_interface.h"
#include "xgboost/base.h"
Expand Down Expand Up @@ -256,7 +256,7 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
Line(ArrayInterface<2> array_interface, size_t ridx)
: array_interface_{std::move(array_interface)}, ridx_{ridx} {}

size_t Size() const { return array_interface_.Shape(1); }
size_t Size() const { return array_interface_.Shape<1>(); }

COOTuple GetElement(size_t idx) const {
return {ridx_, idx, array_interface_(ridx_, idx)};
Expand All @@ -269,8 +269,8 @@ class ArrayAdapterBatch : public detail::NoMetaInfo {
return Line{array_interface_, idx};
}

[[nodiscard]] std::size_t NumRows() const { return array_interface_.Shape(0); }
[[nodiscard]] std::size_t NumCols() const { return array_interface_.Shape(1); }
[[nodiscard]] std::size_t NumRows() const { return array_interface_.Shape<0>(); }
[[nodiscard]] std::size_t NumCols() const { return array_interface_.Shape<1>(); }
[[nodiscard]] std::size_t Size() const { return this->NumRows(); }

explicit ArrayAdapterBatch(ArrayInterface<2> array_interface)
Expand All @@ -290,8 +290,8 @@ class ArrayAdapter : public detail::SingleBatchDataIter<ArrayAdapterBatch> {
batch_ = ArrayAdapterBatch{array_interface_};
}
[[nodiscard]] ArrayAdapterBatch const& Value() const override { return batch_; }
[[nodiscard]] std::size_t NumRows() const { return array_interface_.Shape(0); }
[[nodiscard]] std::size_t NumColumns() const { return array_interface_.Shape(1); }
[[nodiscard]] std::size_t NumRows() const { return array_interface_.Shape<0>(); }
[[nodiscard]] std::size_t NumColumns() const { return array_interface_.Shape<1>(); }

private:
ArrayAdapterBatch batch_;
Expand Down Expand Up @@ -321,7 +321,7 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
}

[[nodiscard]] std::size_t Size() const {
return values_.Shape(0);
return values_.Shape<0>();
}
};

Expand All @@ -339,7 +339,7 @@ class CSRArrayAdapterBatch : public detail::NoMetaInfo {
}

size_t NumRows() const {
size_t size = indptr_.Shape(0);
size_t size = indptr_.Shape<0>();
size = size == 0 ? 0 : size - 1;
return size;
}
Expand Down Expand Up @@ -381,9 +381,9 @@ class CSRArrayAdapter : public detail::SingleBatchDataIter<CSRArrayAdapterBatch>
return batch_;
}
size_t NumRows() const {
size_t size = indptr_.Shape(0);
size_t size = indptr_.Shape<0>();
size = size == 0 ? 0 : size - 1;
return size;
return size;
}
size_t NumColumns() const { return num_cols_; }

Expand Down Expand Up @@ -479,7 +479,7 @@ class CSCArrayAdapterBatch : public detail::NoMetaInfo {
values_{std::move(values)},
offset_{offset} {}

std::size_t Size() const { return values_.Shape(0); }
std::size_t Size() const { return values_.Shape<0>(); }
COOTuple GetElement(std::size_t idx) const {
return {TypedIndex<std::size_t, 1>{row_idx_}(offset_ + idx), column_idx_,
values_(offset_ + idx)};
Expand Down Expand Up @@ -684,7 +684,7 @@ class ColumnarAdapterBatch : public detail::NoMetaInfo {
: columns_{columns} {}
[[nodiscard]] Line GetLine(std::size_t ridx) const { return Line{columns_, ridx}; }
[[nodiscard]] std::size_t Size() const {
return columns_.empty() ? 0 : columns_.front().Shape(0);
return columns_.empty() ? 0 : columns_.front().Shape<0>();
}
[[nodiscard]] std::size_t NumCols() const { return columns_.empty() ? 0 : columns_.size(); }
[[nodiscard]] std::size_t NumRows() const { return this->Size(); }
Expand All @@ -707,7 +707,7 @@ class ColumnarAdapter : public detail::SingleBatchDataIter<ColumnarAdapterBatch>
bool consistent =
columns_.empty() ||
std::all_of(columns_.cbegin(), columns_.cend(), [&](ArrayInterface<1, false> const& array) {
return array.Shape(0) == columns_[0].Shape(0);
return array.Shape<0>() == columns_[0].Shape<0>();
});
CHECK(consistent) << "Size of columns should be the same.";
batch_ = ColumnarAdapterBatch{columns_};
Expand Down
12 changes: 10 additions & 2 deletions src/data/array_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,16 @@ class ArrayInterface {
}
}

[[nodiscard]] XGBOOST_DEVICE std::size_t Shape(size_t i) const { return shape[i]; }
[[nodiscard]] XGBOOST_DEVICE std::size_t Stride(size_t i) const { return strides[i]; }
template <std::size_t i>
[[nodiscard]] XGBOOST_DEVICE std::size_t Shape() const {
static_assert(i < D);
return shape[i];
}
template <std::size_t i>
[[nodiscard]] XGBOOST_DEVICE std::size_t Stride() const {
static_assert(i < D);
return strides[i];
}

template <typename Fn>
XGBOOST_HOST_DEV_INLINE decltype(auto) DispatchCall(Fn func) const {
Expand Down
16 changes: 8 additions & 8 deletions src/data/data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ void CopyGroupInfoImpl(ArrayInterface<1> column, std::vector<bst_group_t>* out)

auto ptr_device = SetDeviceToPtr(column.data);
CHECK_EQ(ptr_device, dh::CurrentDevice());
dh::TemporaryArray<bst_group_t> temp(column.Shape(0));
dh::TemporaryArray<bst_group_t> temp(column.Shape<0>());
auto d_tmp = temp.data().get();

dh::LaunchN(column.Shape(0),
dh::LaunchN(column.Shape<0>(),
[=] __device__(size_t idx) { d_tmp[idx] = TypedIndex<size_t, 1>{column}(idx); });
auto length = column.Shape(0);
auto length = column.Shape<0>();
out->resize(length + 1);
out->at(0) = 0;
thrust::copy(temp.data(), temp.data() + length, out->begin() + 1);
Expand All @@ -93,7 +93,7 @@ void CopyQidImpl(Context const* ctx, ArrayInterface<1> array_interface,
auto d = DeviceOrd::CUDA(SetDeviceToPtr(array_interface.data));
auto cuctx = ctx->CUDACtx();
dh::LaunchN(1, cuctx->Stream(), [=] __device__(size_t) { d_flag[0] = true; });
dh::LaunchN(array_interface.Shape(0) - 1, cuctx->Stream(), [=] __device__(size_t i) {
dh::LaunchN(array_interface.Shape<0>() - 1, cuctx->Stream(), [=] __device__(size_t i) {
auto typed = TypedIndex<uint32_t, 1>{array_interface};
if (typed(i) > typed(i + 1)) {
d_flag[0] = false;
Expand All @@ -104,15 +104,15 @@ void CopyQidImpl(Context const* ctx, ArrayInterface<1> array_interface,
cudaMemcpyDeviceToHost));
CHECK(non_dec) << "`qid` must be sorted in increasing order along with data.";
size_t bytes = 0;
dh::caching_device_vector<uint32_t> out(array_interface.Shape(0));
dh::caching_device_vector<uint32_t> cnt(array_interface.Shape(0));
dh::caching_device_vector<uint32_t> out(array_interface.Shape<0>());
dh::caching_device_vector<uint32_t> cnt(array_interface.Shape<0>());
HostDeviceVector<int> d_num_runs_out(1, 0, d);
cub::DeviceRunLengthEncode::Encode(nullptr, bytes, it, out.begin(), cnt.begin(),
d_num_runs_out.DevicePointer(), array_interface.Shape(0),
d_num_runs_out.DevicePointer(), array_interface.Shape<0>(),
cuctx->Stream());
dh::CachingDeviceUVector<char> tmp(bytes);
cub::DeviceRunLengthEncode::Encode(tmp.data(), bytes, it, out.begin(), cnt.begin(),
d_num_runs_out.DevicePointer(), array_interface.Shape(0),
d_num_runs_out.DevicePointer(), array_interface.Shape<0>(),
cuctx->Stream());

auto h_num_runs_out = d_num_runs_out.HostSpan()[0];
Expand Down
34 changes: 16 additions & 18 deletions src/data/device_adapter.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
#include "adapter.h"
#include "array_interface.h"

namespace xgboost {
namespace data {

namespace xgboost::data {
class CudfAdapterBatch : public detail::NoMetaInfo {
friend class CudfAdapter;

Expand Down Expand Up @@ -114,7 +112,7 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat();
std::vector<ArrayInterface<1>> columns;
auto first_column = ArrayInterface<1>(get<Object const>(json_columns[0]));
num_rows_ = first_column.Shape(0);
num_rows_ = first_column.Shape<0>();
if (num_rows_ == 0) {
return;
}
Expand All @@ -124,12 +122,12 @@ class CudfAdapter : public detail::SingleBatchDataIter<CudfAdapterBatch> {
dh::safe_cuda(cudaSetDevice(device_.ordinal));
for (auto& json_col : json_columns) {
auto column = ArrayInterface<1>(get<Object const>(json_col));
n_bytes_ += column.ElementSize() * column.Shape(0);
n_bytes_ += column.ElementSize() * column.Shape<0>();
columns.push_back(column);
num_rows_ = std::max(num_rows_, column.Shape(0));
num_rows_ = std::max(num_rows_, column.Shape<0>());
CHECK_EQ(device_.ordinal, dh::CudaGetPointerDevice(column.data))
<< "All columns should use the same device.";
CHECK_EQ(num_rows_, column.Shape(0))
CHECK_EQ(num_rows_, column.Shape<0>())
<< "All columns should have same number of rows.";
}
columns_ = columns;
Expand Down Expand Up @@ -161,12 +159,13 @@ class CupyAdapterBatch : public detail::NoMetaInfo {
CupyAdapterBatch() = default;
explicit CupyAdapterBatch(ArrayInterface<2> array_interface)
: array_interface_(std::move(array_interface)) {}
// The total number of elements.
[[nodiscard]] std::size_t Size() const {
return array_interface_.Shape(0) * array_interface_.Shape(1);
return array_interface_.Shape<0>() * array_interface_.Shape<1>();
}
[[nodiscard]]__device__ COOTuple GetElement(size_t idx) const {
size_t column_idx = idx % array_interface_.Shape(1);
size_t row_idx = idx / array_interface_.Shape(1);
size_t column_idx = idx % array_interface_.Shape<1>();
size_t row_idx = idx / array_interface_.Shape<1>();
float value = array_interface_(row_idx, column_idx);
return {row_idx, column_idx, value};
}
Expand All @@ -175,8 +174,8 @@ class CupyAdapterBatch : public detail::NoMetaInfo {
return value;
}

[[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return array_interface_.Shape(0); }
[[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return array_interface_.Shape(1); }
[[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return array_interface_.Shape<0>(); }
[[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return array_interface_.Shape<1>(); }

private:
ArrayInterface<2> array_interface_;
Expand All @@ -188,20 +187,20 @@ class CupyAdapter : public detail::SingleBatchDataIter<CupyAdapterBatch> {
Json json_array_interface = Json::Load(cuda_interface_str);
array_interface_ = ArrayInterface<2>(get<Object const>(json_array_interface));
batch_ = CupyAdapterBatch(array_interface_);
if (array_interface_.Shape(0) == 0) {
if (array_interface_.Shape<0>() == 0) {
return;
}
device_ = DeviceOrd::CUDA(dh::CudaGetPointerDevice(array_interface_.data));
this->n_bytes_ =
array_interface_.Shape(0) * array_interface_.Shape(1) * array_interface_.ElementSize();
array_interface_.Shape<0>() * array_interface_.Shape<1>() * array_interface_.ElementSize();
CHECK(device_.IsCUDA());
}
explicit CupyAdapter(std::string cuda_interface_str)
: CupyAdapter{StringView{cuda_interface_str}} {}
[[nodiscard]] const CupyAdapterBatch& Value() const override { return batch_; }

[[nodiscard]] std::size_t NumRows() const { return array_interface_.Shape(0); }
[[nodiscard]] std::size_t NumColumns() const { return array_interface_.Shape(1); }
[[nodiscard]] std::size_t NumRows() const { return array_interface_.Shape<0>(); }
[[nodiscard]] std::size_t NumColumns() const { return array_interface_.Shape<1>(); }
[[nodiscard]] DeviceOrd Device() const { return device_; }
[[nodiscard]] bst_idx_t SizeBytes() const { return this->n_bytes_; }

Expand Down Expand Up @@ -279,6 +278,5 @@ bool NoInfInData(Context const* ctx, AdapterBatchT const& batch, IsValidFunctor
thrust::logical_and<>{});
return valid;
}
}; // namespace data
} // namespace xgboost
} // namespace xgboost::data
#endif // XGBOOST_DATA_DEVICE_ADAPTER_H_
8 changes: 4 additions & 4 deletions tests/cpp/data/test_array_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ TEST(ArrayInterface, Initialize) {
HostDeviceVector<float> storage;
auto array = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage);
auto arr_interface = ArrayInterface<2>(StringView{array});
ASSERT_EQ(arr_interface.Shape(0), kRows);
ASSERT_EQ(arr_interface.Shape(1), kCols);
ASSERT_EQ(arr_interface.Shape<0>(), kRows);
ASSERT_EQ(arr_interface.Shape<1>(), kCols);
ASSERT_EQ(arr_interface.data, storage.ConstHostPointer());
ASSERT_EQ(arr_interface.ElementSize(), 4);
ASSERT_EQ(arr_interface.type, ArrayInterfaceHandler::kF4);
Expand Down Expand Up @@ -106,15 +106,15 @@ TEST(ArrayInterface, TrivialDim) {
{
ArrayInterface<1> arr_i{interface_str};
ASSERT_EQ(arr_i.n, kRows);
ASSERT_EQ(arr_i.Shape(0), kRows);
ASSERT_EQ(arr_i.Shape<0>(), kRows);
}

std::swap(kRows, kCols);
interface_str = RandomDataGenerator{kRows, kCols, 0}.GenerateArrayInterface(&storage);
{
ArrayInterface<1> arr_i{interface_str};
ASSERT_EQ(arr_i.n, kCols);
ASSERT_EQ(arr_i.Shape(0), kCols);
ASSERT_EQ(arr_i.Shape<0>(), kCols);
}
}

Expand Down

0 comments on commit 4b2001e

Please sign in to comment.