From d775b007b61a8daabe6352900dd885dc6fc6acdf Mon Sep 17 00:00:00 2001 From: Nikita Titov Date: Sun, 29 Aug 2021 04:19:00 +0300 Subject: [PATCH 1/4] deprecate advanced args of `train()` and `cv()` --- python-package/lightgbm/callback.py | 40 ++++++++++++++++--- python-package/lightgbm/engine.py | 26 ++++++++++--- python-package/lightgbm/sklearn.py | 60 ++++++++++++++++++++--------- 3 files changed, 97 insertions(+), 29 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 0585d889c7d1..2b81c16d2893 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -50,19 +50,27 @@ def _format_eval_result(value: list, show_stdv: bool = True) -> str: def print_evaluation(period: int = 1, show_stdv: bool = True) -> Callable: - """Create a callback that prints the evaluation results. + """Create a callback that logs the evaluation results. + + By default, standard output resource is used. + Use ``register_logger()`` function to register a custom logger. + + Note + ---- + Requires at least one validation data. Parameters ---------- period : int, optional (default=1) - The period to print the evaluation results. + The period to log the evaluation results. + The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged. show_stdv : bool, optional (default=True) - Whether to show stdv (if provided). + Whether to log stdv (if provided). Returns ------- callback : function - The callback that prints the evaluation results every ``period`` iteration(s). + The callback that logs the evaluation results every ``period`` boosting iteration(s). """ def _callback(env: CallbackEnv) -> None: if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0: @@ -82,6 +90,23 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable: This should be initialized outside of your call to ``record_evaluation()`` and should be empty. Any initial contents of the dictionary will be deleted. + .. rubric:: Example + + With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss' + this dictionary after finishing a model training process will have the following structure: + ``` + { + 'train': + { + 'logloss': [0.48253, 0.35953, ...] + }, + 'eval': + { + 'logloss': [0.480385, 0.357756, ...] + } + } + ``` + Returns ------- callback : function @@ -150,11 +175,12 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos Activates early stopping. The model will train until the validation score stops improving. - Validation score needs to improve at least every ``early_stopping_rounds`` round(s) + Validation score needs to improve at least every ``stopping_rounds`` round(s) to continue training. Requires at least one validation data and one metric. If there's more than one, will check all of them. But the training data is ignored anyway. To check only the first metric set ``first_metric_only`` to True. + The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model. Parameters ---------- @@ -163,7 +189,9 @@ def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbos first_metric_only : bool, optional (default=False) Whether to use only the first metric for early stopping. verbose : bool, optional (default=True) - Whether to print message with early stopping information. + Whether to log message with early stopping information. + By default, standard output resource is used. + Use ``register_logger()`` function to register a custom logger. Returns ------- diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index fbaac0b6a7c9..59b51bfee15f 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -35,7 +35,7 @@ def train( categorical_feature: Union[List[str], List[int], str] = 'auto', early_stopping_rounds: Optional[int] = None, evals_result: Optional[Dict[str, Any]] = None, - verbose_eval: Union[bool, int] = True, + verbose_eval: Union[bool, int, str] = 'warn', learning_rates: Optional[Union[List[float], Callable[[int], float]]] = None, keep_training_booster: bool = False, callbacks: Optional[List[Callable]] = None @@ -121,7 +121,7 @@ def train( To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in ``params``. The index of iteration that has the best performance will be saved in the ``best_iteration`` field if early stopping logic is enabled by setting ``early_stopping_rounds``. - evals_result: dict or None, optional (default=None) + evals_result : dict or None, optional (default=None) Dictionary used to store all evaluation results of all the items in ``valid_sets``. This should be initialized outside of your call to ``train()`` and should be empty. Any initial contents of the dictionary will be deleted. @@ -176,10 +176,13 @@ def train( num_boost_round = params.pop(alias) _log_warning(f"Found `{alias}` in params. Will use it instead of argument") params["num_iterations"] = num_boost_round + # show deprecation warning only for early stop argument, setting early stop via global params should still be possible + if early_stopping_rounds is not None and early_stopping_rounds > 0: + _log_warning("'early_stopping_rounds' argument is deprecated and will be removed in 4.0.0 release. " + "Pass 'early_stopping()' callback via 'callbacks' argument instead.") for alias in _ConfigAliases.get("early_stopping_round"): if alias in params: early_stopping_rounds = params.pop(alias) - _log_warning(f"Found `{alias}` in params. Will use it instead of argument") params["early_stopping_round"] = early_stopping_rounds first_metric_only = params.get('first_metric_only', False) @@ -233,6 +236,11 @@ def train( callbacks = set(callbacks) # Most of legacy advanced options becomes callbacks + if verbose_eval != "warn": + _log_warning("'verbose_eval' argument is deprecated and will be removed in 4.0.0 release. " + "Pass 'print_evaluation()' callback via 'callbacks' argument instead.") + if verbose_eval == "warn": + verbose_eval = True if verbose_eval is True: callbacks.add(callback.print_evaluation()) elif isinstance(verbose_eval, int): @@ -242,9 +250,13 @@ def train( callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval))) if learning_rates is not None: + _log_warning("'learning_rates' argument is deprecated and will be removed in 4.0.0 release. " + "Pass 'reset_parameter()' callback via 'callbacks' argument instead.") callbacks.add(callback.reset_parameter(learning_rate=learning_rates)) if evals_result is not None: + _log_warning("'evals_result' argument is deprecated and will be removed in 4.0.0 release. " + "Pass 'record_evaluation()' callback via 'callbacks' argument instead.") callbacks.add(callback.record_evaluation(evals_result)) callbacks_before_iter = {cb for cb in callbacks if getattr(cb, 'before_iteration', False)} @@ -520,7 +532,6 @@ def cv(params, train_set, num_boost_round=100, and returns transformed versions of those. verbose_eval : bool, int, or None, optional (default=None) Whether to display the progress. - If None, progress will be displayed when np.ndarray is returned. If True, progress will be displayed at every boosting stage. If int, progress will be displayed at every given ``verbose_eval`` boosting stage. show_stdv : bool, optional (default=True) @@ -560,9 +571,11 @@ def cv(params, train_set, num_boost_round=100, _log_warning(f"Found `{alias}` in params. Will use it instead of argument") num_boost_round = params.pop(alias) params["num_iterations"] = num_boost_round + if early_stopping_rounds is not None and early_stopping_rounds > 0: + _log_warning("'early_stopping_rounds' argument is deprecated and will be removed in 4.0.0 release. " + "Pass ``early_stopping()`` callback via ``callbacks`` argument instead.") for alias in _ConfigAliases.get("early_stopping_round"): if alias in params: - _log_warning(f"Found `{alias}` in params. Will use it instead of argument") early_stopping_rounds = params.pop(alias) params["early_stopping_round"] = early_stopping_rounds first_metric_only = params.get('first_metric_only', False) @@ -601,6 +614,9 @@ def cv(params, train_set, num_boost_round=100, callbacks = set(callbacks) if early_stopping_rounds is not None and early_stopping_rounds > 0: callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=False)) + if verbose_eval is not None: + _log_warning("'verbose_eval' argument is deprecated and will be removed in 4.0.0 release. " + "Pass 'print_evaluation()' callback via 'callbacks' argument instead.") if verbose_eval is True: callbacks.add(callback.print_evaluation(show_stdv=show_stdv)) elif isinstance(verbose_eval, int): diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 9fd2759a0e06..eb60ba726dde 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -7,6 +7,7 @@ import numpy as np from .basic import Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _log_warning +from .callback import print_evaluation, record_evaluation from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray, _LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase, _LGBMComputeSampleWeight, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, dt_DataTable, @@ -570,7 +571,7 @@ def fit(self, X, y, sample_weight=None, init_score=None, group=None, eval_set=None, eval_names=None, eval_sample_weight=None, eval_class_weight=None, eval_init_score=None, eval_group=None, - eval_metric=None, early_stopping_rounds=None, verbose=True, + eval_metric=None, early_stopping_rounds=None, verbose='warn', feature_name='auto', categorical_feature='auto', callbacks=None, init_model=None): """Docstring is set after definition, using a template.""" @@ -587,7 +588,7 @@ def fit(self, X, y, self._fobj = _ObjectiveFunctionWrapper(self._objective) else: self._fobj = None - evals_result = {} + params = self.get_params() # user can set verbose with kwargs, it has higher priority if not any(verbose_alias in params for verbose_alias in _ConfigAliases.get("verbosity")) and self.silent: @@ -709,19 +710,42 @@ def _get_meta_data(collection, name, i): if isinstance(init_model, LGBMModel): init_model = init_model.booster_ - self._Booster = train(params, train_set, - self.n_estimators, valid_sets=valid_sets, valid_names=eval_names, - early_stopping_rounds=early_stopping_rounds, - evals_result=evals_result, fobj=self._fobj, feval=eval_metrics_callable, - verbose_eval=verbose, feature_name=feature_name, - callbacks=callbacks, init_model=init_model) + if early_stopping_rounds is not None and early_stopping_rounds > 0: + _log_warning("'early_stopping_rounds' argument is deprecated and will be removed in 4.0.0 release. " + "Pass 'early_stopping()' callback via 'callbacks' argument instead.") + params['early_stopping_rounds'] = early_stopping_rounds + + if callbacks is None: + callbacks = [] + + evals_result = {} + callbacks.append(record_evaluation(evals_result)) + + if verbose != 'warn': + _log_warning("'verbose' argument is deprecated and will be removed in 4.0.0 release. " + "Pass 'print_evaluation()' callback via 'callbacks' argument instead.") + if verbose == 'warn': + verbose = True + callbacks.append(print_evaluation(int(verbose))) + + self._Booster = train( + params, + train_set, + self.n_estimators, + valid_sets=valid_sets, + valid_names=eval_names, + evals_result=evals_result, + fobj=self._fobj, + feval=eval_metrics_callable, + feature_name=feature_name, + callbacks=callbacks, + init_model=init_model + ) if evals_result: self._evals_result = evals_result - if early_stopping_rounds is not None and early_stopping_rounds > 0: - self._best_iteration = self._Booster.best_iteration - + self._best_iteration = self._Booster.best_iteration self._best_score = self._Booster.best_score self.fitted_ = True @@ -782,16 +806,16 @@ def n_features_in_(self): @property def best_score_(self): - """:obj:`dict` or :obj:`None`: The best score of fitted model.""" + """:obj:`collections.OrderedDict`: The best score of fitted model.""" if self._n_features is None: raise LGBMNotFittedError('No best_score found. Need to call fit beforehand.') return self._best_score @property def best_iteration_(self): - """:obj:`int` or :obj:`None`: The best iteration of fitted model if ``early_stopping_rounds`` has been specified.""" + """:obj:`int`: The best iteration of fitted model if ``early_stopping()`` callback has been specified.""" if self._n_features is None: - raise LGBMNotFittedError('No best_iteration found. Need to call fit with early_stopping_rounds beforehand.') + raise LGBMNotFittedError('No best_iteration found. Need to call fit with early_stopping callback beforehand.') return self._best_iteration @property @@ -810,7 +834,7 @@ def booster_(self): @property def evals_result_(self): - """:obj:`dict` or :obj:`None`: The evaluation results if ``early_stopping_rounds`` has been specified.""" + """:obj:`dict` or :obj:`None`: The evaluation results if validation sets have been specified.""" if self._n_features is None: raise LGBMNotFittedError('No results found. Need to call fit with eval_set beforehand.') return self._evals_result @@ -843,7 +867,7 @@ def fit(self, X, y, sample_weight=None, init_score=None, eval_set=None, eval_names=None, eval_sample_weight=None, eval_init_score=None, eval_metric=None, early_stopping_rounds=None, - verbose=True, feature_name='auto', categorical_feature='auto', + verbose='warn', feature_name='auto', categorical_feature='auto', callbacks=None, init_model=None): """Docstring is inherited from the LGBMModel.""" super().fit(X, y, sample_weight=sample_weight, init_score=init_score, @@ -869,7 +893,7 @@ def fit(self, X, y, sample_weight=None, init_score=None, eval_set=None, eval_names=None, eval_sample_weight=None, eval_class_weight=None, eval_init_score=None, eval_metric=None, - early_stopping_rounds=None, verbose=True, + early_stopping_rounds=None, verbose='warn', feature_name='auto', categorical_feature='auto', callbacks=None, init_model=None): """Docstring is inherited from the LGBMModel.""" @@ -997,7 +1021,7 @@ def fit(self, X, y, sample_weight=None, init_score=None, group=None, eval_set=None, eval_names=None, eval_sample_weight=None, eval_init_score=None, eval_group=None, eval_metric=None, - eval_at=(1, 2, 3, 4, 5), early_stopping_rounds=None, verbose=True, + eval_at=(1, 2, 3, 4, 5), early_stopping_rounds=None, verbose='warn', feature_name='auto', categorical_feature='auto', callbacks=None, init_model=None): """Docstring is inherited from the LGBMModel.""" From 694b57b3714109b895f03caf04eb701433577aa0 Mon Sep 17 00:00:00 2001 From: Nikita Titov Date: Mon, 30 Aug 2021 00:07:26 +0300 Subject: [PATCH 2/4] update Dask test --- tests/python_package_test/test_dask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 2f84933cfd37..11e512ffd996 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -916,7 +916,7 @@ def test_eval_set_no_early_stopping(task, output, eval_sizes, eval_names_prefix, # check that early stopping was not applied. assert dask_model.booster_.num_trees() == model_trees - assert dask_model.best_iteration_ is None + assert dask_model.best_iteration_ == 0 # checks that evals_result_ and best_score_ contain expected data and eval_set names. evals_result = dask_model.evals_result_ From bca79aaa91893cbd89d93b030cae598a60bafc23 Mon Sep 17 00:00:00 2001 From: Nikita Titov Date: Mon, 30 Aug 2021 00:42:31 +0300 Subject: [PATCH 3/4] improve deducing --- python-package/lightgbm/engine.py | 7 +++++-- python-package/lightgbm/sklearn.py | 11 +++++++---- tests/python_package_test/test_utilities.py | 4 +--- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 59b51bfee15f..798ee33d0d18 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -240,7 +240,10 @@ def train( _log_warning("'verbose_eval' argument is deprecated and will be removed in 4.0.0 release. " "Pass 'print_evaluation()' callback via 'callbacks' argument instead.") if verbose_eval == "warn": - verbose_eval = True + if callbacks: # assume user has already specified print_evaluation callback + verbose_eval = False + else: + verbose_eval = True if verbose_eval is True: callbacks.add(callback.print_evaluation()) elif isinstance(verbose_eval, int): @@ -573,7 +576,7 @@ def cv(params, train_set, num_boost_round=100, params["num_iterations"] = num_boost_round if early_stopping_rounds is not None and early_stopping_rounds > 0: _log_warning("'early_stopping_rounds' argument is deprecated and will be removed in 4.0.0 release. " - "Pass ``early_stopping()`` callback via ``callbacks`` argument instead.") + "Pass 'early_stopping()' callback via 'callbacks' argument instead.") for alias in _ConfigAliases.get("early_stopping_round"): if alias in params: early_stopping_rounds = params.pop(alias) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 599f87ca7448..b50509feb083 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -720,16 +720,19 @@ def _get_meta_data(collection, name, i): else: callbacks = copy.deepcopy(callbacks) - evals_result = {} - callbacks.append(record_evaluation(evals_result)) - if verbose != 'warn': _log_warning("'verbose' argument is deprecated and will be removed in 4.0.0 release. " "Pass 'print_evaluation()' callback via 'callbacks' argument instead.") if verbose == 'warn': - verbose = True + if callbacks: # assume user has already specified print_evaluation callback + verbose = False + else: + verbose = True callbacks.append(print_evaluation(int(verbose))) + evals_result = {} + callbacks.append(record_evaluation(evals_result)) + self._Booster = train( params=params, train_set=train_set, diff --git a/tests/python_package_test/test_utilities.py b/tests/python_package_test/test_utilities.py index 45bd77b73009..2fdacf8e8869 100644 --- a/tests/python_package_test/test_utilities.py +++ b/tests/python_package_test/test_utilities.py @@ -38,15 +38,13 @@ def dummy_metric(_, __): ] lgb.train({'objective': 'binary', 'metric': ['auc', 'binary_error']}, lgb_data, num_boost_round=10, feval=dummy_metric, - valid_sets=[lgb_data], categorical_feature=[1], verbose_eval=False, - callbacks=callbacks) + valid_sets=[lgb_data], categorical_feature=[1], callbacks=callbacks) lgb.plot_metric(eval_records) expected_log = r""" WARNING | categorical_feature in Dataset is overridden. New categorical_feature is [1] -WARNING | 'verbose_eval' argument is deprecated and will be removed in 4.0.0 release. Pass 'print_evaluation()' callback via 'callbacks' argument instead. INFO | [LightGBM] [Warning] There are no meaningful features, as all feature values are constant. INFO | [LightGBM] [Info] Number of positive: 2, number of negative: 2 INFO | [LightGBM] [Info] Total Bins 0 From 07046ecafff4f949d96883a2088487c269e27ac5 Mon Sep 17 00:00:00 2001 From: Nikita Titov Date: Wed, 1 Sep 2021 01:11:50 +0300 Subject: [PATCH 4/4] address review comments --- python-package/lightgbm/engine.py | 14 +++++++------- python-package/lightgbm/sklearn.py | 16 +++++++++++----- tests/python_package_test/test_dask.py | 2 +- 3 files changed, 19 insertions(+), 13 deletions(-) diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 646c1ec8505d..2bca6cbdb97b 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -178,7 +178,7 @@ def train( params["num_iterations"] = num_boost_round # show deprecation warning only for early stop argument, setting early stop via global params should still be possible if early_stopping_rounds is not None and early_stopping_rounds > 0: - _log_warning("'early_stopping_rounds' argument is deprecated and will be removed in 4.0.0 release. " + _log_warning("'early_stopping_rounds' argument is deprecated and will be removed in a future release of LightGBM. " "Pass 'early_stopping()' callback via 'callbacks' argument instead.") for alias in _ConfigAliases.get("early_stopping_round"): if alias in params: @@ -237,9 +237,9 @@ def train( # Most of legacy advanced options becomes callbacks if verbose_eval != "warn": - _log_warning("'verbose_eval' argument is deprecated and will be removed in 4.0.0 release. " + _log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. " "Pass 'print_evaluation()' callback via 'callbacks' argument instead.") - if verbose_eval == "warn": + else: if callbacks: # assume user has already specified print_evaluation callback verbose_eval = False else: @@ -253,12 +253,12 @@ def train( callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval))) if learning_rates is not None: - _log_warning("'learning_rates' argument is deprecated and will be removed in 4.0.0 release. " + _log_warning("'learning_rates' argument is deprecated and will be removed in a future release of LightGBM. " "Pass 'reset_parameter()' callback via 'callbacks' argument instead.") callbacks.add(callback.reset_parameter(learning_rate=learning_rates)) if evals_result is not None: - _log_warning("'evals_result' argument is deprecated and will be removed in 4.0.0 release. " + _log_warning("'evals_result' argument is deprecated and will be removed in a future release of LightGBM. " "Pass 'record_evaluation()' callback via 'callbacks' argument instead.") callbacks.add(callback.record_evaluation(evals_result)) @@ -575,7 +575,7 @@ def cv(params, train_set, num_boost_round=100, num_boost_round = params.pop(alias) params["num_iterations"] = num_boost_round if early_stopping_rounds is not None and early_stopping_rounds > 0: - _log_warning("'early_stopping_rounds' argument is deprecated and will be removed in 4.0.0 release. " + _log_warning("'early_stopping_rounds' argument is deprecated and will be removed in a future release of LightGBM. " "Pass 'early_stopping()' callback via 'callbacks' argument instead.") for alias in _ConfigAliases.get("early_stopping_round"): if alias in params: @@ -618,7 +618,7 @@ def cv(params, train_set, num_boost_round=100, if early_stopping_rounds is not None and early_stopping_rounds > 0: callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=False)) if verbose_eval is not None: - _log_warning("'verbose_eval' argument is deprecated and will be removed in 4.0.0 release. " + _log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. " "Pass 'print_evaluation()' callback via 'callbacks' argument instead.") if verbose_eval is True: callbacks.add(callback.print_evaluation(show_stdv=show_stdv)) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 332f971b4c3d..5d6ea2a3f247 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -711,7 +711,7 @@ def _get_meta_data(collection, name, i): init_model = init_model.booster_ if early_stopping_rounds is not None and early_stopping_rounds > 0: - _log_warning("'early_stopping_rounds' argument is deprecated and will be removed in 4.0.0 release. " + _log_warning("'early_stopping_rounds' argument is deprecated and will be removed in a future release of LightGBM. " "Pass 'early_stopping()' callback via 'callbacks' argument instead.") params['early_stopping_rounds'] = early_stopping_rounds @@ -721,9 +721,9 @@ def _get_meta_data(collection, name, i): callbacks = copy.deepcopy(callbacks) if verbose != 'warn': - _log_warning("'verbose' argument is deprecated and will be removed in 4.0.0 release. " + _log_warning("'verbose' argument is deprecated and will be removed in a future release of LightGBM. " "Pass 'print_evaluation()' callback via 'callbacks' argument instead.") - if verbose == 'warn': + else: if callbacks: # assume user has already specified print_evaluation callback verbose = False else: @@ -748,8 +748,14 @@ def _get_meta_data(collection, name, i): if evals_result: self._evals_result = evals_result + else: # reset after previous call to fit() + self._evals_result = None + + if self._Booster.best_iteration != 0: + self._best_iteration = self._Booster.best_iteration + else: # reset after previous call to fit() + self._best_iteration = None - self._best_iteration = self._Booster.best_iteration self._best_score = self._Booster.best_score self.fitted_ = True @@ -817,7 +823,7 @@ def best_score_(self): @property def best_iteration_(self): - """:obj:`int`: The best iteration of fitted model if ``early_stopping()`` callback has been specified.""" + """:obj:`int` or :obj:`None`: The best iteration of fitted model if ``early_stopping()`` callback has been specified.""" if self._n_features is None: raise LGBMNotFittedError('No best_iteration found. Need to call fit with early_stopping callback beforehand.') return self._best_iteration diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 11e512ffd996..2f84933cfd37 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -916,7 +916,7 @@ def test_eval_set_no_early_stopping(task, output, eval_sizes, eval_names_prefix, # check that early stopping was not applied. assert dask_model.booster_.num_trees() == model_trees - assert dask_model.best_iteration_ == 0 + assert dask_model.best_iteration_ is None # checks that evals_result_ and best_score_ contain expected data and eval_set names. evals_result = dask_model.evals_result_