diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index c1412e424e8a..eca68c0d9306 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -258,11 +258,24 @@ def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool: def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool: return curr_score < best_score - delta + def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool: + return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name + def _init(self, env: CallbackEnv) -> None: - self.enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias - in _ConfigAliases.get("boosting")) + is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting")) + only_train_set = ( + len(env.evaluation_result_list) == 1 + and self._is_train_set( + ds_name=env.evaluation_result_list[0][0], + eval_name=env.evaluation_result_list[0][1].split(" ")[0], + train_name=env.model._train_data_name) + ) + self.enabled = not is_dart and not only_train_set if not self.enabled: - _log_warning('Early stopping is not available in dart mode') + if is_dart: + _log_warning('Early stopping is not available in dart mode') + elif only_train_set: + _log_warning('Only training set found, disabling early stopping.') return if not env.evaluation_result_list: raise ValueError('For early stopping, ' @@ -339,9 +352,7 @@ def __call__(self, env: CallbackEnv) -> None: eval_name_splitted = env.evaluation_result_list[i][1].split(" ") if self.first_metric_only and self.first_metric != eval_name_splitted[-1]: continue # use only the first metric for early stopping - if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train" - or env.evaluation_result_list[i][0] == env.model._train_data_name)): - self._final_iteration_check(env, eval_name_splitted, i) + if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name): continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train) elif env.iteration - self.best_iter[i] >= self.stopping_rounds: if self.verbose: diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index e2877a76a549..32ed21c2337a 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -765,6 +765,43 @@ def test_early_stopping(): assert 'binary_logloss' in gbm.best_score[valid_set_name] +@pytest.mark.parametrize('use_valid', [True, False]) +def test_early_stopping_ignores_training_set(use_valid): + x = np.linspace(-1, 1, 100) + X = x.reshape(-1, 1) + y = x**2 + X_train, X_valid = X[:80], X[80:] + y_train, y_valid = y[:80], y[80:] + train_ds = lgb.Dataset(X_train, y_train) + valid_ds = lgb.Dataset(X_valid, y_valid) + valid_sets = [train_ds] + valid_names = ['train'] + if use_valid: + valid_sets.append(valid_ds) + valid_names.append('valid') + eval_result = {} + + def train_fn(): + return lgb.train( + {'num_leaves': 5}, + train_ds, + num_boost_round=2, + valid_sets=valid_sets, + valid_names=valid_names, + callbacks=[lgb.early_stopping(1), lgb.record_evaluation(eval_result)] + ) + if use_valid: + bst = train_fn() + assert bst.best_iteration == 1 + assert eval_result['train']['l2'][1] < eval_result['train']['l2'][0] # train improved + assert eval_result['valid']['l2'][1] > eval_result['valid']['l2'][0] # valid didn't + else: + with pytest.warns(UserWarning, match='Only training set found, disabling early stopping.'): + bst = train_fn() + assert bst.current_iteration() == 2 + assert bst.best_iteration == 0 + + @pytest.mark.parametrize('first_metric_only', [True, False]) def test_early_stopping_via_global_params(first_metric_only): X, y = load_breast_cancer(return_X_y=True) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 4fe65cd8645a..c09be27f1adb 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -1124,11 +1124,6 @@ def fit_and_check(eval_set_names, metric_names, assumed_iteration, first_metric_ iter_min = min([iter_min_l1, iter_min_l2]) iter_min_valid1 = min([iter_valid1_l1, iter_valid1_l2]) - # training data as eval_set - params_fit['eval_set'] = (X_train, y_train) - fit_and_check(['training'], ['l2'], 30, False) - fit_and_check(['training'], ['l2'], 30, True) - # feval params['metric'] = 'None' params_fit['eval_metric'] = lambda preds, train_data: [decreasing_metric(preds, train_data), diff --git a/tests/python_package_test/test_utilities.py b/tests/python_package_test/test_utilities.py index 9c8cd23519fc..1b16550c5d11 100644 --- a/tests/python_package_test/test_utilities.py +++ b/tests/python_package_test/test_utilities.py @@ -29,17 +29,18 @@ def dummy_metric(_, __): [1, 2, 3]], dtype=np.float32) y = np.array([0, 1, 1, 0]) - lgb_data = lgb.Dataset(X, y) + lgb_train = lgb.Dataset(X, y) + lgb_valid = lgb.Dataset(X, y) # different object for early-stopping eval_records = {} callbacks = [ lgb.record_evaluation(eval_records), lgb.log_evaluation(2), - lgb.early_stopping(4) + lgb.early_stopping(10) ] lgb.train({'objective': 'binary', 'metric': ['auc', 'binary_error']}, - lgb_data, num_boost_round=10, feval=dummy_metric, - valid_sets=[lgb_data], categorical_feature=[1], callbacks=callbacks) + lgb_train, num_boost_round=10, feval=dummy_metric, + valid_sets=[lgb_valid], categorical_feature=[1], callbacks=callbacks) lgb.plot_metric(eval_records) @@ -51,32 +52,32 @@ def dummy_metric(_, __): INFO | [LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000 INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements DEBUG | In dummy_metric -INFO | Training until validation scores don't improve for 4 rounds +INFO | Training until validation scores don't improve for 10 rounds INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements DEBUG | In dummy_metric -INFO | [2] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1 +INFO | [2] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1 INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements DEBUG | In dummy_metric INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements DEBUG | In dummy_metric -INFO | [4] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1 +INFO | [4] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1 INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements DEBUG | In dummy_metric INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements DEBUG | In dummy_metric -INFO | [6] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1 +INFO | [6] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1 INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements DEBUG | In dummy_metric INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements DEBUG | In dummy_metric -INFO | [8] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1 +INFO | [8] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1 INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements DEBUG | In dummy_metric INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements DEBUG | In dummy_metric -INFO | [10] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1 +INFO | [10] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1 INFO | Did not meet early stopping. Best iteration is: -[1] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1 +[1] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1 WARNING | More than one metric available, picking one to plot. """.strip()