Skip to content

Commit

Permalink
Fix Subscript literal evaluation for List (#1570)
Browse files Browse the repository at this point in the history
Looking at: #1568

The code was blindly calling down to a `_visit_potential_constant` which
is written for single element rather collection of them. Unroll the
list, like the `dict` is done in the `if` above.
  • Loading branch information
FlorianDeconinck authored May 8, 2024
1 parent 5339c71 commit 63adbd7
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 3 deletions.
18 changes: 15 additions & 3 deletions dace/frontend/python/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,16 +752,28 @@ def visit_Subscript(self, node: ast.Subscript) -> Any:
return self.generic_visit(node)

# Then query for the right value
if isinstance(node.value, ast.Dict):
if isinstance(node.value, ast.Dict): # Dict
for k, v in zip(node.value.keys, node.value.values):
try:
gkey = astutils.evalnode(k, self.globals)
except SyntaxError:
continue
if gkey == gslice:
return self._visit_potential_constant(v, True)
else: # List or Tuple
return self._visit_potential_constant(node.value.elts[gslice], True)
elif isinstance(node.value, (ast.List, ast.Tuple)): # List & Tuple
# Loop over the list if slicing makes it a list
if isinstance(node.value.elts[gslice], List):
visited_list = astutils.copy_tree(node.value)
visited_list.elts.clear()
for v in node.value.elts[gslice]:
visited_cst = self._visit_potential_constant(v, True)
visited_list.elts.append(visited_cst)
node.value = visited_list
return node
else:
return self._visit_potential_constant(node.value.elts[gslice], True)
else: # Catch-all
return self._visit_potential_constant(node, True)

return self._visit_potential_constant(node, True)

Expand Down
46 changes: 46 additions & 0 deletions tests/python_frontend/unroll_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,52 @@ def tounroll(A: dace.float64[3]):
assert np.allclose(a, np.array([1, 2, 3]))


def test_list_global_enumerate():
tracer_variables = ["vapor", "rain", "nope"]

@dace.program
def enumerate_parsing(
A,
tracers: dace.compiletime, # Dict[str, np.float64]
):
for i, q in enumerate(tracer_variables[0:2]):
tracers[q][:] = A # type:ignore

a = np.ones([3])
q = {
"vapor": np.zeros([3]),
"rain": np.zeros([3]),
"nope": np.zeros([3]),
}
enumerate_parsing(a, q)
assert np.allclose(q["vapor"], np.array([1, 1, 1]))
assert np.allclose(q["rain"], np.array([1, 1, 1]))
assert np.allclose(q["nope"], np.array([0, 0, 0]))


def test_tuple_global_enumerate():
tracer_variables = ("vapor", "rain", "nope")

@dace.program
def enumerate_parsing(
A,
tracers: dace.compiletime, # Dict[str, np.float64]
):
for i, q in enumerate(tracer_variables[0:2]):
tracers[q][:] = A # type:ignore

a = np.ones([3])
q = {
"vapor": np.zeros([3]),
"rain": np.zeros([3]),
"nope": np.zeros([3]),
}
enumerate_parsing(a, q)
assert np.allclose(q["vapor"], np.array([1, 1, 1]))
assert np.allclose(q["rain"], np.array([1, 1, 1]))
assert np.allclose(q["nope"], np.array([0, 0, 0]))


def test_tuple_elements_zip():
a1 = [2, 3, 4]
a2 = (4, 5, 6)
Expand Down

0 comments on commit 63adbd7

Please sign in to comment.