Skip to content

Commit

Permalink
add filter method to LinkGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
CunliangGeng committed Jul 12, 2024
1 parent 80527e8 commit 1a02e5a
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
59 changes: 59 additions & 0 deletions src/nplinker/scoring/link_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from collections.abc import Sequence
from functools import wraps
from typing import Union
from networkx import Graph
Expand Down Expand Up @@ -209,3 +210,61 @@ def get_link_data(
{"metcalf": Score("metcalf", 1.0, {"cutoff": 0.5})}
"""
return self._g.get_edge_data(u, v) # type: ignore

def filter(self, u_nodes: Sequence[Entity], v_nodes: Sequence[Entity] = [], /) -> LinkGraph:
"""Return a new LinkGraph object with the filtered links between the given objects.
The new LinkGraph object will only contain the links between `u_nodes` and `v_nodes`.
If `u_nodes` or `v_nodes` is empty, the new LinkGraph object will contain the links for
the given objects in `v_nodes` or `u_nodes`, respectively. If both are empty, return an
empty LinkGraph object.
Note that not all objects in `u_nodes` and `v_nodes` need to be present in the original
LinkGraph.
Args:
u_nodes: a sequence of objects used as the first object in the links
v_nodes: a sequence of objects used as the second object in the links
Returns:
A new LinkGraph object with the filtered links between the given objects.
Examples:
Filter the links for `gcf1` and `gcf2`:
>>> new_lg = lg.filter([gcf1, gcf2])
Filter the links for `spectrum1` and `spectrum2`:
>>> new_lg = lg.filter([spectrum1, spectrum2])
Filter the links between two lists of objects:
>>> new_lg = lg.filter([gcf1, gcf2], [spectrum1, spectrum2])
"""
lg = LinkGraph()

# exchange u_nodes and v_nodes if u_nodes is empty but v_nodes not
if len(u_nodes) == 0 and len(v_nodes) != 0:
u_nodes = v_nodes
v_nodes = []

if len(v_nodes) == 0:
for u in u_nodes:
self._filter_one_node(u, lg)

for u in u_nodes:
for v in v_nodes:
self._filter_two_nodes(u, v, lg)

return lg

@validate_u
def _filter_one_node(self, u: Entity, lg: LinkGraph) -> None:
"""Filter the links for a given object and add them to the new LinkGraph object."""
links = self[u]
for node2, value in links.items():
lg.add_link(u, node2, **value)

@validate_uv
def _filter_two_nodes(self, u: Entity, v: Entity, lg: LinkGraph) -> None:
"""Filter the links between two objects and add them to the new LinkGraph object."""
link_data = self.get_link_data(u, v)
if link_data is not None:
lg.add_link(u, v, **link_data)
29 changes: 29 additions & 0 deletions tests/unit/scoring/test_link_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,32 @@ def test_has_link(lg, gcfs, spectra):
def test_get_link_data(lg, gcfs, spectra, score):
assert lg.get_link_data(gcfs[0], spectra[0]) == {"metcalf": score}
assert lg.get_link_data(gcfs[0], spectra[1]) is None


def test_filter(gcfs, spectra, score):
lg = LinkGraph()
lg.add_link(gcfs[0], spectra[0], metcalf=score)
lg.add_link(gcfs[1], spectra[1], metcalf=score)

u_nodes = [gcfs[0], gcfs[1]]
v_nodes = [spectra[0], spectra[1]]

# test filtering with GCFs
lg_filtered = lg.filter(u_nodes)
assert len(lg_filtered) == 4 # number of nodes

# test filtering with Spectra
lg_filtered = lg.filter(v_nodes)
assert len(lg_filtered) == 4

# test empty `u_nodes` argument
lg_filtered = lg.filter([], v_nodes)
assert len(lg_filtered) == 4

# test empty `u_nodes` and `v_nodes` arguments
lg_filtered = lg.filter([], [])
assert len(lg_filtered) == 0

# test filtering with GCFs and Spectra
lg_filtered = lg.filter(u_nodes, v_nodes)
assert len(lg_filtered) == 4

0 comments on commit 1a02e5a

Please sign in to comment.