Skip to content

Commit

Permalink
[python] Add type hints to python-package/lightgbm/plotting.py (#4367)
Browse files Browse the repository at this point in the history
* Change to f-strings in test_plotting.py

* Implemented type hinting in python-package/lightgbm/plotting.py

* Apply suggestions from code review

Co-authored-by: James Lamb <[email protected]>

* Update python-package/lightgbm/plotting.py

Co-authored-by: James Lamb <[email protected]>

* Delete test_plotting.py

File was mistakenly added to pull request

* Revert "Delete test_plotting.py"

This reverts commit df095b6.

* default argument

* Apply suggestions from code review

Co-authored-by: James Lamb <[email protected]>

* Removed errant bracket

* Update python-package/lightgbm/plotting.py

* Apply suggestions from code review

Co-authored-by: Nikita Titov <[email protected]>

* Fixing tuples from ints to floats

* Apply suggestions from code review

Co-authored-by: Nikita Titov <[email protected]>

* Update python-package/lightgbm/plotting.py

Co-authored-by: Nikita Titov <[email protected]>

Co-authored-by: James Lamb <[email protected]>
Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
3 people authored Jun 27, 2021
1 parent 45ac271 commit bb39bc9
Showing 1 changed file with 79 additions and 24 deletions.
103 changes: 79 additions & 24 deletions python-package/lightgbm/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Plotting library."""
from copy import deepcopy
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np

Expand All @@ -10,24 +11,36 @@
from .sklearn import LGBMModel


def _check_not_tuple_of_2_elements(obj, obj_name='obj'):
def _check_not_tuple_of_2_elements(obj: Any, obj_name: str = 'obj') -> None:
"""Check object is not tuple or does not have 2 elements."""
if not isinstance(obj, tuple) or len(obj) != 2:
raise TypeError(f"{obj_name} must be a tuple of 2 elements.")


def _float2str(value, precision=None):
def _float2str(value: float, precision: Optional[int] = None) -> str:
return (f"{value:.{precision}f}"
if precision is not None and not isinstance(value, str)
else str(value))


def plot_importance(booster, ax=None, height=0.2,
xlim=None, ylim=None, title='Feature importance',
xlabel='Feature importance', ylabel='Features',
importance_type='split', max_num_features=None,
ignore_zero=True, figsize=None, dpi=None, grid=True,
precision=3, **kwargs):
def plot_importance(
booster: Union[Booster, LGBMModel],
ax=None,
height: float = 0.2,
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
title: Optional[str] = 'Feature importance',
xlabel: Optional[str] = 'Feature importance',
ylabel: Optional[str] = 'Features',
importance_type: str = 'split',
max_num_features: Optional[int] = None,
ignore_zero: bool = True,
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
grid: bool = True,
precision: Optional[int] = 3,
**kwargs: Any
) -> Any:
"""Plot model's feature importances.
Parameters
Expand Down Expand Up @@ -138,11 +151,22 @@ def plot_importance(booster, ax=None, height=0.2,
return ax


def plot_split_value_histogram(booster, feature, bins=None, ax=None, width_coef=0.8,
xlim=None, ylim=None,
title='Split value histogram for feature with @index/name@ @feature@',
xlabel='Feature split value', ylabel='Count',
figsize=None, dpi=None, grid=True, **kwargs):
def plot_split_value_histogram(
booster: Union[Booster, LGBMModel],
feature: Union[int, str],
bins: Union[int, str, None] = None,
ax=None,
width_coef: float = 0.8,
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
title: Optional[str] = 'Split value histogram for feature with @index/name@ @feature@',
xlabel: Optional[str] = 'Feature split value',
ylabel: Optional[str] = 'Count',
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
grid: bool = True,
**kwargs: Any
) -> Any:
"""Plot split value histogram for the specified feature of the model.
Parameters
Expand Down Expand Up @@ -244,11 +268,20 @@ def plot_split_value_histogram(booster, feature, bins=None, ax=None, width_coef=
return ax


def plot_metric(booster, metric=None, dataset_names=None,
ax=None, xlim=None, ylim=None,
title='Metric during training',
xlabel='Iterations', ylabel='auto',
figsize=None, dpi=None, grid=True):
def plot_metric(
booster: Union[Dict, LGBMModel],
metric: Optional[str] = None,
dataset_names: Optional[List[str]] = None,
ax=None,
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
title: Optional[str] = 'Metric during training',
xlabel: Optional[str] = 'Iterations',
ylabel: Optional[str] = 'auto',
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
grid: bool = True
) -> Any:
"""Plot one metric during training.
Parameters
Expand Down Expand Up @@ -369,8 +402,15 @@ def plot_metric(booster, metric=None, dataset_names=None,
return ax


def _to_graphviz(tree_info, show_info, feature_names, precision=3,
orientation='horizontal', constraints=None, **kwargs):
def _to_graphviz(
tree_info: Dict[str, Any],
show_info: List[str],
feature_names: Union[List[str], None],
precision: Optional[int] = 3,
orientation: str = 'horizontal',
constraints: Optional[List[int]] = None,
**kwargs: Any
) -> Any:
"""Convert specified tree to graphviz instance.
See:
Expand Down Expand Up @@ -465,8 +505,14 @@ def add(root, total_count, parent=None, decision=None):
return graph


def create_tree_digraph(booster, tree_index=0, show_info=None, precision=3,
orientation='horizontal', **kwargs):
def create_tree_digraph(
booster: Union[Booster, LGBMModel],
tree_index: int = 0,
show_info: Optional[List[str]] = None,
precision: Optional[int] = 3,
orientation: str = 'horizontal',
**kwargs: Any
) -> Any:
"""Create a digraph representation of specified tree.
Each node in the graph represents a node in the tree.
Expand Down Expand Up @@ -542,8 +588,17 @@ def create_tree_digraph(booster, tree_index=0, show_info=None, precision=3,
return graph


def plot_tree(booster, ax=None, tree_index=0, figsize=None, dpi=None,
show_info=None, precision=3, orientation='horizontal', **kwargs):
def plot_tree(
booster: Union[Booster, LGBMModel],
ax=None,
tree_index: int = 0,
figsize: Optional[Tuple[float, float]] = None,
dpi: Optional[int] = None,
show_info: Optional[List[str]] = None,
precision: Optional[int] = 3,
orientation: str = 'horizontal',
**kwargs: Any
) -> Any:
"""Plot specified tree.
Each node in the graph represents a node in the tree.
Expand Down

0 comments on commit bb39bc9

Please sign in to comment.