diff --git a/dace/frontend/fortran/ast_components.py b/dace/frontend/fortran/ast_components.py index 1e5bfb4528..d95fa87e58 100644 --- a/dace/frontend/fortran/ast_components.py +++ b/dace/frontend/fortran/ast_components.py @@ -1,6 +1,5 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -from fparser.two.Fortran2008 import Fortran2008 as f08 -from fparser.two import Fortran2008 +from fparser.two import Fortran2008 as f08 from fparser.two import Fortran2003 as f03 from fparser.two import symbol_table @@ -608,7 +607,7 @@ def type_declaration_stmt(self, node: FASTNode): if i.string.lower() == "parameter": symbol = True - if isinstance(i, Fortran2008.Attr_Spec_List): + if isinstance(i, f08.Attr_Spec_List): dimension_spec = get_children(i, "Dimension_Attr_Spec") if len(dimension_spec) == 0: @@ -1052,7 +1051,7 @@ def specification_part(self, node: FASTNode): decls = [self.create_ast(i) for i in node.children if isinstance(i, f08.Type_Declaration_Stmt)] - uses = [self.create_ast(i) for i in node.children if isinstance(i, f08.Use_Stmt)] + uses = [self.create_ast(i) for i in node.children if isinstance(i, f03.Use_Stmt)] tmp = [self.create_ast(i) for i in node.children] typedecls = [i for i in tmp if isinstance(i, ast_internal_classes.Type_Decl_Node)] symbols = [] diff --git a/dace/frontend/python/astutils.py b/dace/frontend/python/astutils.py index 4a0ec88531..faf214fdeb 100644 --- a/dace/frontend/python/astutils.py +++ b/dace/frontend/python/astutils.py @@ -705,3 +705,17 @@ def escape_string(value: Union[bytes, str]): return value.encode("unicode_escape").decode("utf-8") # Python 2.x return value.encode('string_escape') + + +def parse_function_arguments(node: ast.Call, argnames: List[str]) -> Dict[str, ast.AST]: + """ + Parses function arguments (both positional and keyword) from a Call node, + based on the function's argument names. If an argument was not given, it will + not be in the result. + """ + result = {} + for arg, aname in zip(node.args, argnames): + result[aname] = arg + for kw in node.keywords: + result[kw.arg] = kw.value + return result diff --git a/dace/frontend/python/interface.py b/dace/frontend/python/interface.py index ea1970dafd..69e650beaa 100644 --- a/dace/frontend/python/interface.py +++ b/dace/frontend/python/interface.py @@ -293,10 +293,11 @@ class tasklet(metaclass=TaskletMetaclass): The DaCe framework cannot analyze these tasklets for optimization. """ - def __init__(self, language: Union[str, dtypes.Language] = dtypes.Language.Python): + def __init__(self, language: Union[str, dtypes.Language] = dtypes.Language.Python, side_effects: bool = False): if isinstance(language, str): language = dtypes.Language[language] self.language = language + self.side_effects = side_effects def __enter__(self): if self.language != dtypes.Language.Python: diff --git a/dace/frontend/python/memlet_parser.py b/dace/frontend/python/memlet_parser.py index 6ef627a430..7cc218c4fb 100644 --- a/dace/frontend/python/memlet_parser.py +++ b/dace/frontend/python/memlet_parser.py @@ -285,7 +285,11 @@ def ParseMemlet(visitor, if len(node.value.args) >= 2: write_conflict_resolution = node.value.args[1] - subset, new_axes, arrdims = parse_memlet_subset(array, node, das, parsed_slice) + try: + subset, new_axes, arrdims = parse_memlet_subset(array, node, das, parsed_slice) + except IndexError: + raise DaceSyntaxError(visitor, node, 'Failed to parse memlet expression due to dimensionality. ' + f'Array dimensions: {array.shape}, expression in code: {astutils.unparse(node)}') # If undefined, default number of accesses is the slice size if num_accesses is None: diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index c9d92b7860..b5d27e14f4 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2510,6 +2510,7 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): # Looking for the first argument in a tasklet annotation: @dace.tasklet(STRING HERE) langInf = None + side_effects = None if isinstance(node, ast.FunctionDef) and \ hasattr(node, 'decorator_list') and \ isinstance(node.decorator_list, list) and \ @@ -2522,6 +2523,19 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): langArg = node.decorator_list[0].args[0].value langInf = dtypes.Language[langArg] + # Extract arguments from with statement + if isinstance(node, ast.With): + expr = node.items[0].context_expr + if isinstance(expr, ast.Call): + args = astutils.parse_function_arguments(expr, ['language', 'side_effects']) + langArg = args.get('language', None) + side_effects = args.get('side_effects', None) + langInf = astutils.evalnode(langArg, {**self.globals, **self.defined}) + if isinstance(langInf, str): + langInf = dtypes.Language[langInf] + + side_effects = astutils.evalnode(side_effects, {**self.globals, **self.defined}) + ttrans = TaskletTransformer(self, self.defined, self.sdfg, @@ -2536,6 +2550,9 @@ def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): symbols=self.symbols) node, inputs, outputs, self.accesses = ttrans.parse_tasklet(node, name) + if side_effects is not None: + node.side_effects = side_effects + # Convert memlets to their actual data nodes for i in inputs.values(): if not isinstance(i, tuple) and i.data in self.scope_vars.keys(): diff --git a/dace/frontend/python/preprocessing.py b/dace/frontend/python/preprocessing.py index 10a1ab120e..239875118f 100644 --- a/dace/frontend/python/preprocessing.py +++ b/dace/frontend/python/preprocessing.py @@ -1268,7 +1268,7 @@ def _convert_to_ast(contents: Any): node) else: # Augment closure with new value - newnode = self.resolver.global_value_to_node(e, node, f'inlined_{id(contents)}', True, keep_object=True) + newnode = self.resolver.global_value_to_node(contents, node, f'inlined_{id(contents)}', True, keep_object=True) return newnode return _convert_to_ast(contents) diff --git a/dace/transformation/passes/analysis.py b/dace/transformation/passes/analysis.py index 1ca92d5ffd..86e1cde062 100644 --- a/dace/transformation/passes/analysis.py +++ b/dace/transformation/passes/analysis.py @@ -14,6 +14,7 @@ Set[Tuple[SDFGState, Union[nd.AccessNode, InterstateEdge]]]]] SymbolScopeDict = Dict[str, Dict[Edge[InterstateEdge], Set[Union[Edge[InterstateEdge], SDFGState]]]] + @properties.make_properties class StateReachability(ppl.Pass): """ @@ -35,13 +36,68 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Set[SDFGSta """ reachable: Dict[int, Dict[SDFGState, Set[SDFGState]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - reachable[sdfg.sdfg_id] = {} - tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) - for state in sdfg.nodes(): - reachable[sdfg.sdfg_id][state] = set(tc.successors(state)) + result: Dict[SDFGState, Set[SDFGState]] = {} + + # In networkx this is currently implemented naively for directed graphs. + # The implementation below is faster + # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) + + for n, v in reachable_nodes(sdfg.nx): + result[n] = set(v) + + reachable[sdfg.sdfg_id] = result + return reachable +def _single_shortest_path_length_no_self(adj, source): + """Yields (node, level) in a breadth first search, without the first level + unless a self-edge exists. + + Adapted from Shortest Path Length helper function in NetworkX. + + Parameters + ---------- + adj : dict + Adjacency dict or view + firstlevel : dict + starting nodes, e.g. {source: 1} or {target: 1} + cutoff : int or float + level at which we stop the process + """ + firstlevel = {source: 1} + + seen = {} # level (number of hops) when seen in BFS + level = 0 # the current level + nextlevel = set(firstlevel) # set of nodes to check at next level + n = len(adj) + while nextlevel: + thislevel = nextlevel # advance to next level + nextlevel = set() # and start a new set (fringe) + found = [] + for v in thislevel: + if v not in seen: + if level == 0 and v is source: # Skip 0-length path to self + found.append(v) + continue + seen[v] = level # set the level of vertex v + found.append(v) + yield (v, level) + if len(seen) == n: + return + for v in found: + nextlevel.update(adj[v]) + level += 1 + del seen + + +def reachable_nodes(G): + """Computes the reachable nodes in G.""" + adj = G.adj + for n in G: + yield (n, dict(_single_shortest_path_length_no_self(adj, n))) + + @properties.make_properties class SymbolAccessSets(ppl.Pass): """ @@ -57,9 +113,8 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply return modified & ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Symbols | ppl.Modifies.Nodes - def apply_pass( - self, top_sdfg: SDFG, _ - ) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]: + def apply_pass(self, top_sdfg: SDFG, + _) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]: """ :return: A dictionary mapping each state to a tuple of its (read, written) data descriptors. """ @@ -216,9 +271,8 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def depends_on(self): return {SymbolAccessSets, StateReachability} - def _find_dominating_write( - self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]], state_idom: Dict[SDFGState, SDFGState] - ) -> Optional[Edge[InterstateEdge]]: + def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]], + state_idom: Dict[SDFGState, SDFGState]) -> Optional[Edge[InterstateEdge]]: last_state: SDFGState = read if isinstance(read, SDFGState) else read.src in_edges = last_state.parent.in_edges(last_state) @@ -257,9 +311,9 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int, idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state) all_doms = cfg.all_dominators(sdfg, idom) - symbol_access_sets: Dict[ - Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]] - ] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id] + symbol_access_sets: Dict[Union[SDFGState, Edge[InterstateEdge]], + Tuple[Set[str], + Set[str]]] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id] state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.sdfg_id] for read_loc, (reads, _) in symbol_access_sets.items(): @@ -321,12 +375,14 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def depends_on(self): return {AccessSets, FindAccessNodes, StateReachability} - def _find_dominating_write( - self, desc: str, state: SDFGState, read: Union[nd.AccessNode, InterstateEdge], - access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]], - state_idom: Dict[SDFGState, SDFGState], access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]], - no_self_shadowing: bool = False - ) -> Optional[Tuple[SDFGState, nd.AccessNode]]: + def _find_dominating_write(self, + desc: str, + state: SDFGState, + read: Union[nd.AccessNode, InterstateEdge], + access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]], + state_idom: Dict[SDFGState, SDFGState], + access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]], + no_self_shadowing: bool = False) -> Optional[Tuple[SDFGState, nd.AccessNode]]: if isinstance(read, nd.AccessNode): # If the read is also a write, it shadows itself. iedges = state.in_edges(read) @@ -408,18 +464,21 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i for oedge in out_edges: syms = oedge.data.free_symbols & anames if desc in syms: - write = self._find_dominating_write( - desc, state, oedge.data, access_nodes, idom, access_sets - ) + write = self._find_dominating_write(desc, state, oedge.data, access_nodes, idom, + access_sets) result[desc][write].add((state, oedge.data)) # Take care of any write nodes that have not been assigned to a scope yet, i.e., writes that are not # dominating any reads and are thus not part of the results yet. for state in desc_states_with_nodes: for write_node in access_nodes[desc][state][1]: if not (state, write_node) in result[desc]: - write = self._find_dominating_write( - desc, state, write_node, access_nodes, idom, access_sets, no_self_shadowing=True - ) + write = self._find_dominating_write(desc, + state, + write_node, + access_nodes, + idom, + access_sets, + no_self_shadowing=True) result[desc][write].add((state, write_node)) # If any write A is dominated by another write B and any reads in B's scope are also reachable by A, diff --git a/requirements.txt b/requirements.txt index 33cd58a0bf..ea4db45916 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ charset-normalizer==3.1.0 click==8.1.3 dill==0.3.6 Flask==2.3.2 -fparser==0.1.2 +fparser==0.1.3 idna==3.4 importlib-metadata==6.6.0 itsdangerous==2.1.2 diff --git a/setup.py b/setup.py index b1737aed5a..6f97086543 100644 --- a/setup.py +++ b/setup.py @@ -74,7 +74,7 @@ include_package_data=True, install_requires=[ 'numpy', 'networkx >= 2.5', 'astunparse', 'sympy<=1.9', 'pyyaml', 'ply', 'websockets', 'requests', 'flask', - 'fparser >= 0.1.2', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill', + 'fparser >= 0.1.3', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill', 'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"' ] + cmake_requires, extras_require={ diff --git a/tests/numpy/advanced_indexing_test.py b/tests/numpy/advanced_indexing_test.py index 48853cdf26..d2c348ce95 100644 --- a/tests/numpy/advanced_indexing_test.py +++ b/tests/numpy/advanced_indexing_test.py @@ -1,231 +1,246 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" -Tests for numpy advanced indexing syntax. See also: -https://numpy.org/devdocs/reference/arrays.indexing.html -""" -import dace -import numpy as np -import pytest - -N = dace.symbol('N') -M = dace.symbol('M') - - -def test_flat(): - @dace.program - def indexing_test(A: dace.float64[20, 30]): - return A.flat - - A = np.random.rand(20, 30) - res = indexing_test(A) - assert np.allclose(A.flat, res) - - -def test_flat_noncontiguous(): - with dace.config.set_temporary('compiler', 'allow_view_arguments', value=True): - - @dace.program - def indexing_test(A): - return A.flat - - A = np.random.rand(20, 30).transpose() - res = indexing_test(A) - assert np.allclose(A.flat, res) - - -def test_ellipsis(): - @dace.program - def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): - return A[1:5, ..., 0] - - A = np.random.rand(5, 5, 5, 5, 5) - res = indexing_test(A) - assert np.allclose(A[1:5, ..., 0], res) - - -def test_aug_implicit(): - @dace.program - def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): - A[:, 1:5][:, 0:2] += 5 - - A = np.random.rand(5, 5, 5, 5, 5) - regression = np.copy(A) - regression[:, 1:5][:, 0:2] += 5 - indexing_test(A) - assert np.allclose(A, regression) - - -def test_ellipsis_aug(): - @dace.program - def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): - A[1:5, ..., 0] += 5 - - A = np.random.rand(5, 5, 5, 5, 5) - regression = np.copy(A) - regression[1:5, ..., 0] += 5 - indexing_test(A) - assert np.allclose(A, regression) - - -def test_newaxis(): - @dace.program - def indexing_test(A: dace.float64[20, 30]): - return A[:, np.newaxis, None, :] - - A = np.random.rand(20, 30) - res = indexing_test(A) - assert res.shape == (20, 1, 1, 30) - assert np.allclose(A[:, np.newaxis, None, :], res) - - -def test_multiple_newaxis(): - @dace.program - def indexing_test(A: dace.float64[10, 20, 30]): - return A[np.newaxis, :, np.newaxis, np.newaxis, :, np.newaxis, :, np.newaxis] - - A = np.random.rand(10, 20, 30) - res = indexing_test(A) - assert res.shape == (1, 10, 1, 1, 20, 1, 30, 1) - assert np.allclose(A[np.newaxis, :, np.newaxis, np.newaxis, :, np.newaxis, :, np.newaxis], res) - - -def test_index_intarr_1d(): - @dace.program - def indexing_test(A: dace.float64[N], indices: dace.int32[M]): - return A[indices] - - A = np.random.rand(20) - indices = [1, 10, 15] - res = indexing_test(A, indices, M=3) - assert np.allclose(A[indices], res) - - -def test_index_intarr_1d_literal(): - @dace.program - def indexing_test(A: dace.float64[20]): - return A[[1, 10, 15]] - - A = np.random.rand(20) - indices = [1, 10, 15] - res = indexing_test(A) - assert np.allclose(A[indices], res) - - -def test_index_intarr_1d_constant(): - indices = [1, 10, 15] - - @dace.program - def indexing_test(A: dace.float64[20]): - return A[indices] - - A = np.random.rand(20) - res = indexing_test(A) - assert np.allclose(A[indices], res) - - -def test_index_intarr_1d_multi(): - @dace.program - def indexing_test(A: dace.float64[20, 10, 30], indices: dace.int32[3]): - return A[indices, 2:7:2, [15, 10, 1]] - - A = np.random.rand(20, 10, 30) - indices = [1, 10, 15] - res = indexing_test(A, indices) - # FIXME: NumPy behavior is unclear in this case - assert np.allclose(np.diag(A[indices, 2:7:2, [15, 10, 1]]), res) - - -def test_index_intarr_nd(): - @dace.program - def indexing_test(A: dace.float64[4, 3], rows: dace.int64[2, 2], columns: dace.int64[2, 2]): - return A[rows, columns] - - A = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=np.float64) - rows = np.array([[0, 0], [3, 3]], dtype=np.intp) - columns = np.array([[0, 2], [0, 2]], dtype=np.intp) - expected = A[rows, columns] - res = indexing_test(A, rows, columns) - assert np.allclose(expected, res) - - -def test_index_boolarr_rhs(): - @dace.program - def indexing_test(A: dace.float64[20, 30]): - return A[A > 15] - - A = np.ndarray((20, 30), dtype=np.float64) - for i in range(20): - A[i, :] = np.arange(0, 30) - regression = A[A > 15] - - # Right-hand side boolean array indexing is unsupported - with pytest.raises(IndexError): - res = indexing_test(A) - assert np.allclose(regression, res) - - -def test_index_multiboolarr(): - @dace.program - def indexing_test(A: dace.float64[20, 20], B: dace.bool[20]): - A[B, B] = 2 - - A = np.ndarray((20, 20), dtype=np.float64) - for i in range(20): - A[i, :] = np.arange(0, 20) - B = A[:, 1] > 0 - - # Advanced indexing with multiple boolean arrays should be disallowed - with pytest.raises(IndexError): - indexing_test(A, B) - - -def test_index_boolarr_fixed(): - @dace.program - def indexing_test(A: dace.float64[20, 30], barr: dace.bool[20, 30]): - A[barr] += 5 - - A = np.ndarray((20, 30), dtype=np.float64) - for i in range(20): - A[i, :] = np.arange(0, 30) - barr = A > 15 - regression = np.copy(A) - regression[barr] += 5 - - indexing_test(A, barr) - - assert np.allclose(regression, A) - - -def test_index_boolarr_inline(): - @dace.program - def indexing_test(A: dace.float64[20, 30]): - A[A > 15] = 2 - - A = np.ndarray((20, 30), dtype=np.float64) - for i in range(20): - A[i, :] = np.arange(0, 30) - regression = np.copy(A) - regression[A > 15] = 2 - - indexing_test(A) - - assert np.allclose(regression, A) - - -if __name__ == '__main__': - test_flat() - test_flat_noncontiguous() - test_ellipsis() - test_aug_implicit() - test_ellipsis_aug() - test_newaxis() - test_multiple_newaxis() - test_index_intarr_1d() - test_index_intarr_1d_literal() - test_index_intarr_1d_constant() - test_index_intarr_1d_multi() - test_index_intarr_nd() - test_index_boolarr_rhs() - test_index_multiboolarr() - test_index_boolarr_fixed() - test_index_boolarr_inline() +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +""" +Tests for numpy advanced indexing syntax. See also: +https://numpy.org/devdocs/reference/arrays.indexing.html +""" +import dace +from dace.frontend.python.common import DaceSyntaxError +import numpy as np +import pytest + +N = dace.symbol('N') +M = dace.symbol('M') + + +def test_flat(): + + @dace.program + def indexing_test(A: dace.float64[20, 30]): + return A.flat + + A = np.random.rand(20, 30) + res = indexing_test(A) + assert np.allclose(A.flat, res) + + +def test_flat_noncontiguous(): + with dace.config.set_temporary('compiler', 'allow_view_arguments', value=True): + + @dace.program + def indexing_test(A): + return A.flat + + A = np.random.rand(20, 30).transpose() + res = indexing_test(A) + assert np.allclose(A.flat, res) + + +def test_ellipsis(): + + @dace.program + def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): + return A[1:5, ..., 0] + + A = np.random.rand(5, 5, 5, 5, 5) + res = indexing_test(A) + assert np.allclose(A[1:5, ..., 0], res) + + +def test_aug_implicit(): + + @dace.program + def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): + A[:, 1:5][:, 0:2] += 5 + + A = np.random.rand(5, 5, 5, 5, 5) + regression = np.copy(A) + regression[:, 1:5][:, 0:2] += 5 + indexing_test(A) + assert np.allclose(A, regression) + + +def test_ellipsis_aug(): + + @dace.program + def indexing_test(A: dace.float64[5, 5, 5, 5, 5]): + A[1:5, ..., 0] += 5 + + A = np.random.rand(5, 5, 5, 5, 5) + regression = np.copy(A) + regression[1:5, ..., 0] += 5 + indexing_test(A) + assert np.allclose(A, regression) + + +def test_newaxis(): + + @dace.program + def indexing_test(A: dace.float64[20, 30]): + return A[:, np.newaxis, None, :] + + A = np.random.rand(20, 30) + res = indexing_test(A) + assert res.shape == (20, 1, 1, 30) + assert np.allclose(A[:, np.newaxis, None, :], res) + + +def test_multiple_newaxis(): + + @dace.program + def indexing_test(A: dace.float64[10, 20, 30]): + return A[np.newaxis, :, np.newaxis, np.newaxis, :, np.newaxis, :, np.newaxis] + + A = np.random.rand(10, 20, 30) + res = indexing_test(A) + assert res.shape == (1, 10, 1, 1, 20, 1, 30, 1) + assert np.allclose(A[np.newaxis, :, np.newaxis, np.newaxis, :, np.newaxis, :, np.newaxis], res) + + +def test_index_intarr_1d(): + + @dace.program + def indexing_test(A: dace.float64[N], indices: dace.int32[M]): + return A[indices] + + A = np.random.rand(20) + indices = [1, 10, 15] + res = indexing_test(A, indices, M=3) + assert np.allclose(A[indices], res) + + +def test_index_intarr_1d_literal(): + + @dace.program + def indexing_test(A: dace.float64[20]): + return A[[1, 10, 15]] + + A = np.random.rand(20) + indices = [1, 10, 15] + res = indexing_test(A) + assert np.allclose(A[indices], res) + + +def test_index_intarr_1d_constant(): + indices = [1, 10, 15] + + @dace.program + def indexing_test(A: dace.float64[20]): + return A[indices] + + A = np.random.rand(20) + res = indexing_test(A) + assert np.allclose(A[indices], res) + + +def test_index_intarr_1d_multi(): + + @dace.program + def indexing_test(A: dace.float64[20, 10, 30], indices: dace.int32[3]): + return A[indices, 2:7:2, [15, 10, 1]] + + A = np.random.rand(20, 10, 30) + indices = [1, 10, 15] + res = indexing_test(A, indices) + # FIXME: NumPy behavior is unclear in this case + assert np.allclose(np.diag(A[indices, 2:7:2, [15, 10, 1]]), res) + + +def test_index_intarr_nd(): + + @dace.program + def indexing_test(A: dace.float64[4, 3], rows: dace.int64[2, 2], columns: dace.int64[2, 2]): + return A[rows, columns] + + A = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=np.float64) + rows = np.array([[0, 0], [3, 3]], dtype=np.intp) + columns = np.array([[0, 2], [0, 2]], dtype=np.intp) + expected = A[rows, columns] + res = indexing_test(A, rows, columns) + assert np.allclose(expected, res) + + +def test_index_boolarr_rhs(): + + @dace.program + def indexing_test(A: dace.float64[20, 30]): + return A[A > 15] + + A = np.ndarray((20, 30), dtype=np.float64) + for i in range(20): + A[i, :] = np.arange(0, 30) + regression = A[A > 15] + + # Right-hand side boolean array indexing is unsupported + with pytest.raises(IndexError): + res = indexing_test(A) + assert np.allclose(regression, res) + + +def test_index_multiboolarr(): + + @dace.program + def indexing_test(A: dace.float64[20, 20], B: dace.bool[20]): + A[B, B] = 2 + + A = np.ndarray((20, 20), dtype=np.float64) + for i in range(20): + A[i, :] = np.arange(0, 20) + B = A[:, 1] > 0 + + # Advanced indexing with multiple boolean arrays should be disallowed + with pytest.raises(DaceSyntaxError): + indexing_test(A, B) + + +def test_index_boolarr_fixed(): + + @dace.program + def indexing_test(A: dace.float64[20, 30], barr: dace.bool[20, 30]): + A[barr] += 5 + + A = np.ndarray((20, 30), dtype=np.float64) + for i in range(20): + A[i, :] = np.arange(0, 30) + barr = A > 15 + regression = np.copy(A) + regression[barr] += 5 + + indexing_test(A, barr) + + assert np.allclose(regression, A) + + +def test_index_boolarr_inline(): + + @dace.program + def indexing_test(A: dace.float64[20, 30]): + A[A > 15] = 2 + + A = np.ndarray((20, 30), dtype=np.float64) + for i in range(20): + A[i, :] = np.arange(0, 30) + regression = np.copy(A) + regression[A > 15] = 2 + + indexing_test(A) + + assert np.allclose(regression, A) + + +if __name__ == '__main__': + test_flat() + test_flat_noncontiguous() + test_ellipsis() + test_aug_implicit() + test_ellipsis_aug() + test_newaxis() + test_multiple_newaxis() + test_index_intarr_1d() + test_index_intarr_1d_literal() + test_index_intarr_1d_constant() + test_index_intarr_1d_multi() + test_index_intarr_nd() + test_index_boolarr_rhs() + test_index_multiboolarr() + test_index_boolarr_fixed() + test_index_boolarr_inline()