diff --git a/.github/workflows/pace-build-ci.yml b/.github/workflows/pace-build-ci.yml deleted file mode 100644 index 672c891a55..0000000000 --- a/.github/workflows/pace-build-ci.yml +++ /dev/null @@ -1,75 +0,0 @@ -name: NASA/NOAA Pace repository build test - -on: - workflow_dispatch: - -defaults: - run: - shell: bash - -jobs: - build_pace: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.8.10] - - steps: - - uses: actions/checkout@v2 - with: - repository: 'git@github.com:GEOS-ESM/pace.git' - ref: 'ci/DaCe' - submodules: 'recursive' - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies & pull correct DaCe - run: | - cd pace - python -m pip install --upgrade pip wheel setuptools - cd external/dace - git checkout ${{ github.sha }} - cd ../.. - pip install -e external/gt4py - pip install -e external/dace - pip install -r requirements_dev.txt - - name: Download data - run: | - cd pace - mkdir -p test_data - cd test_data - wget https://portal.nccs.nasa.gov/datashare/astg/smt/pace-regression-data/8.1.3_c12_6_ranks_standard.D_SW.tar.gz - tar -xzvf 8.1.3_c12_6_ranks_standard.D_SW.tar.gz - wget https://portal.nccs.nasa.gov/datashare/astg/smt/pace-regression-data/8.1.3_c12_6_ranks_standard.RiemSolverC.tar.gz - tar -xzvf 8.1.3_c12_6_ranks_standard.RiemSolverC.tar.gz - wget https://portal.nccs.nasa.gov/datashare/astg/smt/pace-regression-data/8.1.3_c12_6_ranks_standard.Remapping.tar.gz - tar -xzvf 8.1.3_c12_6_ranks_standard.Remapping.tar.gz - cd ../.. - - name: "Regression test: Riemman Solver on C-grid" - run: | - export FV3_DACEMODE=BuildAndRun - export PACE_CONSTANTS=GFS - cd pace - pytest -v -s --data_path=./test_data/8.1.3/c12_6ranks_standard/dycore \ - --backend=dace:cpu --which_modules=Riem_Solver_C \ - --threshold_overrides_file=./fv3core/tests/savepoint/translate/overrides/standard.yaml \ - ./fv3core/tests/savepoint - - name: "Regression test: D-grid shallow water lagrangian dynamics (D_SW)" - run: | - export FV3_DACEMODE=BuildAndRun - export PACE_CONSTANTS=GFS - cd pace - pytest -v -s --data_path=./test_data/8.1.3/c12_6ranks_standard/dycore \ - --backend=dace:cpu --which_modules=D_SW \ - --threshold_overrides_file=./fv3core/tests/savepoint/translate/overrides/standard.yaml \ - ./fv3core/tests/savepoint - - name: "Regression test: Remapping (on rank 0 only)" - run: | - export FV3_DACEMODE=BuildAndRun - export PACE_CONSTANTS=GFS - cd pace - pytest -v -s --data_path=./test_data/8.1.3/c12_6ranks_standard/dycore \ - --backend=dace:cpu --which_modules=Remapping --which_rank=0 \ - --threshold_overrides_file=./fv3core/tests/savepoint/translate/overrides/standard.yaml \ - ./fv3core/tests/savepoint diff --git a/.github/workflows/pyFV3-ci.yml b/.github/workflows/pyFV3-ci.yml new file mode 100644 index 0000000000..f50f424bb8 --- /dev/null +++ b/.github/workflows/pyFV3-ci.yml @@ -0,0 +1,94 @@ +name: NASA/NOAA pyFV3 repository build test + +on: + push: + branches: [ master, ci-fix ] + pull_request: + branches: [ master, ci-fix ] + +defaults: + run: + shell: bash + +jobs: + build_and_validate_pyFV3: + if: "!contains(github.event.pull_request.labels.*.name, 'no-ci')" + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.11.7] + + steps: + - uses: actions/checkout@v2 + with: + repository: 'NOAA-GFDL/PyFV3' + ref: 'ci/DaCe' + submodules: 'recursive' + path: 'pyFV3' + - uses: actions/checkout@v2 + with: + path: 'dace' + submodules: 'recursive' + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install library dependencies + run: | + sudo apt-get install libopenmpi-dev libboost-all-dev gcc-13 + sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-13 13 + gcc --version + # Because Github doesn't allow us to do a git checkout in code + # we use a trick to checkout DaCe first (not using the external submodule) + # install the full suite via requirements_dev, then re-install the correct DaCe + - name: Install Python packages + run: | + python -m pip install --upgrade pip wheel setuptools + pip install -e ./pyFV3[develop] + pip install -e ./dace + - name: Download data + run: | + cd pyFV3 + mkdir -p test_data + cd test_data + wget https://portal.nccs.nasa.gov/datashare/astg/smt/pace-regression-data/8.1.3_c12_6ranks_standard.D_SW.tar.gz + tar -xzvf 8.1.3_c12_6ranks_standard.D_SW.tar.gz + wget https://portal.nccs.nasa.gov/datashare/astg/smt/pace-regression-data/8.1.3_c12_6ranks_standard.RiemSolver3.tar.gz + tar -xzvf 8.1.3_c12_6ranks_standard.RiemSolver3.tar.gz + wget https://portal.nccs.nasa.gov/datashare/astg/smt/pace-regression-data/8.1.3_c12_6ranks_standard.Remapping.tar.gz + tar -xzvf 8.1.3_c12_6ranks_standard.Remapping.tar.gz + cd ../.. + # Clean up caches between run for stale un-expanded SDFG to trip the build system (NDSL side issue) + - name: "Regression test: Riemman Solver on D-grid (RiemSolver3)" + env: + FV3_DACEMODE: BuildAndRun + PACE_CONSTANTS: GFS + PACE_LOGLEVEL: Debug + run: | + pytest -v -s --data_path=./pyFV3/test_data/8.1.3/c12_6ranks_standard/dycore \ + --backend=dace:cpu --which_modules=Riem_Solver3 \ + --threshold_overrides_file=./pyFV3/tests/savepoint/translate/overrides/standard.yaml \ + ./pyFV3/tests/savepoint + rm -r ./.gt_cache_FV3_A + - name: "Regression test: Shallow water lagrangian dynamics on D-grid (D_SW) (on rank 0 only)" + env: + FV3_DACEMODE: BuildAndRun + PACE_CONSTANTS: GFS + PACE_LOGLEVEL: Debug + run: | + pytest -v -s --data_path=./pyFV3/test_data/8.1.3/c12_6ranks_standard/dycore \ + --backend=dace:cpu --which_modules=D_SW --which_rank=0 \ + --threshold_overrides_file=./pyFV3/tests/savepoint/translate/overrides/standard.yaml \ + ./pyFV3/tests/savepoint + rm -r ./.gt_cache_FV3_A + - name: "Regression test: Remapping (on rank 0 only)" + env: + FV3_DACEMODE: BuildAndRun + PACE_CONSTANTS: GFS + PACE_LOGLEVEL: Debug + run: | + pytest -v -s --data_path=./pyFV3/test_data/8.1.3/c12_6ranks_standard/dycore \ + --backend=dace:cpu --which_modules=Remapping --which_rank=0 \ + --threshold_overrides_file=./pyFV3/tests/savepoint/translate/overrides/standard.yaml \ + ./pyFV3/tests/savepoint + rm -r ./.gt_cache_FV3_A diff --git a/dace/builtin_hooks.py b/dace/builtin_hooks.py index 6af346e033..65b5c9b1a2 100644 --- a/dace/builtin_hooks.py +++ b/dace/builtin_hooks.py @@ -15,7 +15,12 @@ @contextmanager -def profile(repetitions: int = 100, warmup: int = 0): +def profile( + repetitions: int = 100, + warmup: int = 0, + tqdm_leave: bool = True, + print_results: bool = True, +): """ Context manager that enables profiling of each called DaCe program. If repetitions is greater than 1, the program is run multiple times and the average execution time is reported. @@ -35,6 +40,10 @@ def profile(repetitions: int = 100, warmup: int = 0): :param repetitions: The number of times to run each DaCe program. :param warmup: Number of additional repetitions to run the program without measuring time. + :param tqdm_leave: Sets the ``leave`` parameter of the ``tqdm`` progress bar (useful + for nested progress bars). Ignored if tqdm progress bar is not used. + :param print_results: Whether or not to print the median execution time after + all repetitions. :note: Running functions multiple times may affect the results of the program. """ from dace.frontend.operations import CompiledSDFGProfiler # Avoid circular import @@ -51,7 +60,7 @@ def profile(repetitions: int = 100, warmup: int = 0): yield hook return - profiler = CompiledSDFGProfiler(repetitions, warmup) + profiler = CompiledSDFGProfiler(repetitions, warmup, tqdm_leave, print_results) with on_compiled_sdfg_call(context_manager=profiler): yield profiler diff --git a/dace/cli/sdfv.py b/dace/cli/sdfv.py index f503775814..49255a1e7e 100644 --- a/dace/cli/sdfv.py +++ b/dace/cli/sdfv.py @@ -23,7 +23,7 @@ class NewCls(cls): return NewCls -def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None): +def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None, verbose: bool = True): """ View an sdfg in the system's HTML viewer @@ -33,6 +33,7 @@ def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None): the generated HTML and related sources will be served using a basic web server on that port, blocking the current thread. + :param verbose: Be verbose. """ # If vscode is open, try to open it inside vscode if filename is None: @@ -71,7 +72,8 @@ def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None): with open(html_filename, "w") as f: f.write(html) - print("File saved at %s" % html_filename) + if(verbose): + print("File saved at %s" % html_filename) if fd is not None: os.close(fd) @@ -83,7 +85,8 @@ def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None): # start the web server handler = partialclass(http.server.SimpleHTTPRequestHandler, directory=dirname) httpd = http.server.HTTPServer(('localhost', filename), handler) - print(f"Serving at localhost:{filename}, press enter to stop...") + if(verbose): + print(f"Serving at localhost:{filename}, press enter to stop...") # start the server in a different thread def serve(): diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index 8054448ff9..4037d92992 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -19,6 +19,7 @@ from dace.sdfg import (ScopeSubgraphView, SDFG, scope_contains_scope, is_array_stream_view, NodeNotExpandedError, dynamic_map_inputs, local_transients) from dace.sdfg.scope import is_devicelevel_gpu, is_devicelevel_fpga, is_in_scope +from dace.sdfg.validation import validate_memlet_data from typing import Union from dace.codegen.targets import fpga @@ -40,7 +41,7 @@ def _visit_structure(struct: data.Structure, args: dict, prefix: str = ''): _visit_structure(v, args, f'{prefix}->{k}') elif isinstance(v, data.ContainerArray): _visit_structure(v.stype, args, f'{prefix}->{k}') - elif isinstance(v, data.Data): + if isinstance(v, data.Data): args[f'{prefix}->{k}'] = v # Keeps track of generated connectors, so we know how to access them in nested scopes @@ -624,6 +625,7 @@ def copy_memory( callsite_stream, ) + def _emit_copy( self, sdfg, @@ -641,9 +643,9 @@ def _emit_copy( orig_vconn = vconn # Determine memlet directionality - if isinstance(src_node, nodes.AccessNode) and memlet.data == src_node.data: + if isinstance(src_node, nodes.AccessNode) and validate_memlet_data(memlet.data, src_node.data): write = True - elif isinstance(dst_node, nodes.AccessNode) and memlet.data == dst_node.data: + elif isinstance(dst_node, nodes.AccessNode) and validate_memlet_data(memlet.data, dst_node.data): write = False elif isinstance(src_node, nodes.CodeNode) and isinstance(dst_node, nodes.CodeNode): # Code->Code copy (not read nor write) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 7587f84f54..c1abf82b69 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -155,6 +155,8 @@ def generate_fileheader(self, sdfg: SDFG, global_stream: CodeIOStream, backend: if arr is not None: datatypes.add(arr.dtype) + emitted = set() + def _emit_definitions(dtype: dtypes.typeclass, wrote_something: bool) -> bool: if isinstance(dtype, dtypes.pointer): wrote_something = _emit_definitions(dtype._typeclass, wrote_something) @@ -164,7 +166,10 @@ def _emit_definitions(dtype: dtypes.typeclass, wrote_something: bool) -> bool: if hasattr(dtype, 'emit_definition'): if not wrote_something: global_stream.write("", sdfg) - global_stream.write(dtype.emit_definition(), sdfg) + if dtype not in emitted: + global_stream.write(dtype.emit_definition(), sdfg) + wrote_something = True + emitted.add(dtype) return wrote_something # Emit unique definitions diff --git a/dace/config_schema.yml b/dace/config_schema.yml index 737862cacc..b26e96e920 100644 --- a/dace/config_schema.yml +++ b/dace/config_schema.yml @@ -945,10 +945,10 @@ required: serialize_all_fields: type: bool - default: true + default: false title: Serialize all unmodified fields in SDFG files description: > - If False, saving an SDFG keeps only the modified non-default properties. If True, + If False (default), saving an SDFG keeps only the modified non-default properties. If True, saves all fields. ############################################# diff --git a/dace/dtypes.py b/dace/dtypes.py index f3f27368a5..f04200e63b 100644 --- a/dace/dtypes.py +++ b/dace/dtypes.py @@ -1216,6 +1216,7 @@ def isconstant(var): int16 = typeclass(numpy.int16) int32 = typeclass(numpy.int32) int64 = typeclass(numpy.int64) +uintp = typeclass(numpy.uintp) uint8 = typeclass(numpy.uint8) uint16 = typeclass(numpy.uint16) uint32 = typeclass(numpy.uint32) @@ -1449,8 +1450,10 @@ def validate_name(name): return False if name in {'True', 'False', 'None'}: return False - if namere.match(name) is None: - return False + tokens = name.split('.') + for token in tokens: + if namere.match(token) is None: + return False return True diff --git a/dace/frontend/operations.py b/dace/frontend/operations.py index 98dff2ba1e..11df057ee5 100644 --- a/dace/frontend/operations.py +++ b/dace/frontend/operations.py @@ -1,6 +1,7 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. from __future__ import print_function from functools import partial +from itertools import chain, repeat from contextlib import contextmanager from timeit import default_timer as timer @@ -10,6 +11,7 @@ import sympy import os import sys +import warnings from dace import dtypes from dace.config import Config @@ -28,12 +30,20 @@ class CompiledSDFGProfiler: times: List[Tuple['SDFG', List[float]]] #: The list of SDFGs and times for each SDFG called within the context. - def __init__(self, repetitions: int = 0, warmup: int = 0) -> None: + def __init__( + self, + repetitions: int = 0, + warmup: int = 0, + tqdm_leave: bool = True, + print_results: bool = True, + ) -> None: # Avoid import loop from dace.codegen.instrumentation import report self.repetitions = repetitions or int(Config.get('treps')) self.warmup = warmup + self.tqdm_leave = tqdm_leave + self.print_results = print_results if self.repetitions < 1: raise ValueError('Number of repetitions must be at least 1') if self.warmup < 0: @@ -47,34 +57,45 @@ def __init__(self, repetitions: int = 0, warmup: int = 0) -> None: def __call__(self, compiled_sdfg: 'CompiledSDFG', args: Tuple[Any, ...]): from dace.codegen.instrumentation import report # Avoid import loop - start = timer() + # zeros to overwrite start time, followed by indices for each repetition + iterator = chain(repeat(0, self.warmup), range(1, self.repetitions + 1)) - times = [start] * (self.repetitions + 1) - ret = None - print('\nProfiling...') - - iterator = range(self.warmup + self.repetitions) if Config.get_bool('profiling_status'): try: from tqdm import tqdm - iterator = tqdm(iterator, desc="Profiling", file=sys.stdout) + + iterator = tqdm( + iterator, + desc='Profiling', + total=(self.warmup + self.repetitions), + file=sys.stdout, + leave=self.tqdm_leave, + ) except ImportError: - print('WARNING: Cannot show profiling progress, missing optional ' - 'dependency tqdm...\n\tTo see a live progress bar please install ' - 'tqdm (`pip install tqdm`)\n\tTo disable this feature (and ' - 'this warning) set `profiling_status` to false in the dace ' - 'config (~/.dace.conf).') + warnings.warn( + 'Cannot show profiling progress, missing optional dependency ' + 'tqdm...\n\tTo see a live progress bar please install tqdm ' + '(`pip install tqdm`)\n\tTo disable this feature (and this ' + 'warning) set `profiling_status` to false in the dace config ' + '(~/.dace.conf).' + ) + print('\nProfiling...') + else: + print('\nProfiling...') - offset = 1 - self.warmup start_time = int(time.time()) + + times = np.ndarray(self.repetitions + 1, dtype=np.float64) times[0] = timer() + for i in iterator: # Call function compiled_sdfg._cfunc(compiled_sdfg._libhandle, *args) - if i >= self.warmup: - times[i + offset] = timer() - diffs = np.array([(times[i] - times[i - 1])*1e3 for i in range(1, self.repetitions + 1)]) + times[i] = timer() + + # compute pairwise differences and convert to milliseconds + diffs = np.diff(times) * 1e3 # Add entries to the instrumentation report self.report.name = self.report.name or start_time @@ -88,8 +109,9 @@ def __call__(self, compiled_sdfg: 'CompiledSDFG', args: Tuple[Any, ...]): self.report.durations[(0, -1, -1)][f'Python call to {compiled_sdfg.sdfg.name}'][-1].extend(diffs) # Print profiling results - time_msecs = np.median(diffs) - print(compiled_sdfg.sdfg.name, time_msecs, 'ms') + if self.print_results: + time_msecs = np.median(diffs) + print(compiled_sdfg.sdfg.name, time_msecs, 'ms') # Save every call separately self.times.append((compiled_sdfg.sdfg, diffs)) @@ -105,7 +127,7 @@ def __call__(self, compiled_sdfg: 'CompiledSDFG', args: Tuple[Any, ...]): # Restore state after skipping contents compiled_sdfg.do_not_execute = old_dne - return ret + return None def detect_reduction_type(wcr_str, openmp=False): diff --git a/dace/frontend/python/interface.py b/dace/frontend/python/interface.py index 69e650beaa..ecd0b164d6 100644 --- a/dace/frontend/python/interface.py +++ b/dace/frontend/python/interface.py @@ -42,6 +42,7 @@ def program(f: F, recreate_sdfg: bool = True, regenerate_code: bool = True, recompile: bool = True, + distributed_compilation: bool = False, constant_functions=False, **kwargs) -> Callable[..., parser.DaceProgram]: """ @@ -60,6 +61,9 @@ def program(f: F, it. :param recompile: Whether to recompile the code. If False, the library in the build folder will be used if it exists, without recompiling it. + :param distributed_compilation: Whether to compile the code from rank 0, and broadcast it to all the other ranks. + If False, every rank performs the compilation. In this case, make sure to check the ``cache`` configuration entry + such that no caching or clashes can happen between different MPI processes. :param constant_functions: If True, assumes all external functions that do not depend on internal variables are constant. This will hardcode their return values into the @@ -78,7 +82,8 @@ def program(f: F, constant_functions, recreate_sdfg=recreate_sdfg, regenerate_code=regenerate_code, - recompile=recompile) + recompile=recompile, + distributed_compilation=distributed_compilation) function = program diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 3d2ec5c09d..fda2bd2e23 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -823,7 +823,7 @@ def _add_access( arr_type = type(parent_array) if arr_type == data.Scalar: self.sdfg.add_scalar(var_name, dtype) - elif arr_type in (data.Array, data.View): + elif issubclass(arr_type, data.Array): self.sdfg.add_array(var_name, shape, dtype, strides=strides) elif arr_type == data.Stream: self.sdfg.add_stream(var_name, dtype) @@ -3116,7 +3116,7 @@ def _add_access( arr_type = data.Scalar if arr_type == data.Scalar: self.sdfg.add_scalar(var_name, dtype) - elif arr_type in (data.Array, data.View): + elif issubclass(arr_type, data.Array): if non_squeezed: strides = [parent_array.strides[d] for d in non_squeezed] else: diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 1b6817a7d0..34cb8fb4ad 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -21,6 +21,12 @@ except ImportError: from typing_compat import get_origin, get_args +try: + import mpi4py + from dace.sdfg.utils import distributed_compile +except ImportError: + mpi4py = None + ArgTypes = Dict[str, Data] @@ -145,6 +151,7 @@ def __init__(self, recreate_sdfg: bool = True, regenerate_code: bool = True, recompile: bool = True, + distributed_compilation: bool = False, method: bool = False): from dace.codegen import compiled_sdfg # Avoid import loops @@ -165,6 +172,7 @@ def __init__(self, self.recreate_sdfg = recreate_sdfg self.regenerate_code = regenerate_code self.recompile = recompile + self.distributed_compilation = distributed_compilation self.global_vars = _get_locals_and_globals(f) self.signature = inspect.signature(f) @@ -443,9 +451,12 @@ def __call__(self, *args, **kwargs): sdfg.simplify() with hooks.invoke_sdfg_call_hooks(sdfg) as sdfg: - # Compile SDFG (note: this is done after symbol inference due to shape - # altering transformations such as Vectorization) - binaryobj = sdfg.compile(validate=self.validate) + if self.distributed_compilation and mpi4py: + binaryobj = distributed_compile(sdfg, mpi4py.MPI.COMM_WORLD, validate=self.validate) + else: + # Compile SDFG (note: this is done after symbol inference due to shape + # altering transformations such as Vectorization) + binaryobj = sdfg.compile(validate=self.validate) # Recreate key and add to cache cachekey = self._cache.make_key(argtypes, specified, self.closure_array_keys, self.closure_constant_keys, diff --git a/dace/frontend/python/preprocessing.py b/dace/frontend/python/preprocessing.py index 90ef506bcd..420346ca88 100644 --- a/dace/frontend/python/preprocessing.py +++ b/dace/frontend/python/preprocessing.py @@ -752,7 +752,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: return self.generic_visit(node) # Then query for the right value - if isinstance(node.value, ast.Dict): + if isinstance(node.value, ast.Dict): # Dict for k, v in zip(node.value.keys, node.value.values): try: gkey = astutils.evalnode(k, self.globals) @@ -760,8 +760,20 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: continue if gkey == gslice: return self._visit_potential_constant(v, True) - else: # List or Tuple - return self._visit_potential_constant(node.value.elts[gslice], True) + elif isinstance(node.value, (ast.List, ast.Tuple)): # List & Tuple + # Loop over the list if slicing makes it a list + if isinstance(node.value.elts[gslice], List): + visited_list = astutils.copy_tree(node.value) + visited_list.elts.clear() + for v in node.value.elts[gslice]: + visited_cst = self._visit_potential_constant(v, True) + visited_list.elts.append(visited_cst) + node.value = visited_list + return node + else: + return self._visit_potential_constant(node.value.elts[gslice], True) + else: # Catch-all + return self._visit_potential_constant(node, True) return self._visit_potential_constant(node, True) diff --git a/dace/properties.py b/dace/properties.py index 5fc9b8dcbe..d4a66476b2 100644 --- a/dace/properties.py +++ b/dace/properties.py @@ -24,35 +24,6 @@ ############################################################################### -def set_property_from_string(prop, obj, string, sdfg=None, from_json=False): - """ Interface function that guarantees that a property will always be - correctly set, if possible, by accepting all possible input arguments to - from_string. """ - - # If the property is a string (property name), obtain it from the object - if isinstance(prop, str): - prop = type(obj).__properties__[prop] - - if isinstance(prop, CodeProperty): - if from_json: - val = prop.from_json(string) - else: - val = prop.from_string(string, obj.language) - elif isinstance(prop, (ReferenceProperty, DataProperty)): - if sdfg is None: - raise ValueError("You cannot pass sdfg=None when editing a ReferenceProperty!") - if from_json: - val = prop.from_json(string, sdfg) - else: - val = prop.from_string(string, sdfg) - else: - if from_json: - val = prop.from_json(string, sdfg) - else: - val = prop.from_string(string) - setattr(obj, prop.attr_name, val) - - ############################################################################### # Property base implementation ############################################################################### @@ -74,8 +45,6 @@ def __init__( setter=None, dtype: Type[T] = None, default=None, - from_string=None, - to_string=None, from_json=None, to_json=None, meta_to_json=None, @@ -114,35 +83,8 @@ def __init__( if not isinstance(choice, dtype): raise TypeError("All choices must be an instance of dtype") - if from_string is not None: - self._from_string = from_string - elif choices is not None: - self._from_string = lambda s: choices[s] - else: - self._from_string = self.dtype - - if to_string is not None: - self._to_string = to_string - elif choices is not None: - self._to_string = lambda val: val.__name__ - else: - self._to_string = str - if from_json is None: - if self._from_string is not None: - - def fs(obj, *args, **kwargs): - if isinstance(obj, str): - # The serializer does not know about this property, so if - # we can convert using our to_string method, do that here - return self._from_string(obj) - # Otherwise ship off to the serializer, telling it which type - # it's dealing with as a sanity check - return dace.serialize.from_json(obj, *args, known_type=dtype, **kwargs) - - self._from_json = fs - else: - self._from_json = lambda *args, **kwargs: dace.serialize.from_json(*args, known_type=dtype, **kwargs) + self._from_json = lambda *args, **kwargs: dace.serialize.from_json(*args, known_type=dtype, **kwargs) else: self._from_json = from_json if self.from_json != from_json: @@ -226,7 +168,6 @@ def __set__(self, obj, val): if (self.dtype is not None and not isinstance(val, self.dtype) and not (val is None and self.allow_none)): if isinstance(val, str): raise TypeError("Received str for property {} of type {}. Use " - "dace.properties.set_property_from_string or the " "from_string method of the property.".format(self.attr_name, self.dtype)) raise TypeError("Invalid type \"{}\" for property {}: expected {}".format( type(val).__name__, self.attr_name, self.dtype.__name__)) @@ -296,14 +237,6 @@ def allow_none(self): def desc(self): return self._desc - @property - def from_string(self): - return self._from_string - - @property - def to_string(self): - return self._to_string - @property def from_json(self): return self._from_json @@ -853,8 +786,6 @@ def __init__( getter=None, setter=None, default=None, - from_string=None, - to_string=None, from_json=None, to_json=None, unmapped=False, # Don't enforce 1:1 mapping with a member variable @@ -867,8 +798,6 @@ def __init__( setter=setter, dtype=set, default=default, - from_string=from_string, - to_string=to_string, from_json=from_json, to_json=to_json, choices=None, diff --git a/dace/sdfg/graph.py b/dace/sdfg/graph.py index 5c93149529..91ed698896 100644 --- a/dace/sdfg/graph.py +++ b/dace/sdfg/graph.py @@ -526,10 +526,10 @@ def edges(self): return [DiGraph._from_nx(e) for e in self._nx.edges()] def in_edges(self, node): - return [DiGraph._from_nx(e) for e in self._nx.in_edges()] + return [DiGraph._from_nx(e) for e in self._nx.in_edges(node, True)] def out_edges(self, node): - return [DiGraph._from_nx(e) for e in self._nx.out_edges()] + return [DiGraph._from_nx(e) for e in self._nx.out_edges(node, True)] def add_node(self, node): return self._nx.add_node(node) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index f2f30d06a4..f10e728607 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -122,25 +122,6 @@ def _replace_dict_values(d, old, new): d[k] = new -def _assignments_from_string(astr): - """ Returns a dictionary of assignments from a semicolon-delimited - string of expressions. """ - - result = {} - for aitem in astr.split(';'): - aitem = aitem.strip() - m = re.search(r'([^=\s]+)\s*=\s*([^=]+)', aitem) - result[m.group(1)] = m.group(2) - - return result - - -def _assignments_to_string(assdict): - """ Returns a semicolon-delimited string from a dictionary of assignment - expressions. """ - return '; '.join(['%s=%s' % (k, v) for k, v in assdict.items()]) - - def memlets_in_ast(node: ast.AST, arrays: Dict[str, dt.Data]) -> List[mm.Memlet]: """ Generates a list of memlets from each of the subscripts that appear in the Python AST. @@ -199,9 +180,7 @@ class InterstateEdge(object): """ assignments = Property(dtype=dict, - desc="Assignments to perform upon transition (e.g., 'x=x+1; y = 0')", - from_string=_assignments_from_string, - to_string=_assignments_to_string) + desc="Assignments to perform upon transition (e.g., 'x=x+1; y = 0')") condition = CodeProperty(desc="Transition condition", default=CodeBlock("1")) def __init__(self, condition: CodeBlock = None, assignments=None): @@ -482,8 +461,8 @@ def __init__(self, :param name: Name for the SDFG (also used as the filename for the compiled shared library). - :param symbols: Additional dictionary of symbol names -> types that the SDFG - defines, apart from symbolic data sizes. + :param constants: Additional dictionary of compile-time constants + {name (str): tuple(type (dace.data.Data), value (Any))}. :param propagate: If False, disables automatic propagation of memlet subsets from scopes outwards. Saves processing time but disallows certain @@ -749,7 +728,9 @@ def replace_dict(self, # Replace in arrays and symbols (if a variable name) if replace_keys: - for name, new_name in repldict.items(): + # Filter out nested data names, as we cannot and do not want to replace names in nested data descriptors + repldict_filtered = {k: v for k, v in repldict.items() if '.' not in k} + for name, new_name in repldict_filtered.items(): if validate_name(new_name): _replace_dict_keys(self._arrays, name, new_name) _replace_dict_keys(self.symbols, name, new_name) @@ -1566,14 +1547,15 @@ def save(self, filename: str, use_pickle=False, hash=None, exception=None, compr return None - def view(self, filename=None): + def view(self, filename=None, verbose=False): """ View this sdfg in the system's HTML viewer :param filename: the filename to write the HTML to. If `None`, a temporary file will be created. + :param verbose: Be verbose, `False` by default. """ from dace.cli.sdfv import view - view(self, filename=filename) + view(self, filename=filename, verbose=verbose) @staticmethod def _from_file(fp: BinaryIO) -> 'SDFG': diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index a9f7071b0f..cafea3d754 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -389,7 +389,9 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto # Prepend incoming edges until reaching the source node curedge = edge + visited = set() while not isinstance(curedge.src, (nd.CodeNode, nd.AccessNode)): + visited.add(curedge) # Trace through scopes using OUT_# -> IN_# if isinstance(curedge.src, (nd.EntryNode, nd.ExitNode)): if curedge.src_conn is None: @@ -398,10 +400,14 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto next_edge = next(e for e in state.in_edges(curedge.src) if e.dst_conn == "IN_" + curedge.src_conn[4:]) result.insert(0, next_edge) curedge = next_edge + if curedge in visited: + raise ValueError('Cycle encountered while reading memlet path') # Append outgoing edges until reaching the sink node curedge = edge + visited.clear() while not isinstance(curedge.dst, (nd.CodeNode, nd.AccessNode)): + visited.add(curedge) # Trace through scope entry using IN_# -> OUT_# if isinstance(curedge.dst, (nd.EntryNode, nd.ExitNode)): if curedge.dst_conn is None: @@ -411,6 +417,8 @@ def memlet_path(self, edge: MultiConnectorEdge[mm.Memlet]) -> List[MultiConnecto next_edge = next(e for e in state.out_edges(curedge.dst) if e.src_conn == "OUT_" + curedge.dst_conn[3:]) result.append(next_edge) curedge = next_edge + if curedge in visited: + raise ValueError('Cycle encountered while reading memlet path') return result @@ -434,16 +442,23 @@ def memlet_tree(self, edge: MultiConnectorEdge) -> mm.MemletTree: # Find tree root curedge = edge + visited = set() if propagate_forward: while (isinstance(curedge.src, nd.EntryNode) and curedge.src_conn is not None): + visited.add(curedge) assert curedge.src_conn.startswith('OUT_') cname = curedge.src_conn[4:] curedge = next(e for e in state.in_edges(curedge.src) if e.dst_conn == 'IN_%s' % cname) + if curedge in visited: + raise ValueError('Cycle encountered while reading memlet path') elif propagate_backward: while (isinstance(curedge.dst, nd.ExitNode) and curedge.dst_conn is not None): + visited.add(curedge) assert curedge.dst_conn.startswith('IN_') cname = curedge.dst_conn[3:] curedge = next(e for e in state.out_edges(curedge.dst) if e.src_conn == 'OUT_%s' % cname) + if curedge in visited: + raise ValueError('Cycle encountered while reading memlet path') tree_root = mm.MemletTree(curedge, downwards=propagate_forward) # Collect children (recursively) @@ -2477,38 +2492,56 @@ def add_state(self, label=None, is_start_block=False, *, is_start_state: bool=No self.add_node(state, is_start_block=start_block) return state - def add_state_before(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState: + def add_state_before(self, + state: SDFGState, + label=None, + is_start_block=False, + condition: CodeBlock = None, + assignments=None, + *, + is_start_state: bool=None) -> SDFGState: """ Adds a new SDFG state before an existing state, reconnecting predecessors to it instead. :param state: The state to prepend the new state before. :param label: State label. - :param is_start_state: If True, resets scope block starting state to this state. + :param is_start_block: If True, resets scope block starting state to this state. + :param condition: Transition condition of the newly created edge between state and the new state. + :param assignments: Assignments to perform upon transition. :return: A new SDFGState object. """ - new_state = self.add_state(label, is_start_state) + new_state = self.add_state(label, is_start_block=is_start_block, is_start_state=is_start_state) # Reconnect for e in self.in_edges(state): self.remove_edge(e) self.add_edge(e.src, new_state, e.data) - # Add unconditional connection between the new state and the current - self.add_edge(new_state, state, dace.sdfg.InterstateEdge()) + # Add the new edge + self.add_edge(new_state, state, dace.sdfg.InterstateEdge(condition=condition, assignments=assignments)) return new_state - def add_state_after(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState: + def add_state_after(self, + state: SDFGState, + label=None, + is_start_block=False, + condition: CodeBlock = None, + assignments=None, + *, + is_start_state: bool=None) -> SDFGState: """ Adds a new SDFG state after an existing state, reconnecting it to the successors instead. :param state: The state to append the new state after. :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this state. + :param is_start_block: If True, resets scope block starting state to this state. + :param condition: Transition condition of the newly created edge between state and the new state. + :param assignments: Assignments to perform upon transition. :return: A new SDFGState object. """ - new_state = self.add_state(label, is_start_state) + new_state = self.add_state(label, is_start_block=is_start_block, is_start_state=is_start_state) # Reconnect for e in self.out_edges(state): self.remove_edge(e) self.add_edge(new_state, e.dst, e.data) - # Add unconditional connection between the current and the new state - self.add_edge(state, new_state, dace.sdfg.InterstateEdge()) + # Add the new edge + self.add_edge(state, new_state, dace.sdfg.InterstateEdge(condition=condition, assignments=assignments)) return new_state @abc.abstractmethod diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 68980c3b10..7311f4f028 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -585,11 +585,14 @@ def consolidate_edges_scope(state: SDFGState, scope_node: Union[nd.EntryNode, nd conn_to_remove = prefix + conn[offset:] remove_outer_connector(conn_to_remove) if isinstance(scope_node, nd.EntryNode): - out_edge = next(ed for ed in outer_edges(scope_node) if ed.dst_conn == target_conn) - edge_to_remove = next(ed for ed in outer_edges(scope_node) if ed.dst_conn == conn_to_remove) + out_edges = [ed for ed in outer_edges(scope_node) if ed.dst_conn == target_conn] + edges_to_remove = [ed for ed in outer_edges(scope_node) if ed.dst_conn == conn_to_remove] else: - out_edge = next(ed for ed in outer_edges(scope_node) if ed.src_conn == target_conn) - edge_to_remove = next(ed for ed in outer_edges(scope_node) if ed.src_conn == conn_to_remove) + out_edges = [ed for ed in outer_edges(scope_node) if ed.src_conn == target_conn] + edges_to_remove = [ed for ed in outer_edges(scope_node) if ed.src_conn == conn_to_remove] + assert len(edges_to_remove) == 1 and len(out_edges) == 1 + edge_to_remove = edges_to_remove[0] + out_edge = out_edges[0] out_edge.data.subset = sbs.union(out_edge.data.subset, edge_to_remove.data.subset) # Check if dangling connectors have been created and remove them, @@ -627,9 +630,9 @@ def remove_edge_and_dangling_path(state: SDFGState, edge: MultiConnectorEdge): e = curedge.edge state.remove_edge(e) if inwards: - neighbors = [neighbor for neighbor in state.out_edges(e.src) if e.src_conn == neighbor.src_conn] + neighbors = [] if not e.src_conn else [neighbor for neighbor in state.out_edges_by_connector(e.src, e.src_conn)] else: - neighbors = [neighbor for neighbor in state.in_edges(e.dst) if e.dst_conn == neighbor.dst_conn] + neighbors = [] if not e.dst_conn else [neighbor for neighbor in state.in_edges_by_connector(e.dst, e.dst_conn)] if len(neighbors) > 0: # There are still edges connected, leave as-is break @@ -641,7 +644,7 @@ def remove_edge_and_dangling_path(state: SDFGState, edge: MultiConnectorEdge): else: if e.dst_conn: e.dst.remove_in_connector(e.dst_conn) - e.src.remove_out_connector('OUT' + e.dst_conn[2:]) + e.dst.remove_out_connector('OUT' + e.dst_conn[2:]) # Continue traversing upwards curedge = curedge.parent @@ -1295,32 +1298,43 @@ def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, mu from dace.transformation.interstate import InlineSDFG, InlineMultistateSDFG counter = 0 - nsdfgs = [(n, p) for n, p in sdfg.all_nodes_recursive() if isinstance(n, NestedSDFG)] - - for node, state in optional_progressbar(reversed(nsdfgs), title='Inlining SDFGs', n=len(nsdfgs), progress=progress): - id = node.sdfg.cfg_id - sd = state.parent + nsdfgs = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, NestedSDFG)] + for nsdfg_node in optional_progressbar(reversed(nsdfgs), title='Inlining SDFGs', n=len(nsdfgs), progress=progress): # We have to reevaluate every time due to changing IDs - state_id = sd.node_id(state) + # e.g., InlineMultistateSDFG may fission states + parent_state = nsdfg_node.sdfg.parent + parent_sdfg = parent_state.parent + parent_state_id = parent_sdfg.node_id(parent_state) + if multistate: candidate = { - InlineMultistateSDFG.nested_sdfg: node, + InlineMultistateSDFG.nested_sdfg: nsdfg_node, } inliner = InlineMultistateSDFG() - inliner.setup_match(sd, id, state_id, candidate, 0, override=True) - if inliner.can_be_applied(state, 0, sd, permissive=permissive): - inliner.apply(state, sd) + inliner.setup_match(sdfg=parent_sdfg, + cfg_id=parent_sdfg.sdfg_id, + state_id=parent_state_id, + subgraph=candidate, + expr_index=0, + override=True) + if inliner.can_be_applied(parent_state, 0, parent_sdfg, permissive=permissive): + inliner.apply(parent_state, parent_sdfg) counter += 1 continue candidate = { - InlineSDFG.nested_sdfg: node, + InlineSDFG.nested_sdfg: nsdfg_node, } inliner = InlineSDFG() - inliner.setup_match(sd, id, state_id, candidate, 0, override=True) - if inliner.can_be_applied(state, 0, sd, permissive=permissive): - inliner.apply(state, sd) + inliner.setup_match(sdfg=parent_sdfg, + cfg_id=parent_sdfg.sdfg_id, + state_id=parent_state_id, + subgraph=candidate, + expr_index=0, + override=True) + if inliner.can_be_applied(parent_state, 0, parent_sdfg, permissive=permissive): + inliner.apply(parent_state, parent_sdfg) counter += 1 return counter @@ -1341,7 +1355,7 @@ def load_precompiled_sdfg(folder: str): csdfg.ReloadableDLL(os.path.join(folder, 'build', f'lib{sdfg.name}.{suffix}'), sdfg.name)) -def distributed_compile(sdfg: SDFG, comm) -> csdfg.CompiledSDFG: +def distributed_compile(sdfg: SDFG, comm, validate: bool = True) -> csdfg.CompiledSDFG: """ Compiles an SDFG in rank 0 of MPI communicator ``comm``. Then, the compiled SDFG is loaded in all other ranks. @@ -1357,7 +1371,7 @@ def distributed_compile(sdfg: SDFG, comm) -> csdfg.CompiledSDFG: # Rank 0 compiles SDFG. if rank == 0: - func = sdfg.compile() + func = sdfg.compile(validate=validate) folder = sdfg.build_folder # Broadcasts build folder. diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 299ffc96fa..660e45e574 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -981,3 +981,20 @@ def __str__(self): locinfo += f'\nInvalid SDFG saved for inspection in {os.path.abspath(self.path)}' return f'{self.message} (at state {state.label}{edgestr}){locinfo}' + + +def validate_memlet_data(memlet_data: str, access_data: str) -> bool: + """ Validates that the src/dst access node data matches the memlet data. + + :param memlet_data: The data of the memlet. + :param access_data: The data of the access node. + :return: True if the memlet data matches the access node data. + """ + if memlet_data == access_data: + return True + if memlet_data is None or access_data is None: + return False + access_tokens = access_data.split('.') + memlet_tokens = memlet_data.split('.') + mem_root = '.'.join(memlet_tokens[:len(access_tokens)]) + return mem_root == access_data diff --git a/dace/transformation/dataflow/map_dim_shuffle.py b/dace/transformation/dataflow/map_dim_shuffle.py index 7b4114b188..ad17a5ddac 100644 --- a/dace/transformation/dataflow/map_dim_shuffle.py +++ b/dace/transformation/dataflow/map_dim_shuffle.py @@ -27,18 +27,19 @@ def expressions(cls): return [sdutil.node_path_graph(cls.map_entry)] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + map_entry: nodes.MapEntry = self.map_entry + if self.parameters is None: + return False + if len(self.parameters) != len(map_entry.map.params): + return False + if set(self.parameters) != set(map_entry.map.params): + return False return True def apply(self, graph: SDFGState, sdfg: SDFG): - map_entry = self.map_entry - if self.parameters is None: - return - - if set(self.parameters) != set(map_entry.map.params): - return + map_entry: nodes.MapEntry = self.map_entry + new_map_order: list[int] = [map_entry.map.params.index(param) for param in self.parameters] - map_entry.range.ranges = [ - r for list_param in self.parameters for map_param, r in zip(map_entry.map.params, map_entry.range.ranges) - if list_param == map_param - ] - map_entry.map.params = self.parameters + map_entry.range.ranges = [map_entry.range.ranges[new_pos] for new_pos in new_map_order] + map_entry.range.tile_sizes = [map_entry.range.tile_sizes[new_pos] for new_pos in new_map_order] + map_entry.map.params = [map_entry.map.params[new_pos] for new_pos in new_map_order] diff --git a/dace/transformation/dataflow/map_expansion.py b/dace/transformation/dataflow/map_expansion.py index 9d89ec7c09..8bc14213b0 100644 --- a/dace/transformation/dataflow/map_expansion.py +++ b/dace/transformation/dataflow/map_expansion.py @@ -6,7 +6,7 @@ import copy import dace from dace import dtypes, subsets, symbolic -from dace.properties import EnumProperty, make_properties +from dace.properties import EnumProperty, make_properties, Property from dace.sdfg import nodes from dace.sdfg import utils as sdutil from dace.sdfg.graph import OrderedMultiDiConnectorGraph @@ -18,8 +18,9 @@ class MapExpansion(pm.SingleStateTransformation): """ Implements the map-expansion pattern. - Map-expansion takes an N-dimensional map and expands it to N - unidimensional maps. + Map-expansion takes an N-dimensional map and expands it. + It will generate the k nested unidimensional map and a (N-k)-dimensional inner most map. + If k is not specified all maps are expanded. New edges abide by the following rules: 1. If there are no edges coming from the outside, use empty memlets @@ -33,6 +34,11 @@ class MapExpansion(pm.SingleStateTransformation): dtype=dtypes.ScheduleType, default=dtypes.ScheduleType.Sequential, allow_none=True) + expansion_limit = Property(desc="How many unidimensional maps will be creaed, known as k. " + "If None, the default no limit is in place.", + dtype=int, + allow_none=True, + default=None) @classmethod def expressions(cls): @@ -43,22 +49,77 @@ def can_be_applied(self, graph: dace.SDFGState, expr_index: int, sdfg: dace.SDFG # includes an N-dimensional map, with N greater than one. return self.map_entry.map.get_param_num() > 1 + def generate_new_maps(self, + current_map: nodes.Map): + if self.expansion_limit is None: + full_expand = True + elif isinstance(self.expansion_limit, int): + full_expand = False + if self.expansion_limit <= 0: # These are invalid, so we make a full expansion + full_expand = True + elif (self.map_entry.map.get_param_num() - self.expansion_limit) <= 1: + full_expand = True + else: + raise TypeError(f"Does not know how to handle type {type(self.expansion_limit).__name__}") + + inner_schedule = self.inner_schedule or current_map.schedule + if full_expand: + new_maps = [ + nodes.Map( + current_map.label + '_' + str(param), [param], + subsets.Range([param_range]), + schedule=inner_schedule if dim != 0 else current_map.schedule) + for dim, param, param_range in zip(range(len(current_map.params)), current_map.params, current_map.range) + ] + for i, new_map in enumerate(new_maps): + new_map.range.tile_sizes[0] = current_map.range.tile_sizes[i] + + else: + k = self.expansion_limit + new_maps: list[nodes.Map] = [] + + # Unidimensional maps + for dim in range(0, k): + dim_param = current_map.params[dim] + dim_range = current_map.range.ranges[dim] + dim_tile = current_map.range.tile_sizes[dim] + new_maps.append( + nodes.Map( + current_map.label + '_' + str(dim_param), + [dim_param], + subsets.Range([dim_range]), + schedule=inner_schedule if dim != 0 else current_map.schedule )) + new_maps[-1].range.tile_sizes[0] = dim_tile + + # Multidimensional maps + mdim_params = current_map.params[k:] + mdim_ranges = current_map.range.ranges[k:] + mdim_tiles = current_map.range.tile_sizes[k:] + new_maps.append( + nodes.Map( + current_map.label, # The original name + mdim_params, + mdim_ranges, + schedule=inner_schedule )) + new_maps[-1].range.tile_sizes = mdim_tiles + return new_maps + def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): # Extract the map and its entry and exit nodes. map_entry = self.map_entry map_exit = graph.exit_node(map_entry) current_map = map_entry.map - # Create new maps - inner_schedule = self.inner_schedule or current_map.schedule - new_maps = [ - nodes.Map(current_map.label + '_' + str(param), [param], - subsets.Range([param_range]), - schedule=inner_schedule) - for param, param_range in zip(current_map.params[1:], current_map.range[1:]) - ] - current_map.params = [current_map.params[0]] - current_map.range = subsets.Range([current_map.range[0]]) + # Generate the new maps that we should use. + new_maps = self.generate_new_maps(current_map) + + if not new_maps: # No changes should be made -> noops + return + + # Reuse the map that is already existing for the first one. + current_map.params = new_maps[0].params + current_map.range = new_maps[0].range + new_maps.pop(0) # Create new map entries and exits entries = [nodes.MapEntry(new_map) for new_map in new_maps] diff --git a/dace/transformation/dataflow/map_fusion.py b/dace/transformation/dataflow/map_fusion.py index 9a0dd0e313..186ea32acc 100644 --- a/dace/transformation/dataflow/map_fusion.py +++ b/dace/transformation/dataflow/map_fusion.py @@ -481,6 +481,12 @@ def fuse_nodes(self, sdfg, graph, edge, new_dst, new_dst_conn, other_edges=None) local_node = edge.src src_connector = edge.src_conn + # update edge data in case source or destination is a scalar access node + test_data = [node.data for node in (edge.src, edge.dst) if isinstance(node, nodes.AccessNode)] + for new_data in test_data: + if isinstance(sdfg.arrays[new_data], data.Scalar): + edge.data.data = new_data + # If destination of edge leads to multiple destinations, redirect all through an access node. if other_edges: # NOTE: If a new local node was already created, reuse it. diff --git a/dace/transformation/dataflow/prune_connectors.py b/dace/transformation/dataflow/prune_connectors.py index 865f28f7d9..36352fef0d 100644 --- a/dace/transformation/dataflow/prune_connectors.py +++ b/dace/transformation/dataflow/prune_connectors.py @@ -57,40 +57,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): nsdfg = self.nsdfg # Fission subgraph around nsdfg into its own state to avoid data races - predecessors = set() - for inedge in state.in_edges(nsdfg): - if inedge.data is None: - continue - - pred = state.memlet_path(inedge)[0].src - if state.in_degree(pred) == 0: - continue - - predecessors.add(pred) - for e in state.bfs_edges(pred, reverse=True): - predecessors.add(e.src) - - subgraph = StateSubgraphView(state, predecessors) - pred_state = helpers.state_fission(sdfg, subgraph) - - subgraph_nodes = set() - subgraph_nodes.add(nsdfg) - for inedge in state.in_edges(nsdfg): - if inedge.data is None: - continue - path = state.memlet_path(inedge) - for edge in path: - subgraph_nodes.add(edge.src) - - for oedge in state.out_edges(nsdfg): - if oedge.data is None: - continue - path = state.memlet_path(oedge) - for edge in path: - subgraph_nodes.add(edge.dst) - - subgraph = StateSubgraphView(state, subgraph_nodes) - nsdfg_state = helpers.state_fission(sdfg, subgraph) + nsdfg_state = helpers.state_fission_after(sdfg, state, nsdfg) read_set, write_set = nsdfg.sdfg.read_and_write_sets() prune_in = nsdfg.in_connectors.keys() - read_set diff --git a/dace/transformation/dataflow/wcr_conversion.py b/dace/transformation/dataflow/wcr_conversion.py index 3ef508f7e5..1a0ecf6bc4 100644 --- a/dace/transformation/dataflow/wcr_conversion.py +++ b/dace/transformation/dataflow/wcr_conversion.py @@ -77,8 +77,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # If in map, only match if the subset is independent of any # map indices (otherwise no conflict) - if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len( - me.map.params): + if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len(me.map.params): return False # Get relevant output connector @@ -151,18 +150,16 @@ def apply(self, state: SDFGState, sdfg: SDFG): # If state fission is necessary to keep semantics, do it first if state.in_degree(input) > 0: - subgraph_nodes = set([e.src for e in state.bfs_edges(input, reverse=True)]) - subgraph_nodes.add(input) - - subgraph = StateSubgraphView(state, subgraph_nodes) - helpers.state_fission(sdfg, subgraph) + new_state = helpers.state_fission_after(sdfg, state, tasklet) + else: + new_state = state if self.expr_index == 0: - inedges = state.edges_between(input, tasklet) - outedge = state.edges_between(tasklet, output)[0] + inedges = new_state.edges_between(input, tasklet) + outedge = new_state.edges_between(tasklet, output)[0] else: - inedges = state.edges_between(me, tasklet) - outedge = state.edges_between(tasklet, mx)[0] + inedges = new_state.edges_between(me, tasklet) + outedge = new_state.edges_between(tasklet, mx)[0] # Get relevant output connector outconn = outedge.src_conn @@ -253,8 +250,8 @@ def apply(self, state: SDFGState, sdfg: SDFG): outedge.data.wcr = f'lambda a,b: a {op} b' # Remove input node and connector - state.remove_memlet_path(inedge) - propagate_memlets_state(sdfg, state) + new_state.remove_memlet_path(inedge) + propagate_memlets_state(sdfg, new_state) # If outedge leads to non-transient, and this is a nested SDFG, # propagate outwards diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index c39d744c39..cd73b96a68 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -687,6 +687,85 @@ def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView, label: Optional[str] return newstate +def state_fission_after(sdfg: SDFG, state: SDFGState, node: nodes.Node, label: Optional[str] = None) -> SDFGState: + """ + """ + newstate = sdfg.add_state_after(state, label=label) + + # Bookkeeping + nodes_to_move = set([node]) + boundary_nodes = set() + orig_edges = set() + + # Collect predecessors + if not isinstance(node, nodes.AccessNode): + for edge in state.in_edges(node): + for e in state.memlet_path(edge): + nodes_to_move.add(e.src) + orig_edges.add(e) + + # Collect nodes_to_move + for edge in state.bfs_edges(node): + nodes_to_move.add(edge.dst) + orig_edges.add(edge) + + if not isinstance(edge.dst, nodes.AccessNode): + for iedge in state.in_edges(edge.dst): + if iedge == edge: + continue + + for e in state.memlet_path(iedge): + nodes_to_move.add(e.src) + orig_edges.add(e) + + # Define boundary nodes + for node in set(nodes_to_move): + if isinstance(node, nodes.AccessNode): + for iedge in state.in_edges(node): + if iedge.src not in nodes_to_move: + boundary_nodes.add(node) + break + + if node in boundary_nodes: + continue + + for oedge in state.out_edges(node): + if oedge.dst not in nodes_to_move: + boundary_nodes.add(node) + break + + # Duplicate boundary nodes + new_nodes = {} + for node in boundary_nodes: + node_ = copy.deepcopy(node) + state.add_node(node_) + new_nodes[node] = node_ + + for edge in state.edges(): + if edge.src in boundary_nodes and edge.dst in boundary_nodes: + state.add_edge(new_nodes[edge.src], edge.src_conn, new_nodes[edge.dst], edge.dst_conn, + copy.deepcopy(edge.data)) + elif edge.src in boundary_nodes: + state.add_edge(new_nodes[edge.src], edge.src_conn, edge.dst, edge.dst_conn, copy.deepcopy(edge.data)) + elif edge.dst in boundary_nodes: + state.add_edge(edge.src, edge.src_conn, new_nodes[edge.dst], edge.dst_conn, copy.deepcopy(edge.data)) + + # Move nodes + state.remove_nodes_from(nodes_to_move) + + for n in nodes_to_move: + if isinstance(n, nodes.NestedSDFG): + # Set the new parent state + n.sdfg.parent = newstate + + newstate.add_nodes_from(nodes_to_move) + + for e in orig_edges: + newstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data) + + return newstate + + def _get_internal_subset(internal_memlet: Memlet, external_memlet: Memlet, use_src_subset: bool = False, diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index 8623bdf468..0e4f1b4852 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -20,6 +20,7 @@ from dace.transformation import transformation, helpers from dace.properties import make_properties, Property from dace import data +from dace.sdfg.state import StateSubgraphView @make_properties @@ -85,56 +86,48 @@ def can_be_applied(self, state: SDFGState, expr_index, sdfg, permissive=False): if nested_sdfg.schedule == dtypes.ScheduleType.FPGA_Device: return False - # Ensure the state only contains a nested SDFG and input/output access - # nodes - for node in state.nodes(): - if isinstance(node, nodes.NestedSDFG): - if node is not nested_sdfg: - return False - elif isinstance(node, nodes.AccessNode): - # Must be connected to nested SDFG - # if nested_sdfg in state.predecessors(nested_sdfg): - # if state.in_degree(node) > 0: - # return False - found = False - for e in state.out_edges(node): - if e.dst is not nested_sdfg: - return False - if state.in_degree(node) > 0: - return False - # Only accept full ranges for now. TODO(later): Improve - if e.data.subset != subsets.Range.from_array(sdfg.arrays[node.data]): - return False - if e.dst_conn in nested_sdfg.sdfg.arrays: - # Do not accept views. TODO(later): Improve - outer_desc = sdfg.arrays[node.data] - inner_desc = nested_sdfg.sdfg.arrays[e.dst_conn] - if (outer_desc.shape != inner_desc.shape or outer_desc.strides != inner_desc.strides): - return False - found = True - - for e in state.in_edges(node): - if e.src is not nested_sdfg: - return False - if state.out_degree(node) > 0: - return False - # Only accept full ranges for now. TODO(later): Improve - if e.data.subset != subsets.Range.from_array(sdfg.arrays[node.data]): - return False - if e.src_conn in nested_sdfg.sdfg.arrays: - # Do not accept views. TODO(later): Improve - outer_desc = sdfg.arrays[node.data] - inner_desc = nested_sdfg.sdfg.arrays[e.src_conn] - if (outer_desc.shape != inner_desc.shape or outer_desc.strides != inner_desc.strides): - return False - found = True - - # elif nested_sdfg in state.successors(nested_sdfg): - # if state.out_degree(node) > 0: - # return False - if not found: - return False - else: + # Not nested in scope + if state.entry_node(nested_sdfg) is not None: + return False + + # Must be + # - connected to access nodes only + # - read full subsets + # - not use views inside + for edge in state.in_edges(nested_sdfg): + if edge.data.data is None: + return False + + if not isinstance(edge.src, nodes.AccessNode): + return False + + if edge.data.subset != subsets.Range.from_array(sdfg.arrays[edge.data.data]): + return False + + outer_desc = sdfg.arrays[edge.data.data] + if isinstance(outer_desc, data.View): + return False + + inner_desc = nested_sdfg.sdfg.arrays[edge.dst_conn] + if (outer_desc.shape != inner_desc.shape or outer_desc.strides != inner_desc.strides): + return False + + for edge in state.out_edges(nested_sdfg): + if edge.data.data is None: + return False + + if not isinstance(edge.dst, nodes.AccessNode): + return False + + if edge.data.subset != subsets.Range.from_array(sdfg.arrays[edge.data.data]): + return False + + outer_desc = sdfg.arrays[edge.data.data] + if isinstance(outer_desc, data.View): + return False + + inner_desc = nested_sdfg.sdfg.arrays[edge.src_conn] + if (outer_desc.shape != inner_desc.shape or outer_desc.strides != inner_desc.strides): return False return True @@ -168,16 +161,27 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): for ise in sdfg.edges(): outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols)) + # Isolate nsdfg in a separate state + # 1. Push nsdfg node plus dependencies down into new state + nsdfg_state = helpers.state_fission_after(sdfg, outer_state, nsdfg_node) + # 2. Push successors of nsdfg node into a later state + direct_subgraph = set() + direct_subgraph.add(nsdfg_node) + direct_subgraph.update(nsdfg_state.predecessors(nsdfg_node)) + direct_subgraph.update(nsdfg_state.successors(nsdfg_node)) + direct_subgraph = StateSubgraphView(nsdfg_state, direct_subgraph) + nsdfg_state = helpers.state_fission(sdfg, direct_subgraph) + # Find original source/destination edges (there is only one edge per # connector, according to match) inputs: Dict[str, MultiConnectorEdge] = {} outputs: Dict[str, MultiConnectorEdge] = {} input_set: Dict[str, str] = {} output_set: Dict[str, str] = {} - for e in outer_state.in_edges(nsdfg_node): + for e in nsdfg_state.in_edges(nsdfg_node): inputs[e.dst_conn] = e input_set[e.data.data] = e.dst_conn - for e in outer_state.out_edges(nsdfg_node): + for e in nsdfg_state.out_edges(nsdfg_node): outputs[e.src_conn] = e output_set[e.data.data] = e.src_conn @@ -260,7 +264,6 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) transients[edge.data.data] = name - # All constants (and associated transients) become constants of the parent for cstname, (csttype, cstval) in nsdfg.constants_prop.items(): if cstname in sdfg.constants: @@ -273,7 +276,6 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): else: sdfg.constants_prop[cstname] = (csttype, cstval) - ####################################################### # Replace data on inlined SDFG nodes/edges @@ -352,9 +354,9 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): sinks = nsdfg.sink_nodes() # Reconnect state machine - for e in sdfg.in_edges(outer_state): + for e in sdfg.in_edges(nsdfg_state): sdfg.add_edge(e.src, source, e.data) - for e in sdfg.out_edges(outer_state): + for e in sdfg.out_edges(nsdfg_state): for sink in sinks: sdfg.add_edge(sink, e.dst, dc(e.data)) # Redirect sink incoming edges with a `False` condition to e.dst (return statements) @@ -363,7 +365,7 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): sdfg.add_edge(e2.src, e.dst, InterstateEdge()) # Modify start state as necessary - if outer_start_state is outer_state: + if outer_start_state is nsdfg_state: sdfg.start_state = sdfg.node_id(source) # TODO: Modify memlets by offsetting @@ -418,7 +420,7 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): ####################################################### # Remove nested SDFG and state - sdfg.remove_node(outer_state) + sdfg.remove_node(nsdfg_state) sdfg._cfg_list = sdfg.reset_cfg_list() diff --git a/requirements.txt b/requirements.txt index f06f3421cd..e98e33fe74 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,9 +5,9 @@ charset-normalizer==3.1.0 click==8.1.3 dill==0.3.6 fparser==0.1.3 -idna==3.4 +idna==3.7 importlib-metadata==6.6.0 -Jinja2==3.1.3 +Jinja2==3.1.4 MarkupSafe==2.1.3 mpmath==1.3.0 networkx==3.1 diff --git a/tests/inlining_test.py b/tests/inlining_test.py index d207aa6c2c..7c3510daed 100644 --- a/tests/inlining_test.py +++ b/tests/inlining_test.py @@ -127,15 +127,16 @@ def outerprog(A: dace.float64[20]): nested(A) sdfg = outerprog.to_sdfg(simplify=True) - from dace.transformation.interstate import InlineMultistateSDFG - sdfg.apply_transformations(InlineMultistateSDFG) - assert sdfg.number_of_nodes() in (4, 5) A = np.random.rand(20) expected = np.copy(A) outerprog.f(expected) - outerprog(A) + from dace.transformation.interstate import InlineMultistateSDFG + sdfg.apply_transformations(InlineMultistateSDFG) + assert sdfg.number_of_nodes() in (4, 5) + + sdfg(A) assert np.allclose(A, expected) @@ -152,18 +153,105 @@ def outerprog(A: dace.float64[20]): nested(A) sdfg = outerprog.to_sdfg(simplify=True) - from dace.transformation.interstate import InlineMultistateSDFG - sdfg.apply_transformations(InlineMultistateSDFG) - assert sdfg.number_of_nodes() in (7, 8) A = np.random.rand(20) expected = np.copy(A) outerprog.f(expected) - outerprog(A) + from dace.transformation.interstate import InlineMultistateSDFG + sdfg.apply_transformations(InlineMultistateSDFG) + assert sdfg.number_of_nodes() in (7, 8) + + sdfg(A) assert np.allclose(A, expected) +def test_multistate_inline_outer_dependencies(): + + @dace.program + def nested(A: dace.float64[20]): + for i in range(1, 20): + A[i] += A[i - 1] + + @dace.program + def outerprog(A: dace.float64[20], B: dace.float64[20]): + for i in dace.map[0:20]: + with dace.tasklet: + a >> A[i] + b >> B[i] + + a = 0 + b = 1 + + nested(A) + + for i in dace.map[0:20]: + with dace.tasklet: + a << A[i] + b >> A[i] + + b = 2 * a + + sdfg = outerprog.to_sdfg(simplify=False) + sdfg.apply_transformations_repeated((StateFusion, InlineSDFG)) + assert len(sdfg.states()) == 1 + + A = np.random.rand(20) + B = np.random.rand(20) + expected_a = np.copy(A) + expected_b = np.copy(B) + outerprog.f(expected_a, expected_b) + + from dace.transformation.interstate import InlineMultistateSDFG + sdfg.apply_transformations(InlineMultistateSDFG) + + sdfg(A, B) + assert np.allclose(A, expected_a) + assert np.allclose(B, expected_b) + + +def test_multistate_inline_concurrent_subgraphs(): + + @dace.program + def nested(A: dace.float64[10], B: dace.float64[10]): + for i in range(1, 10): + B[i] = A[i] + + @dace.program + def outerprog(A: dace.float64[10], B: dace.float64[10], C: dace.float64[10]): + nested(A, B) + + for i in dace.map[0:10]: + with dace.tasklet: + a << A[i] + c >> C[i] + + c = 2 * a + + sdfg = outerprog.to_sdfg(simplify=False) + dace.propagate_memlets_sdfg(sdfg) + sdfg.apply_transformations_repeated((StateFusion, InlineSDFG)) + assert len(sdfg.states()) == 1 + assert len([node for node in sdfg.start_state.data_nodes()]) == 3 + + A = np.random.rand(10) + B = np.random.rand(10) + C = np.random.rand(10) + expected_a = np.copy(A) + expected_b = np.copy(B) + expected_c = np.copy(C) + outerprog.f(expected_a, expected_b, expected_c) + + from dace.transformation.interstate import InlineMultistateSDFG + applied = sdfg.apply_transformations(InlineMultistateSDFG) + assert applied == 1 + + sdfg(A, B, C) + assert np.allclose(A, expected_a) + assert np.allclose(B, expected_b) + assert np.allclose(C, expected_c) + + def test_inline_symexpr(): nsdfg = dace.SDFG('inner') nsdfg.add_array('a', [20], dace.float64) @@ -372,6 +460,8 @@ def test(A: dace.float64[96, 32], B: dace.float64[42, 32]): # test_regression_reshape_unsqueeze() test_empty_memlets() test_multistate_inline() + test_multistate_inline_outer_dependencies() + test_multistate_inline_concurrent_subgraphs() test_multistate_inline_samename() test_inline_symexpr() test_inline_unsqueeze() diff --git a/tests/python_frontend/unroll_test.py b/tests/python_frontend/unroll_test.py index 98c81156a0..bf2b1e7c91 100644 --- a/tests/python_frontend/unroll_test.py +++ b/tests/python_frontend/unroll_test.py @@ -169,6 +169,52 @@ def tounroll(A: dace.float64[3]): assert np.allclose(a, np.array([1, 2, 3])) +def test_list_global_enumerate(): + tracer_variables = ["vapor", "rain", "nope"] + + @dace.program + def enumerate_parsing( + A, + tracers: dace.compiletime, # Dict[str, np.float64] + ): + for i, q in enumerate(tracer_variables[0:2]): + tracers[q][:] = A # type:ignore + + a = np.ones([3]) + q = { + "vapor": np.zeros([3]), + "rain": np.zeros([3]), + "nope": np.zeros([3]), + } + enumerate_parsing(a, q) + assert np.allclose(q["vapor"], np.array([1, 1, 1])) + assert np.allclose(q["rain"], np.array([1, 1, 1])) + assert np.allclose(q["nope"], np.array([0, 0, 0])) + + +def test_tuple_global_enumerate(): + tracer_variables = ("vapor", "rain", "nope") + + @dace.program + def enumerate_parsing( + A, + tracers: dace.compiletime, # Dict[str, np.float64] + ): + for i, q in enumerate(tracer_variables[0:2]): + tracers[q][:] = A # type:ignore + + a = np.ones([3]) + q = { + "vapor": np.zeros([3]), + "rain": np.zeros([3]), + "nope": np.zeros([3]), + } + enumerate_parsing(a, q) + assert np.allclose(q["vapor"], np.array([1, 1, 1])) + assert np.allclose(q["rain"], np.array([1, 1, 1])) + assert np.allclose(q["nope"], np.array([0, 0, 0])) + + def test_tuple_elements_zip(): a1 = [2, 3, 4] a2 = (4, 5, 6) diff --git a/tests/sdfg/cycles_test.py b/tests/sdfg/cycles_test.py index 5e94db2eb4..480392ab2d 100644 --- a/tests/sdfg/cycles_test.py +++ b/tests/sdfg/cycles_test.py @@ -1,3 +1,4 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import pytest import dace @@ -13,3 +14,21 @@ def test_cycles(): state.add_edge(access, None, access, None, dace.Memlet.simple("A", "0")) sdfg.validate() + + +def test_cycles_memlet_path(): + with pytest.raises(ValueError, match="Found cycles.*"): + sdfg = dace.SDFG("foo") + state = sdfg.add_state() + sdfg.add_array("bla", shape=(10, ), dtype=dace.float32) + mentry_3, _ = state.add_map("map_3", dict(i="0:9")) + mentry_3.add_in_connector("IN_0") + mentry_3.add_out_connector("OUT_0") + state.add_edge(mentry_3, "OUT_0", mentry_3, "IN_0", dace.Memlet(data="bla", subset='0:9')) + + sdfg.validate() + + +if __name__ == '__main__': + test_cycles() + test_cycles_memlet_path() diff --git a/tests/sdfg/data/container_array_test.py b/tests/sdfg/data/container_array_test.py index 7685361d0f..091bb487d8 100644 --- a/tests/sdfg/data/container_array_test.py +++ b/tests/sdfg/data/container_array_test.py @@ -258,8 +258,67 @@ def test_two_levels(): assert np.allclose(ref, B[0]) +def test_multi_nested_containers(): + + M, N = dace.symbol('M'), dace.symbol('N') + sdfg = dace.SDFG('tester') + float_desc = dace.data.Scalar(dace.float32) + E_desc = dace.data.Structure({'F': dace.float32[N], 'G':float_desc}, 'InnerStruct') + B_desc = dace.data.ContainerArray(E_desc, [M]) + A_desc = dace.data.Structure({'B': B_desc, 'C': dace.float32[M], 'D': float_desc}, 'OuterStruct') + sdfg.add_datadesc('A', A_desc) + sdfg.add_datadesc_view('vB', B_desc) + sdfg.add_datadesc_view('vE', E_desc) + sdfg.add_array('out', [M, N], dace.float32) + + state = sdfg.add_state() + rA = state.add_read('A') + vB = state.add_access('vB') + vE = state.add_access('vE') + wout = state.add_write('out') + + me, mx = state.add_map('outer_product', dict(i='0:M', j='0:N')) + tasklet = state.add_tasklet('outer_product', {'__in_A_B_E_F', '__in_A_B_E_G', '__in_A_C', '__in_A_D'}, {'__out'}, + '__out = (__in_A_B_E_F + __in_A_B_E_G) * (__in_A_C + __in_A_D)') + + state.add_edge(rA, None, vB, 'views', dace.Memlet('A.B')) + state.add_memlet_path(vB, me, vE, dst_conn='views', memlet=dace.Memlet('vB[i]')) + state.add_edge(vE, None, tasklet, '__in_A_B_E_F', dace.Memlet('vE.F[j]')) + state.add_edge(vE, None, tasklet, '__in_A_B_E_G', dace.Memlet(data='vE.G', subset='0')) + state.add_memlet_path(rA, me, tasklet, dst_conn='__in_A_C', memlet=dace.Memlet('A.C[i]')) + state.add_memlet_path(rA, me, tasklet, dst_conn='__in_A_D', memlet=dace.Memlet(data='A.D', subset='0')) + state.add_memlet_path(tasklet, mx, wout, src_conn='__out', memlet=dace.Memlet('out[i, j]')) + + c_data = np.arange(5, dtype=np.float32) + f_data = np.arange(5 * 3, dtype=np.float32).reshape(5, 3) + + e_class = E_desc.dtype._typeclass.as_ctypes() + b_obj = [] + b_data = np.ndarray((5, ), dtype=ctypes.c_void_p) + for i in range(5): + f_obj = f_data[i].__array_interface__['data'][0] + e_obj = e_class(F=f_obj, G=ctypes.c_float(0.1)) + b_obj.append(e_obj) # NOTE: This is needed to keep the object alive ... + b_data[i] = ctypes.addressof(e_obj) + a_dace = A_desc.dtype._typeclass.as_ctypes()(B=b_data.__array_interface__['data'][0], + C=c_data.__array_interface__['data'][0], + D=ctypes.c_float(0.2)) + + + + + out_dace = np.empty((5, 3), dtype=np.float32) + ref = np.empty((5, 3), dtype=np.float32) + for i in range(5): + ref[i] = (f_data[i] + 0.1) * (c_data[i] + 0.2) + + sdfg(A=a_dace, out=out_dace, M=5, N=3) + assert np.allclose(out_dace, ref) + + if __name__ == '__main__': test_read_struct_array() test_write_struct_array() test_jagged_container_array() test_two_levels() + test_multi_nested_containers() diff --git a/tests/transformations/map_dim_shuffle_test.py b/tests/transformations/map_dim_shuffle_test.py index e0eb3f4311..1d9c73e5a2 100644 --- a/tests/transformations/map_dim_shuffle_test.py +++ b/tests/transformations/map_dim_shuffle_test.py @@ -36,6 +36,9 @@ def test_map_dim_shuffle(): sdfg(A=A, B=B) assert np.allclose(B, expected) + assert sdfg.apply_transformations_repeated(MapDimShuffle, options={"parameters": ["k", "i"]}) == 0 + assert sdfg.apply_transformations_repeated(MapDimShuffle, options={"parameters": ["k", "i", "l"]}) == 0 + if __name__ == '__main__': test_map_dim_shuffle() diff --git a/tests/transformations/map_expansion_test.py b/tests/transformations/map_expansion_test.py index 1f9a97f810..6e4b965ba2 100644 --- a/tests/transformations/map_expansion_test.py +++ b/tests/transformations/map_expansion_test.py @@ -73,7 +73,7 @@ def toexpand(B: dace.float64[4, 4]): continue # (Fast) MapExpansion should not add memlet paths for each memlet to a tasklet - if sdfg.start_state.entry_node(node) is None: + if state.entry_node(node) is None: assert state.in_degree(node) == 0 assert state.out_degree(node) == 1 assert len(node.out_connectors) == 0 @@ -113,7 +113,58 @@ def mymap(i: _[0:20], j: _[rng[0]:rng[1]], k: _[0:5]): print('Difference:', diff2) assert (diff <= 1e-5) and (diff2 <= 1e-5) +def test_expand_with_limits(): + @dace.program + def expansion(A: dace.float32[20, 30, 5]): + @dace.map + def mymap(i: _[0:20], j: _[0:30], k: _[0:5]): + a << A[i, j, k] + b >> A[i, j, k] + b = a * 2 + + A = np.random.rand(20, 30, 5).astype(np.float32) + expected = A.copy() + expected *= 2 + + sdfg = expansion.to_sdfg() + sdfg.simplify() + sdfg(A=A) + diff = np.linalg.norm(A - expected) + print('Difference (before transformation):', diff) + + sdfg.apply_transformations(MapExpansion, options=dict(expansion_limit=1)) + + map_entries = set() + state = sdfg.start_state + for node in state.nodes(): + if not isinstance(node, dace.nodes.MapEntry): + continue + + if state.entry_node(node) is None: + assert state.in_degree(node) == 1 + assert state.out_degree(node) == 1 + assert len(node.out_connectors) == 1 + assert len(node.map.range.ranges) == 1 + assert node.map.range.ranges[0][1] - node.map.range.ranges[0][0] + 1 == 20 + else: + assert state.in_degree(node) == 1 + assert state.out_degree(node) == 1 + assert len(node.out_connectors) == 1 + assert len(node.map.range.ranges) == 2 + assert list(map(lambda x: x[1] - x[0] + 1, node.map.range.ranges)) == [30, 5] + + map_entries.add(node) + + sdfg(A=A) + expected *= 2 + diff2 = np.linalg.norm(A - expected) + print('Difference:', diff2) + assert (diff <= 1e-5) and (diff2 <= 1e-5) + assert len(map_entries) == 2 + + if __name__ == '__main__': test_expand_with_inputs() test_expand_without_inputs() test_expand_without_dynamic_inputs() + test_expand_with_limits() diff --git a/tests/transformations/mapfusion_test.py b/tests/transformations/mapfusion_test.py index 653fb9d120..724c8c97ee 100644 --- a/tests/transformations/mapfusion_test.py +++ b/tests/transformations/mapfusion_test.py @@ -163,6 +163,43 @@ def test_fusion_with_transient(): assert np.allclose(A, expected) +def test_fusion_with_transient_scalar(): + N = 10 + K = 4 + + def build_sdfg(): + sdfg = dace.SDFG("map_fusion_with_transient_scalar") + state = sdfg.add_state() + sdfg.add_array("A", (N,K), dace.float64) + sdfg.add_array("B", (N,), dace.float64) + sdfg.add_array("T", (N,), dace.float64, transient=True) + t_node = state.add_access("T") + sdfg.add_scalar("V", dace.float64, transient=True) + v_node = state.add_access("V") + + me1, mx1 = state.add_map("map1", dict(i=f"0:{N}")) + tlet1 = state.add_tasklet("select", {"_v"}, {"_out"}, f"_out = _v[i, {K-1}]") + state.add_memlet_path(state.add_access("A"), me1, tlet1, dst_conn="_v", memlet=dace.Memlet.from_array("A", sdfg.arrays["A"])) + state.add_edge(tlet1, "_out", v_node, None, dace.Memlet("V[0]")) + state.add_memlet_path(v_node, mx1, t_node, memlet=dace.Memlet("T[i]")) + + me2, mx2 = state.add_map("map2", dict(j=f"0:{N}")) + tlet2 = state.add_tasklet("numeric", {"_inp"}, {"_out"}, f"_out = _inp + 1") + state.add_memlet_path(t_node, me2, tlet2, dst_conn="_inp", memlet=dace.Memlet("T[j]")) + state.add_memlet_path(tlet2, mx2, state.add_access("B"), src_conn="_out", memlet=dace.Memlet("B[j]")) + + return sdfg + + sdfg = build_sdfg() + sdfg.apply_transformations(MapFusion) + + A = np.random.rand(N, K) + B = np.repeat(np.nan, N) + sdfg(A=A, B=B) + + assert np.allclose(B, (A[:, K-1] + 1)) + + def test_fusion_with_inverted_indices(): @dace.program @@ -278,6 +315,7 @@ def fusion_with_nested_sdfg_1(A: dace.int32[10], B: dace.int32[10], C: dace.int3 test_multiple_fusions() test_fusion_chain() test_fusion_with_transient() + test_fusion_with_transient_scalar() test_fusion_with_inverted_indices() test_fusion_with_empty_memlet() test_fusion_with_nested_sdfg_0() diff --git a/tests/transformations/prune_connectors_test.py b/tests/transformations/prune_connectors_test.py index e9c7e34a83..59e1b125ff 100644 --- a/tests/transformations/prune_connectors_test.py +++ b/tests/transformations/prune_connectors_test.py @@ -307,7 +307,7 @@ def test_prune_connectors_with_dependencies(): applied = sdfg.apply_transformations_repeated(PruneConnectors) assert applied == 1 - assert len(sdfg.states()) == 3 + assert len(sdfg.states()) == 2 assert "B1" not in nsdfg_node.in_connectors assert "B2" not in nsdfg_node.out_connectors diff --git a/tests/uintptr_t_test.py b/tests/uintptr_t_test.py new file mode 100644 index 0000000000..2b1941340d --- /dev/null +++ b/tests/uintptr_t_test.py @@ -0,0 +1,37 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import dace +import ctypes +import numpy as np + + +def test_uintp_size(): + # c_void_p: C type -> void* + size = ctypes.sizeof(ctypes.c_void_p) + # numpy.uintp: Unsigned integer large enough to fit pointer, compatible with C uintptr_t + size_of_np_uintp = np.uintp().itemsize + # Dace uintptr_t representation + size_of_dace_uintp = dace.uintp.bytes + + assert size == size_of_np_uintp == size_of_dace_uintp + + +def test_uintp_use(): + + @dace.program + def tester(arr: dace.float64[20], pointer: dace.uintp[1]): + with dace.tasklet(dace.Language.CPP): + a << arr(-1) + """ + out = decltype(out)(a); + """ + out >> pointer[0] + + ptr = np.empty([1], dtype=np.uintp) + arr = np.random.rand(20) + tester(arr, ptr) + assert arr.__array_interface__['data'][0] == ptr[0] + + +if __name__ == '__main__': + test_uintp_size() + test_uintp_use()