Skip to content

Commit

Permalink
Add early stopping init params to dask interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudioSalvatoreArcidiacono committed Jun 5, 2023
1 parent 67b1c9a commit 1910076
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 9 deletions.
33 changes: 30 additions & 3 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat,
dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series,
default_client, delayed, pd_DataFrame, pd_Series, wait)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomObjectiveFunction,
_LGBM_ScikitEvalMetricType, _lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit,
_lgbmmodel_doc_predict)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomEvalSetSplitter,
_LGBM_ScikitCustomObjectiveFunction, _LGBM_ScikitEvalMetricType, _lgbmmodel_doc_custom_eval_note,
_lgbmmodel_doc_fit, _lgbmmodel_doc_predict)

__all__ = [
'DaskLGBMClassifier',
Expand Down Expand Up @@ -1170,10 +1170,19 @@ def __init__(
random_state: Optional[Union[int, np.random.RandomState]] = None,
n_jobs: Optional[int] = None,
importance_type: str = 'split',
early_stopping: bool = False,
validation_fraction: Optional[float] = 0.1,
n_iter_no_change: int = 10,
validation_set_split_strategy: Optional[Union[str, _LGBM_ScikitCustomEvalSetSplitter]] = None,
client: Optional[Client] = None,
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMClassifier.__init__."""
if early_stopping:
raise NotImplementedError(
"Early Stopping is not available for the Dask interface of lightgbm "
f"(found early_stopping={early_stopping})"
)
self.client = client
super().__init__(
boosting_type=boosting_type,
Expand Down Expand Up @@ -1375,10 +1384,19 @@ def __init__(
random_state: Optional[Union[int, np.random.RandomState]] = None,
n_jobs: Optional[int] = None,
importance_type: str = 'split',
early_stopping: bool = False,
validation_fraction: Optional[float] = 0.1,
n_iter_no_change: int = 10,
validation_set_split_strategy: Optional[Union[str, _LGBM_ScikitCustomEvalSetSplitter]] = None,
client: Optional[Client] = None,
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMRegressor.__init__."""
if early_stopping:
raise NotImplementedError(
"Early Stopping is not available for the Dask interface of lightgbm "
f"(found early_stopping={early_stopping})"
)
self.client = client
super().__init__(
boosting_type=boosting_type,
Expand Down Expand Up @@ -1545,10 +1563,19 @@ def __init__(
random_state: Optional[Union[int, np.random.RandomState]] = None,
n_jobs: Optional[int] = None,
importance_type: str = 'split',
early_stopping: bool = False,
validation_fraction: Optional[float] = 0.1,
n_iter_no_change: int = 10,
validation_set_split_strategy: Optional[Union[str, _LGBM_ScikitCustomEvalSetSplitter]] = None,
client: Optional[Client] = None,
**kwargs: Any
):
"""Docstring is inherited from the lightgbm.LGBMRanker.__init__."""
if early_stopping:
raise NotImplementedError(
"Early Stopping is not available for the Dask interface of lightgbm "
f"(found early_stopping={early_stopping})"
)
self.client = client
super().__init__(
boosting_type=boosting_type,
Expand Down
9 changes: 3 additions & 6 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,14 +1687,11 @@ def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except
assert dask_spec.kwonlyargs == sklearn_spec.kwonlyargs
assert dask_spec.kwonlydefaults == sklearn_spec.kwonlydefaults

# "client" and early stopping arguments should be the only difference, and the final arguments
# early stopping is not yet implemented for the dask interface
# Ref. https://github.com/microsoft/LightGBM/issues/3712
assert dask_spec.args[:-1] == sklearn_spec.args[:-4]
assert dask_spec.defaults[:-1] == sklearn_spec.defaults[:-4]
# "client" should be the only different, and the final argument
assert dask_spec.args[:-1] == sklearn_spec.args
assert dask_spec.defaults[:-1] == sklearn_spec.defaults
assert dask_spec.args[-1] == 'client'
assert dask_spec.defaults[-1] is None
assert sklearn_spec.args[-4:] == ['early_stopping', 'validation_fraction', 'n_iter_no_change', 'validation_set_split_strategy']


@pytest.mark.parametrize(
Expand Down

0 comments on commit 1910076

Please sign in to comment.