Skip to content

Commit

Permalink
Reworked code to avoid deprecation warnings and errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexnick83 committed Oct 7, 2023
1 parent dcbfd2a commit a8d7431
Showing 1 changed file with 36 additions and 17 deletions.
53 changes: 36 additions & 17 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,29 @@
Shape = Union[ShapeTuple, ShapeList]
DependencyType = Dict[str, Tuple[SDFGState, Union[Memlet, nodes.Tasklet], Tuple[int]]]


if sys.version_info < (3, 8):
_simple_ast_nodes = (ast.Constant, ast.Name, ast.NameConstant, ast.Num)
BytesConstant = ast.Bytes
EllipsisConstant = ast.Ellipsis
NameConstant = ast.NameConstant
NumConstant = ast.Num
StrConstant = ast.Str
else:
_simple_ast_nodes = (ast.Constant, ast.Name)
BytesConstant = ast.Constant
EllipsisConstant = ast.Constant
NameConstant = ast.Constant
NumConstant = ast.Constant
StrConstant = ast.Constant


if sys.version_info < (3, 9):
Index = ast.Index
ExtSlice = ast.ExtSlice
else:
Index = type(None)
ExtSlice = type(None)


class SkipCall(Exception):
Expand Down Expand Up @@ -986,13 +1005,13 @@ def visit_TopLevelExpr(self, node):
raise DaceSyntaxError(self, node, 'Local variable is already a tasklet input or output')
self.outputs[connector] = memlet
return None # Remove from final tasklet code
elif isinstance(node.value, ast.Str):
elif isinstance(node.value, StrConstant):
return self.visit_TopLevelStr(node.value)

return self.generic_visit(node)

# Detect external tasklet code
def visit_TopLevelStr(self, node: ast.Str):
def visit_TopLevelStr(self, node: StrConstant):
if self.extcode != None:
raise DaceSyntaxError(self, node, 'Cannot provide more than one intrinsic implementation ' + 'for tasklet')
self.extcode = node.s
Expand Down Expand Up @@ -1616,7 +1635,7 @@ def _parse_for_indices(self, node: ast.Expr):

return indices

def _parse_value(self, node: Union[ast.Name, ast.Num, ast.Constant]):
def _parse_value(self, node: Union[ast.Name, NumConstant, ast.Constant]):
"""Parses a value
Arguments:
Expand All @@ -1631,7 +1650,7 @@ def _parse_value(self, node: Union[ast.Name, ast.Num, ast.Constant]):

if isinstance(node, ast.Name):
return node.id
elif isinstance(node, ast.Num):
elif sys.version_info < (3.8) and isinstance(node, ast.Num):
return str(node.n)
elif isinstance(node, ast.Constant):
return str(node.value)
Expand All @@ -1651,14 +1670,14 @@ def _parse_slice(self, node: ast.Slice):
return (self._parse_value(node.lower), self._parse_value(node.upper),
self._parse_value(node.step) if node.step is not None else "1")

def _parse_index_as_range(self, node: Union[ast.Index, ast.Tuple]):
def _parse_index_as_range(self, node: Union[Index, ast.Tuple]):
"""
Parses an index as range
:param node: Index node
:return: Range in (from, to, step) format
"""
if isinstance(node, ast.Index):
if sys.version_info < (3.9) and isinstance(node, ast.Index):
val = self._parse_value(node.value)
elif isinstance(node, ast.Tuple):
val = self._parse_value(node.elts)
Expand Down Expand Up @@ -1765,7 +1784,7 @@ def visit_ast_or_value(arg):
iterator = 'dace.map'
else:
ranges = []
if isinstance(node.slice, (ast.Tuple, ast.ExtSlice)):
if isinstance(node.slice, (ast.Tuple, ExtSlice)):
for s in node.slice.dims:
ranges.append(self._parse_slice(s))
elif isinstance(node.slice, ast.Slice):
Expand Down Expand Up @@ -4297,7 +4316,7 @@ def visit_Call(self, node: ast.Call, create_callbacks=False):
func = None
funcname = None
# If the call directly refers to an SDFG or dace-compatible program
if isinstance(node.func, ast.Num):
if sys.version_info < (3, 8) and isinstance(node.func, ast.Num):
if self._has_sdfg(node.func.n):
func = node.func.n
elif isinstance(node.func, ast.Constant):
Expand Down Expand Up @@ -4620,11 +4639,11 @@ def visit_Str(self, node: ast.Str):
# A string constant returns a string literal
return StringLiteral(node.s)

def visit_Bytes(self, node: ast.Bytes):
def visit_Bytes(self, node: BytesConstant):
# A bytes constant returns a string literal
return StringLiteral(node.s)

def visit_Num(self, node: ast.Num):
def visit_Num(self, node: NumConstant):
if isinstance(node.n, bool):
return dace.bool_(node.n)
if isinstance(node.n, (int, float, complex)):
Expand All @@ -4644,7 +4663,7 @@ def visit_Name(self, node: ast.Name):
# If visiting a name, check if it is a defined variable or a global
return self._visitname(node.id, node)

def visit_NameConstant(self, node: ast.NameConstant):
def visit_NameConstant(self, node: NameConstant):
return self.visit_Constant(node)

def visit_Attribute(self, node: ast.Attribute):
Expand Down Expand Up @@ -4919,7 +4938,7 @@ def _promote(node: ast.AST) -> Union[Any, str, symbolic.symbol]:
res = self.visit(s)
else:
res = self._visit_ast_or_value(s)
elif isinstance(s, ast.Index):
elif sys.version_info < (3.9) and isinstance(s, ast.Index):
res = self._parse_subscript_slice(s.value)
elif isinstance(s, ast.Slice):
lower = s.lower
Expand All @@ -4937,7 +4956,7 @@ def _promote(node: ast.AST) -> Union[Any, str, symbolic.symbol]:
res = ((lower, upper, step), )
elif isinstance(s, ast.Tuple):
res = tuple(self._parse_subscript_slice(d, multidim=True) for d in s.elts)
elif isinstance(s, ast.ExtSlice):
elif sys.version_info < (3, 9) and isinstance(s, ast.ExtSlice):
res = tuple(self._parse_subscript_slice(d, multidim=True) for d in s.dims)
else:
res = _promote(s)
Expand Down Expand Up @@ -4999,8 +5018,8 @@ def visit_Subscript(self, node: ast.Subscript, inference: bool = False):
# If the value is a tuple of constants (e.g., array.shape) and the
# slice is constant, return the value itself
nslice = self.visit(node.slice)
if isinstance(nslice, (ast.Index, Number)):
if isinstance(nslice, ast.Index):
if isinstance(nslice, (Index, Number)):
if sys.version_info < (3, 9) and isinstance(nslice, ast.Index):
v = self._parse_value(nslice.value)
else:
v = nslice
Expand Down Expand Up @@ -5064,15 +5083,15 @@ def _visit_ast_or_value(self, node: ast.AST) -> Any:
out = out[0]
return out

def visit_Index(self, node: ast.Index) -> Any:
def visit_Index(self, node: Index) -> Any:
if isinstance(node.value, ast.Tuple):
for i, elt in enumerate(node.value.elts):
node.value.elts[i] = self._visit_ast_or_value(elt)
return node
node.value = self._visit_ast_or_value(node.value)
return node

def visit_ExtSlice(self, node: ast.ExtSlice) -> Any:
def visit_ExtSlice(self, node: ExtSlice) -> Any:
for i, dim in enumerate(node.dims):
node.dims[i] = self._visit_ast_or_value(dim)

Expand Down

0 comments on commit a8d7431

Please sign in to comment.