Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leaf data patch #13

Merged
merged 2 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@ __pycache__/

# Node modules
node_modules/

# Environment
environment.yml
15 changes: 15 additions & 0 deletions src/pycea/pl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,18 @@ def _series_to_rgb_array(series, colors, vmin=None, vmax=None, na_color="#808080
else:
raise ValueError("cmap must be either a dictionary or a ListedColormap.")
return rgb_array


def _check_tree_overlap(tdata, tree_keys):
"""Check single tree is requested when allow_overlap is True"""
if tree_keys is None:
tree_keys = tdata.obst.keys()
if tdata.allow_overlap and len(tree_keys) > 1:
raise ValueError("Must specify a tree when tdata.allow_overlap is True.")
elif isinstance(tree_keys, str):
pass
elif isinstance(tree_keys, Sequence):
if tdata.allow_overlap:
raise ValueError("Cannot request multiple trees when tdata.allow_overlap is True.")
else:
raise ValueError("Tree keys must be a string, list of strings, or None.")
6 changes: 4 additions & 2 deletions src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from ._docs import _doc_params, doc_common_plot_args
from ._utils import (
_check_tree_overlap,
_get_categorical_colors,
_get_categorical_markers,
_series_to_rgb_array,
Expand All @@ -34,7 +35,7 @@ def branches(
extend_branches: bool = False,
angled_branches: bool = False,
color: str = "black",
linewidth: int | float | str = .5,
linewidth: int | float | str = 0.5,
depth_key: str = "depth",
tree: str | Sequence[str] | None = None,
cmap: str | mcolors.Colormap = "viridis",
Expand Down Expand Up @@ -79,6 +80,7 @@ def branches(
""" # noqa: D205
# Setup
tree_keys = tree
_check_tree_overlap(tdata, tree_keys)
if ax is None:
fig, ax = plt.subplots(subplot_kw={"projection": "polar"} if polar else None)
elif (ax.name == "polar" and not polar) or (ax.name != "polar" and polar):
Expand Down Expand Up @@ -498,7 +500,7 @@ def tree(
angled_branches: bool = False,
depth_key: str = "depth",
branch_color: str = "black",
branch_linewidth: int | float | str = .5,
branch_linewidth: int | float | str = 0.5,
node_color: str = "black",
node_style: str = "o",
node_size: int | float = 10,
Expand Down
18 changes: 17 additions & 1 deletion src/pycea/pp/setup_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@ def _add_depth(tree, depth_key):
nx.set_node_attributes(tree, depths, depth_key)


def _check_tree_overlap(tdata, tree_keys):
"""Check single tree is requested when allow_overlap is True"""
if tree_keys is None:
tree_keys = tdata.obst.keys()
if tdata.allow_overlap and len(tree_keys) > 1:
raise ValueError("Must specify a tree when tdata.allow_overlap is True.")
elif isinstance(tree_keys, str):
pass
elif isinstance(tree_keys, Sequence):
if tdata.allow_overlap:
raise ValueError("Cannot request multiple trees when tdata.allow_overlap is True.")
else:
raise ValueError("Tree keys must be a string, list of strings, or None.")


def add_depth(
tdata: td.TreeData, key_added: str = "depth", tree: str | Sequence[str] | None = None, copy: bool = False
) -> None | pd.DataFrame:
Expand Down Expand Up @@ -44,9 +59,10 @@ def add_depth(
- Distance from the root node.
"""
tree_keys = tree
_check_tree_overlap(tdata, tree_keys)
trees = get_trees(tdata, tree_keys)
for _, tree in trees.items():
_add_depth(tree, key_added)
tdata.obs[key_added] = get_keyed_leaf_data(tdata, key_added)[key_added]
tdata.obs[key_added] = get_keyed_leaf_data(tdata, key_added, tree_keys)[key_added]
if copy:
return get_keyed_node_data(tdata, key_added, tree_keys)
15 changes: 15 additions & 0 deletions src/pycea/tl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,18 @@ def _assert_param_xor(params):
if n_set == 0:
raise ValueError(f"At least one of {param_text} must be set.")
return None


def _check_tree_overlap(tdata, tree_keys):
"""Check single tree is requested when allow_overlap is True"""
if tree_keys is None:
tree_keys = tdata.obst.keys()
if tdata.allow_overlap and len(tree_keys) > 1:
raise ValueError("Must specify a tree when tdata.allow_overlap is True.")
elif isinstance(tree_keys, str):
pass
elif isinstance(tree_keys, Sequence):
if tdata.allow_overlap:
raise ValueError("Cannot request multiple trees when tdata.allow_overlap is True.")
else:
raise ValueError("Tree keys must be a string, list of strings, or None.")
3 changes: 3 additions & 0 deletions src/pycea/tl/ancestral_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from pycea.utils import get_keyed_node_data, get_keyed_obs_data, get_root, get_trees

from ._utils import _check_tree_overlap


def _most_common(arr):
"""Finds the most common element in a list."""
Expand Down Expand Up @@ -256,6 +258,7 @@ def ancestral_states(
if len(keys) != len(keys_added):
raise ValueError("Length of keys must match length of keys_added.")
tree_keys = tree
_check_tree_overlap(tdata, tree_keys)
trees = get_trees(tdata, tree_keys)
for _, tree in trees.items():
data, is_array = get_keyed_obs_data(tdata, keys)
Expand Down
3 changes: 3 additions & 0 deletions src/pycea/tl/clades.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from pycea.utils import check_tree_has_key, get_keyed_leaf_data, get_root, get_trees

from ._utils import _check_tree_overlap


def _nodes_at_depth(tree, parent, nodes, depth, depth_key):
"""Recursively finds nodes at a given depth."""
Expand Down Expand Up @@ -100,6 +102,7 @@ def clades(
"""
# Setup
tree_keys = tree
_check_tree_overlap(tdata, tree_keys)
trees = get_trees(tdata, tree_keys)
if clades and len(trees) > 1:
raise ValueError("Multiple trees are present. Must specify a single tree if clades are given.")
Expand Down
2 changes: 2 additions & 0 deletions src/pycea/tl/tree_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ._metrics import _get_tree_metric, _TreeMetric
from ._utils import (
_check_previous_params,
_check_tree_overlap,
_csr_data_mask,
_format_keys,
_set_distances_and_connectivities,
Expand Down Expand Up @@ -164,6 +165,7 @@ def tree_distance(
key_added = key_added or "tree"
connect_key = _format_keys(connect_key, "connectivities")
tree_keys = tree
_check_tree_overlap(tdata, tree_keys)
trees = get_trees(tdata, tree_keys)
metric_fn = _get_tree_metric(metric)
single_obs = False
Expand Down
2 changes: 2 additions & 0 deletions src/pycea/tl/tree_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ._utils import (
_assert_param_xor,
_check_previous_params,
_check_tree_overlap,
_csr_data_mask,
_set_distances_and_connectivities,
_set_random_state,
Expand Down Expand Up @@ -142,6 +143,7 @@ def tree_neighbors(
_assert_param_xor({"n_neighbors": n_neighbors, "max_dist": max_dist})
_ = _get_tree_metric(metric)
tree_keys = tree
_check_tree_overlap(tdata, tree_keys)
if update:
_check_previous_params(tdata, {"metric": metric}, key_added, ["neighbors", "distances"])
# Neighbors of a single leaf
Expand Down
14 changes: 8 additions & 6 deletions src/pycea/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ def check_tree_has_key(tree: nx.DiGraph, key: str):


def get_keyed_edge_data(
tdata: td.TreeData, keys: str | Sequence[str], tree_keys: str | Sequence[str] = None
tdata: td.TreeData, keys: str | Sequence[str], tree: str | Sequence[str] = None
) -> pd.DataFrame:
"""Gets edge data for a given key from a tree or set of trees."""
tree_keys = tree
if isinstance(tree_keys, str):
tree_keys = [tree_keys]
if isinstance(keys, str):
Expand All @@ -65,9 +66,10 @@ def get_keyed_edge_data(


def get_keyed_node_data(
tdata: td.TreeData, keys: str | Sequence[str], tree_keys: str | Sequence[str] = None
tdata: td.TreeData, keys: str | Sequence[str], tree: str | Sequence[str] = None
) -> pd.DataFrame:
"""Gets node data for a given key from a tree or set of trees."""
tree_keys = tree
if isinstance(tree_keys, str):
tree_keys = [tree_keys]
if isinstance(keys, str):
Expand All @@ -86,9 +88,10 @@ def get_keyed_node_data(


def get_keyed_leaf_data(
tdata: td.TreeData, keys: str | Sequence[str], tree_keys: str | Sequence[str] = None
tdata: td.TreeData, keys: str | Sequence[str], tree: str | Sequence[str] = None
) -> pd.DataFrame:
"""Gets node data for a given key from a tree or set of trees."""
tree_keys = tree
if isinstance(tree_keys, str):
tree_keys = [tree_keys]
if isinstance(keys, str):
Expand Down Expand Up @@ -156,16 +159,15 @@ def get_keyed_obsm_data(tdata: td.TreeData, key: str) -> sp.sparse.csr_matrix:
return X


def get_trees(tdata: td.TreeData, tree_keys: str | Sequence[str] | None) -> Mapping[str, nx.DiGraph]:
def get_trees(tdata: td.TreeData, tree: str | Sequence[str] | None) -> Mapping[str, nx.DiGraph]:
"""Gets tree data for a given key from a tree."""
trees = {}
tree_keys = tree
if tree_keys is None:
tree_keys = tdata.obst.keys()
elif isinstance(tree_keys, str):
tree_keys = [tree_keys]
elif isinstance(tree_keys, Sequence):
if tdata.allow_overlap:
raise ValueError("Cannot request multiple trees when tdata.allow_overlap is True.")
tree_keys = list(tree_keys)
else:
raise ValueError("Tree keys must be a string, list of strings, or None.")
Expand Down
18 changes: 18 additions & 0 deletions tests/test_setup_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ def tdata():
yield tdata


@pytest.fixture
def tdata_with_overlap():
tree = nx.DiGraph([("root", "A"), ("root", "B"), ("B", "C"), ("B", "D")])
tdata = td.TreeData(
obs=pd.DataFrame(index=["A", "C", "D"]), obst={"tree1": tree, "tree2": tree}, allow_overlap=True
)
yield tdata


def test_add_depth(tdata):
depths = add_depth(tdata, key_added="depth", copy=True)
assert depths.loc[("tree1", "root"), "depth"] == 0
Expand All @@ -25,5 +34,14 @@ def test_add_depth(tdata):
assert tdata.obs.loc["C", "depth"] == 2


def test_add_depth_overlap(tdata_with_overlap):
with pytest.raises(ValueError):
add_depth(tdata_with_overlap, key_added="depth", copy=True)
depths = add_depth(tdata_with_overlap, key_added="depth", tree="tree1", copy=True)
assert depths.loc[("tree1", "C"), "depth"] == 2
depths = add_depth(tdata_with_overlap, key_added="depth", tree="tree2", copy=True)
assert depths.loc[("tree2", "C"), "depth"] == 2


if __name__ == "__main__":
pytest.main(["-v", __file__])
16 changes: 11 additions & 5 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ def tree():
def tdata(tree):
tdata = td.TreeData(
obs=pd.DataFrame({"value": ["1", "2"]}, index=["D", "E"]),
obst={"tree": tree},
obst={"tree": tree, "tree2": tree},
obsm={"spatial": pd.DataFrame([[0, 0], [1, 1]], index=["D", "E"])},
allow_overlap=True,
)
yield tdata

Expand Down Expand Up @@ -64,22 +65,27 @@ def test_get_subtree_leaves(tree):


def test_get_keyed_edge_data(tdata):
data = get_keyed_edge_data(tdata, ["weight", "color"])
data = get_keyed_edge_data(tdata, ["weight", "color"], tree="tree")
assert data.columns.tolist() == ["weight", "color"]
assert data.index.names == ["tree", "edge"]
assert data["weight"].to_list() == [5, 3, 4]
data = get_keyed_edge_data(tdata, ["weight", "color"])
assert data.shape[0] == 6
assert data.index.get_level_values("tree").unique().tolist() == ["tree", "tree2"]


def test_get_keyed_node_data(tdata):
data = get_keyed_node_data(tdata, ["value", "color"])
data = get_keyed_node_data(tdata, ["value", "color"], tree="tree")
assert data.columns.tolist() == ["value", "color"]
assert data.index.names == ["tree", "node"]
assert data["value"].to_list() == [1, 2, 2, 4, 4]
data = get_keyed_node_data(tdata, ["value", "color"])
assert data.shape[0] == 10
assert data.index.get_level_values("tree").unique().tolist() == ["tree", "tree2"]


def test_get_keyed_leaf_data(tdata):
data = get_keyed_leaf_data(tdata, ["value", "color"])
print(data)
data = get_keyed_leaf_data(tdata, ["value", "color"], tree="tree")
assert data.columns.tolist() == ["value", "color"]
assert data["value"].tolist() == [4, 4]
assert data["color"].tolist() == ["blue", "blue"]
Expand Down
Loading