-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #389 from es-ude/388-add-humble-base-graph-and-ite…
…rators feat(ir): add graph delegate and iterators
- Loading branch information
Showing
4 changed files
with
154 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from collections.abc import Hashable, Iterable, Iterator | ||
from typing import Generic, TypeVar | ||
|
||
HashableT = TypeVar("HashableT", bound=Hashable) | ||
|
||
|
||
class GraphDelegate(Generic[HashableT]): | ||
def __init__(self) -> None: | ||
"""We keep successor and predecessor nodes just to allow for easier implementation. | ||
Currently, this implementation is not optimized for performance. | ||
""" | ||
self.successors: dict[HashableT, dict[HashableT, None]] = dict() | ||
self.predecessors: dict[HashableT, dict[HashableT, None]] = dict() | ||
|
||
@staticmethod | ||
def from_dict(d: dict[HashableT, Iterable[HashableT]]): | ||
g = GraphDelegate() | ||
for node, successors in d.items(): | ||
for s in successors: | ||
g.add_edge(node, s) | ||
return g | ||
|
||
def as_dict(self) -> dict[HashableT, set[HashableT]]: | ||
return self.successors.copy() | ||
|
||
def add_edge(self, _from: HashableT, _to: HashableT): | ||
self.add_node(_from) | ||
self.add_node(_to) | ||
self.predecessors[_to][_from] = None | ||
self.successors[_from][_to] = None | ||
return self | ||
|
||
def add_node(self, node: HashableT): | ||
if node not in self.predecessors: | ||
self.predecessors[node] = dict() | ||
if node not in self.successors: | ||
self.successors[node] = dict() | ||
return self | ||
|
||
def iter_nodes(self) -> Iterator[HashableT]: | ||
yield from self.predecessors.keys() | ||
|
||
def get_edges(self) -> Iterator[tuple[HashableT, HashableT]]: | ||
for _from, _tos in self.successors.items(): | ||
for _to in _tos: | ||
yield _from, _to | ||
|
||
def get_successors(self, node: HashableT) -> Iterator[HashableT]: | ||
yield from self.successors[node] | ||
|
||
def get_predecessors(self, node: HashableT) -> Iterator[HashableT]: | ||
yield from self.predecessors[node] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from .graph_delegate import GraphDelegate | ||
from .graph_iterators import bfs_iter_up, dfs_pre_order | ||
|
||
|
||
def test_iterating_breadth_first_upwards(): | ||
g = GraphDelegate() | ||
""" | ||
0 | ||
| | ||
/-----\ | ||
| | | ||
1 2 | ||
| /---+ | ||
|/ | | ||
3 4 | ||
| | | ||
| 6 | ||
|/----+ | ||
5 | ||
""" | ||
g = GraphDelegate.from_dict( | ||
{ | ||
"0": ["1", "2"], | ||
"1": ["3"], | ||
"2": ["3", "4"], | ||
"3": ["5"], | ||
"4": ["6"], | ||
"6": ["5"], | ||
} | ||
) | ||
|
||
actual = tuple(bfs_iter_up(g.get_predecessors, "5")) | ||
assert actual == ("3", "6", "1", "2", "4", "0") | ||
|
||
|
||
def test_iterating_depth_first_preorder(): | ||
g = GraphDelegate() | ||
""" | ||
0 | ||
| | ||
/-----\ | ||
| | | ||
1 2 | ||
| /---+ | ||
|/ | | ||
3 4 | ||
| | | ||
| 6 | ||
|/----+ | ||
5 | ||
""" | ||
g = GraphDelegate.from_dict( | ||
{ | ||
"0": ["1", "2"], | ||
"1": ["3"], | ||
"2": ["3", "4"], | ||
"3": ["5"], | ||
"4": ["6"], | ||
"6": ["5"], | ||
} | ||
) | ||
|
||
actual = tuple(dfs_pre_order(g.get_successors, "0")) | ||
print(tuple(g.get_successors("0"))) | ||
expected = ("0", "1", "3", "5", "2", "4", "6") | ||
assert actual == expected |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from collections.abc import Callable, Iterable, Iterator | ||
from typing import Hashable, TypeAlias, TypeVar | ||
|
||
HashableT = TypeVar("HashableT", bound=Hashable) | ||
|
||
NodeNeighbourFn: TypeAlias = Callable[[HashableT], Iterable[HashableT]] | ||
|
||
|
||
def dfs_pre_order(successors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]: | ||
visited: set[HashableT] = set() | ||
|
||
def visit(nodes: tuple[HashableT, ...]): | ||
for n in nodes: | ||
if n not in visited: | ||
yield n | ||
visited.add(n) | ||
yield from visit(tuple(successors(n))) | ||
|
||
yield from visit((start,)) | ||
|
||
|
||
def bfs_iter_down(successors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]: | ||
visited: set[HashableT] = set() | ||
visit_next = [start] | ||
while len(visit_next) > 0: | ||
current = visit_next.pop(0) | ||
for p in successors(current): | ||
if p not in visited: | ||
yield p | ||
visited.add(p) | ||
visit_next.append(p) | ||
visited.add(current) | ||
|
||
|
||
def bfs_iter_up(predecessors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]: | ||
return bfs_iter_down(predecessors, start) |