From bb39bc995f92d3ef19dab1a18dcb2cb8f3fae817 Mon Sep 17 00:00:00 2001 From: Weston King-Leatham <71475274+WestonKing-Leatham@users.noreply.github.com> Date: Sun, 27 Jun 2021 09:20:59 -0400 Subject: [PATCH] [python] Add type hints to python-package/lightgbm/plotting.py (#4367) * Change to f-strings in test_plotting.py * Implemented type hinting in python-package/lightgbm/plotting.py * Apply suggestions from code review Co-authored-by: James Lamb * Update python-package/lightgbm/plotting.py Co-authored-by: James Lamb * Delete test_plotting.py File was mistakenly added to pull request * Revert "Delete test_plotting.py" This reverts commit df095b612af3abafcc87df4f95e8b523a49ebde5. * default argument * Apply suggestions from code review Co-authored-by: James Lamb * Removed errant bracket * Update python-package/lightgbm/plotting.py * Apply suggestions from code review Co-authored-by: Nikita Titov * Fixing tuples from ints to floats * Apply suggestions from code review Co-authored-by: Nikita Titov * Update python-package/lightgbm/plotting.py Co-authored-by: Nikita Titov Co-authored-by: James Lamb Co-authored-by: Nikita Titov --- python-package/lightgbm/plotting.py | 103 +++++++++++++++++++++------- 1 file changed, 79 insertions(+), 24 deletions(-) diff --git a/python-package/lightgbm/plotting.py b/python-package/lightgbm/plotting.py index 0ac9198d7acd..28dff0843a05 100644 --- a/python-package/lightgbm/plotting.py +++ b/python-package/lightgbm/plotting.py @@ -2,6 +2,7 @@ """Plotting library.""" from copy import deepcopy from io import BytesIO +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -10,24 +11,36 @@ from .sklearn import LGBMModel -def _check_not_tuple_of_2_elements(obj, obj_name='obj'): +def _check_not_tuple_of_2_elements(obj: Any, obj_name: str = 'obj') -> None: """Check object is not tuple or does not have 2 elements.""" if not isinstance(obj, tuple) or len(obj) != 2: raise TypeError(f"{obj_name} must be a tuple of 2 elements.") -def _float2str(value, precision=None): +def _float2str(value: float, precision: Optional[int] = None) -> str: return (f"{value:.{precision}f}" if precision is not None and not isinstance(value, str) else str(value)) -def plot_importance(booster, ax=None, height=0.2, - xlim=None, ylim=None, title='Feature importance', - xlabel='Feature importance', ylabel='Features', - importance_type='split', max_num_features=None, - ignore_zero=True, figsize=None, dpi=None, grid=True, - precision=3, **kwargs): +def plot_importance( + booster: Union[Booster, LGBMModel], + ax=None, + height: float = 0.2, + xlim: Optional[Tuple[float, float]] = None, + ylim: Optional[Tuple[float, float]] = None, + title: Optional[str] = 'Feature importance', + xlabel: Optional[str] = 'Feature importance', + ylabel: Optional[str] = 'Features', + importance_type: str = 'split', + max_num_features: Optional[int] = None, + ignore_zero: bool = True, + figsize: Optional[Tuple[float, float]] = None, + dpi: Optional[int] = None, + grid: bool = True, + precision: Optional[int] = 3, + **kwargs: Any +) -> Any: """Plot model's feature importances. Parameters @@ -138,11 +151,22 @@ def plot_importance(booster, ax=None, height=0.2, return ax -def plot_split_value_histogram(booster, feature, bins=None, ax=None, width_coef=0.8, - xlim=None, ylim=None, - title='Split value histogram for feature with @index/name@ @feature@', - xlabel='Feature split value', ylabel='Count', - figsize=None, dpi=None, grid=True, **kwargs): +def plot_split_value_histogram( + booster: Union[Booster, LGBMModel], + feature: Union[int, str], + bins: Union[int, str, None] = None, + ax=None, + width_coef: float = 0.8, + xlim: Optional[Tuple[float, float]] = None, + ylim: Optional[Tuple[float, float]] = None, + title: Optional[str] = 'Split value histogram for feature with @index/name@ @feature@', + xlabel: Optional[str] = 'Feature split value', + ylabel: Optional[str] = 'Count', + figsize: Optional[Tuple[float, float]] = None, + dpi: Optional[int] = None, + grid: bool = True, + **kwargs: Any +) -> Any: """Plot split value histogram for the specified feature of the model. Parameters @@ -244,11 +268,20 @@ def plot_split_value_histogram(booster, feature, bins=None, ax=None, width_coef= return ax -def plot_metric(booster, metric=None, dataset_names=None, - ax=None, xlim=None, ylim=None, - title='Metric during training', - xlabel='Iterations', ylabel='auto', - figsize=None, dpi=None, grid=True): +def plot_metric( + booster: Union[Dict, LGBMModel], + metric: Optional[str] = None, + dataset_names: Optional[List[str]] = None, + ax=None, + xlim: Optional[Tuple[float, float]] = None, + ylim: Optional[Tuple[float, float]] = None, + title: Optional[str] = 'Metric during training', + xlabel: Optional[str] = 'Iterations', + ylabel: Optional[str] = 'auto', + figsize: Optional[Tuple[float, float]] = None, + dpi: Optional[int] = None, + grid: bool = True +) -> Any: """Plot one metric during training. Parameters @@ -369,8 +402,15 @@ def plot_metric(booster, metric=None, dataset_names=None, return ax -def _to_graphviz(tree_info, show_info, feature_names, precision=3, - orientation='horizontal', constraints=None, **kwargs): +def _to_graphviz( + tree_info: Dict[str, Any], + show_info: List[str], + feature_names: Union[List[str], None], + precision: Optional[int] = 3, + orientation: str = 'horizontal', + constraints: Optional[List[int]] = None, + **kwargs: Any +) -> Any: """Convert specified tree to graphviz instance. See: @@ -465,8 +505,14 @@ def add(root, total_count, parent=None, decision=None): return graph -def create_tree_digraph(booster, tree_index=0, show_info=None, precision=3, - orientation='horizontal', **kwargs): +def create_tree_digraph( + booster: Union[Booster, LGBMModel], + tree_index: int = 0, + show_info: Optional[List[str]] = None, + precision: Optional[int] = 3, + orientation: str = 'horizontal', + **kwargs: Any +) -> Any: """Create a digraph representation of specified tree. Each node in the graph represents a node in the tree. @@ -542,8 +588,17 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=3, return graph -def plot_tree(booster, ax=None, tree_index=0, figsize=None, dpi=None, - show_info=None, precision=3, orientation='horizontal', **kwargs): +def plot_tree( + booster: Union[Booster, LGBMModel], + ax=None, + tree_index: int = 0, + figsize: Optional[Tuple[float, float]] = None, + dpi: Optional[int] = None, + show_info: Optional[List[str]] = None, + precision: Optional[int] = 3, + orientation: str = 'horizontal', + **kwargs: Any +) -> Any: """Plot specified tree. Each node in the graph represents a node in the tree.