From 3c8762d2cebf1ce1dda9209b7f2ce6e583347325 Mon Sep 17 00:00:00 2001 From: colganwi Date: Thu, 24 Oct 2024 18:40:21 -0400 Subject: [PATCH] branch color vmax --- .gitignore | 1 + src/pycea/pl/plot_tree.py | 8 +++++++- tests/test_plot_tree.py | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 7b6c63b..846a601 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ .DS_Store *~ buck-out/ +.ipynb_checkpoints/ # Compiled files .venv/ diff --git a/src/pycea/pl/plot_tree.py b/src/pycea/pl/plot_tree.py index b95d99f..5c3c3fa 100644 --- a/src/pycea/pl/plot_tree.py +++ b/src/pycea/pl/plot_tree.py @@ -40,6 +40,8 @@ def branches( tree: str | Sequence[str] | None = None, cmap: str | mcolors.Colormap = "viridis", palette: cycler.Cycler | mcolors.ListedColormap | Sequence[str] | Mapping[str] | None = None, + vmax: int | float | None = None, + vmin: int | float | None = None, na_color: str = "lightgrey", na_linewidth: int | float = 1, ax: Axes | None = None, @@ -107,7 +109,11 @@ def branches( if len(color_data) == 0: raise ValueError(f"Key {color!r} is not present in any edge.") if color_data.dtype.kind in ["i", "f"]: - norm = plt.Normalize(vmin=color_data.min(), vmax=color_data.max()) + if not vmin: + vmin = color_data.min() + if not vmax: + vmax = color_data.max() + norm = plt.Normalize(vmin=vmin, vmax=vmax) cmap = plt.get_cmap(cmap) colors = [cmap(norm(color_data[edge])) if edge in color_data.index else na_color for edge in edges] kwargs.update({"color": colors}) diff --git a/tests/test_plot_tree.py b/tests/test_plot_tree.py index 941da4d..9604c49 100755 --- a/tests/test_plot_tree.py +++ b/tests/test_plot_tree.py @@ -21,7 +21,7 @@ def test_polar_with_clades(tdata): def test_angled_numeric_annotations(tdata): pycea.pl.branches( - tdata, polar=False, color="length", cmap="hsv", linewidth="length", depth_key="time", angled_branches=True + tdata, polar=False, color="length", cmap="hsv", linewidth="length", depth_key="time", angled_branches=True, vmax = 2, ) pycea.pl.nodes(tdata, nodes="all", color="time", style="s", size=20) pycea.pl.nodes(tdata, nodes=["2"], tree="1", color="black", style="*", size=200)