Skip to content

Commit

Permalink
[CUDA] Add Poisson regression objective for cuda_exp and refactor obj…
Browse files Browse the repository at this point in the history
…ective functions for cuda_exp (#5486)

* add poisson regression objective for cuda_exp

* enable Poisson regression for cuda_exp

* refactor cuda objective functions

* remove useless changes

* fix linter errors

* remove redundant buffer in cuda poisson regression objective

* fix log of cuda_exp binary objective

* fix threshold of poisson objective result

* remove useless changes

* fix compilation errors

* add cuda quantile regression objective

* remove cuda quantile regression objective

Co-authored-by: James Lamb <[email protected]>
  • Loading branch information
shiyu1994 and jameslamb authored Nov 27, 2022
1 parent f7e64a8 commit 24af9fa
Show file tree
Hide file tree
Showing 20 changed files with 441 additions and 326 deletions.
34 changes: 34 additions & 0 deletions include/LightGBM/cuda/cuda_algorithms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,40 @@ __device__ __forceinline__ void GlobalMemoryPrefixSum(T* array, const size_t len
}
}

template <typename T>
__device__ __forceinline__ T ShuffleReduceMinWarp(T value, const data_size_t len) {
if (len > 0) {
const uint32_t mask = (0xffffffff >> (warpSize - len));
for (int offset = warpSize / 2; offset > 0; offset >>= 1) {
const T other_value = __shfl_down_sync(mask, value, offset);
value = (other_value < value) ? other_value : value;
}
}
return value;
}

// reduce values from an 1-dimensional block (block size must be no greather than 1024)
template <typename T>
__device__ __forceinline__ T ShuffleReduceMin(T value, T* shared_mem_buffer, const size_t len) {
const uint32_t warpLane = threadIdx.x % warpSize;
const uint32_t warpID = threadIdx.x / warpSize;
const data_size_t warp_len = min(static_cast<data_size_t>(warpSize), static_cast<data_size_t>(len) - static_cast<data_size_t>(warpID * warpSize));
value = ShuffleReduceMinWarp<T>(value, warp_len);
if (warpLane == 0) {
shared_mem_buffer[warpID] = value;
}
__syncthreads();
const data_size_t num_warp = static_cast<data_size_t>((len + warpSize - 1) / warpSize);
if (warpID == 0) {
value = (warpLane < num_warp ? shared_mem_buffer[warpLane] : shared_mem_buffer[0]);
value = ShuffleReduceMinWarp<T>(value, num_warp);
}
return value;
}

template <typename VAL_T, typename REDUCE_T>
void ShuffleReduceMinGlobal(const VAL_T* values, size_t n, REDUCE_T* block_buffer);

template <typename VAL_T, typename INDEX_T, bool ASCENDING>
__device__ __forceinline__ void BitonicArgSort_1024(const VAL_T* scores, INDEX_T* indices, const INDEX_T num_items) {
INDEX_T depth = 1;
Expand Down
61 changes: 59 additions & 2 deletions include/LightGBM/cuda/cuda_objective_function.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,68 @@
#include <LightGBM/objective_function.h>
#include <LightGBM/meta.h>

#include <string>
#include <vector>

namespace LightGBM {

class CUDAObjectiveInterface {
template <typename HOST_OBJECTIVE>
class CUDAObjectiveInterface: public HOST_OBJECTIVE {
public:
virtual void ConvertOutputCUDA(const data_size_t /*num_data*/, const double* /*input*/, double* /*output*/) const {}
explicit CUDAObjectiveInterface(const Config& config): HOST_OBJECTIVE(config) {}

explicit CUDAObjectiveInterface(const std::vector<std::string>& strs): HOST_OBJECTIVE(strs) {}

void Init(const Metadata& metadata, data_size_t num_data) {
HOST_OBJECTIVE::Init(metadata, num_data);
cuda_labels_ = metadata.cuda_metadata()->cuda_label();
cuda_weights_ = metadata.cuda_metadata()->cuda_weights();
}

virtual void ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const {
LaunchConvertOutputCUDAKernel(num_data, input, output);
}

std::function<void(data_size_t, const double*, double*)> GetCUDAConvertOutputFunc() const override {
return [this] (data_size_t num_data, const double* input, double* output) {
ConvertOutputCUDA(num_data, input, output);
};
}

double BoostFromScore(int class_id) const override {
return LaunchCalcInitScoreKernel(class_id);
}

bool IsCUDAObjective() const override { return true; }

void GetGradients(const double* scores, score_t* gradients, score_t* hessians) const override {
LaunchGetGradientsKernel(scores, gradients, hessians);
SynchronizeCUDADevice(__FILE__, __LINE__);
}

void RenewTreeOutputCUDA(const double* score, const data_size_t* data_indices_in_leaf, const data_size_t* num_data_in_leaf,
const data_size_t* data_start_in_leaf, const int num_leaves, double* leaf_value) const override {
global_timer.Start("CUDAObjectiveInterface::LaunchRenewTreeOutputCUDAKernel");
LaunchRenewTreeOutputCUDAKernel(score, data_indices_in_leaf, num_data_in_leaf, data_start_in_leaf, num_leaves, leaf_value);
SynchronizeCUDADevice(__FILE__, __LINE__);
global_timer.Stop("CUDAObjectiveInterface::LaunchRenewTreeOutputCUDAKernel");
}

protected:
virtual void LaunchGetGradientsKernel(const double* scores, score_t* gradients, score_t* hessians) const = 0;

virtual double LaunchCalcInitScoreKernel(const int class_id) const {
return HOST_OBJECTIVE::BoostFromScore(class_id);
}

virtual void LaunchConvertOutputCUDAKernel(const data_size_t /*num_data*/, const double* /*input*/, double* /*output*/) const {}

virtual void LaunchRenewTreeOutputCUDAKernel(
const double* /*score*/, const data_size_t* /*data_indices_in_leaf*/, const data_size_t* /*num_data_in_leaf*/,
const data_size_t* /*data_start_in_leaf*/, const int /*num_leaves*/, double* /*leaf_value*/) const {}

const label_t* cuda_labels_;
const label_t* cuda_weights_;
};

} // namespace LightGBM
Expand Down
91 changes: 51 additions & 40 deletions src/cuda/cuda_algorithms.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,27 +55,16 @@ __global__ void ShufflePrefixSumGlobalAddBase(size_t len, const T* block_prefix_
}

template <typename T>
void ShufflePrefixSumGlobalInner(T* values, size_t len, T* block_prefix_sum_buffer) {
void ShufflePrefixSumGlobal(T* values, size_t len, T* block_prefix_sum_buffer) {
const int num_blocks = (static_cast<int>(len) + GLOBAL_PREFIX_SUM_BLOCK_SIZE - 1) / GLOBAL_PREFIX_SUM_BLOCK_SIZE;
ShufflePrefixSumGlobalKernel<<<num_blocks, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(values, len, block_prefix_sum_buffer);
ShufflePrefixSumGlobalReduceBlockKernel<<<1, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(block_prefix_sum_buffer, num_blocks);
ShufflePrefixSumGlobalAddBase<<<num_blocks, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(len, block_prefix_sum_buffer, values);
}

template <>
void ShufflePrefixSumGlobal(uint16_t* values, size_t len, uint16_t* block_prefix_sum_buffer) {
ShufflePrefixSumGlobalInner<uint16_t>(values, len, block_prefix_sum_buffer);
}

template <>
void ShufflePrefixSumGlobal(uint32_t* values, size_t len, uint32_t* block_prefix_sum_buffer) {
ShufflePrefixSumGlobalInner<uint32_t>(values, len, block_prefix_sum_buffer);
}

template <>
void ShufflePrefixSumGlobal(uint64_t* values, size_t len, uint64_t* block_prefix_sum_buffer) {
ShufflePrefixSumGlobalInner<uint64_t>(values, len, block_prefix_sum_buffer);
}
template void ShufflePrefixSumGlobal<uint16_t>(uint16_t* values, size_t len, uint16_t* block_prefix_sum_buffer);
template void ShufflePrefixSumGlobal<uint32_t>(uint32_t* values, size_t len, uint32_t* block_prefix_sum_buffer);
template void ShufflePrefixSumGlobal<uint64_t>(uint64_t* values, size_t len, uint64_t* block_prefix_sum_buffer);

__global__ void BitonicArgSortItemsGlobalKernel(const double* scores,
const int num_queries,
Expand Down Expand Up @@ -130,18 +119,52 @@ __global__ void ShuffleReduceSumGlobalKernel(const VAL_T* values, const data_siz
}

template <typename VAL_T, typename REDUCE_T>
void ShuffleReduceSumGlobalInner(const VAL_T* values, size_t n, REDUCE_T* block_buffer) {
void ShuffleReduceSumGlobal(const VAL_T* values, size_t n, REDUCE_T* block_buffer) {
const data_size_t num_value = static_cast<data_size_t>(n);
const data_size_t num_blocks = (num_value + GLOBAL_PREFIX_SUM_BLOCK_SIZE - 1) / GLOBAL_PREFIX_SUM_BLOCK_SIZE;
ShuffleReduceSumGlobalKernel<VAL_T, REDUCE_T><<<num_blocks, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(values, num_value, block_buffer);
BlockReduceSum<REDUCE_T><<<1, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(block_buffer, num_blocks);
}

template <>
void ShuffleReduceSumGlobal<label_t, double>(const label_t* values, size_t n, double* block_buffer) {
ShuffleReduceSumGlobalInner(values, n, block_buffer);
template void ShuffleReduceSumGlobal<label_t, double>(const label_t* values, size_t n, double* block_buffer);

template <typename VAL_T, typename REDUCE_T>
__global__ void ShuffleReduceMinGlobalKernel(const VAL_T* values, const data_size_t num_value, REDUCE_T* block_buffer) {
__shared__ REDUCE_T shared_buffer[32];
const data_size_t data_index = static_cast<data_size_t>(blockIdx.x * blockDim.x + threadIdx.x);
const REDUCE_T value = (data_index < num_value ? static_cast<REDUCE_T>(values[data_index]) : 0.0f);
const REDUCE_T reduce_value = ShuffleReduceMin<REDUCE_T>(value, shared_buffer, blockDim.x);
if (threadIdx.x == 0) {
block_buffer[blockIdx.x] = reduce_value;
}
}

template <typename T>
__global__ void ShuffleBlockReduceMin(T* block_buffer, const data_size_t num_blocks) {
__shared__ T shared_buffer[32];
T thread_min = 0;
for (data_size_t block_index = static_cast<data_size_t>(threadIdx.x); block_index < num_blocks; block_index += static_cast<data_size_t>(blockDim.x)) {
const T value = block_buffer[block_index];
if (value < thread_min) {
thread_min = value;
}
}
thread_min = ShuffleReduceMin<T>(thread_min, shared_buffer, blockDim.x);
if (threadIdx.x == 0) {
block_buffer[0] = thread_min;
}
}

template <typename VAL_T, typename REDUCE_T>
void ShuffleReduceMinGlobal(const VAL_T* values, size_t n, REDUCE_T* block_buffer) {
const data_size_t num_value = static_cast<data_size_t>(n);
const data_size_t num_blocks = (num_value + GLOBAL_PREFIX_SUM_BLOCK_SIZE - 1) / GLOBAL_PREFIX_SUM_BLOCK_SIZE;
ShuffleReduceMinGlobalKernel<VAL_T, REDUCE_T><<<num_blocks, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(values, num_value, block_buffer);
ShuffleBlockReduceMin<REDUCE_T><<<1, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(block_buffer, num_blocks);
}

template void ShuffleReduceMinGlobal<label_t, double>(const label_t* values, size_t n, double* block_buffer);

template <typename VAL_T, typename REDUCE_T>
__global__ void ShuffleReduceDotProdGlobalKernel(const VAL_T* values1, const VAL_T* values2, const data_size_t num_value, REDUCE_T* block_buffer) {
__shared__ REDUCE_T shared_buffer[32];
Expand All @@ -155,17 +178,14 @@ __global__ void ShuffleReduceDotProdGlobalKernel(const VAL_T* values1, const VAL
}

template <typename VAL_T, typename REDUCE_T>
void ShuffleReduceDotProdGlobalInner(const VAL_T* values1, const VAL_T* values2, size_t n, REDUCE_T* block_buffer) {
void ShuffleReduceDotProdGlobal(const VAL_T* values1, const VAL_T* values2, size_t n, REDUCE_T* block_buffer) {
const data_size_t num_value = static_cast<data_size_t>(n);
const data_size_t num_blocks = (num_value + GLOBAL_PREFIX_SUM_BLOCK_SIZE - 1) / GLOBAL_PREFIX_SUM_BLOCK_SIZE;
ShuffleReduceDotProdGlobalKernel<VAL_T, REDUCE_T><<<num_blocks, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(values1, values2, num_value, block_buffer);
BlockReduceSum<REDUCE_T><<<1, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(block_buffer, num_blocks);
}

template <>
void ShuffleReduceDotProdGlobal<label_t, double>(const label_t* values1, const label_t* values2, size_t n, double* block_buffer) {
ShuffleReduceDotProdGlobalInner(values1, values2, n, block_buffer);
}
template void ShuffleReduceDotProdGlobal<label_t, double>(const label_t* values1, const label_t* values2, size_t n, double* block_buffer);

template <typename INDEX_T, typename VAL_T, typename REDUCE_T>
__global__ void GlobalInclusiveArgPrefixSumKernel(
Expand Down Expand Up @@ -209,7 +229,7 @@ __global__ void GlobalInclusivePrefixSumAddBlockBaseKernel(const T* block_buffer
}

template <typename VAL_T, typename REDUCE_T, typename INDEX_T>
void GlobalInclusiveArgPrefixSumInner(const INDEX_T* sorted_indices, const VAL_T* in_values, REDUCE_T* out_values, REDUCE_T* block_buffer, size_t n) {
void GlobalInclusiveArgPrefixSum(const INDEX_T* sorted_indices, const VAL_T* in_values, REDUCE_T* out_values, REDUCE_T* block_buffer, size_t n) {
const data_size_t num_data = static_cast<data_size_t>(n);
const data_size_t num_blocks = (num_data + GLOBAL_PREFIX_SUM_BLOCK_SIZE - 1) / GLOBAL_PREFIX_SUM_BLOCK_SIZE;
GlobalInclusiveArgPrefixSumKernel<INDEX_T, VAL_T, REDUCE_T><<<num_blocks, GLOBAL_PREFIX_SUM_BLOCK_SIZE>>>(
Expand All @@ -223,10 +243,7 @@ void GlobalInclusiveArgPrefixSumInner(const INDEX_T* sorted_indices, const VAL_T
SynchronizeCUDADevice(__FILE__, __LINE__);
}

template <>
void GlobalInclusiveArgPrefixSum<label_t, double, data_size_t>(const data_size_t* sorted_indices, const label_t* in_values, double* out_values, double* block_buffer, size_t n) {
GlobalInclusiveArgPrefixSumInner<label_t, double, data_size_t>(sorted_indices, in_values, out_values, block_buffer, n);
}
template void GlobalInclusiveArgPrefixSum<label_t, double, data_size_t>(const data_size_t* sorted_indices, const label_t* in_values, double* out_values, double* block_buffer, size_t n);

template <typename VAL_T, typename INDEX_T, bool ASCENDING>
__global__ void BitonicArgSortGlobalKernel(const VAL_T* values, INDEX_T* indices, const int num_total_data) {
Expand Down Expand Up @@ -424,7 +441,7 @@ void BitonicArgSortGlobal<data_size_t, int, true>(const data_size_t* values, int
}

template <typename VAL_T, typename INDEX_T, typename WEIGHT_T, typename REDUCE_WEIGHT_T, bool ASCENDING, bool USE_WEIGHT>
__device__ VAL_T PercentileDeviceInner(const VAL_T* values,
__device__ VAL_T PercentileDevice(const VAL_T* values,
const WEIGHT_T* weights,
INDEX_T* indices,
REDUCE_WEIGHT_T* weights_prefix_sum,
Expand Down Expand Up @@ -472,27 +489,21 @@ __device__ VAL_T PercentileDeviceInner(const VAL_T* values,
}
}

template <>
__device__ double PercentileDevice<double, data_size_t, label_t, double, false, true>(
template __device__ double PercentileDevice<double, data_size_t, label_t, double, false, true>(
const double* values,
const label_t* weights,
data_size_t* indices,
double* weights_prefix_sum,
const double alpha,
const data_size_t len) {
return PercentileDeviceInner<double, data_size_t, label_t, double, false, true>(values, weights, indices, weights_prefix_sum, alpha, len);
}
const data_size_t len);

template <>
__device__ double PercentileDevice<double, data_size_t, label_t, double, false, false>(
template __device__ double PercentileDevice<double, data_size_t, label_t, double, false, false>(
const double* values,
const label_t* weights,
data_size_t* indices,
double* weights_prefix_sum,
const double alpha,
const data_size_t len) {
return PercentileDeviceInner<double, data_size_t, label_t, double, false, false>(values, weights, indices, weights_prefix_sum, alpha, len);
}
const data_size_t len);


} // namespace LightGBM
Expand Down
2 changes: 1 addition & 1 deletion src/objective/binary_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class BinaryLogloss: public ObjectiveFunction {
pavg = std::min(pavg, 1.0 - kEpsilon);
pavg = std::max<double>(pavg, kEpsilon);
double initscore = std::log(pavg / (1.0f - pavg)) / sigmoid_;
Log::Info("[%s:%s]: pavg=%f -> initscore=%f", GetName(), __func__, pavg, initscore);
Log::Info("[%s:%s]: pavg=%f -> initscore=%f", GetName(), __func__, pavg, initscore);
return initscore;
}

Expand Down
32 changes: 7 additions & 25 deletions src/objective/cuda/cuda_binary_objective.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
namespace LightGBM {

CUDABinaryLogloss::CUDABinaryLogloss(const Config& config):
BinaryLogloss(config), ova_class_id_(-1) {
CUDAObjectiveInterface<BinaryLogloss>(config), ova_class_id_(-1) {
cuda_label_ = nullptr;
cuda_ova_label_ = nullptr;
cuda_weights_ = nullptr;
Expand All @@ -24,9 +24,11 @@ BinaryLogloss(config), ova_class_id_(-1) {
}

CUDABinaryLogloss::CUDABinaryLogloss(const Config& config, const int ova_class_id):
BinaryLogloss(config, [ova_class_id](label_t label) { return static_cast<int>(label) == ova_class_id; }), ova_class_id_(ova_class_id) {}
CUDAObjectiveInterface<BinaryLogloss>(config), ova_class_id_(ova_class_id) {
is_pos_ = [ova_class_id](label_t label) { return static_cast<int>(label) == ova_class_id; };
}

CUDABinaryLogloss::CUDABinaryLogloss(const std::vector<std::string>& strs): BinaryLogloss(strs) {}
CUDABinaryLogloss::CUDABinaryLogloss(const std::vector<std::string>& strs): CUDAObjectiveInterface<BinaryLogloss>(strs) {}

CUDABinaryLogloss::~CUDABinaryLogloss() {
DeallocateCUDAMemory<label_t>(&cuda_ova_label_, __FILE__, __LINE__);
Expand All @@ -36,13 +38,13 @@ CUDABinaryLogloss::~CUDABinaryLogloss() {
}

void CUDABinaryLogloss::Init(const Metadata& metadata, data_size_t num_data) {
BinaryLogloss::Init(metadata, num_data);
CUDAObjectiveInterface<BinaryLogloss>::Init(metadata, num_data);
if (ova_class_id_ == -1) {
cuda_label_ = metadata.cuda_metadata()->cuda_label();
cuda_ova_label_ = nullptr;
} else {
InitCUDAMemoryFromHostMemory<label_t>(&cuda_ova_label_, metadata.cuda_metadata()->cuda_label(), static_cast<size_t>(num_data), __FILE__, __LINE__);
LaunchResetOVACUDALableKernel();
LaunchResetOVACUDALabelKernel();
cuda_label_ = cuda_ova_label_;
}
cuda_weights_ = metadata.cuda_metadata()->cuda_weights();
Expand All @@ -57,26 +59,6 @@ void CUDABinaryLogloss::Init(const Metadata& metadata, data_size_t num_data) {
}
}

void CUDABinaryLogloss::GetGradients(const double* scores, score_t* gradients, score_t* hessians) const {
LaunchGetGradientsKernel(scores, gradients, hessians);
SynchronizeCUDADevice(__FILE__, __LINE__);
}

double CUDABinaryLogloss::BoostFromScore(int) const {
LaunchBoostFromScoreKernel();
SynchronizeCUDADevice(__FILE__, __LINE__);
double boost_from_score = 0.0f;
CopyFromCUDADeviceToHost<double>(&boost_from_score, cuda_boost_from_score_, 1, __FILE__, __LINE__);
double pavg = 0.0f;
CopyFromCUDADeviceToHost<double>(&pavg, cuda_sum_weights_, 1, __FILE__, __LINE__);
Log::Info("[%s:%s]: pavg=%f -> initscore=%f", GetName(), __func__, pavg, boost_from_score);
return boost_from_score;
}

void CUDABinaryLogloss::ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const {
LaunchConvertOutputCUDAKernel(num_data, input, output);
}

} // namespace LightGBM

#endif // USE_CUDA_EXP
Loading

0 comments on commit 24af9fa

Please sign in to comment.