Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] add scikit-learn-style API for early stopping #5808

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ad43e17
Enable Auto Early Stopping
ClaudioSalvatoreArcidiacono Sep 15, 2023
f05e5e0
Relax test conditions
ClaudioSalvatoreArcidiacono Sep 18, 2023
76f3c19
Merge branch 'master' into 3313-enable-auto-early-stopping
ClaudioSalvatoreArcidiacono Sep 20, 2023
457c7f6
Merge master
ClaudioSalvatoreArcidiacono Jan 16, 2024
0db1941
Revert "Merge master"
ClaudioSalvatoreArcidiacono Jan 16, 2024
10fac65
Merge remote-tracking branch 'lgbm/master' into 3313-enable-auto-earl…
ClaudioSalvatoreArcidiacono Jan 16, 2024
d10ca54
Add missing import
ClaudioSalvatoreArcidiacono Jan 16, 2024
3b8eb0a
Remove added extra new line
ClaudioSalvatoreArcidiacono Jan 17, 2024
e47acc0
Merge branch 'master' into 3313-enable-auto-early-stopping
ClaudioSalvatoreArcidiacono Jan 17, 2024
66701ac
Merge branch 'master' into 3313-enable-auto-early-stopping
ClaudioSalvatoreArcidiacono Jan 25, 2024
39d333e
Merge branch 'master' into 3313-enable-auto-early-stopping
ClaudioSalvatoreArcidiacono Feb 2, 2024
cad7eb6
Merge branch 'master' into 3313-enable-auto-early-stopping
ClaudioSalvatoreArcidiacono Feb 6, 2024
1234ccf
Merge master
ClaudioSalvatoreArcidiacono Nov 28, 2024
d54c96a
Improve documentation, check default behavior of early stopping
ClaudioSalvatoreArcidiacono Nov 28, 2024
9c1c8b4
Solve python 3.8 compatibility issue
ClaudioSalvatoreArcidiacono Nov 28, 2024
724c7fe
Remove default to auto
ClaudioSalvatoreArcidiacono Nov 29, 2024
c957fce
Revert changes in fit top part
ClaudioSalvatoreArcidiacono Nov 29, 2024
2d7da78
Make interface as similar as possible to sklearn
ClaudioSalvatoreArcidiacono Nov 29, 2024
069a84e
Add parameters to dask interface
ClaudioSalvatoreArcidiacono Nov 29, 2024
c430ec1
Improve documentation
ClaudioSalvatoreArcidiacono Nov 29, 2024
416323a
Linting
ClaudioSalvatoreArcidiacono Nov 29, 2024
73562ff
Check for exact value equal true for early stopping
ClaudioSalvatoreArcidiacono Nov 29, 2024
38edc42
Merge branch 'master' into 3313-enable-auto-early-stopping
jameslamb Dec 15, 2024
9a32376
Switch if/else conditions order in fit
ClaudioSalvatoreArcidiacono Dec 18, 2024
f33ebd3
Merge remote-tracking branch 'origin/master' into 3313-enable-auto-ea…
ClaudioSalvatoreArcidiacono Dec 18, 2024
a61726f
fix issues in engine.py
ClaudioSalvatoreArcidiacono Dec 18, 2024
44316d7
make new early stopping parameters keyword-only
ClaudioSalvatoreArcidiacono Dec 18, 2024
4cbfc84
Remove n_iter_no_change parameter
ClaudioSalvatoreArcidiacono Dec 18, 2024
93acf6a
Address comments in tests
ClaudioSalvatoreArcidiacono Dec 18, 2024
2b049c9
Improve tests
ClaudioSalvatoreArcidiacono Dec 18, 2024
61371cb
Add tests to check for validation fraction
ClaudioSalvatoreArcidiacono Dec 18, 2024
65c4e2f
Remove validation_fraction=None option
ClaudioSalvatoreArcidiacono Dec 18, 2024
0a8e843
Remove validation_fraction=None option also in dask
ClaudioSalvatoreArcidiacono Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/Python-Intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ This works with both metrics to minimize (L2, log loss, etc.) and to maximize (N
Note that if you specify more than one evaluation metric, all of them will be used for early stopping.
However, you can change this behavior and make LightGBM check only the first metric for early stopping by passing ``first_metric_only=True`` in ``early_stopping`` callback constructor.

In the scikit-learn API of lightgbm, early stopping can also be enabled by setting the parameter ``early_stopping`` to ``True``
or by setting the parameter ``early_stopping_round`` to an integer greater than 0.
When early stopping is enabled and no validation set is provided, a portion of the training data will be used as validation set.
The amount of data to use for validation is controlled by the parameter ``validation_fraction`` and defaults to 0.1.

Prediction
----------

Expand Down
3 changes: 3 additions & 0 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,6 +1134,7 @@ def __init__(
random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
n_jobs: Optional[int] = None,
importance_type: str = "split",
validation_fraction: Optional[float] = 0.1,
client: Optional[Client] = None,
**kwargs: Any,
):
Expand Down Expand Up @@ -1337,6 +1338,7 @@ def __init__(
random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
n_jobs: Optional[int] = None,
importance_type: str = "split",
validation_fraction: Optional[float] = 0.1,
client: Optional[Client] = None,
**kwargs: Any,
):
Expand Down Expand Up @@ -1504,6 +1506,7 @@ def __init__(
random_state: Optional[Union[int, np.random.RandomState, "np.random.Generator"]] = None,
n_jobs: Optional[int] = None,
importance_type: str = "split",
validation_fraction: Optional[float] = 0.1,
client: Optional[Client] = None,
**kwargs: Any,
):
Expand Down
36 changes: 22 additions & 14 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,11 +510,9 @@ def _make_n_folds(
nfold: int,
params: Dict[str, Any],
seed: int,
fpreproc: Optional[_LGBM_PreprocFunction],
stratified: bool,
shuffle: bool,
eval_train_metric: bool,
) -> CVBooster:
) -> Iterable[Tuple[np.ndarray, np.ndarray]]:
"""Make a n-fold list of Booster from random indices."""
full_data = full_data.construct()
num_data = full_data.num_data()
Expand Down Expand Up @@ -559,7 +557,16 @@ def _make_n_folds(
test_id = [randidx[i : i + kstep] for i in range(0, num_data, kstep)]
train_id = [np.concatenate([test_id[i] for i in range(nfold) if k != i]) for k in range(nfold)]
folds = zip(train_id, test_id)
return folds


def _make_cvbooster(
full_data: Dataset,
params: Dict[str, Any],
folds: Iterable[Tuple[np.ndarray, np.ndarray]],
fpreproc: Optional[_LGBM_PreprocFunction],
eval_train_metric: bool,
) -> CVBooster:
ret = CVBooster()
for train_idx, test_idx in folds:
train_set = full_data.subset(sorted(train_idx))
Expand Down Expand Up @@ -764,10 +771,11 @@ def cv(
nfold=nfold,
params=params,
seed=seed,
fpreproc=fpreproc,
stratified=stratified,
shuffle=shuffle,
eval_train_metric=eval_train_metric,
)
cvbooster = _make_cvbooster(
full_data=train_set, params=params, folds=cvfolds, fpreproc=fpreproc, eval_train_metric=eval_train_metric
)

# setup callbacks
Expand Down Expand Up @@ -802,24 +810,24 @@ def cv(
for cb in callbacks_before_iter:
cb(
callback.CallbackEnv(
model=cvfolds,
model=cvbooster,
params=params,
iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=None,
)
)
cvfolds.update(fobj=fobj) # type: ignore[call-arg]
res = _agg_cv_result(cvfolds.eval_valid(feval)) # type: ignore[call-arg]
cvbooster.update(fobj=fobj) # type: ignore[call-arg]
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
res = _agg_cv_result(cvbooster.eval_valid(feval)) # type: ignore[call-arg]
for _, key, mean, _, std in res:
results[f"{key}-mean"].append(mean)
results[f"{key}-stdv"].append(std)
try:
for cb in callbacks_after_iter:
cb(
callback.CallbackEnv(
model=cvfolds,
model=cvbooster,
params=params,
iteration=i,
begin_iteration=0,
Expand All @@ -828,14 +836,14 @@ def cv(
)
)
except callback.EarlyStopException as earlyStopException:
cvfolds.best_iteration = earlyStopException.best_iteration + 1
for bst in cvfolds.boosters:
bst.best_iteration = cvfolds.best_iteration
cvbooster.best_iteration = earlyStopException.best_iteration + 1
for bst in cvbooster.boosters:
bst.best_iteration = cvbooster.best_iteration
for k in results:
results[k] = results[k][: cvfolds.best_iteration]
results[k] = results[k][: cvbooster.best_iteration]
break

if return_cvbooster:
results["cvbooster"] = cvfolds # type: ignore[assignment]
results["cvbooster"] = cvbooster # type: ignore[assignment]

return dict(results)
193 changes: 122 additions & 71 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
dt_DataTable,
pd_DataFrame,
)
from .engine import train
from .engine import _make_n_folds, train

if TYPE_CHECKING:
from .compat import _sklearn_Tags
Expand Down Expand Up @@ -509,7 +509,8 @@ def __init__(
random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None,
n_jobs: Optional[int] = None,
importance_type: str = "split",
**kwargs: Any,
validation_fraction: Optional[float] = 0.1,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
**kwargs,
**kwargs: Any,

Why was this type hint removed? If it was just an accident, please put it back to reduce the size of the diff.

):
r"""Construct a gradient boosting model.

Expand Down Expand Up @@ -589,6 +590,10 @@ def __init__(
The type of feature importance to be filled into ``feature_importances_``.
If 'split', result contains numbers of times the feature is used in a model.
If 'gain', result contains total gains of splits which use the feature.
validation_fraction : float or None, optional (default=0.1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are not any tests in test_sklearn.py that pass validation_fraction. Please add some, covering both the default behavior and that passing a non-default value (like 0.4) works as expected.

I don't know the exact code paths off the top of my head, would appreciate if you can investigate... but I think it should be possible to test this by checking the size of the datasets added to valid_sets and confirming that they're as expected (e.g. that the automatically-aded validation set has 4,000 rows if the input data X has 40,000 rows and validation_fraction=0.1 is passed).

If that's not observable through the public API, try to use mocking/patching to observe it instead of adding any additional properties to the Booster / estimators' public API.

Comment in-thread here if you have questions or need help with that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comment, I have added a couple of tests for that, I have used patching to achieve that, let me know if you think there is more that can be improved in those tests.

Proportion of training data to set aside as
validation data for early stopping. If None, early stopping is done on
the training data. Only used if early stopping is performed.
**kwargs
Other parameters for the model.
Check http://lightgbm.readthedocs.io/en/latest/Parameters.html for more parameters.
Expand Down Expand Up @@ -653,6 +658,7 @@ def __init__(
self.random_state = random_state
self.n_jobs = n_jobs
self.importance_type = importance_type
self.validation_fraction = validation_fraction
self._Booster: Optional[Booster] = None
self._evals_result: _EvalResultDict = {}
self._best_score: _LGBM_BoosterBestScoreType = {}
Expand Down Expand Up @@ -812,11 +818,29 @@ def _process_params(self, stage: str) -> Dict[str, Any]:
params.pop("importance_type", None)
params.pop("n_estimators", None)
params.pop("class_weight", None)
params.pop("validation_fraction", None)

if isinstance(params["random_state"], np.random.RandomState):
params["random_state"] = params["random_state"].randint(np.iinfo(np.int32).max)
elif isinstance(params["random_state"], np.random.Generator):
params["random_state"] = int(params["random_state"].integers(np.iinfo(np.int32).max))

params = _choose_param_value(
main_param_name="early_stopping_round",
params=params,
default_value="auto",
)
if params["early_stopping_round"] == "auto":
if hasattr(self, "_n_rows_train") and self._n_rows_train > 10_000:
params["early_stopping_round"] = 10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #5808 (comment), you said:

I have set early stopping to be disabled by default

But this block of code directly contradicts that... it enables early stopping by default if there are more than 10K rows in the training data.

From all the prior discussion on this PR, it seems clear to me that @borchero and @jmoralez both requested you not do that, and instead limit the scope of this PR to just providing an API for early stopping that matches how scikit-learn estimators do it. I agree with them... it shouldn't be enabled by default as part of this PR.

Please, do the following:

  • change the title of the PR (which will become a commit message and release notes item) to [python-package] add scikit-learn-style API for early stopping
  • remove this logic about early_stopping_round = "auto" (and the corresponding changes further down where you moved parameter processing to after input data validation, so it could reference self._n_rows_train
  • match whatever API HistGradientBoostingClassifier / HistGradientBoostingRegressor currently have for enabling early stopping

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @jameslamb, thanks a lot for your review, I highly appreciate that.

You are right with your comment, it was a mistake on my side, I have fixed it in my latest commits. Thank you for spotting it.

In the latest version i had:

  • Renamed the PR as you requested.
  • Removed the logic about early_stopping_round = "auto" and the corresponding changes where I switched parameters processing and input validation.
  • I think the APIs of HistGradientBoostingClassifier / HistGradientBoostingRegressor and the current implementation are matching except that:
    • HistGradientBoostingClassifier / HistGradientBoostingRegressor support early_stopping_round = "auto" as this is their default parameter value, but as we discussed above we do not want this behaviour.
    • I have tried to support the parameters alias logic of LightGBM while working around the fact that early_stopping is an alias for early_stopping_round(s) and n_iter_no_change. In the current implementation setting early_stopping_round(s) to an integer will override the default value for n_iter_no_change, but setting early_stopping to an integer will not enable early stopping and will not override the default value for n_iter_no_change. I have opted for this implementation to preserve backward compatibility as much as possible. Please let me know if you also agree with it or you have a different opinion, happy to change it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for that explanation.

Removed the logic about early_stopping_round = "auto" and the corresponding changes where I switched parameters processing and input validation.

Agreed, thank you for removing it.

*early_stopping is an alias for early_stopping_round(s) and n_iter_no_change

Thanks for being careful about backwards compatibility!

It's unfortunate that we have early_stopping as an alias for early_stopping_rounds and that scikit-learn's estimators now assign a different meaning to that :/

// alias = early_stopping_rounds, early_stopping, n_iter_no_change
// desc = will stop training if one metric of one validation data doesn't improve in last ``early_stopping_round`` rounds
// desc = ``<= 0`` means disable
// desc = can be used to speed up training

https://github.com/scikit-learn/scikit-learn/blob/6cccd99aee3483eb0f7562afdd3179ccccab0b1d/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py#L1572-L1575

Seems like that has been in scikit-learn for a long time (since v0.23):

https://github.com/scikit-learn/scikit-learn/blob/6cccd99aee3483eb0f7562afdd3179ccccab0b1d/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py#L1577

I see your implementation of this comment in LGBMModel._process_params():

params = _choose_param_value("early_stopping_round", params, self.n_iter_no_change)

I think I see a path to resolving this. It'll be complex, but it would keep LightGBM in compliance with the rules from https://scikit-learn.org/1.5/developers/develop.html#instantiation.

  • remove keyword argument n_iter_no_change (and property self.n_iter_no_change)... but ensure there are unit tests confirming that passing keyword argument n_iter_no_change through the scikit-learn estimators correctly sets the number of early stopping rounds
  • make keyword argument early_stopping have type Union[bool, int]
  • handle that difference in LGMBModel._process_params(), roughly like this:

rewrite this block:

# use joblib conventions for negative n_jobs, just like scikit-learn
# at predict time, this is handled later due to the order of parameter updates
if stage == "fit":
params = _choose_param_value("num_threads", params, self.n_jobs)
params["num_threads"] = self._process_n_jobs(params["num_threads"])

to something like this:

        if stage == "fit":
            # use joblib conventions for negative n_jobs, just like scikit-learn
            # at predict time, this is handled later due to the order of parameter updates
            params = _choose_param_value("num_threads", params, self.n_jobs)
            params["num_threads"] = self._process_n_jobs(params["num_threads"])

            if not isinstance(self.early_stopping, bool) and isinstance(self.early_stopping, int):
                _log_warning(
                    f"Found 'early_stopping={self.early_stopping}' passed through keyword arguments.
                    "Future versions of 'lightgbm' will not allow this, as scikit-learn expects keyword argument "
                    "'early_stopping' to be a boolean indicating whether or not to perform early stopping with "
                    "a randomly-sampled validation set. To set the number of early stopping rounds, and suppress "
                    "this warning, pass early_stopping_rounds={self.early_stopping} instead."
                 )
                 params["early_stopping_round"] = _choose_param_value(
                     main_param_name="early_stopping_round",
                     params=params,
                     default_value=self.early_stopping
                 )

And then any other places in the code where you have self.early_stopping is True should use isinstance(self.early_stopping, bool) and self.early_stopping is True, to avoid this SyntaxWarning:

1 is True
# <stdin>:1: SyntaxWarning: "is" with 'int' literal. Did you mean "=="?

I think that can work and provide backwards compatibility while allowing us to match scikit-learn's API. And then eventually (after several lightgbm releases), that warning could be converted to an error and the type of early_stopping could be narrowed to just bool.

What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a great suggestion! thanks a lot for it, I have implemented it in the latest version.

else:
params["early_stopping_round"] = None

if params["early_stopping_round"] is True:
params["early_stopping_round"] = 10
elif params["early_stopping_round"] is False:
params["early_stopping_round"] = None

if self._n_classes > 2:
for alias in _ConfigAliases.get("num_class"):
params.pop(alias, None)
Expand Down Expand Up @@ -891,27 +915,6 @@ def fit(
init_model: Optional[Union[str, Path, Booster, "LGBMModel"]] = None,
) -> "LGBMModel":
"""Docstring is set after definition, using a template."""
params = self._process_params(stage="fit")

# Do not modify original args in fit function
# Refer to https://github.com/microsoft/LightGBM/pull/2619
eval_metric_list: List[Union[str, _LGBM_ScikitCustomEvalFunction]]
if eval_metric is None:
eval_metric_list = []
elif isinstance(eval_metric, list):
eval_metric_list = copy.deepcopy(eval_metric)
else:
eval_metric_list = [copy.deepcopy(eval_metric)]

# Separate built-in from callable evaluation metrics
eval_metrics_callable = [_EvalFunctionWrapper(f) for f in eval_metric_list if callable(f)]
eval_metrics_builtin = [m for m in eval_metric_list if isinstance(m, str)]

# concatenate metric from params (or default if not provided in params) and eval_metric
params["metric"] = [params["metric"]] if isinstance(params["metric"], (str, type(None))) else params["metric"]
params["metric"] = [e for e in eval_metrics_builtin if e not in params["metric"]] + params["metric"]
params["metric"] = [metric for metric in params["metric"] if metric is not None]

if not isinstance(X, (pd_DataFrame, dt_DataTable)):
_X, _y = _LGBMValidateData(
self,
Expand All @@ -933,6 +936,33 @@ def fit(
# for other data types, setting n_features_in_ is handled by _LGBMValidateData() in the branch above
self.n_features_in_ = _X.shape[1]

self._n_features = _X.shape[1]
# copy for consistency
self._n_features_in = self._n_features

self._n_rows_train = _X.shape[0]

params = self._process_params(stage="fit")

# Do not modify original args in fit function
# Refer to https://github.com/microsoft/LightGBM/pull/2619
eval_metric_list: List[Union[str, _LGBM_ScikitCustomEvalFunction]]
if eval_metric is None:
eval_metric_list = []
elif isinstance(eval_metric, list):
eval_metric_list = copy.deepcopy(eval_metric)
else:
eval_metric_list = [copy.deepcopy(eval_metric)]

# Separate built-in from callable evaluation metrics
eval_metrics_callable = [_EvalFunctionWrapper(f) for f in eval_metric_list if callable(f)]
eval_metrics_builtin = [m for m in eval_metric_list if isinstance(m, str)]

# concatenate metric from params (or default if not provided in params) and eval_metric
params["metric"] = [params["metric"]] if isinstance(params["metric"], (str, type(None))) else params["metric"]
params["metric"] = [e for e in eval_metrics_builtin if e not in params["metric"]] + params["metric"]
params["metric"] = [metric for metric in params["metric"] if metric is not None]

if self._class_weight is None:
self._class_weight = self.class_weight
if self._class_weight is not None:
Expand All @@ -953,54 +983,75 @@ def fit(
params=params,
)

valid_sets: List[Dataset] = []
if eval_set is not None:
if isinstance(eval_set, tuple):
eval_set = [eval_set]
for i, valid_data in enumerate(eval_set):
# reduce cost for prediction training data
if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set
else:
valid_weight = _extract_evaluation_meta_data(
collection=eval_sample_weight,
name="eval_sample_weight",
i=i,
)
valid_class_weight = _extract_evaluation_meta_data(
collection=eval_class_weight,
name="eval_class_weight",
i=i,
)
if valid_class_weight is not None:
if isinstance(valid_class_weight, dict) and self._class_map is not None:
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
valid_class_sample_weight = _LGBMComputeSampleWeight(valid_class_weight, valid_data[1])
if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = _extract_evaluation_meta_data(
collection=eval_init_score,
name="eval_init_score",
i=i,
)
valid_group = _extract_evaluation_meta_data(
collection=eval_group,
name="eval_group",
i=i,
)
valid_set = Dataset(
data=valid_data[0],
label=valid_data[1],
weight=valid_weight,
group=valid_group,
init_score=valid_init_score,
categorical_feature="auto",
params=params,
)

valid_sets.append(valid_set)
if params["early_stopping_round"] is not None and eval_set is None:
if self.validation_fraction is not None:
n_splits = max(int(np.ceil(1 / self.validation_fraction)), 2)
stratified = isinstance(self, LGBMClassifier)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not a huge fan of how the validation set is created from the train set using the _make_n_folds function.

if 1/validation_fraction is not an integer, the result will be that the actual validation set size will not match the validation fraction specified by the user.

For example, if the validation fraction is 0.4 the number of splits calculated here will be 2, which will result in a fraction of 0.5, instead of 0.4.

Using something like train_test_split from scikit-learn would solve the issue for the classification and the regression case, but for ranking tasks our best option is GroupShuffleSplit, which will inevitably suffer from the same issue expressed above. The options that I thought to solve this issue are:

  1. Leave the code as-is and raise a warning when 1/validation_fraction is not an integer.
  2. Use train_test_split for creating the validation set in the classification and regression cases; Raise a warning when 1/validation_fraction is not an integer in the ranking case.

I would lean more towards option 2, but this will make the MR bigger.

@jameslamb I would like to hear your opinion on it, do you perhaps already have something else in mind?

cvfolds = _make_n_folds(
full_data=train_set,
folds=None,
nfold=n_splits,
params=params,
seed=self.random_state,
stratified=stratified,
shuffle=True,
)
train_idx, val_idx = next(cvfolds)
valid_set = train_set.subset(sorted(val_idx))
train_set = train_set.subset(sorted(train_idx))
else:
valid_set = train_set
valid_set = valid_set.construct()
valid_sets = [valid_set]
else:
valid_sets: List[Dataset] = []
if eval_set is not None:
if isinstance(eval_set, tuple):
eval_set = [eval_set]
for i, valid_data in enumerate(eval_set):
# reduce cost for prediction training data
if valid_data[0] is X and valid_data[1] is y:
valid_set = train_set
else:
valid_weight = _extract_evaluation_meta_data(
collection=eval_sample_weight,
name="eval_sample_weight",
i=i,
)
valid_class_weight = _extract_evaluation_meta_data(
collection=eval_class_weight,
name="eval_class_weight",
i=i,
)
if valid_class_weight is not None:
if isinstance(valid_class_weight, dict) and self._class_map is not None:
valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()}
valid_class_sample_weight = _LGBMComputeSampleWeight(valid_class_weight, valid_data[1])
if valid_weight is None or len(valid_weight) == 0:
valid_weight = valid_class_sample_weight
else:
valid_weight = np.multiply(valid_weight, valid_class_sample_weight)
valid_init_score = _extract_evaluation_meta_data(
collection=eval_init_score,
name="eval_init_score",
i=i,
)
valid_group = _extract_evaluation_meta_data(
collection=eval_group,
name="eval_group",
i=i,
)
valid_set = Dataset(
data=valid_data[0],
label=valid_data[1],
weight=valid_weight,
group=valid_group,
init_score=valid_init_score,
categorical_feature="auto",
params=params,
)

valid_sets.append(valid_set)

if isinstance(init_model, LGBMModel):
init_model = init_model.booster_
Expand Down
Loading
Loading