From e6a9063344e429a3417adbfdfd9746664c9005a2 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Fri, 8 Sep 2017 15:07:29 +1200 Subject: [PATCH] Integer gradient summation for GPU histogram algorithm. (#2681) --- include/xgboost/base.h | 101 +++++++++++++++----- src/common/hist_util.h | 4 +- src/gbm/gblinear.cc | 19 ++-- src/objective/rank_obj.cc | 6 +- src/tree/param.h | 4 +- src/tree/updater_basemaker-inl.h | 4 +- src/tree/updater_colmaker.cc | 4 +- src/tree/updater_fast_hist.cc | 4 +- src/tree/updater_gpu_common.cuh | 12 +-- src/tree/updater_gpu_hist.cu | 109 +++++++++++----------- src/tree/updater_histmaker.cc | 10 +- src/tree/updater_skmaker.cc | 24 ++--- tests/cpp/helpers.cc | 4 +- tests/cpp/predictor/test_gpu_predictor.cu | 2 +- tests/cpp/xgboost_test.mk | 1 - 15 files changed, 181 insertions(+), 127 deletions(-) diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 3497d19d7bf4..cc08bfebf3eb 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -87,65 +87,116 @@ typedef uint64_t bst_ulong; // NOLINT(*) typedef float bst_float; -/*! \brief Implementation of gradient statistics pair */ +namespace detail { +/*! \brief Implementation of gradient statistics pair. Template specialisation + * may be used to overload different gradients types e.g. low precision, high + * precision, integer, floating point. */ template -struct bst_gpair_internal { +class bst_gpair_internal { /*! \brief gradient statistics */ - T grad; + T grad_; /*! \brief second order gradient statistics */ - T hess; + T hess_; - XGBOOST_DEVICE bst_gpair_internal() : grad(0), hess(0) {} + XGBOOST_DEVICE void SetGrad(float g) { grad_ = g; } + XGBOOST_DEVICE void SetHess(float h) { hess_ = h; } - XGBOOST_DEVICE bst_gpair_internal(T grad, T hess) - : grad(grad), hess(hess) {} + public: + typedef T value_t; + XGBOOST_DEVICE bst_gpair_internal() : grad_(0), hess_(0) {} + + XGBOOST_DEVICE bst_gpair_internal(float grad, float hess) { + SetGrad(grad); + SetHess(hess); + } + + // Copy constructor if of same value type + XGBOOST_DEVICE bst_gpair_internal(const bst_gpair_internal &g) + : grad_(g.grad_), hess_(g.hess_) {} + + // Copy constructor if different value type - use getters and setters to + // perform conversion template - XGBOOST_DEVICE bst_gpair_internal(bst_gpair_internal&g) - : grad(g.grad), hess(g.hess) {} + XGBOOST_DEVICE bst_gpair_internal(const bst_gpair_internal &g) { + SetGrad(g.GetGrad()); + SetHess(g.GetHess()); + } + + XGBOOST_DEVICE float GetGrad() const { return grad_; } + XGBOOST_DEVICE float GetHess() const { return hess_; } - XGBOOST_DEVICE bst_gpair_internal &operator+=(const bst_gpair_internal &rhs) { - grad += rhs.grad; - hess += rhs.hess; + XGBOOST_DEVICE bst_gpair_internal &operator+=( + const bst_gpair_internal &rhs) { + grad_ += rhs.grad_; + hess_ += rhs.hess_; return *this; } - XGBOOST_DEVICE bst_gpair_internal operator+(const bst_gpair_internal &rhs) const { + XGBOOST_DEVICE bst_gpair_internal operator+( + const bst_gpair_internal &rhs) const { bst_gpair_internal g; - g.grad = grad + rhs.grad; - g.hess = hess + rhs.hess; + g.grad_ = grad_ + rhs.grad_; + g.hess_ = hess_ + rhs.hess_; return g; } - XGBOOST_DEVICE bst_gpair_internal &operator-=(const bst_gpair_internal &rhs) { - grad -= rhs.grad; - hess -= rhs.hess; + XGBOOST_DEVICE bst_gpair_internal &operator-=( + const bst_gpair_internal &rhs) { + grad_ -= rhs.grad_; + hess_ -= rhs.hess_; return *this; } - XGBOOST_DEVICE bst_gpair_internal operator-(const bst_gpair_internal &rhs) const { + XGBOOST_DEVICE bst_gpair_internal operator-( + const bst_gpair_internal &rhs) const { bst_gpair_internal g; - g.grad = grad - rhs.grad; - g.hess = hess - rhs.hess; + g.grad_ = grad_ - rhs.grad_; + g.hess_ = hess_ - rhs.hess_; return g; } XGBOOST_DEVICE bst_gpair_internal(int value) { - *this = bst_gpair_internal(static_cast(value), static_cast(value)); + *this = bst_gpair_internal(static_cast(value), + static_cast(value)); } friend std::ostream &operator<<(std::ostream &os, const bst_gpair_internal &g) { - os << g.grad << "/" << g.hess; + os << g.grad_ << "/" << g.hess_; return os; } }; +template<> +inline XGBOOST_DEVICE float bst_gpair_internal::GetGrad() const { + return grad_ * 1e-5; +} +template<> +inline XGBOOST_DEVICE float bst_gpair_internal::GetHess() const { + return hess_ * 1e-5; +} +template<> +inline XGBOOST_DEVICE void bst_gpair_internal::SetGrad(float g) { + grad_ = g * 1e5; +} +template<> +inline XGBOOST_DEVICE void bst_gpair_internal::SetHess(float h) { + hess_ = h * 1e5; +} + +} // namespace detail + /*! \brief gradient statistics pair usually needed in gradient boosting */ -typedef bst_gpair_internal bst_gpair; +typedef detail::bst_gpair_internal bst_gpair; /*! \brief High precision gradient statistics pair */ -typedef bst_gpair_internal bst_gpair_precise; +typedef detail::bst_gpair_internal bst_gpair_precise; + + /*! \brief High precision gradient statistics pair with integer backed + * storage. Operators are associative where floating point versions are not + * associative. */ + typedef detail::bst_gpair_internal bst_gpair_integer; /*! \brief small eps gap for minimum split decision. */ const bst_float rt_eps = 1e-6f; diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 4d5456e8523c..65f367aa8fc9 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -33,8 +33,8 @@ struct GHistEntry { /*! \brief add a bst_gpair to the sum */ inline void Add(const bst_gpair& e) { - sum_grad += e.grad; - sum_hess += e.hess; + sum_grad += e.GetGrad(); + sum_hess += e.GetHess(); } /*! \brief add a GHistEntry to the sum */ diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index 5c1f2474a48b..d839532adb97 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -120,8 +120,9 @@ class GBLinear : public GradientBooster { #pragma omp parallel for schedule(static) reduction(+: sum_grad, sum_hess) for (bst_omp_uint i = 0; i < ndata; ++i) { bst_gpair &p = gpair[rowset[i] * ngroup + gid]; - if (p.hess >= 0.0f) { - sum_grad += p.grad; sum_hess += p.hess; + if (p.GetHess() >= 0.0f) { + sum_grad += p.GetGrad(); + sum_hess += p.GetHess(); } } // remove bias effect @@ -132,8 +133,8 @@ class GBLinear : public GradientBooster { #pragma omp parallel for schedule(static) for (bst_omp_uint i = 0; i < ndata; ++i) { bst_gpair &p = gpair[rowset[i] * ngroup + gid]; - if (p.hess >= 0.0f) { - p.grad += p.hess * dw; + if (p.GetHess() >= 0.0f) { + p += bst_gpair(p.GetHess() * dw, 0); } } } @@ -151,9 +152,9 @@ class GBLinear : public GradientBooster { for (bst_uint j = 0; j < col.length; ++j) { const bst_float v = col[j].fvalue; bst_gpair &p = gpair[col[j].index * ngroup + gid]; - if (p.hess < 0.0f) continue; - sum_grad += p.grad * v; - sum_hess += p.hess * v * v; + if (p.GetHess() < 0.0f) continue; + sum_grad += p.GetGrad() * v; + sum_hess += p.GetHess() * v * v; } bst_float &w = model[fid][gid]; bst_float dw = static_cast(param.learning_rate * @@ -162,8 +163,8 @@ class GBLinear : public GradientBooster { // update grad value for (bst_uint j = 0; j < col.length; ++j) { bst_gpair &p = gpair[col[j].index * ngroup + gid]; - if (p.hess < 0.0f) continue; - p.grad += p.hess * col[j].fvalue * dw; + if (p.GetHess() < 0.0f) continue; + p += bst_gpair(p.GetHess() * col[j].fvalue * dw, 0); } } } diff --git a/src/objective/rank_obj.cc b/src/objective/rank_obj.cc index bf1ef0a7a780..76ce3ad72e3f 100644 --- a/src/objective/rank_obj.cc +++ b/src/objective/rank_obj.cc @@ -109,10 +109,8 @@ class LambdaRankObj : public ObjFunction { bst_float g = p - 1.0f; bst_float h = std::max(p * (1.0f - p), eps); // accumulate gradient and hessian in both pid, and nid - gpair[pos.rindex].grad += g * w; - gpair[pos.rindex].hess += 2.0f * w * h; - gpair[neg.rindex].grad -= g * w; - gpair[neg.rindex].hess += 2.0f * w * h; + gpair[pos.rindex] += bst_gpair(g * w, 2.0f*w*h); + gpair[neg.rindex] += bst_gpair(-g * w, 2.0f*w*h); } } } diff --git a/src/tree/param.h b/src/tree/param.h index 8995c9ee9674..646955397289 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -313,7 +313,7 @@ struct XGBOOST_ALIGNAS(16) GradStats { * \brief accumulate statistics * \param p the gradient pair */ - inline void Add(bst_gpair p) { this->Add(p.grad, p.hess); } + inline void Add(bst_gpair p) { this->Add(p.GetGrad(), p.GetHess()); } /*! * \brief accumulate statistics, more complicated version * \param gpair the vector storing the gradient statistics @@ -323,7 +323,7 @@ struct XGBOOST_ALIGNAS(16) GradStats { inline void Add(const std::vector& gpair, const MetaInfo& info, bst_uint ridx) { const bst_gpair& b = gpair[ridx]; - this->Add(b.grad, b.hess); + this->Add(b.GetGrad(), b.GetHess()); } /*! \brief calculate leaf weight */ inline double CalcWeight(const TrainParam& param) const { diff --git a/src/tree/updater_basemaker-inl.h b/src/tree/updater_basemaker-inl.h index f70f63ad09ee..9f4ae48eae0d 100644 --- a/src/tree/updater_basemaker-inl.h +++ b/src/tree/updater_basemaker-inl.h @@ -140,14 +140,14 @@ class BaseMaker: public TreeUpdater { } // mark delete for the deleted datas for (size_t i = 0; i < position.size(); ++i) { - if (gpair[i].hess < 0.0f) position[i] = ~position[i]; + if (gpair[i].GetHess() < 0.0f) position[i] = ~position[i]; } // mark subsample if (param.subsample < 1.0f) { std::bernoulli_distribution coin_flip(param.subsample); auto& rnd = common::GlobalRandom(); for (size_t i = 0; i < position.size(); ++i) { - if (gpair[i].hess < 0.0f) continue; + if (gpair[i].GetHess() < 0.0f) continue; if (!coin_flip(rnd)) position[i] = ~position[i]; } } diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 18746808f183..bd4de564b8db 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -136,7 +136,7 @@ class ColMaker: public TreeUpdater { // mark delete for the deleted datas for (size_t i = 0; i < rowset.size(); ++i) { const bst_uint ridx = rowset[i]; - if (gpair[ridx].hess < 0.0f) position[ridx] = ~position[ridx]; + if (gpair[ridx].GetHess() < 0.0f) position[ridx] = ~position[ridx]; } // mark subsample if (param.subsample < 1.0f) { @@ -144,7 +144,7 @@ class ColMaker: public TreeUpdater { auto& rnd = common::GlobalRandom(); for (size_t i = 0; i < rowset.size(); ++i) { const bst_uint ridx = rowset[i]; - if (gpair[ridx].hess < 0.0f) continue; + if (gpair[ridx].GetHess() < 0.0f) continue; if (!coin_flip(rnd)) position[ridx] = ~position[ridx]; } } diff --git a/src/tree/updater_fast_hist.cc b/src/tree/updater_fast_hist.cc index 3f1c6c5ee98c..70d39b60baf7 100644 --- a/src/tree/updater_fast_hist.cc +++ b/src/tree/updater_fast_hist.cc @@ -372,13 +372,13 @@ class FastHistMaker: public TreeUpdater { std::bernoulli_distribution coin_flip(param.subsample); auto& rnd = common::GlobalRandom(); for (size_t i = 0; i < info.num_row; ++i) { - if (gpair[i].hess >= 0.0f && coin_flip(rnd)) { + if (gpair[i].GetHess() >= 0.0f && coin_flip(rnd)) { row_indices.push_back(i); } } } else { for (size_t i = 0; i < info.num_row; ++i) { - if (gpair[i].hess >= 0.0f) { + if (gpair[i].GetHess() >= 0.0f) { row_indices.push_back(i); } } diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index dba91dea6f9b..0b203558832e 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -82,8 +82,8 @@ struct DeviceDenseNode { fvalue(0.f), fidx(UNUSED_NODE), idx(nidx) { - this->root_gain = CalcGain(param, sum_gradients.grad, sum_gradients.hess); - this->weight = CalcWeight(param, sum_gradients.grad, sum_gradients.hess); + this->root_gain = CalcGain(param, sum_gradients.GetGrad(), sum_gradients.GetHess()); + this->weight = CalcWeight(param, sum_gradients.GetGrad(), sum_gradients.GetHess()); } HOST_DEV_INLINE void SetSplit(float fvalue, int fidx, DefaultDirection dir) { @@ -113,8 +113,8 @@ __device__ inline float device_calc_loss_chg( gpair_t right = parent_sum - left; - float left_gain = CalcGain(param, left.grad, left.hess); - float right_gain = CalcGain(param, right.grad, right.hess); + float left_gain = CalcGain(param, left.GetGrad(), left.GetHess()); + float right_gain = CalcGain(param, right.GetGrad(), right.GetHess()); return left_gain + right_gain - parent_gain; } @@ -181,13 +181,13 @@ inline void dense2sparse_tree(RegTree* p_tree, tree[nid].set_split(n.fidx, n.fvalue, n.dir == LeftDir); tree.stat(nid).loss_chg = n.root_gain; tree.stat(nid).base_weight = n.weight; - tree.stat(nid).sum_hess = n.sum_gradients.hess; + tree.stat(nid).sum_hess = n.sum_gradients.GetHess(); tree[tree[nid].cleft()].set_leaf(0); tree[tree[nid].cright()].set_leaf(0); nid++; } else if (n.IsLeaf()) { tree[nid].set_leaf(n.weight * param.learning_rate); - tree.stat(nid).sum_hess = n.sum_gradients.hess; + tree.stat(nid).sum_hess = n.sum_gradients.GetHess(); nid++; } } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index be152a3dd80d..a18c5708670f 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -5,17 +5,20 @@ #include #include #include -#include "param.h" #include "../common/compressed_iterator.h" +#include "../common/device_helpers.cuh" #include "../common/hist_util.h" +#include "param.h" #include "updater_gpu_common.cuh" -#include "../common/device_helpers.cuh" namespace xgboost { namespace tree { DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); +typedef bst_gpair_integer gpair_sum_t; +static const ncclDataType_t nccl_sum_t = ncclInt64; + // Helper for explicit template specialisation template struct Int {}; @@ -50,27 +53,29 @@ struct DeviceGMat { }; struct HistHelper { - bst_gpair* d_hist; + gpair_sum_t* d_hist; int n_bins; - __host__ __device__ HistHelper(bst_gpair* ptr, int n_bins) + __host__ __device__ HistHelper(gpair_sum_t* ptr, int n_bins) : d_hist(ptr), n_bins(n_bins) {} + __device__ void Add(bst_gpair gpair, int gidx, int nidx) const { int hist_idx = nidx * n_bins + gidx; - atomicAdd(&(d_hist[hist_idx].grad), gpair.grad); // OPTMARK: This and below - // line lead to about 3X - // slowdown due to memory - // dependency and access - // pattern issues. - atomicAdd(&(d_hist[hist_idx].hess), gpair.hess); + + auto dst_ptr = reinterpret_cast(&d_hist[hist_idx]); // NOLINT + gpair_sum_t tmp(gpair.GetGrad(), gpair.GetHess()); + auto src_ptr = reinterpret_cast(&tmp); + + atomicAdd(dst_ptr, static_cast(*src_ptr)); // NOLINT + atomicAdd(dst_ptr + 1, static_cast(*(src_ptr + 1))); // NOLINT } - __device__ bst_gpair Get(int gidx, int nidx) const { + __device__ gpair_sum_t Get(int gidx, int nidx) const { return d_hist[nidx * n_bins + gidx]; } }; struct DeviceHist { int n_bins; - dh::dvec data; + dh::dvec data; void Init(int n_bins_in) { this->n_bins = n_bins_in; @@ -79,12 +84,12 @@ struct DeviceHist { void Reset(int device_idx) { cudaSetDevice(device_idx); - data.fill(bst_gpair()); + data.fill(gpair_sum_t()); } HistHelper GetBuilder() { return HistHelper(data.data(), n_bins); } - bst_gpair* GetLevelPtr(int depth) { + gpair_sum_t* GetLevelPtr(int depth) { return data.data() + n_nodes(depth - 1) * n_bins; } @@ -96,18 +101,19 @@ struct SplitCandidate { bool missing_left; float fvalue; int findex; - bst_gpair left_sum; - bst_gpair right_sum; + gpair_sum_t left_sum; + gpair_sum_t right_sum; __host__ __device__ SplitCandidate() : loss_chg(-FLT_MAX), missing_left(true), fvalue(0), findex(-1) {} __device__ void Update(float loss_chg_in, bool missing_left_in, - float fvalue_in, int findex_in, bst_gpair left_sum_in, - bst_gpair right_sum_in, + float fvalue_in, int findex_in, + gpair_sum_t left_sum_in, gpair_sum_t right_sum_in, const GPUTrainingParam& param) { - if (loss_chg_in > loss_chg && left_sum_in.hess >= param.min_child_weight && - right_sum_in.hess >= param.min_child_weight) { + if (loss_chg_in > loss_chg && + left_sum_in.GetHess() >= param.min_child_weight && + right_sum_in.GetHess() >= param.min_child_weight) { loss_chg = loss_chg_in; missing_left = missing_left_in; fvalue = fvalue_in; @@ -121,11 +127,11 @@ struct SplitCandidate { struct GpairCallbackOp { // Running prefix - bst_gpair running_total; + gpair_sum_t running_total; // Constructor - __device__ GpairCallbackOp() : running_total(bst_gpair()) {} + __device__ GpairCallbackOp() : running_total(gpair_sum_t()) {} __device__ bst_gpair operator()(bst_gpair block_aggregate) { - bst_gpair old_prefix = running_total; + gpair_sum_t old_prefix = running_total; running_total += block_aggregate; return old_prefix; } @@ -133,17 +139,16 @@ struct GpairCallbackOp { template __global__ void find_split_kernel( - const bst_gpair* d_level_hist, int* d_feature_segments, int depth, + const gpair_sum_t* d_level_hist, int* d_feature_segments, int depth, int n_features, int n_bins, DeviceDenseNode* d_nodes, int nodes_offset_device, float* d_fidx_min_map, float* d_gidx_fvalue_map, GPUTrainingParam gpu_param, bool* d_left_child_smallest_temp, bool colsample, int* d_feature_flags) { typedef cub::KeyValuePair ArgMaxT; - typedef cub::BlockScan + typedef cub::BlockScan BlockScanT; typedef cub::BlockReduce MaxReduceT; - typedef cub::BlockReduce SumReduceT; + typedef cub::BlockReduce SumReduceT; union TempStorage { typename BlockScanT::TempStorage scan; @@ -153,8 +158,8 @@ __global__ void find_split_kernel( __shared__ cub::Uninitialized uninitialized_split; SplitCandidate& split = uninitialized_split.Alias(); - __shared__ cub::Uninitialized uninitialized_sum; - bst_gpair& shared_sum = uninitialized_sum.Alias(); + __shared__ cub::Uninitialized uninitialized_sum; + gpair_sum_t& shared_sum = uninitialized_sum.Alias(); __shared__ ArgMaxT block_max; __shared__ TempStorage temp_storage; @@ -175,14 +180,13 @@ __global__ void find_split_kernel( int begin = d_feature_segments[level_node_idx * n_features + fidx]; int end = d_feature_segments[level_node_idx * n_features + fidx + 1]; - bst_gpair feature_sum = bst_gpair(); + gpair_sum_t feature_sum = gpair_sum_t(); for (int reduce_begin = begin; reduce_begin < end; reduce_begin += BLOCK_THREADS) { bool thread_active = reduce_begin + threadIdx.x < end; // Scan histogram - bst_gpair bin = thread_active - ? d_level_hist[reduce_begin + threadIdx.x] - : bst_gpair(); + gpair_sum_t bin = thread_active ? d_level_hist[reduce_begin + threadIdx.x] + : gpair_sum_t(); feature_sum += SumReduceT(temp_storage.sum_reduce).Reduce(bin, cub::Sum()); @@ -197,18 +201,17 @@ __global__ void find_split_kernel( for (int scan_begin = begin; scan_begin < end; scan_begin += BLOCK_THREADS) { bool thread_active = scan_begin + threadIdx.x < end; - bst_gpair bin = thread_active - ? d_level_hist[scan_begin + threadIdx.x] - : bst_gpair(); + gpair_sum_t bin = thread_active ? d_level_hist[scan_begin + threadIdx.x] + : gpair_sum_t(); BlockScanT(temp_storage.scan) .ExclusiveScan(bin, bin, cub::Sum(), prefix_op); // Calculate gain - bst_gpair parent_sum = d_nodes[node_idx].sum_gradients; + gpair_sum_t parent_sum = gpair_sum_t(d_nodes[node_idx].sum_gradients); float parent_gain = d_nodes[node_idx].root_gain; - bst_gpair missing = parent_sum - shared_sum; + gpair_sum_t missing = parent_sum - shared_sum; bool missing_left; float gain = thread_active @@ -239,8 +242,8 @@ __global__ void find_split_kernel( fvalue = d_gidx_fvalue_map[gidx - 1]; } - bst_gpair left = missing_left ? bin + missing : bin; - bst_gpair right = parent_sum - left; + gpair_sum_t left = missing_left ? bin + missing : bin; + gpair_sum_t right = parent_sum - left; split.Update(gain, missing_left, fvalue, fidx, left, right, gpu_param); } @@ -263,7 +266,7 @@ __global__ void find_split_kernel( DeviceDenseNode(split.right_sum, right_child_nidx(node_idx), gpu_param); // Record smallest node - if (split.left_sum.hess <= split.right_sum.hess) { + if (split.left_sum.GetHess() <= split.right_sum.GetHess()) { left_child_smallest = true; } else { left_child_smallest = false; @@ -595,6 +598,7 @@ class GPUHistMaker : public TreeUpdater { initialised = true; } + void BuildHist(int depth) { for (int d_idx = 0; d_idx < n_devices; d_idx++) { int device_idx = dList[d_idx]; @@ -650,9 +654,9 @@ class GPUHistMaker : public TreeUpdater { dh::safe_nccl(ncclAllReduce( reinterpret_cast(hist_vec[d_idx].GetLevelPtr(depth)), reinterpret_cast(hist_vec[d_idx].GetLevelPtr(depth)), - hist_vec[d_idx].LevelSize(depth) * sizeof(bst_gpair) / - sizeof(float), - ncclFloat, ncclSum, comms[d_idx], *(streams[d_idx]))); + hist_vec[d_idx].LevelSize(depth) * sizeof(gpair_sum_t) / + sizeof(gpair_sum_t::value_t), + nccl_sum_t, ncclSum, comms[d_idx], *(streams[d_idx]))); } for (int d_idx = 0; d_idx < n_devices; d_idx++) { @@ -683,11 +687,12 @@ class GPUHistMaker : public TreeUpdater { } int gidx = idx % hist_builder.n_bins; - bst_gpair parent = hist_builder.Get(gidx, parent_nidx(nidx)); + gpair_sum_t parent = hist_builder.Get(gidx, parent_nidx(nidx)); int other_nidx = left_smallest ? nidx - 1 : nidx + 1; - bst_gpair other = hist_builder.Get(gidx, other_nidx); + gpair_sum_t other = hist_builder.Get(gidx, other_nidx); + gpair_sum_t sub = parent - other; hist_builder.Add( - parent - other, gidx, + bst_gpair(sub.GetGrad(), sub.GetHess()), gidx, nidx); // OPTMARK: This is slow, could use shared // memory or cache results intead of writing to // global memory every time in atomic way. @@ -737,11 +742,11 @@ class GPUHistMaker : public TreeUpdater { int nodes_offset_device = 0; find_split_kernel<<>>( - (const bst_gpair*)(hist_vec[d_idx].GetLevelPtr(depth)), - feature_segments[d_idx].data(), depth, (info->num_col), - (hmat_.row_ptr.back()), nodes[d_idx].data(), nodes_offset_device, - fidx_min_map[d_idx].data(), gidx_fvalue_map[d_idx].data(), - GPUTrainingParam(param), left_child_smallest[d_idx].data(), colsample, + hist_vec[d_idx].GetLevelPtr(depth), feature_segments[d_idx].data(), + depth, (info->num_col), (hmat_.row_ptr.back()), nodes[d_idx].data(), + nodes_offset_device, fidx_min_map[d_idx].data(), + gidx_fvalue_map[d_idx].data(), GPUTrainingParam(param), + left_child_smallest[d_idx].data(), colsample, feature_flags[d_idx].data()); } diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 80e8b5495d38..d4f011d06e0c 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -568,7 +568,7 @@ class CQHistMaker: public HistMaker { const bst_uint ridx = c[j].index; const int nid = this->position[ridx]; if (nid >= 0) { - sbuilder[nid].sum_total += gpair[ridx].hess; + sbuilder[nid].sum_total += gpair[ridx].GetHess(); } } // if only one value, no need to do second pass @@ -595,7 +595,7 @@ class CQHistMaker: public HistMaker { for (bst_uint i = 0; i < kBuffer; ++i) { bst_uint ridx = c[j + i].index; buf_position[i] = this->position[ridx]; - buf_hess[i] = gpair[ridx].hess; + buf_hess[i] = gpair[ridx].GetHess(); } for (bst_uint i = 0; i < kBuffer; ++i) { const int nid = buf_position[i]; @@ -608,7 +608,7 @@ class CQHistMaker: public HistMaker { const bst_uint ridx = c[j].index; const int nid = this->position[ridx]; if (nid >= 0) { - sbuilder[nid].Push(c[j].fvalue, gpair[ridx].hess, max_size); + sbuilder[nid].Push(c[j].fvalue, gpair[ridx].GetHess(), max_size); } } } else { @@ -616,7 +616,7 @@ class CQHistMaker: public HistMaker { const bst_uint ridx = c[j].index; const int nid = this->position[ridx]; if (nid >= 0) { - sbuilder[nid].Push(c[j].fvalue, gpair[ridx].hess, max_size); + sbuilder[nid].Push(c[j].fvalue, gpair[ridx].GetHess(), max_size); } } } @@ -818,7 +818,7 @@ class QuantileHistMaker: public HistMaker { for (size_t i = col_ptr[k]; i < col_ptr[k+1]; ++i) { const SparseBatch::Entry &e = col_data[i]; const int wid = this->node2workindex[e.index]; - sketchs[wid * tree.param.num_feature + k].Push(e.fvalue, gpair[e.index].hess); + sketchs[wid * tree.param.num_feature + k].Push(e.fvalue, gpair[e.index].GetHess()); } } } diff --git a/src/tree/updater_skmaker.cc b/src/tree/updater_skmaker.cc index daf4e1e839e1..1994cb6d3f69 100644 --- a/src/tree/updater_skmaker.cc +++ b/src/tree/updater_skmaker.cc @@ -98,12 +98,12 @@ class SketchMaker: public BaseMaker { const MetaInfo &info, bst_uint ridx) { const bst_gpair &b = gpair[ridx]; - if (b.grad >= 0.0f) { - pos_grad += b.grad; + if (b.GetGrad() >= 0.0f) { + pos_grad += b.GetGrad(); } else { - neg_grad -= b.grad; + neg_grad -= b.GetGrad(); } - sum_hess += b.hess; + sum_hess += b.GetHess(); } /*! \brief calculate gain of the solution */ inline double CalcGain(const TrainParam ¶m) const { @@ -199,12 +199,12 @@ class SketchMaker: public BaseMaker { const int nid = this->position[ridx]; if (nid >= 0) { const bst_gpair &e = gpair[ridx]; - if (e.grad >= 0.0f) { - sbuilder[3 * nid + 0].sum_total += e.grad; + if (e.GetGrad() >= 0.0f) { + sbuilder[3 * nid + 0].sum_total += e.GetGrad(); } else { - sbuilder[3 * nid + 1].sum_total -= e.grad; + sbuilder[3 * nid + 1].sum_total -= e.GetGrad(); } - sbuilder[3 * nid + 2].sum_total += e.hess; + sbuilder[3 * nid + 2].sum_total += e.GetHess(); } } } else { @@ -241,12 +241,12 @@ class SketchMaker: public BaseMaker { const int nid = this->position[ridx]; if (nid >= 0) { const bst_gpair &e = gpair[ridx]; - if (e.grad >= 0.0f) { - sbuilder[3 * nid + 0].Push(c[j].fvalue, e.grad, max_size); + if (e.GetGrad() >= 0.0f) { + sbuilder[3 * nid + 0].Push(c[j].fvalue, e.GetGrad(), max_size); } else { - sbuilder[3 * nid + 1].Push(c[j].fvalue, -e.grad, max_size); + sbuilder[3 * nid + 1].Push(c[j].fvalue, -e.GetGrad(), max_size); } - sbuilder[3 * nid + 2].Push(c[j].fvalue, e.hess, max_size); + sbuilder[3 * nid + 2].Push(c[j].fvalue, e.GetHess(), max_size); } } for (size_t i = 0; i < this->qexpand.size(); ++i) { diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 425fb91a3d3f..7f46e43b6355 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -43,10 +43,10 @@ void CheckObjFunction(xgboost::ObjFunction * obj, ASSERT_EQ(gpair.size(), preds.size()); for (int i = 0; i < static_cast(gpair.size()); ++i) { - EXPECT_NEAR(gpair[i].grad, out_grad[i], 0.01) + EXPECT_NEAR(gpair[i].GetGrad(), out_grad[i], 0.01) << "Unexpected grad for pred=" << preds[i] << " label=" << labels[i] << " weight=" << weights[i]; - EXPECT_NEAR(gpair[i].hess, out_hess[i], 0.01) + EXPECT_NEAR(gpair[i].GetHess(), out_hess[i], 0.01) << "Unexpected hess for pred=" << preds[i] << " label=" << labels[i] << " weight=" << weights[i]; } diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 7afa5b0cdbd9..4a7399460dcf 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -16,7 +16,7 @@ TEST(gpu_predictor, Test) { std::unique_ptr(Predictor::Create("cpu_predictor")); std::vector> trees; - trees.push_back(std::make_unique()); + trees.push_back(std::unique_ptr()); trees.back()->InitModel(); (*trees.back())[0].set_leaf(1.5f); gbm::GBTreeModel model(0.5); diff --git a/tests/cpp/xgboost_test.mk b/tests/cpp/xgboost_test.mk index 7d280673b8c6..2276ca6faac6 100644 --- a/tests/cpp/xgboost_test.mk +++ b/tests/cpp/xgboost_test.mk @@ -14,7 +14,6 @@ UNITTEST_DEPS=lib/libxgboost.a $(DMLC_CORE)/libdmlc.a $(RABIT)/lib/$(LIB_RABIT) COVER_OBJ=$(patsubst %.o, %.gcda, $(ALL_OBJ)) $(patsubst %.o, %.gcda, $(UNITTEST_OBJ)) -# the order of the below targets matter! $(UTEST_OBJ_ROOT)/$(GTEST_PATH)/%.o: $(GTEST_PATH)/%.cc @mkdir -p $(@D) $(CXX) $(UNITTEST_CFLAGS) -I$(GTEST_INC) -I$(GTEST_PATH) -o $@ -c $<