Skip to content

Commit

Permalink
Merge branch '3313-enable-auto-early-stopping' of https://github.com/…
Browse files Browse the repository at this point in the history
…ClaudioSalvatoreArcidiacono/LightGBM into 3313-enable-auto-early-stopping
  • Loading branch information
ClaudioSalvatoreArcidiacono committed Jun 5, 2023
2 parents 2ca8cc1 + 1910076 commit f2ccd52
Showing 1 changed file with 38 additions and 23 deletions.
61 changes: 38 additions & 23 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,16 @@
]
_LGBM_ScikitCustomEvalSetSplitter = Union[
Callable[
[np.ndarray, np.ndarray],
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
[_LGBM_ScikitMatrixLike, _LGBM_LabelType],
Tuple[_LGBM_ScikitMatrixLike, _LGBM_ScikitMatrixLike, _LGBM_LabelType, _LGBM_LabelType]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray],
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]
[_LGBM_ScikitMatrixLike, _LGBM_LabelType, Optional[np.ndarray]],
Tuple[_LGBM_ScikitMatrixLike, _LGBM_ScikitMatrixLike, _LGBM_LabelType, _LGBM_LabelType, Optional[np.ndarray], Optional[np.ndarray]]
],
Callable[
[np.ndarray, np.ndarray, np.ndarray],
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]
[_LGBM_ScikitMatrixLike, _LGBM_LabelType, Optional[np.ndarray], _LGBM_GroupType],
Tuple[_LGBM_ScikitMatrixLike, _LGBM_ScikitMatrixLike, _LGBM_LabelType, _LGBM_LabelType, Optional[np.ndarray], Optional[np.ndarray], _LGBM_GroupType, _LGBM_GroupType]
],
]
_LGBM_ScikitValidSet = Tuple[_LGBM_ScikitMatrixLike, _LGBM_LabelType]
Expand Down Expand Up @@ -256,17 +256,17 @@ def __call__(


def _train_test_split(
X,
y,
X: _LGBM_ScikitMatrixLike,
y: _LGBM_LabelType,
weight,
test_size: float,
random_state: Optional[Union[int, np.random.RandomState]],
stratified: bool,
) -> Tuple[
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
_LGBM_ScikitMatrixLike,
_LGBM_ScikitMatrixLike,
_LGBM_LabelType,
_LGBM_LabelType,
Optional[np.ndarray],
Optional[np.ndarray],
]:
Expand Down Expand Up @@ -319,7 +319,22 @@ def _train_test_split(
return X_train, X_val, y_train, y_val, None, None


def _train_test_group_split(X, y, weight, group, n_splits: int):
def _train_test_group_split(
X: _LGBM_ScikitMatrixLike,
y: _LGBM_LabelType,
weight,
group: _LGBM_GroupType,
n_splits: int
) -> Tuple[
_LGBM_ScikitMatrixLike,
_LGBM_ScikitMatrixLike,
_LGBM_LabelType,
_LGBM_LabelType,
Optional[np.ndarray],
Optional[np.ndarray],
_LGBM_GroupType,
_LGBM_GroupType,
]:
"""Split X, y, weights and group into train and test subsets.
Parameters
Expand Down Expand Up @@ -390,20 +405,20 @@ def _train_test_group_split(X, y, weight, group, n_splits: int):


def _train_test_split_custom_splitter(
custom_splitter,
X,
y,
custom_splitter: _LGBM_ScikitCustomEvalSetSplitter,
X: _LGBM_ScikitMatrixLike,
y: _LGBM_LabelType,
weight,
group
group: Optional[_LGBM_GroupType]
) -> Tuple[
np.ndarray,
np.ndarray,
np.ndarray,
np.ndarray,
Optional[np.ndarray],
Optional[np.ndarray],
_LGBM_ScikitMatrixLike,
_LGBM_ScikitMatrixLike,
_LGBM_LabelType,
_LGBM_LabelType,
Optional[np.ndarray],
Optional[np.ndarray],
Optional[_LGBM_GroupType],
Optional[_LGBM_GroupType],
]:
"""Call passed custom_splitter with appropriate arguments.
Expand Down

0 comments on commit f2ccd52

Please sign in to comment.