diff --git a/.github/workflows/pace-build-ci.yml b/.github/workflows/pace-build-ci.yml deleted file mode 100644 index 672c891a55..0000000000 --- a/.github/workflows/pace-build-ci.yml +++ /dev/null @@ -1,75 +0,0 @@ -name: NASA/NOAA Pace repository build test - -on: - workflow_dispatch: - -defaults: - run: - shell: bash - -jobs: - build_pace: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.8.10] - - steps: - - uses: actions/checkout@v2 - with: - repository: 'git@github.com:GEOS-ESM/pace.git' - ref: 'ci/DaCe' - submodules: 'recursive' - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies & pull correct DaCe - run: | - cd pace - python -m pip install --upgrade pip wheel setuptools - cd external/dace - git checkout ${{ github.sha }} - cd ../.. - pip install -e external/gt4py - pip install -e external/dace - pip install -r requirements_dev.txt - - name: Download data - run: | - cd pace - mkdir -p test_data - cd test_data - wget https://portal.nccs.nasa.gov/datashare/astg/smt/pace-regression-data/8.1.3_c12_6_ranks_standard.D_SW.tar.gz - tar -xzvf 8.1.3_c12_6_ranks_standard.D_SW.tar.gz - wget https://portal.nccs.nasa.gov/datashare/astg/smt/pace-regression-data/8.1.3_c12_6_ranks_standard.RiemSolverC.tar.gz - tar -xzvf 8.1.3_c12_6_ranks_standard.RiemSolverC.tar.gz - wget https://portal.nccs.nasa.gov/datashare/astg/smt/pace-regression-data/8.1.3_c12_6_ranks_standard.Remapping.tar.gz - tar -xzvf 8.1.3_c12_6_ranks_standard.Remapping.tar.gz - cd ../.. - - name: "Regression test: Riemman Solver on C-grid" - run: | - export FV3_DACEMODE=BuildAndRun - export PACE_CONSTANTS=GFS - cd pace - pytest -v -s --data_path=./test_data/8.1.3/c12_6ranks_standard/dycore \ - --backend=dace:cpu --which_modules=Riem_Solver_C \ - --threshold_overrides_file=./fv3core/tests/savepoint/translate/overrides/standard.yaml \ - ./fv3core/tests/savepoint - - name: "Regression test: D-grid shallow water lagrangian dynamics (D_SW)" - run: | - export FV3_DACEMODE=BuildAndRun - export PACE_CONSTANTS=GFS - cd pace - pytest -v -s --data_path=./test_data/8.1.3/c12_6ranks_standard/dycore \ - --backend=dace:cpu --which_modules=D_SW \ - --threshold_overrides_file=./fv3core/tests/savepoint/translate/overrides/standard.yaml \ - ./fv3core/tests/savepoint - - name: "Regression test: Remapping (on rank 0 only)" - run: | - export FV3_DACEMODE=BuildAndRun - export PACE_CONSTANTS=GFS - cd pace - pytest -v -s --data_path=./test_data/8.1.3/c12_6ranks_standard/dycore \ - --backend=dace:cpu --which_modules=Remapping --which_rank=0 \ - --threshold_overrides_file=./fv3core/tests/savepoint/translate/overrides/standard.yaml \ - ./fv3core/tests/savepoint diff --git a/.github/workflows/pyFV3-ci.yml b/.github/workflows/pyFV3-ci.yml new file mode 100644 index 0000000000..2b98327381 --- /dev/null +++ b/.github/workflows/pyFV3-ci.yml @@ -0,0 +1,96 @@ +name: NASA/NOAA pyFV3 repository build test + +on: + push: + branches: [ master, ci-fix ] + pull_request: + branches: [ master, ci-fix ] + merge_group: + branches: [ master, ci-fix ] + +defaults: + run: + shell: bash + +jobs: + build_and_validate_pyFV3: + if: "!contains(github.event.pull_request.labels.*.name, 'no-ci')" + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.11.7] + + steps: + - uses: actions/checkout@v2 + with: + repository: 'NOAA-GFDL/PyFV3' + ref: 'ci/DaCe' + submodules: 'recursive' + path: 'pyFV3' + - uses: actions/checkout@v2 + with: + path: 'dace' + submodules: 'recursive' + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install library dependencies + run: | + sudo apt-get update + sudo apt-get install -y libopenmpi-dev libboost-all-dev + gcc --version + # Because Github doesn't allow us to do a git checkout in code + # we use a trick to checkout DaCe first (not using the external submodule) + # install the full suite via requirements_dev, then re-install the correct DaCe + - name: Install Python packages + run: | + python -m pip install --upgrade pip wheel setuptools + pip install -e ./pyFV3[develop] + pip install -e ./dace + - name: Download data + run: | + cd pyFV3 + mkdir -p test_data + cd test_data + wget https://portal.nccs.nasa.gov/datashare/astg/smt/pace-regression-data/8.1.3_c12_6ranks_standard.D_SW.tar.gz + tar -xzvf 8.1.3_c12_6ranks_standard.D_SW.tar.gz + wget https://portal.nccs.nasa.gov/datashare/astg/smt/pace-regression-data/8.1.3_c12_6ranks_standard.RiemSolver3.tar.gz + tar -xzvf 8.1.3_c12_6ranks_standard.RiemSolver3.tar.gz + wget https://portal.nccs.nasa.gov/datashare/astg/smt/pace-regression-data/8.1.3_c12_6ranks_standard.Remapping.tar.gz + tar -xzvf 8.1.3_c12_6ranks_standard.Remapping.tar.gz + cd ../.. + # Clean up caches between run for stale un-expanded SDFG to trip the build system (NDSL side issue) + - name: "Regression test: Riemman Solver on D-grid (RiemSolver3)" + env: + FV3_DACEMODE: BuildAndRun + PACE_CONSTANTS: GFS + PACE_LOGLEVEL: Debug + run: | + pytest -v -s --data_path=./pyFV3/test_data/8.1.3/c12_6ranks_standard/dycore \ + --backend=dace:cpu --which_modules=Riem_Solver3 \ + --threshold_overrides_file=./pyFV3/tests/savepoint/translate/overrides/standard.yaml \ + ./pyFV3/tests/savepoint + rm -r ./.gt_cache_FV3_A + - name: "Regression test: Shallow water lagrangian dynamics on D-grid (D_SW) (on rank 0 only)" + env: + FV3_DACEMODE: BuildAndRun + PACE_CONSTANTS: GFS + PACE_LOGLEVEL: Debug + run: | + pytest -v -s --data_path=./pyFV3/test_data/8.1.3/c12_6ranks_standard/dycore \ + --backend=dace:cpu --which_modules=D_SW --which_rank=0 \ + --threshold_overrides_file=./pyFV3/tests/savepoint/translate/overrides/standard.yaml \ + ./pyFV3/tests/savepoint + rm -r ./.gt_cache_FV3_A + - name: "Regression test: Remapping (on rank 0 only)" + env: + FV3_DACEMODE: BuildAndRun + PACE_CONSTANTS: GFS + PACE_LOGLEVEL: Debug + run: | + pytest -v -s --data_path=./pyFV3/test_data/8.1.3/c12_6ranks_standard/dycore \ + --backend=dace:cpu --which_modules=Remapping --which_rank=0 \ + --threshold_overrides_file=./pyFV3/tests/savepoint/translate/overrides/standard.yaml \ + ./pyFV3/tests/savepoint + rm -r ./.gt_cache_FV3_A diff --git a/dace/cli/sdfv.py b/dace/cli/sdfv.py index f503775814..49255a1e7e 100644 --- a/dace/cli/sdfv.py +++ b/dace/cli/sdfv.py @@ -23,7 +23,7 @@ class NewCls(cls): return NewCls -def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None): +def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None, verbose: bool = True): """ View an sdfg in the system's HTML viewer @@ -33,6 +33,7 @@ def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None): the generated HTML and related sources will be served using a basic web server on that port, blocking the current thread. + :param verbose: Be verbose. """ # If vscode is open, try to open it inside vscode if filename is None: @@ -71,7 +72,8 @@ def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None): with open(html_filename, "w") as f: f.write(html) - print("File saved at %s" % html_filename) + if(verbose): + print("File saved at %s" % html_filename) if fd is not None: os.close(fd) @@ -83,7 +85,8 @@ def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None): # start the web server handler = partialclass(http.server.SimpleHTTPRequestHandler, directory=dirname) httpd = http.server.HTTPServer(('localhost', filename), handler) - print(f"Serving at localhost:{filename}, press enter to stop...") + if(verbose): + print(f"Serving at localhost:{filename}, press enter to stop...") # start the server in a different thread def serve(): diff --git a/dace/config_schema.yml b/dace/config_schema.yml index 737862cacc..b26e96e920 100644 --- a/dace/config_schema.yml +++ b/dace/config_schema.yml @@ -945,10 +945,10 @@ required: serialize_all_fields: type: bool - default: true + default: false title: Serialize all unmodified fields in SDFG files description: > - If False, saving an SDFG keeps only the modified non-default properties. If True, + If False (default), saving an SDFG keeps only the modified non-default properties. If True, saves all fields. ############################################# diff --git a/dace/dtypes.py b/dace/dtypes.py index 76e6db8397..f04200e63b 100644 --- a/dace/dtypes.py +++ b/dace/dtypes.py @@ -1216,6 +1216,7 @@ def isconstant(var): int16 = typeclass(numpy.int16) int32 = typeclass(numpy.int32) int64 = typeclass(numpy.int64) +uintp = typeclass(numpy.uintp) uint8 = typeclass(numpy.uint8) uint16 = typeclass(numpy.uint16) uint32 = typeclass(numpy.uint32) diff --git a/dace/frontend/python/interface.py b/dace/frontend/python/interface.py index 69e650beaa..ecd0b164d6 100644 --- a/dace/frontend/python/interface.py +++ b/dace/frontend/python/interface.py @@ -42,6 +42,7 @@ def program(f: F, recreate_sdfg: bool = True, regenerate_code: bool = True, recompile: bool = True, + distributed_compilation: bool = False, constant_functions=False, **kwargs) -> Callable[..., parser.DaceProgram]: """ @@ -60,6 +61,9 @@ def program(f: F, it. :param recompile: Whether to recompile the code. If False, the library in the build folder will be used if it exists, without recompiling it. + :param distributed_compilation: Whether to compile the code from rank 0, and broadcast it to all the other ranks. + If False, every rank performs the compilation. In this case, make sure to check the ``cache`` configuration entry + such that no caching or clashes can happen between different MPI processes. :param constant_functions: If True, assumes all external functions that do not depend on internal variables are constant. This will hardcode their return values into the @@ -78,7 +82,8 @@ def program(f: F, constant_functions, recreate_sdfg=recreate_sdfg, regenerate_code=regenerate_code, - recompile=recompile) + recompile=recompile, + distributed_compilation=distributed_compilation) function = program diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 3d2ec5c09d..fda2bd2e23 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -823,7 +823,7 @@ def _add_access( arr_type = type(parent_array) if arr_type == data.Scalar: self.sdfg.add_scalar(var_name, dtype) - elif arr_type in (data.Array, data.View): + elif issubclass(arr_type, data.Array): self.sdfg.add_array(var_name, shape, dtype, strides=strides) elif arr_type == data.Stream: self.sdfg.add_stream(var_name, dtype) @@ -3116,7 +3116,7 @@ def _add_access( arr_type = data.Scalar if arr_type == data.Scalar: self.sdfg.add_scalar(var_name, dtype) - elif arr_type in (data.Array, data.View): + elif issubclass(arr_type, data.Array): if non_squeezed: strides = [parent_array.strides[d] for d in non_squeezed] else: diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 14377c4fe2..34cb8fb4ad 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -151,6 +151,7 @@ def __init__(self, recreate_sdfg: bool = True, regenerate_code: bool = True, recompile: bool = True, + distributed_compilation: bool = False, method: bool = False): from dace.codegen import compiled_sdfg # Avoid import loops @@ -171,6 +172,7 @@ def __init__(self, self.recreate_sdfg = recreate_sdfg self.regenerate_code = regenerate_code self.recompile = recompile + self.distributed_compilation = distributed_compilation self.global_vars = _get_locals_and_globals(f) self.signature = inspect.signature(f) @@ -449,12 +451,12 @@ def __call__(self, *args, **kwargs): sdfg.simplify() with hooks.invoke_sdfg_call_hooks(sdfg) as sdfg: - if not mpi4py: + if self.distributed_compilation and mpi4py: + binaryobj = distributed_compile(sdfg, mpi4py.MPI.COMM_WORLD, validate=self.validate) + else: # Compile SDFG (note: this is done after symbol inference due to shape # altering transformations such as Vectorization) binaryobj = sdfg.compile(validate=self.validate) - else: - binaryobj = distributed_compile(sdfg, mpi4py.MPI.COMM_WORLD, validate=self.validate) # Recreate key and add to cache cachekey = self._cache.make_key(argtypes, specified, self.closure_array_keys, self.closure_constant_keys, diff --git a/dace/frontend/python/preprocessing.py b/dace/frontend/python/preprocessing.py index 90ef506bcd..420346ca88 100644 --- a/dace/frontend/python/preprocessing.py +++ b/dace/frontend/python/preprocessing.py @@ -752,7 +752,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: return self.generic_visit(node) # Then query for the right value - if isinstance(node.value, ast.Dict): + if isinstance(node.value, ast.Dict): # Dict for k, v in zip(node.value.keys, node.value.values): try: gkey = astutils.evalnode(k, self.globals) @@ -760,8 +760,20 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: continue if gkey == gslice: return self._visit_potential_constant(v, True) - else: # List or Tuple - return self._visit_potential_constant(node.value.elts[gslice], True) + elif isinstance(node.value, (ast.List, ast.Tuple)): # List & Tuple + # Loop over the list if slicing makes it a list + if isinstance(node.value.elts[gslice], List): + visited_list = astutils.copy_tree(node.value) + visited_list.elts.clear() + for v in node.value.elts[gslice]: + visited_cst = self._visit_potential_constant(v, True) + visited_list.elts.append(visited_cst) + node.value = visited_list + return node + else: + return self._visit_potential_constant(node.value.elts[gslice], True) + else: # Catch-all + return self._visit_potential_constant(node, True) return self._visit_potential_constant(node, True) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 5017a6ff86..b43ff2a7bf 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -461,8 +461,8 @@ def __init__(self, :param name: Name for the SDFG (also used as the filename for the compiled shared library). - :param symbols: Additional dictionary of symbol names -> types that the SDFG - defines, apart from symbolic data sizes. + :param constants: Additional dictionary of compile-time constants + {name (str): tuple(type (dace.data.Data), value (Any))}. :param propagate: If False, disables automatic propagation of memlet subsets from scopes outwards. Saves processing time but disallows certain @@ -1520,6 +1520,8 @@ def save(self, filename: str, use_pickle=False, hash=None, exception=None, compr :param compress: If True, uses gzip to compress the file upon saving. :return: The hash of the SDFG, or None if failed/not requested. """ + filename = os.path.expanduser(filename) + if compress: fileopen = lambda file, mode: gzip.open(file, mode + 't') else: @@ -1547,14 +1549,15 @@ def save(self, filename: str, use_pickle=False, hash=None, exception=None, compr return None - def view(self, filename=None): + def view(self, filename=None, verbose=False): """ View this sdfg in the system's HTML viewer :param filename: the filename to write the HTML to. If `None`, a temporary file will be created. + :param verbose: Be verbose, `False` by default. """ from dace.cli.sdfv import view - view(self, filename=filename) + view(self, filename=filename, verbose=verbose) @staticmethod def _from_file(fp: BinaryIO) -> 'SDFG': diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index a9f7071b0f..cafea3d754 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -389,7 +389,9 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto # Prepend incoming edges until reaching the source node curedge = edge + visited = set() while not isinstance(curedge.src, (nd.CodeNode, nd.AccessNode)): + visited.add(curedge) # Trace through scopes using OUT_# -> IN_# if isinstance(curedge.src, (nd.EntryNode, nd.ExitNode)): if curedge.src_conn is None: @@ -398,10 +400,14 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto next_edge = next(e for e in state.in_edges(curedge.src) if e.dst_conn == "IN_" + curedge.src_conn[4:]) result.insert(0, next_edge) curedge = next_edge + if curedge in visited: + raise ValueError('Cycle encountered while reading memlet path') # Append outgoing edges until reaching the sink node curedge = edge + visited.clear() while not isinstance(curedge.dst, (nd.CodeNode, nd.AccessNode)): + visited.add(curedge) # Trace through scope entry using IN_# -> OUT_# if isinstance(curedge.dst, (nd.EntryNode, nd.ExitNode)): if curedge.dst_conn is None: @@ -411,6 +417,8 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto next_edge = next(e for e in state.out_edges(curedge.dst) if e.src_conn == "OUT_" + curedge.dst_conn[3:]) result.append(next_edge) curedge = next_edge + if curedge in visited: + raise ValueError('Cycle encountered while reading memlet path') return result @@ -434,16 +442,23 @@ def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree: # Find tree root curedge = edge + visited = set() if propagate_forward: while (isinstance(curedge.src, nd.EntryNode) and curedge.src_conn is not None): + visited.add(curedge) assert curedge.src_conn.startswith('OUT_') cname = curedge.src_conn[4:] curedge = next(e for e in state.in_edges(curedge.src) if e.dst_conn == 'IN_%s' % cname) + if curedge in visited: + raise ValueError('Cycle encountered while reading memlet path') elif propagate_backward: while (isinstance(curedge.dst, nd.ExitNode) and curedge.dst_conn is not None): + visited.add(curedge) assert curedge.dst_conn.startswith('IN_') cname = curedge.dst_conn[3:] curedge = next(e for e in state.out_edges(curedge.dst) if e.src_conn == 'OUT_%s' % cname) + if curedge in visited: + raise ValueError('Cycle encountered while reading memlet path') tree_root = mm.MemletTree(curedge, downwards=propagate_forward) # Collect children (recursively) @@ -2477,38 +2492,56 @@ def add_state(self, label=None, is_start_block=False, *, is_start_state: bool=No self.add_node(state, is_start_block=start_block) return state - def add_state_before(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState: + def add_state_before(self, + state: SDFGState, + label=None, + is_start_block=False, + condition: CodeBlock = None, + assignments=None, + *, + is_start_state: bool=None) -> SDFGState: """ Adds a new SDFG state before an existing state, reconnecting predecessors to it instead. :param state: The state to prepend the new state before. :param label: State label. - :param is_start_state: If True, resets scope block starting state to this state. + :param is_start_block: If True, resets scope block starting state to this state. + :param condition: Transition condition of the newly created edge between state and the new state. + :param assignments: Assignments to perform upon transition. :return: A new SDFGState object. """ - new_state = self.add_state(label, is_start_state) + new_state = self.add_state(label, is_start_block=is_start_block, is_start_state=is_start_state) # Reconnect for e in self.in_edges(state): self.remove_edge(e) self.add_edge(e.src, new_state, e.data) - # Add unconditional connection between the new state and the current - self.add_edge(new_state, state, dace.sdfg.InterstateEdge()) + # Add the new edge + self.add_edge(new_state, state, dace.sdfg.InterstateEdge(condition=condition, assignments=assignments)) return new_state - def add_state_after(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState: + def add_state_after(self, + state: SDFGState, + label=None, + is_start_block=False, + condition: CodeBlock = None, + assignments=None, + *, + is_start_state: bool=None) -> SDFGState: """ Adds a new SDFG state after an existing state, reconnecting it to the successors instead. :param state: The state to append the new state after. :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this state. + :param is_start_block: If True, resets scope block starting state to this state. + :param condition: Transition condition of the newly created edge between state and the new state. + :param assignments: Assignments to perform upon transition. :return: A new SDFGState object. """ - new_state = self.add_state(label, is_start_state) + new_state = self.add_state(label, is_start_block=is_start_block, is_start_state=is_start_state) # Reconnect for e in self.out_edges(state): self.remove_edge(e) self.add_edge(new_state, e.dst, e.data) - # Add unconditional connection between the current and the new state - self.add_edge(state, new_state, dace.sdfg.InterstateEdge()) + # Add the new edge + self.add_edge(state, new_state, dace.sdfg.InterstateEdge(condition=condition, assignments=assignments)) return new_state @abc.abstractmethod diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 9a0dd0e313..186ea32acc 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -481,6 +481,12 @@ def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn, other_edges=None) local_node = edge.src src_connector = edge.src_conn + # update edge data in case source or destination is a scalar access node + test_data = [node.data for node in (edge.src, edge.dst) if isinstance(node, nodes.AccessNode)] + for new_data in test_data: + if isinstance(sdfg.arrays[new_data], data.Scalar): + edge.data.data = new_data + # If destination of edge leads to multiple destinations, redirect all through an access node. if other_edges: # NOTE: If a new local node was already created, reuse it. diff --git a/requirements.txt b/requirements.txt index f06f3421cd..e98e33fe74 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,9 +5,9 @@ charset-normalizer==3.1.0 click==8.1.3 dill==0.3.6 fparser==0.1.3 -idna==3.4 +idna==3.7 importlib-metadata==6.6.0 -Jinja2==3.1.3 +Jinja2==3.1.4 MarkupSafe==2.1.3 mpmath==1.3.0 networkx==3.1 diff --git a/tests/python_frontend/unroll_test.py b/tests/python_frontend/unroll_test.py index 98c81156a0..bf2b1e7c91 100644 --- a/tests/python_frontend/unroll_test.py +++ b/tests/python_frontend/unroll_test.py @@ -169,6 +169,52 @@ def tounroll(A: dace.float64[3]): assert np.allclose(a, np.array([1, 2, 3])) +def test_list_global_enumerate(): + tracer_variables = ["vapor", "rain", "nope"] + + @dace.program + def enumerate_parsing( + A, + tracers: dace.compiletime, # Dict[str, np.float64] + ): + for i, q in enumerate(tracer_variables[0:2]): + tracers[q][:] = A # type:ignore + + a = np.ones([3]) + q = { + "vapor": np.zeros([3]), + "rain": np.zeros([3]), + "nope": np.zeros([3]), + } + enumerate_parsing(a, q) + assert np.allclose(q["vapor"], np.array([1, 1, 1])) + assert np.allclose(q["rain"], np.array([1, 1, 1])) + assert np.allclose(q["nope"], np.array([0, 0, 0])) + + +def test_tuple_global_enumerate(): + tracer_variables = ("vapor", "rain", "nope") + + @dace.program + def enumerate_parsing( + A, + tracers: dace.compiletime, # Dict[str, np.float64] + ): + for i, q in enumerate(tracer_variables[0:2]): + tracers[q][:] = A # type:ignore + + a = np.ones([3]) + q = { + "vapor": np.zeros([3]), + "rain": np.zeros([3]), + "nope": np.zeros([3]), + } + enumerate_parsing(a, q) + assert np.allclose(q["vapor"], np.array([1, 1, 1])) + assert np.allclose(q["rain"], np.array([1, 1, 1])) + assert np.allclose(q["nope"], np.array([0, 0, 0])) + + def test_tuple_elements_zip(): a1 = [2, 3, 4] a2 = (4, 5, 6) diff --git a/tests/sdfg/cycles_test.py b/tests/sdfg/cycles_test.py index 5e94db2eb4..480392ab2d 100644 --- a/tests/sdfg/cycles_test.py +++ b/tests/sdfg/cycles_test.py @@ -1,3 +1,4 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import pytest import dace @@ -13,3 +14,21 @@ def test_cycles(): state.add_edge(access, None, access, None, dace.Memlet.simple("A", "0")) sdfg.validate() + + +def test_cycles_memlet_path(): + with pytest.raises(ValueError, match="Found cycles.*"): + sdfg = dace.SDFG("foo") + state = sdfg.add_state() + sdfg.add_array("bla", shape=(10, ), dtype=dace.float32) + mentry_3, _ = state.add_map("map_3", dict(i="0:9")) + mentry_3.add_in_connector("IN_0") + mentry_3.add_out_connector("OUT_0") + state.add_edge(mentry_3, "OUT_0", mentry_3, "IN_0", dace.Memlet(data="bla", subset='0:9')) + + sdfg.validate() + + +if __name__ == '__main__': + test_cycles() + test_cycles_memlet_path() diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 653fb9d120..724c8c97ee 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -163,6 +163,43 @@ def test_fusion_with_transient(): assert np.allclose(A, expected) +def test_fusion_with_transient_scalar(): + N = 10 + K = 4 + + def build_sdfg(): + sdfg = dace.SDFG("map_fusion_with_transient_scalar") + state = sdfg.add_state() + sdfg.add_array("A", (N,K), dace.float64) + sdfg.add_array("B", (N,), dace.float64) + sdfg.add_array("T", (N,), dace.float64, transient=True) + t_node = state.add_access("T") + sdfg.add_scalar("V", dace.float64, transient=True) + v_node = state.add_access("V") + + me1, mx1 = state.add_map("map1", dict(i=f"0:{N}")) + tlet1 = state.add_tasklet("select", {"_v"}, {"_out"}, f"_out = _v[i, {K-1}]") + state.add_memlet_path(state.add_access("A"), me1, tlet1, dst_conn="_v", memlet=dace.Memlet.from_array("A", sdfg.arrays["A"])) + state.add_edge(tlet1, "_out", v_node, None, dace.Memlet("V[0]")) + state.add_memlet_path(v_node, mx1, t_node, memlet=dace.Memlet("T[i]")) + + me2, mx2 = state.add_map("map2", dict(j=f"0:{N}")) + tlet2 = state.add_tasklet("numeric", {"_inp"}, {"_out"}, f"_out = _inp + 1") + state.add_memlet_path(t_node, me2, tlet2, dst_conn="_inp", memlet=dace.Memlet("T[j]")) + state.add_memlet_path(tlet2, mx2, state.add_access("B"), src_conn="_out", memlet=dace.Memlet("B[j]")) + + return sdfg + + sdfg = build_sdfg() + sdfg.apply_transformations(MapFusion) + + A = np.random.rand(N, K) + B = np.repeat(np.nan, N) + sdfg(A=A, B=B) + + assert np.allclose(B, (A[:, K-1] + 1)) + + def test_fusion_with_inverted_indices(): @dace.program @@ -278,6 +315,7 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 test_multiple_fusions() test_fusion_chain() test_fusion_with_transient() + test_fusion_with_transient_scalar() test_fusion_with_inverted_indices() test_fusion_with_empty_memlet() test_fusion_with_nested_sdfg_0() diff --git a/tests/uintptr_t_test.py b/tests/uintptr_t_test.py new file mode 100644 index 0000000000..2b1941340d --- /dev/null +++ b/tests/uintptr_t_test.py @@ -0,0 +1,37 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import dace +import ctypes +import numpy as np + + +def test_uintp_size(): + # c_void_p: C type -> void* + size = ctypes.sizeof(ctypes.c_void_p) + # numpy.uintp: Unsigned integer large enough to fit pointer, compatible with C uintptr_t + size_of_np_uintp = np.uintp().itemsize + # Dace uintptr_t representation + size_of_dace_uintp = dace.uintp.bytes + + assert size == size_of_np_uintp == size_of_dace_uintp + + +def test_uintp_use(): + + @dace.program + def tester(arr: dace.float64[20], pointer: dace.uintp[1]): + with dace.tasklet(dace.Language.CPP): + a << arr(-1) + """ + out = decltype(out)(a); + """ + out >> pointer[0] + + ptr = np.empty([1], dtype=np.uintp) + arr = np.random.rand(20) + tester(arr, ptr) + assert arr.__array_interface__['data'][0] == ptr[0] + + +if __name__ == '__main__': + test_uintp_size() + test_uintp_use()