From 108d08a3f9a09035050e608e852e53d43c232f52 Mon Sep 17 00:00:00 2001 From: colganwi Date: Mon, 19 Aug 2024 12:54:30 -0400 Subject: [PATCH] added neighbor distance --- README.md | 2 +- docs/api.md | 1 + src/pycea/tl/__init__.py | 1 + src/pycea/tl/neighbor_distance.py | 100 ++++++++++++++++++++++++++++++ src/pycea/tl/tree_neighbors.py | 2 +- tests/test_neighbor_distance.py | 68 ++++++++++++++++++++ 6 files changed, 172 insertions(+), 2 deletions(-) create mode 100755 src/pycea/tl/neighbor_distance.py create mode 100755 tests/test_neighbor_distance.py diff --git a/README.md b/README.md index bc18613..7b7f0e0 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ To learn more about pycea, please refer to the [documentation][link-docs] or the ## Installation -You need to have Python 3.9 or newer installed on your system. If you don't have +You need to have Python 3.10 or newer installed on your system. If you don't have Python installed, we recommend installing [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge). There are several alternative options to install pycea: diff --git a/docs/api.md b/docs/api.md index ee63842..82c181f 100644 --- a/docs/api.md +++ b/docs/api.md @@ -25,6 +25,7 @@ tl.clades tl.compare_distance tl.distance + tl.neighbor_distance tl.sort tl.tree_distance tl.tree_neighbors diff --git a/src/pycea/tl/__init__.py b/src/pycea/tl/__init__.py index 68ea342..7c338f2 100644 --- a/src/pycea/tl/__init__.py +++ b/src/pycea/tl/__init__.py @@ -1,6 +1,7 @@ from .ancestral_states import ancestral_states from .clades import clades from .distance import compare_distance, distance +from .neighbor_distance import neighbor_distance from .sort import sort from .tree_distance import tree_distance from .tree_neighbors import tree_neighbors diff --git a/src/pycea/tl/neighbor_distance.py b/src/pycea/tl/neighbor_distance.py new file mode 100755 index 0000000..2b4b880 --- /dev/null +++ b/src/pycea/tl/neighbor_distance.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from collections.abc import Callable + +import numpy as np +import pandas as pd +import scipy as sp +import treedata as td + +from ._utils import _csr_data_mask, _format_keys + + +def _get_agg_func(method): + """Returns aggregation function.""" + agg_funcs = {"mean": np.mean, "max": np.max, "min": np.min, "median": np.median} + if method in agg_funcs: + return agg_funcs[method] + elif callable(method): + return method + else: + raise ValueError(f"Invalid method: {method}") + + +def _assert_distance_specified(dist, mask): + """Asserts that distance is specified for where connected""" + if isinstance(dist, sp.sparse.csr_matrix): + dist_mask = _csr_data_mask(dist) + if not dist_mask[mask].sum() == mask.sum(): + raise ValueError("Distance must be specified for all connected observations.") + return + + +def neighbor_distance( + tdata: td.TreeData, + connect_key: str | None = None, + dist_key: str | None = None, + method: str | Callable = "mean", + key_added: str = "neighbor_distances", + copy: bool = False, +) -> None | pd.Series: + """Aggregates distance to neighboring observations. + + Parameters + ---------- + tdata + The TreeData object. + connect_key + `tdata.obsp` connectivity key specifying set of neighbors for each observation. + dist_key + `tdata.obsp` distances key specifying distances between observations. + method + Method to calculate neighbor distances: + + * 'mean' : The mean distance to neighboring observations. + * 'median' : The median distance to neighboring observations. + * 'min' : The minimum distance to neighboring observations. + * 'max' : The maximum distance to neighboring observations. + * Any function that takes a list of values and returns a single value. + + key_added + `tdata.obs` key to store neighbor distances. + copy + If True, returns a :class:`Series ` with neighbor distances. + + Returns + ------- + Returns `None` if `copy=False`, else returns a :class:`Series . + + Sets the following fields: + + * `tdata.obs[key_added]` : :class:`Series ` (dtype `float`) + - Neighbor distances for each observation. + """ + # Setup + if connect_key is None: + raise ValueError("connect_key must be specified.") + if dist_key is None: + raise ValueError("dist_key must be specified.") + _format_keys(connect_key, "connectivities") + _format_keys(dist_key, "distances") + agg_func = _get_agg_func(method) + mask = tdata.obsp[connect_key] > 0 + dist = tdata.obsp[dist_key] + _assert_distance_specified(dist, mask) + # Calculate neighbor distances + agg_dist = [] + for i in range(dist.shape[0]): + if isinstance(mask, sp.sparse.csr_matrix): + indices = mask[i].indices + else: + indices = np.nonzero(mask[i])[0] + row_dist = dist[i, indices] + if row_dist.size > 0: + agg_dist.append(agg_func(row_dist)) + else: + agg_dist.append(np.nan) + # Update tdata and return + tdata.obs[key_added] = agg_dist + if copy: + return tdata.obs[key_added] diff --git a/src/pycea/tl/tree_neighbors.py b/src/pycea/tl/tree_neighbors.py index ebe4a2b..5cb205d 100755 --- a/src/pycea/tl/tree_neighbors.py +++ b/src/pycea/tl/tree_neighbors.py @@ -90,7 +90,7 @@ def tree_neighbors( tree: str | Sequence[str] | None = None, copy: bool = False, ) -> None | tuple[sp.sparse.csr_matrix, sp.sparse.csr_matrix]: - """Identify neighbors in the tree. + """Identifies neighbors in the tree. Parameters ---------- diff --git a/tests/test_neighbor_distance.py b/tests/test_neighbor_distance.py new file mode 100755 index 0000000..24aa02b --- /dev/null +++ b/tests/test_neighbor_distance.py @@ -0,0 +1,68 @@ +import numpy as np +import pandas as pd +import pytest +import scipy as sp +import treedata as td + +from pycea.tl.neighbor_distance import neighbor_distance + + +@pytest.fixture +def tdata(): + distances = np.array([[1, 2, 3], [2, 1, 2], [3, 2, 1]]) + neighbors = np.array([[0, 1, 1], [0, 0, 1], [0, 0, 0]]) + tdata = td.TreeData( + obs=pd.DataFrame({"group": ["1", "1", "2"]}, index=["A", "B", "C"]), + obsp={ + "connectivities": neighbors, + "sparse_connectivities": sp.sparse.csr_matrix(neighbors), + "distances": distances, + "sparse_distances": sp.sparse.csr_matrix(distances), + }, + ) + yield tdata + + +@pytest.mark.parametrize("connect_key", ["connectivities", "sparse_connectivities"]) +@pytest.mark.parametrize("dist_key", ["distances", "sparse_distances"]) +def test_neighbor_distance(tdata, connect_key, dist_key): + distances = neighbor_distance(tdata, connect_key=connect_key, dist_key=dist_key, copy=True) + assert tdata.obs["neighbor_distances"].equals(distances) + assert isinstance(distances, pd.Series) + assert np.allclose(distances.values.tolist(), [2.5, 2, np.nan], equal_nan=True) + + +def test_neighbor_distance_methods(tdata): + distances = neighbor_distance(tdata, connect_key="connectivities", dist_key="distances", method="min", copy=True) + assert np.allclose(distances.values.tolist(), [2, 2, np.nan], equal_nan=True) + distances = neighbor_distance(tdata, connect_key="connectivities", dist_key="distances", method="max", copy=True) + assert np.allclose(distances.values.tolist(), [3, 2, np.nan], equal_nan=True) + distances = neighbor_distance(tdata, connect_key="connectivities", dist_key="distances", method="median", copy=True) + assert np.allclose(distances.values.tolist(), [2.5, 2, np.nan], equal_nan=True) + distances = neighbor_distance(tdata, connect_key="connectivities", dist_key="distances", method=np.mean, copy=True) + assert np.allclose(distances.values.tolist(), [2.5, 2, np.nan], equal_nan=True) + + +def test_neighbor_distance_missing(tdata): + tdata.obsp["missing_distances"] = sp.sparse.csr_matrix(([1, 1], ([0, 0], [0, 1])), shape=(3, 3)) + with pytest.raises(ValueError): + neighbor_distance(tdata, connect_key="connectivities", dist_key="missing_distances", copy=True) + with pytest.raises(ValueError): + neighbor_distance(tdata, connect_key="sparse_connectivities", dist_key="missing_distances", copy=True) + + +def test_neighbor_distance_invalid(tdata): + with pytest.raises(ValueError): + neighbor_distance(tdata, connect_key=None, dist_key="distances", copy=True) + with pytest.raises(ValueError): + neighbor_distance(tdata, connect_key="connectivities", dist_key=None, copy=True) + with pytest.raises(ValueError): + neighbor_distance(tdata, connect_key="connectivities", dist_key="distances", method="invalid", copy=True) + with pytest.raises(KeyError): + neighbor_distance(tdata, connect_key="invalid", dist_key="distances", copy=True) + with pytest.raises(KeyError): + neighbor_distance(tdata, connect_key="connectivities", dist_key="invalid", copy=True) + + +if __name__ == "__main__": + pytest.main(["-v", __file__])