-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[python-package] add scikit-learn-style API for early stopping #5808
base: master
Are you sure you want to change the base?
Changes from 25 commits
ad43e17
f05e5e0
76f3c19
457c7f6
0db1941
10fac65
d10ca54
3b8eb0a
e47acc0
66701ac
39d333e
cad7eb6
1234ccf
d54c96a
9c1c8b4
724c7fe
c957fce
2d7da78
069a84e
c430ec1
416323a
73562ff
38edc42
9a32376
f33ebd3
a61726f
44316d7
4cbfc84
93acf6a
2b049c9
61371cb
65c4e2f
0a8e843
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -44,7 +44,7 @@ | |||||
dt_DataTable, | ||||||
pd_DataFrame, | ||||||
) | ||||||
from .engine import train | ||||||
from .engine import _make_n_folds, train | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
from .compat import _sklearn_Tags | ||||||
|
@@ -507,7 +507,10 @@ def __init__( | |||||
random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None, | ||||||
n_jobs: Optional[int] = None, | ||||||
importance_type: str = "split", | ||||||
**kwargs: Any, | ||||||
early_stopping: bool = False, | ||||||
n_iter_no_change: int = 10, | ||||||
validation_fraction: Optional[float] = 0.1, | ||||||
**kwargs, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Why was this type hint removed? If it was just an accident, please put it back to reduce the size of the diff. |
||||||
): | ||||||
r"""Construct a gradient boosting model. | ||||||
|
||||||
|
@@ -587,6 +590,16 @@ def __init__( | |||||
The type of feature importance to be filled into ``feature_importances_``. | ||||||
If 'split', result contains numbers of times the feature is used in a model. | ||||||
If 'gain', result contains total gains of splits which use the feature. | ||||||
early_stopping : bool, optional (default=False) | ||||||
Whether to enable early stopping. If set to True, training will stop if the validation score does not improve | ||||||
for a specified number of rounds (controlled by `n_iter_no_change`). | ||||||
n_iter_no_change : int, optional (default=10) | ||||||
If early stopping is enabled, this parameter specifies the number of iterations with no | ||||||
improvement after which training will be stopped. | ||||||
validation_fraction : float or None, optional (default=0.1) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are not any tests in I don't know the exact code paths off the top of my head, would appreciate if you can investigate... but I think it should be possible to test this by checking the size of the datasets added to If that's not observable through the public API, try to use mocking/patching to observe it instead of adding any additional properties to the Booster / estimators' public API. Comment in-thread here if you have questions or need help with that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for your comment, I have added a couple of tests for that, I have used patching to achieve that, let me know if you think there is more that can be improved in those tests. |
||||||
Proportion of training data to set aside as | ||||||
validation data for early stopping. If None, early stopping is done on | ||||||
the training data. Only used if early stopping is performed. | ||||||
**kwargs | ||||||
Other parameters for the model. | ||||||
Check http://lightgbm.readthedocs.io/en/latest/Parameters.html for more parameters. | ||||||
|
@@ -651,6 +664,9 @@ def __init__( | |||||
self.random_state = random_state | ||||||
self.n_jobs = n_jobs | ||||||
self.importance_type = importance_type | ||||||
self.early_stopping = early_stopping | ||||||
self.n_iter_no_change = n_iter_no_change | ||||||
self.validation_fraction = validation_fraction | ||||||
self._Booster: Optional[Booster] = None | ||||||
self._evals_result: _EvalResultDict = {} | ||||||
self._best_score: _LGBM_BoosterBestScoreType = {} | ||||||
|
@@ -816,11 +832,19 @@ def _process_params(self, stage: str) -> Dict[str, Any]: | |||||
params.pop("importance_type", None) | ||||||
params.pop("n_estimators", None) | ||||||
params.pop("class_weight", None) | ||||||
params.pop("validation_fraction", None) | ||||||
params.pop("early_stopping", None) | ||||||
params.pop("n_iter_no_change", None) | ||||||
|
||||||
if isinstance(params["random_state"], np.random.RandomState): | ||||||
params["random_state"] = params["random_state"].randint(np.iinfo(np.int32).max) | ||||||
elif isinstance(params["random_state"], np.random.Generator): | ||||||
params["random_state"] = int(params["random_state"].integers(np.iinfo(np.int32).max)) | ||||||
|
||||||
params = _choose_param_value("early_stopping_round", params, self.n_iter_no_change) | ||||||
if self.early_stopping is not True: | ||||||
params["early_stopping_round"] = None | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
This looks to me like it might turn off early stopping enabled other ways (like passing The |
||||||
|
||||||
if self._n_classes > 2: | ||||||
for alias in _ConfigAliases.get("num_class"): | ||||||
params.pop(alias, None) | ||||||
|
@@ -1006,6 +1030,27 @@ def fit( | |||||
|
||||||
valid_sets.append(valid_set) | ||||||
|
||||||
elif self.early_stopping is True: | ||||||
if self.validation_fraction is not None: | ||||||
n_splits = max(int(np.ceil(1 / self.validation_fraction)), 2) | ||||||
stratified = isinstance(self, LGBMClassifier) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not a huge fan of how the validation set is created from the train set using the if 1/validation_fraction is not an integer, the result will be that the actual validation set size will not match the validation fraction specified by the user. For example, if the validation fraction is 0.4 the number of splits calculated here will be 2, which will result in a fraction of 0.5, instead of 0.4. Using something like train_test_split from scikit-learn would solve the issue for the classification and the regression case, but for ranking tasks our best option is GroupShuffleSplit, which will inevitably suffer from the same issue expressed above. The options that I thought to solve this issue are:
I would lean more towards option 2, but this will make the MR bigger. @jameslamb I would like to hear your opinion on it, do you perhaps already have something else in mind? |
||||||
cvfolds = _make_n_folds( | ||||||
full_data=train_set, | ||||||
folds=None, | ||||||
nfold=n_splits, | ||||||
params=params, | ||||||
seed=self.random_state, | ||||||
stratified=stratified, | ||||||
shuffle=True, | ||||||
) | ||||||
train_idx, val_idx = next(cvfolds) | ||||||
valid_set = train_set.subset(sorted(val_idx)) | ||||||
train_set = train_set.subset(sorted(train_idx)) | ||||||
else: | ||||||
valid_set = train_set | ||||||
valid_set = valid_set.construct() | ||||||
valid_sets = [valid_set] | ||||||
|
||||||
if isinstance(init_model, LGBMModel): | ||||||
init_model = init_model.booster_ | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -278,6 +278,112 @@ def test_binary_classification_with_custom_objective(): | |||||||||||||||||||||||||||||
assert ret < 0.05 | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def test_auto_early_stopping_binary_classification(): | ||||||||||||||||||||||||||||||
X, y = load_breast_cancer(return_X_y=True) | ||||||||||||||||||||||||||||||
n_estimators = 200 | ||||||||||||||||||||||||||||||
gbm = lgb.LGBMClassifier(n_estimators=n_estimators, random_state=42, verbose=-1, early_stopping=True, num_leaves=5) | ||||||||||||||||||||||||||||||
gbm.fit(X, y) | ||||||||||||||||||||||||||||||
assert gbm._Booster.params["early_stopping_round"] == 10 | ||||||||||||||||||||||||||||||
assert gbm._Booster.num_trees() < n_estimators | ||||||||||||||||||||||||||||||
assert gbm.best_iteration_ < n_estimators | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def test_auto_early_stopping_compatibility_with_histgradientboostingclassifier(): | ||||||||||||||||||||||||||||||
X, y = load_breast_cancer(return_X_y=True) | ||||||||||||||||||||||||||||||
n_estimators = 200 | ||||||||||||||||||||||||||||||
n_iter_no_change = 5 | ||||||||||||||||||||||||||||||
gbm = lgb.LGBMClassifier( | ||||||||||||||||||||||||||||||
n_estimators=n_estimators, | ||||||||||||||||||||||||||||||
random_state=42, | ||||||||||||||||||||||||||||||
verbose=-1, | ||||||||||||||||||||||||||||||
early_stopping=True, | ||||||||||||||||||||||||||||||
num_leaves=5, | ||||||||||||||||||||||||||||||
n_iter_no_change=n_iter_no_change, | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
gbm.fit(X, y) | ||||||||||||||||||||||||||||||
assert gbm._Booster.params["early_stopping_round"] == n_iter_no_change | ||||||||||||||||||||||||||||||
assert gbm._Booster.num_trees() < n_estimators | ||||||||||||||||||||||||||||||
assert gbm.best_iteration_ < n_estimators | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def test_auto_early_stopping_categorical_features_set_during_fit(rng_fixed_seed): | ||||||||||||||||||||||||||||||
pd = pytest.importorskip("pandas") | ||||||||||||||||||||||||||||||
X = pd.DataFrame( | ||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||
"A": pd.Categorical( | ||||||||||||||||||||||||||||||
rng_fixed_seed.permutation(["z", "y", "x", "w", "v"] * 60), ordered=True | ||||||||||||||||||||||||||||||
), # str and ordered categorical | ||||||||||||||||||||||||||||||
"B": rng_fixed_seed.permutation([1, 2, 3] * 100), # int | ||||||||||||||||||||||||||||||
"C": rng_fixed_seed.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float | ||||||||||||||||||||||||||||||
"D": rng_fixed_seed.permutation([True, False] * 150), # bool | ||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
cat_cols_actual = ["A", "B", "C", "D"] | ||||||||||||||||||||||||||||||
y = rng_fixed_seed.permutation([0, 1] * 150) | ||||||||||||||||||||||||||||||
n_estimators = 5 | ||||||||||||||||||||||||||||||
gbm = lgb.LGBMClassifier(n_estimators=n_estimators, random_state=42, verbose=-1, early_stopping=True, num_leaves=5) | ||||||||||||||||||||||||||||||
gbm.fit(X, y, categorical_feature=cat_cols_actual) | ||||||||||||||||||||||||||||||
assert gbm._Booster.params["early_stopping_round"] == 10 | ||||||||||||||||||||||||||||||
assert gbm._Booster.num_trees() < 5 | ||||||||||||||||||||||||||||||
assert gbm.best_iteration_ < 5 | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def test_early_stopping_is_deactivated_by_default_regression(): | ||||||||||||||||||||||||||||||
X, y = make_synthetic_regression(n_samples=10_001) | ||||||||||||||||||||||||||||||
n_estimators = 5 | ||||||||||||||||||||||||||||||
gbm = lgb.LGBMRegressor(n_estimators=n_estimators, random_state=42, verbose=-1) | ||||||||||||||||||||||||||||||
gbm.fit(X, y) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# Check that early stopping did not kick in | ||||||||||||||||||||||||||||||
assert gbm._Booster.params.get("early_stopping_round") is None | ||||||||||||||||||||||||||||||
assert gbm._Booster.num_trees() == n_estimators | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def test_early_stopping_is_deactivated_by_default_classification(): | ||||||||||||||||||||||||||||||
X, y = load_breast_cancer(return_X_y=True) | ||||||||||||||||||||||||||||||
n_estimators = 5 | ||||||||||||||||||||||||||||||
gbm = lgb.LGBMClassifier(n_estimators=n_estimators, random_state=42, verbose=-1) | ||||||||||||||||||||||||||||||
gbm.fit(X, y) | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# Check that early stopping did not kick in | ||||||||||||||||||||||||||||||
assert gbm._Booster.params.get("early_stopping_round") is None | ||||||||||||||||||||||||||||||
assert gbm._Booster.num_trees() == n_estimators | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def test_early_stopping_is_deactivated_by_default_lambdarank(): | ||||||||||||||||||||||||||||||
rank_example_dir = Path(__file__).absolute().parents[2] / "examples" / "lambdarank" | ||||||||||||||||||||||||||||||
X_train, y_train = load_svmlight_file(str(rank_example_dir / "rank.train")) | ||||||||||||||||||||||||||||||
q_train = np.loadtxt(str(rank_example_dir / "rank.train.query")) | ||||||||||||||||||||||||||||||
n_estimators = 5 | ||||||||||||||||||||||||||||||
gbm = lgb.LGBMRanker(n_estimators=n_estimators, random_state=42, verbose=-1) | ||||||||||||||||||||||||||||||
gbm.fit(X_train, y_train, group=q_train) # Assuming 10 samples in one group | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
How does this code comment relate to the code? If it's left over from some earlier debugging, please remove it. |
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
# Check that early stopping did not kick in | ||||||||||||||||||||||||||||||
assert gbm._Booster.params.get("early_stopping_round") is None | ||||||||||||||||||||||||||||||
assert gbm._Booster.num_trees() == n_estimators | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
@pytest.mark.skipif( | ||||||||||||||||||||||||||||||
getenv("TASK", "") == "cuda", reason="Skip due to differences in implementation details of CUDA version" | ||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
def test_auto_early_stopping_lambdarank(): | ||||||||||||||||||||||||||||||
rank_example_dir = Path(__file__).absolute().parents[2] / "examples" / "lambdarank" | ||||||||||||||||||||||||||||||
X_train, y_train = load_svmlight_file(str(rank_example_dir / "rank.train")) | ||||||||||||||||||||||||||||||
q_train = np.loadtxt(str(rank_example_dir / "rank.train.query")) | ||||||||||||||||||||||||||||||
n_estimators = 5 | ||||||||||||||||||||||||||||||
gbm = lgb.LGBMRanker(n_estimators=n_estimators, random_state=42, early_stopping=True, num_leaves=5) | ||||||||||||||||||||||||||||||
gbm.fit( | ||||||||||||||||||||||||||||||
X_train, | ||||||||||||||||||||||||||||||
y_train, | ||||||||||||||||||||||||||||||
group=q_train, | ||||||||||||||||||||||||||||||
eval_at=[1, 3], | ||||||||||||||||||||||||||||||
callbacks=[lgb.reset_parameter(learning_rate=lambda x: max(0.01, 0.1 - 0.01 * x))], | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Why is modifying the learning rate necessary for this test? If this is unnecessary and just copied from somwhere else, please remove it. |
||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||
assert gbm._Booster.params["early_stopping_round"] == 10 | ||||||||||||||||||||||||||||||
assert gbm._Booster.num_trees() < n_estimators | ||||||||||||||||||||||||||||||
assert gbm.best_iteration_ < n_estimators | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
def test_dart(): | ||||||||||||||||||||||||||||||
X, y = make_synthetic_regression() | ||||||||||||||||||||||||||||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) | ||||||||||||||||||||||||||||||
|
@@ -1168,6 +1274,7 @@ def fit_and_check(eval_set_names, metric_names, assumed_iteration, first_metric_ | |||||||||||||||||||||||||||||
"verbose": -1, | ||||||||||||||||||||||||||||||
"seed": 123, | ||||||||||||||||||||||||||||||
"early_stopping_rounds": 5, | ||||||||||||||||||||||||||||||
"early_stopping": True, | ||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was this necessary to make this test pass, or is it just being done for convenience? if it's necessary, could you explain why? If it was just for convenience... please, introduce a new test after this one that tests what you want to test. This PR should only be adding functionality without breaking any existing user code. I'd be more confident that that's true if the tests showed only additions, with no modifications to existing tests. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this was a necessary change in order to make the test pass. With the code in the current state, passing If we want to keep the previous behaviour and we want to enable auto early stopping when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, that's a problem. And exactly why I have been so strict in asking that we only see new tests added in this PR. I think that also answers my questions from #5808 (comment) If the current state of this branch were merged, with this mix of behaviors:
it would turn off early stopping for all existing code using early stopping with That's a very big user-facing breaking change, and one I do not support. Please, let's keep the scope of this PR limited to what I've written in the title... "add scikit-learn-style API for early stopping". Specifically, this is the behavior I'm expecting:
The docstring for the
I think this statement is also coming from you trying to tie this new LightGBM/python-package/lightgbm/engine.py Lines 275 to 288 in 4feee28
I don't think we should do that. Let's limit the Then, the only thing that has to be figured out is how many rounds to to train without improvement before stopping. I gave a proposal for doing that here: #5808 (comment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ClaudioSalvatoreArcidiacono, I'd like to propose that we take a different approach to getting this PR done. Could I take over this PR and push to it, to get it into a state that I'd feel comfortable merging? Then:
This has been taking a really large amount of both your and my time and energy. I think this approach would get us to a resolution faster. I would only push new commits and merge commits (no force-pushing), so you could see everything I changed and revert any of it. Can we do that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey @jameslamb, the statement
Is not correct. If the current state of this branch were merged it would not enable early stopping automatically. Meaning that passing an The only non backward compatible change introduced at this stage is passing early_stopping_rounds is not sufficient anymore to enable early stopping (if the early_stopping parameter is left to False). I am definetly open to change that. Regarding your last suggestion, I am extremely thankful for the effort you have invested in this PR and I am sorry to hear that it is taking more effort than you have anticipated. I am putting my best intentions into this PR and I definetly do not want cause any harm to the library or waste any of your precious time unnecessarily. I would like to give this PR one last attempt, addressing all of your comments and making sure everything is spot on. I will tag you once I think the PR is ready for a last review. If you think that after this last attempt the gap to be filled is still too big, I will step aside and I will happily let you take over and review your code. Are you fine with it :)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jameslamb, In your previous comment you mentioned
Do you think that we should create an extra validation set even if a validation set has already been provided in the I think that if a validation set has been provided in the |
||||||||||||||||||||||||||||||
} # early stop should be supported via global LightGBM parameter | ||||||||||||||||||||||||||||||
params_fit = {"X": X_train, "y": y_train} | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should make these keyword-only arguments, as they are in
scikit-learn
: I think we should make these keyword-only arguments, asscikit-learn
does.https://github.com/scikit-learn/scikit-learn/blob/6cccd99aee3483eb0f7562afdd3179ccccab0b1d/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py#L1689
Could you please try that (in these estimators and the Dask ones)?
I don't want to do that for other existing parameters, to prevent breaking existing user code, but since these are new parameters, it's safe to be stricter.