diff --git a/docs/Python-Intro.rst b/docs/Python-Intro.rst index e7c44bf141bf..8cbd0d42c7e8 100644 --- a/docs/Python-Intro.rst +++ b/docs/Python-Intro.rst @@ -244,6 +244,10 @@ This works with both metrics to minimize (L2, log loss, etc.) and to maximize (N Note that if you specify more than one evaluation metric, all of them will be used for early stopping. However, you can change this behavior and make LightGBM check only the first metric for early stopping by passing ``first_metric_only=True`` in ``early_stopping`` callback constructor. +In the scikit-learn API of lightgbm, early stopping can also be enabled by setting the parameter ``early_stopping`` to ``True`` +When early stopping is enabled and no validation set is provided, a portion of the training data will be used as validation set. +The amount of data to use for validation is controlled by the parameter ``validation_fraction`` and defaults to 0.1. + Prediction ---------- diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index dcdacba7366c..6f3b8f2f548a 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -1135,6 +1135,9 @@ def __init__( n_jobs: Optional[int] = None, importance_type: str = "split", client: Optional[Client] = None, + *, + early_stopping: bool = False, + validation_fraction: float = 0.1, **kwargs: Any, ): """Docstring is inherited from the lightgbm.LGBMClassifier.__init__.""" @@ -1338,6 +1341,9 @@ def __init__( n_jobs: Optional[int] = None, importance_type: str = "split", client: Optional[Client] = None, + *, + early_stopping: bool = False, + validation_fraction: float = 0.1, **kwargs: Any, ): """Docstring is inherited from the lightgbm.LGBMRegressor.__init__.""" @@ -1505,6 +1511,9 @@ def __init__( n_jobs: Optional[int] = None, importance_type: str = "split", client: Optional[Client] = None, + *, + early_stopping: bool = False, + validation_fraction: float = 0.1, **kwargs: Any, ): """Docstring is inherited from the lightgbm.LGBMRanker.__init__.""" diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 20dfc62b8856..1bba09ab3f0f 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -510,11 +510,9 @@ def _make_n_folds( nfold: int, params: Dict[str, Any], seed: int, - fpreproc: Optional[_LGBM_PreprocFunction], stratified: bool, shuffle: bool, - eval_train_metric: bool, -) -> CVBooster: +) -> Iterable[Tuple[np.ndarray, np.ndarray]]: """Make a n-fold list of Booster from random indices.""" full_data = full_data.construct() num_data = full_data.num_data() @@ -559,7 +557,16 @@ def _make_n_folds( test_id = [randidx[i : i + kstep] for i in range(0, num_data, kstep)] train_id = [np.concatenate([test_id[i] for i in range(nfold) if k != i]) for k in range(nfold)] folds = zip(train_id, test_id) + return folds + +def _make_cvbooster( + full_data: Dataset, + params: Dict[str, Any], + folds: Iterable[Tuple[np.ndarray, np.ndarray]], + fpreproc: Optional[_LGBM_PreprocFunction], + eval_train_metric: bool, +) -> CVBooster: ret = CVBooster() for train_idx, test_idx in folds: train_set = full_data.subset(sorted(train_idx)) @@ -758,16 +765,17 @@ def cv( train_set._update_params(params)._set_predictor(predictor) results = defaultdict(list) - cvbooster = _make_n_folds( + cvfolds = _make_n_folds( full_data=train_set, folds=folds, nfold=nfold, params=params, seed=seed, - fpreproc=fpreproc, stratified=stratified, shuffle=shuffle, - eval_train_metric=eval_train_metric, + ) + cvbooster = _make_cvbooster( + full_data=train_set, params=params, folds=cvfolds, fpreproc=fpreproc, eval_train_metric=eval_train_metric ) # setup callbacks diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 108ef1e14498..20714f7cae4e 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -44,7 +44,7 @@ dt_DataTable, pd_DataFrame, ) -from .engine import train +from .engine import _make_n_folds, train if TYPE_CHECKING: from .compat import _sklearn_Tags @@ -507,6 +507,9 @@ def __init__( random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None, n_jobs: Optional[int] = None, importance_type: str = "split", + *, + early_stopping: Union[bool, int] = False, + validation_fraction: float = 0.1, **kwargs: Any, ): r"""Construct a gradient boosting model. @@ -587,6 +590,17 @@ def __init__( The type of feature importance to be filled into ``feature_importances_``. If 'split', result contains numbers of times the feature is used in a model. If 'gain', result contains total gains of splits which use the feature. + early_stopping : bool, optional (default=False) Whether to enable scikit-learn-style early + stopping. If set to ``True` and no other validation set is passed to ``fit()``, a new + validation set will be created by randomly sampling ``validation_fraction`` rows from + the training data ``X`` passed to ``fit()``. Training will stop if the validation score + does not improve for a specific number of rounds (controlled by ``n_iter_no_change``). + This parameter is here for compatibility with ``scikit-learn``'s + ``HistGradientBoosting`` estimators. it does not affect other ``lightgbm``-specific + early stopping mechanisms, like passing the ``lgb.early_stopping`` callback and + validation sets to the ``eval_set`` argument of `fit()`. + validation_fraction : float (default=0.1) + Proportion of training data to set aside as validation data for early stopping. **kwargs Other parameters for the model. Check http://lightgbm.readthedocs.io/en/latest/Parameters.html for more parameters. @@ -651,6 +665,8 @@ def __init__( self.random_state = random_state self.n_jobs = n_jobs self.importance_type = importance_type + self.early_stopping = early_stopping + self.validation_fraction = validation_fraction self._Booster: Optional[Booster] = None self._evals_result: _EvalResultDict = {} self._best_score: _LGBM_BoosterBestScoreType = {} @@ -816,6 +832,7 @@ def _process_params(self, stage: str) -> Dict[str, Any]: params.pop("importance_type", None) params.pop("n_estimators", None) params.pop("class_weight", None) + params.pop("validation_fraction", None) if isinstance(params["random_state"], np.random.RandomState): params["random_state"] = params["random_state"].randint(np.iinfo(np.int32).max) @@ -853,7 +870,28 @@ def _process_params(self, stage: str) -> Dict[str, Any]: params = _choose_param_value("num_threads", params, self.n_jobs) params["num_threads"] = self._process_n_jobs(params["num_threads"]) - return params + if not isinstance(self.early_stopping, bool) and isinstance(self.early_stopping, int): + _log_warning( + f"Found 'early_stopping={self.early_stopping}' passed through keyword arguments. " + "Future versions of 'lightgbm' will not allow this, as scikit-learn expects keyword argument " + "'early_stopping' to be a boolean indicating whether or not to perform early stopping with " + "a randomly-sampled validation set. To set the number of early stopping rounds, and suppress " + f"this warning, pass early_stopping_rounds={self.early_stopping} instead." + ) + params = _choose_param_value( + main_param_name="early_stopping_round", params=params, default_value=self.early_stopping + ) + + params.pop("early_stopping", None) + + if isinstance(self.early_stopping, bool) and self.early_stopping is True: + default_early_stopping_round = 10 + else: + default_early_stopping_round = None + + return _choose_param_value( + main_param_name="early_stopping_round", params=params, default_value=default_early_stopping_round + ) def _process_n_jobs(self, n_jobs: Optional[int]) -> int: """Convert special values of n_jobs to their actual values according to the formulas that apply. @@ -1006,6 +1044,24 @@ def fit( valid_sets.append(valid_set) + elif self.early_stopping is True: + n_splits = max(int(np.ceil(1 / self.validation_fraction)), 2) + stratified = isinstance(self, LGBMClassifier) + cvfolds = _make_n_folds( + full_data=train_set, + folds=None, + nfold=n_splits, + params=params, + seed=self.random_state, + stratified=stratified, + shuffle=True, + ) + train_idx, val_idx = next(cvfolds) + valid_set = train_set.subset(sorted(val_idx)) + train_set = train_set.subset(sorted(train_idx)) + valid_set = valid_set.construct() + valid_sets = [valid_set] + if isinstance(init_model, LGBMModel): init_model = init_model.booster_ diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 1cdd047f1857..869cd449afd4 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -5,6 +5,7 @@ from functools import partial from os import getenv from pathlib import Path +from unittest.mock import patch import joblib import numpy as np @@ -278,6 +279,175 @@ def test_binary_classification_with_custom_objective(): assert ret < 0.05 +def test_auto_early_stopping_binary_classification(): + X, y = load_breast_cancer(return_X_y=True) + n_estimators = 200 + gbm = lgb.LGBMClassifier(n_estimators=n_estimators, random_state=42, verbose=-1, early_stopping=True, num_leaves=5) + gbm.fit(X, y) + assert gbm._Booster.params["early_stopping_round"] == 10 + assert gbm._Booster.num_trees() < n_estimators + assert gbm.best_iteration_ < n_estimators + + +def test_auto_early_stopping_regression(): + X, y = make_synthetic_regression(n_samples=30) + n_estimators = 20 + early_stopping_rounds = 2 + gbm = lgb.LGBMRegressor( + n_estimators=n_estimators, + random_state=42, + verbose=-1, + early_stopping=True, + num_leaves=5, + early_stopping_rounds=early_stopping_rounds, + ) + gbm.fit(X, y) + assert gbm._Booster.params["early_stopping_round"] == early_stopping_rounds + assert gbm._Booster.num_trees() < n_estimators + assert gbm.best_iteration_ < n_estimators + + +def test_auto_early_stopping_check_validation_fraction_default_value(): + n_samples = 30 + X, y = make_synthetic_regression(n_samples=n_samples) + n_estimators = 20 + early_stopping_rounds = 2 + gbm = lgb.LGBMRegressor( + n_estimators=n_estimators, + random_state=42, + verbose=-1, + early_stopping=True, + num_leaves=5, + early_stopping_rounds=early_stopping_rounds, + ) + with patch("lightgbm.sklearn.train", side_effect=lgb.sklearn.train) as mock_train: + gbm.fit(X, y) + + valid_sets = mock_train.call_args.kwargs["valid_sets"] + assert len(valid_sets) == 1 + assert valid_sets[0].num_data() == n_samples * 0.1 + assert mock_train.call_args.kwargs["train_set"].num_data() == n_samples * 0.9 + + +def test_auto_early_stopping_check_set_validation_fraction(): + n_samples = 30 + validation_fraction = 0.2 + X, y = make_synthetic_regression(n_samples=n_samples) + n_estimators = 20 + early_stopping_rounds = 2 + gbm = lgb.LGBMRegressor( + n_estimators=n_estimators, + random_state=42, + verbose=-1, + early_stopping=True, + num_leaves=5, + early_stopping_rounds=early_stopping_rounds, + validation_fraction=validation_fraction, + ) + with patch("lightgbm.sklearn.train", side_effect=lgb.sklearn.train) as mock_train: + gbm.fit(X, y) + + valid_sets = mock_train.call_args.kwargs["valid_sets"] + assert len(valid_sets) == 1 + assert valid_sets[0].num_data() == n_samples * validation_fraction + assert mock_train.call_args.kwargs["train_set"].num_data() == n_samples * (1 - validation_fraction) + + +@pytest.mark.skipif( + getenv("TASK", "") == "cuda", reason="Skip due to differences in implementation details of CUDA version" +) +def test_auto_early_stopping_lambdarank(): + rank_example_dir = Path(__file__).absolute().parents[2] / "examples" / "lambdarank" + X_train, y_train = load_svmlight_file(str(rank_example_dir / "rank.train")) + q_train = np.loadtxt(str(rank_example_dir / "rank.train.query")) + n_estimators = 5 + gbm = lgb.LGBMRanker(n_estimators=n_estimators, random_state=42, early_stopping=True, num_leaves=5) + gbm.fit( + X_train, + y_train, + group=q_train, + eval_at=[1, 3], + ) + assert gbm._Booster.params["early_stopping_round"] == 10 + assert gbm._Booster.num_trees() < n_estimators + assert gbm.best_iteration_ < n_estimators + + +def test_auto_early_stopping_n_iter_no_change(): + X, y = load_breast_cancer(return_X_y=True) + n_estimators = 200 + n_iter_no_change = 5 + gbm = lgb.LGBMClassifier( + n_estimators=n_estimators, + random_state=42, + verbose=-1, + early_stopping=True, + num_leaves=5, + n_iter_no_change=n_iter_no_change, + ) + gbm.fit(X, y) + assert gbm._Booster.params["early_stopping_round"] == n_iter_no_change + assert gbm._Booster.num_trees() < n_estimators + assert gbm.best_iteration_ < n_estimators + + +def test_auto_early_stopping_categorical_features_set_during_fit(rng_fixed_seed): + pd = pytest.importorskip("pandas") + X = pd.DataFrame( + { + "A": pd.Categorical( + rng_fixed_seed.permutation(["z", "y", "x", "w", "v"] * 60), ordered=True + ), # str and ordered categorical + "B": rng_fixed_seed.permutation([1, 2, 3] * 100), # int + "C": rng_fixed_seed.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float + "D": rng_fixed_seed.permutation([True, False] * 150), # bool + } + ) + cat_cols_actual = ["A", "B", "C", "D"] + y = rng_fixed_seed.permutation([0, 1] * 150) + n_estimators = 5 + gbm = lgb.LGBMClassifier(n_estimators=n_estimators, random_state=42, verbose=-1, early_stopping=True, num_leaves=5) + gbm.fit(X, y, categorical_feature=cat_cols_actual) + assert gbm._Booster.params["early_stopping_round"] == 10 + assert gbm._Booster.num_trees() < n_estimators + assert gbm.best_iteration_ < n_estimators + + +def test_early_stopping_is_deactivated_by_default_regression(): + X, y = make_synthetic_regression(n_samples=10_001) + n_estimators = 5 + gbm = lgb.LGBMRegressor(n_estimators=n_estimators, random_state=42, verbose=-1) + gbm.fit(X, y) + + # Check that early stopping did not kick in + assert gbm._Booster.params.get("early_stopping_round") is None + assert gbm._Booster.num_trees() == n_estimators + + +def test_early_stopping_is_deactivated_by_default_classification(): + X, y = load_breast_cancer(return_X_y=True) + n_estimators = 5 + gbm = lgb.LGBMClassifier(n_estimators=n_estimators, random_state=42, verbose=-1) + gbm.fit(X, y) + + # Check that early stopping did not kick in + assert gbm._Booster.params.get("early_stopping_round") is None + assert gbm._Booster.num_trees() == n_estimators + + +def test_early_stopping_is_deactivated_by_default_lambdarank(): + rank_example_dir = Path(__file__).absolute().parents[2] / "examples" / "lambdarank" + X_train, y_train = load_svmlight_file(str(rank_example_dir / "rank.train")) + q_train = np.loadtxt(str(rank_example_dir / "rank.train.query")) + n_estimators = 5 + gbm = lgb.LGBMRanker(n_estimators=n_estimators, random_state=42, verbose=-1) + gbm.fit(X_train, y_train, group=q_train) + + # Check that early stopping did not kick in + assert gbm._Booster.params.get("early_stopping_round") is None + assert gbm._Booster.num_trees() == n_estimators + + def test_dart(): X, y = make_synthetic_regression() X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)