diff --git a/src/plottools/remove.py b/src/plottools/remove.py index 1513189..a5110a0 100644 --- a/src/plottools/remove.py +++ b/src/plottools/remove.py @@ -7,6 +7,8 @@ - `remove_lines()`: remove all line artists. - `remove_markers()`: remove all line artists with markers that are not connected by lines. - `remove_style()`: remove all line artists that match a style. +- `remove_texts()`: remove text artists. +- `remove_arrows()`: remove arrow, i.e. annotation artists. ## Install/uninstall remove functions @@ -91,6 +93,64 @@ def remove_style(ax, **style): line.remove() +def remove_texts(ax, *indices): + """Remove text artists. + + Parameters + ---------- + ax: matplotlib axes + Axes from which texts should be removed. + indices: list of int or str + If specified, remove only the text elements at the specified indices + or with the specified text. + """ + texts = [] + for i in indices: + if not isinstance(i, int): + texts.append(i) + remove_text = [] + count = 0 + for a in ax.get_children(): + if type(a) is mpl.text.Text: + if len(indices) == 0 or count in indices or a.get_text() in texts: + remove_text.append(a) + count += 1 + for text in remove_text: + try: + text.remove() + except NotImplementedError: + text.set_visible(False) + + +def remove_arrows(ax, *indices): + """Remove arrow, i.e. annotation artists. + + Parameters + ---------- + ax: matplotlib axes + Axes from which arrows should be removed. + indices: list of int or str + If specified, remove only the annotation elements at the specified indices + or with the specified text. + """ + texts = [] + for i in indices: + if not isinstance(i, int): + texts.append(i) + remove_text = [] + count = 0 + for a in ax.get_children(): + if isinstance(a, mpl.text.Annotation): + if len(indices) == 0 or count in indices or a.get_text() in texts: + remove_text.append(a) + count += 1 + for text in remove_text: + try: + text.remove() + except NotImplementedError: + text.set_visible(False) + + def install_remove(): """ Install remove functions on matplotlib axes. ``` @@ -107,6 +167,10 @@ def install_remove(): mpl.axes.Axes.remove_markers = remove_markers if not hasattr(mpl.axes.Axes, 'remove_style'): mpl.axes.Axes.remove_style = remove_style + if not hasattr(mpl.axes.Axes, 'remove_texts'): + mpl.axes.Axes.remove_texts = remove_texts + if not hasattr(mpl.axes.Axes, 'remove_arrows'): + mpl.axes.Axes.remove_arrows = remove_arrows def uninstall_remove(): @@ -124,6 +188,10 @@ def uninstall_remove(): delattr(mpl.axes.Axes, 'remove_markers') if hasattr(mpl.axes.Axes, 'remove_style'): delattr(mpl.axes.Axes, 'remove_style') + if hasattr(mpl.axes.Axes, 'remove_texts'): + delattr(mpl.axes.Axes, 'remove_texts') + if hasattr(mpl.axes.Axes, 'remove_arrows'): + delattr(mpl.axes.Axes, 'remove_arrows') install_remove()