From 44014015536574743b5d0eb7b4791258268090b7 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 5 Jun 2024 07:54:16 -0500 Subject: [PATCH] [python-package] add a few type hints in LGBMModel.fit() (#6470) --- python-package/lightgbm/sklearn.py | 59 ++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index cb577c18c265..8fb998984720 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -454,6 +454,30 @@ def __call__( """ +def _extract_evaluation_meta_data( + *, + collection: Optional[Union[Dict[Any, Any], List[Any]]], + name: str, + i: int, +) -> Optional[Any]: + """Try to extract the ith element of one of the ``eval_*`` inputs.""" + if collection is None: + return None + elif isinstance(collection, list): + # It's possible, for example, to pass 3 eval sets through `eval_set`, + # but only 1 init_score through `eval_init_score`. + # + # This if-else accounts for that possiblity. + if len(collection) > i: + return collection[i] + else: + return None + elif isinstance(collection, dict): + return collection.get(i, None) + else: + raise TypeError(f"{name} should be dict or list") + + class LGBMModel(_LGBMModelBase): """Implementation of the scikit-learn API for LightGBM.""" @@ -869,17 +893,6 @@ def fit( valid_sets: List[Dataset] = [] if eval_set is not None: - - def _get_meta_data(collection, name, i): - if collection is None: - return None - elif isinstance(collection, list): - return collection[i] if len(collection) > i else None - elif isinstance(collection, dict): - return collection.get(i, None) - else: - raise TypeError(f"{name} should be dict or list") - if isinstance(eval_set, tuple): eval_set = [eval_set] for i, valid_data in enumerate(eval_set): @@ -887,8 +900,16 @@ def _get_meta_data(collection, name, i): if valid_data[0] is X and valid_data[1] is y: valid_set = train_set else: - valid_weight = _get_meta_data(eval_sample_weight, "eval_sample_weight", i) - valid_class_weight = _get_meta_data(eval_class_weight, "eval_class_weight", i) + valid_weight = _extract_evaluation_meta_data( + collection=eval_sample_weight, + name="eval_sample_weight", + i=i, + ) + valid_class_weight = _extract_evaluation_meta_data( + collection=eval_class_weight, + name="eval_class_weight", + i=i, + ) if valid_class_weight is not None: if isinstance(valid_class_weight, dict) and self._class_map is not None: valid_class_weight = {self._class_map[k]: v for k, v in valid_class_weight.items()} @@ -897,8 +918,16 @@ def _get_meta_data(collection, name, i): valid_weight = valid_class_sample_weight else: valid_weight = np.multiply(valid_weight, valid_class_sample_weight) - valid_init_score = _get_meta_data(eval_init_score, "eval_init_score", i) - valid_group = _get_meta_data(eval_group, "eval_group", i) + valid_init_score = _extract_evaluation_meta_data( + collection=eval_init_score, + name="eval_init_score", + i=i, + ) + valid_group = _extract_evaluation_meta_data( + collection=eval_group, + name="eval_group", + i=i, + ) valid_set = Dataset( data=valid_data[0], label=valid_data[1],