Skip to content

Commit

Permalink
fix bug for one-class binary (#1877)
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke authored Nov 25, 2018
1 parent e55c815 commit 0c4bb89
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/objective/binary_objective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ class BinaryLogloss: public ObjectiveFunction {
++cnt_negative;
}
}
need_train_ = true;
if (cnt_negative == 0 || cnt_positive == 0) {
Log::Warning("Contains only one class");
// not need to boost.
num_data_ = 0;
need_train_ = false;
}
Log::Info("Number of positive: %d, number of negative: %d", cnt_positive, cnt_negative);
// use -1 for negative class, and 1 for positive class
Expand All @@ -91,6 +92,9 @@ class BinaryLogloss: public ObjectiveFunction {
}

void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
if (!need_train_) {
return;
}
if (weights_ == nullptr) {
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
Expand Down Expand Up @@ -146,7 +150,7 @@ class BinaryLogloss: public ObjectiveFunction {
}

bool ClassNeedTrain(int /*class_id*/) const override {
return num_data_ > 0;
return need_train_;
}

const char* GetName() const override {
Expand Down Expand Up @@ -185,6 +189,7 @@ class BinaryLogloss: public ObjectiveFunction {
const label_t* weights_;
double scale_pos_weight_;
std::function<bool(label_t)> is_pos_;
bool need_train_;
};

} // namespace LightGBM
Expand Down

0 comments on commit 0c4bb89

Please sign in to comment.