diff --git a/dace/symbolic.py b/dace/symbolic.py index 9737080c52..98ffa008d3 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -1,6 +1,7 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast from functools import lru_cache +import sys import sympy import pickle import re @@ -982,6 +983,32 @@ def _process_is(elem: Union[Is, IsNot]): return expr +# Depending on the Python version we need to handle different AST nodes to correctly interpret and detect falsy / truthy +# values. +if sys.version_info < (3, 8): + _SimpleASTNode = (ast.Constant, ast.Name, ast.NameConstant, ast.Num) + _SimpleASTNodeT = Union[ast.Constant, ast.Name, ast.NameConstant, ast.Num] + + def __comp_convert_truthy_falsy(node: _SimpleASTNodeT): + if isinstance(node, ast.Num): + node_val = node.n + elif isinstance(node, ast.Name): + node_val = node.id + else: + node_val = node.value + return ast.copy_location(ast.NameConstant(bool(node_val)), node) +else: + _SimpleASTNode = (ast.Constant, ast.Name) + _SimpleASTNodeT = Union[ast.Constant, ast.Name] + + def __comp_convert_truthy_falsy(node: _SimpleASTNodeT): + return ast.copy_location(ast.Constant(bool(node.value)), node) + +# Convert simple AST node (constant) into a falsy / truthy. Anything other than 0, None, and an empty string '' is +# considered a truthy, while the listed exceptions are considered falsy values - following the semantics of Python's +# bool() builtin. +_convert_truthy_falsy = __comp_convert_truthy_falsy + class PythonOpToSympyConverter(ast.NodeTransformer): """ Replaces various operations with the appropriate SymPy functions to avoid non-symbolic evaluation. @@ -1067,6 +1094,13 @@ def visit_Compare(self, node: ast.Compare): raise NotImplementedError op = node.ops[0] arguments = [node.left, node.comparators[0]] + + # Ensure constant values in boolean comparisons are interpreted als booleans. + if isinstance(node.left, ast.Compare) and isinstance(node.comparators[0], _SimpleASTNode): + arguments[1] = _convert_truthy_falsy(node.comparators[0]) + elif isinstance(node.left, _SimpleASTNode) and isinstance(node.comparators[0], ast.Compare): + arguments[0] = _convert_truthy_falsy(node.left) + func_node = ast.copy_location(ast.Name(id=self._ast_to_sympy_comparators[type(op)], ctx=ast.Load()), node) new_node = ast.Call(func=func_node, args=[self.visit(arg) for arg in arguments], keywords=[]) return ast.copy_location(new_node, node) diff --git a/tests/passes/dead_code_elimination_test.py b/tests/passes/dead_code_elimination_test.py index a41a11c4d6..1832ad8321 100644 --- a/tests/passes/dead_code_elimination_test.py +++ b/tests/passes/dead_code_elimination_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Various tests for dead code elimination passes. """ import numpy as np @@ -45,6 +45,26 @@ def test_dse_unconditional(): assert set(sdfg.states()) == {s, s2, e} +def test_dse_edge_condition_with_integer_as_boolean_regression(): + """ + This is a regression test for issue #1129, which describes dead state elimination incorrectly eliminating interstate + edges when integers are used as boolean values in interstate edge conditions. Code taken from issue #1129. + """ + sdfg = dace.SDFG('dse_edge_condition_with_integer_as_boolean_regression') + sdfg.add_scalar('N', dtype=dace.int32, transient=True) + sdfg.add_scalar('result', dtype=dace.int32) + state_init = sdfg.add_state() + state_middle = sdfg.add_state() + state_end = sdfg.add_state() + sdfg.add_edge(state_init, state_end, dace.InterstateEdge(condition='(not ((N > 20) != 0))', + assignments={'result': 'N'})) + sdfg.add_edge(state_init, state_middle, dace.InterstateEdge(condition='((N > 20) != 0)')) + sdfg.add_edge(state_middle, state_end, dace.InterstateEdge(assignments={'result': '20'})) + + res = DeadStateElimination().apply_pass(sdfg, {}) + assert res is None + + def test_dde_simple(): @dace.program @@ -307,6 +327,7 @@ def test_dce_add_type_hint_of_variable(dtype): if __name__ == '__main__': test_dse_simple() test_dse_unconditional() + test_dse_edge_condition_with_integer_as_boolean_regression() test_dde_simple() test_dde_libnode() test_dde_access_node_in_scope(False)