From 2b8fe8b4bdc00a2611442fdee4c45316f08b1c4b Mon Sep 17 00:00:00 2001 From: shiyu1994 Date: Wed, 31 Aug 2022 09:33:17 +0800 Subject: [PATCH] [CUDA] Add binary objective for cuda_exp (#5425) * add binary objective for cuda_exp * include and * exchange include ordering * fix length of score to copy in evaluation * fix EvalOneMetric * fix cuda binary objective and prediction when boosting on gpu * Add white space * fix BoostFromScore for CUDABinaryLogloss update log in test_register_logger * include * simplify shared memory buffer --- CMakeLists.txt | 2 + .../LightGBM/cuda/cuda_objective_function.hpp | 27 +++ include/LightGBM/objective_function.h | 9 + src/boosting/gbdt.cpp | 26 +- src/boosting/gbdt.h | 2 +- src/objective/binary_objective.hpp | 2 +- src/objective/cuda/cuda_binary_objective.cpp | 82 +++++++ src/objective/cuda/cuda_binary_objective.cu | 227 ++++++++++++++++++ src/objective/cuda/cuda_binary_objective.hpp | 75 ++++++ src/objective/objective_function.cpp | 5 +- tests/python_package_test/test_utilities.py | 1 - 11 files changed, 446 insertions(+), 12 deletions(-) create mode 100644 include/LightGBM/cuda/cuda_objective_function.hpp create mode 100644 src/objective/cuda/cuda_binary_objective.cpp create mode 100644 src/objective/cuda/cuda_binary_objective.cu create mode 100644 src/objective/cuda/cuda_binary_objective.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 25bfec47e477..c3a16000c20f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -398,6 +398,8 @@ endif() if(USE_CUDA_EXP) src/boosting/cuda/*.cpp src/boosting/cuda/*.cu + src/objective/cuda/*.cpp + src/objective/cuda/*.cu src/treelearner/cuda/*.cpp src/treelearner/cuda/*.cu src/io/cuda/*.cu diff --git a/include/LightGBM/cuda/cuda_objective_function.hpp b/include/LightGBM/cuda/cuda_objective_function.hpp new file mode 100644 index 000000000000..30642eaa5ac0 --- /dev/null +++ b/include/LightGBM/cuda/cuda_objective_function.hpp @@ -0,0 +1,27 @@ +/*! + * Copyright (c) 2021 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for + * license information. + */ + +#ifndef LIGHTGBM_OBJECTIVE_CUDA_CUDA_OBJECTIVE_HPP_ +#define LIGHTGBM_OBJECTIVE_CUDA_CUDA_OBJECTIVE_HPP_ + +#ifdef USE_CUDA_EXP + +#include +#include +#include + +namespace LightGBM { + +class CUDAObjectiveInterface { + public: + virtual void ConvertOutputCUDA(const data_size_t /*num_data*/, const double* /*input*/, double* /*output*/) const {} +}; + +} // namespace LightGBM + +#endif // USE_CUDA_EXP + +#endif // LIGHTGBM_OBJECTIVE_CUDA_CUDA_OBJECTIVE_HPP_ diff --git a/include/LightGBM/objective_function.h b/include/LightGBM/objective_function.h index 78ec2835349b..249f42f623dc 100644 --- a/include/LightGBM/objective_function.h +++ b/include/LightGBM/objective_function.h @@ -93,6 +93,15 @@ class ObjectiveFunction { * \brief Whether boosting is done on CUDA */ virtual bool IsCUDAObjective() const { return false; } + + #ifdef USE_CUDA_EXP + /*! + * \brief Get output convert function for CUDA version + */ + virtual std::function GetCUDAConvertOutputFunc() const { + return [] (data_size_t, const double*, double*) {}; + } + #endif // USE_CUDA_EXP }; } // namespace LightGBM diff --git a/src/boosting/gbdt.cpp b/src/boosting/gbdt.cpp index ee29263bb7b6..98a689d4db8f 100644 --- a/src/boosting/gbdt.cpp +++ b/src/boosting/gbdt.cpp @@ -607,7 +607,11 @@ void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) { } } -std::vector GBDT::EvalOneMetric(const Metric* metric, const double* score) const { +#ifdef USE_CUDA_EXP +std::vector GBDT::EvalOneMetric(const Metric* metric, const double* score, const data_size_t num_data) const { +#else +std::vector GBDT::EvalOneMetric(const Metric* metric, const double* score, const data_size_t /*num_data*/) const { +#endif // USE_CUDA_EXP #ifdef USE_CUDA_EXP const bool evaluation_on_cuda = metric->IsCUDAMetric(); if ((boosting_on_gpu_ && evaluation_on_cuda) || (!boosting_on_gpu_ && !evaluation_on_cuda)) { @@ -615,14 +619,14 @@ std::vector GBDT::EvalOneMetric(const Metric* metric, const double* scor return metric->Eval(score, objective_function_); #ifdef USE_CUDA_EXP } else if (boosting_on_gpu_ && !evaluation_on_cuda) { - const size_t total_size = static_cast(num_data_) * static_cast(num_tree_per_iteration_); + const size_t total_size = static_cast(num_data) * static_cast(num_tree_per_iteration_); if (total_size > host_score_.size()) { host_score_.resize(total_size, 0.0f); } CopyFromCUDADeviceToHost(host_score_.data(), score, total_size, __FILE__, __LINE__); return metric->Eval(host_score_.data(), objective_function_); } else { - const size_t total_size = static_cast(num_data_) * static_cast(num_tree_per_iteration_); + const size_t total_size = static_cast(num_data) * static_cast(num_tree_per_iteration_); if (total_size > cuda_score_.Size()) { cuda_score_.Resize(total_size); } @@ -641,7 +645,7 @@ std::string GBDT::OutputMetric(int iter) { if (need_output) { for (auto& sub_metric : training_metrics_) { auto name = sub_metric->GetName(); - auto scores = EvalOneMetric(sub_metric, train_score_updater_->score()); + auto scores = EvalOneMetric(sub_metric, train_score_updater_->score(), train_score_updater_->num_data()); for (size_t k = 0; k < name.size(); ++k) { std::stringstream tmp_buf; tmp_buf << "Iteration:" << iter @@ -658,7 +662,7 @@ std::string GBDT::OutputMetric(int iter) { if (need_output || early_stopping_round_ > 0) { for (size_t i = 0; i < valid_metrics_.size(); ++i) { for (size_t j = 0; j < valid_metrics_[i].size(); ++j) { - auto test_scores = EvalOneMetric(valid_metrics_[i][j], valid_score_updater_[i]->score()); + auto test_scores = EvalOneMetric(valid_metrics_[i][j], valid_score_updater_[i]->score(), valid_score_updater_[i]->num_data()); auto name = valid_metrics_[i][j]->GetName(); for (size_t k = 0; k < name.size(); ++k) { std::stringstream tmp_buf; @@ -698,7 +702,7 @@ std::vector GBDT::GetEvalAt(int data_idx) const { std::vector ret; if (data_idx == 0) { for (auto& sub_metric : training_metrics_) { - auto scores = EvalOneMetric(sub_metric, train_score_updater_->score()); + auto scores = EvalOneMetric(sub_metric, train_score_updater_->score(), train_score_updater_->num_data()); for (auto score : scores) { ret.push_back(score); } @@ -706,7 +710,7 @@ std::vector GBDT::GetEvalAt(int data_idx) const { } else { auto used_idx = data_idx - 1; for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) { - auto test_scores = EvalOneMetric(valid_metrics_[used_idx][j], valid_score_updater_[used_idx]->score()); + auto test_scores = EvalOneMetric(valid_metrics_[used_idx][j], valid_score_updater_[used_idx]->score(), valid_score_updater_[used_idx]->num_data()); for (auto score : test_scores) { ret.push_back(score); } @@ -760,6 +764,14 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) { num_data = valid_score_updater_[used_idx]->num_data(); *out_len = static_cast(num_data) * num_class_; } + #ifdef USE_CUDA_EXP + std::vector host_raw_scores; + if (boosting_on_gpu_) { + host_raw_scores.resize(static_cast(*out_len), 0.0); + CopyFromCUDADeviceToHost(host_raw_scores.data(), raw_scores, static_cast(*out_len), __FILE__, __LINE__); + raw_scores = host_raw_scores.data(); + } + #endif // USE_CUDA_EXP if (objective_function_ != nullptr) { #pragma omp parallel for schedule(static) for (data_size_t i = 0; i < num_data; ++i) { diff --git a/src/boosting/gbdt.h b/src/boosting/gbdt.h index 052fd0ea9051..362e4d27c1e1 100644 --- a/src/boosting/gbdt.h +++ b/src/boosting/gbdt.h @@ -443,7 +443,7 @@ class GBDT : public GBDTBase { * \brief eval results for one metric */ - virtual std::vector EvalOneMetric(const Metric* metric, const double* score) const; + virtual std::vector EvalOneMetric(const Metric* metric, const double* score, const data_size_t num_data) const; /*! * \brief Print metric result of current iteration diff --git a/src/objective/binary_objective.hpp b/src/objective/binary_objective.hpp index 2f9d93509362..0f40ffb4062f 100644 --- a/src/objective/binary_objective.hpp +++ b/src/objective/binary_objective.hpp @@ -189,7 +189,7 @@ class BinaryLogloss: public ObjectiveFunction { data_size_t NumPositiveData() const override { return num_pos_data_; } - private: + protected: /*! \brief Number of data */ data_size_t num_data_; /*! \brief Number of positive samples */ diff --git a/src/objective/cuda/cuda_binary_objective.cpp b/src/objective/cuda/cuda_binary_objective.cpp new file mode 100644 index 000000000000..35889c488ce5 --- /dev/null +++ b/src/objective/cuda/cuda_binary_objective.cpp @@ -0,0 +1,82 @@ +/*! + * Copyright (c) 2021 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for + * license information. + */ + +#ifdef USE_CUDA_EXP + +#include "cuda_binary_objective.hpp" + +#include +#include + +namespace LightGBM { + +CUDABinaryLogloss::CUDABinaryLogloss(const Config& config): +BinaryLogloss(config), ova_class_id_(-1) { + cuda_label_ = nullptr; + cuda_ova_label_ = nullptr; + cuda_weights_ = nullptr; + cuda_boost_from_score_ = nullptr; + cuda_sum_weights_ = nullptr; + cuda_label_weights_ = nullptr; +} + +CUDABinaryLogloss::CUDABinaryLogloss(const Config& config, const int ova_class_id): +BinaryLogloss(config, [ova_class_id](label_t label) { return static_cast(label) == ova_class_id; }), ova_class_id_(ova_class_id) {} + +CUDABinaryLogloss::CUDABinaryLogloss(const std::vector& strs): BinaryLogloss(strs) {} + +CUDABinaryLogloss::~CUDABinaryLogloss() { + DeallocateCUDAMemory(&cuda_ova_label_, __FILE__, __LINE__); + DeallocateCUDAMemory(&cuda_label_weights_, __FILE__, __LINE__); + DeallocateCUDAMemory(&cuda_boost_from_score_, __FILE__, __LINE__); + DeallocateCUDAMemory(&cuda_sum_weights_, __FILE__, __LINE__); +} + +void CUDABinaryLogloss::Init(const Metadata& metadata, data_size_t num_data) { + BinaryLogloss::Init(metadata, num_data); + if (ova_class_id_ == -1) { + cuda_label_ = metadata.cuda_metadata()->cuda_label(); + cuda_ova_label_ = nullptr; + } else { + InitCUDAMemoryFromHostMemory(&cuda_ova_label_, metadata.cuda_metadata()->cuda_label(), static_cast(num_data), __FILE__, __LINE__); + LaunchResetOVACUDALableKernel(); + cuda_label_ = cuda_ova_label_; + } + cuda_weights_ = metadata.cuda_metadata()->cuda_weights(); + AllocateCUDAMemory(&cuda_boost_from_score_, 1, __FILE__, __LINE__); + SetCUDAMemory(cuda_boost_from_score_, 0, 1, __FILE__, __LINE__); + AllocateCUDAMemory(&cuda_sum_weights_, 1, __FILE__, __LINE__); + SetCUDAMemory(cuda_sum_weights_, 0, 1, __FILE__, __LINE__); + if (label_weights_[0] != 1.0f || label_weights_[1] != 1.0f) { + InitCUDAMemoryFromHostMemory(&cuda_label_weights_, label_weights_, 2, __FILE__, __LINE__); + } else { + cuda_label_weights_ = nullptr; + } +} + +void CUDABinaryLogloss::GetGradients(const double* scores, score_t* gradients, score_t* hessians) const { + LaunchGetGradientsKernel(scores, gradients, hessians); + SynchronizeCUDADevice(__FILE__, __LINE__); +} + +double CUDABinaryLogloss::BoostFromScore(int) const { + LaunchBoostFromScoreKernel(); + SynchronizeCUDADevice(__FILE__, __LINE__); + double boost_from_score = 0.0f; + CopyFromCUDADeviceToHost(&boost_from_score, cuda_boost_from_score_, 1, __FILE__, __LINE__); + double pavg = 0.0f; + CopyFromCUDADeviceToHost(&pavg, cuda_sum_weights_, 1, __FILE__, __LINE__); + Log::Info("[%s:%s]: pavg=%f -> initscore=%f", GetName(), __func__, pavg, boost_from_score); + return boost_from_score; +} + +void CUDABinaryLogloss::ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const { + LaunchConvertOutputCUDAKernel(num_data, input, output); +} + +} // namespace LightGBM + +#endif // USE_CUDA_EXP diff --git a/src/objective/cuda/cuda_binary_objective.cu b/src/objective/cuda/cuda_binary_objective.cu new file mode 100644 index 000000000000..6f1711d64629 --- /dev/null +++ b/src/objective/cuda/cuda_binary_objective.cu @@ -0,0 +1,227 @@ +/*! + * Copyright (c) 2021 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for + * license information. + */ + +#ifdef USE_CUDA_EXP + +#include + +#include "cuda_binary_objective.hpp" + +namespace LightGBM { + +template +__global__ void BoostFromScoreKernel_1_BinaryLogloss(const label_t* cuda_labels, const data_size_t num_data, double* out_cuda_sum_labels, + double* out_cuda_sum_weights, const label_t* cuda_weights, const int ova_class_id) { + __shared__ double shared_buffer[32]; + const uint32_t mask = 0xffffffff; + const uint32_t warpLane = threadIdx.x % warpSize; + const uint32_t warpID = threadIdx.x / warpSize; + const uint32_t num_warp = blockDim.x / warpSize; + const data_size_t index = static_cast(threadIdx.x + blockIdx.x * blockDim.x); + double label_value = 0.0; + double weight_value = 0.0; + if (index < num_data) { + if (USE_WEIGHT) { + const label_t cuda_label = cuda_labels[index]; + const double sample_weight = cuda_weights[index]; + const label_t label = IS_OVA ? (static_cast(cuda_label) == ova_class_id ? 1 : 0) : (cuda_label > 0 ? 1 : 0); + label_value = label * sample_weight; + weight_value = sample_weight; + } else { + const label_t cuda_label = cuda_labels[index]; + label_value = IS_OVA ? (static_cast(cuda_label) == ova_class_id ? 1 : 0) : (cuda_label > 0 ? 1 : 0); + } + } + for (uint32_t offset = warpSize / 2; offset >= 1; offset >>= 1) { + label_value += __shfl_down_sync(mask, label_value, offset); + } + if (warpLane == 0) { + shared_buffer[warpID] = label_value; + } + __syncthreads(); + if (warpID == 0) { + label_value = (warpLane < num_warp ? shared_buffer[warpLane] : 0); + for (uint32_t offset = warpSize / 2; offset >= 1; offset >>= 1) { + label_value += __shfl_down_sync(mask, label_value, offset); + } + } + __syncthreads(); + if (USE_WEIGHT) { + for (uint32_t offset = warpSize / 2; offset >= 1; offset >>= 1) { + weight_value += __shfl_down_sync(mask, weight_value, offset); + } + if (warpLane == 0) { + shared_buffer[warpID] = weight_value; + } + __syncthreads(); + if (warpID == 0) { + weight_value = (warpLane < num_warp ? shared_buffer[warpLane] : 0); + for (uint32_t offset = warpSize / 2; offset >= 1; offset >>= 1) { + weight_value += __shfl_down_sync(mask, weight_value, offset); + } + } + __syncthreads(); + } + if (threadIdx.x == 0) { + atomicAdd_system(out_cuda_sum_labels, label_value); + if (USE_WEIGHT) { + atomicAdd_system(out_cuda_sum_weights, weight_value); + } + } +} + +template +__global__ void BoostFromScoreKernel_2_BinaryLogloss(double* out_cuda_sum_labels, double* out_cuda_sum_weights, + const data_size_t num_data, const double sigmoid) { + const double suml = *out_cuda_sum_labels; + const double sumw = USE_WEIGHT ? *out_cuda_sum_weights : static_cast(num_data); + double pavg = suml / sumw; + pavg = min(pavg, 1.0 - kEpsilon); + pavg = max(pavg, kEpsilon); + const double init_score = log(pavg / (1.0f - pavg)) / sigmoid; + *out_cuda_sum_weights = pavg; + *out_cuda_sum_labels = init_score; +} + +void CUDABinaryLogloss::LaunchBoostFromScoreKernel() const { + const int num_blocks = (num_data_ + CALC_INIT_SCORE_BLOCK_SIZE_BINARY - 1) / CALC_INIT_SCORE_BLOCK_SIZE_BINARY; + if (ova_class_id_ == -1) { + if (cuda_weights_ == nullptr) { + BoostFromScoreKernel_1_BinaryLogloss<<>> + (cuda_label_, num_data_, cuda_boost_from_score_, cuda_sum_weights_, cuda_weights_, ova_class_id_); + } else { + BoostFromScoreKernel_1_BinaryLogloss<<>> + (cuda_label_, num_data_, cuda_boost_from_score_, cuda_sum_weights_, cuda_weights_, ova_class_id_); + } + } else { + if (cuda_weights_ == nullptr) { + BoostFromScoreKernel_1_BinaryLogloss<<>> + (cuda_label_, num_data_, cuda_boost_from_score_, cuda_sum_weights_, cuda_weights_, ova_class_id_); + } else { + BoostFromScoreKernel_1_BinaryLogloss<<>> + (cuda_label_, num_data_, cuda_boost_from_score_, cuda_sum_weights_, cuda_weights_, ova_class_id_); + } + } + SynchronizeCUDADevice(__FILE__, __LINE__); + if (cuda_weights_ == nullptr) { + BoostFromScoreKernel_2_BinaryLogloss<<<1, 1>>>(cuda_boost_from_score_, cuda_sum_weights_, num_data_, sigmoid_); + } else { + BoostFromScoreKernel_2_BinaryLogloss<<<1, 1>>>(cuda_boost_from_score_, cuda_sum_weights_, num_data_, sigmoid_); + } + SynchronizeCUDADevice(__FILE__, __LINE__); +} + +template +__global__ void GetGradientsKernel_BinaryLogloss(const double* cuda_scores, const label_t* cuda_labels, + const double* cuda_label_weights, const label_t* cuda_weights, const int ova_class_id, + const double sigmoid, const data_size_t num_data, + score_t* cuda_out_gradients, score_t* cuda_out_hessians) { + const data_size_t data_index = static_cast(blockDim.x * blockIdx.x + threadIdx.x); + if (data_index < num_data) { + const label_t cuda_label = static_cast(cuda_labels[data_index]); + const int label = IS_OVA ? (cuda_label == ova_class_id ? 1 : -1) : (cuda_label > 0 ? 1 : -1); + const double response = -label * sigmoid / (1.0f + exp(label * sigmoid * cuda_scores[data_index])); + const double abs_response = fabs(response); + if (!USE_WEIGHT) { + if (USE_LABEL_WEIGHT) { + const double label_weight = cuda_label_weights[label]; + cuda_out_gradients[data_index] = static_cast(response * label_weight); + cuda_out_hessians[data_index] = static_cast(abs_response * (sigmoid - abs_response) * label_weight); + } else { + cuda_out_gradients[data_index] = static_cast(response); + cuda_out_hessians[data_index] = static_cast(abs_response * (sigmoid - abs_response)); + } + } else { + const double sample_weight = cuda_weights[data_index]; + if (USE_LABEL_WEIGHT) { + const double label_weight = cuda_label_weights[label]; + cuda_out_gradients[data_index] = static_cast(response * label_weight * sample_weight); + cuda_out_hessians[data_index] = static_cast(abs_response * (sigmoid - abs_response) * label_weight * sample_weight); + } else { + cuda_out_gradients[data_index] = static_cast(response * sample_weight); + cuda_out_hessians[data_index] = static_cast(abs_response * (sigmoid - abs_response) * sample_weight); + } + } + } +} + +#define GetGradientsKernel_BinaryLogloss_ARGS \ + scores, \ + cuda_label_, \ + cuda_label_weights_, \ + cuda_weights_, \ + ova_class_id_, \ + sigmoid_, \ + num_data_, \ + gradients, \ + hessians + +void CUDABinaryLogloss::LaunchGetGradientsKernel(const double* scores, score_t* gradients, score_t* hessians) const { + const int num_blocks = (num_data_ + GET_GRADIENTS_BLOCK_SIZE_BINARY - 1) / GET_GRADIENTS_BLOCK_SIZE_BINARY; + if (ova_class_id_ == -1) { + if (cuda_label_weights_ == nullptr) { + if (cuda_weights_ == nullptr) { + GetGradientsKernel_BinaryLogloss<<>>(GetGradientsKernel_BinaryLogloss_ARGS); + } else { + GetGradientsKernel_BinaryLogloss<<>>(GetGradientsKernel_BinaryLogloss_ARGS); + } + } else { + if (cuda_weights_ == nullptr) { + GetGradientsKernel_BinaryLogloss<<>>(GetGradientsKernel_BinaryLogloss_ARGS); + } else { + GetGradientsKernel_BinaryLogloss<<>>(GetGradientsKernel_BinaryLogloss_ARGS); + } + } + } else { + if (cuda_label_weights_ == nullptr) { + if (cuda_weights_ == nullptr) { + GetGradientsKernel_BinaryLogloss<<>>(GetGradientsKernel_BinaryLogloss_ARGS); + } else { + GetGradientsKernel_BinaryLogloss<<>>(GetGradientsKernel_BinaryLogloss_ARGS); + } + } else { + if (cuda_weights_ == nullptr) { + GetGradientsKernel_BinaryLogloss<<>>(GetGradientsKernel_BinaryLogloss_ARGS); + } else { + GetGradientsKernel_BinaryLogloss<<>>(GetGradientsKernel_BinaryLogloss_ARGS); + } + } + } +} + +#undef GetGradientsKernel_BinaryLogloss_ARGS + +__global__ void ConvertOutputCUDAKernel_BinaryLogloss(const double sigmoid, const data_size_t num_data, const double* input, double* output) { + const data_size_t data_index = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (data_index < num_data) { + output[data_index] = 1.0f / (1.0f + exp(-sigmoid * input[data_index])); + } +} + +void CUDABinaryLogloss::LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const { + const int num_blocks = (num_data + GET_GRADIENTS_BLOCK_SIZE_BINARY - 1) / GET_GRADIENTS_BLOCK_SIZE_BINARY; + ConvertOutputCUDAKernel_BinaryLogloss<<>>(sigmoid_, num_data, input, output); +} + +__global__ void ResetOVACUDALableKernel( + const int ova_class_id, + const data_size_t num_data, + label_t* cuda_label) { + const data_size_t data_index = static_cast(threadIdx.x + blockIdx.x * blockDim.x); + if (data_index < num_data) { + const int int_label = static_cast(cuda_label[data_index]); + cuda_label[data_index] = (int_label == ova_class_id ? 1.0f : 0.0f); + } +} + +void CUDABinaryLogloss::LaunchResetOVACUDALableKernel() const { + const int num_blocks = (num_data_ + GET_GRADIENTS_BLOCK_SIZE_BINARY - 1) / GET_GRADIENTS_BLOCK_SIZE_BINARY; + ResetOVACUDALableKernel<<>>(ova_class_id_, num_data_, cuda_ova_label_); +} + +} // namespace LightGBM + +#endif // USE_CUDA_EXP diff --git a/src/objective/cuda/cuda_binary_objective.hpp b/src/objective/cuda/cuda_binary_objective.hpp new file mode 100644 index 000000000000..d11d49b5affb --- /dev/null +++ b/src/objective/cuda/cuda_binary_objective.hpp @@ -0,0 +1,75 @@ +/*! + * Copyright (c) 2021 Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See LICENSE file in the project root for + * license information. + */ + +#ifndef LIGHTGBM_OBJECTIVE_CUDA_CUDA_BINARY_OBJECTIVE_HPP_ +#define LIGHTGBM_OBJECTIVE_CUDA_CUDA_BINARY_OBJECTIVE_HPP_ + +#ifdef USE_CUDA_EXP + +#define GET_GRADIENTS_BLOCK_SIZE_BINARY (1024) +#define CALC_INIT_SCORE_BLOCK_SIZE_BINARY (1024) + +#include + +#include +#include + +#include "../binary_objective.hpp" + +namespace LightGBM { + +class CUDABinaryLogloss : public CUDAObjectiveInterface, public BinaryLogloss { + public: + explicit CUDABinaryLogloss(const Config& config); + + explicit CUDABinaryLogloss(const Config& config, const int ova_class_id); + + explicit CUDABinaryLogloss(const std::vector& strs); + + ~CUDABinaryLogloss(); + + void Init(const Metadata& metadata, data_size_t num_data) override; + + void GetGradients(const double* scores, score_t* gradients, score_t* hessians) const override; + + double BoostFromScore(int) const override; + + void ConvertOutputCUDA(const data_size_t num_data, const double* input, double* output) const override; + + std::function GetCUDAConvertOutputFunc() const override { + return [this] (data_size_t num_data, const double* input, double* output) { + ConvertOutputCUDA(num_data, input, output); + }; + } + + bool IsCUDAObjective() const override { return true; } + + private: + void LaunchGetGradientsKernel(const double* scores, score_t* gradients, score_t* hessians) const; + + void LaunchBoostFromScoreKernel() const; + + void LaunchConvertOutputCUDAKernel(const data_size_t num_data, const double* input, double* output) const; + + void LaunchResetOVACUDALableKernel() const; + + // CUDA memory, held by other objects + const label_t* cuda_label_; + label_t* cuda_ova_label_; + const label_t* cuda_weights_; + + // CUDA memory, held by this object + double* cuda_boost_from_score_; + double* cuda_sum_weights_; + double* cuda_label_weights_; + const int ova_class_id_ = -1; +}; + +} // namespace LightGBM + +#endif // USE_CUDA_EXP + +#endif // LIGHTGBM_OBJECTIVE_CUDA_CUDA_BINARY_OBJECTIVE_HPP_ diff --git a/src/objective/objective_function.cpp b/src/objective/objective_function.cpp index 2f719b44a988..51287585646f 100644 --- a/src/objective/objective_function.cpp +++ b/src/objective/objective_function.cpp @@ -10,6 +10,8 @@ #include "regression_objective.hpp" #include "xentropy_objective.hpp" +#include "cuda/cuda_binary_objective.hpp" + namespace LightGBM { ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& type, const Config& config) { @@ -34,8 +36,7 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& Log::Warning("Objective poisson is not implemented in cuda_exp version. Fall back to boosting on CPU."); return new RegressionPoissonLoss(config); } else if (type == std::string("binary")) { - Log::Warning("Objective binary is not implemented in cuda_exp version. Fall back to boosting on CPU."); - return new BinaryLogloss(config); + return new CUDABinaryLogloss(config); } else if (type == std::string("lambdarank")) { Log::Warning("Objective lambdarank is not implemented in cuda_exp version. Fall back to boosting on CPU."); return new LambdarankNDCG(config); diff --git a/tests/python_package_test/test_utilities.py b/tests/python_package_test/test_utilities.py index 08913ceb6e38..e75198bb1214 100644 --- a/tests/python_package_test/test_utilities.py +++ b/tests/python_package_test/test_utilities.py @@ -92,7 +92,6 @@ def dummy_metric(_, __): "INFO | [LightGBM] [Info] LightGBM using CUDA trainer with DP float!!" ] cuda_exp_lines = [ - "INFO | [LightGBM] [Warning] Objective binary is not implemented in cuda_exp version. Fall back to boosting on CPU.", "INFO | [LightGBM] [Warning] Metric auc is not implemented in cuda_exp version. Fall back to evaluation on CPU.", "INFO | [LightGBM] [Warning] Metric binary_error is not implemented in cuda_exp version. Fall back to evaluation on CPU.", ]