-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from YosefLab/neighbor-distance
added neighbor distance
- Loading branch information
Showing
6 changed files
with
172 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from .ancestral_states import ancestral_states | ||
from .clades import clades | ||
from .distance import compare_distance, distance | ||
from .neighbor_distance import neighbor_distance | ||
from .sort import sort | ||
from .tree_distance import tree_distance | ||
from .tree_neighbors import tree_neighbors |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from __future__ import annotations | ||
|
||
from collections.abc import Callable | ||
|
||
import numpy as np | ||
import pandas as pd | ||
import scipy as sp | ||
import treedata as td | ||
|
||
from ._utils import _csr_data_mask, _format_keys | ||
|
||
|
||
def _get_agg_func(method): | ||
"""Returns aggregation function.""" | ||
agg_funcs = {"mean": np.mean, "max": np.max, "min": np.min, "median": np.median} | ||
if method in agg_funcs: | ||
return agg_funcs[method] | ||
elif callable(method): | ||
return method | ||
else: | ||
raise ValueError(f"Invalid method: {method}") | ||
|
||
|
||
def _assert_distance_specified(dist, mask): | ||
"""Asserts that distance is specified for where connected""" | ||
if isinstance(dist, sp.sparse.csr_matrix): | ||
dist_mask = _csr_data_mask(dist) | ||
if not dist_mask[mask].sum() == mask.sum(): | ||
raise ValueError("Distance must be specified for all connected observations.") | ||
return | ||
|
||
|
||
def neighbor_distance( | ||
tdata: td.TreeData, | ||
connect_key: str | None = None, | ||
dist_key: str | None = None, | ||
method: str | Callable = "mean", | ||
key_added: str = "neighbor_distances", | ||
copy: bool = False, | ||
) -> None | pd.Series: | ||
"""Aggregates distance to neighboring observations. | ||
Parameters | ||
---------- | ||
tdata | ||
The TreeData object. | ||
connect_key | ||
`tdata.obsp` connectivity key specifying set of neighbors for each observation. | ||
dist_key | ||
`tdata.obsp` distances key specifying distances between observations. | ||
method | ||
Method to calculate neighbor distances: | ||
* 'mean' : The mean distance to neighboring observations. | ||
* 'median' : The median distance to neighboring observations. | ||
* 'min' : The minimum distance to neighboring observations. | ||
* 'max' : The maximum distance to neighboring observations. | ||
* Any function that takes a list of values and returns a single value. | ||
key_added | ||
`tdata.obs` key to store neighbor distances. | ||
copy | ||
If True, returns a :class:`Series <pandas.Series>` with neighbor distances. | ||
Returns | ||
------- | ||
Returns `None` if `copy=False`, else returns a :class:`Series <pandas.Series>. | ||
Sets the following fields: | ||
* `tdata.obs[key_added]` : :class:`Series <pandas.Series>` (dtype `float`) | ||
- Neighbor distances for each observation. | ||
""" | ||
# Setup | ||
if connect_key is None: | ||
raise ValueError("connect_key must be specified.") | ||
if dist_key is None: | ||
raise ValueError("dist_key must be specified.") | ||
_format_keys(connect_key, "connectivities") | ||
_format_keys(dist_key, "distances") | ||
agg_func = _get_agg_func(method) | ||
mask = tdata.obsp[connect_key] > 0 | ||
dist = tdata.obsp[dist_key] | ||
_assert_distance_specified(dist, mask) | ||
# Calculate neighbor distances | ||
agg_dist = [] | ||
for i in range(dist.shape[0]): | ||
if isinstance(mask, sp.sparse.csr_matrix): | ||
indices = mask[i].indices | ||
else: | ||
indices = np.nonzero(mask[i])[0] | ||
row_dist = dist[i, indices] | ||
if row_dist.size > 0: | ||
agg_dist.append(agg_func(row_dist)) | ||
else: | ||
agg_dist.append(np.nan) | ||
# Update tdata and return | ||
tdata.obs[key_added] = agg_dist | ||
if copy: | ||
return tdata.obs[key_added] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
import scipy as sp | ||
import treedata as td | ||
|
||
from pycea.tl.neighbor_distance import neighbor_distance | ||
|
||
|
||
@pytest.fixture | ||
def tdata(): | ||
distances = np.array([[1, 2, 3], [2, 1, 2], [3, 2, 1]]) | ||
neighbors = np.array([[0, 1, 1], [0, 0, 1], [0, 0, 0]]) | ||
tdata = td.TreeData( | ||
obs=pd.DataFrame({"group": ["1", "1", "2"]}, index=["A", "B", "C"]), | ||
obsp={ | ||
"connectivities": neighbors, | ||
"sparse_connectivities": sp.sparse.csr_matrix(neighbors), | ||
"distances": distances, | ||
"sparse_distances": sp.sparse.csr_matrix(distances), | ||
}, | ||
) | ||
yield tdata | ||
|
||
|
||
@pytest.mark.parametrize("connect_key", ["connectivities", "sparse_connectivities"]) | ||
@pytest.mark.parametrize("dist_key", ["distances", "sparse_distances"]) | ||
def test_neighbor_distance(tdata, connect_key, dist_key): | ||
distances = neighbor_distance(tdata, connect_key=connect_key, dist_key=dist_key, copy=True) | ||
assert tdata.obs["neighbor_distances"].equals(distances) | ||
assert isinstance(distances, pd.Series) | ||
assert np.allclose(distances.values.tolist(), [2.5, 2, np.nan], equal_nan=True) | ||
|
||
|
||
def test_neighbor_distance_methods(tdata): | ||
distances = neighbor_distance(tdata, connect_key="connectivities", dist_key="distances", method="min", copy=True) | ||
assert np.allclose(distances.values.tolist(), [2, 2, np.nan], equal_nan=True) | ||
distances = neighbor_distance(tdata, connect_key="connectivities", dist_key="distances", method="max", copy=True) | ||
assert np.allclose(distances.values.tolist(), [3, 2, np.nan], equal_nan=True) | ||
distances = neighbor_distance(tdata, connect_key="connectivities", dist_key="distances", method="median", copy=True) | ||
assert np.allclose(distances.values.tolist(), [2.5, 2, np.nan], equal_nan=True) | ||
distances = neighbor_distance(tdata, connect_key="connectivities", dist_key="distances", method=np.mean, copy=True) | ||
assert np.allclose(distances.values.tolist(), [2.5, 2, np.nan], equal_nan=True) | ||
|
||
|
||
def test_neighbor_distance_missing(tdata): | ||
tdata.obsp["missing_distances"] = sp.sparse.csr_matrix(([1, 1], ([0, 0], [0, 1])), shape=(3, 3)) | ||
with pytest.raises(ValueError): | ||
neighbor_distance(tdata, connect_key="connectivities", dist_key="missing_distances", copy=True) | ||
with pytest.raises(ValueError): | ||
neighbor_distance(tdata, connect_key="sparse_connectivities", dist_key="missing_distances", copy=True) | ||
|
||
|
||
def test_neighbor_distance_invalid(tdata): | ||
with pytest.raises(ValueError): | ||
neighbor_distance(tdata, connect_key=None, dist_key="distances", copy=True) | ||
with pytest.raises(ValueError): | ||
neighbor_distance(tdata, connect_key="connectivities", dist_key=None, copy=True) | ||
with pytest.raises(ValueError): | ||
neighbor_distance(tdata, connect_key="connectivities", dist_key="distances", method="invalid", copy=True) | ||
with pytest.raises(KeyError): | ||
neighbor_distance(tdata, connect_key="invalid", dist_key="distances", copy=True) | ||
with pytest.raises(KeyError): | ||
neighbor_distance(tdata, connect_key="connectivities", dist_key="invalid", copy=True) | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main(["-v", __file__]) |