Skip to content

Commit

Permalink
[python-package] ignore training set on early stopping callback (fixes
Browse files Browse the repository at this point in the history
…#5354) (#5412)

* ignore training set on early stopping callback

* fixes

* lint

* Apply suggestions from code review

Co-authored-by: James Lamb <[email protected]>

* trigger ci

Co-authored-by: James Lamb <[email protected]>
  • Loading branch information
jmoralez and jameslamb authored Aug 28, 2022
1 parent 581d53c commit e063dad
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 22 deletions.
23 changes: 17 additions & 6 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,24 @@ def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
return curr_score < best_score - delta

def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool:
return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name

def _init(self, env: CallbackEnv) -> None:
self.enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
in _ConfigAliases.get("boosting"))
is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting"))
only_train_set = (
len(env.evaluation_result_list) == 1
and self._is_train_set(
ds_name=env.evaluation_result_list[0][0],
eval_name=env.evaluation_result_list[0][1].split(" ")[0],
train_name=env.model._train_data_name)
)
self.enabled = not is_dart and not only_train_set
if not self.enabled:
_log_warning('Early stopping is not available in dart mode')
if is_dart:
_log_warning('Early stopping is not available in dart mode')
elif only_train_set:
_log_warning('Only training set found, disabling early stopping.')
return
if not env.evaluation_result_list:
raise ValueError('For early stopping, '
Expand Down Expand Up @@ -339,9 +352,7 @@ def __call__(self, env: CallbackEnv) -> None:
eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping
if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
or env.evaluation_result_list[i][0] == env.model._train_data_name)):
self._final_iteration_check(env, eval_name_splitted, i)
if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name):
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
if self.verbose:
Expand Down
37 changes: 37 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,43 @@ def test_early_stopping():
assert 'binary_logloss' in gbm.best_score[valid_set_name]


@pytest.mark.parametrize('use_valid', [True, False])
def test_early_stopping_ignores_training_set(use_valid):
x = np.linspace(-1, 1, 100)
X = x.reshape(-1, 1)
y = x**2
X_train, X_valid = X[:80], X[80:]
y_train, y_valid = y[:80], y[80:]
train_ds = lgb.Dataset(X_train, y_train)
valid_ds = lgb.Dataset(X_valid, y_valid)
valid_sets = [train_ds]
valid_names = ['train']
if use_valid:
valid_sets.append(valid_ds)
valid_names.append('valid')
eval_result = {}

def train_fn():
return lgb.train(
{'num_leaves': 5},
train_ds,
num_boost_round=2,
valid_sets=valid_sets,
valid_names=valid_names,
callbacks=[lgb.early_stopping(1), lgb.record_evaluation(eval_result)]
)
if use_valid:
bst = train_fn()
assert bst.best_iteration == 1
assert eval_result['train']['l2'][1] < eval_result['train']['l2'][0] # train improved
assert eval_result['valid']['l2'][1] > eval_result['valid']['l2'][0] # valid didn't
else:
with pytest.warns(UserWarning, match='Only training set found, disabling early stopping.'):
bst = train_fn()
assert bst.current_iteration() == 2
assert bst.best_iteration == 0


@pytest.mark.parametrize('first_metric_only', [True, False])
def test_early_stopping_via_global_params(first_metric_only):
X, y = load_breast_cancer(return_X_y=True)
Expand Down
5 changes: 0 additions & 5 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,11 +1124,6 @@ def fit_and_check(eval_set_names, metric_names, assumed_iteration, first_metric_
iter_min = min([iter_min_l1, iter_min_l2])
iter_min_valid1 = min([iter_valid1_l1, iter_valid1_l2])

# training data as eval_set
params_fit['eval_set'] = (X_train, y_train)
fit_and_check(['training'], ['l2'], 30, False)
fit_and_check(['training'], ['l2'], 30, True)

# feval
params['metric'] = 'None'
params_fit['eval_metric'] = lambda preds, train_data: [decreasing_metric(preds, train_data),
Expand Down
23 changes: 12 additions & 11 deletions tests/python_package_test/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@ def dummy_metric(_, __):
[1, 2, 3]],
dtype=np.float32)
y = np.array([0, 1, 1, 0])
lgb_data = lgb.Dataset(X, y)
lgb_train = lgb.Dataset(X, y)
lgb_valid = lgb.Dataset(X, y) # different object for early-stopping

eval_records = {}
callbacks = [
lgb.record_evaluation(eval_records),
lgb.log_evaluation(2),
lgb.early_stopping(4)
lgb.early_stopping(10)
]
lgb.train({'objective': 'binary', 'metric': ['auc', 'binary_error']},
lgb_data, num_boost_round=10, feval=dummy_metric,
valid_sets=[lgb_data], categorical_feature=[1], callbacks=callbacks)
lgb_train, num_boost_round=10, feval=dummy_metric,
valid_sets=[lgb_valid], categorical_feature=[1], callbacks=callbacks)

lgb.plot_metric(eval_records)

Expand All @@ -51,32 +52,32 @@ def dummy_metric(_, __):
INFO | [LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | Training until validation scores don't improve for 4 rounds
INFO | Training until validation scores don't improve for 10 rounds
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [2] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
INFO | [2] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [4] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
INFO | [4] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [6] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
INFO | [6] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [8] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
INFO | [8] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [LightGBM] [Warning] Stopped training because there are no more leaves that meet the split requirements
DEBUG | In dummy_metric
INFO | [10] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
INFO | [10] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
INFO | Did not meet early stopping. Best iteration is:
[1] training's auc: 0.5 training's binary_error: 0.5 training's dummy_metric: 1
[1] valid_0's auc: 0.5 valid_0's binary_error: 0.5 valid_0's dummy_metric: 1
WARNING | More than one metric available, picking one to plot.
""".strip()

Expand Down

0 comments on commit e063dad

Please sign in to comment.