From aa88d82c246f2a52a7d0889cd2dc811ce2cf8b20 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 19 Sep 2023 23:17:22 -0700 Subject: [PATCH] Speed up StateReachability pass for large state machines --- dace/transformation/passes/analysis.py | 64 ++++++++++++++++---------- 1 file changed, 39 insertions(+), 25 deletions(-) diff --git a/dace/transformation/passes/analysis.py b/dace/transformation/passes/analysis.py index 1ca92d5ffd..b59bfee5d1 100644 --- a/dace/transformation/passes/analysis.py +++ b/dace/transformation/passes/analysis.py @@ -14,6 +14,7 @@ Set[Tuple[SDFGState, Union[nd.AccessNode, InterstateEdge]]]]] SymbolScopeDict = Dict[str, Dict[Edge[InterstateEdge], Set[Union[Edge[InterstateEdge], SDFGState]]]] + @properties.make_properties class StateReachability(ppl.Pass): """ @@ -35,10 +36,20 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Set[SDFGSta """ reachable: Dict[int, Dict[SDFGState, Set[SDFGState]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - reachable[sdfg.sdfg_id] = {} - tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) - for state in sdfg.nodes(): - reachable[sdfg.sdfg_id][state] = set(tc.successors(state)) + result: Dict[SDFGState, Set[SDFGState]] = {} + + # In networkx this is currently implemented naively for directed graphs. + # The implementation below is faster + # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) + + for n, v in nx.all_pairs_shortest_path_length(sdfg.nx): + result[n] = set(t for t, l in v.items() if l > 0) + # Add self-edges + if n in sdfg.successors(n): + result[n].add(n) + + reachable[sdfg.sdfg_id] = result + return reachable @@ -57,9 +68,8 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply return modified & ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Symbols | ppl.Modifies.Nodes - def apply_pass( - self, top_sdfg: SDFG, _ - ) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]: + def apply_pass(self, top_sdfg: SDFG, + _) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]: """ :return: A dictionary mapping each state to a tuple of its (read, written) data descriptors. """ @@ -216,9 +226,8 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def depends_on(self): return {SymbolAccessSets, StateReachability} - def _find_dominating_write( - self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]], state_idom: Dict[SDFGState, SDFGState] - ) -> Optional[Edge[InterstateEdge]]: + def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]], + state_idom: Dict[SDFGState, SDFGState]) -> Optional[Edge[InterstateEdge]]: last_state: SDFGState = read if isinstance(read, SDFGState) else read.src in_edges = last_state.parent.in_edges(last_state) @@ -257,9 +266,9 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int, idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state) all_doms = cfg.all_dominators(sdfg, idom) - symbol_access_sets: Dict[ - Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]] - ] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id] + symbol_access_sets: Dict[Union[SDFGState, Edge[InterstateEdge]], + Tuple[Set[str], + Set[str]]] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id] state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.sdfg_id] for read_loc, (reads, _) in symbol_access_sets.items(): @@ -321,12 +330,14 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def depends_on(self): return {AccessSets, FindAccessNodes, StateReachability} - def _find_dominating_write( - self, desc: str, state: SDFGState, read: Union[nd.AccessNode, InterstateEdge], - access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]], - state_idom: Dict[SDFGState, SDFGState], access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]], - no_self_shadowing: bool = False - ) -> Optional[Tuple[SDFGState, nd.AccessNode]]: + def _find_dominating_write(self, + desc: str, + state: SDFGState, + read: Union[nd.AccessNode, InterstateEdge], + access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]], + state_idom: Dict[SDFGState, SDFGState], + access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]], + no_self_shadowing: bool = False) -> Optional[Tuple[SDFGState, nd.AccessNode]]: if isinstance(read, nd.AccessNode): # If the read is also a write, it shadows itself. iedges = state.in_edges(read) @@ -408,18 +419,21 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i for oedge in out_edges: syms = oedge.data.free_symbols & anames if desc in syms: - write = self._find_dominating_write( - desc, state, oedge.data, access_nodes, idom, access_sets - ) + write = self._find_dominating_write(desc, state, oedge.data, access_nodes, idom, + access_sets) result[desc][write].add((state, oedge.data)) # Take care of any write nodes that have not been assigned to a scope yet, i.e., writes that are not # dominating any reads and are thus not part of the results yet. for state in desc_states_with_nodes: for write_node in access_nodes[desc][state][1]: if not (state, write_node) in result[desc]: - write = self._find_dominating_write( - desc, state, write_node, access_nodes, idom, access_sets, no_self_shadowing=True - ) + write = self._find_dominating_write(desc, + state, + write_node, + access_nodes, + idom, + access_sets, + no_self_shadowing=True) result[desc][write].add((state, write_node)) # If any write A is dominated by another write B and any reads in B's scope are also reachable by A,