Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] add scikit-learn-style API for early stopping #5808

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ad43e17
Enable Auto Early Stopping
ClaudioSalvatoreArcidiacono Sep 15, 2023
f05e5e0
Relax test conditions
ClaudioSalvatoreArcidiacono Sep 18, 2023
76f3c19
Merge branch 'master' into 3313-enable-auto-early-stopping
ClaudioSalvatoreArcidiacono Sep 20, 2023
457c7f6
Merge master
ClaudioSalvatoreArcidiacono Jan 16, 2024
0db1941
Revert "Merge master"
ClaudioSalvatoreArcidiacono Jan 16, 2024
10fac65
Merge remote-tracking branch 'lgbm/master' into 3313-enable-auto-earl…
ClaudioSalvatoreArcidiacono Jan 16, 2024
d10ca54
Add missing import
ClaudioSalvatoreArcidiacono Jan 16, 2024
3b8eb0a
Remove added extra new line
ClaudioSalvatoreArcidiacono Jan 17, 2024
e47acc0
Merge branch 'master' into 3313-enable-auto-early-stopping
ClaudioSalvatoreArcidiacono Jan 17, 2024
66701ac
Merge branch 'master' into 3313-enable-auto-early-stopping
ClaudioSalvatoreArcidiacono Jan 25, 2024
39d333e
Merge branch 'master' into 3313-enable-auto-early-stopping
ClaudioSalvatoreArcidiacono Feb 2, 2024
cad7eb6
Merge branch 'master' into 3313-enable-auto-early-stopping
ClaudioSalvatoreArcidiacono Feb 6, 2024
1234ccf
Merge master
ClaudioSalvatoreArcidiacono Nov 28, 2024
d54c96a
Improve documentation, check default behavior of early stopping
ClaudioSalvatoreArcidiacono Nov 28, 2024
9c1c8b4
Solve python 3.8 compatibility issue
ClaudioSalvatoreArcidiacono Nov 28, 2024
724c7fe
Remove default to auto
ClaudioSalvatoreArcidiacono Nov 29, 2024
c957fce
Revert changes in fit top part
ClaudioSalvatoreArcidiacono Nov 29, 2024
2d7da78
Make interface as similar as possible to sklearn
ClaudioSalvatoreArcidiacono Nov 29, 2024
069a84e
Add parameters to dask interface
ClaudioSalvatoreArcidiacono Nov 29, 2024
c430ec1
Improve documentation
ClaudioSalvatoreArcidiacono Nov 29, 2024
416323a
Linting
ClaudioSalvatoreArcidiacono Nov 29, 2024
73562ff
Check for exact value equal true for early stopping
ClaudioSalvatoreArcidiacono Nov 29, 2024
38edc42
Merge branch 'master' into 3313-enable-auto-early-stopping
jameslamb Dec 15, 2024
9a32376
Switch if/else conditions order in fit
ClaudioSalvatoreArcidiacono Dec 18, 2024
f33ebd3
Merge remote-tracking branch 'origin/master' into 3313-enable-auto-ea…
ClaudioSalvatoreArcidiacono Dec 18, 2024
a61726f
fix issues in engine.py
ClaudioSalvatoreArcidiacono Dec 18, 2024
44316d7
make new early stopping parameters keyword-only
ClaudioSalvatoreArcidiacono Dec 18, 2024
4cbfc84
Remove n_iter_no_change parameter
ClaudioSalvatoreArcidiacono Dec 18, 2024
93acf6a
Address comments in tests
ClaudioSalvatoreArcidiacono Dec 18, 2024
2b049c9
Improve tests
ClaudioSalvatoreArcidiacono Dec 18, 2024
61371cb
Add tests to check for validation fraction
ClaudioSalvatoreArcidiacono Dec 18, 2024
65c4e2f
Remove validation_fraction=None option
ClaudioSalvatoreArcidiacono Dec 18, 2024
0a8e843
Remove validation_fraction=None option also in dask
ClaudioSalvatoreArcidiacono Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/Python-Intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ This works with both metrics to minimize (L2, log loss, etc.) and to maximize (N
Note that if you specify more than one evaluation metric, all of them will be used for early stopping.
However, you can change this behavior and make LightGBM check only the first metric for early stopping by passing ``first_metric_only=True`` in ``early_stopping`` callback constructor.

In the scikit-learn API of lightgbm, early stopping can also be enabled by setting the parameter ``early_stopping`` to ``True``
When early stopping is enabled and no validation set is provided, a portion of the training data will be used as validation set.
The amount of data to use for validation is controlled by the parameter ``validation_fraction`` and defaults to 0.1.

Prediction
----------

Expand Down
9 changes: 9 additions & 0 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,9 @@ def __init__(
random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
n_jobs: Optional[int] = None,
importance_type: str = "split",
early_stopping: bool = False,
n_iter_no_change: int = 10,
validation_fraction: Optional[float] = 0.1,
client: Optional[Client] = None,
**kwargs: Any,
):
Expand Down Expand Up @@ -1337,6 +1340,9 @@ def __init__(
random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
n_jobs: Optional[int] = None,
importance_type: str = "split",
early_stopping: bool = False,
n_iter_no_change: int = 10,
validation_fraction: Optional[float] = 0.1,
client: Optional[Client] = None,
**kwargs: Any,
):
Expand Down Expand Up @@ -1504,6 +1510,9 @@ def __init__(
random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
n_jobs: Optional[int] = None,
importance_type: str = "split",
early_stopping: bool = False,
n_iter_no_change: int = 10,
validation_fraction: Optional[float] = 0.1,
client: Optional[Client] = None,
**kwargs: Any,
):
Expand Down
18 changes: 13 additions & 5 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,11 +510,9 @@ def _make_n_folds(
nfold: int,
params: Dict[str, Any],
seed: int,
fpreproc: Optional[_LGBM_PreprocFunction],
stratified: bool,
shuffle: bool,
eval_train_metric: bool,
) -> CVBooster:
) -> Iterable[Tuple[np.ndarray, np.ndarray]]:
"""Make a n-fold list of Booster from random indices."""
full_data = full_data.construct()
num_data = full_data.num_data()
Expand Down Expand Up @@ -559,7 +557,16 @@ def _make_n_folds(
test_id = [randidx[i : i + kstep] for i in range(0, num_data, kstep)]
train_id = [np.concatenate([test_id[i] for i in range(nfold) if k != i]) for k in range(nfold)]
folds = zip(train_id, test_id)
return folds


def _make_cvbooster(
full_data: Dataset,
params: Dict[str, Any],
folds: Iterable[Tuple[np.ndarray, np.ndarray]],
fpreproc: Optional[_LGBM_PreprocFunction],
eval_train_metric: bool,
) -> CVBooster:
ret = CVBooster()
for train_idx, test_idx in folds:
train_set = full_data.subset(sorted(train_idx))
Expand Down Expand Up @@ -764,10 +771,11 @@ def cv(
nfold=nfold,
params=params,
seed=seed,
fpreproc=fpreproc,
stratified=stratified,
shuffle=shuffle,
eval_train_metric=eval_train_metric,
)
cvbooster = _make_cvbooster(
full_data=train_set, params=params, folds=cvfolds, fpreproc=fpreproc, eval_train_metric=eval_train_metric
)

# setup callbacks
Expand Down
49 changes: 47 additions & 2 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
dt_DataTable,
pd_DataFrame,
)
from .engine import train
from .engine import _make_n_folds, train

if TYPE_CHECKING:
from .compat import _sklearn_Tags
Expand Down Expand Up @@ -507,7 +507,10 @@ def __init__(
random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None,
n_jobs: Optional[int] = None,
importance_type: str = "split",
**kwargs: Any,
early_stopping: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
early_stopping: bool = False,
*,
early_stopping: bool = False,

I think we should make these keyword-only arguments, as they are in scikit-learn: I think we should make these keyword-only arguments, as scikit-learn does.

https://github.com/scikit-learn/scikit-learn/blob/6cccd99aee3483eb0f7562afdd3179ccccab0b1d/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py#L1689

Could you please try that (in these estimators and the Dask ones)?

I don't want to do that for other existing parameters, to prevent breaking existing user code, but since these are new parameters, it's safe to be stricter.

n_iter_no_change: int = 10,
validation_fraction: Optional[float] = 0.1,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**kwargs,
**kwargs: Any,

Why was this type hint removed? If it was just an accident, please put it back to reduce the size of the diff.

):
r"""Construct a gradient boosting model.

Expand Down Expand Up @@ -587,6 +590,16 @@ def __init__(
The type of feature importance to be filled into ``feature_importances_``.
If 'split', result contains numbers of times the feature is used in a model.
If 'gain', result contains total gains of splits which use the feature.
early_stopping : bool, optional (default=False)
Whether to enable early stopping. If set to True, training will stop if the validation score does not improve
for a specified number of rounds (controlled by `n_iter_no_change`).
n_iter_no_change : int, optional (default=10)
If early stopping is enabled, this parameter specifies the number of iterations with no
improvement after which training will be stopped.
validation_fraction : float or None, optional (default=0.1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are not any tests in test_sklearn.py that pass validation_fraction. Please add some, covering both the default behavior and that passing a non-default value (like 0.4) works as expected.

I don't know the exact code paths off the top of my head, would appreciate if you can investigate... but I think it should be possible to test this by checking the size of the datasets added to valid_sets and confirming that they're as expected (e.g. that the automatically-aded validation set has 4,000 rows if the input data X has 40,000 rows and validation_fraction=0.1 is passed).

If that's not observable through the public API, try to use mocking/patching to observe it instead of adding any additional properties to the Booster / estimators' public API.

Comment in-thread here if you have questions or need help with that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment, I have added a couple of tests for that, I have used patching to achieve that, let me know if you think there is more that can be improved in those tests.

Proportion of training data to set aside as
validation data for early stopping. If None, early stopping is done on
the training data. Only used if early stopping is performed.
**kwargs
Other parameters for the model.
Check http://lightgbm.readthedocs.io/en/latest/Parameters.html for more parameters.
Expand Down Expand Up @@ -651,6 +664,9 @@ def __init__(
self.random_state = random_state
self.n_jobs = n_jobs
self.importance_type = importance_type
self.early_stopping = early_stopping
self.n_iter_no_change = n_iter_no_change
self.validation_fraction = validation_fraction
self._Booster: Optional[Booster] = None
self._evals_result: _EvalResultDict = {}
self._best_score: _LGBM_BoosterBestScoreType = {}
Expand Down Expand Up @@ -816,11 +832,19 @@ def _process_params(self, stage: str) -> Dict[str, Any]:
params.pop("importance_type", None)
params.pop("n_estimators", None)
params.pop("class_weight", None)
params.pop("validation_fraction", None)
params.pop("early_stopping", None)
params.pop("n_iter_no_change", None)

if isinstance(params["random_state"], np.random.RandomState):
params["random_state"] = params["random_state"].randint(np.iinfo(np.int32).max)
elif isinstance(params["random_state"], np.random.Generator):
params["random_state"] = int(params["random_state"].integers(np.iinfo(np.int32).max))

params = _choose_param_value("early_stopping_round", params, self.n_iter_no_change)
if self.early_stopping is not True:
params["early_stopping_round"] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if self.early_stopping is not True:
params["early_stopping_round"] = None

This looks to me like it might turn off early stopping enabled other ways (like passing early_stopping_round=15 or the lgb.early_stopping() callback + some valid_sets) if keyword argument early_stopping=False. Since early_stopping=False is the default, that'd be a backwards-incompatible change.

The early_stopping keyword argument in this PR is not intended to control ALL early stopping, right? I think it should be limited to controlling the scikit-learn-style early stopping, but that the other mechanisms that people have been using with lightgbm for years should continue to work.


if self._n_classes > 2:
for alias in _ConfigAliases.get("num_class"):
params.pop(alias, None)
Expand Down Expand Up @@ -1006,6 +1030,27 @@ def fit(

valid_sets.append(valid_set)

elif self.early_stopping is True:
if self.validation_fraction is not None:
n_splits = max(int(np.ceil(1 / self.validation_fraction)), 2)
stratified = isinstance(self, LGBMClassifier)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not a huge fan of how the validation set is created from the train set using the _make_n_folds function.

if 1/validation_fraction is not an integer, the result will be that the actual validation set size will not match the validation fraction specified by the user.

For example, if the validation fraction is 0.4 the number of splits calculated here will be 2, which will result in a fraction of 0.5, instead of 0.4.

Using something like train_test_split from scikit-learn would solve the issue for the classification and the regression case, but for ranking tasks our best option is GroupShuffleSplit, which will inevitably suffer from the same issue expressed above. The options that I thought to solve this issue are:

  1. Leave the code as-is and raise a warning when 1/validation_fraction is not an integer.
  2. Use train_test_split for creating the validation set in the classification and regression cases; Raise a warning when 1/validation_fraction is not an integer in the ranking case.

I would lean more towards option 2, but this will make the MR bigger.

@jameslamb I would like to hear your opinion on it, do you perhaps already have something else in mind?

cvfolds = _make_n_folds(
full_data=train_set,
folds=None,
nfold=n_splits,
params=params,
seed=self.random_state,
stratified=stratified,
shuffle=True,
)
train_idx, val_idx = next(cvfolds)
valid_set = train_set.subset(sorted(val_idx))
train_set = train_set.subset(sorted(train_idx))
else:
valid_set = train_set
valid_set = valid_set.construct()
valid_sets = [valid_set]

if isinstance(init_model, LGBMModel):
init_model = init_model.booster_

Expand Down
107 changes: 107 additions & 0 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,112 @@ def test_binary_classification_with_custom_objective():
assert ret < 0.05


def test_auto_early_stopping_binary_classification():
X, y = load_breast_cancer(return_X_y=True)
n_estimators = 200
gbm = lgb.LGBMClassifier(n_estimators=n_estimators, random_state=42, verbose=-1, early_stopping=True, num_leaves=5)
gbm.fit(X, y)
assert gbm._Booster.params["early_stopping_round"] == 10
assert gbm._Booster.num_trees() < n_estimators
assert gbm.best_iteration_ < n_estimators


def test_auto_early_stopping_compatibility_with_histgradientboostingclassifier():
X, y = load_breast_cancer(return_X_y=True)
n_estimators = 200
n_iter_no_change = 5
gbm = lgb.LGBMClassifier(
n_estimators=n_estimators,
random_state=42,
verbose=-1,
early_stopping=True,
num_leaves=5,
n_iter_no_change=n_iter_no_change,
)
gbm.fit(X, y)
assert gbm._Booster.params["early_stopping_round"] == n_iter_no_change
assert gbm._Booster.num_trees() < n_estimators
assert gbm.best_iteration_ < n_estimators


def test_auto_early_stopping_categorical_features_set_during_fit(rng_fixed_seed):
pd = pytest.importorskip("pandas")
X = pd.DataFrame(
{
"A": pd.Categorical(
rng_fixed_seed.permutation(["z", "y", "x", "w", "v"] * 60), ordered=True
), # str and ordered categorical
"B": rng_fixed_seed.permutation([1, 2, 3] * 100), # int
"C": rng_fixed_seed.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float
"D": rng_fixed_seed.permutation([True, False] * 150), # bool
}
)
cat_cols_actual = ["A", "B", "C", "D"]
y = rng_fixed_seed.permutation([0, 1] * 150)
n_estimators = 5
gbm = lgb.LGBMClassifier(n_estimators=n_estimators, random_state=42, verbose=-1, early_stopping=True, num_leaves=5)
gbm.fit(X, y, categorical_feature=cat_cols_actual)
assert gbm._Booster.params["early_stopping_round"] == 10
assert gbm._Booster.num_trees() < 5
assert gbm.best_iteration_ < 5


def test_early_stopping_is_deactivated_by_default_regression():
X, y = make_synthetic_regression(n_samples=10_001)
n_estimators = 5
gbm = lgb.LGBMRegressor(n_estimators=n_estimators, random_state=42, verbose=-1)
gbm.fit(X, y)

# Check that early stopping did not kick in
assert gbm._Booster.params.get("early_stopping_round") is None
assert gbm._Booster.num_trees() == n_estimators


def test_early_stopping_is_deactivated_by_default_classification():
X, y = load_breast_cancer(return_X_y=True)
n_estimators = 5
gbm = lgb.LGBMClassifier(n_estimators=n_estimators, random_state=42, verbose=-1)
gbm.fit(X, y)

# Check that early stopping did not kick in
assert gbm._Booster.params.get("early_stopping_round") is None
assert gbm._Booster.num_trees() == n_estimators


def test_early_stopping_is_deactivated_by_default_lambdarank():
rank_example_dir = Path(__file__).absolute().parents[2] / "examples" / "lambdarank"
X_train, y_train = load_svmlight_file(str(rank_example_dir / "rank.train"))
q_train = np.loadtxt(str(rank_example_dir / "rank.train.query"))
n_estimators = 5
gbm = lgb.LGBMRanker(n_estimators=n_estimators, random_state=42, verbose=-1)
gbm.fit(X_train, y_train, group=q_train) # Assuming 10 samples in one group
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
gbm.fit(X_train, y_train, group=q_train) # Assuming 10 samples in one group
gbm.fit(X_train, y_train, group=q_train)

How does this code comment relate to the code? If it's left over from some earlier debugging, please remove it.


# Check that early stopping did not kick in
assert gbm._Booster.params.get("early_stopping_round") is None
assert gbm._Booster.num_trees() == n_estimators


@pytest.mark.skipif(
getenv("TASK", "") == "cuda", reason="Skip due to differences in implementation details of CUDA version"
)
def test_auto_early_stopping_lambdarank():
rank_example_dir = Path(__file__).absolute().parents[2] / "examples" / "lambdarank"
X_train, y_train = load_svmlight_file(str(rank_example_dir / "rank.train"))
q_train = np.loadtxt(str(rank_example_dir / "rank.train.query"))
n_estimators = 5
gbm = lgb.LGBMRanker(n_estimators=n_estimators, random_state=42, early_stopping=True, num_leaves=5)
gbm.fit(
X_train,
y_train,
group=q_train,
eval_at=[1, 3],
callbacks=[lgb.reset_parameter(learning_rate=lambda x: max(0.01, 0.1 - 0.01 * x))],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
callbacks=[lgb.reset_parameter(learning_rate=lambda x: max(0.01, 0.1 - 0.01 * x))],

Why is modifying the learning rate necessary for this test? If this is unnecessary and just copied from somwhere else, please remove it.

)
assert gbm._Booster.params["early_stopping_round"] == 10
assert gbm._Booster.num_trees() < n_estimators
assert gbm.best_iteration_ < n_estimators


def test_dart():
X, y = make_synthetic_regression()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
Expand Down Expand Up @@ -1168,6 +1274,7 @@ def fit_and_check(eval_set_names, metric_names, assumed_iteration, first_metric_
"verbose": -1,
"seed": 123,
"early_stopping_rounds": 5,
"early_stopping": True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this necessary to make this test pass, or is it just being done for convenience?

if it's necessary, could you explain why? If it was just for convenience... please, introduce a new test after this one that tests what you want to test.

This PR should only be adding functionality without breaking any existing user code. I'd be more confident that that's true if the tests showed only additions, with no modifications to existing tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was a necessary change in order to make the test pass.

With the code in the current state, passing early_stopping_rounds while leaving early_stopping to its default value (False) does not enable early stopping. Before it did.

If we want to keep the previous behaviour and we want to enable auto early stopping when early_stopping_rounds is set to a value and early_stopping is left to its default value we need to set the default value of early_stopping_rounds to None and to change it in process_parameters to 10 when early_stopping is set to True, which will make the actual default value of early_stopping_rounds a bit hidden in the code. Let me know what you prefer, both options have their own drawbacks.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that's a problem. And exactly why I have been so strict in asking that we only see new tests added in this PR.

I think that also answers my questions from #5808 (comment)

If the current state of this branch were merged, with this mix of behaviors:

  • early_stopping keyword argument defaulting to False
  • early_stopping=False meaning "do not perform any early stopping"

it would turn off early stopping for all existing code using early stopping with lightgbm.sklearn estimators.

That's a very big user-facing breaking change, and one I do not support. Please, let's keep the scope of this PR limited to what I've written in the title... "add scikit-learn-style API for early stopping".

Specifically, this is the behavior I'm expecting:

  • early_stopping=True:
    • adds a new validation set to eval_set which is a random sample of {validation_fraction} * num_rows rows from the training data
    • enables early stopping, with early_stopping_rounds defaulting to 10 if not provided by any other mechanism
    • does not affect any other relevant early stopping behavior... e.g. if you also passed a dataset to eval_set in fit(...), then that dataset is ALSO used for early stopping
  • early_stopping=False
    • everything lightgbm does is identical to what it does today

The docstring for the early_stopping keyword argument should also be updated to clearly explain this. For example:

        early_stopping : bool, optional (default=False)
            Whether to enable scikit-learn-style early stopping. If set to ``True`, 
            a new validation set will be created by randomly sampling ``validation_fraction`` rows
            from the training data ``X`` passed to ``fit()``. Training will stop if the validation score
            does not improve for a specific number of rounds (controlled by ``n_iter_no_change``).
            This parameter is here for compatibility with ``scikit-learn``'s ``HistGradientBoosting``
            estimators... it does not affect other ``lightgbm``-specific early stopping mechanisms,
            like passing the ``lgb.early_stopping`` callback and validation sets to the ``eval_set``
            argument of `fit()`.

If we want to keep the previous behaviour and we want to enable auto early stopping when early_stopping_rounds is set to a value and early_stopping is left to its default value we need to set the default value of early_stopping_rounds to None

I think this statement is also coming from you trying to tie this new .early_stopping property of the scikit-learn estimators to "whether or not any early stopping happens", and I guess you are referring to this:

if callback._should_enable_early_stopping(params.get("early_stopping_round", 0)):
callbacks_set.add(
callback.early_stopping(
stopping_rounds=params["early_stopping_round"], # type: ignore[arg-type]
first_metric_only=first_metric_only,
min_delta=params.get("early_stopping_min_delta", 0.0),
verbose=_choose_param_value(
main_param_name="verbosity",
params=params,
default_value=1,
).pop("verbosity")
> 0,
)
)

I don't think we should do that. Let's limit the early_stopping keyword argument to the narrow purpose of "whether or not to add 1 randomly-sampled validation set and perform early stopping on it".

Then, the only thing that has to be figured out is how many rounds to to train without improvement before stopping. I gave a proposal for doing that here: #5808 (comment)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ClaudioSalvatoreArcidiacono, I'd like to propose that we take a different approach to getting this PR done.

Could I take over this PR and push to it, to get it into a state that I'd feel comfortable merging?

Then:

  1. you review the state and tell me if anything that's left concerns you
  2. after you and I agree on the state, we ask other maintainers to review

This has been taking a really large amount of both your and my time and energy. I think this approach would get us to a resolution faster.

I would only push new commits and merge commits (no force-pushing), so you could see everything I changed and revert any of it.

Can we do that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @jameslamb, the statement

If the current state of this branch were merged, with this mix of behaviours:

  • early_stopping keyword argument defaulting to False
  • early_stopping=False meaning "do not perform any early stopping"

it would turn off early stopping for all existing code using early stopping with lightgbm.sklearn estimators.

Is not correct. If the current state of this branch were merged it would not enable early stopping automatically. Meaning that passing an early_stopping callback while leaving early_stopping keyword argument defaulting to False will still enable early stopping as it has always had. See for example this other test which is left unchanged and it passes.

The only non backward compatible change introduced at this stage is passing early_stopping_rounds is not sufficient anymore to enable early stopping (if the early_stopping parameter is left to False). I am definetly open to change that.

Regarding your last suggestion, I am extremely thankful for the effort you have invested in this PR and I am sorry to hear that it is taking more effort than you have anticipated. I am putting my best intentions into this PR and I definetly do not want cause any harm to the library or waste any of your precious time unnecessarily.

I would like to give this PR one last attempt, addressing all of your comments and making sure everything is spot on. I will tag you once I think the PR is ready for a last review. If you think that after this last attempt the gap to be filled is still too big, I will step aside and I will happily let you take over and review your code.

Are you fine with it :)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jameslamb, In your previous comment you mentioned

  • early_stopping=True:
    • adds a new validation set to eval_set which is a random sample of {validation_fraction} * num_rows rows from the training data
    • enables early stopping, with early_stopping_rounds defaulting to 10 if not provided by any other mechanism
    • does not affect any other relevant early stopping behavior... e.g. if you also passed a dataset to eval_set in fit(...), then that dataset is ALSO used for early stopping.

Do you think that we should create an extra validation set even if a validation set has already been provided in the fit()?

I think that if a validation set has been provided in the fit() and the parameter early_stopping=True then we do not need to create a new validation set for early stopping. What do you think?

} # early stop should be supported via global LightGBM parameter
params_fit = {"X": X_train, "y": y_train}

Expand Down
Loading