diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index fc77ff7a9d4a..0d278b21fc49 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -4,6 +4,7 @@ import copy from operator import attrgetter from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -11,14 +12,34 @@ from .basic import Booster, Dataset, LightGBMError, _ConfigAliases, _InnerPredictor, _log_warning from .compat import SKLEARN_INSTALLED, _LGBMGroupKFold, _LGBMStratifiedKFold - -def train(params, train_set, num_boost_round=100, - valid_sets=None, valid_names=None, - fobj=None, feval=None, init_model=None, - feature_name='auto', categorical_feature='auto', - early_stopping_rounds=None, evals_result=None, - verbose_eval=True, learning_rates=None, - keep_training_booster=False, callbacks=None): +_LGBM_CustomObjectiveFunction = Callable[ + [Union[List, np.ndarray], Dataset], + Tuple[Union[List, np.ndarray], Union[List, np.ndarray]] +] +_LGBM_CustomMetricFunction = Callable[ + [Union[List, np.ndarray], Dataset], + Tuple[str, float, bool] +] + + +def train( + params: Dict[str, Any], + train_set: Dataset, + num_boost_round: int = 100, + valid_sets: Optional[List[Dataset]] = None, + valid_names: Optional[List[str]] = None, + fobj: Optional[_LGBM_CustomObjectiveFunction] = None, + feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None, + init_model: Optional[Union[str, Path, Booster]] = None, + feature_name: Union[List[str], str] = 'auto', + categorical_feature: Union[List[str], List[int], str] = 'auto', + early_stopping_rounds: Optional[int] = None, + evals_result: Optional[Dict[str, Any]] = None, + verbose_eval: Union[bool, int] = True, + learning_rates: Optional[Union[List[float], Callable[[int], float]]] = None, + keep_training_booster: bool = False, + callbacks: Optional[List[Callable]] = None +) -> Booster: """Perform the training with given parameters. Parameters @@ -101,7 +122,9 @@ def train(params, train_set, num_boost_round=100, The index of iteration that has the best performance will be saved in the ``best_iteration`` field if early stopping logic is enabled by setting ``early_stopping_rounds``. evals_result: dict or None, optional (default=None) - This dictionary used to store all evaluation results of all the items in ``valid_sets``. + Dictionary used to store all evaluation results of all the items in ``valid_sets``. + This should be initialized outside of your call to ``train()`` and should be empty. + Any initial contents of the dictionary will be deleted by ``train()``. .. rubric:: Example