Skip to content

Commit

Permalink
Merge pull request #389 from es-ude/388-add-humble-base-graph-and-ite…
Browse files Browse the repository at this point in the history
…rators

feat(ir): add graph delegate and iterators
  • Loading branch information
glencoe authored Nov 19, 2024
2 parents ca31f79 + 3da0dbc commit 07f5160
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 0 deletions.
Empty file.
52 changes: 52 additions & 0 deletions elasticai/creator/ir/graph_delegate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from collections.abc import Hashable, Iterable, Iterator
from typing import Generic, TypeVar

HashableT = TypeVar("HashableT", bound=Hashable)


class GraphDelegate(Generic[HashableT]):
def __init__(self) -> None:
"""We keep successor and predecessor nodes just to allow for easier implementation.
Currently, this implementation is not optimized for performance.
"""
self.successors: dict[HashableT, dict[HashableT, None]] = dict()
self.predecessors: dict[HashableT, dict[HashableT, None]] = dict()

@staticmethod
def from_dict(d: dict[HashableT, Iterable[HashableT]]):
g = GraphDelegate()
for node, successors in d.items():
for s in successors:
g.add_edge(node, s)
return g

def as_dict(self) -> dict[HashableT, set[HashableT]]:
return self.successors.copy()

def add_edge(self, _from: HashableT, _to: HashableT):
self.add_node(_from)
self.add_node(_to)
self.predecessors[_to][_from] = None
self.successors[_from][_to] = None
return self

def add_node(self, node: HashableT):
if node not in self.predecessors:
self.predecessors[node] = dict()
if node not in self.successors:
self.successors[node] = dict()
return self

def iter_nodes(self) -> Iterator[HashableT]:
yield from self.predecessors.keys()

def get_edges(self) -> Iterator[tuple[HashableT, HashableT]]:
for _from, _tos in self.successors.items():
for _to in _tos:
yield _from, _to

def get_successors(self, node: HashableT) -> Iterator[HashableT]:
yield from self.successors[node]

def get_predecessors(self, node: HashableT) -> Iterator[HashableT]:
yield from self.predecessors[node]
66 changes: 66 additions & 0 deletions elasticai/creator/ir/graph_delegate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from .graph_delegate import GraphDelegate
from .graph_iterators import bfs_iter_up, dfs_pre_order


def test_iterating_breadth_first_upwards():
g = GraphDelegate()
"""
0
|
/-----\
| |
1 2
| /---+
|/ |
3 4
| |
| 6
|/----+
5
"""
g = GraphDelegate.from_dict(
{
"0": ["1", "2"],
"1": ["3"],
"2": ["3", "4"],
"3": ["5"],
"4": ["6"],
"6": ["5"],
}
)

actual = tuple(bfs_iter_up(g.get_predecessors, "5"))
assert actual == ("3", "6", "1", "2", "4", "0")


def test_iterating_depth_first_preorder():
g = GraphDelegate()
"""
0
|
/-----\
| |
1 2
| /---+
|/ |
3 4
| |
| 6
|/----+
5
"""
g = GraphDelegate.from_dict(
{
"0": ["1", "2"],
"1": ["3"],
"2": ["3", "4"],
"3": ["5"],
"4": ["6"],
"6": ["5"],
}
)

actual = tuple(dfs_pre_order(g.get_successors, "0"))
print(tuple(g.get_successors("0")))
expected = ("0", "1", "3", "5", "2", "4", "6")
assert actual == expected
36 changes: 36 additions & 0 deletions elasticai/creator/ir/graph_iterators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from collections.abc import Callable, Iterable, Iterator
from typing import Hashable, TypeAlias, TypeVar

HashableT = TypeVar("HashableT", bound=Hashable)

NodeNeighbourFn: TypeAlias = Callable[[HashableT], Iterable[HashableT]]


def dfs_pre_order(successors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]:
visited: set[HashableT] = set()

def visit(nodes: tuple[HashableT, ...]):
for n in nodes:
if n not in visited:
yield n
visited.add(n)
yield from visit(tuple(successors(n)))

yield from visit((start,))


def bfs_iter_down(successors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]:
visited: set[HashableT] = set()
visit_next = [start]
while len(visit_next) > 0:
current = visit_next.pop(0)
for p in successors(current):
if p not in visited:
yield p
visited.add(p)
visit_next.append(p)
visited.add(current)


def bfs_iter_up(predecessors: NodeNeighbourFn, start: HashableT) -> Iterator[HashableT]:
return bfs_iter_down(predecessors, start)

0 comments on commit 07f5160

Please sign in to comment.