diff --git a/dace/transformation/passes/analysis.py b/dace/transformation/passes/analysis.py index 1ca92d5ffd..86e1cde062 100644 --- a/dace/transformation/passes/analysis.py +++ b/dace/transformation/passes/analysis.py @@ -14,6 +14,7 @@ Set[Tuple[SDFGState, Union[nd.AccessNode, InterstateEdge]]]]] SymbolScopeDict = Dict[str, Dict[Edge[InterstateEdge], Set[Union[Edge[InterstateEdge], SDFGState]]]] + @properties.make_properties class StateReachability(ppl.Pass): """ @@ -35,13 +36,68 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Set[SDFGSta """ reachable: Dict[int, Dict[SDFGState, Set[SDFGState]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - reachable[sdfg.sdfg_id] = {} - tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) - for state in sdfg.nodes(): - reachable[sdfg.sdfg_id][state] = set(tc.successors(state)) + result: Dict[SDFGState, Set[SDFGState]] = {} + + # In networkx this is currently implemented naively for directed graphs. + # The implementation below is faster + # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) + + for n, v in reachable_nodes(sdfg.nx): + result[n] = set(v) + + reachable[sdfg.sdfg_id] = result + return reachable +def _single_shortest_path_length_no_self(adj, source): + """Yields (node, level) in a breadth first search, without the first level + unless a self-edge exists. + + Adapted from Shortest Path Length helper function in NetworkX. + + Parameters + ---------- + adj : dict + Adjacency dict or view + firstlevel : dict + starting nodes, e.g. {source: 1} or {target: 1} + cutoff : int or float + level at which we stop the process + """ + firstlevel = {source: 1} + + seen = {} # level (number of hops) when seen in BFS + level = 0 # the current level + nextlevel = set(firstlevel) # set of nodes to check at next level + n = len(adj) + while nextlevel: + thislevel = nextlevel # advance to next level + nextlevel = set() # and start a new set (fringe) + found = [] + for v in thislevel: + if v not in seen: + if level == 0 and v is source: # Skip 0-length path to self + found.append(v) + continue + seen[v] = level # set the level of vertex v + found.append(v) + yield (v, level) + if len(seen) == n: + return + for v in found: + nextlevel.update(adj[v]) + level += 1 + del seen + + +def reachable_nodes(G): + """Computes the reachable nodes in G.""" + adj = G.adj + for n in G: + yield (n, dict(_single_shortest_path_length_no_self(adj, n))) + + @properties.make_properties class SymbolAccessSets(ppl.Pass): """ @@ -57,9 +113,8 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply return modified & ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Symbols | ppl.Modifies.Nodes - def apply_pass( - self, top_sdfg: SDFG, _ - ) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]: + def apply_pass(self, top_sdfg: SDFG, + _) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]: """ :return: A dictionary mapping each state to a tuple of its (read, written) data descriptors. """ @@ -216,9 +271,8 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def depends_on(self): return {SymbolAccessSets, StateReachability} - def _find_dominating_write( - self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]], state_idom: Dict[SDFGState, SDFGState] - ) -> Optional[Edge[InterstateEdge]]: + def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]], + state_idom: Dict[SDFGState, SDFGState]) -> Optional[Edge[InterstateEdge]]: last_state: SDFGState = read if isinstance(read, SDFGState) else read.src in_edges = last_state.parent.in_edges(last_state) @@ -257,9 +311,9 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int, idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state) all_doms = cfg.all_dominators(sdfg, idom) - symbol_access_sets: Dict[ - Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]] - ] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id] + symbol_access_sets: Dict[Union[SDFGState, Edge[InterstateEdge]], + Tuple[Set[str], + Set[str]]] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id] state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.sdfg_id] for read_loc, (reads, _) in symbol_access_sets.items(): @@ -321,12 +375,14 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def depends_on(self): return {AccessSets, FindAccessNodes, StateReachability} - def _find_dominating_write( - self, desc: str, state: SDFGState, read: Union[nd.AccessNode, InterstateEdge], - access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]], - state_idom: Dict[SDFGState, SDFGState], access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]], - no_self_shadowing: bool = False - ) -> Optional[Tuple[SDFGState, nd.AccessNode]]: + def _find_dominating_write(self, + desc: str, + state: SDFGState, + read: Union[nd.AccessNode, InterstateEdge], + access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]], + state_idom: Dict[SDFGState, SDFGState], + access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]], + no_self_shadowing: bool = False) -> Optional[Tuple[SDFGState, nd.AccessNode]]: if isinstance(read, nd.AccessNode): # If the read is also a write, it shadows itself. iedges = state.in_edges(read) @@ -408,18 +464,21 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i for oedge in out_edges: syms = oedge.data.free_symbols & anames if desc in syms: - write = self._find_dominating_write( - desc, state, oedge.data, access_nodes, idom, access_sets - ) + write = self._find_dominating_write(desc, state, oedge.data, access_nodes, idom, + access_sets) result[desc][write].add((state, oedge.data)) # Take care of any write nodes that have not been assigned to a scope yet, i.e., writes that are not # dominating any reads and are thus not part of the results yet. for state in desc_states_with_nodes: for write_node in access_nodes[desc][state][1]: if not (state, write_node) in result[desc]: - write = self._find_dominating_write( - desc, state, write_node, access_nodes, idom, access_sets, no_self_shadowing=True - ) + write = self._find_dominating_write(desc, + state, + write_node, + access_nodes, + idom, + access_sets, + no_self_shadowing=True) result[desc][write].add((state, write_node)) # If any write A is dominated by another write B and any reads in B's scope are also reachable by A,