From 51f37e9bc43a7ff3e0773e411168a01b19bbc801 Mon Sep 17 00:00:00 2001 From: Alberto Ferreira Date: Thu, 2 Apr 2020 17:35:27 +0100 Subject: [PATCH] Cleanup MissingType enum constants (#2931) * [refactor] Cleanup MissingType enum constants * Update tree.cpp Co-authored-by: Alberto Ferreira --- include/LightGBM/tree.h | 16 +++++++--------- src/io/tree.cpp | 27 ++++++++------------------- 2 files changed, 15 insertions(+), 28 deletions(-) diff --git a/include/LightGBM/tree.h b/include/LightGBM/tree.h index 55568e41f544..047215231fc6 100644 --- a/include/LightGBM/tree.h +++ b/include/LightGBM/tree.h @@ -257,13 +257,11 @@ class Tree { inline int NumericalDecision(double fval, int node) const { uint8_t missing_type = GetMissingType(decision_type_[node]); - if (std::isnan(fval)) { - if (missing_type != 2) { - fval = 0.0f; - } + if (std::isnan(fval) && missing_type != MissingType::NaN) { + fval = 0.0f; } - if ((missing_type == 1 && IsZero(fval)) - || (missing_type == 2 && std::isnan(fval))) { + if ((missing_type == MissingType::Zero && IsZero(fval)) + || (missing_type == MissingType::NaN && std::isnan(fval))) { if (GetDecisionType(decision_type_[node], kDefaultLeftMask)) { return left_child_[node]; } else { @@ -279,8 +277,8 @@ class Tree { inline int NumericalDecisionInner(uint32_t fval, int node, uint32_t default_bin, uint32_t max_bin) const { uint8_t missing_type = GetMissingType(decision_type_[node]); - if ((missing_type == 1 && fval == default_bin) - || (missing_type == 2 && fval == max_bin)) { + if ((missing_type == MissingType::Zero && fval == default_bin) + || (missing_type == MissingType::NaN && fval == max_bin)) { if (GetDecisionType(decision_type_[node], kDefaultLeftMask)) { return left_child_[node]; } else { @@ -301,7 +299,7 @@ class Tree { return right_child_[node];; } else if (std::isnan(fval)) { // NaN is always in the right - if (missing_type == 2) { + if (missing_type == MissingType::NaN) { return right_child_[node]; } int_fval = 0; diff --git a/src/io/tree.cpp b/src/io/tree.cpp index 5b5e24a2321c..4c8fb4eb0e20 100644 --- a/src/io/tree.cpp +++ b/src/io/tree.cpp @@ -57,13 +57,7 @@ int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin, decision_type_[new_node_idx] = 0; SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask); SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask); - if (missing_type == MissingType::None) { - SetMissingType(&decision_type_[new_node_idx], 0); - } else if (missing_type == MissingType::Zero) { - SetMissingType(&decision_type_[new_node_idx], 1); - } else if (missing_type == MissingType::NaN) { - SetMissingType(&decision_type_[new_node_idx], 2); - } + SetMissingType(&decision_type_[new_node_idx], missing_type); threshold_in_bin_[new_node_idx] = threshold_bin; threshold_[new_node_idx] = threshold_double; ++num_leaves_; @@ -77,13 +71,7 @@ int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32 int new_node_idx = num_leaves_ - 1; decision_type_[new_node_idx] = 0; SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask); - if (missing_type == MissingType::None) { - SetMissingType(&decision_type_[new_node_idx], 0); - } else if (missing_type == MissingType::Zero) { - SetMissingType(&decision_type_[new_node_idx], 1); - } else if (missing_type == MissingType::NaN) { - SetMissingType(&decision_type_[new_node_idx], 2); - } + SetMissingType(&decision_type_[new_node_idx], missing_type); threshold_in_bin_[new_node_idx] = num_cat_; threshold_[new_node_idx] = num_cat_; ++num_cat_; @@ -316,9 +304,9 @@ std::string Tree::NodeToJSON(int index) const { str_buf << "\"default_left\":false," << '\n'; } uint8_t missing_type = GetMissingType(decision_type_[index]); - if (missing_type == 0) { + if (missing_type == MissingType::None) { str_buf << "\"missing_type\":\"None\"," << '\n'; - } else if (missing_type == 1) { + } else if (missing_type == MissingType::Zero) { str_buf << "\"missing_type\":\"Zero\"," << '\n'; } else { str_buf << "\"missing_type\":\"NaN\"," << '\n'; @@ -347,9 +335,10 @@ std::string Tree::NumericalDecisionIfElse(int node) const { std::stringstream str_buf; uint8_t missing_type = GetMissingType(decision_type_[node]); bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask); - if (missing_type == 0 || (missing_type == 1 && default_left && kZeroThreshold < threshold_[node])) { + if (missing_type == MissingType::None + || (missing_type == MissingType::Zero && default_left && kZeroThreshold < threshold_[node])) { str_buf << "if (fval <= " << threshold_[node] << ") {"; - } else if (missing_type == 1) { + } else if (missing_type == MissingType::Zero) { if (default_left) { str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {"; } else { @@ -368,7 +357,7 @@ std::string Tree::NumericalDecisionIfElse(int node) const { std::string Tree::CategoricalDecisionIfElse(int node) const { uint8_t missing_type = GetMissingType(decision_type_[node]); std::stringstream str_buf; - if (missing_type == 2) { + if (missing_type == MissingType::NaN) { str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast(fval); }"; } else { str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast(fval); }";