Skip to content

Commit

Permalink
Merge pull request #8 from YosefLab/neighbor-distance
Browse files Browse the repository at this point in the history
added neighbor distance
  • Loading branch information
colganwi authored Aug 19, 2024
2 parents 995f12a + 108d08a commit 7cc4e02
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ To learn more about pycea, please refer to the [documentation][link-docs] or the

## Installation

You need to have Python 3.9 or newer installed on your system. If you don't have
You need to have Python 3.10 or newer installed on your system. If you don't have
Python installed, we recommend installing [Mambaforge](https://github.com/conda-forge/miniforge#mambaforge).

There are several alternative options to install pycea:
Expand Down
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
tl.clades
tl.compare_distance
tl.distance
tl.neighbor_distance
tl.sort
tl.tree_distance
tl.tree_neighbors
Expand Down
1 change: 1 addition & 0 deletions src/pycea/tl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .ancestral_states import ancestral_states
from .clades import clades
from .distance import compare_distance, distance
from .neighbor_distance import neighbor_distance
from .sort import sort
from .tree_distance import tree_distance
from .tree_neighbors import tree_neighbors
100 changes: 100 additions & 0 deletions src/pycea/tl/neighbor_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from __future__ import annotations

from collections.abc import Callable

import numpy as np
import pandas as pd
import scipy as sp
import treedata as td

from ._utils import _csr_data_mask, _format_keys


def _get_agg_func(method):
"""Returns aggregation function."""
agg_funcs = {"mean": np.mean, "max": np.max, "min": np.min, "median": np.median}
if method in agg_funcs:
return agg_funcs[method]
elif callable(method):
return method
else:
raise ValueError(f"Invalid method: {method}")


def _assert_distance_specified(dist, mask):
"""Asserts that distance is specified for where connected"""
if isinstance(dist, sp.sparse.csr_matrix):
dist_mask = _csr_data_mask(dist)
if not dist_mask[mask].sum() == mask.sum():
raise ValueError("Distance must be specified for all connected observations.")
return


def neighbor_distance(
tdata: td.TreeData,
connect_key: str | None = None,
dist_key: str | None = None,
method: str | Callable = "mean",
key_added: str = "neighbor_distances",
copy: bool = False,
) -> None | pd.Series:
"""Aggregates distance to neighboring observations.
Parameters
----------
tdata
The TreeData object.
connect_key
`tdata.obsp` connectivity key specifying set of neighbors for each observation.
dist_key
`tdata.obsp` distances key specifying distances between observations.
method
Method to calculate neighbor distances:
* 'mean' : The mean distance to neighboring observations.
* 'median' : The median distance to neighboring observations.
* 'min' : The minimum distance to neighboring observations.
* 'max' : The maximum distance to neighboring observations.
* Any function that takes a list of values and returns a single value.
key_added
`tdata.obs` key to store neighbor distances.
copy
If True, returns a :class:`Series <pandas.Series>` with neighbor distances.
Returns
-------
Returns `None` if `copy=False`, else returns a :class:`Series <pandas.Series>.
Sets the following fields:
* `tdata.obs[key_added]` : :class:`Series <pandas.Series>` (dtype `float`)
- Neighbor distances for each observation.
"""
# Setup
if connect_key is None:
raise ValueError("connect_key must be specified.")
if dist_key is None:
raise ValueError("dist_key must be specified.")
_format_keys(connect_key, "connectivities")
_format_keys(dist_key, "distances")
agg_func = _get_agg_func(method)
mask = tdata.obsp[connect_key] > 0
dist = tdata.obsp[dist_key]
_assert_distance_specified(dist, mask)
# Calculate neighbor distances
agg_dist = []
for i in range(dist.shape[0]):
if isinstance(mask, sp.sparse.csr_matrix):
indices = mask[i].indices
else:
indices = np.nonzero(mask[i])[0]
row_dist = dist[i, indices]
if row_dist.size > 0:
agg_dist.append(agg_func(row_dist))
else:
agg_dist.append(np.nan)
# Update tdata and return
tdata.obs[key_added] = agg_dist
if copy:
return tdata.obs[key_added]
2 changes: 1 addition & 1 deletion src/pycea/tl/tree_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def tree_neighbors(
tree: str | Sequence[str] | None = None,
copy: bool = False,
) -> None | tuple[sp.sparse.csr_matrix, sp.sparse.csr_matrix]:
"""Identify neighbors in the tree.
"""Identifies neighbors in the tree.
Parameters
----------
Expand Down
68 changes: 68 additions & 0 deletions tests/test_neighbor_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy as np
import pandas as pd
import pytest
import scipy as sp
import treedata as td

from pycea.tl.neighbor_distance import neighbor_distance


@pytest.fixture
def tdata():
distances = np.array([[1, 2, 3], [2, 1, 2], [3, 2, 1]])
neighbors = np.array([[0, 1, 1], [0, 0, 1], [0, 0, 0]])
tdata = td.TreeData(
obs=pd.DataFrame({"group": ["1", "1", "2"]}, index=["A", "B", "C"]),
obsp={
"connectivities": neighbors,
"sparse_connectivities": sp.sparse.csr_matrix(neighbors),
"distances": distances,
"sparse_distances": sp.sparse.csr_matrix(distances),
},
)
yield tdata


@pytest.mark.parametrize("connect_key", ["connectivities", "sparse_connectivities"])
@pytest.mark.parametrize("dist_key", ["distances", "sparse_distances"])
def test_neighbor_distance(tdata, connect_key, dist_key):
distances = neighbor_distance(tdata, connect_key=connect_key, dist_key=dist_key, copy=True)
assert tdata.obs["neighbor_distances"].equals(distances)
assert isinstance(distances, pd.Series)
assert np.allclose(distances.values.tolist(), [2.5, 2, np.nan], equal_nan=True)


def test_neighbor_distance_methods(tdata):
distances = neighbor_distance(tdata, connect_key="connectivities", dist_key="distances", method="min", copy=True)
assert np.allclose(distances.values.tolist(), [2, 2, np.nan], equal_nan=True)
distances = neighbor_distance(tdata, connect_key="connectivities", dist_key="distances", method="max", copy=True)
assert np.allclose(distances.values.tolist(), [3, 2, np.nan], equal_nan=True)
distances = neighbor_distance(tdata, connect_key="connectivities", dist_key="distances", method="median", copy=True)
assert np.allclose(distances.values.tolist(), [2.5, 2, np.nan], equal_nan=True)
distances = neighbor_distance(tdata, connect_key="connectivities", dist_key="distances", method=np.mean, copy=True)
assert np.allclose(distances.values.tolist(), [2.5, 2, np.nan], equal_nan=True)


def test_neighbor_distance_missing(tdata):
tdata.obsp["missing_distances"] = sp.sparse.csr_matrix(([1, 1], ([0, 0], [0, 1])), shape=(3, 3))
with pytest.raises(ValueError):
neighbor_distance(tdata, connect_key="connectivities", dist_key="missing_distances", copy=True)
with pytest.raises(ValueError):
neighbor_distance(tdata, connect_key="sparse_connectivities", dist_key="missing_distances", copy=True)


def test_neighbor_distance_invalid(tdata):
with pytest.raises(ValueError):
neighbor_distance(tdata, connect_key=None, dist_key="distances", copy=True)
with pytest.raises(ValueError):
neighbor_distance(tdata, connect_key="connectivities", dist_key=None, copy=True)
with pytest.raises(ValueError):
neighbor_distance(tdata, connect_key="connectivities", dist_key="distances", method="invalid", copy=True)
with pytest.raises(KeyError):
neighbor_distance(tdata, connect_key="invalid", dist_key="distances", copy=True)
with pytest.raises(KeyError):
neighbor_distance(tdata, connect_key="connectivities", dist_key="invalid", copy=True)


if __name__ == "__main__":
pytest.main(["-v", __file__])

0 comments on commit 7cc4e02

Please sign in to comment.