diff --git a/src/pycea/tl/tree_neighbors.py b/src/pycea/tl/tree_neighbors.py index a5cb0f2..b693694 100755 --- a/src/pycea/tl/tree_neighbors.py +++ b/src/pycea/tl/tree_neighbors.py @@ -71,8 +71,8 @@ def _tree_neighbors(tree, n_neighbors, max_dist, depth_key, metric, leaves=None) leaves = [node for node in tree.nodes() if tree.out_degree(node) == 0] for leaf in leaves: neighbors, neighbor_distances = _bfs_by_distance(tree, leaf, n_neighbors, max_dist, metric, depth_key) - rows.extend(neighbors) - cols.extend([leaf] * len(neighbors)) + rows.extend([leaf] * len(neighbors)) + cols.extend(neighbors) distances.extend(neighbor_distances) return rows, cols, distances diff --git a/tests/test_tree_neighbors.py b/tests/test_tree_neighbors.py index e7ce3de..b44df77 100755 --- a/tests/test_tree_neighbors.py +++ b/tests/test_tree_neighbors.py @@ -44,7 +44,7 @@ def test_tree_neighbors_n(tdata): def test_select_tree_neighbors(tdata): tree_neighbors(tdata, n_neighbors=2, metric="path", obs="C") - assert tdata.obs.query("tree_neighbors").index.tolist() == ["D", "E"] + assert tdata.obs.query("tree_neighbors").index.tolist() == ["C"] tree_neighbors(tdata, n_neighbors=3, metric="path", obs=["C", "D"], random_state=0) assert tdata.obsp["tree_connectivities"].sum() == 2