Skip to content

Commit

Permalink
GPU binning and compression. (#3319)
Browse files Browse the repository at this point in the history
* GPU binning and compression.

- binning and index compression are done inside the DeviceShard constructor
- in case of a DMatrix with multiple row batches, it is first converted into a single row batch
  • Loading branch information
canonizer authored and RAMitchell committed Jun 5, 2018
1 parent 3f7696f commit 286dccb
Show file tree
Hide file tree
Showing 10 changed files with 304 additions and 69 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,4 @@ List of Contributors
* [Gideon Whitehead](https://github.com/gaw89)
* [Yi-Lin Juang](https://github.com/frankyjuang)
* [Andrew Hannigan](https://github.com/andrewhannigan)
* [Andy Adinets](https://github.com/canonizer)
6 changes: 3 additions & 3 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,15 @@ struct SparseBatch {
/*! \brief feature value */
bst_float fvalue;
/*! \brief default constructor */
Entry() = default;
XGBOOST_DEVICE Entry() {}
/*!
* \brief constructor with index and value
* \param index The feature or row index.
* \param fvalue THe feature value.
*/
Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {}
XGBOOST_DEVICE Entry(bst_uint index, bst_float fvalue) : index(index), fvalue(fvalue) {}
/*! \brief reversely compare feature values */
inline static bool CmpValue(const Entry& a, const Entry& b) {
XGBOOST_DEVICE inline static bool CmpValue(const Entry& a, const Entry& b) {
return a.fvalue < b.fvalue;
}
};
Expand Down
21 changes: 21 additions & 0 deletions src/common/compressed_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
#include <cstddef>
#include <algorithm>

#ifdef __CUDACC__
#include "device_helpers.cuh"
#endif

namespace xgboost {
namespace common {

Expand Down Expand Up @@ -96,6 +100,23 @@ class CompressedBufferWriter {
}
}
}

#ifdef __CUDACC__
__device__ void AtomicWriteSymbol
(CompressedByteT* buffer, uint64_t symbol, size_t offset) {
size_t ibit_start = offset * symbol_bits_;
size_t ibit_end = (offset + 1) * symbol_bits_ - 1;
size_t ibyte_start = ibit_start / 8, ibyte_end = ibit_end / 8;

symbol <<= 7 - ibit_end % 8;
for (ptrdiff_t ibyte = ibyte_end; ibyte >= (ptrdiff_t)ibyte_start; --ibyte) {
dh::AtomicOrByte(reinterpret_cast<unsigned int*>(buffer + detail::kPadding),
ibyte, symbol & 0xff);
symbol >>= 8;
}
}
#endif

template <typename IterT>
void Write(CompressedByteT *buffer, IterT input_begin, IterT input_end) {
uint64_t tmp = 0;
Expand Down
14 changes: 14 additions & 0 deletions src/common/device_helpers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,14 @@ inline size_t AvailableMemory(int device_idx) {
return device_free;
}

inline size_t TotalMemory(int device_idx) {
size_t device_free = 0;
size_t device_total = 0;
safe_cuda(cudaSetDevice(device_idx));
dh::safe_cuda(cudaMemGetInfo(&device_free, &device_total));
return device_total;
}

/**
* \fn inline int max_shared_memory(int device_idx)
*
Expand Down Expand Up @@ -155,6 +163,12 @@ inline void CheckComputeCapability() {
}
}


DEV_INLINE void AtomicOrByte(unsigned int* __restrict__ buffer, size_t ibyte, unsigned char b) {
atomicOr(&buffer[ibyte / sizeof(unsigned int)], (unsigned int)b << (ibyte % (sizeof(unsigned int)) * 8));
}


/*
* Range iterator
*/
Expand Down
1 change: 1 addition & 0 deletions src/tree/param.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ struct TrainParam : public dmlc::Parameter<TrainParam> {
.set_lower_bound(-1)
.set_default(1)
.describe("Number of GPUs to use for multi-gpu algorithms: -1=use all GPUs");

// add alias of parameters
DMLC_DECLARE_ALIAS(reg_lambda, lambda);
DMLC_DECLARE_ALIAS(reg_alpha, alpha);
Expand Down
219 changes: 160 additions & 59 deletions src/tree/updater_gpu_hist.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
* Copyright 2017 XGBoost contributors
*/
#include <thrust/execution_policy.h>
#include <thrust/functional.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/reduce.h>
#include <thrust/sequence.h>
#include <xgboost/tree_updater.h>
Expand Down Expand Up @@ -224,6 +227,53 @@ struct CalcWeightTrainParam {
learning_rate(p.learning_rate) {}
};

// index of the first element in cuts greater than v, or n if none;
// cuts are ordered, and binary search is used
__device__ int upper_bound(const float* __restrict__ cuts, int n, float v) {
if (n == 0)
return 0;
if (cuts[n - 1] <= v)
return n;
if (cuts[0] > v)
return 0;
int left = 0, right = n - 1;
while (right - left > 1) {
int middle = left + (right - left) / 2;
if (cuts[middle] > v)
right = middle;
else
left = middle;
}
return right;
}

__global__ void compress_bin_ellpack_k
(common::CompressedBufferWriter wr, common::CompressedByteT* __restrict__ buffer,
const size_t* __restrict__ row_ptrs,
const RowBatch::Entry* __restrict__ entries,
const float* __restrict__ cuts, const size_t* __restrict__ cut_rows,
size_t base_row, size_t n_rows, size_t row_ptr_begin, size_t row_stride,
unsigned int null_gidx_value) {
size_t irow = threadIdx.x + size_t(blockIdx.x) * blockDim.x;
int ifeature = threadIdx.y + blockIdx.y * blockDim.y;
if (irow >= n_rows || ifeature >= row_stride)
return;
int row_size = static_cast<int>(row_ptrs[irow + 1] - row_ptrs[irow]);
unsigned int bin = null_gidx_value;
if (ifeature < row_size) {
RowBatch::Entry entry = entries[row_ptrs[irow] - row_ptr_begin + ifeature];
int feature = entry.index;
float fvalue = entry.fvalue;
const float *feature_cuts = &cuts[cut_rows[feature]];
int ncuts = cut_rows[feature + 1] - cut_rows[feature];
bin = upper_bound(feature_cuts, ncuts, fvalue);
if (bin >= ncuts)
bin = ncuts - 1;
bin += cut_rows[feature];
}
wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature);
}

// Manage memory for a single GPU
struct DeviceShard {
struct Segment {
Expand Down Expand Up @@ -271,74 +321,117 @@ struct DeviceShard {
dh::CubMemory temp_memory;

DeviceShard(int device_idx, int normalised_device_idx,
const common::GHistIndexMatrix& gmat, bst_uint row_begin,
bst_uint row_end, int n_bins, TrainParam param)
: device_idx(device_idx),
normalised_device_idx(normalised_device_idx),
row_begin_idx(row_begin),
row_end_idx(row_end),
n_rows(row_end - row_begin),
n_bins(n_bins),
null_gidx_value(n_bins),
param(param),
prediction_cache_initialised(false) {
// Convert to ELLPACK matrix representation
int max_elements_row = 0;
for (auto i = row_begin; i < row_end; i++) {
max_elements_row =
(std::max)(max_elements_row,
static_cast<int>(gmat.row_ptr[i + 1] - gmat.row_ptr[i]));
}
row_stride = max_elements_row;
std::vector<int> ellpack_matrix(row_stride * n_rows, null_gidx_value);

for (auto i = row_begin; i < row_end; i++) {
int row_count = 0;
for (auto j = gmat.row_ptr[i]; j < gmat.row_ptr[i + 1]; j++) {
ellpack_matrix[(i - row_begin) * row_stride + row_count] =
gmat.index[j];
row_count++;
}
}

// Allocate
bst_uint row_begin, bst_uint row_end, int n_bins, TrainParam param)
: device_idx(device_idx),
normalised_device_idx(normalised_device_idx),
row_begin_idx(row_begin),
row_end_idx(row_end),
n_rows(row_end - row_begin),
n_bins(n_bins),
null_gidx_value(n_bins),
param(param),
prediction_cache_initialised(false) {}

void Init(const common::HistCutMatrix& hmat, const RowBatch& row_batch) {
// copy cuts to the GPU
dh::safe_cuda(cudaSetDevice(device_idx));
thrust::device_vector<float> cuts_d(hmat.cut);
thrust::device_vector<size_t> cut_row_ptrs_d(hmat.row_ptr);

// find the maximum row size
thrust::device_vector<size_t> row_ptr_d(
row_batch.ind_ptr + row_begin_idx, row_batch.ind_ptr + row_end_idx + 1);

auto row_iter = row_ptr_d.begin();
auto get_size = [=] __device__(size_t row) {
return row_iter[row + 1] - row_iter[row];
}; // NOLINT
auto counting = thrust::make_counting_iterator(size_t(0));
using TransformT = thrust::transform_iterator<decltype(get_size),
decltype(counting), size_t>;
TransformT row_size_iter = TransformT(counting, get_size);
row_stride = thrust::reduce(row_size_iter, row_size_iter + n_rows, 0,
thrust::maximum<size_t>());

// allocate compressed bin data
int num_symbols = n_bins + 1;
size_t compressed_size_bytes =
common::CompressedBufferWriter::CalculateBufferSize(
ellpack_matrix.size(), num_symbols);
common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows,
num_symbols);

CHECK(!(param.max_leaves == 0 && param.max_depth == 0))
<< "Max leaves and max depth cannot both be unconstrained for "
"gpu_hist.";
ba.Allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes);

gidx_buffer.Fill(0);

// bin and compress entries in batches of rows
// use no more than 1/16th of GPU memory per batch
size_t gpu_batch_nrows = dh::TotalMemory(device_idx) /
(16 * row_stride * sizeof(RowBatch::Entry));
if (gpu_batch_nrows > n_rows) {
gpu_batch_nrows = n_rows;
}
thrust::device_vector<RowBatch::Entry> entries_d(gpu_batch_nrows * row_stride);
size_t gpu_nbatches = dh::DivRoundUp(n_rows, gpu_batch_nrows);
for (size_t gpu_batch = 0; gpu_batch < gpu_nbatches; ++gpu_batch) {
size_t batch_row_begin = gpu_batch * gpu_batch_nrows;
size_t batch_row_end = (gpu_batch + 1) * gpu_batch_nrows;
if (batch_row_end > n_rows) {
batch_row_end = n_rows;
}
size_t batch_nrows = batch_row_end - batch_row_begin;
size_t n_entries =
row_batch.ind_ptr[row_begin_idx + batch_row_end] -
row_batch.ind_ptr[row_begin_idx + batch_row_begin];
dh::safe_cuda
(cudaMemcpy
(entries_d.data().get(),
&row_batch.data_ptr[row_batch.ind_ptr[row_begin_idx + batch_row_begin]],
n_entries * sizeof(RowBatch::Entry), cudaMemcpyDefault));
dim3 block3(32, 8, 1);
dim3 grid3(dh::DivRoundUp(n_rows, block3.x),
dh::DivRoundUp(row_stride, block3.y), 1);
compress_bin_ellpack_k<<<grid3, block3>>>
(common::CompressedBufferWriter(num_symbols), gidx_buffer.Data(),
row_ptr_d.data().get() + batch_row_begin,
entries_d.data().get(), cuts_d.data().get(), cut_row_ptrs_d.data().get(),
batch_row_begin, batch_nrows,
row_batch.ind_ptr[row_begin_idx + batch_row_begin],
row_stride, null_gidx_value);

dh::safe_cuda(cudaGetLastError());
dh::safe_cuda(cudaDeviceSynchronize());
}

// free the memory that is no longer needed
row_ptr_d.resize(0);
row_ptr_d.shrink_to_fit();
entries_d.resize(0);
entries_d.shrink_to_fit();

gidx = common::CompressedIterator<uint32_t>(gidx_buffer.Data(), num_symbols);

// allocate the rest
int max_nodes =
param.max_leaves > 0 ? param.max_leaves * 2 : MaxNodesDepth(param.max_depth);
ba.Allocate(device_idx, param.silent, &gidx_buffer, compressed_size_bytes,
ba.Allocate(device_idx, param.silent,
&gpair, n_rows, &ridx, n_rows, &position, n_rows,
&prediction_cache, n_rows, &node_sum_gradients_d, max_nodes,
&feature_segments, gmat.cut->row_ptr.size(), &gidx_fvalue_map,
gmat.cut->cut.size(), &min_fvalue, gmat.cut->min_val.size(),
&feature_segments, hmat.row_ptr.size(), &gidx_fvalue_map,
hmat.cut.size(), &min_fvalue, hmat.min_val.size(),
&monotone_constraints, param.monotone_constraints.size());
gidx_fvalue_map = gmat.cut->cut;
min_fvalue = gmat.cut->min_val;
feature_segments = gmat.cut->row_ptr;
gidx_fvalue_map = hmat.cut;
min_fvalue = hmat.min_val;
feature_segments = hmat.row_ptr;
monotone_constraints = param.monotone_constraints;

node_sum_gradients.resize(max_nodes);
ridx_segments.resize(max_nodes);

// Compress gidx
common::CompressedBufferWriter cbw(num_symbols);
std::vector<common::CompressedByteT> host_buffer(gidx_buffer.Size());
cbw.Write(host_buffer.data(), ellpack_matrix.begin(), ellpack_matrix.end());
gidx_buffer = host_buffer;
gidx =
common::CompressedIterator<uint32_t>(gidx_buffer.Data(), num_symbols);

common::CompressedIterator<uint32_t> ci_host(host_buffer.data(),
num_symbols);

// Init histogram
hist.Init(device_idx, max_nodes, gmat.cut->row_ptr.back(), param.silent);
hist.Init(device_idx, max_nodes, hmat.row_ptr.back(), param.silent);

dh::safe_cuda(cudaMallocHost(&tmp_pinned, sizeof(int64_t)));
}
Expand Down Expand Up @@ -579,8 +672,6 @@ class GPUHistMaker : public TreeUpdater {
info_ = &dmat->Info();
monitor_.Start("Quantiles", device_list_);
hmat_.Init(dmat, param_.max_bin);
gmat_.cut = &hmat_;
gmat_.Init(dmat);
monitor_.Stop("Quantiles", device_list_);
n_bins_ = hmat_.row_ptr.back();

Expand Down Expand Up @@ -609,12 +700,22 @@ class GPUHistMaker : public TreeUpdater {
row_begin = row_end;
}

// Create device shards
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
shard = std::unique_ptr<DeviceShard>(
new DeviceShard(device_list_[i], i, gmat_,
row_segments[i], row_segments[i + 1], n_bins_, param_));
});
monitor_.Start("BinningCompression", device_list_);
{
dmlc::DataIter<RowBatch>* iter = dmat->RowIterator();
iter->BeforeFirst();
CHECK(iter->Next()) << "Empty batches are not supported";
const RowBatch& batch = iter->Value();
// Create device shards
dh::ExecuteIndexShards(&shards_, [&](int i, std::unique_ptr<DeviceShard>& shard) {
shard = std::unique_ptr<DeviceShard>
(new DeviceShard(device_list_[i], i,
row_segments[i], row_segments[i + 1], n_bins_, param_));
shard->Init(hmat_, batch);
});
CHECK(!iter->Next()) << "External memory not supported";
}
monitor_.Stop("BinningCompression", device_list_);

p_last_fmat_ = dmat;
initialised_ = true;
Expand Down
Loading

0 comments on commit 286dccb

Please sign in to comment.