Skip to content

Commit

Permalink
Accept numpy array view. (dmlc#4147)
Browse files Browse the repository at this point in the history
* Accept array view (slice) in metainfo.
  • Loading branch information
trivialfis authored Feb 18, 2019
1 parent 0ff84d9 commit a985a99
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 42 deletions.
36 changes: 35 additions & 1 deletion include/xgboost/c_api.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright (c) 2015 by Contributors
* Copyright (c) 2015-2019 by Contributors
* \file c_api.h
* \author Tianqi Chen
* \brief C API of XGBoost, used for interfacing to other languages.
Expand Down Expand Up @@ -283,6 +283,23 @@ XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
const char *field,
const float *array,
bst_ulong len);
/*!
* \brief `XGDMatrixSetFloatInfo' with strided array as input.
*
* \param handle a instance of data matrix
* \param field field name, can be label, weight
* \param array pointer to float vector
* \param stride stride of input vector
* \param len length of array
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixSetFloatInfoStrided(DMatrixHandle handle,
const char *field,
const float *array,
const bst_ulong stride,
bst_ulong len);

/*!
* \brief set uint32 vector to a content in info
* \param handle a instance of data matrix
Expand All @@ -295,6 +312,23 @@ XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
const char *field,
const unsigned *array,
bst_ulong len);

/*!
* \brief `XGDMatrixSetUIntInfo' with strided array as input.
*
* \param handle a instance of data matrix
* \param field field name
* \param array pointer to unsigned int vector
* \param stride stride of input vector
* \param len length of array
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixSetUIntInfoStrided(DMatrixHandle handle,
const char *field,
const unsigned *array,
const bst_ulong stride,
bst_ulong len);
/*!
* \brief set label of the training matrix
* \param handle a instance of data matrix
Expand Down
1 change: 1 addition & 0 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class MetaInfo {
* \param num Number of elements in the source array.
*/
void SetInfo(const char* key, const void* dptr, DataType dtype, size_t num);
void SetInfo(const char* key, const void* dptr, DataType dtype, size_t stride, size_t num);

private:
/*! \brief argsort of labels */
Expand Down
50 changes: 35 additions & 15 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,17 @@ def c_array(ctype, values):
return (ctype * len(values))(*values)


def _get_length_and_stride(data):
"Return length and stride of 1-D data."
if isinstance(data, np.ndarray) and data.base is not None:
length = len(data.base)
stride = data.strides[0] // data.dtype.itemsize
else:
length = len(data)
stride = 1
return length, stride


PANDAS_DTYPE_MAPPER = {'int8': 'int', 'int16': 'int', 'int32': 'int', 'int64': 'int',
'uint8': 'int', 'uint16': 'int', 'uint32': 'int', 'uint64': 'int',
'float16': 'float', 'float32': 'float', 'float64': 'float',
Expand Down Expand Up @@ -585,10 +596,13 @@ def set_float_info(self, field, data):
The array of data to be set
"""
c_data = c_array(ctypes.c_float, data)
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
c_str(field),
c_data,
c_bst_ulong(len(data))))
length, stride = _get_length_and_stride(data)
_check_call(_LIB.XGDMatrixSetFloatInfoStrided(
self.handle,
c_str(field),
c_data,
c_bst_ulong(stride),
c_bst_ulong(length)))

def set_float_info_npy2d(self, field, data):
"""Set float type property into the DMatrix
Expand All @@ -604,10 +618,13 @@ def set_float_info_npy2d(self, field, data):
"""
data = np.array(data, copy=False, dtype=np.float32)
c_data = data.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
_check_call(_LIB.XGDMatrixSetFloatInfo(self.handle,
c_str(field),
c_data,
c_bst_ulong(len(data))))
length, stride = _get_length_and_stride(data)
_check_call(_LIB.XGDMatrixSetFloatInfoStrided(
self.handle,
c_str(field),
c_data,
c_bst_ulong(stride),
c_bst_ulong(length)))

def set_uint_info(self, field, data):
"""Set uint type property into the DMatrix.
Expand All @@ -620,10 +637,15 @@ def set_uint_info(self, field, data):
data: numpy array
The array of data to be set
"""
_check_call(_LIB.XGDMatrixSetUIntInfo(self.handle,
c_str(field),
c_array(ctypes.c_uint, data),
c_bst_ulong(len(data))))
data = np.array(data, copy=False, dtype=ctypes.c_uint)
c_data = c_array(ctypes.c_uint, data)
length, stride = _get_length_and_stride(data)
_check_call(_LIB.XGDMatrixSetUIntInfoStrided(
self.handle,
c_str(field),
c_data,
c_bst_ulong(stride),
c_bst_ulong(length)))

def save_binary(self, fname, silent=True):
"""Save DMatrix to an XGBoost buffer.
Expand Down Expand Up @@ -719,9 +741,7 @@ def set_group(self, group):
group : array like
Group size of each group
"""
_check_call(_LIB.XGDMatrixSetGroup(self.handle,
c_array(ctypes.c_uint, group),
c_bst_ulong(len(group))))
self.set_uint_info('group', group)

def get_label(self):
"""Get the label of the DMatrix.
Expand Down
44 changes: 34 additions & 10 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2014 by Contributors
// Copyright (c) 2014-2019 by Contributors

#include <xgboost/data.h>
#include <xgboost/learner.h>
Expand Down Expand Up @@ -768,24 +768,48 @@ XGB_DLL int XGDMatrixSaveBinary(DMatrixHandle handle,
}

XGB_DLL int XGDMatrixSetFloatInfo(DMatrixHandle handle,
const char* field,
const bst_float* info,
xgboost::bst_ulong len) {
const char* field,
const xgboost::bst_float* info,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, info, kFloat32, len);
API_END();
}

XGB_DLL int XGDMatrixSetFloatInfoStrided(DMatrixHandle handle,
const char* field,
const xgboost::bst_float* info,
const xgboost::bst_ulong stride,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, info, kFloat32, stride, len);
API_END();
}

XGB_DLL int XGDMatrixSetUIntInfo(DMatrixHandle handle,
const char* field,
const unsigned* info,
xgboost::bst_ulong len) {
const char* field,
const unsigned* array,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, array, kUInt32, len);
API_END();
}

XGB_DLL int XGDMatrixSetUIntInfoStrided(DMatrixHandle handle,
const char* field,
const unsigned* array,
const xgboost::bst_ulong stride,
xgboost::bst_ulong len) {
API_BEGIN();
CHECK_HANDLE();
static_cast<std::shared_ptr<DMatrix>*>(handle)
->get()->Info().SetInfo(field, info, kUInt32, len);
->get()->Info().SetInfo(field, array, kUInt32, stride, len);
API_END();
}

Expand Down Expand Up @@ -864,8 +888,8 @@ XGB_DLL int XGDMatrixNumCol(const DMatrixHandle handle,

// xgboost implementation
XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
xgboost::bst_ulong len,
BoosterHandle *out) {
xgboost::bst_ulong len,
BoosterHandle *out) {
API_BEGIN();
std::vector<std::shared_ptr<DMatrix> > mats;
for (xgboost::bst_ulong i = 0; i < len; ++i) {
Expand Down
56 changes: 40 additions & 16 deletions src/data/data.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2015 by Contributors
* Copyright 2015-2019 by Contributors
* \file data.cc
*/
#include <xgboost/data.h>
Expand Down Expand Up @@ -100,53 +100,77 @@ inline bool MetaTryLoadFloatInfo(const std::string& fname,
#define DISPATCH_CONST_PTR(dtype, old_ptr, cast_ptr, proc) \
switch (dtype) { \
case kFloat32: { \
auto cast_ptr = reinterpret_cast<const float*>(old_ptr); proc; break; \
auto cast_ptr = reinterpret_cast<const float*>(old_ptr); proc; \
break; \
} \
case kDouble: { \
auto cast_ptr = reinterpret_cast<const double*>(old_ptr); proc; break; \
auto cast_ptr = reinterpret_cast<const double*>(old_ptr); proc; \
break; \
} \
case kUInt32: { \
auto cast_ptr = reinterpret_cast<const uint32_t*>(old_ptr); proc; break; \
auto cast_ptr = reinterpret_cast<const uint32_t*>(old_ptr); proc; \
break; \
} \
case kUInt64: { \
auto cast_ptr = reinterpret_cast<const uint64_t*>(old_ptr); proc; break; \
auto cast_ptr = reinterpret_cast<const uint64_t*>(old_ptr); proc; \
break; \
} \
default: LOG(FATAL) << "Unknown data type" << dtype; \
} \


void MetaInfo::SetInfo(const char* key, const void* dptr, DataType dtype, size_t num) {
this->SetInfo(key, dptr, dtype, 1, num);
}

template <typename IterIn, typename IterOut>
void StridedCopy(IterIn in_beg, IterIn in_end, IterOut out_beg, size_t stride) {
if (stride != 1) {
IterOut out_iter = out_beg;
for (IterIn in_iter = in_beg; in_iter < in_end; in_iter += stride) {
*out_iter = *in_iter;
out_iter++;
}
} else {
// There can be builtin optimization in std::copy
std::copy(in_beg, in_end, out_beg);
}
}

void MetaInfo::SetInfo(
const char* key, const void* dptr, DataType dtype, size_t stride, size_t num) {
size_t view_length =
static_cast<size_t>(std::ceil(static_cast<bst_float>(num) / stride));
if (!std::strcmp(key, "root_index")) {
root_index_.resize(num);
root_index_.resize(view_length);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, root_index_.begin()));
StridedCopy(cast_dptr, cast_dptr + num, root_index_.begin(), stride));
} else if (!std::strcmp(key, "label")) {
auto& labels = labels_.HostVector();
labels.resize(num);
labels.resize(view_length);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, labels.begin()));
StridedCopy(cast_dptr, cast_dptr + num, labels.begin(), stride));
} else if (!std::strcmp(key, "weight")) {
auto& weights = weights_.HostVector();
weights.resize(num);
weights.resize(view_length);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, weights.begin()));
StridedCopy(cast_dptr, cast_dptr + num, weights.begin(), stride));
} else if (!std::strcmp(key, "base_margin")) {
auto& base_margin = base_margin_.HostVector();
base_margin.resize(num);
base_margin.resize(view_length);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, base_margin.begin()));
StridedCopy(cast_dptr, cast_dptr + num, base_margin.begin(), stride));
} else if (!std::strcmp(key, "group")) {
group_ptr_.resize(num + 1);
group_ptr_.resize(view_length+1);
DISPATCH_CONST_PTR(dtype, dptr, cast_dptr,
std::copy(cast_dptr, cast_dptr + num, group_ptr_.begin() + 1));
StridedCopy(cast_dptr, cast_dptr + num, group_ptr_.begin() + 1, stride));
group_ptr_[0] = 0;
for (size_t i = 1; i < group_ptr_.size(); ++i) {
group_ptr_[i] = group_ptr_[i - 1] + group_ptr_[i];
}
}
}


DMatrix* DMatrix::Load(const std::string& uri,
bool silent,
bool load_row_split,
Expand Down
6 changes: 6 additions & 0 deletions tests/python/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ def test_basic(self):
# assert they are the same
assert np.sum(np.abs(preds2 - preds)) == 0

def test_np_view(self):
y = np.array([12, 34, 56], np.float32)[::2]
from_view = xgb.DMatrix([], label=y).get_label()
from_array = xgb.DMatrix([], label=y + 0).get_label()
assert (from_view == from_array).all()

def test_record_results(self):
dtrain = xgb.DMatrix(dpath + 'agaricus.txt.train')
dtest = xgb.DMatrix(dpath + 'agaricus.txt.test')
Expand Down

0 comments on commit a985a99

Please sign in to comment.