Skip to content

Commit

Permalink
Merge pull request #1370 from spcl/faster-reachability
Browse files Browse the repository at this point in the history
Speed up StateReachability pass for large state machines
  • Loading branch information
acalotoiu authored Sep 21, 2023
2 parents 66010ec + aac7013 commit a4fa6ed
Showing 1 changed file with 84 additions and 25 deletions.
109 changes: 84 additions & 25 deletions dace/transformation/passes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Set[Tuple[SDFGState, Union[nd.AccessNode, InterstateEdge]]]]]
SymbolScopeDict = Dict[str, Dict[Edge[InterstateEdge], Set[Union[Edge[InterstateEdge], SDFGState]]]]


@properties.make_properties
class StateReachability(ppl.Pass):
"""
Expand All @@ -35,13 +36,68 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Set[SDFGSta
"""
reachable: Dict[int, Dict[SDFGState, Set[SDFGState]]] = {}
for sdfg in top_sdfg.all_sdfgs_recursive():
reachable[sdfg.sdfg_id] = {}
tc: nx.DiGraph = nx.transitive_closure(sdfg.nx)
for state in sdfg.nodes():
reachable[sdfg.sdfg_id][state] = set(tc.successors(state))
result: Dict[SDFGState, Set[SDFGState]] = {}

# In networkx this is currently implemented naively for directed graphs.
# The implementation below is faster
# tc: nx.DiGraph = nx.transitive_closure(sdfg.nx)

for n, v in reachable_nodes(sdfg.nx):
result[n] = set(v)

reachable[sdfg.sdfg_id] = result

return reachable


def _single_shortest_path_length_no_self(adj, source):
"""Yields (node, level) in a breadth first search, without the first level
unless a self-edge exists.
Adapted from Shortest Path Length helper function in NetworkX.
Parameters
----------
adj : dict
Adjacency dict or view
firstlevel : dict
starting nodes, e.g. {source: 1} or {target: 1}
cutoff : int or float
level at which we stop the process
"""
firstlevel = {source: 1}

seen = {} # level (number of hops) when seen in BFS
level = 0 # the current level
nextlevel = set(firstlevel) # set of nodes to check at next level
n = len(adj)
while nextlevel:
thislevel = nextlevel # advance to next level
nextlevel = set() # and start a new set (fringe)
found = []
for v in thislevel:
if v not in seen:
if level == 0 and v is source: # Skip 0-length path to self
found.append(v)
continue
seen[v] = level # set the level of vertex v
found.append(v)
yield (v, level)
if len(seen) == n:
return
for v in found:
nextlevel.update(adj[v])
level += 1
del seen


def reachable_nodes(G):
"""Computes the reachable nodes in G."""
adj = G.adj
for n in G:
yield (n, dict(_single_shortest_path_length_no_self(adj, n)))


@properties.make_properties
class SymbolAccessSets(ppl.Pass):
"""
Expand All @@ -57,9 +113,8 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
# If anything was modified, reapply
return modified & ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Symbols | ppl.Modifies.Nodes

def apply_pass(
self, top_sdfg: SDFG, _
) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]:
def apply_pass(self, top_sdfg: SDFG,
_) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]:
"""
:return: A dictionary mapping each state to a tuple of its (read, written) data descriptors.
"""
Expand Down Expand Up @@ -216,9 +271,8 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
def depends_on(self):
return {SymbolAccessSets, StateReachability}

def _find_dominating_write(
self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]], state_idom: Dict[SDFGState, SDFGState]
) -> Optional[Edge[InterstateEdge]]:
def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]],
state_idom: Dict[SDFGState, SDFGState]) -> Optional[Edge[InterstateEdge]]:
last_state: SDFGState = read if isinstance(read, SDFGState) else read.src

in_edges = last_state.parent.in_edges(last_state)
Expand Down Expand Up @@ -257,9 +311,9 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int,

idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state)
all_doms = cfg.all_dominators(sdfg, idom)
symbol_access_sets: Dict[
Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]
] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id]
symbol_access_sets: Dict[Union[SDFGState, Edge[InterstateEdge]],
Tuple[Set[str],
Set[str]]] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id]
state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.sdfg_id]

for read_loc, (reads, _) in symbol_access_sets.items():
Expand Down Expand Up @@ -321,12 +375,14 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
def depends_on(self):
return {AccessSets, FindAccessNodes, StateReachability}

def _find_dominating_write(
self, desc: str, state: SDFGState, read: Union[nd.AccessNode, InterstateEdge],
access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]],
state_idom: Dict[SDFGState, SDFGState], access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]],
no_self_shadowing: bool = False
) -> Optional[Tuple[SDFGState, nd.AccessNode]]:
def _find_dominating_write(self,
desc: str,
state: SDFGState,
read: Union[nd.AccessNode, InterstateEdge],
access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]],
state_idom: Dict[SDFGState, SDFGState],
access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]],
no_self_shadowing: bool = False) -> Optional[Tuple[SDFGState, nd.AccessNode]]:
if isinstance(read, nd.AccessNode):
# If the read is also a write, it shadows itself.
iedges = state.in_edges(read)
Expand Down Expand Up @@ -408,18 +464,21 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i
for oedge in out_edges:
syms = oedge.data.free_symbols & anames
if desc in syms:
write = self._find_dominating_write(
desc, state, oedge.data, access_nodes, idom, access_sets
)
write = self._find_dominating_write(desc, state, oedge.data, access_nodes, idom,
access_sets)
result[desc][write].add((state, oedge.data))
# Take care of any write nodes that have not been assigned to a scope yet, i.e., writes that are not
# dominating any reads and are thus not part of the results yet.
for state in desc_states_with_nodes:
for write_node in access_nodes[desc][state][1]:
if not (state, write_node) in result[desc]:
write = self._find_dominating_write(
desc, state, write_node, access_nodes, idom, access_sets, no_self_shadowing=True
)
write = self._find_dominating_write(desc,
state,
write_node,
access_nodes,
idom,
access_sets,
no_self_shadowing=True)
result[desc][write].add((state, write_node))

# If any write A is dominated by another write B and any reads in B's scope are also reachable by A,
Expand Down

0 comments on commit a4fa6ed

Please sign in to comment.