diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index c9d92b7860..4a089be753 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -1150,6 +1150,9 @@ def __init__(self, # Indirections self.indirections = dict() + # Program variables + self.pvars = dict() # Dict[str, Any] + @classmethod def progress_count(cls) -> int: """ Returns the number of parsed SDFGs so far within this run. """ @@ -3120,6 +3123,12 @@ def visit_AnnAssign(self, node: ast.AnnAssign): self._visit_assign(node, node.target, None, dtype=dtype) def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): + + # NOTE: Assuming (for now) simple assignment with single target (LHS). + # NOTE: This should be enforced by the preprocessor. + # NOTE: There may be issues with implicit swaps (e.g., a, b = b, a). + assert isinstance(node_target, (ast.Name, ast.Subscript, ast.Attribute)) + # Get targets (elts) and results elts = None results = None @@ -3143,6 +3152,12 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): for target, (result, _) in zip(elts, results): + if not isinstance(result, (ast.Name, ast.Subscript, ast.Attribute)): + assert isinstance(target, ast.Name) + assert target.id not in self.pvars + self.pvars[target.id] = result + continue + name = rname(target) true_name = None if name in defined_vars: @@ -4277,8 +4292,20 @@ def visit_Call(self, node: ast.Call, create_callbacks=False): if self._has_sdfg(node.func.value): func = node.func.value + # https://stackoverflow.com/a/2020083 + def fullname(f): + module = f.__module__ + if module == 'builtins': + return f.__qualname__ # avoid outputs like 'builtins.str' + return module + '.' + f.__qualname__ + + if isinstance(node.func, ast.Name) and node.func.id in self.pvars: + funcname = fullname(self.pvars[node.func.id]) + print(funcname) + if func is None: - funcname = rname(node) + if funcname is None: + funcname = rname(node) # Check if the function exists as an SDFG in a different module modname = until(funcname, '.') if ('.' in funcname and len(modname) > 0 and modname in self.globals diff --git a/dace/frontend/python/preprocessing.py b/dace/frontend/python/preprocessing.py index 10a1ab120e..d428318629 100644 --- a/dace/frontend/python/preprocessing.py +++ b/dace/frontend/python/preprocessing.py @@ -20,6 +20,594 @@ from dace.frontend.python.common import (DaceSyntaxError, SDFGConvertible, SDFGClosure, StringLiteral) +_do_not_parse: Tuple[ast.AST] +NamedExpr = ast.AST +if sys.version_info >= (3, 8): + _do_not_parse = (ast.FunctionType, ast.ClassDef) + NamedExpr = ast.NamedExpr +else: + _do_not_parse = (ast.ClassDef,) + + +_only_parse_body = (ast.FunctionDef, ast.AsyncFunctionDef) + + + +class ParentSetter(ast.NodeTransformer): + """ + Sets the ``parent`` attribute of each AST node to its parent. + """ + + def __init__(self, root: ast.AST = None): + self.parent = root + + def visit(self, node): + node.parent = self.parent + self.parent = node + self.generic_visit(node) + self.parent = node.parent + return node + + +def parent_found(parent: ast.AST, child: ast.AST) -> Tuple[bool, str, int]: + result = None + attr = None + index = None + + if not hasattr(parent, 'body'): + result = False + elif child in parent.body: + result = True + attr = 'body' + index = parent.body.index(child) + elif hasattr(parent, 'orelse') and child in parent.orelse: + result = True + attr = 'orelse' + index = parent.orelse.index(child) + elif hasattr(parent, 'finalbody') and child in parent.finalbody: + result = True + attr = 'finalbody' + index = parent.finalbody.index(child) + else: + result = False + + return result, attr, index + + + +def find_parent_body(node: ast.AST) -> Tuple[ast.AST, int]: + """ + Finds the parent AST node that has a ``body`` attribute, and the index of the given node within that body. + :param node: The node to find the parent of. + :return: A tuple of the parent AST node and the index of the given node within its body. + """ + last_parent = node.parent + found, attr, idx = parent_found(last_parent, node) + while not found: + new_parent = last_parent.parent + last_parent, new_parent = new_parent, last_parent + found, attr, idx = parent_found(last_parent, new_parent) + return last_parent, attr, idx + + +class NameGetter(ast.NodeVisitor): + """ + Collects all names in an AST. + """ + + def __init__(self): + self.names = set() + + def visit_Name(self, node: ast.Name): + self.names.add(node.id) + self.generic_visit(node) + + +def find_new_name(names: Set[str], base: str = '__var_') -> str: + """ + Finds a new name that does not exist in the given set of names. + :param names: A set of names to avoid. + :return: A new name that does not exist in the given set of names. + """ + + i = 0 + name = f"{base}{i}" + while name in names: + i += 1 + name = f"{base}{i}" + return name + + +class ExpressionUnnester(ast.NodeTransformer): + """ + unnests expressions in a given AST. + """ + + def __init__(self, names: Set[str] = None): + self.names = names or set() + self.ast_nodes_to_add = [] + + def _new_val(self, old_val: Union[ast.AST, Any]) -> Union[ast.Constant, ast.Name]: + + if old_val is None: + return old_val + + if not isinstance(old_val, ast.AST): + return old_val + + if isinstance(old_val, (ast.Constant, ast.Name)): + return old_val + + if hasattr(old_val, 'ctx') and isinstance(old_val.ctx, ast.Del): + old_val.ctx = ast.Load() + + old_val = self.visit(old_val) + + if isinstance(old_val, (ast.Constant, ast.Name)): + return old_val + + new_val = self._new_name(old_val) + self._new_assign(old_val, new_val.id) + new_val.parent = old_val.parent + + return new_val + + def _new_name(self, old_node: ast.AST) -> ast.Name: + + new_id = find_new_name(self.names) + self.names.add(new_id) + new_name = ast.Name(id=new_id, ctx=ast.Load()) + ast.copy_location(new_name, old_node) + + return new_name + + def _new_assign(self, old_node: ast.AST, new_id: str) -> None: + + val = old_node + if isinstance(old_node, ast.Slice): + val = ast.Call(func=ast.Name(id='slice', ctx=ast.Load()), + args=[old_node.lower, old_node.upper, old_node.step], keywords=[]) + + parent, attr, idx = find_parent_body(old_node.parent) + assign = ast.Assign(targets=[ast.Name(id=new_id, ctx=ast.Store())], value=val) + self.ast_nodes_to_add.append((parent, attr, idx, assign)) + + def visit(self, node: ast.AST) -> ast.AST: + if isinstance(node, _do_not_parse): + return node + if isinstance(node, _only_parse_body): + for stmt in node.body: + self.visit(stmt) + return node + return super().visit(node) + + ##### Statements ##### + + def visit_Return(self, node: ast.Return) -> ast.Return: + + if node.value is not None: + if isinstance(node.value, ast.Tuple): + node.value = self.visit(node.value) + else: + node.value = self._new_val(node.value) + return node + + def visit_Delete(self, node: ast.Delete) -> ast.Delete: + + targets = [] + for target in node.targets: + if isinstance(target, ast.Subscript): + target = self.visit(target) + else: + target = self._new_val(target) + if hasattr(target, 'ctx') and not isinstance(target.ctx, ast.Del): + target.ctx = ast.Del() + targets.append(target) + node.targets = targets + return node + + def visit_Assign(self, node: ast.Assign) -> ast.Assign: + + # TODO: How to handle swaps? + return self.generic_visit(node) + + def visit_AugAssign(self, node: ast.AugAssign) -> ast.AugAssign: + + return self.generic_visit(node) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> ast.AnnAssign: + + return self.generic_visit(node) + + def visit_For(self, node: ast.For) -> ast.For: + + # TODO: Do we want to break down iterator structures to their components? + return self.generic_visit(node) + + def visit_AsyncFor(self, node: ast.AsyncFor) -> ast.AsyncFor: + + return self.generic_visit(node) + + def visit_While(self, node: ast.While) -> ast.While: + + # NOTE: We cannot unnest the test expression because it has to be reevaluated every iteration. + # TODO: Rewrite the test expression as a function call? + # TODO: Unnest the test expression and repeat it at the end of the loop body? + node.body = [self.visit(stmt) for stmt in node.body] + node.orelse = [self.visit(stmt) for stmt in node.orelse] + return node + + def visit_If(self, node: ast.If) -> ast.If: + + node.test = self._new_val(node.test) + return self.generic_visit(node) + + ##### Expressions ##### + + def visit_BoolOp(self, node: ast.BoolOp) -> ast.BoolOp: + + node.values = [self._new_val(val) for val in node.values] + return node + + def visit_NamedExpr(self, node: NamedExpr) -> ast.Name: + + node.target = self._new_val(node.target) + node.value = self._new_val(node.value) + + self._new_assign(node.value, node.target.id) + + new_node = node.target + new_node.ctx = ast.Load() + new_node.parent = node.parent + + return new_node + + def visit_BinOp(self, node: ast.BinOp) -> ast.BinOp: + + node.left = self._new_val(node.left) + node.right = self._new_val(node.right) + + return node + + def visit_UnaryOp(self, node: ast.UnaryOp) -> ast.UnaryOp: + + node.operand = self._new_val(node.operand) + return node + + def visit_Lambda(self, node: ast.Lambda) -> ast.Lambda: + + # NOTE: We cannot unnest the body of a lambda, so we just return the node. + # TODO: We could unnest the body of a lambda to a proper function definition. + + return node + + def visit_IfExp(self, node: ast.IfExp) -> ast.IfExp: + + node.test = self._new_val(node.test) + node.body = self._new_val(node.body) + node.orelse = self._new_val(node.orelse) + return node + + def visit_Dict(self, node: ast.Dict) -> ast.Dict: + + node.keys = [key if isinstance(key, ast.Name) else self._new_val(key) for key in node.keys] + node.values = [val if isinstance(val, ast.Name) else self._new_val(val) for val in node.values] + return node + + def visit_Set(self, node: ast.Set) -> ast.Set: + + node.elts = [elt if isinstance(elt, ast.Name) else self._new_val(elt) for elt in node.elts] + return node + + def visit_ListComp(self, node: ast.ListComp) -> ast.ListComp: + + # NOTE: We cannot unnest a ListComp's elt because it likely depends on the generators + # TODO: Unnest to for-loops or Maps? + return node + + def visit_SetComp(self, node: ast.SetComp) -> ast.SetComp: + + # NOTE: We cannot unnest a SetComp's elt because it likely depends on the generators + # TODO: Unnest to for-loops or Maps? + return node + + def visit_DictComp(self, node: ast.DictComp) -> ast.DictComp: + + # NOTE: We cannot unnest a DictComp's key-value pair because it likely depends on the generators + # TODO: Unnest to for-loops or Maps? + return node + + def visit_GeneratorExp(self, node: ast.GeneratorExp) -> ast.GeneratorExp: + + # NOTE: We cannot unnest a GeneratorExp's elt because it likely depends on the generators + # TODO: Unnest to for-loops or Maps? + return node + + def visit_Await(self, node: ast.Await) -> ast.Await: + + if isinstance(node.value, ast.Call): + return self.generic_visit(node) + node.value = self._new_val(node.value) + return node + + def visit_Yield(self, node: ast.Yield) -> ast.Yield: + + if node.value is not None: + node.value = self._new_val(node.value) + return node + + def visit_YieldFrom(self, node: ast.YieldFrom) -> ast.YieldFrom: + + node.value = self._new_val(node.value) + return node + + def visit_Compare(self, node: ast.Compare) -> ast.Compare: + + node.left = self._new_val(node.left) + node.comparators = [self._new_val(comp) for comp in node.comparators] + return node + + def visit_Call(self, node: ast.Call) -> ast.Call: + + node.func = self._new_val(node.func) + node.args = [self._new_val(arg) for arg in node.args] + for kw in node.keywords: + kw.value = self._new_val(kw.value) + return node + + def visit_FormattedValue(self, node: ast.FormattedValue) -> ast.FormattedValue: + + node.value = self._new_val(node.value) + node.conversion = self._new_val(node.conversion) + if node.format_spec is not None: + node.format_spec = self.visit(node.format_spec) + return node + + def visit_JoinedStr(self, node: ast.JoinedStr) -> ast.JoinedStr: + + # NOTE: JoinedStr's values must be only FormattedValues and Constants + node.values = [self.visit(val) for val in node.values] + return node + + def visit_Constant(self, node: ast.Constant) -> ast.Constant: + + return node + + def visit_Attribute(self, node: ast.Attribute) -> ast.Attribute: + + node.value = self._new_val(node.value) + node.attr = self._new_val(node.attr) + return node + + def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript: + + node.value = self._new_val(node.value) + node.slice = self._new_val(node.slice) + return node + + def visit_Starred(self, node: ast.Starred) -> ast.Starred: + + # TODO: How should this be handled? + return node + + def visit_Name(self, node: ast.Name) -> ast.Name: + + return node + + def visit_List(self, node: ast.List) -> ast.List: + + node.elts = [self._new_val(elt) for elt in node.elts] + return node + + def visit_Tuple(self, node: ast.Tuple) -> ast.Tuple: + + node.elts = [self._new_val(elt) for elt in node.elts] + return node + + def visit_Slice(self, node: ast.Slice) -> ast.Slice: + + node.lower = self._new_val(node.lower) or ast.Constant(value=0) + node.upper = self._new_val(node.upper) + node.step = self._new_val(node.step) or ast.Constant(value=1) + return node + + +class AttributeTransformer(ast.NodeTransformer): + """ + Transforms indirect attribute accesses to SSA. + """ + + def __init__(self, names: Set[str] = None): + self.names = names or set() + self.ast_nodes_to_add = [] + + def visit(self, node: ast.AST) -> ast.AST: + if isinstance(node, _do_not_parse): + return node + if isinstance(node, _only_parse_body): + for stmt in node.body: + self.visit(stmt) + return node + return super().visit(node) + + def visit_Call(self, node: ast.Call) -> ast.AST: + + node = self.generic_visit(node) + + if not isinstance(node.parent, (ast.Assign, ast.AnnAssign, ast.AugAssign)): + + new_id = find_new_name(self.names) + self.names.add(new_id) + + new_node = ast.Name(id=new_id, ctx=ast.Load()) + ast.copy_location(new_node, node) + new_node = ParentSetter(root=node.parent).visit(new_node) + + parent, attr, idx = find_parent_body(node) + assign = ast.Assign(targets=[ast.Name(id=new_id, ctx=ast.Store())], value=node) + self.ast_nodes_to_add.append((parent, attr, idx, assign)) + + else: + new_node = node + + return new_node + + def visit_Attribute(self, node: ast.Attribute) -> ast.AST: + + node = self.generic_visit(node) + + if isinstance(node.value, ast.Name) and isinstance(node.parent, (ast.Assign, ast.AnnAssign, ast.AugAssign)): + return node + + if not isinstance(node.value, ast.Name): + + new_id = find_new_name(self.names) + self.names.add(new_id) + + name_node = ast.Attribute(value=ast.Name(id=new_id, ctx=ast.Load()), attr=node.attr, ctx=node.ctx) + ast.copy_location(name_node, node) + name_node = ParentSetter(root=node.parent).visit(name_node) + + parent, attr, idx = find_parent_body(node) + assign = ast.Assign(targets=[ast.Name(id=new_id, ctx=ast.Store())], value=node.value) + self.ast_nodes_to_add.append((parent, attr, idx, assign)) + + else: + name_node = node + + if not isinstance(node.parent, (ast.Assign, ast.AnnAssign, ast.AugAssign, ast.Call)): + + new_id = find_new_name(self.names) + self.names.add(new_id) + + new_node = ast.Name(id=new_id, ctx=node.ctx) + ast.copy_location(new_node, node) + new_node = ParentSetter(root=node.parent).visit(new_node) + + parent, attr, idx = find_parent_body(node) + assign = ast.Assign(targets=[ast.Name(id=new_id, ctx=ast.Store())], value=name_node) + self.ast_nodes_to_add.append((parent, attr, idx, assign)) + + else: + new_node = name_node + + return new_node + + +class NestedSubsAttrsReplacer(ast.NodeTransformer): + """ + Replaces nested subscript and attribute accesses with temporary variables. + """ + + def __init__(self, names: Set[str] = None): + self.names = names or set() + self.ast_nodes_to_add = [] + + def visit_FunctionDef(self, node: ast.FunctionDef): + for stmt in node.body: + self.visit(stmt) + return node + + def visit_Attribute(self, node: ast.Attribute): + + self.generic_visit(node) + + if isinstance(node.value, ast.Name): + return node + + new_id = find_new_name(self.names) + self.names.add(new_id) + + new_node = ast.Attribute(value=ast.Name(id=new_id, ctx=ast.Load()), attr=node.attr, ctx=node.ctx) + ast.copy_location(new_node, node) + new_node = ParentSetter(root=node.parent).visit(new_node) + + parent, body_idx = find_parent_body(node) + assign = ast.Assign(targets=[ast.Name(id=new_id, ctx=ast.Store())], value=node.value) + self.ast_nodes_to_add.append((parent, body_idx, assign)) + + return new_node + + def visit_Subscript(self, node: ast.Subscript): + + self.generic_visit(node) + + if hasattr(node.slice, 'elts'): + + new_elts = [] + for item in node.slice.elts: + + if not isinstance(item, (ast.Slice, ast.Name, ast.Constant)): + + self.generic_visit(item) + + new_id = find_new_name(self.names) + self.names.add(new_id) + + new_item = ast.Name(id=new_id, ctx=ast.Load()) + ast.copy_location(new_item, item) + new_elts.append(new_item) + + parent, body_idx = find_parent_body(item) + assign = ast.Assign(targets=[ast.Name(id=new_id, ctx=ast.Store())], value=item) + self.ast_nodes_to_add.append((parent, body_idx, assign)) + + elif isinstance(item, ast.Slice): + + self.generic_visit(item) + + for attr in ['lower', 'upper', 'step']: + if hasattr(item, attr): + + old_attr = getattr(item, attr) + if old_attr is None or isinstance(old_attr, (ast.Name, ast.Constant)): + continue + + new_id = find_new_name(self.names) + self.names.add(new_id) + + new_attr = ast.Name(id=new_id, ctx=ast.Load()) + ast.copy_location(new_attr, item) + setattr(item, attr, new_attr) + + parent, body_idx = find_parent_body(item) + assign = ast.Assign(targets=[ast.Name(id=new_id, ctx=ast.Store())], value=old_attr) + self.ast_nodes_to_add.append((parent, body_idx, assign)) + + new_elts.append(item) + + else: + new_elts.append(item) + + new_node = ast.Subscript(value=node.value, slice=ast.Tuple(elts=new_elts, ctx=ast.Load()), ctx=node.ctx) + + else: + + new_node = node.slice + + if not isinstance(node.slice, (ast.Slice, ast.Name, ast.Constant)): + + self.generic_visit(node.slice) + + new_id = find_new_name(self.names) + self.names.add(new_id) + + new_node = ast.Name(id=new_id, ctx=ast.Load()) + ast.copy_location(new_node, node.slice) + + parent, body_idx = find_parent_body(node.slice) + assign = ast.Assign(targets=[ast.Name(id=new_id, ctx=ast.Store())], value=node.slice) + self.ast_nodes_to_add.append((parent, body_idx, assign)) + + new_node = ast.Subscript(value=node.value, slice=new_node, ctx=node.ctx) + + ast.copy_location(new_node, node) + new_node = ParentSetter(root=node.parent).visit(new_node) + + return new_node + + + class DaceRecursionError(Exception): """ Exception that indicates a recursion in a data-centric parsed context. @@ -1631,6 +2219,18 @@ def preprocess_dace_program(f: Callable[..., Any], raise TypeError(f'Converting function "{f.__name__}" ({src_file}:{src_line}) to callback due to disallowed ' f'keyword: {disallowed}') + name_getter = NameGetter() + name_getter.visit(src_ast) + program_names = name_getter.names + src_ast = ParentSetter().visit(src_ast) + unnester = ExpressionUnnester(names=program_names) + src_ast = unnester.visit(src_ast) + for parent, attr, idx, node in reversed(unnester.ast_nodes_to_add): + getattr(parent, attr).insert(idx, node) + ast.fix_missing_locations(src_ast) + + print(astutils.unparse(src_ast)) + passes = int(Config.get('frontend', 'preprocessing_passes')) if passes >= 0: gen = range(passes) diff --git a/tests/python_frontend/preprocessing/expression_unnester_test.py b/tests/python_frontend/preprocessing/expression_unnester_test.py new file mode 100644 index 0000000000..14552cfe8e --- /dev/null +++ b/tests/python_frontend/preprocessing/expression_unnester_test.py @@ -0,0 +1,724 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Tests the preprocessing functionality of unnesting expressions. """ + +import ast +import inspect +import itertools +import numpy as np +import warnings + +from collections.abc import Callable +from dace.frontend.python.astutils import _remove_outer_indentation +from dace.frontend.python import preprocessing as pr +from dataclasses import dataclass, field +from numpy import typing as npt +from typing import Any, Dict, List, Tuple + + +##### Helper functions ##### + + +def _unnest(func: Callable[..., Any], expected_var_num: int, context: Dict[str, Any] = None) -> Callable[..., Any]: + + function_ast = ast.parse(_remove_outer_indentation(inspect.getsource(func))) + + name_getter = pr.NameGetter() + name_getter.visit(function_ast) + program_names = name_getter.names + + function_ast = pr.ParentSetter().visit(function_ast) + unnester = pr.ExpressionUnnester(names=program_names) + function_ast = unnester.visit(function_ast) + for parent, attr, idx, node in reversed(unnester.ast_nodes_to_add): + getattr(parent, attr).insert(idx, node) + + ast.fix_missing_locations(function_ast) + print(ast.unparse(function_ast)) + + _validate_unnesting(function_ast, expected_var_num) + + code = compile(function_ast, filename='', mode='exec') + context = context or {} + namespace = {**globals(), **context} + exec(code, namespace) + unnested_function = namespace[func.__name__] + + return unnested_function + + +def _validate_unnesting(unnested_ast: ast.AST, expected_var_num: int) -> None: + + name_getter = pr.NameGetter() + name_getter.visit(unnested_ast) + program_names = name_getter.names + + for i in range(expected_var_num): + name = f'__var_{i}' + assert name in program_names + assert f'__var_{expected_var_num}' not in program_names + + +##### Tests for Statements ##### + + +def test_Return(): + + def original_function(a: int, b: int) -> int: + return a + b + + new_function = _unnest(original_function, 1) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +def test_Delete(): + + def original_function(a: Dict[int, Any], b: int) -> int: + for i in range(b): + del a[b - i] + + + new_function = _unnest(original_function, 1) + + rng = np.random.default_rng(42) + randints = rng.integers(10, 100, size=(10,)) + for s in randints: + ref = {i: i for i in range(s)} + val = {i: i for i in range(s)} + original_function(ref, s - 1) + new_function(val, s - 1) + assert ref == val + + +def test_Assign(): + + def original_function(a: int, b: int) -> int: + c, d = a + b, a - b + return c, d + + new_function = _unnest(original_function, 2) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +def test_AugAssign(): + + def original_function(a: int, b: int) -> int: + a += b + return a + + new_function = _unnest(original_function, 0) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +def test_AnnAssign(): + + def original_function(a: int, b: int) -> int: + c: int = a + b + return c + + new_function = _unnest(original_function, 0) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +def test_For(): + + def original_function(a: int, b: int) -> int: + for i in range(b): + a += i + return a + + new_function = _unnest(original_function, 0) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +def test_While(): + + def original_function(a: int, b: int) -> int: + while min(a, b) < b: + a += 1 + return a + + new_function = _unnest(original_function, 0) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +def test_If(): + + def original_function(a: int, b: int) -> int: + if a < b: + a += 1 + return a + + new_function = _unnest(original_function, 1) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +##### Tests for Expressions ##### + + +def test_BoolOp(): + + def original_function(a: bool, b: bool, c: bool, d: bool, e: bool) -> bool: + e = (a and (b or False)) or (c and (d or e)) + return e + + new_function = _unnest(original_function, 4) + + for (a, b, c, d, e) in itertools.permutations([True, False], 5): + assert original_function(a, b, c, d, e) == new_function(a, b, c, d, e) + + +def test_NamedExpr(): + + def original_function(a: int) -> int: + y = (x := a + 7) + x + return y + + new_function = _unnest(original_function, 1) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10,)) + for a in randints: + assert original_function(a) == new_function(a) + + +def test_BinOp(): + + def original_function(a: int, b: int) -> int: + c = ((a + b) * (a - b)) ** 2 + return c + + new_function = _unnest(original_function, 3) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +def test_UnaryOp(): + + def original_function(a: int, b: int) -> int: + c = - (a + b) + return b + + new_function = _unnest(original_function, 1) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +def test_Lambda(): + + def original_function(a: int, b: int) -> int: + f = lambda x: x + a * 2 + return f(b) + + new_function = _unnest(original_function, 1) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +def test_IfExp(): + + def original_function(a: int, b: int) -> int: + c = a - b if a > b else b - a + return c + + new_function = _unnest(original_function, 3) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +def test_Dict(): + + def original_function(a: int, b: int) -> int: + c = {a + b: a - b, a - b: a + b} + return c + + new_function = _unnest(original_function, 4) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +def test_Set(): + + def original_function(a: int, b: int) -> int: + c = {a + b, a - b} + return c + + new_function = _unnest(original_function, 2) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +def test_Await(): + + import asyncio + + async def original_function(a: int) -> int: + await asyncio.sleep(0.001 * a * 2) + return a + + new_function = _unnest(original_function, 3, locals()) + + rng = np.random.default_rng(42) + randints = rng.integers(1, 10, size=(10,)) + for a in randints: + assert asyncio.run(original_function(a)) == asyncio.run(new_function(a)) + + +def test_Yield(): + + def original_function(a: int) -> int: + yield a + 3 + yield a + 2 + yield a + 1 + + new_function = _unnest(original_function, 3) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10,)) + for a in randints: + assert list(original_function(a)) == list(new_function(a)) + + +def test_YieldFrom(): + + def x(n: int) -> int: + for i in range(n): + yield i + + def y(n: int) -> int: + for i in reversed(range(n)): + yield i + + def original_function(n: int) -> int: + yield from itertools.chain(x(n), y(n)) + + new_function = _unnest(original_function, 4, locals()) + + rng = np.random.default_rng(42) + randints = rng.integers(1, 10, size=(10,)) + for n in randints: + assert list(original_function(n)) == list(new_function(n)) + + +def test_Compare(): + + def original_function(a: int, b: int) -> int: + c = a + b > a > a - b > b + return c + + new_function = _unnest(original_function, 2) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +def test_Call(): + + def x(n: int) -> int: + return 2 * n + + def original_function(a: int, b: int) -> int: + c = x(a + b) * x(a - b) + return c + + new_function = _unnest(original_function, 4, locals()) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +def test_JoinedStr(): + # NOTE: Also tests FormattedValue + + def original_function(a: int) -> str: + string = f'"sin({a}) is {np.sin(a):.3}"' + return string + + print(original_function(5)) + + new_function = _unnest(original_function, 2, locals()) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10,)) + for a in randints: + assert original_function(a) == new_function(a) + + +def test_Attribute(): + + def original_function(a: npt.NDArray[np.int32], b: int) -> int: + c = (a + b).T.size + return c + + new_function = _unnest(original_function, 2, locals()) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for s, b in randints: + a = np.arange(s, dtype=np.int32) + assert original_function(a, b) == new_function(a, b) + + +def test_Attribute_1(): + + @dataclass + class MyClass: + a: npt.NDArray[np.int32] + b: npt.NDArray[np.int32] = field(default=None, init=False) + + def original_function(arr: MyClass, b: int) -> int: + arr.a += b + arr.b = MyClass(arr.a) + arr.b.b = MyClass(arr.a) + + new_function = _unnest(original_function, 3, locals()) + + a_ref = MyClass(np.arange(10, dtype=np.int32)) + a_val = MyClass(np.arange(10, dtype=np.int32)) + + original_function(a_ref, 1) + new_function(a_val, 1) + + assert np.array_equal(a_ref.a, a_val.a) + assert np.array_equal(a_ref.b.a, a_val.b.a) + assert np.array_equal(a_ref.b.b.a, a_val.b.b.a) + assert a_val.b.b.b is None + + +def test_Subscript(): + + def original_function(a: npt.NDArray[np.int32], b: int) -> int: + c = (a + b)[a[b - 1]] + return c + + new_function = _unnest(original_function, 3, locals()) + + rng = np.random.default_rng(42) + randints = rng.integers(1, 100, size=(10,)) + for b in randints: + a = np.arange(b, dtype=np.int32) + assert original_function(a, b) == new_function(a, b) + + +def test_List(): + + def original_function(a: int, b: int) -> List[int]: + c = [a + b, a - b] + return c + + new_function = _unnest(original_function, 2) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +def test_Tuple(): + + def original_function(a: int, b: int) -> Tuple[int, int]: + c = (a + b, a - b) + return c + + new_function = _unnest(original_function, 2) + + rng = np.random.default_rng(42) + randints = rng.integers(-99, 100, size=(10, 2)) + for (a, b) in randints: + assert original_function(a, b) == new_function(a, b) + + +def test_Slice(): + + def original_function(a: npt.NDArray[np.int32], b: npt.NDArray[np.int32], c: npt.NDArray[np.int32]) -> int: + c = a[b[c[0]:c[1]]] + return c + + new_function = _unnest(original_function, 4, locals()) + + rng = np.random.default_rng(42) + randints = rng.integers(1, 100, size=(10,)) + for s in randints: + a = np.arange(s, dtype=np.int32) + b = np.arange(s, dtype=np.int32) + n = rng.integers(0, s) + c = np.array([n, min(s-1, n + 2)], dtype=np.int32) + assert np.array_equal(original_function(a, b, c), new_function(a, b, c)) + + +def test_Slice_1(): + + def original_function(a: npt.NDArray[np.int32], b: npt.NDArray[np.int32], c: npt.NDArray[np.int32]): + a[b[c[0]:c[1]]] = 1000 + + new_function = _unnest(original_function, 4, locals()) + + rng = np.random.default_rng(42) + randints = rng.integers(1, 100, size=(10,)) + for s in randints: + a_ref = np.arange(s, dtype=np.int32) + a_val = a_ref.copy() + b = np.arange(s, dtype=np.int32) + n = rng.integers(0, s) + c = np.array([n, min(s-1, n + 2)], dtype=np.int32) + original_function(a_ref, b, c) + new_function(a_val, b, c) + assert np.array_equal(a_ref, a_val) + + +##### Mixed tests ##### + + +def test_mixed(): + + try: + from scipy import sparse + except ImportError: + warnings.warn('Skipping mixed test, scipy not installed') + return + + def original_function(A: sparse.csr_matrix, B: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: + C = np.zeros((A.shape[0], B.shape[1]), dtype=A.dtype) + for i, j in itertools.product(range(A.shape[0]), range(B.shape[1])): + for k in range(A.indptr[i], A.indptr[i + 1]): + C[i, j] += A.data[k] * B[A.indices[k], j] + return C + + new_function = _unnest(original_function, 26, locals()) + + rng = np.random.default_rng(42) + for _ in range(10): + A = sparse.random(20, 10, density=0.1, format='csr', dtype=np.float32, random_state=rng) + B = rng.random((10, 5), dtype=np.float32) + assert np.allclose(original_function(A, B), new_function(A, B)) + + +def test_mixed_1(): + + try: + from scipy import sparse + except ImportError: + warnings.warn('Skipping mixed test, scipy not installed') + return + + def original_function(A: List[sparse.csr_matrix], B: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: + C = np.zeros((len(A), A[0].shape[0], B.shape[-1]), dtype=np.float32) + for l, i, j in itertools.product(range(len(A)), range(A[0].shape[0]), range(B.shape[-1])): + for k in range(A[l].indptr[i], A[l].indptr[i + 1]): + C[l, i, j] += A[l].data[k] * B[l, A[l].indices[k], j] + return C + + new_function = _unnest(original_function, 37, locals()) + + rng = np.random.default_rng(42) + for _ in range(10): + A = [sparse.random(20, 10, density=0.1, format='csr', dtype=np.float32, random_state=rng) for _ in range(5)] + B = rng.random((5, 10, 5), dtype=np.float32) + assert np.allclose(original_function(A, B), new_function(A, B)) + + +def test_mixed_2(): + + try: + from scipy import sparse + except ImportError: + warnings.warn('Skipping mixed test, scipy not installed') + return + + def original_function(A: sparse.csr_matrix, + B: npt.NDArray[np.float32], + C: npt.NDArray[np.float32]) -> sparse.csr_matrix: + D = sparse.csr_matrix((np.zeros(A.nnz, dtype=A.dtype), A.indices, A.indptr), shape=A.shape) + for i in range(A.shape[0]): + for j in range(A.indptr[i], A.indptr[i + 1]): + for k in range(B.shape[1]): + D.data[j] += A.data[j] * B[i, k] * C[k, A.indices[j]] + return D + + new_function = _unnest(original_function, 28, locals()) + + rng = np.random.default_rng(42) + for _ in range(10): + A = sparse.random(20, 10, density=0.1, format='csr', dtype=np.float32, random_state=rng) + B = rng.random((20, 5), dtype=np.float32) + C = rng.random((5, 10), dtype=np.float32) + assert np.allclose(original_function(A, B, C).todense(), new_function(A, B, C).todense()) + + +def test_mixed_3(): + + def match(b1: int, b2: int) -> int: + if b1 + b2 == 3: + return 1 + else: + return 0 + + def original_function(N: int, seq: npt.NDArray[np.int32]) -> npt.NDArray[np.int32]: + + table = np.zeros((N, N), np.int32) + + for i in range(N - 1, -1, -1): + for j in range(i + 1, N): + if j - 1 >= 0: + table[i, j] = max(table[i, j], table[i, j - 1]) + if i + 1 < N: + table[i, j] = max(table[i, j], table[i + 1, j]) + if j - 1 >= 0 and i + 1 < N: + if i < j - 1: + table[i, + j] = max(table[i, j], + table[i + 1, j - 1] + match(seq[i], seq[j])) + else: + table[i, j] = max(table[i, j], table[i + 1, j - 1]) + for k in range(i + 1, j): + table[i, j] = max(table[i, j], table[i, k] + table[k + 1, j]) + + return table + + new_function = _unnest(original_function, 58, locals()) + + rng = np.random.default_rng(42) + for _ in range(10): + N = rng.integers(10, 20) + seq = rng.integers(0, 4, N) + assert np.allclose(original_function(N, seq), new_function(N, seq)) + + +def test_mixed_4(): + + # ----------------------------------------------------------------------------- + # From Numpy to Python + # Copyright (2017) Nicolas P. Rougier - BSD license + # More information at https://github.com/rougier/numpy-book + # ----------------------------------------------------------------------------- + + def original_function(xmin, xmax, ymin, ymax, xn, yn, itermax, horizon=2.0): + # Adapted from https://thesamovar.wordpress.com/2009/03/22/fast-fractals-with-python-and-numpy/ + Xi, Yi = np.mgrid[0:xn, 0:yn] + X = np.linspace(xmin, xmax, xn, dtype=np.float64)[Xi] + Y = np.linspace(ymin, ymax, yn, dtype=np.float64)[Yi] + C = X + Y * 1j + N_ = np.zeros(C.shape, dtype=np.int64) + Z_ = np.zeros(C.shape, dtype=np.complex128) + Xi.shape = Yi.shape = C.shape = xn * yn + + Z = np.zeros(C.shape, np.complex128) + for i in range(itermax): + if not len(Z): + break + + # Compute for relevant points only + np.multiply(Z, Z, Z) + np.add(Z, C, Z) + + # Failed convergence + I = abs(Z) > horizon + N_[Xi[I], Yi[I]] = i + 1 + Z_[Xi[I], Yi[I]] = Z[I] + + # Keep going with those who have not diverged yet + np.logical_not(I, I) # np.negative(I, I) not working any longer + Z = Z[I] + Xi, Yi = Xi[I], Yi[I] + C = C[I] + return Z_.T, N_.T + + new_function = _unnest(original_function, 36, locals()) + + rng = np.random.default_rng(42) + for _ in range(10): + xmin = rng.random() + xmax = xmin + rng.random() + ymin = rng.random() + ymax = ymin + rng.random() + xn = rng.integers(10, 20) + yn = rng.integers(10, 20) + itermax = rng.integers(10, 20) + assert np.allclose(original_function(xmin, xmax, ymin, ymax, xn, yn, itermax)[0], + new_function(xmin, xmax, ymin, ymax, xn, yn, itermax)[0]) + + +if __name__ == '__main__': + test_Return() + test_Delete() + test_Assign() + test_AugAssign() + test_AnnAssign() + test_For() + test_While() + test_If() + test_BoolOp() + test_NamedExpr() + test_BinOp() + test_UnaryOp() + test_Lambda() + test_IfExp() + test_Dict() + test_Set() + test_Await() + test_Yield() + test_YieldFrom() + test_Compare() + test_Call() + test_JoinedStr() + test_Attribute() + test_Attribute_1() + test_Subscript() + test_List() + test_Tuple() + test_Slice() + test_Slice_1() + test_mixed() + test_mixed_1() + test_mixed_2() + test_mixed_3() + test_mixed_4() diff --git a/tests/python_frontend/preprocessing/nested_subscripts_attributes_test.py b/tests/python_frontend/preprocessing/nested_subscripts_attributes_test.py new file mode 100644 index 0000000000..f59a65958b --- /dev/null +++ b/tests/python_frontend/preprocessing/nested_subscripts_attributes_test.py @@ -0,0 +1,173 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Tests the preprocessing functionality of unnesting nested subscripts and attributes. """ + +import ast +import inspect +import numpy as np + +from dace.frontend.python.astutils import _remove_outer_indentation +from dace.frontend.python import preprocessing as pr +from numpy import typing as npt + + +def test_attribute_on_op(): + + def original_function(A: npt.NDArray[np.int32], B: npt.NDArray[np.int32]) -> npt.NDArray[np.int32]: + return (A + B).T + + function_ast = ast.parse(_remove_outer_indentation(inspect.getsource(original_function))) + + name_getter = pr.NameGetter() + name_getter.visit(function_ast) + program_names = name_getter.names + + function_ast = pr.ParentSetter().visit(function_ast) + subatrr_replacer = pr.AttributeTransformer(names=program_names) + function_ast = subatrr_replacer.visit(function_ast) + for parent, attr, idx, node in reversed(subatrr_replacer.ast_nodes_to_add): + getattr(parent, attr).insert(idx, node) + + ast.fix_missing_locations(function_ast) + + name_getter_2 = pr.NameGetter() + name_getter_2.visit(function_ast) + program_names_2 = name_getter_2.names + + for i in range(2): + name = f'__var_{i}' + assert name in program_names_2 + + code = compile(function_ast, filename='', mode='exec') + print(ast.unparse(function_ast)) + + +def test_attribute_call(): + + def original_function(A: npt.NDArray[np.int32]) -> npt.NDArray[np.int32]: + return A.sum(axis=0) + + function_ast = ast.parse(_remove_outer_indentation(inspect.getsource(original_function))) + + name_getter = pr.NameGetter() + name_getter.visit(function_ast) + program_names = name_getter.names + + function_ast = pr.ParentSetter().visit(function_ast) + subatrr_replacer = pr.AttributeTransformer(names=program_names) + function_ast = subatrr_replacer.visit(function_ast) + for parent, attr, idx, node in reversed(subatrr_replacer.ast_nodes_to_add): + getattr(parent, attr).insert(idx, node) + + ast.fix_missing_locations(function_ast) + + name_getter_2 = pr.NameGetter() + name_getter_2.visit(function_ast) + program_names_2 = name_getter_2.names + + for i in range(1): + name = f'__var_{i}' + assert name in program_names_2 + + code = compile(function_ast, filename='', mode='exec') + print(ast.unparse(function_ast)) + + +def test_attribute_call_on_op(): + + def original_function(A: npt.NDArray[np.int32], B: npt.NDArray[np.int32]) -> npt.NDArray[np.int32]: + return (A + B).sum(axis=0) + + function_ast = ast.parse(_remove_outer_indentation(inspect.getsource(original_function))) + + name_getter = pr.NameGetter() + name_getter.visit(function_ast) + program_names = name_getter.names + + function_ast = pr.ParentSetter().visit(function_ast) + subatrr_replacer = pr.AttributeTransformer(names=program_names) + function_ast = subatrr_replacer.visit(function_ast) + for parent, attr, idx, node in reversed(subatrr_replacer.ast_nodes_to_add): + getattr(parent, attr).insert(idx, node) + + ast.fix_missing_locations(function_ast) + + name_getter_2 = pr.NameGetter() + name_getter_2.visit(function_ast) + program_names_2 = name_getter_2.names + + for i in range(1): + name = f'__var_{i}' + assert name in program_names_2 + + code = compile(function_ast, filename='', mode='exec') + print(ast.unparse(function_ast)) + + +def original_function(A: npt.NDArray[np.int32], + i0: npt.NDArray[np.int32], + i1: npt.NDArray[np.int32], + i2: npt.NDArray[np.int32], + i3: npt.NDArray[np.int32], + i4: npt.NDArray[np.int32]) -> npt.NDArray[np.int32]: + + B = np.zeros_like(A) + + for i in range(A.shape[0]): + for j in range(A.shape[1]): + for k in range(A.shape[2]): + for l in range(A.shape[3]): + B[i, :max(i2[j], i3[j]) - min(i2[j], i3[j]), k, l] = ( + A[i0[i1[i]], min(i2[j], i3[j]) : max(i2[j], i3[j]), k, i4[l]]) + + return (A + B).sum(axis=0) + + +def test_0(): + + function_ast = ast.parse(inspect.getsource(original_function)) + + name_getter = pr.NameGetter() + name_getter.visit(function_ast) + program_names = name_getter.names + + function_ast = pr.ParentSetter().visit(function_ast) + subatrr_replacer = pr.NestedSubsAttrsReplacer(names=program_names) + function_ast = subatrr_replacer.visit(function_ast) + for parent, idx, node in reversed(subatrr_replacer.ast_nodes_to_add): + parent.body.insert(idx, node) + + ast.fix_missing_locations(function_ast) + + name_getter_2 = pr.NameGetter() + name_getter_2.visit(function_ast) + program_names_2 = name_getter_2.names + + for i in range(7): + name = f'__var_{i}' + assert name in program_names_2 + + code = compile(function_ast, filename='', mode='exec') + namespace = {**globals()} + exec(code, namespace) + new_function = namespace['original_function'] + + rng = np.random.default_rng(42) + + A = rng.integers(0, 100, size=(5, 5, 5, 5), dtype=np.int32) + i0 = rng.integers(0, 5, size=(5,), dtype=np.int32) + i1 = rng.integers(0, 5, size=(5,), dtype=np.int32) + i2 = rng.integers(0, 5, size=(5,), dtype=np.int32) + i3 = rng.integers(0, 5, size=(5,), dtype=np.int32) + i4 = rng.integers(0, 5, size=(5,), dtype=np.int32) + + ref = original_function(A, i0, i1, i2, i3, i4) + val = new_function(A, i0, i1, i2, i3, i4) + + assert np.allclose(ref, val) + + +if __name__ == '__main__': + # test_0() + test_attribute_on_op() + test_attribute_call() + test_attribute_call_on_op()