Skip to content

Commit

Permalink
Adapt dead code elimination
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Sep 21, 2023
1 parent 2067554 commit d543b2a
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 65 deletions.
8 changes: 8 additions & 0 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ def nodes(self) -> List[SomeNodeT]:
def edges(self) -> List[SomeEdgeT]:
...

@overload
def in_degree(self, node: SomeNodeT) -> int:
...

@overload
def out_degree(self, node: SomeNodeT) -> int:
...

###################################################################
# Traversal methods

Expand Down
4 changes: 3 additions & 1 deletion dace/transformation/pass_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def apply_pass(
self,
sdfg: SDFG,
pipeline_results: Dict[str, Any],
**kwargs,
) -> Optional[Dict[nodes.EntryNode, Optional[Any]]]:
"""
Applies the pass to the CFGs of the given SDFG by calling ``apply`` on each CFG.
Expand All @@ -335,7 +336,8 @@ def apply_pass(
"""
result = {}
for scope_block in sdfg.all_state_scopes_recursive(recurse_into_sdfgs=False):
retval = self.apply(scope_block, scope_block.sdfg, pipeline_results)
retval = self.apply(scope_block, scope_block if isinstance(scope_block, SDFG) else scope_block.sdfg,
pipeline_results, **kwargs)
if retval is not None:
result[scope_block] = retval

Expand Down
51 changes: 32 additions & 19 deletions dace/transformation/passes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dace import SDFG, SDFGState, properties, InterstateEdge
from dace.sdfg.graph import Edge
from dace.sdfg import nodes as nd
from dace.sdfg.state import ControlFlowBlock
from dace.sdfg.state import ControlFlowBlock, ScopeBlock
from dace.sdfg.analysis import cfg
from typing import Dict, Set, Tuple, Any, Optional, Union
import networkx as nx
Expand Down Expand Up @@ -37,9 +37,16 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Set[SDFGSta
reachable: Dict[int, Dict[SDFGState, Set[SDFGState]]] = {}
for sdfg in top_sdfg.all_sdfgs_recursive():
reachable[sdfg.sdfg_id] = {}
tc: nx.DiGraph = nx.transitive_closure(sdfg.nx)
for state in sdfg.nodes():
reachable[sdfg.sdfg_id][state] = set(tc.successors(state))
for scope in sdfg.all_state_scopes_recursive(recurse_into_sdfgs=False):
tc: nx.DiGraph = nx.transitive_closure(scope.nx)
for block in scope.nodes():
reach_set = set()
for reached in tc.successors(block):
if isinstance(reached, ScopeBlock):
reach_set.update(reached.all_states_recursive())
else:
reach_set.add(reached)
reachable[sdfg.sdfg_id][block] = reach_set
return reachable


Expand Down Expand Up @@ -84,7 +91,7 @@ def apply_pass(
@properties.make_properties
class AccessSets(ppl.Pass):
"""
Evaluates memory access sets (which arrays/data descriptors are read/written in each state).
Evaluates memory access sets (which arrays/data descriptors are read/written in each control flow block).
"""

CATEGORY: str = 'Analysis'
Expand All @@ -96,26 +103,32 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
# If anything was modified, reapply
return modified & ppl.Modifies.AccessNodes

def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Tuple[Set[str], Set[str]]]]:
def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]]]:
"""
:return: A dictionary mapping each state to a tuple of its (read, written) data descriptors.
:return: A dictionary mapping each control flow block to a tuple of its (read, written) data descriptors.
"""
top_result: Dict[int, Dict[SDFGState, Tuple[Set[str], Set[str]]]] = {}
top_result: Dict[int, Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]]] = {}
for sdfg in top_sdfg.all_sdfgs_recursive():
result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {}
for state in sdfg.nodes():
result: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = {}
for state in sdfg.states():
readset, writeset = set(), set()
for anode in state.data_nodes():
if state.in_degree(anode) > 0:
writeset.add(anode.data)
if state.out_degree(anode) > 0:
readset.add(anode.data)

result[state] = (readset, writeset)

for scope in sdfg.all_state_scopes_recursive(recurse_into_sdfgs=False):
readset, writeset = set(), set()
for substate in scope.all_states_recursive():
readset.update(result[substate][0])
writeset.update(result[substate][1])
result[scope] = (readset, writeset)

# Edges that read from arrays add to both ends' access sets
anames = sdfg.arrays.keys()
for e in sdfg.edges():
for e in sdfg.all_interstate_edges_recursive(recurse_into_sdfgs=False):
fsyms = e.data.free_symbols & anames
if fsyms:
result[e.src][0].update(fsyms)
Expand Down Expand Up @@ -148,13 +161,13 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]:

for sdfg in top_sdfg.all_sdfgs_recursive():
result: Dict[str, Set[SDFGState]] = defaultdict(set)
for state in sdfg.nodes():
for state in sdfg.states():
for anode in state.data_nodes():
result[anode.data].add(state)

# Edges that read from arrays add to both ends' access sets
anames = sdfg.arrays.keys()
for e in sdfg.edges():
for e in sdfg.all_interstate_edges_recursive(recurse_into_sdfgs=False):
fsyms = e.data.free_symbols & anames
for access in fsyms:
result[access].update({e.src, e.dst})
Expand Down Expand Up @@ -189,7 +202,7 @@ def apply_pass(self, top_sdfg: SDFG,
for sdfg in top_sdfg.all_sdfgs_recursive():
result: Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]] = defaultdict(
lambda: defaultdict(lambda: [set(), set()]))
for state in sdfg.nodes():
for state in sdfg.states():
for anode in state.data_nodes():
if state.in_degree(anode) > 0:
result[anode.data][state][1].add(anode)
Expand All @@ -200,10 +213,10 @@ def apply_pass(self, top_sdfg: SDFG,


@properties.make_properties
class SymbolWriteScopes(ppl.Pass):
class SymbolWriteScopes(ppl.Pass): # TODO: adapt
"""
For each symbol, create a dictionary mapping each writing interstate edge to that symbol to the set of interstate
edges and states reading that symbol that are dominated by that write.
edges and control flow blocks reading that symbol that are dominated by that write.
"""

CATEGORY: str = 'Analysis'
Expand Down Expand Up @@ -260,7 +273,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int,
idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state)
all_doms = cfg.all_dominators(sdfg, idom)
symbol_access_sets: Dict[
Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]
Union[ControlFlowBlock, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]
] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id]
state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.sdfg_id]

Expand Down Expand Up @@ -305,7 +318,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int,


@properties.make_properties
class ScalarWriteShadowScopes(ppl.Pass):
class ScalarWriteShadowScopes(ppl.Pass): # TODO: Adapt
"""
For each scalar or array of size 1, create a dictionary mapping writes to that data container to the set of reads
and writes that are dominated by that write.
Expand Down
Loading

0 comments on commit d543b2a

Please sign in to comment.