From 0c4bb89de578a02b7822ba2bc6af2947e1c8b527 Mon Sep 17 00:00:00 2001 From: Guolin Ke Date: Mon, 26 Nov 2018 04:10:25 +0800 Subject: [PATCH] fix bug for one-class binary (#1877) --- src/objective/binary_objective.hpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/objective/binary_objective.hpp b/src/objective/binary_objective.hpp index 4bd54f2f1789..1a19676cd0e7 100644 --- a/src/objective/binary_objective.hpp +++ b/src/objective/binary_objective.hpp @@ -65,10 +65,11 @@ class BinaryLogloss: public ObjectiveFunction { ++cnt_negative; } } + need_train_ = true; if (cnt_negative == 0 || cnt_positive == 0) { Log::Warning("Contains only one class"); // not need to boost. - num_data_ = 0; + need_train_ = false; } Log::Info("Number of positive: %d, number of negative: %d", cnt_positive, cnt_negative); // use -1 for negative class, and 1 for positive class @@ -91,6 +92,9 @@ class BinaryLogloss: public ObjectiveFunction { } void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override { + if (!need_train_) { + return; + } if (weights_ == nullptr) { #pragma omp parallel for schedule(static) for (data_size_t i = 0; i < num_data_; ++i) { @@ -146,7 +150,7 @@ class BinaryLogloss: public ObjectiveFunction { } bool ClassNeedTrain(int /*class_id*/) const override { - return num_data_ > 0; + return need_train_; } const char* GetName() const override { @@ -185,6 +189,7 @@ class BinaryLogloss: public ObjectiveFunction { const label_t* weights_; double scale_pos_weight_; std::function is_pos_; + bool need_train_; }; } // namespace LightGBM