From fe23f49eb8351e0ddff20034ac2fda8b2b2205d4 Mon Sep 17 00:00:00 2001 From: Ben Weber Date: Wed, 12 Jun 2024 14:40:25 +0200 Subject: [PATCH] Small cleanup --- dace/sdfg/graph.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/dace/sdfg/graph.py b/dace/sdfg/graph.py index edb95fe4bb..567e5e84d2 100644 --- a/dace/sdfg/graph.py +++ b/dace/sdfg/graph.py @@ -6,7 +6,7 @@ import networkx as nx from dace.dtypes import deduplicate import dace.serialize -from typing import Any, Callable, Generic, Iterable, List, Sequence, TypeVar, Union +from typing import Any, Callable, Generic, Iterable, List, Optional, Sequence, TypeVar, Union class NodeNotFoundError(Exception): @@ -364,19 +364,19 @@ def sink_nodes(self) -> List[NodeT]: """Returns nodes with no outgoing edges.""" return [n for n in self.nodes() if self.out_degree(n) == 0] - def bfs_nodes(self, source: NodeT = None) -> Sequence[NodeT]: - """Returns nodes in topological order iff the graph contains exactly - one node with no incoming edges.""" + def bfs_nodes(self, source: Optional[NodeT] = None) -> Iterable[NodeT]: + """Returns an iterable over nodes traversed in breadth-first search + order starting from ``source``.""" if source is not None: sources = [source] else: sources = self.source_nodes() - if len(sources) == 0: - sources = [self.nodes()[0]] - #raise RuntimeError("No source nodes found") - if len(sources) > 1: - sources = [self.nodes()[0]] - #raise RuntimeError("Multiple source nodes found") + if len(sources) != 1: + source = next(iter(self.nodes()), None) + if source is None: + return [] # graph has no nodes + sources = [source] + seen = OrderedDict() # No OrderedSet in Python queue = deque(sources) while len(queue) > 0: