Skip to content

Commit

Permalink
Moved names_in_ast from symbolic to astutils
Browse files Browse the repository at this point in the history
  • Loading branch information
BenWeber42 committed Nov 14, 2023
1 parent 659814e commit 8db5e9a
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 19 deletions.
14 changes: 14 additions & 0 deletions dace/frontend/python/astutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,20 @@ def visit_Subscript(self, node: ast.Subscript):
return self.generic_visit(node)


def names_in_ast(tree: ast.AST):
""" Walks an AST and finds all names, excluding function names. """
symbols = []
skip = set()
for node in ast.walk(tree):
if node in skip:
continue
if isinstance(node, ast.Call):
skip.add(node.func)
if isinstance(node, ast.Name):
symbols.append(node.id)
return dtypes.deduplicate(symbols)


class TaskletFreeSymbolVisitor(ast.NodeVisitor):
"""
Simple Python AST visitor to find free symbols in a code, not including
Expand Down
6 changes: 3 additions & 3 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def read_symbols(self) -> Set[str]:
Returns a set of symbols read in this edge (including symbols in the condition and assignment values).
"""
# Symbols in conditions and assignments
result = set(map(str, dace.symbolic.names_in_ast(self.condition.code[0])))
result = set(map(str, astutils.names_in_ast(self.condition.code[0])))
for assign in self.assignments.values():
result |= symbolic.free_symbols_and_functions(assign)

Expand All @@ -266,14 +266,14 @@ def used_symbols(self, all_symbols: bool) -> Set[str]:
# exlcuding keys from being considered "defined" if they have been already read.

# Symbols in conditions are always free, because the condition is executed before the assignments
cond_symbols = set(map(str, dace.symbolic.names_in_ast(self.condition.code[0])))
cond_symbols = set(map(str, astutils.names_in_ast(self.condition.code[0])))
# Symbols in assignment keys are candidate defined symbols
lhs_symbols = set()
# Symbols in assignment values are candidate free symbols
rhs_symbols = set()
for lhs, rhs in self.assignments.items():
# Always add LHS symbols to the set of candidate free symbols
rhs_symbols |= set(map(str, dace.symbolic.names_in_ast(ast.parse(rhs))))
rhs_symbols |= set(map(str, astutils.names_in_ast(ast.parse(rhs))))
# Add the RHS to the set of candidate defined symbols ONLY if it has not been read yet
# This also solves the ordering issue that may arise in cases like the 3rd example above
if lhs not in cond_symbols and lhs not in rhs_symbols:
Expand Down
14 changes: 0 additions & 14 deletions dace/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,20 +455,6 @@ def resolve_symbol_to_constant(symb, start_sdfg):
return None


def names_in_ast(tree: ast.AST):
""" Walks an AST and finds all names, excluding function names. """
symbols = []
skip = set()
for node in ast.walk(tree):
if node in skip:
continue
if isinstance(node, ast.Call):
skip.add(node.func)
if isinstance(node, ast.Name):
symbols.append(node.id)
return dtypes.deduplicate(symbols)


def symbol_name_or_value(val):
""" Returns the symbol name if symbol, otherwise the value as a string. """
if isinstance(val, symbol):
Expand Down
5 changes: 3 additions & 2 deletions dace/transformation/dataflow/prune_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dace.transformation import transformation as pm, helpers
from dace.sdfg import nodes, utils
from dace.sdfg.analysis import cfg
from dace.frontend.python import astutils


@properties.make_properties
Expand Down Expand Up @@ -158,7 +159,7 @@ def _candidates(nsdfg: nodes.NestedSDFG) -> Set[str]:
local_ignore = None
for e in nsdfg.sdfg.out_edges(nstate):
# Look for symbols in condition
candidates -= (set(map(str, symbolic.names_in_ast(e.data.condition.code[0]))) - ignore)
candidates -= (set(map(str, astutils.names_in_ast(e.data.condition.code[0]))) - ignore)

for assign in e.data.assignments.values():
candidates -= (symbolic.free_symbols_and_functions(assign) - ignore)
Expand Down Expand Up @@ -258,7 +259,7 @@ def _candidates(cls, nsdfg: nodes.NestedSDFG) -> Tuple[Set[str], Set[Tuple[SDFGS

# Any array that is used in interstate edges is removed
for e in nsdfg.sdfg.edges():
candidates -= (set(map(str, symbolic.names_in_ast(e.data.condition.code[0]))))
candidates -= (set(map(str, astutils.names_in_ast(e.data.condition.code[0]))))
for assign in e.data.assignments.values():
candidates -= (symbolic.free_symbols_and_functions(assign))

Expand Down

0 comments on commit 8db5e9a

Please sign in to comment.