diff --git a/src/treelearner/serial_tree_learner.cpp b/src/treelearner/serial_tree_learner.cpp index 304c712f0723..9ea50e148901 100644 --- a/src/treelearner/serial_tree_learner.cpp +++ b/src/treelearner/serial_tree_learner.cpp @@ -683,7 +683,8 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf, // init the leaves that used on next iteration if (best_split_info.left_count < best_split_info.right_count) { - CHECK_GT(best_split_info.left_count, 0); + if (best_split_info.left_count == 0) + Log::Warning("Best split left count is 0 for leaf %d", *left_leaf); smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian, @@ -693,7 +694,8 @@ void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf, best_split_info.right_sum_hessian, best_split_info.right_output); } else { - CHECK_GT(best_split_info.right_count, 0); + if (best_split_info.right_count == 0) + Log::Warning("Best split right count is 0 for leaf %d", *right_leaf); smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian, @@ -735,7 +737,8 @@ void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj const double new_output = obj->RenewTreeOutput(output, residual_getter, index_mapper, bag_mapper, cnt_leaf_data); tree->SetLeafOutput(i, new_output); } else { - CHECK_GT(num_machines, 1); + if (num_machines <= 1) + Log::Warning("num_machines less or equal to 1 for leaf %d, num_machines is %d", i, num_machines); tree->SetLeafOutput(i, 0.0); n_nozeroworker_perleaf[i] = 0; } diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index d44e5848fe7c..0cdaa8988868 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -9,6 +9,7 @@ from pathlib import Path import numpy as np +import pandas as pd import psutil import pytest from scipy.sparse import csr_matrix, isspmatrix_csc, isspmatrix_csr @@ -3458,3 +3459,54 @@ def test_boost_from_average_with_single_leaf_trees(): preds = model.predict(X) mean_preds = np.mean(preds) assert y.min() <= mean_preds <= y.max() + + +@pytest.mark.parametrize('device', ['cpu']) +def test_training_leaf_count_zero(device): + # test data is prepared produce one of the following errors (without the fix): + # Check failed: (best_split_info.left_count) > (0) + # Check failed: (best_split_info.right_count) > (0) + # The issues related to this tests are: + # https://github.com/microsoft/LightGBM/issues/4946 + + # Make random data with the seed + R, C = 100000, 10 + if device == 'cpu': + np.random.seed(0) + else: + np.random.seed(50) + data = pd.DataFrame(np.random.randn(R, C), dtype=np.float32) + for i in range(1, C): + data[i] += data[0] * np.random.randn() + + # Split train/test = 60/40 + N = int(0.6 * len(data)) + train_data = data.loc[:N] + test_data = data.loc[N:] + + train = lgb.Dataset(train_data.iloc[:, 1:], train_data.iloc[:, 0], free_raw_data=True) + test = lgb.Dataset(test_data.iloc[:, 1:], test_data.iloc[:, 0], free_raw_data=True, reference=train) + + # The test is run twice, on cpu and gpu + params = { + 'device': device, + 'boosting_type': 'gbdt', + 'objective': 'regression', + 'max_tree_output': 0.03, + 'max_bin': 20, + 'max_depth': 10, + 'num_leaves': 127, + 'seed': 8, + 'learning_rate': 0.01, + 'bagging_fraction': 0.5, + 'bagging_freq': 1, + 'min_data_in_leaf': 0, + 'verbose': -1, + 'min_split_gain': 0.1, + 'cegb_penalty_feature_coupled': 5 * np.ones(C - 1), + 'cegb_penalty_split': 0.0000002, + } + + # The code without the fix will break on the following line + gbm = lgb.train(params, train, num_boost_round=5000, valid_sets=test) + assert True