Skip to content

Commit

Permalink
Cleanup MissingType enum constants (#2931)
Browse files Browse the repository at this point in the history
* [refactor] Cleanup MissingType enum constants

* Update tree.cpp

Co-authored-by: Alberto Ferreira <[email protected]>
  • Loading branch information
AlbertoEAF and Alberto Ferreira authored Apr 2, 2020
1 parent 2d4f390 commit 51f37e9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 28 deletions.
16 changes: 7 additions & 9 deletions include/LightGBM/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,11 @@ class Tree {

inline int NumericalDecision(double fval, int node) const {
uint8_t missing_type = GetMissingType(decision_type_[node]);
if (std::isnan(fval)) {
if (missing_type != 2) {
fval = 0.0f;
}
if (std::isnan(fval) && missing_type != MissingType::NaN) {
fval = 0.0f;
}
if ((missing_type == 1 && IsZero(fval))
|| (missing_type == 2 && std::isnan(fval))) {
if ((missing_type == MissingType::Zero && IsZero(fval))
|| (missing_type == MissingType::NaN && std::isnan(fval))) {
if (GetDecisionType(decision_type_[node], kDefaultLeftMask)) {
return left_child_[node];
} else {
Expand All @@ -279,8 +277,8 @@ class Tree {

inline int NumericalDecisionInner(uint32_t fval, int node, uint32_t default_bin, uint32_t max_bin) const {
uint8_t missing_type = GetMissingType(decision_type_[node]);
if ((missing_type == 1 && fval == default_bin)
|| (missing_type == 2 && fval == max_bin)) {
if ((missing_type == MissingType::Zero && fval == default_bin)
|| (missing_type == MissingType::NaN && fval == max_bin)) {
if (GetDecisionType(decision_type_[node], kDefaultLeftMask)) {
return left_child_[node];
} else {
Expand All @@ -301,7 +299,7 @@ class Tree {
return right_child_[node];;
} else if (std::isnan(fval)) {
// NaN is always in the right
if (missing_type == 2) {
if (missing_type == MissingType::NaN) {
return right_child_[node];
}
int_fval = 0;
Expand Down
27 changes: 8 additions & 19 deletions src/io/tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,7 @@ int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
decision_type_[new_node_idx] = 0;
SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask);
if (missing_type == MissingType::None) {
SetMissingType(&decision_type_[new_node_idx], 0);
} else if (missing_type == MissingType::Zero) {
SetMissingType(&decision_type_[new_node_idx], 1);
} else if (missing_type == MissingType::NaN) {
SetMissingType(&decision_type_[new_node_idx], 2);
}
SetMissingType(&decision_type_[new_node_idx], missing_type);
threshold_in_bin_[new_node_idx] = threshold_bin;
threshold_[new_node_idx] = threshold_double;
++num_leaves_;
Expand All @@ -77,13 +71,7 @@ int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32
int new_node_idx = num_leaves_ - 1;
decision_type_[new_node_idx] = 0;
SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
if (missing_type == MissingType::None) {
SetMissingType(&decision_type_[new_node_idx], 0);
} else if (missing_type == MissingType::Zero) {
SetMissingType(&decision_type_[new_node_idx], 1);
} else if (missing_type == MissingType::NaN) {
SetMissingType(&decision_type_[new_node_idx], 2);
}
SetMissingType(&decision_type_[new_node_idx], missing_type);
threshold_in_bin_[new_node_idx] = num_cat_;
threshold_[new_node_idx] = num_cat_;
++num_cat_;
Expand Down Expand Up @@ -316,9 +304,9 @@ std::string Tree::NodeToJSON(int index) const {
str_buf << "\"default_left\":false," << '\n';
}
uint8_t missing_type = GetMissingType(decision_type_[index]);
if (missing_type == 0) {
if (missing_type == MissingType::None) {
str_buf << "\"missing_type\":\"None\"," << '\n';
} else if (missing_type == 1) {
} else if (missing_type == MissingType::Zero) {
str_buf << "\"missing_type\":\"Zero\"," << '\n';
} else {
str_buf << "\"missing_type\":\"NaN\"," << '\n';
Expand Down Expand Up @@ -347,9 +335,10 @@ std::string Tree::NumericalDecisionIfElse(int node) const {
std::stringstream str_buf;
uint8_t missing_type = GetMissingType(decision_type_[node]);
bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
if (missing_type == 0 || (missing_type == 1 && default_left && kZeroThreshold < threshold_[node])) {
if (missing_type == MissingType::None
|| (missing_type == MissingType::Zero && default_left && kZeroThreshold < threshold_[node])) {
str_buf << "if (fval <= " << threshold_[node] << ") {";
} else if (missing_type == 1) {
} else if (missing_type == MissingType::Zero) {
if (default_left) {
str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
} else {
Expand All @@ -368,7 +357,7 @@ std::string Tree::NumericalDecisionIfElse(int node) const {
std::string Tree::CategoricalDecisionIfElse(int node) const {
uint8_t missing_type = GetMissingType(decision_type_[node]);
std::stringstream str_buf;
if (missing_type == 2) {
if (missing_type == MissingType::NaN) {
str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
} else {
str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
Expand Down

0 comments on commit 51f37e9

Please sign in to comment.