Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad authored Sep 25, 2023
2 parents ef79124 + 13d1397 commit f29ddb0
Show file tree
Hide file tree
Showing 10 changed files with 374 additions and 265 deletions.
7 changes: 3 additions & 4 deletions dace/frontend/fortran/ast_components.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
from fparser.two.Fortran2008 import Fortran2008 as f08
from fparser.two import Fortran2008
from fparser.two import Fortran2008 as f08
from fparser.two import Fortran2003 as f03
from fparser.two import symbol_table

Expand Down Expand Up @@ -608,7 +607,7 @@ def type_declaration_stmt(self, node: FASTNode):
if i.string.lower() == "parameter":
symbol = True

if isinstance(i, Fortran2008.Attr_Spec_List):
if isinstance(i, f08.Attr_Spec_List):

dimension_spec = get_children(i, "Dimension_Attr_Spec")
if len(dimension_spec) == 0:
Expand Down Expand Up @@ -1052,7 +1051,7 @@ def specification_part(self, node: FASTNode):

decls = [self.create_ast(i) for i in node.children if isinstance(i, f08.Type_Declaration_Stmt)]

uses = [self.create_ast(i) for i in node.children if isinstance(i, f08.Use_Stmt)]
uses = [self.create_ast(i) for i in node.children if isinstance(i, f03.Use_Stmt)]
tmp = [self.create_ast(i) for i in node.children]
typedecls = [i for i in tmp if isinstance(i, ast_internal_classes.Type_Decl_Node)]
symbols = []
Expand Down
14 changes: 14 additions & 0 deletions dace/frontend/python/astutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,3 +705,17 @@ def escape_string(value: Union[bytes, str]):
return value.encode("unicode_escape").decode("utf-8")
# Python 2.x
return value.encode('string_escape')


def parse_function_arguments(node: ast.Call, argnames: List[str]) -> Dict[str, ast.AST]:
"""
Parses function arguments (both positional and keyword) from a Call node,
based on the function's argument names. If an argument was not given, it will
not be in the result.
"""
result = {}
for arg, aname in zip(node.args, argnames):
result[aname] = arg
for kw in node.keywords:
result[kw.arg] = kw.value
return result
3 changes: 2 additions & 1 deletion dace/frontend/python/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,11 @@ class tasklet(metaclass=TaskletMetaclass):
The DaCe framework cannot analyze these tasklets for optimization.
"""

def __init__(self, language: Union[str, dtypes.Language] = dtypes.Language.Python):
def __init__(self, language: Union[str, dtypes.Language] = dtypes.Language.Python, side_effects: bool = False):
if isinstance(language, str):
language = dtypes.Language[language]
self.language = language
self.side_effects = side_effects

def __enter__(self):
if self.language != dtypes.Language.Python:
Expand Down
6 changes: 5 additions & 1 deletion dace/frontend/python/memlet_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,11 @@ def ParseMemlet(visitor,
if len(node.value.args) >= 2:
write_conflict_resolution = node.value.args[1]

subset, new_axes, arrdims = parse_memlet_subset(array, node, das, parsed_slice)
try:
subset, new_axes, arrdims = parse_memlet_subset(array, node, das, parsed_slice)
except IndexError:
raise DaceSyntaxError(visitor, node, 'Failed to parse memlet expression due to dimensionality. '
f'Array dimensions: {array.shape}, expression in code: {astutils.unparse(node)}')

# If undefined, default number of accesses is the slice size
if num_accesses is None:
Expand Down
17 changes: 17 additions & 0 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2510,6 +2510,7 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None):

# Looking for the first argument in a tasklet annotation: @dace.tasklet(STRING HERE)
langInf = None
side_effects = None
if isinstance(node, ast.FunctionDef) and \
hasattr(node, 'decorator_list') and \
isinstance(node.decorator_list, list) and \
Expand All @@ -2522,6 +2523,19 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None):
langArg = node.decorator_list[0].args[0].value
langInf = dtypes.Language[langArg]

# Extract arguments from with statement
if isinstance(node, ast.With):
expr = node.items[0].context_expr
if isinstance(expr, ast.Call):
args = astutils.parse_function_arguments(expr, ['language', 'side_effects'])
langArg = args.get('language', None)
side_effects = args.get('side_effects', None)
langInf = astutils.evalnode(langArg, {**self.globals, **self.defined})
if isinstance(langInf, str):
langInf = dtypes.Language[langInf]

side_effects = astutils.evalnode(side_effects, {**self.globals, **self.defined})

ttrans = TaskletTransformer(self,
self.defined,
self.sdfg,
Expand All @@ -2536,6 +2550,9 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None):
symbols=self.symbols)
node, inputs, outputs, self.accesses = ttrans.parse_tasklet(node, name)

if side_effects is not None:
node.side_effects = side_effects

# Convert memlets to their actual data nodes
for i in inputs.values():
if not isinstance(i, tuple) and i.data in self.scope_vars.keys():
Expand Down
2 changes: 1 addition & 1 deletion dace/frontend/python/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,7 +1268,7 @@ def _convert_to_ast(contents: Any):
node)
else:
# Augment closure with new value
newnode = self.resolver.global_value_to_node(e, node, f'inlined_{id(contents)}', True, keep_object=True)
newnode = self.resolver.global_value_to_node(contents, node, f'inlined_{id(contents)}', True, keep_object=True)
return newnode

return _convert_to_ast(contents)
Expand Down
109 changes: 84 additions & 25 deletions dace/transformation/passes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Set[Tuple[SDFGState, Union[nd.AccessNode, InterstateEdge]]]]]
SymbolScopeDict = Dict[str, Dict[Edge[InterstateEdge], Set[Union[Edge[InterstateEdge], SDFGState]]]]


@properties.make_properties
class StateReachability(ppl.Pass):
"""
Expand All @@ -35,13 +36,68 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Set[SDFGSta
"""
reachable: Dict[int, Dict[SDFGState, Set[SDFGState]]] = {}
for sdfg in top_sdfg.all_sdfgs_recursive():
reachable[sdfg.sdfg_id] = {}
tc: nx.DiGraph = nx.transitive_closure(sdfg.nx)
for state in sdfg.nodes():
reachable[sdfg.sdfg_id][state] = set(tc.successors(state))
result: Dict[SDFGState, Set[SDFGState]] = {}

# In networkx this is currently implemented naively for directed graphs.
# The implementation below is faster
# tc: nx.DiGraph = nx.transitive_closure(sdfg.nx)

for n, v in reachable_nodes(sdfg.nx):
result[n] = set(v)

reachable[sdfg.sdfg_id] = result

return reachable


def _single_shortest_path_length_no_self(adj, source):
"""Yields (node, level) in a breadth first search, without the first level
unless a self-edge exists.
Adapted from Shortest Path Length helper function in NetworkX.
Parameters
----------
adj : dict
Adjacency dict or view
firstlevel : dict
starting nodes, e.g. {source: 1} or {target: 1}
cutoff : int or float
level at which we stop the process
"""
firstlevel = {source: 1}

seen = {} # level (number of hops) when seen in BFS
level = 0 # the current level
nextlevel = set(firstlevel) # set of nodes to check at next level
n = len(adj)
while nextlevel:
thislevel = nextlevel # advance to next level
nextlevel = set() # and start a new set (fringe)
found = []
for v in thislevel:
if v not in seen:
if level == 0 and v is source: # Skip 0-length path to self
found.append(v)
continue
seen[v] = level # set the level of vertex v
found.append(v)
yield (v, level)
if len(seen) == n:
return
for v in found:
nextlevel.update(adj[v])
level += 1
del seen


def reachable_nodes(G):
"""Computes the reachable nodes in G."""
adj = G.adj
for n in G:
yield (n, dict(_single_shortest_path_length_no_self(adj, n)))


@properties.make_properties
class SymbolAccessSets(ppl.Pass):
"""
Expand All @@ -57,9 +113,8 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
# If anything was modified, reapply
return modified & ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Symbols | ppl.Modifies.Nodes

def apply_pass(
self, top_sdfg: SDFG, _
) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]:
def apply_pass(self, top_sdfg: SDFG,
_) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]:
"""
:return: A dictionary mapping each state to a tuple of its (read, written) data descriptors.
"""
Expand Down Expand Up @@ -216,9 +271,8 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
def depends_on(self):
return {SymbolAccessSets, StateReachability}

def _find_dominating_write(
self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]], state_idom: Dict[SDFGState, SDFGState]
) -> Optional[Edge[InterstateEdge]]:
def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]],
state_idom: Dict[SDFGState, SDFGState]) -> Optional[Edge[InterstateEdge]]:
last_state: SDFGState = read if isinstance(read, SDFGState) else read.src

in_edges = last_state.parent.in_edges(last_state)
Expand Down Expand Up @@ -257,9 +311,9 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int,

idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state)
all_doms = cfg.all_dominators(sdfg, idom)
symbol_access_sets: Dict[
Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]
] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id]
symbol_access_sets: Dict[Union[SDFGState, Edge[InterstateEdge]],
Tuple[Set[str],
Set[str]]] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id]
state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.sdfg_id]

for read_loc, (reads, _) in symbol_access_sets.items():
Expand Down Expand Up @@ -321,12 +375,14 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
def depends_on(self):
return {AccessSets, FindAccessNodes, StateReachability}

def _find_dominating_write(
self, desc: str, state: SDFGState, read: Union[nd.AccessNode, InterstateEdge],
access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]],
state_idom: Dict[SDFGState, SDFGState], access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]],
no_self_shadowing: bool = False
) -> Optional[Tuple[SDFGState, nd.AccessNode]]:
def _find_dominating_write(self,
desc: str,
state: SDFGState,
read: Union[nd.AccessNode, InterstateEdge],
access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]],
state_idom: Dict[SDFGState, SDFGState],
access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]],
no_self_shadowing: bool = False) -> Optional[Tuple[SDFGState, nd.AccessNode]]:
if isinstance(read, nd.AccessNode):
# If the read is also a write, it shadows itself.
iedges = state.in_edges(read)
Expand Down Expand Up @@ -408,18 +464,21 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i
for oedge in out_edges:
syms = oedge.data.free_symbols & anames
if desc in syms:
write = self._find_dominating_write(
desc, state, oedge.data, access_nodes, idom, access_sets
)
write = self._find_dominating_write(desc, state, oedge.data, access_nodes, idom,
access_sets)
result[desc][write].add((state, oedge.data))
# Take care of any write nodes that have not been assigned to a scope yet, i.e., writes that are not
# dominating any reads and are thus not part of the results yet.
for state in desc_states_with_nodes:
for write_node in access_nodes[desc][state][1]:
if not (state, write_node) in result[desc]:
write = self._find_dominating_write(
desc, state, write_node, access_nodes, idom, access_sets, no_self_shadowing=True
)
write = self._find_dominating_write(desc,
state,
write_node,
access_nodes,
idom,
access_sets,
no_self_shadowing=True)
result[desc][write].add((state, write_node))

# If any write A is dominated by another write B and any reads in B's scope are also reachable by A,
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ charset-normalizer==3.1.0
click==8.1.3
dill==0.3.6
Flask==2.3.2
fparser==0.1.2
fparser==0.1.3
idna==3.4
importlib-metadata==6.6.0
itsdangerous==2.1.2
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
include_package_data=True,
install_requires=[
'numpy', 'networkx >= 2.5', 'astunparse', 'sympy<=1.9', 'pyyaml', 'ply', 'websockets', 'requests', 'flask',
'fparser >= 0.1.2', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill',
'fparser >= 0.1.3', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill',
'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"'
] + cmake_requires,
extras_require={
Expand Down
Loading

0 comments on commit f29ddb0

Please sign in to comment.