Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up StateReachability pass for large state machines #1370

Merged
merged 3 commits into from
Sep 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 84 additions & 25 deletions dace/transformation/passes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Set[Tuple[SDFGState, Union[nd.AccessNode, InterstateEdge]]]]]
SymbolScopeDict = Dict[str, Dict[Edge[InterstateEdge], Set[Union[Edge[InterstateEdge], SDFGState]]]]


@properties.make_properties
class StateReachability(ppl.Pass):
"""
Expand All @@ -35,13 +36,68 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Set[SDFGSta
"""
reachable: Dict[int, Dict[SDFGState, Set[SDFGState]]] = {}
for sdfg in top_sdfg.all_sdfgs_recursive():
reachable[sdfg.sdfg_id] = {}
tc: nx.DiGraph = nx.transitive_closure(sdfg.nx)
for state in sdfg.nodes():
reachable[sdfg.sdfg_id][state] = set(tc.successors(state))
result: Dict[SDFGState, Set[SDFGState]] = {}

# In networkx this is currently implemented naively for directed graphs.
# The implementation below is faster
# tc: nx.DiGraph = nx.transitive_closure(sdfg.nx)

for n, v in reachable_nodes(sdfg.nx):
result[n] = set(v)

reachable[sdfg.sdfg_id] = result

return reachable


def _single_shortest_path_length_no_self(adj, source):
"""Yields (node, level) in a breadth first search, without the first level
unless a self-edge exists.

Adapted from Shortest Path Length helper function in NetworkX.

Parameters
----------
adj : dict
Adjacency dict or view
firstlevel : dict
starting nodes, e.g. {source: 1} or {target: 1}
cutoff : int or float
level at which we stop the process
"""
firstlevel = {source: 1}

seen = {} # level (number of hops) when seen in BFS
level = 0 # the current level
nextlevel = set(firstlevel) # set of nodes to check at next level
n = len(adj)
while nextlevel:
thislevel = nextlevel # advance to next level
nextlevel = set() # and start a new set (fringe)
found = []
for v in thislevel:
if v not in seen:
if level == 0 and v is source: # Skip 0-length path to self
found.append(v)
continue
seen[v] = level # set the level of vertex v
found.append(v)
yield (v, level)
if len(seen) == n:
return
for v in found:
nextlevel.update(adj[v])
level += 1
del seen


def reachable_nodes(G):
"""Computes the reachable nodes in G."""
adj = G.adj
for n in G:
yield (n, dict(_single_shortest_path_length_no_self(adj, n)))


@properties.make_properties
class SymbolAccessSets(ppl.Pass):
"""
Expand All @@ -57,9 +113,8 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
# If anything was modified, reapply
return modified & ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Symbols | ppl.Modifies.Nodes

def apply_pass(
self, top_sdfg: SDFG, _
) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]:
def apply_pass(self, top_sdfg: SDFG,
_) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]:
"""
:return: A dictionary mapping each state to a tuple of its (read, written) data descriptors.
"""
Expand Down Expand Up @@ -216,9 +271,8 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
def depends_on(self):
return {SymbolAccessSets, StateReachability}

def _find_dominating_write(
self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]], state_idom: Dict[SDFGState, SDFGState]
) -> Optional[Edge[InterstateEdge]]:
def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]],
state_idom: Dict[SDFGState, SDFGState]) -> Optional[Edge[InterstateEdge]]:
last_state: SDFGState = read if isinstance(read, SDFGState) else read.src

in_edges = last_state.parent.in_edges(last_state)
Expand Down Expand Up @@ -257,9 +311,9 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int,

idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state)
all_doms = cfg.all_dominators(sdfg, idom)
symbol_access_sets: Dict[
Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]
] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id]
symbol_access_sets: Dict[Union[SDFGState, Edge[InterstateEdge]],
Tuple[Set[str],
Set[str]]] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id]
state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.sdfg_id]

for read_loc, (reads, _) in symbol_access_sets.items():
Expand Down Expand Up @@ -321,12 +375,14 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
def depends_on(self):
return {AccessSets, FindAccessNodes, StateReachability}

def _find_dominating_write(
self, desc: str, state: SDFGState, read: Union[nd.AccessNode, InterstateEdge],
access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]],
state_idom: Dict[SDFGState, SDFGState], access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]],
no_self_shadowing: bool = False
) -> Optional[Tuple[SDFGState, nd.AccessNode]]:
def _find_dominating_write(self,
desc: str,
state: SDFGState,
read: Union[nd.AccessNode, InterstateEdge],
access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]],
state_idom: Dict[SDFGState, SDFGState],
access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]],
no_self_shadowing: bool = False) -> Optional[Tuple[SDFGState, nd.AccessNode]]:
if isinstance(read, nd.AccessNode):
# If the read is also a write, it shadows itself.
iedges = state.in_edges(read)
Expand Down Expand Up @@ -408,18 +464,21 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i
for oedge in out_edges:
syms = oedge.data.free_symbols & anames
if desc in syms:
write = self._find_dominating_write(
desc, state, oedge.data, access_nodes, idom, access_sets
)
write = self._find_dominating_write(desc, state, oedge.data, access_nodes, idom,
access_sets)
result[desc][write].add((state, oedge.data))
# Take care of any write nodes that have not been assigned to a scope yet, i.e., writes that are not
# dominating any reads and are thus not part of the results yet.
for state in desc_states_with_nodes:
for write_node in access_nodes[desc][state][1]:
if not (state, write_node) in result[desc]:
write = self._find_dominating_write(
desc, state, write_node, access_nodes, idom, access_sets, no_self_shadowing=True
)
write = self._find_dominating_write(desc,
state,
write_node,
access_nodes,
idom,
access_sets,
no_self_shadowing=True)
result[desc][write].add((state, write_node))

# If any write A is dominated by another write B and any reads in B's scope are also reachable by A,
Expand Down
Loading