diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 7284a7df04..83304a9c62 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -15,7 +15,7 @@ Delete this comment and add a proper description of the changes contained in thi - test: Adding missing tests or correcting existing tests : cartesian | eve | next | storage - # ONLY if changes are limited to a specific subsytem + # ONLY if changes are limited to a specific subsystem - PR Description: @@ -27,7 +27,7 @@ Delete this comment and add a proper description of the changes contained in thi ## Requirements - [ ] All fixes and/or new features come with corresponding tests. -- [ ] Important design decisions have been documented in the approriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. +- [ ] Important design decisions have been documented in the appropriate ADR inside the [docs/development/ADRs/](docs/development/ADRs/Index.md) folder. If this PR contains code authored by new contributors please make sure: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1c3b6e693f..7e1870c67f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -102,6 +102,7 @@ repos: - devtools==0.12.2 - diskcache==5.6.3 - factory-boy==3.3.1 + - filelock==3.16.1 - frozendict==2.4.6 - gridtools-cpp==2.3.8 - importlib-resources==6.4.5 diff --git a/constraints.txt b/constraints.txt index b4b8bc00d4..f039fa2125 100644 --- a/constraints.txt +++ b/constraints.txt @@ -49,7 +49,7 @@ executing==2.1.0 # via devtools, stack-data factory-boy==3.3.1 # via gt4py (pyproject.toml), pytest-factoryboy faker==33.0.0 # via factory-boy fastjsonschema==2.20.0 # via nbformat -filelock==3.16.1 # via tox, virtualenv +filelock==3.16.1 # via gt4py (pyproject.toml), tox, virtualenv fonttools==4.55.0 # via matplotlib fparser==0.1.4 # via dace frozendict==2.4.6 # via gt4py (pyproject.toml) @@ -113,8 +113,8 @@ psutil==6.1.0 # via -r requirements-dev.in, ipykernel, pytest-xdist ptyprocess==0.7.0 # via pexpect pure-eval==0.2.3 # via stack-data pybind11==2.13.6 # via gt4py (pyproject.toml) -pydantic==2.9.2 # via bump-my-version, pydantic-settings -pydantic-core==2.23.4 # via pydantic +pydantic==2.10.0 # via bump-my-version, pydantic-settings +pydantic-core==2.27.0 # via pydantic pydantic-settings==2.6.1 # via bump-my-version pydot==3.0.2 # via tach pygments==2.18.0 # via -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx @@ -159,7 +159,7 @@ stack-data==0.6.3 # via ipython stdlib-list==0.10.0 # via tach sympy==1.13.3 # via dace tabulate==0.9.0 # via gt4py (pyproject.toml) -tach==0.14.3 # via -r requirements-dev.in +tach==0.14.4 # via -r requirements-dev.in tomli==2.1.0 ; python_version < "3.11" # via -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via tach tomlkit==0.13.2 # via bump-my-version diff --git a/docs/user/next/advanced/HackTheToolchain.md b/docs/user/next/advanced/HackTheToolchain.md index 029833cb7d..358f6e8d0d 100644 --- a/docs/user/next/advanced/HackTheToolchain.md +++ b/docs/user/next/advanced/HackTheToolchain.md @@ -15,7 +15,7 @@ from gt4py import eve ```python cached_lowering_toolchain = gtx.backend.DEFAULT_TRANSFORMS.replace( - past_to_itir=gtx.ffront.past_to_itir.past_to_itir_factory(cached=False) + past_to_itir=gtx.ffront.past_to_itir.past_to_gtir_factory(cached=False) ) ``` diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 57c0d3969d..d7679a1f0f 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -67,6 +67,7 @@ deepdiff==5.6.0 devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 +filelock==3.0.0 frozendict==2.3 gridtools-cpp==2.3.8 hypothesis==6.0.0 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index 81a1c2dea3..cf505e88d6 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -63,6 +63,7 @@ deepdiff==5.6.0 devtools==0.6 diskcache==5.6.3 factory-boy==3.3.0 +filelock==3.0.0 frozendict==2.3 gridtools-cpp==2.3.8 hypothesis==6.0.0 diff --git a/pyproject.toml b/pyproject.toml index 02d301957c..e859c9b4f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ 'devtools>=0.6', 'diskcache>=5.6.3', 'factory-boy>=3.3.0', + 'filelock>=3.0.0', 'frozendict>=2.3', 'gridtools-cpp>=2.3.8,==2.*', "importlib-resources>=5.0;python_version<'3.9'", @@ -239,17 +240,14 @@ markers = [ 'requires_atlas: tests that require `atlas4py` bindings package', 'requires_dace: tests that require `dace` package', 'requires_gpu: tests that require a NVidia GPU (`cupy` and `cudatoolkit` are required)', - 'starts_from_gtir_program: tests that require backend to start lowering from GTIR program', 'uses_applied_shifts: tests that require backend support for applied-shifts', 'uses_constant_fields: tests that require backend support for constant fields', 'uses_dynamic_offsets: tests that require backend support for dynamic offsets', 'uses_floordiv: tests that require backend support for floor division', 'uses_if_stmts: tests that require backend support for if-statements', 'uses_index_fields: tests that require backend support for index fields', - 'uses_lift_expressions: tests that require backend support for lift expressions', 'uses_negative_modulo: tests that require backend support for modulo on negative numbers', 'uses_origin: tests that require backend support for domain origin', - 'uses_reduction_over_lift_expressions: tests that require backend support for reduction over lift expressions', 'uses_reduction_with_only_sparse_fields: tests that require backend support for with sparse fields', 'uses_scan: tests that uses scan', 'uses_scan_in_field_operator: tests that require backend support for scan in field operator', diff --git a/requirements-dev.txt b/requirements-dev.txt index 9f95779fd5..6542be36f1 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -49,7 +49,7 @@ executing==2.1.0 # via -c constraints.txt, devtools, stack-data factory-boy==3.3.1 # via -c constraints.txt, gt4py (pyproject.toml), pytest-factoryboy faker==33.0.0 # via -c constraints.txt, factory-boy fastjsonschema==2.20.0 # via -c constraints.txt, nbformat -filelock==3.16.1 # via -c constraints.txt, tox, virtualenv +filelock==3.16.1 # via -c constraints.txt, gt4py (pyproject.toml), tox, virtualenv fonttools==4.55.0 # via -c constraints.txt, matplotlib fparser==0.1.4 # via -c constraints.txt, dace frozendict==2.4.6 # via -c constraints.txt, gt4py (pyproject.toml) @@ -113,8 +113,8 @@ psutil==6.1.0 # via -c constraints.txt, -r requirements-dev.in, ipyk ptyprocess==0.7.0 # via -c constraints.txt, pexpect pure-eval==0.2.3 # via -c constraints.txt, stack-data pybind11==2.13.6 # via -c constraints.txt, gt4py (pyproject.toml) -pydantic==2.9.2 # via -c constraints.txt, bump-my-version, pydantic-settings -pydantic-core==2.23.4 # via -c constraints.txt, pydantic +pydantic==2.10.0 # via -c constraints.txt, bump-my-version, pydantic-settings +pydantic-core==2.27.0 # via -c constraints.txt, pydantic pydantic-settings==2.6.1 # via -c constraints.txt, bump-my-version pydot==3.0.2 # via -c constraints.txt, tach pygments==2.18.0 # via -c constraints.txt, -r requirements-dev.in, devtools, ipython, nbmake, rich, sphinx @@ -158,7 +158,7 @@ stack-data==0.6.3 # via -c constraints.txt, ipython stdlib-list==0.10.0 # via -c constraints.txt, tach sympy==1.13.3 # via -c constraints.txt, dace tabulate==0.9.0 # via -c constraints.txt, gt4py (pyproject.toml) -tach==0.14.3 # via -c constraints.txt, -r requirements-dev.in +tach==0.14.4 # via -c constraints.txt, -r requirements-dev.in tomli==2.1.0 ; python_version < "3.11" # via -c constraints.txt, -r requirements-dev.in, black, build, coverage, jupytext, mypy, pip-tools, pyproject-api, pytest, setuptools-scm, tach, tox tomli-w==1.0.0 # via -c constraints.txt, tach tomlkit==0.13.2 # via -c constraints.txt, bump-my-version diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index f49895a435..a6d28f5994 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -56,17 +56,17 @@ def _specialize_transient_strides(sdfg: dace.SDFG, layout_map): - repldict = replace_strides( + replacement_dictionary = replace_strides( [array for array in sdfg.arrays.values() if array.transient], layout_map ) - sdfg.replace_dict(repldict) + sdfg.replace_dict(replacement_dictionary) for state in sdfg.nodes(): for node in state.nodes(): if isinstance(node, dace.nodes.NestedSDFG): - for k, v in repldict.items(): + for k, v in replacement_dictionary.items(): if k in node.symbol_mapping: node.symbol_mapping[k] = v - for k in repldict.keys(): + for k in replacement_dictionary.keys(): if k in sdfg.symbols: sdfg.remove_symbol(k) @@ -143,7 +143,7 @@ def _to_device(sdfg: dace.SDFG, device: str) -> None: node.device = dace.DeviceType.GPU -def _pre_expand_trafos(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map): +def _pre_expand_transformations(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map): args_data = make_args_data_from_gtir(gtir_pipeline) # stencils without effect @@ -164,7 +164,7 @@ def _pre_expand_trafos(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map) return sdfg -def _post_expand_trafos(sdfg: dace.SDFG): +def _post_expand_transformations(sdfg: dace.SDFG): # DaCe "standard" clean-up transformations sdfg.simplify(validate=False) @@ -355,7 +355,7 @@ def _unexpanded_sdfg(self): sdfg = OirSDFGBuilder().visit(oir_node) _to_device(sdfg, self.builder.backend.storage_info["device"]) - _pre_expand_trafos( + _pre_expand_transformations( self.builder.gtir_pipeline, sdfg, self.builder.backend.storage_info["layout_map"], @@ -371,7 +371,7 @@ def unexpanded_sdfg(self): def _expanded_sdfg(self): sdfg = self._unexpanded_sdfg() sdfg.expand_library_nodes() - _post_expand_trafos(sdfg) + _post_expand_transformations(sdfg) return sdfg def expanded_sdfg(self): diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index ade05921ef..f155ea6209 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -1460,6 +1460,13 @@ def visit_Assign(self, node: ast.Assign) -> list: loc=nodes.Location.from_ast_node(t), ) + if self.backend_name in ["gt:gpu"]: + raise GTScriptSyntaxError( + message=f"Assignment to non-zero offsets in K is not available in {self.backend_name} as an unsolved bug remains." + "Please refer to https://github.com/GridTools/gt4py/issues/1754.", + loc=nodes.Location.from_ast_node(t), + ) + if not self._is_known(name): if name in self.temp_decls: field_decl = self.temp_decls[name] diff --git a/src/gt4py/cartesian/gtc/dace/expansion_specification.py b/src/gt4py/cartesian/gtc/dace/expansion_specification.py index c716f1a103..af9a814843 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion_specification.py +++ b/src/gt4py/cartesian/gtc/dace/expansion_specification.py @@ -107,7 +107,8 @@ def get_expansion_order_index(expansion_order, axis): for idx, item in enumerate(expansion_order): if isinstance(item, Iteration) and item.axis == axis: return idx - elif isinstance(item, Map): + + if isinstance(item, Map): for it in item.iterations: if it.kind == "contiguous" and it.axis == axis: return idx @@ -136,7 +137,9 @@ def _choose_loop_or_map(node, eo): return eo -def _order_as_spec(computation_node, expansion_order): +def _order_as_spec( + computation_node: StencilComputation, expansion_order: Union[List[str], List[ExpansionItem]] +) -> List[ExpansionItem]: expansion_order = list(_choose_loop_or_map(computation_node, eo) for eo in expansion_order) expansion_specification = [] for item in expansion_order: @@ -170,7 +173,7 @@ def _order_as_spec(computation_node, expansion_order): return expansion_specification -def _populate_strides(node, expansion_specification): +def _populate_strides(node: StencilComputation, expansion_specification: List[ExpansionItem]): """Fill in `stride` attribute of `Iteration` and `Loop` dataclasses. For loops, stride is set to either -1 or 1, based on iteration order. @@ -185,10 +188,7 @@ def _populate_strides(node, expansion_specification): for it in iterations: if isinstance(it, Loop): if it.stride is None: - if node.oir_node.loop_order == common.LoopOrder.BACKWARD: - it.stride = -1 - else: - it.stride = 1 + it.stride = -1 if node.oir_node.loop_order == common.LoopOrder.BACKWARD else 1 else: if it.stride is None: if it.kind == "tiling": @@ -204,7 +204,7 @@ def _populate_strides(node, expansion_specification): it.stride = 1 -def _populate_storages(self, expansion_specification): +def _populate_storages(expansion_specification: List[ExpansionItem]): assert all(isinstance(es, ExpansionItem) for es in expansion_specification) innermost_axes = set(dcir.Axis.dims_3d()) tiled_axes = set() @@ -222,7 +222,7 @@ def _populate_storages(self, expansion_specification): tiled_axes.remove(it.axis) -def _populate_cpu_schedules(self, expansion_specification): +def _populate_cpu_schedules(expansion_specification: List[ExpansionItem]): is_outermost = True for es in expansion_specification: if isinstance(es, Map): @@ -234,7 +234,7 @@ def _populate_cpu_schedules(self, expansion_specification): es.schedule = dace.ScheduleType.Default -def _populate_gpu_schedules(self, expansion_specification): +def _populate_gpu_schedules(expansion_specification: List[ExpansionItem]): # On GPU if any dimension is tiled and has a contiguous map in the same axis further in # pick those two maps as Device/ThreadBlock maps. If not, Make just device map with # default blocksizes @@ -267,16 +267,16 @@ def _populate_gpu_schedules(self, expansion_specification): es.schedule = dace.ScheduleType.Default -def _populate_schedules(self, expansion_specification): +def _populate_schedules(node: StencilComputation, expansion_specification: List[ExpansionItem]): assert all(isinstance(es, ExpansionItem) for es in expansion_specification) - assert hasattr(self, "_device") - if self.device == dace.DeviceType.GPU: - _populate_gpu_schedules(self, expansion_specification) + assert hasattr(node, "_device") + if node.device == dace.DeviceType.GPU: + _populate_gpu_schedules(expansion_specification) else: - _populate_cpu_schedules(self, expansion_specification) + _populate_cpu_schedules(expansion_specification) -def _collapse_maps_gpu(self, expansion_specification): +def _collapse_maps_gpu(expansion_specification: List[ExpansionItem]) -> List[ExpansionItem]: def _union_map_items(last_item, next_item): if last_item.schedule == next_item.schedule: return ( @@ -307,7 +307,7 @@ def _union_map_items(last_item, next_item): ), ) - res_items = [] + res_items: List[ExpansionItem] = [] for item in expansion_specification: if isinstance(item, Map): if not res_items or not isinstance(res_items[-1], Map): @@ -324,8 +324,8 @@ def _union_map_items(last_item, next_item): return res_items -def _collapse_maps_cpu(self, expansion_specification): - res_items = [] +def _collapse_maps_cpu(expansion_specification: List[ExpansionItem]) -> List[ExpansionItem]: + res_items: List[ExpansionItem] = [] for item in expansion_specification: if isinstance(item, Map): if ( @@ -360,12 +360,12 @@ def _collapse_maps_cpu(self, expansion_specification): return res_items -def _collapse_maps(self, expansion_specification): - assert hasattr(self, "_device") - if self.device == dace.DeviceType.GPU: - res_items = _collapse_maps_gpu(self, expansion_specification) +def _collapse_maps(node: StencilComputation, expansion_specification: List[ExpansionItem]): + assert hasattr(node, "_device") + if node.device == dace.DeviceType.GPU: + res_items = _collapse_maps_gpu(expansion_specification) else: - res_items = _collapse_maps_cpu(self, expansion_specification) + res_items = _collapse_maps_cpu(expansion_specification) expansion_specification.clear() expansion_specification.extend(res_items) @@ -387,7 +387,7 @@ def make_expansion_order( _populate_strides(node, expansion_specification) _populate_schedules(node, expansion_specification) _collapse_maps(node, expansion_specification) - _populate_storages(node, expansion_specification) + _populate_storages(expansion_specification) return expansion_specification diff --git a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py index f12c13cd0e..14448bb08e 100644 --- a/src/gt4py/cartesian/gtc/dace/oir_to_dace.py +++ b/src/gt4py/cartesian/gtc/dace/oir_to_dace.py @@ -123,6 +123,7 @@ def visit_VerticalLoop( state.add_edge( access_node, None, library_node, "__in_" + field, dace.Memlet(field, subset=subset) ) + for field in access_collection.write_fields(): access_node = state.add_access(field, debuginfo=dace.DebugInfo(0)) library_node.add_out_connector("__out_" + field) @@ -131,8 +132,6 @@ def visit_VerticalLoop( library_node, "__out_" + field, access_node, None, dace.Memlet(field, subset=subset) ) - return - def visit_Stencil(self, node: oir.Stencil, **kwargs): ctx = OirSDFGBuilder.SDFGContext(stencil=node) for param in node.params: diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index 517e80ceb3..bd65861a49 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -40,7 +40,7 @@ def array_dimensions(array: dace.data.Array): return dims -def replace_strides(arrays, get_layout_map): +def replace_strides(arrays: List[dace.data.Array], get_layout_map) -> Dict[str, str]: symbol_mapping = {} for array in arrays: dims = array_dimensions(array) diff --git a/src/gt4py/eve/.gitignore b/src/gt4py/eve/.gitignore deleted file mode 100644 index 050cda3ca5..0000000000 --- a/src/gt4py/eve/.gitignore +++ /dev/null @@ -1 +0,0 @@ -_version.py diff --git a/src/gt4py/eve/__init__.py b/src/gt4py/eve/__init__.py index 0b8cfa7d62..5adac47da3 100644 --- a/src/gt4py/eve/__init__.py +++ b/src/gt4py/eve/__init__.py @@ -24,8 +24,7 @@ """ -from __future__ import annotations # isort:skip - +from __future__ import annotations from .concepts import ( AnnexManager, @@ -89,15 +88,6 @@ "SymbolRef", "VType", "register_annex_user", - "# datamodels" "Coerced", - "DataModel", - "FrozenModel", - "GenericDataModel", - "Unchecked", - "concretize", - "datamodel", - "field", - "frozenmodel", # datamodels "Coerced", "DataModel", @@ -122,7 +112,7 @@ "pre_walk_values", "walk_items", "walk_values", - "# type_definition", + # type_definitions "NOTHING", "ConstrainedStr", "Enum", diff --git a/src/gt4py/eve/codegen.py b/src/gt4py/eve/codegen.py index 15fda4f3b4..3869ff313b 100644 --- a/src/gt4py/eve/codegen.py +++ b/src/gt4py/eve/codegen.py @@ -347,7 +347,7 @@ def __str__(self) -> str: class Template(Protocol): """Protocol (abstract base class) defining the Template interface. - Direct subclassess of this base class only need to implement the + Direct subclasses of this base class only need to implement the abstract methods to adapt different template engines to this interface. @@ -654,8 +654,8 @@ def apply( # redefinition of symbol Args: root: An IR node. - node_templates (optiona): see :class:`NodeDumper`. - dump_function (optiona): see :class:`NodeDumper`. + node_templates (optional): see :class:`NodeDumper`. + dump_function (optional): see :class:`NodeDumper`. ``**kwargs`` (optional): custom extra parameters forwarded to `visit_NODE_TYPE_NAME()`. Returns: diff --git a/src/gt4py/eve/datamodels/__init__.py b/src/gt4py/eve/datamodels/__init__.py index 68ddea2510..6fd9c7bb21 100644 --- a/src/gt4py/eve/datamodels/__init__.py +++ b/src/gt4py/eve/datamodels/__init__.py @@ -11,7 +11,7 @@ Data Models can be considered as enhanced `attrs `_ / `dataclasses `_ providing additional features like automatic run-time type validation. Values assigned to fields -at initialization can be validated with automatic type checkings using the +at initialization can be validated with automatic type checking using the field type definition. Custom field validation methods can also be added with the :func:`validator` decorator, and global instance validation methods with :func:`root_validator`. @@ -33,7 +33,7 @@ 1. ``__init__()``. a. If a custom ``__init__`` already exists in the class, it will not be overwritten. - It is your responsability to call ``__auto_init__`` from there to obtain + It is your responsibility to call ``__auto_init__`` from there to obtain the described behavior. b. If there is not custom ``__init__``, the one generated by datamodels will be called first. diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index d596f59cfb..1b0e995156 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -24,7 +24,7 @@ try: - # For perfomance reasons, try to use cytoolz when possible (using cython) + # For performance reasons, try to use cytoolz when possible (using cython) import cytoolz as toolz except ModuleNotFoundError: # Fall back to pure Python toolz @@ -270,7 +270,7 @@ def datamodel( @overload -def datamodel( # redefinion of unused symbol +def datamodel( # redefinition of unused symbol cls: Type[_T], /, *, @@ -289,7 +289,7 @@ def datamodel( # redefinion of unused symbol # TODO(egparedes): Use @dataclass_transform(eq_default=True, field_specifiers=("field",)) -def datamodel( # redefinion of unused symbol +def datamodel( # redefinition of unused symbol cls: Optional[Type[_T]] = None, /, *, @@ -867,7 +867,7 @@ def _substitute_typevars( def _make_counting_attr_from_attribute( field_attrib: Attribute, *, include_type: bool = False, **kwargs: Any -) -> Any: # attr.s lies a bit in some typing definitons +) -> Any: # attr.s lies a bit in some typing definitions args = [ "default", "validator", @@ -965,7 +965,7 @@ def _type_converter(value: Any) -> _T: return value if isinstance(value, type_annotation) else type_annotation(value) except Exception as error: raise TypeError( - f"Error during coertion of given value '{value}' for field '{name}'." + f"Error during coercion of given value '{value}' for field '{name}'." ) from error return _type_converter @@ -996,7 +996,7 @@ def _type_converter(value: Any) -> _T: return _make_type_converter(origin_type, name) raise exceptions.EveTypeError( - f"Automatic type coertion for {type_annotation} types is not supported." + f"Automatic type coercion for {type_annotation} types is not supported." ) @@ -1085,7 +1085,7 @@ def _make_datamodel( ) else: - # Create field converter if automatic coertion is enabled + # Create field converter if automatic coercion is enabled converter: TypeConverter = cast( TypeConverter, _make_type_converter(type_hint, qualified_field_name) if coerce_field else None, @@ -1099,7 +1099,7 @@ def _make_datamodel( if isinstance(attr_value_in_cls, _KNOWN_MUTABLE_TYPES): warnings.warn( f"'{attr_value_in_cls.__class__.__name__}' value used as default in '{cls.__name__}.{key}'.\n" - "Mutable types should not defbe normally used as field defaults (use 'default_factory' instead).", + "Mutable types should not be used as field defaults (use 'default_factory' instead).", stacklevel=_stacklevel_offset + 2, ) setattr( diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index e276f3bccf..bf44824b49 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -14,12 +14,8 @@ from __future__ import annotations -import abc as _abc import array as _array -import collections.abc as _collections_abc -import ctypes as _ctypes import dataclasses as _dataclasses -import enum as _enum import functools as _functools import inspect as _inspect import mmap as _mmap diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index c8e8658413..8a3cc30f4b 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -31,14 +31,6 @@ from .type_definitions import Enum -try: - # For performance reasons, try to use cytoolz when possible (using cython) - import cytoolz as toolz -except ModuleNotFoundError: - # Fall back to pure Python toolz - import toolz # noqa: F401 [unused-import] - - TreeKey = Union[int, str] diff --git a/src/gt4py/eve/type_validation.py b/src/gt4py/eve/type_validation.py index 613eca40b2..e150832295 100644 --- a/src/gt4py/eve/type_validation.py +++ b/src/gt4py/eve/type_validation.py @@ -311,7 +311,7 @@ def __call__( # ... # # Since this can be an arbitrary type (not something regular like a collection) there is - # no way to check if the type parameter is verifed in the actual instance. + # no way to check if the type parameter is verified in the actual instance. # The only check can be done at run-time is to verify that the value is an instance of # the original type, completely ignoring the annotation. Ideally, the static type checker # can do a better job to try figure out if the type parameter is ok ... diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 8cb68845d7..2c66d39290 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -69,7 +69,7 @@ try: - # For perfomance reasons, try to use cytoolz when possible (using cython) + # For performance reasons, try to use cytoolz when possible (using cython) import cytoolz as toolz except ModuleNotFoundError: # Fall back to pure Python toolz diff --git a/src/gt4py/next/backend.py b/src/gt4py/next/backend.py index e223d7771c..e075422ca3 100644 --- a/src/gt4py/next/backend.py +++ b/src/gt4py/next/backend.py @@ -16,7 +16,6 @@ from gt4py.next import allocators as next_allocators from gt4py.next.ffront import ( foast_to_gtir, - foast_to_itir, foast_to_past, func_to_foast, func_to_past, @@ -41,7 +40,7 @@ ARGS: typing.TypeAlias = arguments.JITArgs CARG: typing.TypeAlias = arguments.CompileTimeArgs -IT_PRG: typing.TypeAlias = itir.FencilDefinition | itir.Program +IT_PRG: typing.TypeAlias = itir.Program INPUT_DATA: typing.TypeAlias = DSL_FOP | FOP | DSL_PRG | PRG | IT_PRG @@ -93,7 +92,7 @@ class Transforms(workflow.MultiWorkflow[INPUT_PAIR, stages.CompilableProgram]): ) past_to_itir: workflow.Workflow[AOT_PRG, stages.CompilableProgram] = dataclasses.field( - default_factory=past_to_itir.past_to_itir_factory + default_factory=past_to_itir.past_to_gtir_factory ) def step_order(self, inp: INPUT_PAIR) -> list[str]: @@ -126,7 +125,7 @@ def step_order(self, inp: INPUT_PAIR) -> list[str]: ) case PRG(): steps.extend(["past_lint", "field_view_prog_args_transform", "past_to_itir"]) - case itir.FencilDefinition() | itir.Program(): + case itir.Program(): pass case _: raise ValueError("Unexpected input.") @@ -135,17 +134,6 @@ def step_order(self, inp: INPUT_PAIR) -> list[str]: DEFAULT_TRANSFORMS: Transforms = Transforms() -# FIXME[#1582](havogt): remove after refactoring to GTIR -# note: this step is deliberately placed here, such that the cache is shared -_foast_to_itir_step = foast_to_itir.adapted_foast_to_itir_factory(cached=True) -LEGACY_TRANSFORMS: Transforms = Transforms( - past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=False), - foast_to_itir=_foast_to_itir_step, - field_view_op_to_prog=foast_to_past.operator_to_program_factory( - foast_to_itir_step=_foast_to_itir_step - ), -) - # TODO(tehrengruber): Rename class and `executor` & `transforms` attribute. Maybe: # `Backend` -> `Toolchain` diff --git a/src/gt4py/next/config.py b/src/gt4py/next/config.py index ed244c2932..7a19f3eb9d 100644 --- a/src/gt4py/next/config.py +++ b/src/gt4py/next/config.py @@ -11,7 +11,6 @@ import enum import os import pathlib -import tempfile from typing import Final @@ -51,25 +50,22 @@ def env_flag_to_bool(name: str, default: bool) -> bool: ) -_PREFIX: Final[str] = "GT4PY" - #: Master debug flag #: Changes defaults for all the other options to be as helpful for debugging as possible. #: Does not override values set in environment variables. -DEBUG: Final[bool] = env_flag_to_bool(f"{_PREFIX}_DEBUG", default=False) +DEBUG: Final[bool] = env_flag_to_bool("GT4PY_DEBUG", default=False) #: Verbose flag for DSL compilation errors VERBOSE_EXCEPTIONS: bool = env_flag_to_bool( - f"{_PREFIX}_VERBOSE_EXCEPTIONS", default=True if DEBUG else False + "GT4PY_VERBOSE_EXCEPTIONS", default=True if DEBUG else False ) #: Where generated code projects should be persisted. #: Only active if BUILD_CACHE_LIFETIME is set to PERSISTENT BUILD_CACHE_DIR: pathlib.Path = ( - pathlib.Path(os.environ.get(f"{_PREFIX}_BUILD_CACHE_DIR", tempfile.gettempdir())) - / "gt4py_cache" + pathlib.Path(os.environ.get("GT4PY_BUILD_CACHE_DIR", pathlib.Path.cwd())) / ".gt4py_cache" ) @@ -77,11 +73,11 @@ def env_flag_to_bool(name: str, default: bool) -> bool: #: - SESSION: generated code projects get destroyed when the interpreter shuts down #: - PERSISTENT: generated code projects are written to BUILD_CACHE_DIR and persist between runs BUILD_CACHE_LIFETIME: BuildCacheLifetime = BuildCacheLifetime[ - os.environ.get(f"{_PREFIX}_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() + os.environ.get("GT4PY_BUILD_CACHE_LIFETIME", "persistent" if DEBUG else "session").upper() ] #: Build type to be used when CMake is used to compile generated code. #: Might have no effect when CMake is not used as part of the toolchain. CMAKE_BUILD_TYPE: CMakeBuildType = CMakeBuildType[ - os.environ.get(f"{_PREFIX}_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() + os.environ.get("GT4PY_CMAKE_BUILD_TYPE", "debug" if DEBUG else "release").upper() ] diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 9ce07d01bb..d187095019 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -34,7 +34,6 @@ from gt4py.next.ffront import ( field_operator_ast as foast, foast_to_gtir, - foast_to_itir, past_process_args, signature, stages as ffront_stages, @@ -186,7 +185,7 @@ def _all_closure_vars(self) -> dict[str, Any]: return transform_utils._get_closure_vars_recursively(self.past_stage.closure_vars) @functools.cached_property - def itir(self) -> itir.FencilDefinition: + def gtir(self) -> itir.Program: no_args_past = toolchain.CompilableProgram( data=ffront_stages.PastProgramDefinition( past_node=self.past_stage.past_node, @@ -230,7 +229,7 @@ def __call__(self, *args: Any, offset_provider: common.OffsetProvider, **kwargs: if self.backend is None: warnings.warn( UserWarning( - f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a perfomance backend." + f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a performance backend." ), stacklevel=2, ) @@ -561,7 +560,7 @@ def with_grid_type(self, grid_type: common.GridType) -> FieldOperator: # a different backend than the one of the program that calls this field operator. Just use # the hard-coded lowering until this is cleaned up. def __gt_itir__(self) -> itir.FunctionDefinition: - return foast_to_itir.foast_to_itir(self.foast_stage) + return foast_to_gtir.foast_to_gtir(self.foast_stage) # FIXME[#1582](tehrengruber): remove after refactoring to GTIR def __gt_gtir__(self) -> itir.FunctionDefinition: diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 2c2971f49a..3c65695aec 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -359,11 +359,7 @@ def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: obj, new_type = self.visit(node.args[0], **kwargs), node.args[1].id def create_cast(expr: itir.Expr, t: tuple[ts.TypeSpec]) -> itir.FunCall: - if isinstance(t[0], ts.FieldType): - return im.cast_as_fieldop(str(new_type))(expr) - else: - assert isinstance(t[0], ts.ScalarType) - return im.call("cast_")(expr, str(new_type)) + return _map(im.lambda_("val")(im.call("cast_")("val", str(new_type))), (expr,), t) if not isinstance(node.type, ts.TupleType): # to keep the IR simpler return create_cast(obj, (node.args[0].type,)) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py deleted file mode 100644 index 538b0f3ddb..0000000000 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ /dev/null @@ -1,512 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -# FIXME[#1582](havogt): remove after refactoring to GTIR - -import dataclasses -from typing import Any, Callable, Optional - -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.eve.extended_typing import Never -from gt4py.eve.utils import UIDGenerator -from gt4py.next import common -from gt4py.next.ffront import ( - dialect_ast_enums, - fbuiltins, - field_operator_ast as foast, - lowering_utils, - stages as ffront_stages, - type_specifications as ts_ffront, -) -from gt4py.next.ffront.experimental import EXPERIMENTAL_FUN_BUILTIN_NAMES -from gt4py.next.ffront.fbuiltins import FUN_BUILTIN_NAMES, MATH_BUILTIN_NAMES, TYPE_BUILTIN_NAMES -from gt4py.next.ffront.foast_introspection import StmtReturnKind, deduce_stmt_return_kind -from gt4py.next.ffront.stages import AOT_FOP, FOP -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.otf import toolchain, workflow -from gt4py.next.type_system import type_info, type_specifications as ts - - -def foast_to_itir(inp: FOP) -> itir.Expr: - """ - Lower a FOAST field operator node to Iterator IR. - - See the docstring of `FieldOperatorLowering` for details. - """ - return FieldOperatorLowering.apply(inp.foast_node) - - -def foast_to_itir_factory(cached: bool = True) -> workflow.Workflow[FOP, itir.Expr]: - """Wrap `foast_to_itir` into a chainable and, optionally, cached workflow step.""" - wf = foast_to_itir - if cached: - wf = workflow.CachedStep(step=wf, hash_function=ffront_stages.fingerprint_stage) - return wf - - -def adapted_foast_to_itir_factory(**kwargs: Any) -> workflow.Workflow[AOT_FOP, itir.Expr]: - """Wrap the `foast_to_itir` workflow step into an adapter to fit into backend transform workflows.""" - return toolchain.StripArgsAdapter(foast_to_itir_factory(**kwargs)) - - -def promote_to_list(node_type: ts.TypeSpec) -> Callable[[itir.Expr], itir.Expr]: - if not type_info.contains_local_field(node_type): - return lambda x: im.promote_to_lifted_stencil("make_const_list")(x) - return lambda x: x - - -@dataclasses.dataclass -class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): - """ - Lower FieldOperator AST (FOAST) to Iterator IR (ITIR). - - The strategy is to lower every expression to lifted stencils, - i.e. taking iterators and returning iterator. - - Examples - -------- - >>> from gt4py.next.ffront.func_to_foast import FieldOperatorParser - >>> from gt4py.next import Field, Dimension, float64 - >>> - >>> IDim = Dimension("IDim") - >>> def fieldop(inp: Field[[IDim], "float64"]): - ... return inp - >>> - >>> parsed = FieldOperatorParser.apply_to_function(fieldop) - >>> lowered = FieldOperatorLowering.apply(parsed) - >>> type(lowered) - - >>> lowered.id - SymbolName('fieldop') - >>> lowered.params # doctest: +ELLIPSIS - [Sym(id=SymbolName('inp'))] - """ - - uid_generator: UIDGenerator = dataclasses.field(default_factory=UIDGenerator) - - @classmethod - def apply(cls, node: foast.LocatedNode) -> itir.Expr: - return cls().visit(node) - - def visit_FunctionDefinition( - self, node: foast.FunctionDefinition, **kwargs: Any - ) -> itir.FunctionDefinition: - params = self.visit(node.params) - return itir.FunctionDefinition( - id=node.id, params=params, expr=self.visit_BlockStmt(node.body, inner_expr=None) - ) # `expr` is a lifted stencil - - def visit_FieldOperator( - self, node: foast.FieldOperator, **kwargs: Any - ) -> itir.FunctionDefinition: - func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) - - new_body = func_definition.expr - - return itir.FunctionDefinition( - id=func_definition.id, params=func_definition.params, expr=new_body - ) - - def visit_ScanOperator( - self, node: foast.ScanOperator, **kwargs: Any - ) -> itir.FunctionDefinition: - # note: we don't need the axis here as this is handled by the program - # decorator - assert isinstance(node.type, ts_ffront.ScanOperatorType) - - # We are lowering node.forward and node.init to iterators, but here we expect values -> `deref`. - # In iterator IR we didn't properly specify if this is legal, - # however after lift-inlining the expressions are transformed back to literals. - forward = im.deref(self.visit(node.forward, **kwargs)) - init = lowering_utils.process_elements( - im.deref, self.visit(node.init, **kwargs), node.init.type - ) - - # lower definition function - func_definition: itir.FunctionDefinition = self.visit(node.definition, **kwargs) - new_body = im.let( - func_definition.params[0].id, - # promote carry to iterator of tuples - # (this is the only place in the lowering were a variable is captured in a lifted lambda) - lowering_utils.to_tuples_of_iterator( - im.promote_to_const_iterator(func_definition.params[0].id), - [*node.type.definition.pos_or_kw_args.values()][0], # noqa: RUF015 [unnecessary-iterable-allocation-for-first-element] - ), - )( - # the function itself returns a tuple of iterators, deref element-wise - lowering_utils.process_elements( - im.deref, func_definition.expr, node.type.definition.returns - ) - ) - - stencil_args: list[itir.Expr] = [] - assert not node.type.definition.pos_only_args and not node.type.definition.kw_only_args - for param, arg_type in zip( - func_definition.params[1:], - [*node.type.definition.pos_or_kw_args.values()][1:], - strict=True, - ): - if isinstance(arg_type, ts.TupleType): - # convert into iterator of tuples - stencil_args.append(lowering_utils.to_iterator_of_tuples(param.id, arg_type)) - - new_body = im.let( - param.id, lowering_utils.to_tuples_of_iterator(param.id, arg_type) - )(new_body) - else: - stencil_args.append(im.ref(param.id)) - - definition = itir.Lambda(params=func_definition.params, expr=new_body) - - body = im.lift(im.call("scan")(definition, forward, init))(*stencil_args) - - return itir.FunctionDefinition(id=node.id, params=definition.params[1:], expr=body) - - def visit_Stmt(self, node: foast.Stmt, **kwargs: Any) -> Never: - raise AssertionError("Statements must always be visited in the context of a function.") - - def visit_Return( - self, node: foast.Return, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - return self.visit(node.value, **kwargs) - - def visit_BlockStmt( - self, node: foast.BlockStmt, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - for stmt in reversed(node.stmts): - inner_expr = self.visit(stmt, inner_expr=inner_expr, **kwargs) - assert inner_expr - return inner_expr - - def visit_IfStmt( - self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - # the lowered if call doesn't need to be lifted as the condition can only originate - # from a scalar value (and not a field) - assert ( - isinstance(node.condition.type, ts.ScalarType) - and node.condition.type.kind == ts.ScalarKind.BOOL - ) - - cond = self.visit(node.condition, **kwargs) - - return_kind: StmtReturnKind = deduce_stmt_return_kind(node) - - common_symbols: dict[str, foast.Symbol] = node.annex.propagated_symbols - - if return_kind is StmtReturnKind.NO_RETURN: - # pack the common symbols into a tuple - common_symrefs = im.make_tuple(*(im.ref(sym) for sym in common_symbols.keys())) - - # apply both branches and extract the common symbols through the prepared tuple - true_branch = self.visit(node.true_branch, inner_expr=common_symrefs, **kwargs) - false_branch = self.visit(node.false_branch, inner_expr=common_symrefs, **kwargs) - - # unpack the common symbols' tuple for `inner_expr` - for i, sym in enumerate(common_symbols.keys()): - inner_expr = im.let(sym, im.tuple_get(i, im.ref("__if_stmt_result")))(inner_expr) - - # here we assume neither branch returns - return im.let("__if_stmt_result", im.if_(im.deref(cond), true_branch, false_branch))( - inner_expr - ) - elif return_kind is StmtReturnKind.CONDITIONAL_RETURN: - common_syms = tuple(im.sym(sym) for sym in common_symbols.keys()) - common_symrefs = tuple(im.ref(sym) for sym in common_symbols.keys()) - - # wrap the inner expression in a lambda function. note that this increases the - # operation count if both branches are evaluated. - inner_expr_name = self.uid_generator.sequential_id(prefix="__inner_expr") - inner_expr_evaluator = im.lambda_(*common_syms)(inner_expr) - inner_expr = im.call(inner_expr_name)(*common_symrefs) - - true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) - false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - - return im.let(inner_expr_name, inner_expr_evaluator)( - im.if_(im.deref(cond), true_branch, false_branch) - ) - - assert return_kind is StmtReturnKind.UNCONDITIONAL_RETURN - - # note that we do not duplicate `inner_expr` here since if both branches - # return, `inner_expr` is ignored. - true_branch = self.visit(node.true_branch, inner_expr=inner_expr, **kwargs) - false_branch = self.visit(node.false_branch, inner_expr=inner_expr, **kwargs) - - return im.if_(im.deref(cond), true_branch, false_branch) - - def visit_Assign( - self, node: foast.Assign, *, inner_expr: Optional[itir.Expr], **kwargs: Any - ) -> itir.Expr: - return im.let(self.visit(node.target, **kwargs), self.visit(node.value, **kwargs))( - inner_expr - ) - - def visit_Symbol(self, node: foast.Symbol, **kwargs: Any) -> itir.Sym: - return im.sym(node.id) - - def visit_Name(self, node: foast.Name, **kwargs: Any) -> itir.SymRef: - return im.ref(node.id) - - def visit_Subscript(self, node: foast.Subscript, **kwargs: Any) -> itir.Expr: - return im.tuple_get(node.index, self.visit(node.value, **kwargs)) - - def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: - return im.make_tuple(*[self.visit(el, **kwargs) for el in node.elts]) - - def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: - # TODO(tehrengruber): extend iterator ir to support unary operators - dtype = type_info.extract_dtype(node.type) - if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: - if dtype.kind != ts.ScalarKind.BOOL: - raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") - return self._lower_and_map("not_", node.operand) - - return self._lower_and_map( - node.op.value, - foast.Constant(value="0", type=dtype, location=node.location), - node.operand, - ) - - def visit_BinOp(self, node: foast.BinOp, **kwargs: Any) -> itir.FunCall: - return self._lower_and_map(node.op.value, node.left, node.right) - - def visit_TernaryExpr(self, node: foast.TernaryExpr, **kwargs: Any) -> itir.FunCall: - op = "if_" - args = (node.condition, node.true_expr, node.false_expr) - lowered_args: list[itir.Expr] = [ - lowering_utils.to_iterator_of_tuples(self.visit(arg, **kwargs), arg.type) - for arg in args - ] - if any(type_info.contains_local_field(arg.type) for arg in args): - lowered_args = [ - promote_to_list(arg.type)(larg) for arg, larg in zip(args, lowered_args) - ] - op = im.call("map_")(op) - - return lowering_utils.to_tuples_of_iterator( - im.promote_to_lifted_stencil(im.call(op))(*lowered_args), node.type - ) - - def visit_Compare(self, node: foast.Compare, **kwargs: Any) -> itir.FunCall: - return self._lower_and_map(node.op.value, node.left, node.right) - - def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - current_expr = self.visit(node.func, **kwargs) - - for arg in node.args: - match arg: - # `field(Off[idx])` - case foast.Subscript(value=foast.Name(id=offset_name), index=int(offset_index)): - current_expr = im.lift( - im.lambda_("it")(im.deref(im.shift(offset_name, offset_index)("it"))) - )(current_expr) - # `field(Dim + idx)` - case foast.BinOp( - op=dialect_ast_enums.BinaryOperator.ADD - | dialect_ast_enums.BinaryOperator.SUB, - left=foast.Name(id=dimension), - right=foast.Constant(value=offset_index), - ): - if arg.op == dialect_ast_enums.BinaryOperator.SUB: - offset_index *= -1 - current_expr = im.lift( - # TODO(SF-N): we rely on the naming-convention that the cartesian dimensions - # are passed suffixed with `off`, e.g. the `K` is passed as `Koff` in the - # offset provider. This is a rather unclean solution and should be - # improved. - im.lambda_("it")( - im.deref( - im.shift( - common.dimension_to_implicit_offset(dimension), offset_index - )("it") - ) - ) - )(current_expr) - # `field(Off)` - case foast.Name(id=offset_name): - # only a single unstructured shift is supported so returning here is fine even though we - # are in a loop. - assert len(node.args) == 1 and len(arg.type.target) > 1 # type: ignore[attr-defined] # ensured by pattern - return im.lifted_neighbors(str(offset_name), self.visit(node.func, **kwargs)) - # `field(as_offset(Off, offset_field))` - case foast.Call(func=foast.Name(id="as_offset")): - func_args = arg - # TODO(tehrengruber): Use type system to deduce the offset dimension instead of - # (e.g. to allow aliasing) - offset_dim = func_args.args[0] - assert isinstance(offset_dim, foast.Name) - offset_it = self.visit(func_args.args[1], **kwargs) - current_expr = im.lift( - im.lambda_("it", "offset")( - im.deref(im.shift(offset_dim.id, im.deref("offset"))("it")) - ) - )(current_expr, offset_it) - case _: - raise FieldOperatorLoweringError("Unexpected shift arguments!") - - return current_expr - - def visit_Call(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - if type_info.type_class(node.func.type) is ts.FieldType: - return self._visit_shift(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in MATH_BUILTIN_NAMES: - return self._visit_math_built_in(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in ( - FUN_BUILTIN_NAMES + EXPERIMENTAL_FUN_BUILTIN_NAMES - ): - visitor = getattr(self, f"_visit_{node.func.id}") - return visitor(node, **kwargs) - elif isinstance(node.func, foast.Name) and node.func.id in TYPE_BUILTIN_NAMES: - return self._visit_type_constr(node, **kwargs) - elif isinstance( - node.func.type, - (ts.FunctionType, ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType), - ): - # ITIR has no support for keyword arguments. Instead, we concatenate both positional - # and keyword arguments and use the unique order as given in the function signature. - lowered_args, lowered_kwargs = type_info.canonicalize_arguments( - node.func.type, - self.visit(node.args, **kwargs), - self.visit(node.kwargs, **kwargs), - use_signature_ordering=True, - ) - result = im.call(self.visit(node.func, **kwargs))( - *lowered_args, *lowered_kwargs.values() - ) - - # scan operators return an iterator of tuples, transform into tuples of iterator again - if isinstance(node.func.type, ts_ffront.ScanOperatorType): - result = lowering_utils.to_tuples_of_iterator( - result, node.func.type.definition.returns - ) - - return result - - raise AssertionError( - f"Call to object of type '{type(node.func.type).__name__}' not understood." - ) - - def _visit_astype(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - assert len(node.args) == 2 and isinstance(node.args[1], foast.Name) - obj, new_type = node.args[0], node.args[1].id - return lowering_utils.process_elements( - lambda x: im.promote_to_lifted_stencil( - im.lambda_("it")(im.call("cast_")("it", str(new_type))) - )(x), - self.visit(obj, **kwargs), - obj.type, - ) - - def _visit_where(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - condition, true_value, false_value = node.args - - lowered_condition = self.visit(condition, **kwargs) - return lowering_utils.process_elements( - lambda tv, fv, types: _map( - "if_", (lowered_condition, tv, fv), (condition.type, *types) - ), - [self.visit(true_value, **kwargs), self.visit(false_value, **kwargs)], - node.type, - (node.args[1].type, node.args[2].type), - ) - - _visit_concat_where = _visit_where - - def _visit_broadcast(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self.visit(node.args[0], **kwargs) - - def _visit_math_built_in(self, node: foast.Call, **kwargs: Any) -> itir.FunCall: - return self._lower_and_map(self.visit(node.func, **kwargs), *node.args) - - def _make_reduction_expr( - self, node: foast.Call, op: str | itir.SymRef, init_expr: itir.Expr, **kwargs: Any - ) -> itir.Expr: - # TODO(havogt): deal with nested reductions of the form neighbor_sum(neighbor_sum(field(off1)(off2))) - it = self.visit(node.args[0], **kwargs) - assert isinstance(node.kwargs["axis"].type, ts.DimensionType) - val = im.call(im.call("reduce")(op, im.deref(init_expr))) - return im.promote_to_lifted_stencil(val)(it) - - def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - dtype = type_info.extract_dtype(node.type) - return self._make_reduction_expr(node, "plus", self._make_literal("0", dtype), **kwargs) - - def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - dtype = type_info.extract_dtype(node.type) - min_value, _ = type_info.arithmetic_bounds(dtype) - init_expr = self._make_literal(str(min_value), dtype) - return self._make_reduction_expr(node, "maximum", init_expr, **kwargs) - - def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - dtype = type_info.extract_dtype(node.type) - _, max_value = type_info.arithmetic_bounds(dtype) - init_expr = self._make_literal(str(max_value), dtype) - return self._make_reduction_expr(node, "minimum", init_expr, **kwargs) - - def _visit_type_constr(self, node: foast.Call, **kwargs: Any) -> itir.Expr: - el = node.args[0] - node_kind = self.visit(node.type).kind.name.lower() - source_type = {**fbuiltins.BUILTINS, "string": str}[el.type.__str__().lower()] - target_type = fbuiltins.BUILTINS[node_kind] - - if isinstance(el, foast.Constant): - val = source_type(el.value) - elif isinstance(el, foast.UnaryOp) and isinstance(el.operand, foast.Constant): - operand = source_type(el.operand.value) - val = eval(f"lambda arg: {el.op}arg")(operand) - else: - raise FieldOperatorLoweringError( - f"Type cast only supports literal arguments, {node.type} not supported." - ) - val = target_type(val) - - return im.promote_to_const_iterator(im.literal(str(val), node_kind)) - - def _make_literal(self, val: Any, type_: ts.TypeSpec) -> itir.Expr: - # TODO(havogt): lifted nullary lambdas are not supported in iterator.embedded due to an implementation detail; - # the following constructs work if they are removed by inlining. - if isinstance(type_, ts.TupleType): - return im.make_tuple( - *(self._make_literal(val, type_) for val, type_ in zip(val, type_.types)) - ) - elif isinstance(type_, ts.ScalarType): - typename = type_.kind.name.lower() - return im.promote_to_const_iterator(im.literal(str(val), typename)) - raise ValueError(f"Unsupported literal type '{type_}'.") - - def visit_Constant(self, node: foast.Constant, **kwargs: Any) -> itir.Expr: - return self._make_literal(node.value, node.type) - - def _lower_and_map(self, op: itir.Expr | str, *args: Any, **kwargs: Any) -> itir.FunCall: - return _map( - op, tuple(self.visit(arg, **kwargs) for arg in args), tuple(arg.type for arg in args) - ) - - -def _map( - op: itir.Expr | str, - lowered_args: tuple, - original_arg_types: tuple[ts.TypeSpec, ...], -) -> itir.FunCall: - """ - Mapping includes making the operation an lifted stencil (first kind of mapping), but also `itir.map_`ing lists. - """ - if any(type_info.contains_local_field(arg_type) for arg_type in original_arg_types): - lowered_args = tuple( - promote_to_list(arg_type)(larg) - for arg_type, larg in zip(original_arg_types, lowered_args) - ) - op = im.call("map_")(op) - - return im.promote_to_lifted_stencil(im.call(op))(*lowered_args) - - -class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index c0348bb5c6..4ec12bb76b 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -9,7 +9,6 @@ from __future__ import annotations import dataclasses -import functools from typing import Any, Optional, cast import devtools @@ -19,7 +18,6 @@ from gt4py.next.ffront import ( fbuiltins, gtcallable, - lowering_utils, program_ast as past, stages as ffront_stages, transform_utils, @@ -32,10 +30,9 @@ from gt4py.next.type_system import type_info, type_specifications as ts -# FIXME[#1582](havogt): remove `to_gtir` arg after refactoring to GTIR # FIXME[#1582](tehrengruber): This should only depend on the program not the arguments. Remove # dependency as soon as column axis can be deduced from ITIR in consumers of the CompilableProgram. -def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgram: +def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: """ Lower a PAST program definition to Iterator IR. @@ -59,7 +56,7 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra ... column_axis=None, ... ) - >>> itir_copy = past_to_itir( + >>> itir_copy = past_to_gtir( ... toolchain.CompilableProgram(copy_program.past_stage, compile_time_args) ... ) @@ -67,7 +64,7 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra copy_program >>> print(type(itir_copy.data)) - + """ all_closure_vars = transform_utils._get_closure_vars_recursively(inp.data.closure_vars) offsets_and_dimensions = transform_utils._filter_closure_vars_by_type( @@ -88,13 +85,10 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra # making this step aware of the toolchain it is called by (it can be part of multiple). lowered_funcs = [] for gt_callable in gt_callables: - if to_gtir: - lowered_funcs.append(gt_callable.__gt_gtir__()) - else: - lowered_funcs.append(gt_callable.__gt_itir__()) + lowered_funcs.append(gt_callable.__gt_gtir__()) itir_program = ProgramLowering.apply( - inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type, to_gtir=to_gtir + inp.data.past_node, function_definitions=lowered_funcs, grid_type=grid_type ) if config.DEBUG or inp.data.debug: @@ -106,11 +100,10 @@ def past_to_itir(inp: AOT_PRG, to_gtir: bool = False) -> stages.CompilableProgra ) -# FIXME[#1582](havogt): remove `to_gtir` arg after refactoring to GTIR -def past_to_itir_factory( - cached: bool = True, to_gtir: bool = True +def past_to_gtir_factory( + cached: bool = True, ) -> workflow.Workflow[AOT_PRG, stages.CompilableProgram]: - wf = workflow.make_step(functools.partial(past_to_itir, to_gtir=to_gtir)) + wf = workflow.make_step(past_to_gtir) if cached: wf = workflow.CachedStep(wf, hash_function=ffront_stages.fingerprint_stage) return wf @@ -190,7 +183,7 @@ class ProgramLowering( ... parsed, [fieldop_def], grid_type=common.GridType.CARTESIAN ... ) # doctest: +SKIP >>> type(lowered) # doctest: +SKIP - + >>> lowered.id # doctest: +SKIP SymbolName('program') >>> lowered.params # doctest: +SKIP @@ -198,7 +191,6 @@ class ProgramLowering( """ grid_type: common.GridType - to_gtir: bool = False # FIXME[#1582](havogt): remove after refactoring to GTIR # TODO(tehrengruber): enable doctests again. For unknown / obscure reasons # the above doctest fails when executed using `pytest --doctest-modules`. @@ -209,11 +201,8 @@ def apply( node: past.Program, function_definitions: list[itir.FunctionDefinition], grid_type: common.GridType, - to_gtir: bool = False, # FIXME[#1582](havogt): remove after refactoring to GTIR - ) -> itir.FencilDefinition: - return cls(grid_type=grid_type, to_gtir=to_gtir).visit( - node, function_definitions=function_definitions - ) + ) -> itir.Program: + return cls(grid_type=grid_type).visit(node, function_definitions=function_definitions) def _gen_size_params_from_program(self, node: past.Program) -> list[itir.Sym]: """Generate symbols for each field param and dimension.""" @@ -246,7 +235,7 @@ def visit_Program( *, function_definitions: list[itir.FunctionDefinition], **kwargs: Any, - ) -> itir.FencilDefinition | itir.Program: + ) -> itir.Program: # The ITIR does not support dynamically getting the size of a field. As # a workaround we add additional arguments to the fencil definition # containing the size of all fields. The caller of a program is (e.g. @@ -259,27 +248,17 @@ def visit_Program( params = params + self._gen_size_params_from_program(node) implicit_domain = True - if self.to_gtir: - set_ats = [self._visit_stencil_call_as_set_at(stmt, **kwargs) for stmt in node.body] - return itir.Program( - id=node.id, - function_definitions=function_definitions, - params=params, - declarations=[], - body=set_ats, - implicit_domain=implicit_domain, - ) - else: - closures = [self._visit_stencil_call_as_closure(stmt, **kwargs) for stmt in node.body] - return itir.FencilDefinition( - id=node.id, - function_definitions=function_definitions, - params=params, - closures=closures, - implicit_domain=implicit_domain, - ) + set_ats = [self._visit_field_operator_call(stmt, **kwargs) for stmt in node.body] + return itir.Program( + id=node.id, + function_definitions=function_definitions, + params=params, + declarations=[], + body=set_ats, + implicit_domain=implicit_domain, + ) - def _visit_stencil_call_as_set_at(self, node: past.Call, **kwargs: Any) -> itir.SetAt: + def _visit_field_operator_call(self, node: past.Call, **kwargs: Any) -> itir.SetAt: assert isinstance(node.kwargs["out"].type, ts.TypeSpec) assert type_info.is_type_or_tuple_of_type(node.kwargs["out"].type, ts.FieldType) @@ -303,56 +282,6 @@ def _visit_stencil_call_as_set_at(self, node: past.Call, **kwargs: Any) -> itir. target=output, ) - # FIXME[#1582](havogt): remove after refactoring to GTIR - def _visit_stencil_call_as_closure(self, node: past.Call, **kwargs: Any) -> itir.StencilClosure: - assert isinstance(node.kwargs["out"].type, ts.TypeSpec) - assert type_info.is_type_or_tuple_of_type(node.kwargs["out"].type, ts.FieldType) - - node_kwargs = {**node.kwargs} - domain = node_kwargs.pop("domain", None) - output, lowered_domain = self._visit_stencil_call_out_arg( - node_kwargs.pop("out"), domain, **kwargs - ) - - assert isinstance(node.func.type, (ts_ffront.FieldOperatorType, ts_ffront.ScanOperatorType)) - - args, node_kwargs = type_info.canonicalize_arguments( - node.func.type, node.args, node_kwargs, use_signature_ordering=True - ) - - lowered_args, lowered_kwargs = self.visit(args, **kwargs), self.visit(node_kwargs, **kwargs) - - stencil_params = [] - stencil_args: list[itir.Expr] = [] - for i, arg in enumerate([*args, *node_kwargs]): - stencil_params.append(f"__stencil_arg{i}") - if isinstance(arg.type, ts.TupleType): - # convert into tuple of iterators - stencil_args.append( - lowering_utils.to_tuples_of_iterator(f"__stencil_arg{i}", arg.type) - ) - else: - stencil_args.append(im.ref(f"__stencil_arg{i}")) - - if isinstance(node.func.type, ts_ffront.ScanOperatorType): - # scan operators return an iterator of tuples, just deref directly - stencil_body = im.deref(im.call(node.func.id)(*stencil_args)) - else: - # field operators return a tuple of iterators, deref element-wise - stencil_body = lowering_utils.process_elements( - im.deref, - im.call(node.func.id)(*stencil_args), - node.func.type.definition.returns, - ) - - return itir.StencilClosure( - domain=lowered_domain, - stencil=im.lambda_(*stencil_params)(stencil_body), - inputs=[*lowered_args, *lowered_kwargs.values()], - output=output, - location=node.location, - ) - def _visit_slice_bound( self, slice_bound: Optional[past.Constant], diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 3c63ffef30..13c64e264e 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1706,8 +1706,10 @@ def impl(*iters: ItIterator): return impl -def _dimension_to_tag(domain: Domain) -> dict[Tag, range]: - return {k.value if isinstance(k, common.Dimension) else k: v for k, v in domain.items()} +def _dimension_to_tag( + domain: runtime.CartesianDomain | runtime.UnstructuredDomain, +) -> dict[Tag, range]: + return {k.value: v for k, v in domain.items()} def _validate_domain(domain: Domain, offset_provider_type: common.OffsetProviderType) -> None: @@ -1828,7 +1830,7 @@ def impl(*args): # TODO(havogt): after updating all tests to use the new program, # we should get rid of closure and move the implementation to this function - closure(_dimension_to_tag(domain), fun, out, list(args)) + closure(domain, fun, out, list(args)) return out return impl @@ -1839,9 +1841,8 @@ def index(axis: common.Dimension) -> common.Field: return IndexField(axis) -@runtime.closure.register(EMBEDDED) def closure( - domain_: Domain, + domain_: runtime.CartesianDomain | runtime.UnstructuredDomain, sten: Callable[..., Any], out, #: MutableLocatedField, ins: list[common.Field | Scalar | tuple[common.Field | Scalar | tuple, ...]], diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 7098e9fa2e..e875709631 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -9,7 +9,7 @@ from typing import ClassVar, List, Optional, Union import gt4py.eve as eve -from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels +from gt4py.eve import Coerced, SymbolName, SymbolRef from gt4py.eve.concepts import SourceLocation from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.eve.utils import noninstantiable @@ -19,10 +19,6 @@ DimensionKind = common.DimensionKind -# TODO(havogt): -# After completion of refactoring to GTIR, FencilDefinition and StencilClosure should be removed everywhere. -# During transition, we lower to FencilDefinitions and apply a transformation to GTIR-style afterwards. - @noninstantiable class Node(eve.Node): @@ -97,23 +93,6 @@ class FunctionDefinition(Node, SymbolTableTrait): expr: Expr -class StencilClosure(Node): - domain: FunCall - stencil: Expr - output: Union[SymRef, FunCall] - inputs: List[Union[SymRef, FunCall]] - - @datamodels.validator("output") - def _output_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): - if isinstance(value, FunCall) and value.fun != SymRef(id="make_tuple"): - raise ValueError("Only FunCall to 'make_tuple' allowed.") - - @datamodels.validator("inputs") - def _input_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value): - if any(isinstance(v, FunCall) and v.fun != SymRef(id="index") for v in value): - raise ValueError("Only FunCall to 'index' allowed.") - - UNARY_MATH_NUMBER_BUILTINS = {"abs"} UNARY_LOGICAL_BUILTINS = {"not_"} UNARY_MATH_FP_BUILTINS = { @@ -189,29 +168,11 @@ def _input_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribu "scan", "if_", "index", # `index(dim)` creates a dim-field that has the current index at each point + "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) *ARITHMETIC_BUILTINS, *TYPEBUILTINS, } -# only used in `Program`` not `FencilDefinition` -# TODO(havogt): restructure after refactoring to GTIR -GTIR_BUILTINS = { - *BUILTINS, - "as_fieldop", # `as_fieldop(stencil, domain)` creates field_operator from stencil (domain is optional, but for now required for embedded execution) -} - - -class FencilDefinition(Node, ValidatedSymbolTableTrait): - id: Coerced[SymbolName] - function_definitions: List[FunctionDefinition] - params: List[Sym] - closures: List[StencilClosure] - implicit_domain: bool = False - - _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ - Sym(id=name) for name in sorted(BUILTINS) - ] # sorted for serialization stability - class Stmt(Node): ... @@ -243,7 +204,7 @@ class Program(Node, ValidatedSymbolTableTrait): implicit_domain: bool = False _NODE_SYMBOLS_: ClassVar[List[Sym]] = [ - Sym(id=name) for name in sorted(GTIR_BUILTINS) + Sym(id=name) for name in sorted(BUILTINS) ] # sorted for serialization stability @@ -258,8 +219,6 @@ class Program(Node, ValidatedSymbolTableTrait): Lambda.__hash__ = Node.__hash__ # type: ignore[method-assign] FunCall.__hash__ = Node.__hash__ # type: ignore[method-assign] FunctionDefinition.__hash__ = Node.__hash__ # type: ignore[method-assign] -StencilClosure.__hash__ = Node.__hash__ # type: ignore[method-assign] -FencilDefinition.__hash__ = Node.__hash__ # type: ignore[method-assign] Program.__hash__ = Node.__hash__ # type: ignore[method-assign] SetAt.__hash__ = Node.__hash__ # type: ignore[method-assign] IfStmt.__hash__ = Node.__hash__ # type: ignore[method-assign] diff --git a/src/gt4py/next/iterator/ir_utils/domain_utils.py b/src/gt4py/next/iterator/ir_utils/domain_utils.py index f5625b509c..4a023f7535 100644 --- a/src/gt4py/next/iterator/ir_utils/domain_utils.py +++ b/src/gt4py/next/iterator/ir_utils/domain_utils.py @@ -16,6 +16,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms import trace_shifts +from gt4py.next.iterator.transforms.constant_folding import ConstantFolding def _max_domain_sizes_by_location_type(offset_provider: Mapping[str, Any]) -> dict[str, int]: @@ -168,6 +169,8 @@ def domain_union(*domains: SymbolicDomain) -> SymbolicDomain: lambda current_expr, el_expr: im.call("maximum")(current_expr, el_expr), [domain.ranges[dim].stop for domain in domains], ) + # constant fold expression to keep the tree small + start, stop = ConstantFolding.apply(start), ConstantFolding.apply(stop) # type: ignore[assignment] # always an itir.Expr new_domain_ranges[dim] = SymbolicRange(start, stop) return SymbolicDomain(domains[0].grid_type, new_domain_ranges) diff --git a/src/gt4py/next/iterator/pretty_parser.py b/src/gt4py/next/iterator/pretty_parser.py index b4a673772f..29b30beae1 100644 --- a/src/gt4py/next/iterator/pretty_parser.py +++ b/src/gt4py/next/iterator/pretty_parser.py @@ -216,10 +216,6 @@ def function_definition(self, *args: ir.Node) -> ir.FunctionDefinition: fid, *params, expr = args return ir.FunctionDefinition(id=fid, params=params, expr=expr) - def stencil_closure(self, *args: ir.Expr) -> ir.StencilClosure: - output, stencil, *inputs, domain = args - return ir.StencilClosure(domain=domain, stencil=stencil, output=output, inputs=inputs) - def if_stmt(self, cond: ir.Expr, *args): found_else_seperator = False true_branch = [] @@ -249,23 +245,6 @@ def set_at(self, *args: ir.Expr) -> ir.SetAt: target, domain, expr = args return ir.SetAt(expr=expr, domain=domain, target=target) - # TODO(havogt): remove after refactoring. - def fencil_definition(self, fid: str, *args: ir.Node) -> ir.FencilDefinition: - params = [] - function_definitions = [] - closures = [] - for arg in args: - if isinstance(arg, ir.Sym): - params.append(arg) - elif isinstance(arg, ir.FunctionDefinition): - function_definitions.append(arg) - else: - assert isinstance(arg, ir.StencilClosure) - closures.append(arg) - return ir.FencilDefinition( - id=fid, function_definitions=function_definitions, params=params, closures=closures - ) - def program(self, fid: str, *args: ir.Node) -> ir.Program: params = [] function_definitions = [] diff --git a/src/gt4py/next/iterator/pretty_printer.py b/src/gt4py/next/iterator/pretty_printer.py index 99287f8a11..a25f99356c 100644 --- a/src/gt4py/next/iterator/pretty_printer.py +++ b/src/gt4py/next/iterator/pretty_printer.py @@ -248,28 +248,6 @@ def visit_FunctionDefinition(self, node: ir.FunctionDefinition, prec: int) -> li vbody = self._vmerge(params, self._indent(expr)) return self._optimum(hbody, vbody) - def visit_StencilClosure(self, node: ir.StencilClosure, *, prec: int) -> list[str]: - assert prec == 0 - domain = self.visit(node.domain, prec=0) - stencil = self.visit(node.stencil, prec=0) - output = self.visit(node.output, prec=0) - inputs = self.visit(node.inputs, prec=0) - - hinputs = self._hmerge(["("], *self._hinterleave(inputs, ", "), [")"]) - vinputs = self._vmerge(["("], *self._hinterleave(inputs, ",", indent=True), [")"]) - inputs = self._optimum(hinputs, vinputs) - - head = self._hmerge(output, [" ← "]) - foot = self._hmerge(inputs, [" @ "], domain, [";"]) - - h = self._hmerge(head, ["("], stencil, [")"], foot) - v = self._vmerge( - self._hmerge(head, ["("]), - self._indent(self._indent(stencil)), - self._indent(self._hmerge([")"], foot)), - ) - return self._optimum(h, v) - def visit_Temporary(self, node: ir.Temporary, *, prec: int) -> list[str]: start, end = [node.id + " = temporary("], [");"] args = [] @@ -312,25 +290,6 @@ def visit_IfStmt(self, node: ir.IfStmt, *, prec: int) -> list[str]: head, self._indent(true_branch), ["} else {"], self._indent(false_branch), ["}"] ) - def visit_FencilDefinition(self, node: ir.FencilDefinition, *, prec: int) -> list[str]: - assert prec == 0 - function_definitions = self.visit(node.function_definitions, prec=0) - closures = self.visit(node.closures, prec=0) - params = self.visit(node.params, prec=0) - - hparams = self._hmerge([node.id + "("], *self._hinterleave(params, ", "), [") {"]) - vparams = self._vmerge( - [node.id + "("], *self._hinterleave(params, ",", indent=True), [") {"] - ) - params = self._optimum(hparams, vparams) - - function_definitions = self._vmerge(*function_definitions) - closures = self._vmerge(*closures) - - return self._vmerge( - params, self._indent(function_definitions), self._indent(closures), ["}"] - ) - def visit_Program(self, node: ir.Program, *, prec: int) -> list[str]: assert prec == 0 function_definitions = self.visit(node.function_definitions, prec=0) diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index d42f961202..e47a6886ad 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -26,7 +26,7 @@ # TODO(tehrengruber): remove cirular dependency and import unconditionally from gt4py.next import backend as next_backend -__all__ = ["offset", "fundef", "fendef", "closure", "set_at", "if_stmt"] +__all__ = ["offset", "fundef", "fendef", "set_at", "if_stmt"] @dataclass(frozen=True) @@ -163,7 +163,7 @@ def impl(out, *inps): # if passed as a dict, we need to convert back to builtins for interpretation by the backends assert offset_provider is not None dom = _deduce_domain(dom, common.offset_provider_to_type(offset_provider)) - closure(dom, self.fundef_dispatcher, out, [*inps]) + set_at(builtins.as_fieldop(self.fundef_dispatcher, dom)(*inps), dom, out) return impl @@ -208,11 +208,6 @@ def fundef(fun): return FundefDispatcher(fun) -@builtin_dispatch -def closure(*args): # TODO remove - return BackendNotSelectedError() - - @builtin_dispatch def set_at(*args): return BackendNotSelectedError() diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 6772d4b507..12c86680b5 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -23,7 +23,6 @@ Lambda, NoneLiteral, OffsetLiteral, - StencilClosure, Sym, SymRef, ) @@ -202,9 +201,6 @@ def __bool__(self): class TracerContext: fundefs: ClassVar[List[FunctionDefinition]] = [] - closures: ClassVar[ - List[StencilClosure] - ] = [] # TODO(havogt): remove after refactoring to `Program` is complete, currently handles both programs and fencils body: ClassVar[List[itir.Stmt]] = [] @classmethod @@ -212,10 +208,6 @@ def add_fundef(cls, fun): if fun not in cls.fundefs: cls.fundefs.append(fun) - @classmethod - def add_closure(cls, closure): - cls.closures.append(closure) - @classmethod def add_stmt(cls, stmt): cls.body.append(stmt) @@ -225,23 +217,10 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_traceback): type(self).fundefs = [] - type(self).closures = [] type(self).body = [] iterator.builtins.builtin_dispatch.pop_key() -@iterator.runtime.closure.register(TRACING) -def closure(domain, stencil, output, inputs): - if hasattr(stencil, "__name__") and stencil.__name__ in iterator.builtins.__all__: - stencil = _s(stencil.__name__) - else: - stencil(*(_s(param) for param in inspect.signature(stencil).parameters)) - stencil = make_node(stencil) - TracerContext.add_closure( - StencilClosure(domain=domain, stencil=stencil, output=output, inputs=inputs) - ) - - @iterator.runtime.set_at.register(TRACING) def set_at(expr: itir.Expr, domain: itir.Expr, target: itir.Expr) -> None: TracerContext.add_stmt(itir.SetAt(expr=expr, domain=domain, target=target)) @@ -279,7 +258,7 @@ def _contains_tuple_dtype_field(arg): return isinstance(arg, common.Field) and any(dim is None for dim in arg.domain.dims) -def _make_fencil_params(fun, args) -> list[Sym]: +def _make_program_params(fun, args) -> list[Sym]: params: list[Sym] = [] param_infos = list(inspect.signature(fun).parameters.values()) @@ -314,33 +293,22 @@ def _make_fencil_params(fun, args) -> list[Sym]: return params -def trace_fencil_definition( - fun: typing.Callable, args: typing.Iterable -) -> itir.FencilDefinition | itir.Program: +def trace_fencil_definition(fun: typing.Callable, args: typing.Iterable) -> itir.Program: """ - Transform fencil given as a callable into `itir.FencilDefinition` using tracing. + Transform fencil given as a callable into `itir.Program` using tracing. Arguments: - fun: The fencil / callable to trace. + fun: The program / callable to trace. args: A list of arguments, e.g. fields, scalars, composites thereof, or directly a type. """ with TracerContext() as _: - params = _make_fencil_params(fun, args) + params = _make_program_params(fun, args) trace_function_call(fun, args=(_s(param.id) for param in params)) - if TracerContext.closures: - return itir.FencilDefinition( - id=fun.__name__, - function_definitions=TracerContext.fundefs, - params=params, - closures=TracerContext.closures, - ) - else: - assert TracerContext.body - return itir.Program( - id=fun.__name__, - function_definitions=TracerContext.fundefs, - params=params, - declarations=[], # TODO - body=TracerContext.body, - ) + return itir.Program( + id=fun.__name__, + function_definitions=TracerContext.fundefs, + params=params, + declarations=[], # TODO + body=TracerContext.body, + ) diff --git a/src/gt4py/next/iterator/transforms/__init__.py b/src/gt4py/next/iterator/transforms/__init__.py index aeccb5f26d..d0afc610e7 100644 --- a/src/gt4py/next/iterator/transforms/__init__.py +++ b/src/gt4py/next/iterator/transforms/__init__.py @@ -7,10 +7,10 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.next.iterator.transforms.pass_manager import ( - ITIRTransform, + GTIRTransform, apply_common_transforms, apply_fieldview_transforms, ) -__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "ITIRTransform"] +__all__ = ["apply_common_transforms", "apply_fieldview_transforms", "GTIRTransform"] diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 01ccbc8ab6..ea7aad890c 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -173,7 +173,7 @@ def apply( offset_provider_type = offset_provider_type or {} uids = uids or eve_utils.UIDGenerator() - if isinstance(node, (ir.Program, ir.FencilDefinition)): + if isinstance(node, ir.Program): within_stencil = False assert within_stencil in [ True, diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 824adfdd8d..4f3fcbfdd5 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -376,7 +376,7 @@ def extract_subexpression( return _NodeReplacer(expr_map).visit(node), extracted, ignored_children -ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.FencilDefinition | itir.Expr) +ProgramOrExpr = TypeVar("ProgramOrExpr", bound=itir.Program | itir.Expr) @dataclasses.dataclass(frozen=True) @@ -413,7 +413,7 @@ def apply( within_stencil: bool | None = None, offset_provider_type: common.OffsetProviderType | None = None, ) -> ProgramOrExpr: - is_program = isinstance(node, (itir.Program, itir.FencilDefinition)) + is_program = isinstance(node, itir.Program) if is_program: assert within_stencil is None within_stencil = False diff --git a/src/gt4py/next/iterator/transforms/fencil_to_program.py b/src/gt4py/next/iterator/transforms/fencil_to_program.py deleted file mode 100644 index 4ad91645d4..0000000000 --- a/src/gt4py/next/iterator/transforms/fencil_to_program.py +++ /dev/null @@ -1,31 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py import eve -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im - - -class FencilToProgram(eve.NodeTranslator): - @classmethod - def apply(cls, node: itir.FencilDefinition | itir.Program) -> itir.Program: - return cls().visit(node) - - def visit_StencilClosure(self, node: itir.StencilClosure) -> itir.SetAt: - as_fieldop = im.call(im.call("as_fieldop")(node.stencil, node.domain))(*node.inputs) - return itir.SetAt(expr=as_fieldop, domain=node.domain, target=node.output) - - def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program: - return itir.Program( - id=node.id, - function_definitions=node.function_definitions, - params=node.params, - declarations=[], - body=self.visit(node.closures), - implicit_domain=node.implicit_domain, - ) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 9076bf2d3f..e8a221b814 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -53,7 +53,7 @@ def _canonicalize_as_fieldop(expr: itir.FunCall) -> itir.FunCall: if cpm.is_ref_to(stencil, "deref"): stencil = im.lambda_("arg")(im.deref("arg")) new_expr = im.as_fieldop(stencil, domain)(*expr.args) - type_inference.copy_type(from_=expr, to=new_expr) + type_inference.copy_type(from_=expr, to=new_expr, allow_untyped=True) return new_expr @@ -68,6 +68,107 @@ def _is_tuple_expr_of_literals(expr: itir.Expr): return isinstance(expr, itir.Literal) +def _inline_as_fieldop_arg( + arg: itir.Expr, *, uids: eve_utils.UIDGenerator +) -> tuple[itir.Expr, dict[str, itir.Expr]]: + assert cpm.is_applied_as_fieldop(arg) + arg = _canonicalize_as_fieldop(arg) + + stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` + inner_args: list[itir.Expr] = arg.args + extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg + + stencil_params: list[itir.Sym] = [] + stencil_body: itir.Expr = stencil.expr + + for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): + if isinstance(inner_arg, itir.SymRef): + stencil_params.append(inner_param) + extracted_args[inner_arg.id] = inner_arg + elif isinstance(inner_arg, itir.Literal): + # note: only literals, not all scalar expressions are required as it doesn't make sense + # for them to be computed per grid point. + stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( + stencil_body + ) + else: + # a scalar expression, a previously not inlined `as_fieldop` call or an opaque + # expression e.g. containing a tuple + stencil_params.append(inner_param) + new_outer_stencil_param = uids.sequential_id(prefix="__iasfop") + extracted_args[new_outer_stencil_param] = inner_arg + + return im.lift(im.lambda_(*stencil_params)(stencil_body))( + *extracted_args.keys() + ), extracted_args + + +def fuse_as_fieldop( + expr: itir.Expr, eligible_args: list[bool], *, uids: eve_utils.UIDGenerator +) -> itir.Expr: + assert cpm.is_applied_as_fieldop(expr) and isinstance(expr.fun.args[0], itir.Lambda) # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + + stencil: itir.Lambda = expr.fun.args[0] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + domain = expr.fun.args[1] if len(expr.fun.args) > 1 else None # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + + args: list[itir.Expr] = expr.args + + new_args: dict[str, itir.Expr] = {} + new_stencil_body: itir.Expr = stencil.expr + + for eligible, stencil_param, arg in zip(eligible_args, stencil.params, args, strict=True): + if eligible: + if cpm.is_applied_as_fieldop(arg): + pass + elif cpm.is_call_to(arg, "if_"): + # TODO(tehrengruber): revisit if we want to inline if_ + type_ = arg.type + arg = im.op_as_fieldop("if_")(*arg.args) + arg.type = type_ + elif _is_tuple_expr_of_literals(arg): + arg = im.op_as_fieldop(im.lambda_()(arg))() + else: + raise NotImplementedError() + + inline_expr, extracted_args = _inline_as_fieldop_arg(arg, uids=uids) + + new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body) + + new_args = _merge_arguments(new_args, extracted_args) + else: + # just a safety check if typing information is available + if arg.type and not isinstance(arg.type, ts.DeferredType): + assert isinstance(arg.type, ts.TypeSpec) + dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) + assert not isinstance(dtype, it_ts.ListType) + new_param: str + if isinstance( + arg, itir.SymRef + ): # use name from outer scope (optional, just to get a nice IR) + new_param = arg.id + new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body) + else: + new_param = stencil_param.id + new_args = _merge_arguments(new_args, {new_param: arg}) + + new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( + *new_args.values() + ) + + # simplify stencil directly to keep the tree small + new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( + new_node + ) # to keep the tree small + new_node = inline_lambdas.InlineLambdas.apply( + new_node, opcount_preserving=True, force_inline_lift_args=True + ) + new_node = inline_lifts.InlineLifts().visit(new_node) + + type_inference.copy_type(from_=expr, to=new_node, allow_untyped=True) + + return new_node + + @dataclasses.dataclass class FuseAsFieldOp(eve.NodeTranslator): """ @@ -98,38 +199,6 @@ class FuseAsFieldOp(eve.NodeTranslator): uids: eve_utils.UIDGenerator - def _inline_as_fieldop_arg(self, arg: itir.Expr) -> tuple[itir.Expr, dict[str, itir.Expr]]: - assert cpm.is_applied_as_fieldop(arg) - arg = _canonicalize_as_fieldop(arg) - - stencil, *_ = arg.fun.args # type: ignore[attr-defined] # ensured by `is_applied_as_fieldop` - inner_args: list[itir.Expr] = arg.args - extracted_args: dict[str, itir.Expr] = {} # mapping from outer-stencil param to arg - - stencil_params: list[itir.Sym] = [] - stencil_body: itir.Expr = stencil.expr - - for inner_param, inner_arg in zip(stencil.params, inner_args, strict=True): - if isinstance(inner_arg, itir.SymRef): - stencil_params.append(inner_param) - extracted_args[inner_arg.id] = inner_arg - elif isinstance(inner_arg, itir.Literal): - # note: only literals, not all scalar expressions are required as it doesn't make sense - # for them to be computed per grid point. - stencil_body = im.let(inner_param, im.promote_to_const_iterator(inner_arg))( - stencil_body - ) - else: - # a scalar expression, a previously not inlined `as_fieldop` call or an opaque - # expression e.g. containing a tuple - stencil_params.append(inner_param) - new_outer_stencil_param = self.uids.sequential_id(prefix="__iasfop") - extracted_args[new_outer_stencil_param] = inner_arg - - return im.lift(im.lambda_(*stencil_params)(stencil_body))( - *extracted_args.keys() - ), extracted_args - @classmethod def apply( cls, @@ -158,72 +227,26 @@ def visit_FunCall(self, node: itir.FunCall): if cpm.is_call_to(node.fun, "as_fieldop") and isinstance(node.fun.args[0], itir.Lambda): stencil: itir.Lambda = node.fun.args[0] - domain = node.fun.args[1] if len(node.fun.args) > 1 else None - - shifts = trace_shifts.trace_stencil(stencil) - args: list[itir.Expr] = node.args + shifts = trace_shifts.trace_stencil(stencil) - new_args: dict[str, itir.Expr] = {} - new_stencil_body: itir.Expr = stencil.expr - - for stencil_param, arg, arg_shifts in zip(stencil.params, args, shifts, strict=True): + eligible_args = [] + for arg, arg_shifts in zip(args, shifts, strict=True): assert isinstance(arg.type, ts.TypeSpec) dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type) # TODO(tehrengruber): make this configurable - should_inline = _is_tuple_expr_of_literals(arg) or ( - isinstance(arg, itir.FunCall) - and ( - cpm.is_call_to(arg.fun, "as_fieldop") - and isinstance(arg.fun.args[0], itir.Lambda) - or cpm.is_call_to(arg, "if_") + eligible_args.append( + _is_tuple_expr_of_literals(arg) + or ( + isinstance(arg, itir.FunCall) + and ( + cpm.is_call_to(arg.fun, "as_fieldop") + and isinstance(arg.fun.args[0], itir.Lambda) + or cpm.is_call_to(arg, "if_") + ) + and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) ) - and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) ) - if should_inline: - if cpm.is_applied_as_fieldop(arg): - pass - elif cpm.is_call_to(arg, "if_"): - # TODO(tehrengruber): revisit if we want to inline if_ - type_ = arg.type - arg = im.op_as_fieldop("if_")(*arg.args) - arg.type = type_ - elif _is_tuple_expr_of_literals(arg): - arg = im.op_as_fieldop(im.lambda_()(arg))() - else: - raise NotImplementedError() - - inline_expr, extracted_args = self._inline_as_fieldop_arg(arg) - - new_stencil_body = im.let(stencil_param, inline_expr)(new_stencil_body) - - new_args = _merge_arguments(new_args, extracted_args) - else: - assert not isinstance(dtype, it_ts.ListType) - new_param: str - if isinstance( - arg, itir.SymRef - ): # use name from outer scope (optional, just to get a nice IR) - new_param = arg.id - new_stencil_body = im.let(stencil_param.id, arg.id)(new_stencil_body) - else: - new_param = stencil_param.id - new_args = _merge_arguments(new_args, {new_param: arg}) - - new_node = im.as_fieldop(im.lambda_(*new_args.keys())(new_stencil_body), domain)( - *new_args.values() - ) - - # simplify stencil directly to keep the tree small - new_node = inline_center_deref_lift_vars.InlineCenterDerefLiftVars.apply( - new_node - ) # to keep the tree small - new_node = inline_lambdas.InlineLambdas.apply( - new_node, opcount_preserving=True, force_inline_lift_args=True - ) - new_node = inline_lifts.InlineLifts().visit(new_node) - - type_inference.copy_type(from_=node, to=new_node) - return new_node + return fuse_as_fieldop(node, eligible_args, uids=self.uids) return node diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index a6d39883e3..334fb330d7 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -74,7 +74,7 @@ def _transform_by_pattern( # or a tuple thereof) # - one `SetAt` statement that materializes the expression into the temporary for tmp_sym, tmp_expr in extracted_fields.items(): - domain = tmp_expr.annex.domain + domain: infer_domain.DomainAccess = tmp_expr.annex.domain # TODO(tehrengruber): Implement. This happens when the expression is a combination # of an `if_` call with a tuple, e.g., `if_(cond, {a, b}, {c, d})`. As long as we are @@ -186,7 +186,7 @@ def create_global_tmps( This pass looks at all `as_fieldop` calls and transforms field-typed subexpressions of its arguments into temporaries. """ - program = infer_domain.infer_program(program, offset_provider) + program = infer_domain.infer_program(program, offset_provider=offset_provider) program = type_inference.infer( program, offset_provider_type=common.offset_provider_to_type(offset_provider) ) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 6852b47a7a..f26d3f9ec2 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -10,10 +10,10 @@ import itertools import typing -from typing import Callable, Optional, TypeAlias from gt4py import eve from gt4py.eve import utils as eve_utils +from gt4py.eve.extended_typing import Callable, Optional, TypeAlias, Unpack from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ( @@ -25,8 +25,35 @@ from gt4py.next.utils import flatten_nested_tuple, tree_map -DOMAIN: TypeAlias = domain_utils.SymbolicDomain | None | tuple["DOMAIN", ...] -ACCESSED_DOMAINS: TypeAlias = dict[str, DOMAIN] +class DomainAccessDescriptor(eve.StrEnum): + """ + Descriptor for domains that could not be inferred. + """ + + # TODO(tehrengruber): Revisit this concept. It is strange that we don't have a descriptor + # `KNOWN`, but since we don't need it, it wasn't added. + + #: The access is unknown because of a dynamic shift.whose extent is not known. + #: E.g.: `(⇑(λ(arg0, arg1) → ·⟪Ioffₒ, ·arg1⟫(arg0)))(in_field1, in_field2)` + UNKNOWN = "unknown" + #: The domain is never accessed. + #: E.g.: `{in_field1, in_field2}[0]` + NEVER = "never" + + +NonTupleDomainAccess: TypeAlias = domain_utils.SymbolicDomain | DomainAccessDescriptor +#: The domain can also be a tuple of domains, usually this only occurs for scan operators returning +#: a tuple since other occurrences for tuples are removed before domain inference. This is +#: however not a requirement of the pass and `make_tuple(vertex_field, edge_field)` infers just +#: fine to a tuple of a vertex and an edge domain. +DomainAccess: TypeAlias = NonTupleDomainAccess | tuple["DomainAccess", ...] +AccessedDomains: TypeAlias = dict[str, DomainAccess] + + +class InferenceOptions(typing.TypedDict): + offset_provider: common.OffsetProvider + symbolic_domain_sizes: Optional[dict[str, str]] + allow_uninferred: bool class DomainAnnexDebugger(eve.NodeVisitor): @@ -57,43 +84,58 @@ def _split_dict_by_key(pred: Callable, d: dict): # TODO(tehrengruber): Revisit whether we want to move this behaviour to `domain_utils.domain_union`. -def _domain_union_with_none( - *domains: domain_utils.SymbolicDomain | None, -) -> domain_utils.SymbolicDomain | None: - filtered_domains: list[domain_utils.SymbolicDomain] = [d for d in domains if d is not None] +def _domain_union( + *domains: domain_utils.SymbolicDomain | DomainAccessDescriptor, +) -> domain_utils.SymbolicDomain | DomainAccessDescriptor: + if any(d == DomainAccessDescriptor.UNKNOWN for d in domains): + return DomainAccessDescriptor.UNKNOWN + + filtered_domains: list[domain_utils.SymbolicDomain] = [ + d # type: ignore[misc] # domain can never be unknown as these cases are filtered above + for d in domains + if d != DomainAccessDescriptor.NEVER + ] if len(filtered_domains) == 0: - return None + return DomainAccessDescriptor.NEVER return domain_utils.domain_union(*filtered_domains) -def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMAIN]: +def _canonicalize_domain_structure( + d1: DomainAccess, d2: DomainAccess +) -> tuple[DomainAccess, DomainAccess]: """ Given two domains or composites thereof, canonicalize their structure. If one of the arguments is a tuple the other one will be promoted to a tuple of same structure - unless it already is a tuple. Missing values are replaced by None, meaning no domain is - specified. + unless it already is a tuple. Missing values are filled by :ref:`DomainAccessDescriptor.NEVER`. >>> domain = im.domain(common.GridType.CARTESIAN, {}) >>> _canonicalize_domain_structure((domain,), (domain, domain)) == ( - ... (domain, None), + ... (domain, DomainAccessDescriptor.NEVER), ... (domain, domain), ... ) True - >>> _canonicalize_domain_structure((domain, None), None) == ((domain, None), (None, None)) + >>> _canonicalize_domain_structure( + ... (domain, DomainAccessDescriptor.NEVER), DomainAccessDescriptor.NEVER + ... ) == ( + ... (domain, DomainAccessDescriptor.NEVER), + ... (DomainAccessDescriptor.NEVER, DomainAccessDescriptor.NEVER), + ... ) True """ - if d1 is None and isinstance(d2, tuple): - return _canonicalize_domain_structure((None,) * len(d2), d2) - if d2 is None and isinstance(d1, tuple): - return _canonicalize_domain_structure(d1, (None,) * len(d1)) + if d1 is DomainAccessDescriptor.NEVER and isinstance(d2, tuple): + return _canonicalize_domain_structure((DomainAccessDescriptor.NEVER,) * len(d2), d2) + if d2 is DomainAccessDescriptor.NEVER and isinstance(d1, tuple): + return _canonicalize_domain_structure(d1, (DomainAccessDescriptor.NEVER,) * len(d1)) if isinstance(d1, tuple) and isinstance(d2, tuple): return tuple( zip( *( _canonicalize_domain_structure(el1, el2) - for el1, el2 in itertools.zip_longest(d1, d2, fillvalue=None) + for el1, el2 in itertools.zip_longest( + d1, d2, fillvalue=DomainAccessDescriptor.NEVER + ) ) ) ) # type: ignore[return-value] # mypy not smart enough @@ -101,16 +143,16 @@ def _canonicalize_domain_structure(d1: DOMAIN, d2: DOMAIN) -> tuple[DOMAIN, DOMA def _merge_domains( - original_domains: ACCESSED_DOMAINS, - additional_domains: ACCESSED_DOMAINS, -) -> ACCESSED_DOMAINS: + original_domains: AccessedDomains, + additional_domains: AccessedDomains, +) -> AccessedDomains: new_domains = {**original_domains} for key, domain in additional_domains.items(): original_domain, domain = _canonicalize_domain_structure( - original_domains.get(key, None), domain + original_domains.get(key, DomainAccessDescriptor.NEVER), domain ) - new_domains[key] = tree_map(_domain_union_with_none)(original_domain, domain) + new_domains[key] = tree_map(_domain_union)(original_domain, domain) return new_domains @@ -118,44 +160,52 @@ def _merge_domains( def _extract_accessed_domains( stencil: itir.Expr, input_ids: list[str], - target_domain: domain_utils.SymbolicDomain, + target_domain: NonTupleDomainAccess, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], -) -> ACCESSED_DOMAINS: - accessed_domains: dict[str, domain_utils.SymbolicDomain | None] = {} +) -> dict[str, NonTupleDomainAccess]: + accessed_domains: dict[str, NonTupleDomainAccess] = {} shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True): + # TODO(tehrengruber): Dynamic shifts are not supported by `SymbolicDomain.translate`. Use + # special `UNKNOWN` marker for them until we have implemented a proper solution. + if any(s == trace_shifts.Sentinel.VALUE for shift in shifts_list for s in shift): + accessed_domains[in_field_id] = DomainAccessDescriptor.UNKNOWN + continue + new_domains = [ domain_utils.SymbolicDomain.translate( target_domain, shift, offset_provider, symbolic_domain_sizes ) + if not isinstance(target_domain, DomainAccessDescriptor) + else target_domain for shift in shifts_list ] - # `None` means field is never accessed - accessed_domains[in_field_id] = _domain_union_with_none( - accessed_domains.get(in_field_id, None), *new_domains + accessed_domains[in_field_id] = _domain_union( + accessed_domains.get(in_field_id, DomainAccessDescriptor.NEVER), *new_domains ) - return typing.cast(ACCESSED_DOMAINS, accessed_domains) + return accessed_domains def _infer_as_fieldop( applied_fieldop: itir.FunCall, - target_domain: DOMAIN, + target_domain: DomainAccess, + *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: + allow_uninferred: bool, +) -> tuple[itir.FunCall, AccessedDomains]: assert isinstance(applied_fieldop, itir.FunCall) assert cpm.is_call_to(applied_fieldop.fun, "as_fieldop") - if target_domain is None: - raise ValueError("'target_domain' cannot be 'None'.") + if not allow_uninferred and target_domain is DomainAccessDescriptor.NEVER: + raise ValueError("'target_domain' cannot be 'NEVER' unless `allow_uninferred=True`.") # FIXME[#1582](tehrengruber): Temporary solution for `tuple_get` on scan result. See `test_solve_triag`. if isinstance(target_domain, tuple): - target_domain = _domain_union_with_none(*flatten_nested_tuple(target_domain)) - if not isinstance(target_domain, domain_utils.SymbolicDomain): - raise ValueError("'target_domain' needs to be a 'domain_utils.SymbolicDomain'.") + target_domain = _domain_union(*flatten_nested_tuple(target_domain)) # type: ignore[arg-type] # mypy not smart enough + assert isinstance(target_domain, (domain_utils.SymbolicDomain, DomainAccessDescriptor)) # `as_fieldop(stencil)(inputs...)` stencil, inputs = applied_fieldop.fun.args[0], applied_fieldop.args @@ -177,22 +227,29 @@ def _infer_as_fieldop( raise ValueError(f"Unsupported expression of type '{type(in_field)}'.") input_ids.append(id_) - inputs_accessed_domains: ACCESSED_DOMAINS = _extract_accessed_domains( + inputs_accessed_domains: dict[str, NonTupleDomainAccess] = _extract_accessed_domains( stencil, input_ids, target_domain, offset_provider, symbolic_domain_sizes ) # Recursively infer domain of inputs and update domain arg of nested `as_fieldop`s - accessed_domains: ACCESSED_DOMAINS = {} + accessed_domains: AccessedDomains = {} transformed_inputs: list[itir.Expr] = [] for in_field_id, in_field in zip(input_ids, inputs): transformed_input, accessed_domains_tmp = infer_expr( - in_field, inputs_accessed_domains[in_field_id], offset_provider, symbolic_domain_sizes + in_field, + inputs_accessed_domains[in_field_id], + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, ) transformed_inputs.append(transformed_input) accessed_domains = _merge_domains(accessed_domains, accessed_domains_tmp) - target_domain_expr = domain_utils.SymbolicDomain.as_expr(target_domain) + if not isinstance(target_domain, DomainAccessDescriptor): + target_domain_expr = domain_utils.SymbolicDomain.as_expr(target_domain) + else: + target_domain_expr = None transformed_call = im.as_fieldop(stencil, target_domain_expr)(*transformed_inputs) accessed_domains_without_tmp = { @@ -206,17 +263,15 @@ def _infer_as_fieldop( def _infer_let( let_expr: itir.FunCall, - input_domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.FunCall, ACCESSED_DOMAINS]: + input_domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.FunCall, AccessedDomains]: assert cpm.is_let(let_expr) assert isinstance(let_expr.fun, itir.Lambda) # just to make mypy happy - transformed_calls_expr, accessed_domains = infer_expr( - let_expr.fun.expr, input_domain, offset_provider, symbolic_domain_sizes - ) - let_params = {param_sym.id for param_sym in let_expr.fun.params} + + transformed_calls_expr, accessed_domains = infer_expr(let_expr.fun.expr, input_domain, **kwargs) + accessed_domains_let_args, accessed_domains_outer = _split_dict_by_key( lambda k: k in let_params, accessed_domains ) @@ -227,10 +282,9 @@ def _infer_let( arg, accessed_domains_let_args.get( param.id, - None, + DomainAccessDescriptor.NEVER, ), - offset_provider, - symbolic_domain_sizes, + **kwargs, ) accessed_domains_outer = _merge_domains(accessed_domains_outer, accessed_domains_arg) transformed_calls_args.append(transformed_calls_arg) @@ -247,13 +301,12 @@ def _infer_let( def _infer_make_tuple( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "make_tuple") infered_args_expr = [] - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} if not isinstance(domain, tuple): # promote domain to a tuple of domains such that it has the same structure as # the expression @@ -261,13 +314,12 @@ def _infer_make_tuple( # out @ c⟨ IDimₕ: [0, __out_size_0) ⟩ ← {__sym_1, __sym_2}; domain = (domain,) * len(expr.args) assert len(expr.args) >= len(domain) - # There may be less domains than tuple args, pad the domain with `None` in that case. - # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` - domain = (*domain, *(None for _ in range(len(expr.args) - len(domain)))) + # There may be fewer domains than tuple args, pad the domain with `NEVER` + # in that case. + # e.g. `im.tuple_get(0, im.make_tuple(a, b), domain=domain)` + domain = (*domain, *(DomainAccessDescriptor.NEVER for _ in range(len(expr.args) - len(domain)))) for i, arg in enumerate(expr.args): - infered_arg_expr, actual_domains_arg = infer_expr( - arg, domain[i], offset_provider, symbolic_domain_sizes - ) + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain[i], **kwargs) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(*infered_args_expr) @@ -276,19 +328,18 @@ def _infer_make_tuple( def _infer_tuple_get( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "tuple_get") - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} idx_expr, tuple_arg = expr.args assert isinstance(idx_expr, itir.Literal) idx = int(idx_expr.value) - tuple_domain = tuple(None if i != idx else domain for i in range(idx + 1)) - infered_arg_expr, actual_domains_arg = infer_expr( - tuple_arg, tuple_domain, offset_provider, symbolic_domain_sizes + tuple_domain = tuple( + DomainAccessDescriptor.NEVER if i != idx else domain for i in range(idx + 1) ) + infered_arg_expr, actual_domains_arg = infer_expr(tuple_arg, tuple_domain, **kwargs) infered_args_expr = im.tuple_get(idx, infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) @@ -297,18 +348,15 @@ def _infer_tuple_get( def _infer_if( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: assert cpm.is_call_to(expr, "if_") infered_args_expr = [] - actual_domains: ACCESSED_DOMAINS = {} + actual_domains: AccessedDomains = {} cond, true_val, false_val = expr.args for arg in [true_val, false_val]: - infered_arg_expr, actual_domains_arg = infer_expr( - arg, domain, offset_provider, symbolic_domain_sizes - ) + infered_arg_expr, actual_domains_arg = infer_expr(arg, domain, **kwargs) infered_args_expr.append(infered_arg_expr) actual_domains = _merge_domains(actual_domains, actual_domains_arg) result_expr = im.call(expr.fun)(cond, *infered_args_expr) @@ -317,24 +365,23 @@ def _infer_if( def _infer_expr( expr: itir.Expr, - domain: DOMAIN, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + domain: DomainAccess, + **kwargs: Unpack[InferenceOptions], +) -> tuple[itir.Expr, AccessedDomains]: if isinstance(expr, itir.SymRef): return expr, {str(expr.id): domain} elif isinstance(expr, itir.Literal): return expr, {} elif cpm.is_applied_as_fieldop(expr): - return _infer_as_fieldop(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_as_fieldop(expr, domain, **kwargs) elif cpm.is_let(expr): - return _infer_let(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_let(expr, domain, **kwargs) elif cpm.is_call_to(expr, "make_tuple"): - return _infer_make_tuple(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_make_tuple(expr, domain, **kwargs) elif cpm.is_call_to(expr, "tuple_get"): - return _infer_tuple_get(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_tuple_get(expr, domain, **kwargs) elif cpm.is_call_to(expr, "if_"): - return _infer_if(expr, domain, offset_provider, symbolic_domain_sizes) + return _infer_if(expr, domain, **kwargs) elif ( cpm.is_call_to(expr, itir.ARITHMETIC_BUILTINS) or cpm.is_call_to(expr, itir.TYPEBUILTINS) @@ -347,10 +394,12 @@ def _infer_expr( def infer_expr( expr: itir.Expr, - domain: DOMAIN, + domain: DomainAccess, + *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, -) -> tuple[itir.Expr, ACCESSED_DOMAINS]: + allow_uninferred: bool = False, +) -> tuple[itir.Expr, AccessedDomains]: """ Infer the domain of all field subexpressions of `expr`. @@ -362,30 +411,35 @@ def infer_expr( - domain: The domain `expr` is read at. - symbolic_domain_sizes: A dictionary mapping axes names, e.g., `I`, `Vertex`, to a symbol name that evaluates to the length of that axis. + - allow_uninferred: Allow `as_fieldop` expressions whose domain is either unknown (e.g. + because of a dynamic shift) or never accessed. Returns: A tuple containing the inferred expression with all applied `as_fieldop` (that are accessed) having a domain argument now, and a dictionary mapping symbol names referenced in `expr` to domain they are accessed at. """ - # this is just a small wrapper that populates the `domain` annex - expr, accessed_domains = _infer_expr(expr, domain, offset_provider, symbolic_domain_sizes) + expr, accessed_domains = _infer_expr( + expr, + domain, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, + ) expr.annex.domain = domain + return expr, accessed_domains def _infer_stmt( stmt: itir.Stmt, - offset_provider: common.OffsetProvider, - symbolic_domain_sizes: Optional[dict[str, str]], + **kwargs: Unpack[InferenceOptions], ): if isinstance(stmt, itir.SetAt): - transformed_call, _unused_domain = infer_expr( - stmt.expr, - domain_utils.SymbolicDomain.from_expr(stmt.domain), - offset_provider, - symbolic_domain_sizes, + transformed_call, _ = infer_expr( + stmt.expr, domain_utils.SymbolicDomain.from_expr(stmt.domain), **kwargs ) + return itir.SetAt( expr=transformed_call, domain=stmt.domain, @@ -394,20 +448,18 @@ def _infer_stmt( elif isinstance(stmt, itir.IfStmt): return itir.IfStmt( cond=stmt.cond, - true_branch=[ - _infer_stmt(c, offset_provider, symbolic_domain_sizes) for c in stmt.true_branch - ], - false_branch=[ - _infer_stmt(c, offset_provider, symbolic_domain_sizes) for c in stmt.false_branch - ], + true_branch=[_infer_stmt(c, **kwargs) for c in stmt.true_branch], + false_branch=[_infer_stmt(c, **kwargs) for c in stmt.false_branch], ) raise ValueError(f"Unsupported stmt: {stmt}") def infer_program( program: itir.Program, + *, offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, + allow_uninferred: bool = False, ) -> itir.Program: """ Infer the domain of all field subexpressions inside a program. @@ -423,5 +475,13 @@ def infer_program( function_definitions=program.function_definitions, params=program.params, declarations=program.declarations, - body=[_infer_stmt(stmt, offset_provider, symbolic_domain_sizes) for stmt in program.body], + body=[ + _infer_stmt( + stmt, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, + ) + for stmt in program.body + ], ) diff --git a/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py new file mode 100644 index 0000000000..0af9d9dab9 --- /dev/null +++ b/src/gt4py/next/iterator/transforms/inline_dynamic_shifts.py @@ -0,0 +1,73 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dataclasses +from typing import Optional + +import gt4py.next.iterator.ir_utils.common_pattern_matcher as cpm +from gt4py import eve +from gt4py.eve import utils as eve_utils +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.transforms import fuse_as_fieldop, inline_lambdas, trace_shifts +from gt4py.next.iterator.transforms.symbol_ref_utils import collect_symbol_refs + + +def _dynamic_shift_args(node: itir.Expr) -> None | list[bool]: + if not cpm.is_applied_as_fieldop(node): + return None + params_shifts = trace_shifts.trace_stencil( + node.fun.args[0], # type: ignore[attr-defined] # ensured by is_applied_as_fieldop + num_args=len(node.args), + save_to_annex=True, + ) + dynamic_shifts = [ + any(trace_shifts.Sentinel.VALUE in shifts for shifts in param_shifts) + for param_shifts in params_shifts + ] + return dynamic_shifts + + +@dataclasses.dataclass +class InlineDynamicShifts(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): + uids: eve_utils.UIDGenerator + + @classmethod + def apply(cls, node: itir.Program, uids: Optional[eve_utils.UIDGenerator] = None): + if not uids: + uids = eve_utils.UIDGenerator() + + return cls(uids=uids).visit(node) + + def visit_FunCall(self, node: itir.FunCall, **kwargs): + node = self.generic_visit(node, **kwargs) + + if cpm.is_let(node) and ( + dynamic_shift_args := _dynamic_shift_args(let_body := node.fun.expr) # type: ignore[attr-defined] # ensured by is_let + ): + inline_let_params = {p.id: False for p in node.fun.params} # type: ignore[attr-defined] # ensured by is_let + + for inp, is_dynamic_shift_arg in zip(let_body.args, dynamic_shift_args, strict=True): + for ref in collect_symbol_refs(inp): + if ref in inline_let_params and is_dynamic_shift_arg: + inline_let_params[ref] = True + + if any(inline_let_params): + node = inline_lambdas.inline_lambda( + node, eligible_params=list(inline_let_params.values()) + ) + + if dynamic_shift_args := _dynamic_shift_args(node): + assert len(node.fun.args) in [1, 2] # type: ignore[attr-defined] # ensured by is_applied_as_fieldop in _dynamic_shift_args + fuse_args = [ + not isinstance(inp, itir.SymRef) and dynamic_shift_arg + for inp, dynamic_shift_arg in zip(node.args, dynamic_shift_args, strict=True) + ] + if any(fuse_args): + return fuse_as_fieldop.fuse_as_fieldop(node, fuse_args, uids=self.uids) + + return node diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index 5ec9ec5d0b..9053214b39 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -97,7 +97,6 @@ def new_name(name): if all(eligible_params): new_expr.location = node.location - return new_expr else: new_expr = ir.FunCall( fun=ir.Lambda( @@ -111,11 +110,11 @@ def new_name(name): args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], location=node.location, ) - for attr in ("type", "recorded_shifts", "domain"): - if hasattr(node.annex, attr): - setattr(new_expr.annex, attr, getattr(node.annex, attr)) - itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) - return new_expr + for attr in ("type", "recorded_shifts", "domain"): + if hasattr(node.annex, attr): + setattr(new_expr.annex, attr, getattr(node.annex, attr)) + itir_inference.copy_type(from_=node, to=new_expr, allow_untyped=True) + return new_expr @dataclasses.dataclass diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 2ed6e93f2d..f3cb0cc468 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -6,16 +6,16 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Callable, Optional, Protocol +from typing import Optional, Protocol from gt4py.eve import utils as eve_utils from gt4py.next import common from gt4py.next.iterator import ir as itir from gt4py.next.iterator.transforms import ( - fencil_to_program, fuse_as_fieldop, global_tmps, infer_domain, + inline_dynamic_shifts, inline_fundefs, inline_lifts, ) @@ -32,16 +32,16 @@ from gt4py.next.iterator.type_system.inference import infer -class ITIRTransform(Protocol): +class GTIRTransform(Protocol): def __call__( - self, _: itir.Program | itir.FencilDefinition, *, offset_provider: common.OffsetProvider + self, _: itir.Program, *, offset_provider: common.OffsetProvider ) -> itir.Program: ... # TODO(tehrengruber): Revisit interface to configure temporary extraction. We currently forward # `extract_temporaries` and `temporary_extraction_heuristics` which is inconvenient. def apply_common_transforms( - ir: itir.Program | itir.FencilDefinition, + ir: itir.Program, *, offset_provider=None, # TODO(havogt): should be replaced by offset_provider_type, but global_tmps currently relies on runtime info extract_temporaries=False, @@ -49,10 +49,6 @@ def apply_common_transforms( common_subexpression_elimination=True, force_inline_lambda_args=False, unconditionally_collapse_tuples=False, - # FIXME[#1582](tehrengruber): Revisit and cleanup after new GTIR temporary pass is in place - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, #: A dictionary mapping axes names to their length. See :func:`infer_domain.infer_expr` for #: more details. symbolic_domain_sizes: Optional[dict[str, str]] = None, @@ -62,9 +58,6 @@ def apply_common_transforms( if offset_provider_type is None: offset_provider_type = common.offset_provider_to_type(offset_provider) - # FIXME[#1582](tehrengruber): Rewrite iterator tests with itir.Program and remove this - if isinstance(ir, itir.FencilDefinition): - ir = fencil_to_program.FencilToProgram.apply(ir) assert isinstance(ir, itir.Program) tmp_uids = eve_utils.UIDGenerator(prefix="__tmp") @@ -74,7 +67,7 @@ def apply_common_transforms( ir = MergeLet().visit(ir) ir = inline_fundefs.InlineFundefs().visit(ir) - ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program + ir = inline_fundefs.prune_unreferenced_fundefs(ir) ir = NormalizeShifts().visit(ir) # note: this increases the size of the tree @@ -87,8 +80,11 @@ def apply_common_transforms( uids=collapse_tuple_uids, offset_provider_type=offset_provider_type, ) # type: ignore[assignment] # always an itir.Program + ir = inline_dynamic_shifts.InlineDynamicShifts.apply( + ir + ) # domain inference does not support dynamic offsets yet ir = infer_domain.infer_program( - ir, # type: ignore[arg-type] # always an itir.Program + ir, offset_provider=offset_provider, symbolic_domain_sizes=symbolic_domain_sizes, ) @@ -130,7 +126,7 @@ def apply_common_transforms( if extract_temporaries: ir = infer(ir, inplace=True, offset_provider_type=offset_provider_type) - ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # type: ignore[arg-type] # always an itir.Program + ir = global_tmps.create_global_tmps(ir, offset_provider=offset_provider, uids=tmp_uids) # Since `CollapseTuple` relies on the type inference which does not support returning tuples # larger than the number of closure outputs as given by the unconditional collapse, we can @@ -182,5 +178,8 @@ def apply_fieldview_transforms( flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, offset_provider_type=common.offset_provider_to_type(offset_provider), ) # type: ignore[assignment] # type is still `itir.Program` + ir = inline_dynamic_shifts.InlineDynamicShifts.apply( + ir + ) # domain inference does not support dynamic offsets yet ir = infer_domain.infer_program(ir, offset_provider=offset_provider) return ir diff --git a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py b/src/gt4py/next/iterator/transforms/pass_manager_legacy.py deleted file mode 100644 index 94c962e92d..0000000000 --- a/src/gt4py/next/iterator/transforms/pass_manager_legacy.py +++ /dev/null @@ -1,181 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause -# FIXME[#1582](tehrengruber): file should be removed after refactoring to GTIR -import enum -from typing import Callable, Optional - -from gt4py.eve import utils as eve_utils -from gt4py.next import common -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import fencil_to_program, inline_fundefs -from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet -from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple -from gt4py.next.iterator.transforms.constant_folding import ConstantFolding -from gt4py.next.iterator.transforms.cse import CommonSubexpressionElimination -from gt4py.next.iterator.transforms.eta_reduction import EtaReduction -from gt4py.next.iterator.transforms.fuse_maps import FuseMaps -from gt4py.next.iterator.transforms.inline_center_deref_lift_vars import InlineCenterDerefLiftVars -from gt4py.next.iterator.transforms.inline_into_scan import InlineIntoScan -from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas -from gt4py.next.iterator.transforms.inline_lifts import InlineLifts -from gt4py.next.iterator.transforms.merge_let import MergeLet -from gt4py.next.iterator.transforms.normalize_shifts import NormalizeShifts -from gt4py.next.iterator.transforms.propagate_deref import PropagateDeref -from gt4py.next.iterator.transforms.scan_eta_reduction import ScanEtaReduction -from gt4py.next.iterator.transforms.unroll_reduce import UnrollReduce - - -@enum.unique -class LiftMode(enum.Enum): - FORCE_INLINE = enum.auto() - USE_TEMPORARIES = enum.auto() - - -def _inline_lifts(ir, lift_mode): - if lift_mode == LiftMode.FORCE_INLINE: - return InlineLifts().visit(ir) - elif lift_mode == LiftMode.USE_TEMPORARIES: - return InlineLifts( - flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT - | InlineLifts.Flag.INLINE_DEREF_LIFT # some tuple exprs found in FVM don't work yet. - ).visit(ir) - else: - raise ValueError() - - return ir - - -def _inline_into_scan(ir, *, max_iter=10): - for _ in range(10): - # in case there are multiple levels of lambdas around the scan we have to do multiple iterations - inlined = InlineIntoScan().visit(ir) - inlined = InlineLambdas.apply(inlined, opcount_preserving=True, force_inline_lift_args=True) - if inlined == ir: - break - ir = inlined - else: - raise RuntimeError(f"Inlining into 'scan' did not converge within {max_iter} iterations.") - return ir - - -def apply_common_transforms( - ir: itir.Node, - *, - lift_mode=None, - offset_provider=None, - unroll_reduce=False, - common_subexpression_elimination=True, - force_inline_lambda_args=False, - unconditionally_collapse_tuples=False, - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, - symbolic_domain_sizes: Optional[dict[str, str]] = None, - offset_provider_type: Optional[common.OffsetProviderType] = None, -) -> itir.Program: - assert isinstance(ir, itir.FencilDefinition) - # TODO(havogt): if the runtime `offset_provider` is not passed, we cannot run global_tmps - if offset_provider_type is None: - offset_provider_type = common.offset_provider_to_type(offset_provider) - - ir = fencil_to_program.FencilToProgram().apply(ir) - icdlv_uids = eve_utils.UIDGenerator() - - if lift_mode is None: - lift_mode = LiftMode.FORCE_INLINE - assert isinstance(lift_mode, LiftMode) - ir = MergeLet().visit(ir) - ir = inline_fundefs.InlineFundefs().visit(ir) - - ir = inline_fundefs.prune_unreferenced_fundefs(ir) # type: ignore[arg-type] # all previous passes return itir.Program - ir = PropagateDeref.apply(ir) - ir = NormalizeShifts().visit(ir) - - for _ in range(10): - inlined = ir - - inlined = InlineCenterDerefLiftVars.apply(inlined, uids=icdlv_uids) # type: ignore[arg-type] # always a fencil - inlined = _inline_lifts(inlined, lift_mode) - - inlined = InlineLambdas.apply( - inlined, - opcount_preserving=True, - force_inline_lift_args=(lift_mode == LiftMode.FORCE_INLINE), - # If trivial lifts are not inlined we might create temporaries for constants. In all - # other cases we want it anyway. - force_inline_trivial_lift_args=True, - ) - inlined = ConstantFolding.apply(inlined) - # This pass is required to be in the loop such that when an `if_` call with tuple arguments - # is constant-folded the surrounding tuple_get calls can be removed. - inlined = CollapseTuple.apply( - inlined, - offset_provider_type=offset_provider_type, - # TODO(tehrengruber): disabled since it increases compile-time too much right now - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, - ) - # This pass is required such that a deref outside of a - # `tuple_get(make_tuple(let(...), ...))` call is propagated into the let after the - # `tuple_get` is removed by the `CollapseTuple` pass. - inlined = PropagateDeref.apply(inlined) - - if inlined == ir: - break - ir = inlined - else: - raise RuntimeError("Inlining 'lift' and 'lambdas' did not converge.") - - if lift_mode != LiftMode.FORCE_INLINE: - raise NotImplementedError() - - # Since `CollapseTuple` relies on the type inference which does not support returning tuples - # larger than the number of closure outputs as given by the unconditional collapse, we can - # only run the unconditional version here instead of in the loop above. - if unconditionally_collapse_tuples: - ir = CollapseTuple.apply( - ir, - ignore_tuple_size=True, - offset_provider_type=offset_provider_type, - # TODO(tehrengruber): disabled since it increases compile-time too much right now - flags=~CollapseTuple.Flag.PROPAGATE_TO_IF_ON_TUPLES, - ) - - if lift_mode == LiftMode.FORCE_INLINE: - ir = _inline_into_scan(ir) - - ir = NormalizeShifts().visit(ir) - - ir = FuseMaps().visit(ir) - ir = CollapseListGet().visit(ir) - - if unroll_reduce: - for _ in range(10): - unrolled = UnrollReduce.apply(ir, offset_provider_type=offset_provider_type) - if unrolled == ir: - break - ir = unrolled - ir = CollapseListGet().visit(ir) - ir = NormalizeShifts().visit(ir) - ir = _inline_lifts(ir, LiftMode.FORCE_INLINE) - ir = NormalizeShifts().visit(ir) - else: - raise RuntimeError("Reduction unrolling failed.") - - ir = EtaReduction().visit(ir) - ir = ScanEtaReduction().visit(ir) - - if common_subexpression_elimination: - ir = CommonSubexpressionElimination.apply(ir, offset_provider_type=offset_provider_type) # type: ignore[type-var] # always an itir.Program - ir = MergeLet().visit(ir) - - ir = InlineLambdas.apply( - ir, opcount_preserving=True, force_inline_lambda_args=force_inline_lambda_args - ) - - assert isinstance(ir, itir.Program) - return ir diff --git a/src/gt4py/next/iterator/transforms/program_to_fencil.py b/src/gt4py/next/iterator/transforms/program_to_fencil.py deleted file mode 100644 index 4411dda74f..0000000000 --- a/src/gt4py/next/iterator/transforms/program_to_fencil.py +++ /dev/null @@ -1,31 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm - - -def program_to_fencil(node: itir.Program) -> itir.FencilDefinition: - assert not node.declarations - closures = [] - for stmt in node.body: - assert isinstance(stmt, itir.SetAt) - assert isinstance(stmt.expr, itir.FunCall) and cpm.is_call_to(stmt.expr.fun, "as_fieldop") - stencil, domain = stmt.expr.fun.args - inputs = stmt.expr.args - assert all(isinstance(inp, itir.SymRef) for inp in inputs) - closures.append( - itir.StencilClosure(domain=domain, stencil=stencil, output=stmt.target, inputs=inputs) - ) - - return itir.FencilDefinition( - id=node.id, - function_definitions=node.function_definitions, - params=node.params, - closures=closures, - ) diff --git a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py deleted file mode 100644 index 5058a91216..0000000000 --- a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py +++ /dev/null @@ -1,44 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.eve import NodeTranslator, PreserveLocationVisitor -from gt4py.next.iterator import ir - - -class PruneClosureInputs(PreserveLocationVisitor, NodeTranslator): - """Removes all unused input arguments from a stencil closure.""" - - def visit_StencilClosure(self, node: ir.StencilClosure) -> ir.StencilClosure: - if not isinstance(node.stencil, ir.Lambda): - return node - - unused: set[str] = {p.id for p in node.stencil.params} - expr = self.visit(node.stencil.expr, unused=unused, shadowed=set[str]()) - params = [] - inputs = [] - for param, inp in zip(node.stencil.params, node.inputs): - if param.id not in unused: - params.append(param) - inputs.append(inp) - - return ir.StencilClosure( - domain=node.domain, - stencil=ir.Lambda(params=params, expr=expr), - output=node.output, - inputs=inputs, - ) - - def visit_SymRef(self, node: ir.SymRef, *, unused: set[str], shadowed: set[str]) -> ir.SymRef: - if node.id not in shadowed: - unused.discard(node.id) - return node - - def visit_Lambda(self, node: ir.Lambda, *, unused: set[str], shadowed: set[str]) -> ir.Lambda: - return self.generic_visit( - node, unused=unused, shadowed=shadowed | {p.id for p in node.params} - ) diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index 08d896121d..fb909dc5d0 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -10,6 +10,7 @@ from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait from gt4py.next.iterator import ir +from gt4py.next.iterator.type_system import inference as type_inference class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): @@ -46,7 +47,9 @@ def visit_SymRef( self, node: ir.SymRef, *, name_map: Dict[str, str], active: Optional[Set[str]] = None ): if active and node.id in active: - return ir.SymRef(id=name_map.get(node.id, node.id)) + new_ref = ir.SymRef(id=name_map.get(node.id, node.id)) + type_inference.copy_type(from_=node, to=new_ref, allow_untyped=True) + return new_ref return node def generic_visit( # type: ignore[override] diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 05163a3630..2903201083 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -69,7 +69,7 @@ def apply( Counter({SymRef(id=SymbolRef('x')): 2, SymRef(id=SymbolRef('y')): 2, SymRef(id=SymbolRef('z')): 1}) """ if ignore_builtins: - inactive_refs = {str(n.id) for n in itir.FencilDefinition._NODE_SYMBOLS_} + inactive_refs = {str(n.id) for n in itir.Program._NODE_SYMBOLS_} else: inactive_refs = set() @@ -140,6 +140,4 @@ def collect_symbol_refs( def get_user_defined_symbols(symtable: dict[eve.SymbolName, itir.Sym]) -> set[str]: - return {str(sym) for sym in symtable.keys()} - { - str(n.id) for n in itir.FencilDefinition._NODE_SYMBOLS_ - } + return {str(sym) for sym in symtable.keys()} - {str(n.id) for n in itir.Program._NODE_SYMBOLS_} diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index bc437ba44c..9f7a14b0b8 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -95,14 +95,17 @@ def _set_node_type(node: itir.Node, type_: ts.TypeSpec) -> None: node.type = type_ -def copy_type(from_: itir.Node, to: itir.Node, allow_untyped=False) -> None: +def copy_type(from_: itir.Node, to: itir.Node, allow_untyped: bool = False) -> None: """ Copy type from one node to another. This function mainly exists for readability reasons. """ assert allow_untyped is not None or isinstance(from_.type, ts.TypeSpec) - _set_node_type(to, from_.type) # type: ignore[arg-type] + if from_.type is None: + assert allow_untyped + return + _set_node_type(to, from_.type) def on_inferred(callback: Callable, *args: Union[ts.TypeSpec, ObservableTypeSynthesizer]) -> None: @@ -362,7 +365,7 @@ def apply( Preconditions: - All parameters in :class:`itir.Program` and :class:`itir.FencilDefinition` must have a type + All parameters in :class:`itir.Program` must have a type defined, as they are the starting point for type propagation. Design decisions: @@ -411,9 +414,9 @@ def apply( # parts of a program. node = SanitizeTypes().visit(node) - if isinstance(node, (itir.FencilDefinition, itir.Program)): + if isinstance(node, itir.Program): assert all(isinstance(param.type, ts.DataType) for param in node.params), ( - "All parameters in 'itir.Program' and 'itir.FencilDefinition' must have a type " + "All parameters in 'itir.Program' must have a type " "defined, as they are the starting point for type propagation.", ) @@ -490,20 +493,6 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: ) return result - # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere - def visit_FencilDefinition(self, node: itir.FencilDefinition, *, ctx) -> it_ts.FencilType: - params: dict[str, ts.DataType] = {} - for param in node.params: - assert isinstance(param.type, ts.DataType) - params[param.id] = param.type - - function_definitions: dict[str, type_synthesizer.TypeSynthesizer] = {} - for fun_def in node.function_definitions: - function_definitions[fun_def.id] = self.visit(fun_def, ctx=ctx | function_definitions) - - closures = self.visit(node.closures, ctx=ctx | params | function_definitions) - return it_ts.FencilType(params=params, closures=closures) - def visit_Program(self, node: itir.Program, *, ctx) -> it_ts.ProgramType: params: dict[str, ts.DataType] = {} for param in node.params: @@ -562,37 +551,6 @@ def visit_SetAt(self, node: itir.SetAt, *, ctx) -> None: and target_type.dtype == expr_type.dtype ) - # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere - def visit_StencilClosure(self, node: itir.StencilClosure, *, ctx) -> it_ts.StencilClosureType: - domain: it_ts.DomainType = self.visit(node.domain, ctx=ctx) - inputs: list[ts.FieldType] = self.visit(node.inputs, ctx=ctx) - output: ts.FieldType = self.visit(node.output, ctx=ctx) - - assert isinstance(domain, it_ts.DomainType) - for output_el in type_info.primitive_constituents(output): - assert isinstance(output_el, ts.FieldType) - - stencil_type_synthesizer = self.visit(node.stencil, ctx=ctx) - stencil_args = [ - type_synthesizer._convert_as_fieldop_input_to_iterator(domain, input_) - for input_ in inputs - ] - stencil_returns = stencil_type_synthesizer( - *stencil_args, offset_provider_type=self.offset_provider_type - ) - - return it_ts.StencilClosureType( - domain=domain, - stencil=ts.FunctionType( - pos_only_args=stencil_args, - pos_or_kw_args={}, - kw_only_args={}, - returns=stencil_returns, - ), - output=output, - inputs=inputs, - ) - def visit_AxisLiteral(self, node: itir.AxisLiteral, **kwargs) -> ts.DimensionType: return ts.DimensionType(dim=common.Dimension(value=node.value, kind=node.kind)) diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index edb56f5659..eef8c75d0f 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -43,30 +43,6 @@ class IteratorType(ts.DataType, ts.CallableType): element_type: ts.DataType -@dataclasses.dataclass(frozen=True) -class StencilClosureType(ts.TypeSpec): - domain: DomainType - stencil: ts.FunctionType - output: ts.FieldType | ts.TupleType - inputs: list[ts.FieldType] - - def __post_init__(self): - # local import to avoid importing type_info from a type_specification module - from gt4py.next.type_system import type_info - - for i, el_type in enumerate(type_info.primitive_constituents(self.output)): - assert isinstance( - el_type, ts.FieldType - ), f"All constituent types must be field types, but the {i}-th element is of type '{el_type}'." - - -# TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere -@dataclasses.dataclass(frozen=True) -class FencilType(ts.TypeSpec): - params: dict[str, ts.DataType] - closures: list[StencilClosureType] - - @dataclasses.dataclass(frozen=True) class ProgramType(ts.TypeSpec): params: dict[str, ts.DataType] diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index 85838d9c76..22326c7e87 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -26,9 +26,7 @@ SettingT_co = TypeVar("SettingT_co", bound=languages.LanguageSettings, covariant=True) -CompilableProgram: TypeAlias = toolchain.CompilableProgram[ - itir.FencilDefinition | itir.Program, arguments.CompileTimeArgs -] +CompilableProgram: TypeAlias = toolchain.CompilableProgram[itir.Program, arguments.CompileTimeArgs] @dataclasses.dataclass(frozen=True) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index f1649112a7..020b1f55ea 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -10,7 +10,7 @@ import dataclasses import functools -from typing import Any, Callable, Final, Optional +from typing import Any, Final, Optional import factory import numpy as np @@ -53,9 +53,6 @@ class GTFNTranslationStep( use_imperative_backend: bool = False device_type: core_defs.DeviceType = core_defs.DeviceType.CPU symbolic_domain_sizes: Optional[dict[str, str]] = None - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: match self.device_type: @@ -80,7 +77,7 @@ def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSetting def _process_regular_arguments( self, - program: itir.FencilDefinition | itir.Program, + program: itir.Program, arg_types: tuple[ts.TypeSpec, ...], offset_provider_type: common.OffsetProviderType, ) -> tuple[list[interface.Parameter], list[str]]: @@ -157,7 +154,7 @@ def _process_connectivity_args( def _preprocess_program( self, - program: itir.FencilDefinition | itir.Program, + program: itir.Program, offset_provider: common.OffsetProvider, ) -> itir.Program: apply_common_transforms = functools.partial( @@ -167,7 +164,6 @@ def _preprocess_program( # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements unconditionally_collapse_tuples=True, symbolic_domain_sizes=self.symbolic_domain_sizes, - temporary_extraction_heuristics=self.temporary_extraction_heuristics, ) new_program = apply_common_transforms( @@ -186,7 +182,7 @@ def _preprocess_program( def generate_stencil_source( self, - program: itir.FencilDefinition | itir.Program, + program: itir.Program, offset_provider: common.OffsetProvider, column_axis: Optional[common.Dimension], ) -> str: @@ -214,7 +210,7 @@ def __call__( self, inp: stages.CompilableProgram ) -> stages.ProgramSource[languages.NanobindSrcL, languages.LanguageWithHeaderFilesSettings]: """Generate GTFN C++ code from the ITIR definition.""" - program: itir.FencilDefinition | itir.Program = inp.data + program: itir.Program = inp.data # handle regular parameters and arguments of the program (i.e. what the user defined in # the program) diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 129d81d6f9..d5b34fd5b9 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -108,7 +108,7 @@ def _get_gridtype(body: list[itir.Stmt]) -> common.GridType: grid_types = {_extract_grid_type(d) for d in domains} if len(grid_types) != 1: raise ValueError( - f"Found 'StencilClosures' with more than one 'GridType': '{grid_types}'. This is currently not supported." + f"Found 'set_at' with more than one 'GridType': '{grid_types}'. This is currently not supported." ) return grid_types.pop() @@ -198,6 +198,9 @@ def _collect_offset_definitions( "Mapping an offset to a horizontal dimension in unstructured is not allowed." ) # create alias from vertical offset to vertical dimension + offset_definitions[dim.value] = TagDefinition( + name=Sym(id=dim.value), alias=_vertical_dimension + ) offset_definitions[offset_name] = TagDefinition( name=Sym(id=offset_name), alias=SymRef(id=dim.value) ) diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index db1242e2a4..5f32eaa2bb 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -15,7 +15,7 @@ @program_formatter.program_formatter -def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: +def format_cpp(program: itir.Program, *args: Any, **kwargs: Any) -> str: # TODO(tehrengruber): This is a little ugly. Revisit. gtfn_translation = gtfn.GTFNBackendFactory().executor.translation assert isinstance(gtfn_translation, GTFNTranslationStep) diff --git a/src/gt4py/next/program_processors/formatters/lisp.py b/src/gt4py/next/program_processors/formatters/lisp.py deleted file mode 100644 index 0a8253595e..0000000000 --- a/src/gt4py/next/program_processors/formatters/lisp.py +++ /dev/null @@ -1,67 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from typing import Any - -from gt4py.eve.codegen import FormatTemplate as as_fmt, TemplatedGenerator -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import apply_common_transforms -from gt4py.next.program_processors import program_formatter - - -class ToLispLike(TemplatedGenerator): - Sym = as_fmt("{id}") - FunCall = as_fmt("({fun} {' '.join(args)})") - Literal = as_fmt("{value}") - OffsetLiteral = as_fmt("{value}") - SymRef = as_fmt("{id}") - StencilClosure = as_fmt( - """( - :domain {domain} - :stencil {stencil} - :output {output} - :inputs {' '.join(inputs)} - ) - """ - ) - FencilDefinition = as_fmt( - """ - ({' '.join(function_definitions)}) - (defen {id}({' '.join(params)}) - {''.join(closures)}) - """ - ) - FunctionDefinition = as_fmt( - """(defun {id}({' '.join(params)}) - {expr} - ) - -""" - ) - Lambda = as_fmt( - """(lambda ({' '.join(params)}) - {expr} - )""" - ) - - @classmethod - def apply(cls, root: itir.FencilDefinition, **kwargs: Any) -> str: # type: ignore[override] - transformed = apply_common_transforms(root, offset_provider=kwargs["offset_provider"]) - generated_code = super().apply(transformed, **kwargs) - try: - from yasi import indent_code - - indented = indent_code(generated_code, "--dialect lisp") - return "".join(indented["indented_code"]) - except ImportError: - return generated_code - - -@program_formatter.program_formatter -def format_lisp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: - return ToLispLike.apply(program, **kwargs) diff --git a/src/gt4py/next/program_processors/formatters/pretty_print.py b/src/gt4py/next/program_processors/formatters/pretty_print.py index f14ac5653f..cbf9fd1978 100644 --- a/src/gt4py/next/program_processors/formatters/pretty_print.py +++ b/src/gt4py/next/program_processors/formatters/pretty_print.py @@ -15,7 +15,7 @@ @program_formatter.program_formatter -def format_itir_and_check(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: +def format_itir_and_check(program: itir.Program, *args: Any, **kwargs: Any) -> str: pretty = pretty_printer.pformat(program) parsed = pretty_parser.pparse(pretty) assert parsed == program diff --git a/src/gt4py/next/program_processors/program_formatter.py b/src/gt4py/next/program_processors/program_formatter.py index f77e7f32ee..321c09668c 100644 --- a/src/gt4py/next/program_processors/program_formatter.py +++ b/src/gt4py/next/program_processors/program_formatter.py @@ -10,7 +10,7 @@ Interface for program processors. Program processors are functions which operate on a program paired with the input -arguments for the program. Programs are represented by an ``iterator.ir.itir.FencilDefinition`` +arguments for the program. Programs are represented by an ``iterator.ir.Program`` node. Program processors that execute the program with the given arguments (possibly by generating code along the way) are program executors. Those that generate any kind of string based on the program and (optionally) input values are program formatters. @@ -30,14 +30,14 @@ class ProgramFormatter: @abc.abstractmethod - def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: ... + def __call__(self, program: itir.Program, *args: Any, **kwargs: Any) -> str: ... @dataclasses.dataclass(frozen=True) class WrappedProgramFormatter(ProgramFormatter): formatter: Callable[..., str] - def __call__(self, program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: + def __call__(self, program: itir.Program, *args: Any, **kwargs: Any) -> str: return self.formatter(program, *args, **kwargs) @@ -47,7 +47,7 @@ def program_formatter(func: Callable[..., str]) -> ProgramFormatter: Examples: >>> @program_formatter - ... def format_foo(fencil: itir.FencilDefinition, *args, **kwargs) -> str: + ... def format_foo(fencil: itir.Program, *args, **kwargs) -> str: ... '''A very useless fencil formatter.''' ... return "foo" diff --git a/src/gt4py/next/program_processors/runners/dace.py b/src/gt4py/next/program_processors/runners/dace.py index 95186e0b5d..1b3b930818 100644 --- a/src/gt4py/next/program_processors/runners/dace.py +++ b/src/gt4py/next/program_processors/runners/dace.py @@ -8,45 +8,34 @@ import factory +import gt4py._core.definitions as core_defs +import gt4py.next.allocators as next_allocators from gt4py.next import backend +from gt4py.next.otf import workflow from gt4py.next.program_processors.runners.dace_fieldview import workflow as dace_fieldview_workflow -from gt4py.next.program_processors.runners.dace_iterator import workflow as dace_iterator_workflow from gt4py.next.program_processors.runners.gtfn import GTFNBackendFactory -class DaCeIteratorBackendFactory(GTFNBackendFactory): +class DaCeFieldviewBackendFactory(GTFNBackendFactory): + class Meta: + model = backend.Backend + class Params: - otf_workflow = factory.SubFactory( - dace_iterator_workflow.DaCeWorkflowFactory, - device_type=factory.SelfAttribute("..device_type"), - use_field_canonical_representation=factory.SelfAttribute( - "..use_field_canonical_representation" - ), + name_device = "cpu" + name_cached = "" + name_postfix = "" + gpu = factory.Trait( + allocator=next_allocators.StandardGPUFieldBufferAllocator(), + device_type=next_allocators.CUPY_DEVICE or core_defs.DeviceType.CUDA, + name_device="gpu", ) - auto_optimize = factory.Trait( - otf_workflow__translation__auto_optimize=True, name_postfix="_opt" + cached = factory.Trait( + executor=factory.LazyAttribute( + lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) + ), + name_cached="_cached", ) - use_field_canonical_representation: bool = False - - name = factory.LazyAttribute( - lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.itir" - ) - - transforms = backend.LEGACY_TRANSFORMS - - -run_dace_cpu = DaCeIteratorBackendFactory(cached=True, auto_optimize=True) -run_dace_cpu_noopt = DaCeIteratorBackendFactory(cached=True, auto_optimize=False) - -run_dace_gpu = DaCeIteratorBackendFactory(gpu=True, cached=True, auto_optimize=True) -run_dace_gpu_noopt = DaCeIteratorBackendFactory(gpu=True, cached=True, auto_optimize=False) - -itir_cpu = run_dace_cpu -itir_gpu = run_dace_gpu - - -class DaCeFieldviewBackendFactory(GTFNBackendFactory): - class Params: + device_type = core_defs.DeviceType.CPU otf_workflow = factory.SubFactory( dace_fieldview_workflow.DaCeWorkflowFactory, device_type=factory.SelfAttribute("..device_type"), @@ -55,11 +44,16 @@ class Params: auto_optimize = factory.Trait(name_postfix="_opt") name = factory.LazyAttribute( - lambda o: f"run_dace_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}.gtir" + lambda o: f"run_dace_{o.name_device}{o.name_cached}{o.name_postfix}" ) + executor = factory.LazyAttribute(lambda o: o.otf_workflow) + allocator = next_allocators.StandardCPUFieldBufferAllocator() transforms = backend.DEFAULT_TRANSFORMS -gtir_cpu = DaCeFieldviewBackendFactory(cached=True, auto_optimize=False) -gtir_gpu = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=False) +run_dace_cpu = DaCeFieldviewBackendFactory(cached=True, auto_optimize=True) +run_dace_cpu_noopt = DaCeFieldviewBackendFactory(cached=True, auto_optimize=False) + +run_dace_gpu = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=True) +run_dace_gpu_noopt = DaCeFieldviewBackendFactory(gpu=True, cached=True, auto_optimize=False) diff --git a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py index 56ba08015b..90e7e07ad5 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py +++ b/src/gt4py/next/program_processors/runners/dace_common/dace_backend.py @@ -24,7 +24,7 @@ cp = None -def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: bool) -> Any: +def _convert_arg(arg: Any, sdfg_param: str) -> Any: if not isinstance(arg, gtx_common.Field): return arg if len(arg.domain.dims) == 0: @@ -41,26 +41,14 @@ def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: raise RuntimeError( f"Field '{sdfg_param}' passed as array slice with offset {dim_range.start} on dimension {dim.value}." ) - if not use_field_canonical_representation: - return arg.ndarray - # the canonical representation requires alphabetical ordering of the dimensions in field domain definition - sorted_dims = dace_utils.get_sorted_dims(arg.domain.dims) - ndim = len(sorted_dims) - dim_indices = [dim_index for dim_index, _ in sorted_dims] - if isinstance(arg.ndarray, np.ndarray): - return np.moveaxis(arg.ndarray, range(ndim), dim_indices) - else: - assert cp is not None and isinstance(arg.ndarray, cp.ndarray) - return cp.moveaxis(arg.ndarray, range(ndim), dim_indices) - - -def _get_args( - sdfg: dace.SDFG, args: Sequence[Any], use_field_canonical_representation: bool -) -> dict[str, Any]: + return arg.ndarray + + +def _get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: sdfg_params: Sequence[str] = sdfg.arg_names flat_args: Iterable[Any] = gtx_utils.flatten_nested_tuple(tuple(args)) return { - sdfg_param: _convert_arg(arg, sdfg_param, use_field_canonical_representation) + sdfg_param: _convert_arg(arg, sdfg_param) for sdfg_param, arg in zip(sdfg_params, flat_args, strict=True) } @@ -154,10 +142,10 @@ def get_sdfg_conn_args( def get_sdfg_args( sdfg: dace.SDFG, + offset_provider: gtx_common.OffsetProvider, *args: Any, check_args: bool = False, on_gpu: bool = False, - use_field_canonical_representation: bool = True, **kwargs: Any, ) -> dict[str, Any]: """Extracts the arguments needed to call the SDFG. @@ -166,10 +154,10 @@ def get_sdfg_args( Args: sdfg: The SDFG for which we want to get the arguments. + offset_provider: Offset provider. """ - offset_provider = kwargs["offset_provider"] - dace_args = _get_args(sdfg, args, use_field_canonical_representation) + dace_args = _get_args(sdfg, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} dace_conn_args = get_sdfg_conn_args(sdfg, offset_provider, on_gpu) dace_shapes = _get_shape_args(sdfg.arrays, dace_field_args) diff --git a/src/gt4py/next/program_processors/runners/dace_common/utility.py b/src/gt4py/next/program_processors/runners/dace_common/utility.py index 3e96ef3cec..ac15bc1cbf 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_common/utility.py @@ -9,7 +9,7 @@ from __future__ import annotations import re -from typing import Final, Literal, Optional, Sequence +from typing import Final, Literal, Optional import dace @@ -96,10 +96,3 @@ def filter_connectivity_types( for offset, conn in offset_provider_type.items() if isinstance(conn, gtx_common.NeighborConnectivityType) } - - -def get_sorted_dims( - dims: Sequence[gtx_common.Dimension], -) -> Sequence[tuple[int, gtx_common.Dimension]]: - """Sort list of dimensions in alphabetical order.""" - return sorted(enumerate(dims), key=lambda v: v[1].value) diff --git a/src/gt4py/next/program_processors/runners/dace_common/workflow.py b/src/gt4py/next/program_processors/runners/dace_common/workflow.py index 91e83dba9d..5d9ac863c5 100644 --- a/src/gt4py/next/program_processors/runners/dace_common/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_common/workflow.py @@ -150,9 +150,9 @@ def decorated_program( sdfg_args = dace_backend.get_sdfg_args( sdfg, + offset_provider, *args, check_args=False, - offset_provider=offset_provider, on_gpu=on_gpu, use_field_canonical_representation=use_field_canonical_representation, ) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 94ab3a6f76..ff011c4193 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -18,7 +18,11 @@ from gt4py.next import common as gtx_common, utils as gtx_utils from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as gtir -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, domain_utils +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -229,40 +233,75 @@ def _get_field_layout( return list(domain_dims), list(domain_lbs), domain_sizes -def _create_temporary_field( +def _create_field_operator( sdfg: dace.SDFG, state: dace.SDFGState, domain: FieldopDomain, node_type: ts.FieldType, - dataflow_output: gtir_dataflow.DataflowOutputEdge, + sdfg_builder: gtir_sdfg.SDFGBuilder, + input_edges: Sequence[gtir_dataflow.DataflowInputEdge], + output_edge: gtir_dataflow.DataflowOutputEdge, ) -> FieldopData: - """Helper method to allocate a temporary field where to write the output of a field operator.""" + """ + Helper method to allocate a temporary field to store the output of a field operator. + + Args: + sdfg: The SDFG that represents the scope of the field data. + state: The SDFG state where to create an access node to the field data. + domain: The domain of the field operator that computes the field. + node_type: The GT4Py type of the IR node that produces this field. + sdfg_builder: The object used to build the map scope in the provided SDFG. + input_edges: List of edges to pass input data into the dataflow. + output_edge: Edge representing the dataflow output data. + + Returns: + The field data descriptor, which includes the field access node in the given `state` + and the field domain offset. + """ field_dims, field_offset, field_shape = _get_field_layout(domain) + field_indices = _get_domain_indices(field_dims, field_offset) + + dataflow_output_desc = output_edge.result.dc_node.desc(sdfg) - output_desc = dataflow_output.result.dc_node.desc(sdfg) - if isinstance(output_desc, dace.data.Array): + field_subset = sbs.Range.from_indices(field_indices) + if isinstance(output_edge.result.gt_dtype, ts.ScalarType): + assert output_edge.result.gt_dtype == node_type.dtype + assert isinstance(dataflow_output_desc, dace.data.Scalar) + assert dataflow_output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) + field_dtype = output_edge.result.gt_dtype + else: assert isinstance(node_type.dtype, itir_ts.ListType) - assert isinstance(node_type.dtype.element_type, ts.ScalarType) - assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype.element_type) + assert output_edge.result.gt_dtype.element_type == node_type.dtype.element_type + assert isinstance(dataflow_output_desc, dace.data.Array) + assert isinstance(output_edge.result.gt_dtype.element_type, ts.ScalarType) + field_dtype = output_edge.result.gt_dtype.element_type # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) - field_offset.extend(output_desc.offset) - field_shape.extend(output_desc.shape) - elif isinstance(output_desc, dace.data.Scalar): - assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) - else: - raise ValueError(f"Cannot create field for dace type {output_desc}.") + assert output_edge.result.gt_dtype.offset_type is not None + field_dims.append(output_edge.result.gt_dtype.offset_type) + field_shape.extend(dataflow_output_desc.shape) + field_offset.extend(dataflow_output_desc.offset) + field_subset = field_subset + sbs.Range.from_array(dataflow_output_desc) # allocate local temporary storage - temp_name, _ = sdfg.add_temp_transient(field_shape, output_desc.dtype) - field_node = state.add_access(temp_name) + field_name, _ = sdfg.add_temp_transient(field_shape, dataflow_output_desc.dtype) + field_node = state.add_access(field_name) - if isinstance(dataflow_output.result.gt_dtype, ts.ScalarType): - field_dtype = dataflow_output.result.gt_dtype - else: - assert isinstance(dataflow_output.result.gt_dtype.element_type, ts.ScalarType) - field_dtype = dataflow_output.result.gt_dtype.element_type - assert dataflow_output.result.gt_dtype.offset_type is not None - field_dims.append(dataflow_output.result.gt_dtype.offset_type) + # create map range corresponding to the field operator domain + me, mx = sdfg_builder.add_map( + "fieldop", + state, + ndrange={ + dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" + for dim, lower_bound, upper_bound in domain + }, + ) + + # here we setup the edges passing through the map entry node + for edge in input_edges: + edge.connect(me) + + # and here the edge writing the dataflow result data through the map exit node + output_edge.connect(mx, field_node, field_subset) return FieldopData( field_node, @@ -341,7 +380,8 @@ def translate_as_fieldop( # Special usage of 'deref' as argument to fieldop expression, to pass a scalar # value to 'as_fieldop' function. It results in broadcasting the scalar value # over the field domain. - return translate_broadcast_scalar(node, sdfg, state, sdfg_builder) + stencil_expr = im.lambda_("a")(im.deref("a")) + stencil_expr.expr.type = node.type.dtype # type: ignore[attr-defined] else: raise NotImplementedError( f"Expression type '{type(stencil_expr)}' not supported as argument to 'as_fieldop' node." @@ -349,117 +389,18 @@ def translate_as_fieldop( # parse the domain of the field operator domain = extract_domain(domain_expr) - domain_dims, domain_offsets, _ = zip(*domain) - domain_indices = _get_domain_indices(domain_dims, domain_offsets) # visit the list of arguments to be passed to the lambda expression stencil_args = [_parse_fieldop_arg(arg, sdfg, state, sdfg_builder, domain) for arg in node.args] # represent the field operator as a mapped tasklet graph, which will range over the field domain taskgen = gtir_dataflow.LambdaToDataflow(sdfg, state, sdfg_builder) - input_edges, output = taskgen.visit(stencil_expr, args=stencil_args) - output_desc = output.result.dc_node.desc(sdfg) - - if isinstance(node.type.dtype, itir_ts.ListType): - assert isinstance(output_desc, dace.data.Array) - # additional local dimension for neighbors - # TODO(phimuell): Investigate if we should swap the two. - output_subset = sbs.Range.from_indices(domain_indices) + sbs.Range.from_array(output_desc) - else: - assert isinstance(output_desc, dace.data.Scalar) - output_subset = sbs.Range.from_indices(domain_indices) - - # create map range corresponding to the field operator domain - me, mx = sdfg_builder.add_map( - "fieldop", - state, - ndrange={ - dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" - for dim, lower_bound, upper_bound in domain - }, - ) - - # allocate local temporary storage for the result field - result_field = _create_temporary_field(sdfg, state, domain, node.type, output) - - # here we setup the edges from the map entry node - for edge in input_edges: - edge.connect(me) - - # and here the edge writing the result data through the map exit node - output.connect(mx, result_field.dc_node, output_subset) - - return result_field - + input_edges, output_edge = taskgen.visit(stencil_expr, args=stencil_args) -def translate_broadcast_scalar( - node: gtir.Node, - sdfg: dace.SDFG, - state: dace.SDFGState, - sdfg_builder: gtir_sdfg.SDFGBuilder, -) -> FieldopResult: - """ - Generates the dataflow subgraph for the 'as_fieldop' builtin function for the - special case where the argument to 'as_fieldop' is a 'deref' scalar expression, - rather than a lambda function. This case corresponds to broadcasting the scalar - value over the field domain. Therefore, it is lowered to a mapped tasklet that - just writes the scalar value out to all elements of the result field. - """ - assert isinstance(node, gtir.FunCall) - assert cpm.is_call_to(node.fun, "as_fieldop") - assert isinstance(node.type, ts.FieldType) - - fun_node = node.fun - assert len(fun_node.args) == 2 - stencil_expr, domain_expr = fun_node.args - assert cpm.is_ref_to(stencil_expr, "deref") - - domain = extract_domain(domain_expr) - output_dims, output_offset, output_shape = _get_field_layout(domain) - output_subset = sbs.Range.from_indices(_get_domain_indices(output_dims, output_offset)) - - assert len(node.args) == 1 - scalar_expr = _parse_fieldop_arg(node.args[0], sdfg, state, sdfg_builder, domain) - - if isinstance(node.args[0].type, ts.ScalarType): - assert isinstance(scalar_expr, (gtir_dataflow.MemletExpr, gtir_dataflow.ValueExpr)) - input_subset = ( - str(scalar_expr.subset) if isinstance(scalar_expr, gtir_dataflow.MemletExpr) else "0" - ) - input_node = scalar_expr.dc_node - gt_dtype = node.args[0].type - elif isinstance(node.args[0].type, ts.FieldType): - assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr) - if len(node.args[0].type.dims) == 0: # zero-dimensional field - input_subset = "0" - else: - input_subset = scalar_expr.get_memlet_subset(sdfg) - - input_node = scalar_expr.field - gt_dtype = node.args[0].type.dtype - else: - raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.") - - output, _ = sdfg.add_temp_transient(output_shape, input_node.desc(sdfg).dtype) - output_node = state.add_access(output) - - sdfg_builder.add_mapped_tasklet( - "broadcast", - state, - map_ranges={ - dace_gtir_utils.get_map_variable(dim): f"{lower_bound}:{upper_bound}" - for dim, lower_bound, upper_bound in domain - }, - inputs={"__inp": dace.Memlet(data=input_node.data, subset=input_subset)}, - code="__val = __inp", - outputs={"__val": dace.Memlet(data=output_node.data, subset=output_subset)}, - input_nodes={input_node.data: input_node}, - output_nodes={output_node.data: output_node}, - external_edges=True, + return _create_field_operator( + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge ) - return FieldopData(output_node, ts.FieldType(output_dims, gt_dtype), output_offset) - def translate_if( node: gtir.Node, @@ -567,38 +508,44 @@ def translate_index( index values to a transient array. The extent of the index range is taken from the domain information that should be present in the node annex. """ + assert cpm.is_call_to(node, "index") + assert isinstance(node.type, ts.FieldType) + assert "domain" in node.annex domain = extract_domain(node.annex.domain) assert len(domain) == 1 - dim, lower_bound, upper_bound = domain[0] + dim, _, _ = domain[0] dim_index = dace_gtir_utils.get_map_variable(dim) - field_dims, field_offset, field_shape = _get_field_layout(domain) - field_type = ts.FieldType(field_dims, dace_utils.as_itir_type(INDEX_DTYPE)) - - output, _ = sdfg.add_temp_transient(field_shape, INDEX_DTYPE) - output_node = state.add_access(output) - - sdfg_builder.add_mapped_tasklet( + index_data = sdfg.temp_data_name() + sdfg.add_scalar(index_data, INDEX_DTYPE, transient=True) + index_node = state.add_access(index_data) + index_value = gtir_dataflow.ValueExpr( + dc_node=index_node, + gt_dtype=dace_utils.as_itir_type(INDEX_DTYPE), + ) + index_write_tasklet = sdfg_builder.add_tasklet( "index", state, - map_ranges={ - dim_index: f"{lower_bound}:{upper_bound}", - }, inputs={}, + outputs={"__val"}, code=f"__val = {dim_index}", - outputs={ - "__val": dace.Memlet( - data=output_node.data, - subset=sbs.Range.from_indices(_get_domain_indices(field_dims, field_offset)), - ) - }, - input_nodes={}, - output_nodes={output_node.data: output_node}, - external_edges=True, + ) + state.add_edge( + index_write_tasklet, + "__val", + index_node, + None, + dace.Memlet(data=index_data, subset="0"), ) - return FieldopData(output_node, field_type, field_offset) + input_edges = [ + gtir_dataflow.EmptyInputEdge(state, index_write_tasklet), + ] + output_edge = gtir_dataflow.DataflowOutputEdge(state, index_value) + return _create_field_operator( + sdfg, state, domain, node.type, sdfg_builder, input_edges, output_edge + ) def _get_data_nodes( @@ -831,7 +778,6 @@ def translate_symbol_ref( # Use type-checking to assert that all translator functions implement the `PrimitiveTranslator` protocol __primitive_translators: list[PrimitiveTranslator] = [ translate_as_fieldop, - translate_broadcast_scalar, translate_if, translate_index, translate_literal, diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py index 6aee33c56e..4bdb602f5f 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_python_codegen.py @@ -118,29 +118,41 @@ class PythonCodegen(codegen.TemplatedGenerator): as in the case of field domain definitions, for sybolic array shape and map range. """ - SymRef = as_fmt("{id}") Literal = as_fmt("{value}") - def _visit_deref(self, node: gtir.FunCall) -> str: - assert len(node.args) == 1 - if isinstance(node.args[0], gtir.SymRef): - return self.visit(node.args[0]) - raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") - - def visit_FunCall(self, node: gtir.FunCall) -> str: - if cpm.is_call_to(node, "deref"): - return self._visit_deref(node) + def visit_FunCall(self, node: gtir.FunCall, args_map: dict[str, gtir.Node]) -> str: + if isinstance(node.fun, gtir.Lambda): + # update the mapping from lambda parameters to corresponding argument expressions + lambda_args_map = args_map | { + p.id: arg for p, arg in zip(node.fun.params, node.args, strict=True) + } + return self.visit(node.fun.expr, args_map=lambda_args_map) + elif cpm.is_call_to(node, "deref"): + assert len(node.args) == 1 + if not isinstance(node.args[0], gtir.SymRef): + # shift expressions are not expected in this visitor context + raise NotImplementedError(f"Unexpected deref with arg type '{type(node.args[0])}'.") + return self.visit(node.args[0], args_map=args_map) elif isinstance(node.fun, gtir.SymRef): - args = self.visit(node.args) + args = self.visit(node.args, args_map=args_map) builtin_name = str(node.fun.id) return format_builtin(builtin_name, *args) raise NotImplementedError(f"Unexpected 'FunCall' node ({node}).") + def visit_SymRef(self, node: gtir.SymRef, args_map: dict[str, gtir.Node]) -> str: + symbol = str(node.id) + if symbol in args_map: + return self.visit(args_map[symbol], args_map=args_map) + return symbol + -get_source = PythonCodegen.apply -""" -Specialized visit method for symbolic expressions. +def get_source(node: gtir.Node) -> str: + """ + Specialized visit method for symbolic expressions. -Returns: - A string containing the Python code corresponding to a symbolic expression -""" + The visitor uses `args_map` to map lambda parameters to the corresponding argument expressions. + + Returns: + A string containing the Python code corresponding to a symbolic expression + """ + return PythonCodegen.apply(node, args_map={}) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 8852dd6d2d..2232bcef01 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -12,32 +12,56 @@ that explains the general structure and requirements on the SDFGs. """ -from .auto_opt import ( +from .auto_optimize import gt_auto_optimize +from .gpu_utils import ( + GPUSetBlockSize, + gt_gpu_transform_non_standard_memlet, + gt_gpu_transformation, + gt_set_gpu_blocksize, +) +from .local_double_buffering import gt_create_local_double_buffering +from .loop_blocking import LoopBlocking +from .map_fusion_parallel import MapFusionParallel +from .map_fusion_serial import MapFusionSerial +from .map_orderer import MapIterationOrder, gt_set_iteration_order +from .map_promoter import SerialMapPromoter +from .simplify import ( GT_SIMPLIFY_DEFAULT_SKIP_SET, - gt_auto_optimize, + GT4PyGlobalSelfCopyElimination, + GT4PyMapBufferElimination, + GT4PyMoveTaskletIntoMap, gt_inline_nested_sdfg, - gt_set_iteration_order, + gt_reduce_distributed_buffering, gt_simplify, + gt_substitute_compiletime_symbols, ) -from .gpu_utils import GPUSetBlockSize, gt_gpu_transformation, gt_set_gpu_blocksize -from .loop_blocking import LoopBlocking -from .map_orderer import MapIterationOrder -from .map_promoter import SerialMapPromoter -from .map_serial_fusion import SerialMapFusion +from .strides import gt_change_transient_strides +from .util import gt_find_constant_arguments, gt_make_transients_persistent __all__ = [ "GT_SIMPLIFY_DEFAULT_SKIP_SET", "GPUSetBlockSize", + "GT4PyGlobalSelfCopyElimination", + "GT4PyMoveTaskletIntoMap", + "GT4PyMapBufferElimination", "LoopBlocking", "MapIterationOrder", - "SerialMapFusion", + "MapFusionParallel", + "MapFusionSerial", "SerialMapPromoter", "SerialMapPromoterGPU", "gt_auto_optimize", + "gt_change_transient_strides", + "gt_create_local_double_buffering", "gt_gpu_transformation", "gt_inline_nested_sdfg", "gt_set_iteration_order", "gt_set_gpu_blocksize", "gt_simplify", + "gt_make_transients_persistent", + "gt_reduce_distributed_buffering", + "gt_find_constant_arguments", + "gt_substitute_compiletime_symbols", + "gt_gpu_transform_non_standard_memlet", ] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py similarity index 67% rename from src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py rename to src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py index e070cdfe4e..4a06d2f416 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_optimize.py @@ -8,10 +8,10 @@ """Fast access to the auto optimization on DaCe.""" -from typing import Any, Final, Iterable, Optional, Sequence +from typing import Any, Optional, Sequence, Union import dace -from dace.transformation import dataflow as dace_dataflow, passes as dace_passes +from dace.transformation import dataflow as dace_dataflow from dace.transformation.auto import auto_optimize as dace_aoptimize from gt4py.next import common as gtx_common @@ -20,155 +20,24 @@ ) -GT_SIMPLIFY_DEFAULT_SKIP_SET: Final[set[str]] = {"ScalarToSymbolPromotion", "ConstantPropagation"} -"""Set of simplify passes `gt_simplify()` skips by default. - -The following passes are included: -- `ScalarToSymbolPromotion`: The lowering has sometimes to turn a scalar into a - symbol or vice versa and at a later point to invert this again. However, this - pass has some problems with this pattern so for the time being it is disabled. -- `ConstantPropagation`: Same reasons as `ScalarToSymbolPromotion`. -""" - - -def gt_simplify( - sdfg: dace.SDFG, - validate: bool = True, - validate_all: bool = False, - skip: Optional[Iterable[str]] = None, -) -> Any: - """Performs simplifications on the SDFG in place. - - Instead of calling `sdfg.simplify()` directly, you should use this function, - as it is specially tuned for GridTool based SDFGs. - - This function runs the DaCe simplification pass, but the following passes are - replaced: - - `InlineSDFGs`: Instead `gt_inline_nested_sdfg()` will be called. - - Furthermore, by default, or if `None` is passed fro `skip` the passes listed in - `GT_SIMPLIFY_DEFAULT_SKIP_SET` will be skipped. - - Args: - sdfg: The SDFG to optimize. - validate: Perform validation after the pass has run. - validate_all: Perform extensive validation. - skip: List of simplify passes that should not be applied, defaults - to `GT_SIMPLIFY_DEFAULT_SKIP_SET`. - """ - # Ensure that `skip` is a `set` - skip = GT_SIMPLIFY_DEFAULT_SKIP_SET if skip is None else set(skip) - - if "InlineSDFGs" not in skip: - gt_inline_nested_sdfg( - sdfg=sdfg, - multistate=True, - permissive=False, - validate=validate, - validate_all=validate_all, - ) - - return dace_passes.SimplifyPass( - validate=validate, - validate_all=validate_all, - verbose=False, - skip=(skip | {"InlineSDFGs"}), - ).apply_pass(sdfg, {}) - - -def gt_set_iteration_order( - sdfg: dace.SDFG, - leading_dim: gtx_common.Dimension, - validate: bool = True, - validate_all: bool = False, -) -> Any: - """Set the iteration order of the Maps correctly. - - Modifies the order of the Map parameters such that `leading_dim` - is the fastest varying one, the order of the other dimensions in - a Map is unspecific. `leading_dim` should be the dimensions were - the stride is one. - - Args: - sdfg: The SDFG to process. - leading_dim: The leading dimensions. - validate: Perform validation during the steps. - validate_all: Perform extensive validation. - """ - return sdfg.apply_transformations_once_everywhere( - gtx_transformations.MapIterationOrder( - leading_dim=leading_dim, - ), - validate=validate, - validate_all=validate_all, - ) - - -def gt_inline_nested_sdfg( - sdfg: dace.SDFG, - multistate: bool = True, - permissive: bool = False, - validate: bool = True, - validate_all: bool = False, -) -> dace.SDFG: - """Perform inlining of nested SDFG into their parent SDFG. - - The function uses DaCe's `InlineSDFG` transformation, the same used in simplify. - However, before the inline transformation is run the function will run some - cleaning passes that allows inlining nested SDFGs. - As a side effect, the function will split stages into more states. - - Args: - sdfg: The SDFG that should be processed, will be modified in place and returned. - multistate: Allow inlining of multistate nested SDFG, defaults to `True`. - permissive: Be less strict on the accepted SDFGs. - validate: Perform validation after the transformation has finished. - validate_all: Performs extensive validation. - """ - first_iteration = True - i = 0 - while True: - print(f"ITERATION: {i}") - nb_preproccess = sdfg.apply_transformations_repeated( - [dace_dataflow.PruneSymbols, dace_dataflow.PruneConnectors], - validate=False, - validate_all=validate_all, - ) - if (nb_preproccess == 0) and (not first_iteration): - break - - # Create and configure the inline pass - inline_sdfg = dace_passes.InlineSDFGs() - inline_sdfg.progress = False - inline_sdfg.permissive = permissive - inline_sdfg.multistate = multistate - - # Apply the inline pass - nb_inlines = inline_sdfg.apply_pass(sdfg, {}) - - # Check result, if needed and test if we can stop - if validate_all or validate: - sdfg.validate() - if nb_inlines == 0: - break - first_iteration = False - - return sdfg - - def gt_auto_optimize( sdfg: dace.SDFG, gpu: bool, - leading_dim: Optional[gtx_common.Dimension] = None, + leading_dim: Optional[ + Union[str, gtx_common.Dimension, list[Union[str, gtx_common.Dimension]]] + ] = None, aggressive_fusion: bool = True, max_optimization_rounds_p2: int = 100, make_persistent: bool = True, gpu_block_size: Optional[Sequence[int | str] | str] = None, blocking_dim: Optional[gtx_common.Dimension] = None, blocking_size: int = 10, + blocking_only_if_independent_nodes: Optional[bool] = None, reuse_transients: bool = False, gpu_launch_bounds: Optional[int | str] = None, gpu_launch_factor: Optional[int] = None, + constant_symbols: Optional[dict[str, Any]] = None, + assume_pointwise: bool = True, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -184,6 +53,9 @@ def gt_auto_optimize( different aspects of the SDFG. The initial SDFG is assumed to have a very large number of rather simple Maps. + Note, because of how `gt_auto_optimizer()` works it is not save to call + it twice on the same SDFG. + 1. Some general simplification transformations, beyond classical simplify, are applied to the SDFG. 2. Tries to create larger kernels by fusing smaller ones, see @@ -219,24 +91,37 @@ def gt_auto_optimize( one for all. blocking_dim: On which dimension blocking should be applied. blocking_size: How many elements each block should process. + blocking_only_if_independent_nodes: If `True` only apply loop blocking if + there are independent nodes in the Map, see the `require_independent_nodes` + option of the `LoopBlocking` transformation. reuse_transients: Run the `TransientReuse` transformation, might reduce memory footprint. gpu_launch_bounds: Use this value as `__launch_bounds__` for _all_ GPU Maps. gpu_launch_factor: Use the number of threads times this value as `__launch_bounds__` for _all_ GPU Maps. + constant_symbols: Symbols listed in this `dict` will be replaced by the + respective value inside the SDFG. This might increase performance. + assume_pointwise: Assume that the SDFG has no risk for race condition in + global data access. See the `GT4PyMapBufferElimination` transformation for more. validate: Perform validation during the steps. validate_all: Perform extensive validation. + Note: + For identifying symbols that can be treated as compile time constants + `gt_find_constant_arguments()` function can be used. + Todo: - - Make sure that `SDFG.simplify()` is not called indirectly, by temporarily - overwriting it with `gt_simplify()`. + - Update the description. The Phases are nice, but they have lost their + link to reality a little bit. + - Improve the determination of the strides and iteration order of the + transients. + - Set padding of transients, i.e. alignment, the DaCe datadescriptor + can do that. + - Handle nested SDFGs better. - Specify arguments to set the size of GPU thread blocks depending on the dimensions. I.e. be able to use a different size for 1D than 2D Maps. - - Add a parallel version of Map fusion. - Implement some model to further guide to determine what we want to fuse. Something along the line "Fuse if operational intensity goes up, but not if we have too much internal space (register pressure). - - Create a custom array elimination pass that honors rule 1. - - Check if a pipeline could be used to speed up some computations. """ device = dace.DeviceType.GPU if gpu else dace.DeviceType.CPU @@ -249,20 +134,25 @@ def gt_auto_optimize( # to internal serial maps, such that they do not block fusion? # Phase 1: Initial Cleanup - gt_simplify( + gtx_transformations.gt_simplify( sdfg=sdfg, validate=validate, validate_all=validate_all, ) + gtx_transformations.gt_reduce_distributed_buffering(sdfg) + + if constant_symbols: + gtx_transformations.gt_substitute_compiletime_symbols( + sdfg=sdfg, + repl=constant_symbols, + validate=validate, + validate_all=validate_all, + ) + gtx_transformations.gt_simplify(sdfg) + sdfg.apply_transformations_repeated( [ dace_dataflow.TrivialMapElimination, - # TODO(phimuell): The transformation are interesting, but they have - # a bug as they assume that they are not working inside a map scope. - # Before we use them we have to fix them. - # https://chat.spcl.inf.ethz.ch/spcl/pl/8mtgtqjb378hfy7h9a96sy3nhc - # dace_dataflow.MapReduceFusion, - # dace_dataflow.MapWCRFusion, ], validate=validate, validate_all=validate_all, @@ -278,34 +168,69 @@ def gt_auto_optimize( validate_all=validate_all, ) - # Phase 3: Optimizing the kernels, i.e. the larger maps, themselves. - # Currently this only applies fusion inside Maps. + # After we have created big kernels, we will perform some post cleanup. + gtx_transformations.gt_reduce_distributed_buffering(sdfg) sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion( - only_inner_maps=True, - ), + [ + gtx_transformations.GT4PyMoveTaskletIntoMap, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=assume_pointwise), + ], validate=validate, validate_all=validate_all, ) - gt_simplify(sdfg) + + # TODO(phimuell): The `MapReduceFusion` transformation is interesting as + # it moves the initialization of the accumulator at the top, which allows + # further fusing of the accumulator loop. However the transformation has + # a bug, so we can not use it. Furthermore, I have looked at the assembly + # and the compiler is already doing that. + # https://chat.spcl.inf.ethz.ch/spcl/pl/8mtgtqjb378hfy7h9a96sy3nhc + + # After we have created large kernels we run `dace_dataflow.MapReduceFusion`. + + # Phase 3: Optimizing the kernels, i.e. the larger maps, themselves. + # Currently this only applies fusion inside Maps. + gtx_transformations.gt_simplify(sdfg) + while True: + nb_applied = sdfg.apply_transformations_repeated( + [ + gtx_transformations.MapFusionSerial( + only_inner_maps=True, + ), + gtx_transformations.MapFusionParallel( + only_inner_maps=True, + only_if_common_ancestor=False, # TODO(phimuell): Should we? + ), + ], + validate=validate, + validate_all=validate_all, + ) + if not nb_applied: + break + gtx_transformations.gt_simplify(sdfg) # Phase 4: Iteration Space # This essentially ensures that the stride 1 dimensions are handled # by the inner most loop nest (CPU) or x-block (GPU) if leading_dim is not None: - gt_set_iteration_order( + gtx_transformations.gt_set_iteration_order( sdfg=sdfg, leading_dim=leading_dim, validate=validate, validate_all=validate_all, ) + # We now ensure that point wise computations are properly double buffered. + # The main reason is to ensure that rule 3 of ADR18 is maintained. + gtx_transformations.gt_create_local_double_buffering(sdfg) + # Phase 5: Apply blocking if blocking_dim is not None: sdfg.apply_transformations_once_everywhere( gtx_transformations.LoopBlocking( blocking_size=blocking_size, blocking_parameter=blocking_dim, + require_independent_nodes=blocking_only_if_independent_nodes, ), validate=validate, validate_all=validate_all, @@ -342,9 +267,23 @@ def gt_auto_optimize( dace_aoptimize.set_fast_implementations(sdfg, device) # TODO(phimuell): Fix the bug, it uses the tile value and not the stack array value. dace_aoptimize.move_small_arrays_to_stack(sdfg) + + # Now we modify the strides. + gtx_transformations.gt_change_transient_strides(sdfg, gpu=gpu) + if make_persistent: - # TODO(phimuell): Allow to also to set the lifetime to `SDFG`. - dace_aoptimize.make_transients_persistent(sdfg, device) + gtx_transformations.gt_make_transients_persistent(sdfg=sdfg, device=device) + + if device == dace.DeviceType.GPU: + # NOTE: For unknown reasons the counterpart of the + # `gt_make_transients_persistent()` function in DaCe, resets the + # `wcr_nonatomic` property of every memlet, i.e. makes it atomic. + # However, it does this only for edges on the top level and on GPU. + # For compatibility with DaCe (and until we found out why) the GT4Py + # auto optimizer will emulate this behaviour. + for state in sdfg.states(): + for edge in state.edges(): + edge.data.wcr_nonatomic = False return sdfg @@ -395,9 +334,17 @@ def gt_auto_fuse_top_level_maps( # TODO(phimuell): Add parallel fusion transformation. Should it run after # or with the serial one? sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion( - only_toplevel_maps=True, - ), + [ + gtx_transformations.MapFusionSerial( + only_toplevel_maps=True, + ), + gtx_transformations.MapFusionParallel( + only_toplevel_maps=True, + # This will lead to the creation of big probably unrelated maps. + # However, it might be good. + only_if_common_ancestor=False, + ), + ], validate=validate, validate_all=validate_all, ) @@ -437,7 +384,7 @@ def gt_auto_fuse_top_level_maps( # The SDFG was modified by the transformations above. The SDFG was # modified. Call Simplify and try again to further optimize. - gt_simplify(sdfg, validate=validate, validate_all=validate_all) + gtx_transformations.gt_simplify(sdfg, validate=validate, validate_all=validate_all) else: raise RuntimeWarning("Optimization of the SDFG did not converge.") diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py index 16c9600a3a..2cd3020180 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/gpu_utils.py @@ -11,10 +11,15 @@ from __future__ import annotations import copy -from typing import Any, Optional, Sequence, Union +from typing import Any, Callable, Final, Optional, Sequence, Union import dace -from dace import properties as dace_properties, transformation as dace_transformation +from dace import ( + dtypes as dace_dtypes, + properties as dace_properties, + transformation as dace_transformation, +) +from dace.codegen.targets import cpp as dace_cpp from dace.sdfg import nodes as dace_nodes from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -51,7 +56,9 @@ def gt_gpu_transformation( will avoid the data copy from host to GPU memory. gpu_block_size: The size of a thread block on the GPU. gpu_launch_bounds: Use this value as `__launch_bounds__` for _all_ GPU Maps. + Will only take effect if `gpu_block_size` is specified. gpu_launch_factor: Use the number of threads times this value as `__launch_bounds__` + Will only take effect if `gpu_block_size` is specified. validate: Perform validation during the steps. validate_all: Perform extensive validation. @@ -82,39 +89,197 @@ def gt_gpu_transformation( validate_all=validate_all, simplify=False, ) + # The documentation recommends to run simplify afterwards gtx_transformations.gt_simplify(sdfg) if try_removing_trivial_maps: - # A Tasklet, outside of a Map, that writes into an array on GPU can not work - # `sdfg.appyl_gpu_transformations()` puts Map around it (if said Tasklet - # would write into a Scalar that then goes into a GPU Map, nothing would - # happen. So we might end up with lot of these trivial Maps, that results - # in a single kernel launch. To prevent this we will try to fuse them. - # NOTE: The current implementation has a bug, because promotion and fusion - # are two different steps. Because of this the function will implicitly - # fuse everything together it can find. - # TODO(phimuell): Fix the issue described above. + # In DaCe a Tasklet, outside of a Map, can not write into an _array_ that is on + # GPU. `sdfg.appyl_gpu_transformations()` will wrap such Tasklets in a Map. So + # we might end up with lots of these trivial Maps, each requiring a separate + # kernel launch. To prevent this we will combine these trivial maps, if + # possible, with their downstream maps. sdfg.apply_transformations_once_everywhere( - TrivialGPUMapPromoter(), + TrivialGPUMapElimination(), validate=False, validate_all=False, ) - sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion( - only_toplevel_maps=True, - ), - validate=validate, - validate_all=validate_all, - ) + gtx_transformations.gt_simplify(sdfg, validate=validate, validate_all=validate_all) + + # TODO(phimuell): Fixing the stride problem. + sdfg = gt_gpu_transform_non_standard_memlet( + sdfg=sdfg, + map_postprocess=True, + validate=validate, + validate_all=validate_all, + ) # Set the GPU block size if it is known. if gpu_block_size is not None: gt_set_gpu_blocksize( sdfg=sdfg, - gpu_block_size=gpu_block_size, - gpu_launch_bounds=gpu_launch_bounds, - gpu_launch_factor=gpu_launch_factor, + block_size=gpu_block_size, + launch_bounds=gpu_launch_bounds, + launch_factor=gpu_launch_factor, + ) + + if validate_all or validate: + sdfg.validate() + + return sdfg + + +def gt_gpu_transform_non_standard_memlet( + sdfg: dace.SDFG, + map_postprocess: bool, + validate: bool = True, + validate_all: bool = False, +) -> dace.SDFG: + """Transform some non standard Melets to Maps. + + The GPU code generator is not able to handle certain sets of Memlets. To + handle them, the code generator transforms them into copy Maps. The main + issue is that this transformation happens after the auto optimizer, thus + the copy-Maps will most likely have the wrong iteration order. + + This function allows to perform the preprocessing step before the actual + code generation. The function will perform the expansion. If + `map_postprocess` is `True` then the function will also apply MapFusion, + to these newly created copy-Maps and set their iteration order correctly. + + A user should not call this function directly, instead this function is + called by the `gt_gpu_transformation()` function. + + Args: + sdfg: The SDFG that we process. + map_postprocess: Enable post processing of the maps that are created. + See the Note section below. + validate: Perform validation at the end of the function. + validate_all: Perform validation also on intermediate steps. + + Note: + - Currently the function applies some crude heuristic to determine the + correct loop order. + - This function should be called after `gt_set_iteration_order()` has run. + """ + new_maps: set[dace_nodes.MapEntry] = set() + + # This code is is copied from DaCe's code generator. + for e, state in list(sdfg.all_edges_recursive()): + nsdfg = state.parent + if ( + isinstance(e.src, dace_nodes.AccessNode) + and isinstance(e.dst, dace_nodes.AccessNode) + and e.src.desc(nsdfg).storage == dace_dtypes.StorageType.GPU_Global + and e.dst.desc(nsdfg).storage == dace_dtypes.StorageType.GPU_Global + ): + a: dace_nodes.AccessNode = e.src + b: dace_nodes.AccessNode = e.dst + + copy_shape, src_strides, dst_strides, _, _ = dace_cpp.memlet_copy_to_absolute_strides( + None, nsdfg, state, e, a, b + ) + dims = len(copy_shape) + if dims == 1: + continue + elif dims == 2: + if src_strides[-1] != 1 or dst_strides[-1] != 1: + try: + is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1] + is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1] + except (TypeError, ValueError): + is_src_cont = False + is_dst_cont = False + if is_src_cont and is_dst_cont: + continue + else: + continue + elif dims > 2: + if not (src_strides[-1] != 1 or dst_strides[-1] != 1): + continue + + # For identifying the new map, we first store all neighbors of `a`. + old_neighbors_of_a: list[dace_nodes.AccessNode] = [ + edge.dst for edge in state.out_edges(a) + ] + + # Turn unsupported copy to a map + try: + dace_transformation.dataflow.CopyToMap.apply_to( + nsdfg, save=False, annotate=False, a=a, b=b + ) + except ValueError: # If transformation doesn't match, continue normally + continue + + # We find the new map by comparing the new neighborhood of `a` with the old one. + new_nodes: set[dace_nodes.MapEntry] = { + edge.dst for edge in state.out_edges(a) if edge.dst not in old_neighbors_of_a + } + assert any(isinstance(new_node, dace_nodes.MapEntry) for new_node in new_nodes) + assert len(new_nodes) == 1 + new_maps.update(new_nodes) + + # If there are no Memlets that are translated to copy-Maps, then we have nothing to do. + if len(new_maps) == 0: + return sdfg + + # This function allows to restrict any fusion operation to the maps + # that we have just created. + def restrict_fusion_to_newly_created_maps( + self: gtx_transformations.map_fusion_helper.MapFusionHelper, + map_entry_1: dace_nodes.MapEntry, + map_entry_2: dace_nodes.MapEntry, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + permissive: bool, + ) -> bool: + return any(new_entry in new_maps for new_entry in [map_entry_1, map_entry_2]) + + # Using the callback to restrict the fusing + sdfg.apply_transformations_repeated( + [ + gtx_transformations.MapFusionSerial( + only_toplevel_maps=True, + apply_fusion_callback=restrict_fusion_to_newly_created_maps, + ), + gtx_transformations.MapFusionParallel( + only_toplevel_maps=True, + apply_fusion_callback=restrict_fusion_to_newly_created_maps, + ), + ], + validate=validate, + validate_all=validate_all, + ) + + # Now we have to find the maps that were not fused. We rely here on the fact + # that at least one of the map that is involved in fusing still exists. + maps_to_modify: set[dace_nodes.MapEntry] = set() + for nsdfg in sdfg.all_sdfgs_recursive(): + for state in nsdfg.states(): + for map_entry in state.nodes(): + if not isinstance(map_entry, dace_nodes.MapEntry): + continue + if map_entry in new_maps: + maps_to_modify.add(map_entry) + assert 0 < len(maps_to_modify) <= len(new_maps) + + # This is a gross hack, but it is needed, for the following reasons: + # - The transients have C order while the non-transients have (most + # likely) FORTRAN order. So there is not an unique stride dimension. + # - The newly created maps have names that does not reflect GT4Py dimensions, + # thus we can not use `gt_set_iteration_order()`. + # For these reasons we do the simplest thing, which is assuming that the maps + # are created in C order and we must make them in FORTRAN order, which means + # just swapping the order of the map parameters. + # TODO(phimuell): Do it properly. + for me_to_modify in maps_to_modify: + map_to_modify: dace_nodes.Map = me_to_modify.map + map_to_modify.params = list(reversed(map_to_modify.params)) + map_to_modify.range = dace.subsets.Range( + (r1, r2, r3, t) + for (r1, r2, r3), t in zip( + reversed(map_to_modify.range.ranges), reversed(map_to_modify.range.tile_sizes) + ) ) return sdfg @@ -122,131 +287,214 @@ def gt_gpu_transformation( def gt_set_gpu_blocksize( sdfg: dace.SDFG, - gpu_block_size: Optional[Sequence[int | str] | str], - gpu_launch_bounds: Optional[int | str] = None, - gpu_launch_factor: Optional[int] = None, + block_size: Optional[Sequence[int | str] | str], + launch_bounds: Optional[int | str] = None, + launch_factor: Optional[int] = None, + **kwargs: Any, ) -> Any: """Set the block size related properties of _all_ Maps. - See `GPUSetBlockSize` for more information. + It supports the same arguments as `GPUSetBlockSize`, however it also has + versions without `_Xd`, these are used as default for the other maps. + If a version with `_Xd` is specified then it takes precedence. Args: sdfg: The SDFG to process. - gpu_block_size: The size of a thread block on the GPU. + block_size: The size of a thread block on the GPU. launch_bounds: The value for the launch bound that should be used. launch_factor: If no `launch_bounds` was given use the number of threads in a block multiplied by this number. """ - xform = GPUSetBlockSize( - block_size=gpu_block_size, - launch_bounds=gpu_launch_bounds, - launch_factor=gpu_launch_factor, - ) - return sdfg.apply_transformations_once_everywhere([xform]) - - -def _gpu_block_parser( - self: GPUSetBlockSize, - val: Any, -) -> None: - """Used by the setter of `GPUSetBlockSize.block_size`.""" - org_val = val - if isinstance(val, (tuple | list)): - pass - elif isinstance(val, str): - val = tuple(x.strip() for x in val.split(",")) - elif isinstance(val, int): - val = (val,) - else: - raise TypeError( - f"Does not know how to transform '{type(org_val).__name__}' into a proper GPU block size." - ) - if 0 < len(val) <= 3: - val = [*val, *([1] * (3 - len(val)))] - else: - raise ValueError(f"Can not parse block size '{org_val}': wrong length") - try: - val = [int(x) for x in val] - except ValueError: - raise TypeError( - f"Currently only block sizes convertible to int are supported, you passed '{val}'." - ) from None - self._block_size = val - + for dim in [1, 2, 3]: + for arg, val in { + "block_size": block_size, + "launch_bounds": launch_bounds, + "launch_factor": launch_factor, + }.items(): + if f"{arg}_{dim}d" not in kwargs: + kwargs[f"{arg}_{dim}d"] = val + return sdfg.apply_transformations_once_everywhere(GPUSetBlockSize(**kwargs)) + + +def _make_gpu_block_parser_for( + dim: int, +) -> Callable[["GPUSetBlockSize", Any], None]: + """Generates a parser for GPU blocks for dimension `dim`. + + The returned function can be used as parser for the `GPUSetBlockSize.block_size_*d` + properties. + """ -def _gpu_block_getter( - self: "GPUSetBlockSize", -) -> tuple[int, int, int]: - """Used as getter in the `GPUSetBlockSize.block_size` property.""" - assert isinstance(self._block_size, (tuple, list)) and len(self._block_size) == 3 - assert all(isinstance(x, int) for x in self._block_size) - return tuple(self._block_size) + def _gpu_block_parser( + self: GPUSetBlockSize, + val: Any, + ) -> None: + """Used by the setter of `GPUSetBlockSize.block_size`.""" + org_val = val + if isinstance(val, (tuple | list)): + pass + elif isinstance(val, str): + val = tuple(x.strip() for x in val.split(",")) + elif isinstance(val, int): + val = (val,) + else: + raise TypeError( + f"Does not know how to transform '{type(org_val).__name__}' into a proper GPU block size." + ) + if len(val) < dim: + raise ValueError( + f"The passed block size only covers {len(val)} dimensions, but dimension was {dim}." + ) + if 0 < len(val) <= 3: + val = [*val, *([1] * (3 - len(val)))] + else: + raise ValueError(f"Can not parse block size '{org_val}': wrong length") + try: + val = [int(x) for x in val] + except ValueError: + raise TypeError( + f"Currently only block sizes convertible to int are supported, you passed '{val}'." + ) from None + + # Remove over specification. + for i in range(dim, 3): + val[i] = 1 + setattr(self, f"_block_size_{dim}d", tuple(val)) + + return _gpu_block_parser + + +def _make_gpu_block_getter_for( + dim: int, +) -> Callable[["GPUSetBlockSize"], tuple[int, int, int]]: + """Makes the getter for the block size of dimension `dim`.""" + + def _gpu_block_getter( + self: "GPUSetBlockSize", + ) -> tuple[int, int, int]: + """Used as getter in the `GPUSetBlockSize.block_size` property.""" + return getattr(self, f"_block_size_{dim}d") + + return _gpu_block_getter + + +def _gpu_launch_bound_parser( + block_size: tuple[int, int, int], + launch_bounds: int | str | None, + launch_factor: int | None = None, +) -> str | None: + """Used by the `GPUSetBlockSize.__init__()` method to parse the launch bounds.""" + if launch_bounds is None and launch_factor is None: + return None + elif launch_bounds is None and launch_factor is not None: + return str(int(launch_factor) * block_size[0] * block_size[1] * block_size[2]) + elif launch_bounds is not None and launch_factor is None: + assert isinstance(launch_bounds, (str, int)) + return str(launch_bounds) + else: + raise ValueError("Specified both `launch_bounds` and `launch_factor`.") @dace_properties.make_properties class GPUSetBlockSize(dace_transformation.SingleStateTransformation): """Sets the GPU block size on GPU Maps. - The transformation will apply to all Maps that have a GPU schedule, regardless - of their dimensionality. + The `block_size` is either a sequence, of up to three integers or a string + of up to three numbers, separated by comma (`,`). The first number is the size + of the block in `x` direction, the second for the `y` direction and the third + for the `z` direction. Missing values will be filled with `1`. - The `gpu_block_size` is either a sequence, of up to three integers or a string - of up to three numbers, separated by comma (`,`). - The first number is the size of the block in `x` direction, the second for the - `y` direction and the third for the `z` direction. Missing values will be filled - with `1`. + A different value for the GPU block size and launch bound can be specified for + maps of dimension 1, 2 or 3 (all maps with higher dimensions are considered + three dimensional). If no value is specified then the block size `(32, 1, 1)` + will be used an no launch bound will be be emitted. Args: - block_size: The size of a thread block on the GPU. - launch_bounds: The value for the launch bound that should be used. - launch_factor: If no `launch_bounds` was given use the number of threads - in a block multiplied by this number. + block_size_Xd: The size of a thread block on the GPU for `X` dimensional maps. + launch_bounds_Xd: The value for the launch bound that should be used for `X` + dimensional maps. + launch_factor_Xd: If no `launch_bounds` was given use the number of threads + in a block multiplied by this number, for maps of dimension `X`. - Todo: - Add the possibility to specify other bounds for 1, 2, or 3 dimensional maps. + Note: + - You should use the `gt_set_gpu_blocksize()` function. + - "Over specification" is ignored, i.e. if `(32, 3, 1)` is passed as block + size for 1 dimensional maps, then it is changed to `(32, 1, 1)`. """ - block_size = dace_properties.Property( - dtype=None, - allow_none=False, - default=(32, 1, 1), - setter=_gpu_block_parser, - getter=_gpu_block_getter, - desc="Size of the block size a GPU Map should have.", - ) + _block_size_default: Final[tuple[int, int, int]] = (32, 1, 1) - launch_bounds = dace_properties.Property( + block_size_1d = dace_properties.Property( + dtype=tuple[int, int, int], + default=_block_size_default, + setter=_make_gpu_block_parser_for(1), + getter=_make_gpu_block_getter_for(1), + desc="Block size for 1 dimensional GPU maps.", + ) + launch_bounds_1d = dace_properties.Property( + dtype=str, + allow_none=True, + default=None, + desc="Set the launch bound property for 1 dimensional map.", + ) + block_size_2d = dace_properties.Property( + dtype=tuple[int, int, int], + default=_block_size_default, + setter=_make_gpu_block_parser_for(2), + getter=_make_gpu_block_getter_for(2), + desc="Block size for 2 dimensional GPU maps.", + ) + launch_bounds_2d = dace_properties.Property( + dtype=str, + allow_none=True, + default=None, + desc="Set the launch bound property for 2 dimensional map.", + ) + block_size_3d = dace_properties.Property( + dtype=tuple[int, int, int], + default=_block_size_default, + setter=_make_gpu_block_parser_for(3), + getter=_make_gpu_block_getter_for(3), + desc="Block size for 3 dimensional GPU maps.", + ) + launch_bounds_3d = dace_properties.Property( dtype=str, allow_none=True, default=None, - desc="Set the launch bound property of the map.", + desc="Set the launch bound property for 3 dimensional map.", ) - map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + # Pattern matching + map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) def __init__( self, - block_size: Sequence[int | str] | str | None = None, - launch_bounds: int | str | None = None, - launch_factor: int | None = None, + block_size_1d: Sequence[int | str] | str | None = None, + block_size_2d: Sequence[int | str] | str | None = None, + block_size_3d: Sequence[int | str] | str | None = None, + launch_bounds_1d: int | str | None = None, + launch_bounds_2d: int | str | None = None, + launch_bounds_3d: int | str | None = None, + launch_factor_1d: int | None = None, + launch_factor_2d: int | None = None, + launch_factor_3d: int | None = None, ) -> None: super().__init__() - if block_size is not None: - self.block_size = block_size - - if launch_factor is not None: - assert launch_bounds is None - self.launch_bounds = str( - int(launch_factor) * self.block_size[0] * self.block_size[1] * self.block_size[2] - ) - elif launch_bounds is None: - self.launch_bounds = None - elif isinstance(launch_bounds, (str, int)): - self.launch_bounds = str(launch_bounds) - else: - raise TypeError( - f"Does not know how to parse '{launch_bounds}' as 'launch_bounds' argument." - ) + if block_size_1d is not None: + self.block_size_1d = block_size_1d + if block_size_2d is not None: + self.block_size_2d = block_size_2d + if block_size_3d is not None: + self.block_size_3d = block_size_3d + self.launch_bounds_1d = _gpu_launch_bound_parser( + self.block_size_1d, launch_bounds_1d, launch_factor_1d + ) + self.launch_bounds_2d = _gpu_launch_bound_parser( + self.block_size_2d, launch_bounds_2d, launch_factor_2d + ) + self.launch_bounds_3d = _gpu_launch_bound_parser( + self.block_size_3d, launch_bounds_3d, launch_factor_3d + ) @classmethod def expressions(cls) -> Any: @@ -266,7 +514,6 @@ def can_be_applied( - If the map is at global scope. - If if the schedule of the map is correct. """ - scope = graph.scope_dict() if scope[self.map_entry] is not None: return False @@ -282,35 +529,69 @@ def apply( sdfg: dace.SDFG, ) -> None: """Modify the map as requested.""" - self.map_entry.map.gpu_block_size = self.block_size - if self.launch_bounds is not None: # Note empty string has a meaning in DaCe - self.map_entry.map.gpu_launch_bounds = self.launch_bounds + gpu_map: dace_nodes.Map = self.map_entry.map + if len(gpu_map.params) == 1: + block_size = self.block_size_1d + launch_bounds = self.launch_bounds_1d + elif len(gpu_map.params) == 2: + block_size = self.block_size_2d + launch_bounds = self.launch_bounds_2d + else: + block_size = self.block_size_3d + launch_bounds = self.launch_bounds_3d + gpu_map.gpu_block_size = block_size + if launch_bounds is not None: # Note: empty string has a meaning in DaCe + gpu_map.gpu_launch_bounds = launch_bounds @dace_properties.make_properties -class TrivialGPUMapPromoter(dace_transformation.SingleStateTransformation): - """Serial Map promoter for empty GPU maps. +class TrivialGPUMapElimination(dace_transformation.SingleStateTransformation): + """Eliminate certain kind of trivial GPU maps. - In CPU mode a Tasklet can be outside of a map, however, this is not - possible in GPU mode. For this reason DaCe wraps such Tasklets in a - trivial Map. - This transformation will look for such Maps and promote them, such - that they can be fused with downstream maps. + A tasklet outside of map can not write to GPU memory, this can only be done + from within a map (a scalar is possible). For that reason DaCe's GPU + transformation wraps such tasklets in trivial maps. + Under certain condition the transformation will fuse the trivial tasklet with + a downstream (serial) map. + + Args: + do_not_fuse: If `True` then the maps are not fused together. + only_gpu_maps: Only apply to GPU maps; `True` by default. Note: - This transformation should not be run on its own, instead it is run within the context of `gt_gpu_transformation()`. - This transformation must be run after the GPU Transformation. - - Currently the transformation does not do the fusion on its own. - Instead map fusion must be run afterwards. - - The transformation assumes that the upper Map is a trivial Tasklet. - Which should be the majority of all cases. """ + only_gpu_maps = dace_properties.Property( + dtype=bool, + default=True, + desc="Only promote maps that are GPU maps (debug option).", + ) + do_not_fuse = dace_properties.Property( + dtype=bool, + default=False, + desc="Only perform the promotion, do not fuse.", + ) + # Pattern Matching - trivial_map_exit = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) - access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) - second_map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + trivial_map_exit = dace_transformation.PatternNode(dace_nodes.MapExit) + access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + second_map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) + + def __init__( + self, + do_not_fuse: Optional[bool] = None, + only_gpu_maps: Optional[bool] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if only_gpu_maps is not None: + self.only_gpu_maps = only_gpu_maps + if do_not_fuse is not None: + self.do_not_fuse = do_not_fuse @classmethod def expressions(cls) -> Any: @@ -332,63 +613,118 @@ def can_be_applied( The tests includes: - Schedule of the maps. - If the map is trivial. - - If the trivial map was not used to define a symbol. - - Intermediate access node can only have in and out degree of 1. - - The trivial map exit can only have one output. + - Tests if the maps can be fused. """ trivial_map_exit: dace_nodes.MapExit = self.trivial_map_exit trivial_map: dace_nodes.Map = trivial_map_exit.map trivial_map_entry: dace_nodes.MapEntry = graph.entry_node(trivial_map_exit) second_map: dace_nodes.Map = self.second_map_entry.map - access_node: dace_nodes.AccessNode = self.access_node # The kind of maps we are interested only have one parameter. if len(trivial_map.params) != 1: return False - - # Check if it is a GPU map - for map_to_check in [trivial_map, second_map]: - if map_to_check.schedule not in [ - dace.dtypes.ScheduleType.GPU_Device, - dace.dtypes.ScheduleType.GPU_Default, - ]: - return False - - # Check if the map is trivial. for rng in trivial_map.range.ranges: if rng[0] != rng[1]: return False - # Now we have to ensure that the symbol is not used inside the scope of the - # map, if it is, then the symbol is just there to define a symbol. - scope_view = graph.scope_subgraph( - trivial_map_entry, - include_entry=False, - include_exit=False, - ) - if any(map_param in scope_view.free_symbols for map_param in trivial_map.params): - return False + # If we do not not fuse, then the second map can not be trivial. + # If we would not prevent that case then we would match these two + # maps again and again. + if self.do_not_fuse and len(second_map.params) <= 1: + for rng in second_map.range.ranges: + if rng[0] == rng[1]: + return False + + # We now check that the Memlets do not depend on the map parameter. + # This is important for the `can_be_applied_to()` check we do below + # because we can avoid calling the replace function. + scope = graph.scope_subgraph(trivial_map_entry) + trivial_map_param: str = trivial_map.params[0] + for edge in scope.edges(): + if trivial_map_param in edge.data.free_symbols: + return False - # ensuring that the trivial map exit and the intermediate node have degree - # one is a cheap way to ensure that the map can be merged into the - # second map. - if graph.in_degree(access_node) != 1: - return False - if graph.out_degree(access_node) != 1: - return False - if graph.out_degree(trivial_map_exit) != 1: - return False + # Check if only GPU maps are involved (this is more a testing debug feature). + if self.only_gpu_maps: + for map_to_check in [trivial_map, second_map]: + if map_to_check.schedule not in [ + dace.dtypes.ScheduleType.GPU_Device, + dace.dtypes.ScheduleType.GPU_Default, + ]: + return False + + # Now we check if the two maps can be fused together. For that we have to + # do a temporary promotion, it is important that we do not perform the + # renaming. If the old symbol is still used, it is used inside a tasklet + # so it would show up (temporarily) as free symbol. + org_trivial_map_params = copy.deepcopy(trivial_map.params) + org_trivial_map_range = copy.deepcopy(trivial_map.range) + try: + self._promote_map(graph, replace_trivail_map_parameter=False) + if not gtx_transformations.MapFusionSerial.can_be_applied_to( + sdfg=sdfg, + map_exit_1=trivial_map_exit, + intermediate_access_node=self.access_node, + map_entry_2=self.second_map_entry, + ): + return False + finally: + trivial_map.params = org_trivial_map_params + trivial_map.range = org_trivial_map_range return True def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: """Performs the Map Promoting. - The function essentially copies the parameters and the ranges from the - bottom map to the top one. + The function will first perform the promotion of the trivial map and then + perform the merging of the two maps in one go. """ + trivial_map_exit: dace_nodes.MapExit = self.trivial_map_exit + second_map_entry: dace_nodes.MapEntry = self.second_map_entry + access_node: dace_nodes.AccessNode = self.access_node + + # Promote the maps. + self._promote_map(graph) + + # Perform the fusing if requested. + if not self.do_not_fuse: + gtx_transformations.MapFusionSerial.apply_to( + sdfg=sdfg, + map_exit_1=trivial_map_exit, + intermediate_access_node=access_node, + map_entry_2=second_map_entry, + verify=True, + ) + + def _promote_map( + self, + state: dace.SDFGState, + replace_trivail_map_parameter: bool = True, + ) -> None: + """Performs the map promoting. + + Essentially this function will copy the parameters and the range from + the non trivial map (`self.second_map_entry.map`) to the trivial map + (`self.trivial_map_exit.map`). + + If `replace_trivail_map_parameter` is `True` (the default value), then the + function will also remove the trivial map parameter with its value. + """ + assert isinstance(self.trivial_map_exit, dace_nodes.MapExit) + assert isinstance(self.second_map_entry, dace_nodes.MapEntry) + assert isinstance(self.access_node, dace_nodes.AccessNode) + + trivial_map_exit: dace_nodes.MapExit = self.trivial_map_exit trivial_map: dace_nodes.Map = self.trivial_map_exit.map + trivial_map_entry: dace_nodes.MapEntry = state.entry_node(trivial_map_exit) second_map: dace_nodes.Map = self.second_map_entry.map + # If requested then replace the map variable with its value. + if replace_trivail_map_parameter: + scope = state.scope_subgraph(trivial_map_entry) + scope.replace(trivial_map.params[0], trivial_map.range[0][0]) + + # Now copy parameter and the ranges from the second to the trivial map. trivial_map.params = copy.deepcopy(second_map.params) trivial_map.range = copy.deepcopy(second_map.range) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/local_double_buffering.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/local_double_buffering.py new file mode 100644 index 0000000000..52f1de3d0c --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/local_double_buffering.py @@ -0,0 +1,393 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + + +import copy + +import dace +from dace import ( + data as dace_data, + dtypes as dace_dtypes, + symbolic as dace_symbolic, + transformation as dace_transformation, +) +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +def gt_create_local_double_buffering( + sdfg: dace.SDFG, +) -> int: + """Modifies the SDFG such that point wise data dependencies are stable. + + Rule 3 of the ADR18, guarantees that if data is input and output to a map, + then it must be a non transient array and it must only have point wise + dependency. This means that every index that is read is also written by + the same thread and no other thread reads or writes to the same location. + However, because the dataflow inside a map is partially asynchronous + it might happen if something is read multiple times, i.e. Tasklets, + the data might already be overwritten. + This function will scan the SDFG for potential cases and insert an + access node to cache this read. This is essentially a double buffer, but + it is not needed that the whole data is stored, but only the working set + of a single thread. + """ + + processed_maps = 0 + for nsdfg in sdfg.all_sdfgs_recursive(): + processed_maps += _create_local_double_buffering_non_recursive(nsdfg) + return processed_maps + + +def _create_local_double_buffering_non_recursive( + sdfg: dace.SDFG, +) -> int: + """Implementation of the point wise transformation. + + This function does not handle nested SDFGs. + """ + # First we call `EdgeConsolidation`, because of that we know that + # every incoming edge of a `MapEntry` refers to distinct data. + # We do this to simplify our implementation. + edge_consolidation = dace_transformation.passes.ConsolidateEdges() + edge_consolidation.apply_pass(sdfg, None) + + processed_maps = 0 + for state in sdfg.states(): + scope_dict = state.scope_dict() + for node in state.nodes(): + if not isinstance(node, dace_nodes.MapEntry): + continue + if scope_dict[node] is not None: + continue + inout_nodes = _check_if_map_must_be_handled( + map_entry=node, + state=state, + sdfg=sdfg, + ) + if inout_nodes is not None: + processed_maps += _add_local_double_buffering_to( + map_entry=node, + inout_nodes=inout_nodes, + state=state, + sdfg=sdfg, + ) + return processed_maps + + +def _add_local_double_buffering_to( + inout_nodes: dict[str, tuple[dace_nodes.AccessNode, dace_nodes.AccessNode]], + map_entry: dace_nodes.MapEntry, + state: dace.SDFGState, + sdfg: dace.SDFG, +) -> int: + """Adds the double buffering to `map_entry` for `inout_nodes`. + + The function assumes that there is only in incoming edge per data + descriptor at the map entry. If the data is needed multiple times, + then the distribution must be done inside the map. + + The function will now channel all reads to the data descriptor + through an access node, this ensures that the read happens + before the write. + """ + processed_maps = 0 + for inout_node in inout_nodes.values(): + _add_local_double_buffering_to_single_data( + map_entry=map_entry, + inout_node=inout_node, + state=state, + sdfg=sdfg, + ) + processed_maps += 1 + return processed_maps + + +def _add_local_double_buffering_to_single_data( + inout_node: tuple[dace_nodes.AccessNode, dace_nodes.AccessNode], + map_entry: dace_nodes.MapEntry, + state: dace.SDFGState, + sdfg: dace.SDFG, +) -> None: + """Adds the local double buffering for a single data.""" + map_exit: dace_nodes.MapExit = state.exit_node(map_entry) + input_node, output_node = inout_node + input_edges = state.edges_between(input_node, map_entry) + output_edges = state.edges_between(map_exit, output_node) + assert len(input_edges) == 1 + assert len(output_edges) == 1 + inner_read_edges = _get_inner_edges(input_edges[0], map_entry, state, False) + inner_write_edges = _get_inner_edges(output_edges[0], map_exit, state, True) + + # For now we assume that all read the same, which is checked below. + new_double_inner_buff_shape_raw = dace_symbolic.overapproximate( + inner_read_edges[0].data.get_src_subset(inner_read_edges[0], state).size() + ) + + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + squeezed_dims: list[int] = [] # These are the dimensions we removed. + new_double_inner_buff_shape: list[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate( + zip(new_double_inner_buff_shape_raw, input_node.desc(sdfg).shape) + ): + if full_dim_size == 1: # Must be kept! + new_double_inner_buff_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_double_inner_buff_shape.append(proposed_dim_size) + + new_double_inner_buff_name: str = f"__inner_double_buffer_for_{input_node.data}" + # Now generate the intermediate data container. + if len(new_double_inner_buff_shape) == 0: + new_double_inner_buff_name, new_double_inner_buff_desc = sdfg.add_scalar( + new_double_inner_buff_name, + dtype=input_node.desc(sdfg).dtype, + transient=True, + storage=dace_dtypes.StorageType.Register, + find_new_name=True, + ) + else: + new_double_inner_buff_name, new_double_inner_buff_desc = sdfg.add_transient( + new_double_inner_buff_name, + shape=new_double_inner_buff_shape, + dtype=input_node.desc(sdfg).dtype, + find_new_name=True, + storage=dace_dtypes.StorageType.Register, + ) + new_double_inner_buff_node = state.add_access(new_double_inner_buff_name) + + # Now reroute the data flow through the new access node. + for old_inner_read_edge in inner_read_edges: + # To do handle the case the memlet is "fancy" + state.add_edge( + new_double_inner_buff_node, + None, + old_inner_read_edge.dst, + old_inner_read_edge.dst_conn, + dace.Memlet( + data=new_double_inner_buff_name, + subset=dace.subsets.Range.from_array(new_double_inner_buff_desc), + other_subset=copy.deepcopy( + old_inner_read_edge.data.get_dst_subset(old_inner_read_edge, state) + ), + ), + ) + state.remove_edge(old_inner_read_edge) + + # Now create a connection between the map entry and the intermediate node. + state.add_edge( + map_entry, + inner_read_edges[0].src_conn, + new_double_inner_buff_node, + None, + dace.Memlet( + data=input_node.data, + subset=copy.deepcopy( + inner_read_edges[0].data.get_src_subset(inner_read_edges[0], state) + ), + other_subset=dace.subsets.Range.from_array(new_double_inner_buff_desc), + ), + ) + + # To really ensure that a read happens before a write, we have to sequence + # the read first. We do this by connecting the double buffer node with + # empty Memlets to the last row of nodes that writes to the global buffer. + # This is needed to handle the case that some other data path performs the + # write. + # TODO(phimuell): Add a test that only performs this when it is really needed. + for inner_write_edge in inner_write_edges: + state.add_nedge( + new_double_inner_buff_node, + inner_write_edge.src, + dace.Memlet(), + ) + + +def _check_if_map_must_be_handled_classify_adjacent_access_node( + data_node: dace_nodes.AccessNode, + sdfg: dace.SDFG, + known_nodes: dict[str, dace_nodes.AccessNode], +) -> bool: + """Internal function used by `_check_if_map_must_be_handled()` to classify nodes. + + If the function returns `True` it means that the input/output, does not + violates an internal constraint, i.e. can be handled by + `_ensure_that_map_is_pointwise()`. If appropriate the function will add the + node to `known_nodes`. I.e. in case of a transient the function will return + `True` but will not add it to `known_nodes`. + """ + + # This case is indicating that the `ConsolidateEdges` has not fully worked. + # Currently the transformation implementation assumes that this is the + # case, so we can not handle this case. + # TODO(phimuell): Implement this case. + if data_node.data in known_nodes: + return False + data_desc: dace_data.Data = data_node.desc(sdfg) + + # The conflict can only occur for global data, because transients + # are only written once. + if data_desc.transient: + return False + + # Currently we do not handle view, as they need to be traced. + # TODO(phimuell): Implement + if gtx_transformations.util.is_view(data_desc, sdfg): + return False + + # TODO(phimuell): Check if there is a access node on the inner side, then we do not have to do it. + + # Now add the node to the list. + assert all(data_node is not known_node for known_node in known_nodes.values()) + known_nodes[data_node.data] = data_node + return True + + +def _get_inner_edges( + outer_edge: dace.sdfg.graph.MultiConnectorEdge, + scope_node: dace_nodes.MapExit | dace_nodes.MapEntry, + state: dace.SDFG, + outgoing_edge: bool, +) -> list[dace.sdfg.graph.MultiConnectorEdge]: + """Gets the edges on the inside of a map.""" + if outgoing_edge: + assert isinstance(scope_node, dace_nodes.MapExit) + conn_name = outer_edge.src_conn[4:] + return list(state.in_edges_by_connector(scope_node, connector="IN_" + conn_name)) + else: + assert isinstance(scope_node, dace_nodes.MapEntry) + conn_name = outer_edge.dst_conn[3:] + return list(state.out_edges_by_connector(scope_node, connector="OUT_" + conn_name)) + + +def _check_if_map_must_be_handled( + map_entry: dace_nodes.MapEntry, + state: dace.SDFGState, + sdfg: dace.SDFG, +) -> None | dict[str, tuple[dace_nodes.AccessNode, dace_nodes.AccessNode]]: + """Check if the map should be processed to uphold rule 3. + + Essentially the function will check if there is a potential read-write + conflict. The function assumes that `ConsolidateEdges` has already run. + + If there is a possible data race the function will return a `dict`, that + maps the name of the data to the access nodes that are used as input and + output to the Map. + + Otherwise, the function returns `None`. It is, however, important that + `None` does not means that there is no possible race condition. It could + also means that the function that implements the buffering, i.e. + `_ensure_that_map_is_pointwise()`, is unable to handle this case. + + Todo: + Improve the function + """ + map_exit: dace_nodes.MapExit = state.exit_node(map_entry) + + # Find all the data that is accessed. Views are resolved. + input_datas: dict[str, dace_nodes.AccessNode] = {} + output_datas: dict[str, dace_nodes.AccessNode] = {} + + # Determine which nodes are possible conflicting. + for in_edge in state.in_edges(map_entry): + if in_edge.data.is_empty(): + continue + if not isinstance(in_edge.src, dace_nodes.AccessNode): + # TODO(phiumuell): Figuring out what this case means + continue + if in_edge.dst_conn and not in_edge.dst_conn.startswith("IN_"): + # TODO(phimuell): It is very unlikely that a Dynamic Map Range causes + # this particular data race, so we ignore it for the time being. + continue + if not _check_if_map_must_be_handled_classify_adjacent_access_node( + data_node=in_edge.src, + sdfg=sdfg, + known_nodes=input_datas, + ): + continue + for out_edge in state.out_edges(map_exit): + if out_edge.data.is_empty(): + continue + if not isinstance(out_edge.dst, dace_nodes.AccessNode): + # TODO(phiumuell): Figuring out what this case means + continue + if not _check_if_map_must_be_handled_classify_adjacent_access_node( + data_node=out_edge.dst, + sdfg=sdfg, + known_nodes=output_datas, + ): + continue + + # Double buffering is only needed if there inout arguments. + inout_datas: dict[str, tuple[dace_nodes.AccessNode, dace_nodes.AccessNode]] = { + dname: (input_datas[dname], output_datas[dname]) + for dname in input_datas + if dname in output_datas + } + if len(inout_datas) == 0: + return None + + # TODO(phimuell): What about the case that some data descriptor needs double + # buffering, but some do not? + for inout_data_name in list(inout_datas.keys()): + input_node, output_node = inout_datas[inout_data_name] + input_edges = state.edges_between(input_node, map_entry) + output_edges = state.edges_between(map_exit, output_node) + assert ( + len(input_edges) == 1 + ), f"Expected a single connection between input node and map entry, but found {len(input_edges)}." + assert ( + len(output_edges) == 1 + ), f"Expected a single connection between map exit and write back node, but found {len(output_edges)}." + + # If there is only one edge on the inside of the map, that goes into an + # AccessNode, then we assume it is double buffered. + inner_read_edges = _get_inner_edges(input_edges[0], map_entry, state, False) + if ( + len(inner_read_edges) == 1 + and isinstance(inner_read_edges[0].dst, dace_nodes.AccessNode) + and not gtx_transformations.util.is_view(inner_read_edges[0].dst, sdfg) + ): + inout_datas.pop(inout_data_name) + continue + + inner_read_subsets = [ + inner_read_edge.data.get_src_subset(inner_read_edge, state) + for inner_read_edge in inner_read_edges + ] + assert all(inner_read_subset is not None for inner_read_subset in inner_read_subsets) + inner_write_subsets = [ + inner_write_edge.data.get_dst_subset(inner_write_edge, state) + for inner_write_edge in _get_inner_edges(output_edges[0], map_exit, state, True) + ] + # TODO(phimuell): Also implement a check that the volume equals the size of the subset. + assert all(inner_write_subset is not None for inner_write_subset in inner_write_subsets) + + # For being point wise the subsets must be compatible. The correct check would be: + # - The write sets are unique. + # - For every read subset there exists one matching write subset. It could + # be that there are many equivalent read subsets. + # - For every write subset there exists at least one matching read subset. + # The current implementation only checks if all are the same. + # TODO(phimuell): Implement the real check. + all_inner_subsets = inner_read_subsets + inner_write_subsets + if not all( + all_inner_subsets[0] == all_inner_subsets[i] for i in range(1, len(all_inner_subsets)) + ): + return None + + if len(inout_datas) == 0: + return None + + return inout_datas diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py index d7326e1131..27b6c68072 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/loop_blocking.py @@ -36,12 +36,16 @@ class LoopBlocking(dace_transformation.SingleStateTransformation): What makes this transformation different from simple blocking, is that the inner map will not just be inserted right after the outer Map. Instead the transformation will first identify all nodes that does not depend - on the blocking parameter `I` and relocate them between the outer and inner map. - Thus these operations will only be performed once, per inner loop. + on the blocking parameter `I`, called independent nodes and relocate them + between the outer and inner map. Note that an independent node must be connected + to the MapEntry or another independent node. + Thus these operations will only be performed once, per outer loop iteration. Args: blocking_size: The size of the block, denoted as `B` above. blocking_parameter: On which parameter should we block. + require_independent_nodes: If `True` only apply loop blocking if the Map + actually contains independent nodes. Defaults to `False`. Todo: - Modify the inner map such that it always starts at zero. @@ -59,25 +63,23 @@ class LoopBlocking(dace_transformation.SingleStateTransformation): desc="Name of the iteration variable on which to block (must be an exact match);" " 'I' in the above description.", ) - independent_nodes = dace_properties.Property( - dtype=set, - allow_none=True, - default=None, - desc="Set of nodes that are independent of the blocking parameter.", - ) - dependent_nodes = dace_properties.Property( - dtype=set, - allow_none=True, - default=None, - desc="Set of nodes that are dependent on the blocking parameter.", + require_independent_nodes = dace_properties.Property( + dtype=bool, + default=False, + desc="If 'True' then blocking is only applied if there are independent nodes.", ) - outer_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + # Set of nodes that are independent of the blocking parameter. + _independent_nodes: Optional[set[dace_nodes.AccessNode]] + _dependent_nodes: Optional[set[dace_nodes.AccessNode]] + + outer_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) def __init__( self, blocking_size: Optional[int] = None, blocking_parameter: Optional[Union[gtx_common.Dimension, str]] = None, + require_independent_nodes: Optional[bool] = None, ) -> None: super().__init__() if isinstance(blocking_parameter, gtx_common.Dimension): @@ -86,6 +88,10 @@ def __init__( self.blocking_parameter = blocking_parameter if blocking_size is not None: self.blocking_size = blocking_size + if require_independent_nodes is not None: + self.require_independent_nodes = require_independent_nodes + self._independent_nodes = None + self._dependent_nodes = None @classmethod def expressions(cls) -> Any: @@ -125,6 +131,8 @@ def can_be_applied( return False if not self.partition_map_output(graph, sdfg): return False + self._independent_nodes = None + self._dependent_nodes = None return True @@ -137,7 +145,6 @@ def apply( Performs the operation described in the doc string. """ - # Now compute the partitions of the nodes. self.partition_map_output(graph, sdfg) @@ -153,10 +160,8 @@ def apply( state=graph, sdfg=sdfg, ) - - # Clear the old partitions - self.independent_nodes = None - self.dependent_nodes = None + self._independent_nodes = None + self._dependent_nodes = None def _prepare_inner_outer_maps( self, @@ -258,6 +263,9 @@ def partition_map_output( member variables are updated. If the partition does not exists the function will return `False` and the respective member variables will be `None`. + The function will honor `self.require_independent_nodes`. Thus if no independent + nodes were found the function behaves as if the partition does not exist. + Args: state: The state on which we operate. sdfg: The SDFG in which we operate on. @@ -269,8 +277,8 @@ def partition_map_output( """ # Clear the previous partition. - self.independent_nodes = set() - self.dependent_nodes = None + self._independent_nodes = set() + self._dependent_nodes = None while True: # Find all the nodes that we have to classify in this iteration. @@ -279,9 +287,9 @@ def partition_map_output( nodes_to_classify: set[dace_nodes.Node] = { edge.dst for edge in state.out_edges(self.outer_entry) } - for independent_node in self.independent_nodes: + for independent_node in self._independent_nodes: nodes_to_classify.update({edge.dst for edge in state.out_edges(independent_node)}) - nodes_to_classify.difference_update(self.independent_nodes) + nodes_to_classify.difference_update(self._independent_nodes) # Now classify each node found_new_independent_node = False @@ -294,7 +302,7 @@ def partition_map_output( # Check if the partition exists. if class_res is None: - self.independent_nodes = None + self._independent_nodes = None return False if class_res is True: found_new_independent_node = True @@ -303,12 +311,16 @@ def partition_map_output( if not found_new_independent_node: break + if self.require_independent_nodes and len(self._independent_nodes) == 0: + self._independent_nodes = None + return False + # After the independent set is computed compute the set of dependent nodes # as the set of all nodes adjacent to `outer_entry` that are not dependent. - self.dependent_nodes = { + self._dependent_nodes = { edge.dst for edge in state.out_edges(self.outer_entry) - if edge.dst not in self.independent_nodes + if edge.dst not in self._independent_nodes } return True @@ -333,7 +345,7 @@ def _classify_node( Returns: The function returns `True` if `node_to_classify` is considered independent. - In this case the function will add the node to `self.independent_nodes`. + In this case the function will add the node to `self._independent_nodes`. If the function returns `False` the node was classified as a dependent node. The function will return `None` if the node can not be classified, in this case the partition does not exist. @@ -343,23 +355,50 @@ def _classify_node( state: The state containing the map. sdfg: The SDFG that is processed. """ + assert self._independent_nodes is not None # silence MyPy outer_entry: dace_nodes.MapEntry = self.outer_entry # for caching. + outer_exit: dace_nodes.MapExit = state.exit_node(outer_entry) + + # The node needs to have an input and output. + if state.in_degree(node_to_classify) == 0 or state.out_degree(node_to_classify) == 0: + return None # We are only able to handle certain kind of nodes, so screening them. if isinstance(node_to_classify, dace_nodes.Tasklet): if node_to_classify.side_effects: - # TODO(phimuell): Think of handling it. return None + + # A Tasklet must write to an AccessNode, because otherwise there would + # be nothing that could be used to cache anything. Furthermore, this + # AccessNode must be outside of the inner loop, i.e. be independent. + # TODO: Make this check stronger to ensure that there is always an + # AccessNode that is independent. + if not all( + isinstance(out_edge.dst, dace_nodes.AccessNode) + for out_edge in state.out_edges(node_to_classify) + if not out_edge.data.is_empty() + ): + return False + elif isinstance(node_to_classify, dace_nodes.AccessNode): # AccessNodes need to have some special properties. node_desc: dace.data.Data = node_to_classify.desc(sdfg) - if isinstance(node_desc, dace.data.View): # Views are forbidden. return None - if node_desc.lifetime != dace.dtypes.AllocationLifetime.Scope: - # The access node has to life fully within the scope. + + # The access node inside either has scope lifetime or is a scalar. + if isinstance(node_desc, dace.data.Scalar): + pass + elif node_desc.lifetime != dace.dtypes.AllocationLifetime.Scope: return None + + elif isinstance(node_to_classify, dace_nodes.MapEntry): + # We classify `MapEntries` as dependent nodes, we could now start + # looking if the whole map is independent, but it is currently an + # overkill. + return False + else: # Any other node type we can not handle, so the partition can not exist. # TODO(phimuell): Try to handle certain kind of library nodes. @@ -376,29 +415,12 @@ def _classify_node( # for these classification to make sense the partition has to exist in the # first place. - # Either all incoming edges of a node are empty or none of them. If it has - # empty edges, they are only allowed to come from the map entry. - found_empty_edges, found_nonempty_edges = False, False - for in_edge in in_edges: - if in_edge.data.is_empty(): - found_empty_edges = True - if in_edge.src is not outer_entry: - # TODO(phimuell): Lift this restriction. - return None - else: - found_nonempty_edges = True - - # Test if we found a mixture of empty and nonempty edges. - if found_empty_edges and found_nonempty_edges: - return None - assert ( - found_empty_edges or found_nonempty_edges - ), f"Node '{node_to_classify}' inside '{outer_entry}' without an input connection." - - # Requiring that all output Memlets are non empty implies, because we are - # inside a scope, that there exists an output. - if any(out_edge.data.is_empty() for out_edge in state.out_edges(node_to_classify)): - return None + # There are some very small requirements that we impose on the output edges. + for out_edge in state.out_edges(node_to_classify): + # We consider nodes that are directly connected to the outer map exit as + # dependent. This is an implementation detail to avoid some hard cases. + if out_edge.dst is outer_exit: + return False # Now we have ensured that the partition exists, thus we will now evaluate # if the node is independent or dependent. @@ -413,7 +435,7 @@ def _classify_node( # Now we have to look at incoming edges individually. # We will inspect the subset of the Memlet to see if they depend on the # block variable. If this loop ends normally, then we classify the node - # as independent and the node is added to `independent_nodes`. + # as independent and the node is added to `_independent_nodes`. for in_edge in in_edges: memlet: dace.Memlet = in_edge.data src_subset: dace_subsets.Subset | None = memlet.src_subset @@ -436,11 +458,11 @@ def _classify_node( # The edge must either originate from `outer_entry` or from an independent # node if not it is dependent. - if not (in_edge.src is outer_entry or in_edge.src in self.independent_nodes): + if not (in_edge.src is outer_entry or in_edge.src in self._independent_nodes): return False # Loop ended normally, thus we classify the node as independent. - self.independent_nodes.add(node_to_classify) + self._independent_nodes.add(node_to_classify) return True def _rewire_map_scope( @@ -467,116 +489,138 @@ def _rewire_map_scope( state: The state of the map. sdfg: The SDFG we operate on. """ + assert self._independent_nodes is not None and self._dependent_nodes is not None # Contains the nodes that are already have been handled. relocated_nodes: set[dace_nodes.Node] = set() # We now handle all independent nodes, this means that all of their - # _output_ edges have to go through the new inner map and the Memlets need - # modifications, because of the block parameter. - for independent_node in self.independent_nodes: - for out_edge in state.out_edges(independent_node): + # _output_ edges have to go through the new inner map and the Memlets + # need modifications, because of the block parameter. + for independent_node in self._independent_nodes: + for out_edge in list(state.out_edges(independent_node)): edge_dst: dace_nodes.Node = out_edge.dst relocated_nodes.add(edge_dst) # If destination of this edge is also independent we do not need # to handle it, because that node will also be before the new # inner serial map. - if edge_dst in self.independent_nodes: + if edge_dst in self._independent_nodes: continue # Now split `out_edge` such that it passes through the new inner entry. # We do not need to modify the subsets, i.e. replacing the variable # on which we block, because the node is independent and the outgoing # new inner map entry iterate over the blocked variable. - new_map_conn = inner_entry.next_connector() - dace_helpers.redirect_edge( - state=state, - edge=out_edge, - new_dst=inner_entry, - new_dst_conn="IN_" + new_map_conn, + if out_edge.data.is_empty(): + # `out_edge` is an empty Memlet that ensures its source, which is + # independent, is sequenced before its destination, which is + # dependent. We now have to split it into two. + # TODO(phimuell): Can we remove this edge? Is the map enough to + # ensure proper sequencing? + new_in_conn = None + new_out_conn = None + new_memlet_outside = dace.Memlet() + + elif not isinstance(independent_node, dace_nodes.AccessNode): + # For syntactical reasons there must be an access node on the + # outside of the (inner) scope, that acts as cache. The + # classification and this preconditions on SDFG should ensure + # that, but there are a few super hard edge cases. + # TODO(phimuell): Add an intermediate here in this case + raise NotImplementedError() + + else: + # NOTE: This creates more connections that are ultimately + # necessary. However, figuring out which one to use and if + # it would be valid, is very complicated, so we don't do it. + new_map_conn = inner_entry.next_connector(try_name=out_edge.data.data) + new_in_conn = "IN_" + new_map_conn + new_out_conn = "OUT_" + new_map_conn + new_memlet_outside = dace.Memlet.from_array( + out_edge.data.data, sdfg.arrays[out_edge.data.data] + ) + inner_entry.add_in_connector(new_in_conn) + inner_entry.add_out_connector(new_out_conn) + + state.add_edge( + out_edge.src, + out_edge.src_conn, + inner_entry, + new_in_conn, + new_memlet_outside, ) - # TODO(phimuell): Check if there might be a subset error. state.add_edge( inner_entry, - "OUT_" + new_map_conn, + new_out_conn, out_edge.dst, out_edge.dst_conn, copy.deepcopy(out_edge.data), ) - inner_entry.add_in_connector("IN_" + new_map_conn) - inner_entry.add_out_connector("OUT_" + new_map_conn) + state.remove_edge(out_edge) # Now we handle the dependent nodes, they differ from the independent nodes - # in that they _after_ the new inner map entry. Thus, we will modify incoming edges. - for dependent_node in self.dependent_nodes: + # in that they _after_ the new inner map entry. Thus, we have to modify + # their incoming edges. + for dependent_node in self._dependent_nodes: for in_edge in state.in_edges(dependent_node): edge_src: dace_nodes.Node = in_edge.src - # Since the independent nodes were already processed, and they process - # their output we have to check for this. We do this by checking if - # the source of the edge is the new inner map entry. + # The incoming edge of a dependent node (before any processing) either + # starts at: + # - The outer map. + # - An other dependent node. + # - An independent node. + # The last case was already handled by the loop above. if edge_src is inner_entry: + # Edge originated originally at an independent node, but was + # already handled by the loop above. assert dependent_node in relocated_nodes - continue - - # A dependent node has at least one connection to the outer map entry. - # And these are the only connections that we must handle, since other - # connections come from independent nodes, and were already handled - # or are inner nodes. - if edge_src is not outer_entry: - continue - # If we encounter an empty Memlet we just just attach it to the - # new inner map entry. Note the partition function ensures that - # either all edges are empty or non. - if in_edge.data.is_empty(): - assert ( - edge_src is outer_entry - ), f"Found an empty edge that does not go to the outer map entry, but to '{edge_src}'." + elif edge_src is not outer_entry: + # Edge originated at an other dependent node. There is nothing + # that we have to do. + # NOTE: We can not test if `edge_src` is in `self._dependent_nodes` + # because it only contains the dependent nodes that are directly + # connected to the map entry. + assert edge_src not in self._independent_nodes + + elif in_edge.data.is_empty(): + # The dependent node has an empty Memlet to the other map. + # Since the inner map is sequenced after the outer map, + # we will simply reconnect the edge to the inner map. + # TODO(phimuell): Are there situations where this makes problems. dace_helpers.redirect_edge(state=state, edge=in_edge, new_src=inner_entry) - continue - - # Because of the definition of a dependent node and the processing - # order, their incoming edges either point to the outer map or - # are already handled. - assert ( - edge_src is outer_entry - ), f"Expected to find source '{outer_entry}' but found '{edge_src}'." - edge_conn: str = in_edge.src_conn[4:] - - # Must be before the handling of the modification below - # Note that this will remove the original edge from the SDFG. - dace_helpers.redirect_edge( - state=state, - edge=in_edge, - new_src=inner_entry, - new_src_conn="OUT_" + edge_conn, - ) - # In a valid SDFG only one edge can go into an input connector of a Map. - if "IN_" + edge_conn in inner_entry.in_connectors: - # We have found this edge multiple times already. - # To ensure that there is no error, we will create a new - # Memlet that reads the whole array. - piping_edge = next(state.in_edges_by_connector(inner_entry, "IN_" + edge_conn)) - data_name = piping_edge.data.data - piping_edge.data = dace.Memlet.from_array( - data_name, sdfg.arrays[data_name], piping_edge.data.wcr + elif edge_src is outer_entry: + # This dependent node originated at the outer map. Thus we have to + # split the edge, such that it now passes through the inner map. + new_map_conn = inner_entry.next_connector(try_name=in_edge.src_conn[4:]) + new_in_conn = "IN_" + new_map_conn + new_out_conn = "OUT_" + new_map_conn + new_memlet_inner = dace.Memlet.from_array( + in_edge.data.data, sdfg.arrays[in_edge.data.data] ) - - else: - # This is the first time we found this connection. - # so we just create the edge. state.add_edge( - outer_entry, - "OUT_" + edge_conn, + in_edge.src, + in_edge.src_conn, inner_entry, - "IN_" + edge_conn, + new_in_conn, + new_memlet_inner, + ) + state.add_edge( + inner_entry, + new_out_conn, + in_edge.dst, + in_edge.dst_conn, copy.deepcopy(in_edge.data), ) - inner_entry.add_in_connector("IN_" + edge_conn) - inner_entry.add_out_connector("OUT_" + edge_conn) + inner_entry.add_in_connector(new_in_conn) + inner_entry.add_out_connector(new_out_conn) + state.remove_edge(in_edge) + + else: + raise NotImplementedError("Unknown node configuration.") # In certain cases it might happen that we need to create an empty # Memlet between the outer map entry and the inner one. @@ -593,7 +637,7 @@ def _rewire_map_scope( # This is simple reconnecting, there would be possibilities for improvements # but we do not use them for now. for in_edge in state.in_edges(outer_exit): - edge_conn = in_edge.dst_conn[3:] + edge_conn = inner_exit.next_connector(in_edge.dst_conn[3:]) dace_helpers.redirect_edge( state=state, edge=in_edge, @@ -610,5 +654,9 @@ def _rewire_map_scope( inner_exit.add_in_connector("IN_" + edge_conn) inner_exit.add_out_connector("OUT_" + edge_conn) + # There is an invalid cache state in the SDFG, that makes the memlet + # propagation fail, to clear the cache we call the hash function. + # See: https://github.com/spcl/dace/issues/1703 + _ = sdfg.hash_sdfg() # TODO(phimuell): Use a less expensive method. dace.sdfg.propagation.propagate_memlets_state(sdfg, state) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py index ec33e7ea63..eceb07ed82 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_helper.py @@ -6,89 +6,106 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -"""Implements helper functions for the map fusion transformations. +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +"""Implements Helper functionaliyies for map fusion -Note: - After DaCe [PR#1629](https://github.com/spcl/dace/pull/1629), that implements - a better map fusion transformation is merged, this file will be deleted. +THIS FILE WAS COPIED FROM DACE TO FACILITATE DEVELOPMENT UNTIL THE PR#1625 IN +DACE IS MERGED AND THE VERSION WAS UPGRADED. """ -import functools -import itertools -from typing import Any, Optional, Sequence, Union + +# ruff: noqa + +import copy +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union, Callable, TypeAlias import dace -from dace import ( - data as dace_data, - properties as dace_properties, - subsets as dace_subsets, - transformation as dace_transformation, -) -from dace.sdfg import graph as dace_graph, nodes as dace_nodes, validation as dace_validation -from dace.transformation import helpers as dace_helpers - -from gt4py.next.program_processors.runners.dace_fieldview.transformations import util - - -@dace_properties.make_properties -class MapFusionHelper(dace_transformation.SingleStateTransformation): - """Contains common part of the fusion for parallel and serial Map fusion. - - The transformation assumes that the SDFG obeys the principals outlined in - [ADR0018](https://github.com/GridTools/gt4py/tree/main/docs/development/ADRs/0018-Canonical_SDFG_in_GT4Py_Transformations.md). - The main advantage of this structure is, that it is rather easy to determine - if a transient is used anywhere else. This check, performed by - `is_interstate_transient()`. It is further speeded up by cashing some computation, - thus such an object should not be used after interstate optimizations were applied - to the SDFG. +from dace import data, properties, subsets, symbolic, transformation +from dace.sdfg import SDFG, SDFGState, nodes, validation +from dace.transformation import helpers + +FusionCallback: TypeAlias = Callable[ + ["MapFusionHelper", nodes.MapEntry, nodes.MapEntry, dace.SDFGState, dace.SDFG, bool], bool +] +"""Callback for the map fusion transformation to check if a fusion should be performed. +""" + + +@properties.make_properties +class MapFusionHelper(transformation.SingleStateTransformation): + """Common parts of the parallel and serial map fusion transformation. Args: only_inner_maps: Only match Maps that are internal, i.e. inside another Map. only_toplevel_maps: Only consider Maps that are at the top. + strict_dataflow: If `True`, the transformation ensures a more + stricter version of the data flow. + apply_fusion_callback: A user supplied function, same signature as `can_be_fused()`, + to check if a fusion should be performed. + + Note: + If `strict_dataflow` mode is enabled then the transformation will not remove + _direct_ data flow dependency from the graph. Furthermore, the transformation + will not remove size 1 dimensions of intermediate it creates. + This is a compatibility mode, that will limit the applicability of the + transformation, but might help transformations that do not fully analyse + the graph. """ - only_toplevel_maps = dace_properties.Property( + only_toplevel_maps = properties.Property( dtype=bool, default=False, - allow_none=False, desc="Only perform fusing if the Maps are in the top level.", ) - only_inner_maps = dace_properties.Property( + only_inner_maps = properties.Property( dtype=bool, default=False, - allow_none=False, desc="Only perform fusing if the Maps are inner Maps, i.e. does not have top level scope.", ) - shared_transients = dace_properties.DictProperty( - key_type=dace.SDFG, - value_type=set[str], - default=None, - allow_none=True, - desc="Maps SDFGs to the set of array transients that can not be removed. " - "The variable acts as a cache, and is managed by 'is_interstate_transient()'.", + strict_dataflow = properties.Property( + dtype=bool, + default=False, + desc="If `True` then the transformation will ensure a more stricter data flow.", ) + # Callable that can be specified by the user, if it is specified, it should be + # a callable with the same signature as `can_be_fused()`. If the function returns + # `False` then the fusion will be rejected. + _apply_fusion_callback: Optional[FusionCallback] + + # Maps SDFGs to the set of data that can not be removed, + # because they transmit data _between states_, such data will be made 'shared'. + # This variable acts as a cache, and is managed by 'is_shared_data()'. + _shared_data: Dict[SDFG, Set[str]] + def __init__( self, only_inner_maps: Optional[bool] = None, only_toplevel_maps: Optional[bool] = None, + strict_dataflow: Optional[bool] = None, + apply_fusion_callback: Optional[FusionCallback] = None, **kwargs: Any, ) -> None: super().__init__(**kwargs) + self._shared_data = {} + self._apply_fusion_callback = None if only_toplevel_maps is not None: self.only_toplevel_maps = bool(only_toplevel_maps) if only_inner_maps is not None: self.only_inner_maps = bool(only_inner_maps) - self.shared_transients = {} + if strict_dataflow is not None: + self.strict_dataflow = bool(strict_dataflow) + if apply_fusion_callback is not None: + self._apply_fusion_callback = apply_fusion_callback @classmethod def expressions(cls) -> bool: - raise RuntimeError("The `_MapFusionHelper` is not a transformation on its own.") + raise RuntimeError("The `MapFusionHelper` is not a transformation on its own.") def can_be_fused( self, - map_entry_1: dace_nodes.MapEntry, - map_entry_2: dace_nodes.MapEntry, + map_entry_1: nodes.MapEntry, + map_entry_2: nodes.MapEntry, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG, permissive: bool = False, @@ -97,13 +114,11 @@ def can_be_fused( This function only checks constrains that are common between serial and parallel map fusion process, which includes: + - The registered callback, if specified. - The scope of the maps. - The scheduling of the maps. - The map parameters. - However, for performance reasons, the function does not check if the node - decomposition exists. - Args: map_entry_1: The entry of the first (in serial case the top) map. map_exit_2: The entry of the second (in serial case the bottom) map. @@ -111,6 +126,13 @@ def can_be_fused( sdfg: The SDFG itself. permissive: Currently unused. """ + # Consult the callback if defined. + if self._apply_fusion_callback is not None: + if not self._apply_fusion_callback( + self, map_entry_1, map_entry_2, graph, sdfg, permissive + ): + return False + if self.only_inner_maps and self.only_toplevel_maps: raise ValueError("You specified both `only_inner_maps` and `only_toplevel_maps`.") @@ -128,26 +150,22 @@ def can_be_fused( elif self.only_toplevel_maps: if scope[map_entry_1] is not None: return False - # TODO(phimuell): Figuring out why this is here. - elif util.is_nested_sdfg(sdfg): - return False - # We will now check if there exists a "remapping" that we can use. - # NOTE: The serial map promoter depends on the fact that this is the - # last check. - if not self.map_parameter_compatible( - map_1=map_entry_1.map, map_2=map_entry_2.map, state=graph, sdfg=sdfg + # We will now check if there exists a remapping of the map parameter + if ( + self.find_parameter_remapping(first_map=map_entry_1.map, second_map=map_entry_2.map) + is None ): return False return True - @staticmethod def relocate_nodes( - from_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], - to_node: Union[dace_nodes.MapExit, dace_nodes.MapEntry], - state: dace.SDFGState, - sdfg: dace.SDFG, + self, + from_node: Union[nodes.MapExit, nodes.MapEntry], + to_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, ) -> None: """Move the connectors and edges from `from_node` to `to_nodes` node. @@ -156,6 +174,7 @@ def relocate_nodes( once for the entry and then for the exit. While it does not remove the node themselves if guarantees that the `from_node` has degree zero. + The function assumes that the parameter renaming was already done. Args: from_node: Node from which the edges should be removed. @@ -165,22 +184,23 @@ def relocate_nodes( """ # Now we relocate empty Memlets, from the `from_node` to the `to_node` - for empty_edge in filter(lambda e: e.data.is_empty(), state.out_edges(from_node)): - dace_helpers.redirect_edge(state, empty_edge, new_src=to_node) - for empty_edge in filter(lambda e: e.data.is_empty(), state.in_edges(from_node)): - dace_helpers.redirect_edge(state, empty_edge, new_dst=to_node) + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.out_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_src=to_node) + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.in_edges(from_node))): + helpers.redirect_edge(state, empty_edge, new_dst=to_node) # We now ensure that there is only one empty Memlet from the `to_node` to any other node. # Although it is allowed, we try to prevent it. - empty_targets: set[dace_nodes.Node] = set() - for empty_edge in filter(lambda e: e.data.is_empty(), state.all_edges(to_node)): + empty_targets: Set[nodes.Node] = set() + for empty_edge in list(filter(lambda e: e.data.is_empty(), state.all_edges(to_node))): if empty_edge.dst in empty_targets: state.remove_edge(empty_edge) empty_targets.add(empty_edge.dst) # We now determine which edges we have to migrate, for this we are looking at # the incoming edges, because this allows us also to detect dynamic map ranges. - for edge_to_move in state.in_edges(from_node): + # TODO(phimuell): If there is already a connection to the node, reuse this. + for edge_to_move in list(state.in_edges(from_node)): assert isinstance(edge_to_move.dst_conn, str) if not edge_to_move.dst_conn.startswith("IN_"): @@ -200,36 +220,32 @@ def relocate_nodes( raise RuntimeError( # Might fail because of out connectors. f"Failed to add the dynamic map range symbol '{dmr_symbol}' to '{to_node}'." ) - dace_helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) + helpers.redirect_edge(state=state, edge=edge_to_move, new_dst=to_node) from_node.remove_in_connector(dmr_symbol) - # There is no other edge that we have to consider, so we just end here - continue - - # We have a Passthrough connection, i.e. there exists a matching `OUT_`. - old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix - new_conn = to_node.next_connector(old_conn) - - to_node.add_in_connector("IN_" + new_conn) - for e in state.in_edges_by_connector(from_node, "IN_" + old_conn): - dace_helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) - to_node.add_out_connector("OUT_" + new_conn) - for e in state.out_edges_by_connector(from_node, "OUT_" + old_conn): - dace_helpers.redirect_edge( - state, e, new_src=to_node, new_src_conn="OUT_" + new_conn - ) - from_node.remove_in_connector("IN_" + old_conn) - from_node.remove_out_connector("OUT_" + old_conn) + else: + # We have a Passthrough connection, i.e. there exists a matching `OUT_`. + old_conn = edge_to_move.dst_conn[3:] # The connection name without prefix + new_conn = to_node.next_connector(old_conn) + + to_node.add_in_connector("IN_" + new_conn) + for e in list(state.in_edges_by_connector(from_node, "IN_" + old_conn)): + helpers.redirect_edge(state, e, new_dst=to_node, new_dst_conn="IN_" + new_conn) + to_node.add_out_connector("OUT_" + new_conn) + for e in list(state.out_edges_by_connector(from_node, "OUT_" + old_conn)): + helpers.redirect_edge(state, e, new_src=to_node, new_src_conn="OUT_" + new_conn) + from_node.remove_in_connector("IN_" + old_conn) + from_node.remove_out_connector("OUT_" + old_conn) # Check if we succeeded. if state.out_degree(from_node) != 0: - raise dace_validation.InvalidSDFGError( + raise validation.InvalidSDFGError( f"Failed to relocate the outgoing edges from `{from_node}`, there are still `{state.out_edges(from_node)}`", sdfg, sdfg.node_id(state), ) if state.in_degree(from_node) != 0: - raise dace_validation.InvalidSDFGError( + raise validation.InvalidSDFGError( f"Failed to relocate the incoming edges from `{from_node}`, there are still `{state.in_edges(from_node)}`", sdfg, sdfg.node_id(state), @@ -237,330 +253,442 @@ def relocate_nodes( assert len(from_node.in_connectors) == 0 assert len(from_node.out_connectors) == 0 - @staticmethod - def map_parameter_compatible( - map_1: dace_nodes.Map, - map_2: dace_nodes.Map, - state: Union[dace.SDFGState, dace.SDFG], - sdfg: dace.SDFG, - ) -> bool: - """Checks if the parameters of `map_1` are compatible with `map_2`. + def find_parameter_remapping( + self, first_map: nodes.Map, second_map: nodes.Map + ) -> Union[Dict[str, str], None]: + """Computes the parameter remapping for the parameters of the _second_ map. + + The returned `dict` maps the parameters of the second map (keys) to parameter + names of the first map (values). Because of how the replace function works + the `dict` describes how to replace the parameters of the second map + with parameters of the first map. + Parameters that already have the correct name and compatible range, are not + included in the return value, thus the keys and values are always different. + If no renaming at all is _needed_, i.e. all parameter have the same name and + range, then the function returns an empty `dict`. + If no remapping exists, then the function will return `None`. - The check follows the following rules: - - The names of the map variables must be the same, i.e. no renaming - is performed. - - The ranges must be the same. + Args: + first_map: The first map (these parameters will be replaced). + second_map: The second map, these parameters acts as source. """ - range_1: dace_subsets.Range = map_1.range - params_1: Sequence[str] = map_1.params - range_2: dace_subsets.Range = map_2.range - params_2: Sequence[str] = map_2.params - - # The maps are only fuseable if we have an exact match in the parameter names - # this is because we do not do any renaming. This is in accordance with the - # rules. - if set(params_1) != set(params_2): - return False - # Maps the name of a parameter to the dimension index - param_dim_map_1: dict[str, int] = {pname: i for i, pname in enumerate(params_1)} - param_dim_map_2: dict[str, int] = {pname: i for i, pname in enumerate(params_2)} + # The parameter names + first_params: List[str] = first_map.params + second_params: List[str] = second_map.params + + if len(first_params) != len(second_params): + return None + + # The ranges, however, we apply some post processing to them. + simp = lambda e: symbolic.simplify_ext(symbolic.simplify(e)) # noqa: E731 + first_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) for param, rng in zip(first_params, first_map.range) + } + second_rngs: Dict[str, Tuple[Any, Any, Any]] = { + param: tuple(simp(r) for r in rng) + for param, rng in zip(second_params, second_map.range) + } + + # Parameters of the second map that have not yet been matched to a parameter + # of the first map and vice versa. + unmapped_second_params: Set[str] = set(second_params) + unused_first_params: Set[str] = set(first_params) + + # This is the result (`second_param -> first_param`), note that if no renaming + # is needed then the parameter is not present in the mapping. + final_mapping: Dict[str, str] = {} + + # First we identify the parameters that already have the correct name. + for param in set(first_params).intersection(second_params): + first_rng = first_rngs[param] + second_rng = second_rngs[param] + + if first_rng == second_rng: + # They have the same name and the same range, this is already a match. + # Because the names are already the same, we do not have to enter them + # in the `final_mapping` + unmapped_second_params.discard(param) + unused_first_params.discard(param) + + # Check if no remapping is needed. + if len(unmapped_second_params) == 0: + return {} + + # Now we go through all the parameters that we have not mapped yet. + # All of them will result in a remapping. + for unmapped_second_param in unmapped_second_params: + second_rng = second_rngs[unmapped_second_param] + assert unmapped_second_param not in final_mapping + + # Now look in all not yet used parameters of the first map which to use. + for candidate_param in unused_first_params: + candidate_rng = first_rngs[candidate_param] + if candidate_rng == second_rng: + final_mapping[unmapped_second_param] = candidate_param + unused_first_params.discard(candidate_param) + break + else: + # We did not find a candidate, so the remapping does not exist + return None - # To fuse the two maps the ranges must have the same ranges - for pname in params_1: - idx_1 = param_dim_map_1[pname] - idx_2 = param_dim_map_2[pname] - # TODO(phimuell): do we need to call simplify? - if range_1[idx_1] != range_2[idx_2]: - return False + assert len(unused_first_params) == 0 + assert len(final_mapping) == len(unmapped_second_params) + return final_mapping - return True + def rename_map_parameters( + self, + first_map: nodes.Map, + second_map: nodes.Map, + second_map_entry: nodes.MapEntry, + state: SDFGState, + ) -> None: + """Replaces the map parameters of the second map with names from the first. + + The replacement is done in a safe way, thus `{'i': 'j', 'j': 'i'}` is + handled correct. The function assumes that a proper replacement exists. + The replacement is computed by calling `self.find_parameter_remapping()`. + + Args: + first_map: The first map (these are the final parameter). + second_map: The second map, this map will be replaced. + second_map_entry: The entry node of the second map. + state: The SDFGState on which we operate. + """ + # Compute the replacement dict. + repl_dict: Dict[str, str] = self.find_parameter_remapping( # type: ignore[assignment] + first_map=first_map, second_map=second_map + ) + + if repl_dict is None: + raise RuntimeError("The replacement does not exist") + if len(repl_dict) == 0: + return + + second_map_scope = state.scope_subgraph(entry_node=second_map_entry) + # Why is this thing is symbolic and not in replace? + symbolic.safe_replace( + mapping=repl_dict, + replace_callback=second_map_scope.replace_dict, + ) - def is_interstate_transient( + # For some odd reason the replace function does not modify the range and + # parameter of the map, so we will do it the hard way. + second_map.params = copy.deepcopy(first_map.params) + second_map.range = copy.deepcopy(first_map.range) + + def is_shared_data( self, - transient: Union[str, dace_nodes.AccessNode], + data: nodes.AccessNode, sdfg: dace.SDFG, - state: dace.SDFGState, ) -> bool: - """Tests if `transient` is an interstate transient, an can not be removed. - - Essentially this function checks if a transient might be needed in a - different state in the SDFG, because it transmit information from - one state to the other. - If only the name of the data container is passed the function will - first look for an corresponding access node. + """Tests if `data` is interstate data, an can not be removed. - The set of these "interstate transients" is computed once per SDFG. - The result is then cached internally for later reuse. + Interstate data is used to transmit data between multiple state or by + extension within the state. Thus it must be classified as a shared output. + This function will go through the SDFG to and collect the names of all data + container that should be classified as shared. Note that this is an over + approximation as it does not take the location into account, i.e. "is no longer + used". Args: transient: The transient that should be checked. sdfg: The SDFG containing the array. - state: If given the state the node is located in. + + Note: + The function computes the this set once for every SDFG and then caches it. + There is no mechanism to detect if the cache must be evicted. However, + as long as no additional data is added, there is no problem. """ + if sdfg not in self._shared_data: + self._compute_shared_data(sdfg) + return data.data in self._shared_data[sdfg] - # The following builds upon the HACK MD document and not on ADR0018. - # Therefore the numbers are slightly different, but both documents - # essentially describes the same SDFG. - # According to [rule 6](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG) - # the set of such transients is partially given by all source access dace_nodes. - # Because of rule 3 we also include all scalars in this set, as an over - # approximation. Furthermore, because simplify might violate rule 3, - # we also include the sink dace_nodes. - - # See if we have already computed the set - if sdfg in self.shared_transients: - shared_sdfg_transients: set[str] = self.shared_transients[sdfg] - else: - # SDFG is not known so we have to compute the set. - shared_sdfg_transients = set() - for state_to_scan in sdfg.all_states(): - # TODO(phimuell): Use `all_nodes_recursive()` once it is available. - shared_sdfg_transients.update( - [ - node.data - for node in itertools.chain( - state_to_scan.source_nodes(), state_to_scan.sink_nodes() - ) - if isinstance(node, dace_nodes.AccessNode) - and sdfg.arrays[node.data].transient - ] + def _compute_shared_data( + self, + sdfg: dace.SDFG, + ) -> None: + """Updates the internal set of shared data/interstate data of `self` for `sdfg`. + + See the documentation for `self.is_shared_data()` for a description. + + Args: + sdfg: The SDFG for which the set of shared data should be computed. + """ + # Shared data of this SDFG. + shared_data: Set[str] = set() + + # All global data can not be removed, so it must always be shared. + for data_name, data_desc in sdfg.arrays.items(): + if not data_desc.transient: + shared_data.add(data_name) + elif isinstance(data_desc, dace.data.Scalar): + shared_data.add(data_name) + + # We go through all states and classify the nodes/data: + # - Data is referred to in different states. + # - The access node is a view (both have to survive). + # - Transient sink or source node. + # - The access node has output degree larger than 1 (input degrees larger + # than one, will always be partitioned as shared anyway). + prevously_seen_data: Set[str] = set() + interstate_read_symbols: Set[str] = set() + for state in sdfg.nodes(): + for access_node in state.data_nodes(): + if access_node.data in shared_data: + # The data was already classified to be shared data + pass + + elif access_node.data in prevously_seen_data: + # We have seen this data before, either in this state or in + # a previous one, but we did not classifies it as shared back then + shared_data.add(access_node.data) + + if state.in_degree(access_node) == 0: + # (Transient) sink nodes are used in other states, or simplify + # will get rid of them. + shared_data.add(access_node.data) + + elif ( + state.out_degree(access_node) != 1 + ): # state.out_degree() == 0 or state.out_degree() > 1 + # The access node is either a source node (it is shared in another + # state) or the node has a degree larger than one, so it is used + # in this state somewhere else. + shared_data.add(access_node.data) + + elif self.is_view(node=access_node, sdfg=sdfg): + # To ensure that the write to the view happens, both have to be shared. + viewed_data: str = self.track_view( + view=access_node, state=state, sdfg=sdfg + ).data + shared_data.update([access_node.data, viewed_data]) + prevously_seen_data.update([access_node.data, viewed_data]) + + else: + # The node was not classified as shared data, so we record that + # we saw it. Note that a node that was immediately classified + # as shared node will never be added to this set, but a data + # that was found twice will be inside this list. + prevously_seen_data.add(access_node.data) + + # Now we are collecting all symbols that interstate edges read from. + for edge in sdfg.edges(): + interstate_read_symbols.update(edge.data.read_symbols()) + + # We also have to keep everything the edges referrers to and is an array. + shared_data.update(interstate_read_symbols.intersection(prevously_seen_data)) + + # Update the internal cache + self._shared_data[sdfg] = shared_data + + def _compute_multi_write_data( + self, + state: SDFGState, + sdfg: SDFG, + ) -> Set[str]: + """Computes data inside a _single_ state, that is written multiple times. + + Essentially this function computes the set of data that does not follow + the single static assignment idiom. The function also resolves views. + If an access node, refers to a view, not only the view itself, but also + the data it refers to is added to the set. + + Args: + state: The state that should be examined. + sdfg: The SDFG object. + + Note: + This information is used by the partition function (in case strict data + flow mode is enabled), in strict data flow mode only. The current + implementation is rather simple as it only checks if a data is written + to multiple times in the same state. + """ + data_written_to: Set[str] = set() + multi_write_data: Set[str] = set() + + for access_node in state.data_nodes(): + if state.in_degree(access_node) == 0: + continue + if access_node.data in data_written_to: + multi_write_data.add(access_node.data) + elif self.is_view(access_node, sdfg): + # This is an over approximation. + multi_write_data.update( + [access_node.data, self.track_view(access_node, state, sdfg).data] ) - self.shared_transients[sdfg] = shared_sdfg_transients - - if isinstance(transient, str): - name = transient - matching_access_nodes = [node for node in state.data_nodes() if node.data == name] - # Rule 8: There is only one access node per state for data. - assert len(matching_access_nodes) == 1 - transient = matching_access_nodes[0] - else: - assert isinstance(transient, dace_nodes.AccessNode) - name = transient.data + data_written_to.add(access_node.data) + return multi_write_data - desc: dace_data.Data = sdfg.arrays[name] - if not desc.transient: - return True - if isinstance(desc, dace_data.Scalar): - return True # Scalars can not be removed by fusion anyway. + def is_node_reachable_from( + self, + graph: Union[dace.SDFG, dace.SDFGState], + begin: nodes.Node, + end: nodes.Node, + ) -> bool: + """Test if the node `end` can be reached from `begin`. + + Essentially the function starts a DFS at `begin`. If an edge is found that lead + to `end` the function returns `True`. If the node is never found `False` is + returned. + + Args: + graph: The graph to operate on. + begin: The start of the DFS. + end: The node that should be located. + """ - # Rule 8: If degree larger than one then it is used within the state. - if state.out_degree(transient) > 1: - return True + def next_nodes(node: nodes.Node) -> Iterable[nodes.Node]: + return (edge.dst for edge in graph.out_edges(node)) - # Now we check if it is used in a different state. - return name in shared_sdfg_transients + to_visit: List[nodes.Node] = [begin] + seen: Set[nodes.Node] = set() - def partition_first_outputs( + while len(to_visit) > 0: + node: nodes.Node = to_visit.pop() + if node == end: + return True + elif node not in seen: + to_visit.extend(next_nodes(node)) + seen.add(node) + + # We never found `end` + return False + + def get_access_set( self, - state: dace.SDFGState, - sdfg: dace.SDFG, - map_exit_1: dace_nodes.MapExit, - map_entry_2: dace_nodes.MapEntry, - ) -> Union[ - tuple[ - set[dace_graph.MultiConnectorEdge[dace.Memlet]], - set[dace_graph.MultiConnectorEdge[dace.Memlet]], - set[dace_graph.MultiConnectorEdge[dace.Memlet]], - ], - None, - ]: - """Partition the output edges of `map_exit_1` for serial map fusion. - - The output edges of the first map are partitioned into three distinct sets, - defined as follows: - - - Pure Output Set `\mathbb{P}`: - These edges exits the first map and does not enter the second map. These - outputs will be simply be moved to the output of the second map. - - Exclusive Intermediate Set `\mathbb{E}`: - Edges in this set leaves the first map exit, enters an access node, from - where a Memlet then leads immediately to the second map. The memory - referenced by this access node is not used anywhere else, thus it can - be removed. - - Shared Intermediate Set `\mathbb{S}`: - These edges are very similar to the one in `\mathbb{E}` except that they - are used somewhere else, thus they can not be removed and have to be - recreated as output of the second map. - - Returns: - If such a decomposition exists the function will return the three sets - mentioned above in the same order. - In case the decomposition does not exist, i.e. the maps can not be fused - the function returns `None`. + scope_node: Union[nodes.MapEntry, nodes.MapExit], + state: SDFGState, + ) -> Set[nodes.AccessNode]: + """Computes the access set of a "scope node". + + If `scope_node` is a `MapEntry` it will operate on the set of incoming edges + and if it is an `MapExit` on the set of outgoing edges. The function will + then determine all access nodes that have a connection through these edges + to the scope nodes (edges that does not lead to access nodes are ignored). + The function returns a set that contains all access nodes that were found. + It is important that this set will also contain views. Args: - state: The in which the two maps are located. - sdfg: The full SDFG in whcih we operate. - map_exit_1: The exit node of the first map. - map_entry_2: The entry node of the second map. + scope_node: The scope node that should be evaluated. + state: The state in which we operate. """ - # The three outputs set. - pure_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - exclusive_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - shared_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() + if isinstance(scope_node, nodes.MapEntry): + get_edges = lambda node: state.in_edges(node) # noqa: E731 + other_node = lambda e: e.src # noqa: E731 + else: + get_edges = lambda node: state.out_edges(node) # noqa: E731 + other_node = lambda e: e.dst # noqa: E731 + access_set: Set[nodes.AccessNode] = { + node + for node in map(other_node, get_edges(scope_node)) + if isinstance(node, nodes.AccessNode) + } - # Set of intermediate nodes that we have already processed. - processed_inter_nodes: set[dace_nodes.Node] = set() + return access_set - # Now scan all output edges of the first exit and classify them - for out_edge in state.out_edges(map_exit_1): - intermediate_node: dace_nodes.Node = out_edge.dst + def find_subsets( + self, + node: nodes.AccessNode, + scope_node: Union[nodes.MapExit, nodes.MapEntry], + state: SDFGState, + sdfg: SDFG, + repl_dict: Optional[Dict[str, str]], + ) -> List[subsets.Subset]: + """Finds all subsets that access `node` within `scope_node`. + + The function will not start a search for all consumer/producers. + Instead it will locate the edges which is immediately inside the + map scope. - # We already processed the node, this should indicate that we should - # run simplify again, or we should start implementing this case. - if intermediate_node in processed_inter_nodes: - return None - processed_inter_nodes.add(intermediate_node) - - # Now let's look at all nodes that are downstream of the intermediate node. - # This, among other things, will tell us, how we have to handle this node. - downstream_nodes = util.all_nodes_between( - graph=state, - begin=intermediate_node, - end=map_entry_2, + Args: + node: The access node that should be examined. + scope_node: We are only interested in data that flows through this node. + state: The state in which we operate. + sdfg: The SDFG object. + """ + + # Is the node used for reading or for writing. + # This influences how we have to proceed. + if isinstance(scope_node, nodes.MapEntry): + outer_edges_to_inspect = [e for e in state.in_edges(scope_node) if e.src == node] + get_subset = lambda e: e.data.src_subset # noqa: E731 + get_inner_edges = lambda e: state.out_edges_by_connector( + scope_node, "OUT_" + e.dst_conn[3:] + ) + else: + outer_edges_to_inspect = [e for e in state.out_edges(scope_node) if e.dst == node] + get_subset = lambda e: e.data.dst_subset # noqa: E731 + get_inner_edges = lambda e: state.in_edges_by_connector( + scope_node, "IN_" + e.src_conn[4:] ) - # If `downstream_nodes` is `None` this means that `map_entry_2` was never - # reached, thus `intermediate_node` does not enter the second map and - # the node is a pure output node. - if downstream_nodes is None: - pure_outputs.add(out_edge) - continue + found_subsets: List[subsets.Subset] = [] + for edge in outer_edges_to_inspect: + found_subsets.extend(get_subset(e) for e in get_inner_edges(edge)) + assert len(found_subsets) > 0, "Could not find any subsets." + assert not any(subset is None for subset in found_subsets) - # The following tests are _after_ we have determined if we have a pure - # output node, because this allows us to handle more exotic pure node - # cases, as handling them is essentially rerouting an edge, whereas - # handling intermediate nodes is much more complicated. + found_subsets = copy.deepcopy(found_subsets) + if repl_dict: + for subset in found_subsets: + # Replace happens in place + symbolic.safe_replace(repl_dict, subset.replace) - # Empty Memlets are only allowed if they are in `\mathbb{P}`, which - # is also the only place they really make sense (for a map exit). - # Thus if we now found an empty Memlet we reject it. - if out_edge.data.is_empty(): - return None + return found_subsets - # In case the intermediate has more than one entry, all must come from the - # first map, otherwise we can not fuse them. Currently we restrict this - # even further by saying that it has only one incoming Memlet. - if state.in_degree(intermediate_node) != 1: - return None + def is_view( + self, + node: nodes.AccessNode, + sdfg: SDFG, + ) -> bool: + """Tests if `node` points to a view or not.""" + node_desc: data.Data = node.desc(sdfg) + return isinstance(node_desc, data.View) - # It can happen that multiple edges converges at the `IN_` connector - # of the first map exit, but there is only one edge leaving the exit. - # It is complicate to handle this, so for now we ignore it. - # TODO(phimuell): Handle this case properly. - inner_collector_edges = list( - state.in_edges_by_connector(intermediate_node, "IN_" + out_edge.src_conn[3:]) - ) - if len(inner_collector_edges) > 1: - return None + def track_view( + self, + view: nodes.AccessNode, + state: SDFGState, + sdfg: SDFG, + ) -> nodes.AccessNode: + """Find the original data of a View. - # For us an intermediate node must always be an access node, because - # everything else we do not know how to handle. It is important that - # we do not test for non transient data here, because they can be - # handled has shared intermediates. - if not isinstance(intermediate_node, dace_nodes.AccessNode): - return None - intermediate_desc: dace_data.Data = intermediate_node.desc(sdfg) - if isinstance(intermediate_desc, dace_data.View): - return None + Given the View `view`, the function will trace the view back to the original + access node. For convenience, if `view` is not a `View` the argument will be + returned. - # There are some restrictions we have on intermediate dace_nodes. The first one - # is that we do not allow WCR, this is because they need special handling - # which is currently not implement (the DaCe transformation has this - # restriction as well). The second one is that we can reduce the - # intermediate node and only feed a part into the second map, consider - # the case `b = a + 1; return b + 2`, where we have arrays. In this - # example only a single element must be available to the second map. - # However, this is hard to check so we will make a simplification. - # First, we will not check it at the producer, but at the consumer point. - # There we assume if the consumer does _not consume the whole_ - # intermediate array, then we can decompose the intermediate, by setting - # the map iteration index to zero and recover the shape, see - # implementation in the actual fusion routine. - # This is an assumption that is in most cases correct, but not always. - # However, doing it correctly is extremely complex. - for _, produce_edge in util.find_upstream_producers(state, out_edge): - if produce_edge.data.wcr is not None: - return None - - if len(downstream_nodes) == 0: - # There is nothing between intermediate node and the entry of the - # second map, thus the edge belongs either in `\mathbb{S}` or - # `\mathbb{E}`. - - # This is a very special situation, i.e. the access node has many - # different connections to the second map entry, this is a special - # case that we do not handle. - # TODO(phimuell): Handle this case. - if state.out_degree(intermediate_node) != 1: - return None - - # Certain nodes need more than one element as input. As explained - # above, in this situation we assume that we can naturally decompose - # them iff the node does not consume that whole intermediate. - # Furthermore, it can not be a dynamic map range or a library node. - intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) - consumers = util.find_downstream_consumers(state=state, begin=intermediate_node) - for consumer_node, feed_edge in consumers: - # TODO(phimuell): Improve this approximation. - if ( - intermediate_size != 1 - ) and feed_edge.data.num_elements() == intermediate_size: - return None - if consumer_node is map_entry_2: # Dynamic map range. - return None - if isinstance(consumer_node, dace_nodes.LibraryNode): - # TODO(phimuell): Allow some library dace_nodes. - return None - - # Note that "remove" has a special meaning here, regardless of the - # output of the check function, from within the second map we remove - # the intermediate, it has more the meaning of "do we need to - # reconstruct it after the second map again?" - if self.is_interstate_transient(intermediate_node, sdfg, state): - shared_outputs.add(out_edge) - else: - exclusive_outputs.add(out_edge) - continue + Args: + view: The view that should be traced. + state: The state in which we operate. + sdfg: The SDFG on which we operate. + """ - else: - # There is not only a single connection from the intermediate node to - # the second map, but the intermediate has more connections, thus - # the node might belong to the shared output. Of the many different - # possibilities, we only consider a single case: - # - The intermediate has a single connection to the second map, that - # fulfills the restriction outlined above. - # - All other connections have no connection to the second map. - found_second_entry = False - intermediate_size = functools.reduce(lambda a, b: a * b, intermediate_desc.shape) - for edge in state.out_edges(intermediate_node): - if edge.dst is map_entry_2: - if found_second_entry: # The second map was found again. - return None - found_second_entry = True - consumers = util.find_downstream_consumers(state=state, begin=edge) - for consumer_node, feed_edge in consumers: - if feed_edge.data.num_elements() == intermediate_size: - return None - if consumer_node is map_entry_2: # Dynamic map range - return None - if isinstance(consumer_node, dace_nodes.LibraryNode): - # TODO(phimuell): Allow some library dace_nodes. - return None - else: - # Ensure that there is no path that leads to the second map. - after_intermdiate_node = util.all_nodes_between( - graph=state, begin=edge.dst, end=map_entry_2 - ) - if after_intermdiate_node is not None: - return None - # If we are here, then we know that the node is a shared output - shared_outputs.add(out_edge) - continue + # Test if it is a view at all, if not return the passed node as source. + if not self.is_view(view, sdfg): + return view + + # First determine if the view is used for reading or writing. + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") + if curr_edge.dst_conn == "views": + # The view is used for reading. + next_node = lambda curr_edge: curr_edge.src # noqa: E731 + elif curr_edge.src_conn == "views": + # The view is used for writing. + next_node = lambda curr_edge: curr_edge.dst # noqa: E731 + else: + raise RuntimeError( + f"Failed to determine the direction of the view '{view}' | {curr_edge}." + ) - assert exclusive_outputs or shared_outputs or pure_outputs - assert len(processed_inter_nodes) == sum( - len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs] - ) - return (pure_outputs, exclusive_outputs, shared_outputs) + # Now trace the view back. + org_view = view + view = next_node(curr_edge) + while self.is_view(view, sdfg): + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"View tracing of '{org_view}' failed at note '{view}'.") + view = next_node(curr_edge) + return view diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_parallel.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_parallel.py new file mode 100644 index 0000000000..19412b9dfa --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_parallel.py @@ -0,0 +1,170 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +"""Implements the parallel map fusing transformation. + +THIS FILE WAS COPIED FROM DACE TO FACILITATE DEVELOPMENT UNTIL THE PR#1625 IN +DACE IS MERGED AND THE VERSION WAS UPGRADED. +""" + +from typing import Any, Optional, Set, Union + +import dace +from dace import properties, transformation +from dace.sdfg import SDFG, SDFGState, graph, nodes + +from . import map_fusion_helper as mfh + + +@properties.make_properties +class MapFusionParallel(mfh.MapFusionHelper): + """The `MapFusionParallel` transformation allows to merge two parallel maps. + + While the `MapFusionSerial` transformation fuses maps that are sequentially + connected through an intermediate node, this transformation is able to fuse any + two maps that are not sequential and in the same scope. + + Args: + only_if_common_ancestor: Only perform fusion if both Maps share at least one + node as direct ancestor. This will increase the locality of the merge. + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + apply_fusion_callback: A user supplied function, same signature as `can_be_fused()`, + to check if a fusion should be performed. + + Note: + This transformation only matches the entry nodes of the Map, but will also + modify the exit nodes of the Maps. + """ + + map_entry_1 = transformation.transformation.PatternNode(nodes.MapEntry) + map_entry_2 = transformation.transformation.PatternNode(nodes.MapEntry) + + only_if_common_ancestor = properties.Property( + dtype=bool, + default=False, + allow_none=False, + desc="Only perform fusing if the Maps share a node as parent.", + ) + + def __init__( + self, + only_if_common_ancestor: Optional[bool] = None, + **kwargs: Any, + ) -> None: + if only_if_common_ancestor is not None: + self.only_if_common_ancestor = only_if_common_ancestor + super().__init__(**kwargs) + + @classmethod + def expressions(cls) -> Any: + # This just matches _any_ two Maps inside a state. + state = graph.OrderedMultiDiConnectorGraph() + state.add_nodes_from([cls.map_entry_1, cls.map_entry_2]) + return [state] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Checks if the fusion can be done. + + The function checks the general fusing conditions and if the maps are parallel. + """ + map_entry_1: nodes.MapEntry = self.map_entry_1 + map_entry_2: nodes.MapEntry = self.map_entry_2 + + # Check the structural properties of the maps, this will also ensure that + # the two maps are in the same scope and the parameters can be renamed + if not self.can_be_fused( + map_entry_1=map_entry_1, + map_entry_2=map_entry_2, + graph=graph, + sdfg=sdfg, + permissive=permissive, + ): + return False + + # Since the match expression matches any two Maps, we have to ensure that + # the maps are parallel. The `can_be_fused()` function already verified + # if they are in the same scope. + if not self.is_parallel(graph=graph, node1=map_entry_1, node2=map_entry_2): + return False + + # Test if they have they share a node as direct ancestor. + if self.only_if_common_ancestor: + # This assumes that there is only one access node per data container in the state. + ancestors_1: Set[nodes.Node] = {e1.src for e1 in graph.in_edges(map_entry_1)} + if not any(e2.src in ancestors_1 for e2 in graph.in_edges(map_entry_2)): + return False + + return True + + def is_parallel( + self, + graph: SDFGState, + node1: nodes.Node, + node2: nodes.Node, + ) -> bool: + """Tests if `node1` and `node2` are parallel. + + The nodes are parallel if `node2` can not be reached from `node1` and vice versa. + + Args: + graph: The graph to traverse. + node1: The first node to check. + node2: The second node to check. + """ + + # In order to be parallel they must be in the same scope. + scope = graph.scope_dict() + if scope[node1] != scope[node2]: + return False + + # The `all_nodes_between()` function traverse the graph and returns `None` if + # `end` was not found. We have to call it twice, because we do not know + # which node is upstream if they are not parallel. + if self.is_node_reachable_from(graph=graph, begin=node1, end=node2): + return False + elif self.is_node_reachable_from(graph=graph, begin=node2, end=node1): + return False + return True + + def apply(self, graph: Union[SDFGState, SDFG], sdfg: SDFG) -> None: + """Performs the Map fusing. + + Essentially, the function relocate all edges from the scope nodes (`MapEntry` + and `MapExit`) of the second map to the scope nodes of the first map. + """ + + map_entry_1: nodes.MapEntry = self.map_entry_1 + map_exit_1: nodes.MapExit = graph.exit_node(map_entry_1) + map_entry_2: nodes.MapEntry = self.map_entry_2 + map_exit_2: nodes.MapExit = graph.exit_node(map_entry_2) + + # Before we do anything we perform the renaming. + self.rename_map_parameters( + first_map=map_entry_1.map, + second_map=map_entry_2.map, + second_map_entry=map_entry_2, + state=graph, + ) + + for to_node, from_node in zip((map_entry_1, map_exit_1), (map_entry_2, map_exit_2)): + self.relocate_nodes( + from_node=from_node, + to_node=to_node, + state=graph, + sdfg=sdfg, + ) + # The relocate function does not remove the node, so we must do it. + graph.remove_node(from_node) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_serial.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_serial.py new file mode 100644 index 0000000000..2cdcc455d4 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_fusion_serial.py @@ -0,0 +1,1007 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +"""Implements the serial map fusing transformation. + +THIS FILE WAS COPIED FROM DACE TO FACILITATE DEVELOPMENT UNTIL THE PR#1625 IN +DACE IS MERGED AND THE VERSION WAS UPGRADED. +""" + +import copy +from typing import Any, Dict, List, Optional, Set, Tuple, Union + +import dace +from dace import data, dtypes, properties, subsets, symbolic, transformation +from dace.sdfg import SDFG, SDFGState, graph, nodes + +from . import map_fusion_helper as mfh + + +@properties.make_properties +class MapFusionSerial(mfh.MapFusionHelper): + """Fuse two serial maps together. + + The transformation combines two maps into one that are connected through some + access nodes. Conceptually this transformation removes the exit of the first + or upper map and the entry of the lower or second map and then rewrites the + connections appropriately. Depending on the situation the transformation will + either fully remove or make the intermediate a new output of the second map. + + By default, the transformation does not use the strict data flow mode, see + `MapFusionHelper` for more, however, it might be useful in come cases to enable + it, especially in the context of DaCe's auto optimizer. + + Args: + only_inner_maps: Only match Maps that are internal, i.e. inside another Map. + only_toplevel_maps: Only consider Maps that are at the top. + strict_dataflow: If `True`, the transformation ensures a more + stricter version of the data flow. + apply_fusion_callback: A user supplied function, same signature as `can_be_fused()`, + to check if a fusion should be performed. + + Notes: + - This transformation modifies more nodes than it matches. + - After the transformation has been applied simplify should be run to remove + some dead data flow, that was introduced to ensure validity. + - A `MapFusionSerial` object can be initialized and be reused. However, + after new access nodes are added to any state, it is no longer valid + to use the object. + + Todo: + - Consider the case that only shared nodes are created (thus no inspection of + the graph is needed) and make all shared. Then use the dead dataflow + elimination transformation to get rid of the ones we no longer need. + - Increase the applicability. + """ + + map_exit_1 = transformation.transformation.PatternNode(nodes.MapExit) + intermediate_access_node = transformation.transformation.PatternNode(nodes.AccessNode) + map_entry_2 = transformation.transformation.PatternNode(nodes.MapEntry) + + def __init__( + self, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + @classmethod + def expressions(cls) -> Any: + """Get the match expression. + + The transformation matches the exit node of the top Map that is connected to + an access node that again is connected to the entry node of the second Map. + An important note is, that the transformation operates not just on the + matched nodes, but more or less on anything that has an incoming connection + from the first Map or an outgoing connection to the second Map entry. + """ + return [ + dace.sdfg.utils.node_path_graph( + cls.map_exit_1, cls.intermediate_access_node, cls.map_entry_2 + ) + ] + + def can_be_applied( + self, + graph: Union[SDFGState, SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + """Tests if the matched Maps can be merged. + + The two Maps are mergeable iff: + - Satisfies general requirements, see `MapFusionHelper.can_be_fused()`. + - Tests if the decomposition exists. + - Tests if there are read write dependencies. + """ + map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit_1) + map_exit_1: nodes.MapExit = self.map_exit_1 + map_entry_2: nodes.MapEntry = self.map_entry_2 + + # This essentially test the structural properties of the two Maps. + if not self.can_be_fused( + map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg + ): + return False + + # Test for read-write conflicts + if self.has_read_write_dependency( + map_entry_1=map_entry_1, + map_entry_2=map_entry_2, + state=graph, + sdfg=sdfg, + ): + return False + + # Two maps can be serially fused if the node decomposition exists and + # at least one of the intermediate output sets is not empty. The state + # of the pure outputs is irrelevant for serial map fusion. + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + ) + if output_partition is None: + return False + _, exclusive_outputs, shared_outputs = output_partition + if not (exclusive_outputs or shared_outputs): + return False + return True + + def has_read_write_dependency( + self, + map_entry_1: nodes.MapEntry, + map_entry_2: nodes.MapEntry, + state: SDFGState, + sdfg: SDFG, + ) -> bool: + """Test if there is a read write dependency between the two maps to be fused. + + The function checks two different things. + - The function will make sure that there is no read write dependency between + the input and output of the fused maps. For that it will inspect the + respective subsets. + - The second part partially checks the intermediate nodes, it mostly ensures + that there are not views and that they are not used as inputs or outputs + at the same time. However, the function will not check for read write + conflicts in this set, this is done in the partition function. + + Returns: + `True` if there is a conflict between the maps that can not be handled. + If there is no conflict or if the conflict can be handled `False` + is returned. + + Args: + map_entry_1: The entry node of the first map. + map_entry_2: The entry node of the second map. + state: The state on which we operate. + sdfg: The SDFG on which we operate. + """ + map_exit_1: nodes.MapExit = state.exit_node(map_entry_1) + map_exit_2: nodes.MapExit = state.exit_node(map_entry_2) + + # Get the read and write sets of the different maps, note that Views + # are not resolved yet. + access_sets: List[Dict[str, nodes.AccessNode]] = [] + for scope_node in [map_entry_1, map_exit_1, map_entry_2, map_exit_2]: + access_set: Set[nodes.AccessNode] = self.get_access_set(scope_node, state) + access_sets.append({node.data: node for node in access_set}) + # If two different access nodes of the same scoping node refers to the + # same data, then we consider this as a dependency we can not handle. + # It is only a problem for the intermediate nodes and might be possible + # to handle, but doing so is hard, so we just forbid it. + if len(access_set) != len(access_sets[-1]): + return True + read_map_1, write_map_1, read_map_2, write_map_2 = access_sets + + # It might be possible that there are views, so we have to resolve them. + # We also already get the name of the data container. + # Note that `len(real_read_map_1) <= len(read_map_1)` holds because of Views. + resolved_sets: List[Set[str]] = [] + for unresolved_set in [read_map_1, write_map_1, read_map_2, write_map_2]: + resolved_sets.append( + { + self.track_view(node, state, sdfg).data + if self.is_view(node, sdfg) + else node.data + for node in unresolved_set.values() + } + ) + # If the resolved and unresolved names do not have the same length. + # Then different views point to the same location, which we forbid + if len(unresolved_set) != len(resolved_sets[-1]): + return False + real_read_map_1, real_write_map_1, real_read_map_2, real_write_map_2 = resolved_sets + + # We do not allow that the first and second map each write to the same data. + if not real_write_map_1.isdisjoint(real_write_map_2): + return True + + # If there is no overlap in what is (totally) read and written, there will be no conflict. + # This must come before the check of disjoint write. + if (real_read_map_1 | real_read_map_2).isdisjoint(real_write_map_1 | real_write_map_2): + return False + + # These are the names (unresolved) and the access nodes of the data that is used + # to transmit information between the maps. The partition function ensures that + # these nodes are directly connected to the two maps. + exchange_names: Set[str] = set(write_map_1.keys()).intersection(read_map_2.keys()) + exchange_nodes: Set[nodes.AccessNode] = set(write_map_1.values()).intersection( + read_map_2.values() + ) + + # If the number are different then a data is accessed through multiple nodes. + if len(exchange_names) != len(exchange_nodes): + return True + assert all(exchange_node.data in exchange_names for exchange_node in exchange_nodes) + + # For simplicity we assume that the nodes used for exchange are not views. + if any(self.is_view(exchange_node, sdfg) for exchange_node in exchange_nodes): + return True + + # This is the names of the node that are used as input of the first map and + # as output of the second map. We have to ensure that there is no data + # dependency between these nodes. + fused_inout_data_names: Set[str] = set(read_map_1.keys()).intersection(write_map_2.keys()) + + # If a data container is used as input and output then it can not be a view (simplicity) + if any(self.is_view(read_map_1[name], sdfg) for name in fused_inout_data_names): + return True + + # A data container can be used as input and output. But we do not allow that + # it is also used as intermediate or exchange data. This is an important check. + if not fused_inout_data_names.isdisjoint(exchange_names): + return True + + # Get the replacement dict for changing the map variables from the subsets of + # the second map. + repl_dict = self.find_parameter_remapping(map_entry_1.map, map_exit_2.map) + + # Now we inspect if there is a read write dependency, between data that is + # used as input and output of the fused map. There is no problem is they + # are pointwise, i.e. in each iteration the same locations are accessed. + # Essentially they all boil down to `a += 1`. + for inout_data_name in fused_inout_data_names: + all_subsets: List[subsets.Subset] = [] + # The subsets that define reading are given by the first map's entry node + all_subsets.extend( + self.find_subsets( + node=read_map_1[inout_data_name], + scope_node=map_entry_1, + state=state, + sdfg=sdfg, + repl_dict=None, + ) + ) + # While the subsets defining writing are given by the second map's exit + # node, there we also have to apply renaming. + all_subsets.extend( + self.find_subsets( + node=write_map_2[inout_data_name], + scope_node=map_exit_2, + state=state, + sdfg=sdfg, + repl_dict=repl_dict, + ) + ) + # Now we can test if these subsets are point wise + if not self.test_if_subsets_are_point_wise(all_subsets): + return True + + # No read write dependency was found. + return False + + def test_if_subsets_are_point_wise(self, subsets_to_check: List[subsets.Subset]) -> bool: + """Point wise means that they are all the same. + + If a series of subsets are point wise it means that all Memlets, access + the same data. This is an important property because the whole map fusion + is build upon this. + If the subsets originates from different maps, then they must have been + renamed. + + Args: + subsets_to_check: The list of subsets that should be checked. + """ + assert len(subsets_to_check) > 1 + + # We will check everything against the master subset. + master_subset = subsets_to_check[0] + for ssidx in range(1, len(subsets_to_check)): + subset = subsets_to_check[ssidx] + if isinstance(subset, subsets.Indices): + subset = subsets.Range.from_indices(subset) + # Do we also need the reverse? See below why. + if any(r != (0, 0, 1) for r in subset.offset_new(master_subset, negative=True)): + return False + else: + # The original code used `Range.offset` here, but that one had trouble + # for `r1 = 'j, 0:10'` and `r2 = 'j, 0`. The solution would be to test + # symmetrically, i.e. `r1 - r2` and `r2 - r1`. However, if we would + # have `r2_1 = 'j, 0:10'` it consider it as failing, which is not + # what we want. Thus we will use symmetric cover. + if not master_subset.covers(subset): + return False + if not subset.covers(master_subset): + return False + + # All subsets are equal to the master subset, thus they are equal to each other. + # This means that the data accesses, described by this transformation is + # point wise + return True + + def compute_offset_subset( + self, + original_subset: subsets.Range, + intermediate_desc: data.Data, + map_params: List[str], + producer_offset: Optional[subsets.Range] = None, + ) -> subsets.Range: + """Computes the memlet to correct read and writes of the intermediate. + + Args: + original_subset: The original subset that was used to write into the + intermediate, must be renamed to the final map parameter. + intermediate_desc: The original intermediate data descriptor. + map_params: The parameter of the final map. + """ + assert not isinstance(intermediate_desc, data.View) + final_offset: subsets.Range = None + if isinstance(intermediate_desc, data.Scalar): + final_offset = subsets.Range.from_string("0") + + elif isinstance(intermediate_desc, data.Array): + basic_offsets = original_subset.min_element() + offset_list = [] + for d in range(original_subset.dims()): + d_range = subsets.Range([original_subset[d]]) + if d_range.free_symbols.intersection(map_params): + offset_list.append(d_range[0]) + else: + offset_list.append((basic_offsets[d], basic_offsets[d], 1)) + final_offset = subsets.Range(offset_list) + + else: + raise TypeError( + f"Does not know how to compute the subset offset for '{type(intermediate_desc).__name__}'." + ) + + if producer_offset is not None: + # Here we are correcting some parts that over approximate (which partially + # does under approximate) might screw up. Consider two maps, the first + # map only writes the subset `[:, 2:6]`, thus the new intermediate will + # have shape `(1, 4)`. Now also imagine that the second map only reads + # the elements `[:, 3]`. From this we see that we can only correct the + # consumer side if we also take the producer side into consideration! + # See also the `transformations/mapfusion_test.py::test_offset_correction_*` + # tests for more. + final_offset.offset( + final_offset.offset_new( + producer_offset, + negative=True, + ), + negative=True, + ) + return final_offset + + def partition_first_outputs( + self, + state: SDFGState, + sdfg: SDFG, + map_exit_1: nodes.MapExit, + map_entry_2: nodes.MapEntry, + ) -> Union[ + Tuple[ + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + Set[graph.MultiConnectorEdge[dace.Memlet]], + ], + None, + ]: + """Partition the output edges of `map_exit_1` for serial map fusion. + + The output edges of the first map are partitioned into three distinct sets, + defined as follows: + - Pure Output Set `\mathbb{P}`: + These edges exits the first map and does not enter the second map. These + outputs will be simply be moved to the output of the second map. + - Exclusive Intermediate Set `\mathbb{E}`: + Edges in this set leaves the first map exit, enters an access node, from + where a Memlet then leads immediately to the second map. The memory + referenced by this access node is not used anywhere else, thus it can + be removed. + - Shared Intermediate Set `\mathbb{S}`: + These edges are very similar to the one in `\mathbb{E}` except that they + are used somewhere else, thus they can not be removed and have to be + recreated as output of the second map. + + If strict data flow mode is enabled the function is rather strict if an + output can be added to either intermediate set and might fail to compute + the partition, even if it would exist. + + Returns: + If such a decomposition exists the function will return the three sets + mentioned above in the same order. + In case the decomposition does not exist, i.e. the maps can not be fused + the function returns `None`. + + Args: + state: The in which the two maps are located. + sdfg: The full SDFG in whcih we operate. + map_exit_1: The exit node of the first map. + map_entry_2: The entry node of the second map. + """ + # The three outputs set. + pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + exclusive_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + shared_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set() + + # Compute the renaming that for translating the parameter of the _second_ + # map to the ones used by the first map. + repl_dict: Dict[str, str] = self.find_parameter_remapping( # type: ignore[assignment] + first_map=map_exit_1.map, + second_map=map_entry_2.map, + ) + assert repl_dict is not None + + # Set of intermediate nodes that we have already processed. + processed_inter_nodes: Set[nodes.Node] = set() + + # These are the data that is written to multiple times in _this_ state. + # If a data is written to multiple time in a state, it could be + # classified as shared. However, it might happen that the node has zero + # degree. This is not a problem as the maps also induced a before-after + # relationship. But some DaCe transformations do not catch this. + # Thus we will never modify such intermediate nodes and fail instead. + if self.strict_dataflow: + multi_write_data: Set[str] = self._compute_multi_write_data(state, sdfg) + else: + multi_write_data = set() + + # Now scan all output edges of the first exit and classify them + for out_edge in state.out_edges(map_exit_1): + intermediate_node: nodes.Node = out_edge.dst + + # We already processed the node, this should indicate that we should + # run simplify again, or we should start implementing this case. + # TODO(phimuell): Handle this case, already partially handled here. + if intermediate_node in processed_inter_nodes: + return None + processed_inter_nodes.add(intermediate_node) + + # The intermediate can only have one incoming degree. It might be possible + # to handle multiple incoming edges, if they all come from the top map. + # However, the resulting SDFG might be invalid. + # NOTE: Allow this to happen (under certain cases) if the only producer + # is the top map. + if state.in_degree(intermediate_node) != 1: + return None + + # If the second map is not reachable from the intermediate node, then + # the output is pure and we can end here. + if not self.is_node_reachable_from( + graph=state, + begin=intermediate_node, + end=map_entry_2, + ): + pure_outputs.add(out_edge) + continue + + # The following tests are _after_ we have determined if we have a pure + # output node, because this allows us to handle more exotic pure node + # cases, as handling them is essentially rerouting an edge, whereas + # handling intermediate nodes is much more complicated. + + # For us an intermediate node must always be an access node, because + # everything else we do not know how to handle. It is important that + # we do not test for non transient data here, because they can be + # handled has shared intermediates. + if not isinstance(intermediate_node, nodes.AccessNode): + return None + if self.is_view(intermediate_node, sdfg): + return None + + # Checks if the intermediate node refers to data that is accessed by + # _other_ access nodes in _this_ state. If this is the case then never + # touch this intermediate node. + # TODO(phimuell): Technically it would be enough to turn the node into + # a shared output node, because this will still fulfil the dependencies. + # However, some DaCe transformation can not handle this properly, so we + # are _forced_ to reject this node. + if intermediate_node.data in multi_write_data: + return None + + # Empty Memlets are only allowed if they are in `\mathbb{P}`, which + # is also the only place they really make sense (for a map exit). + # Thus if we now found an empty Memlet we reject it. + if out_edge.data.is_empty(): + return None + + # It can happen that multiple edges converges at the `IN_` connector + # of the first map exit, but there is only one edge leaving the exit. + # It is complicate to handle this, so for now we ignore it. + # TODO(phimuell): Handle this case properly. + # To handle this we need to associate a consumer edge (the outgoing edges + # of the second map) with exactly one producer. + producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list( + state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) + ) + if len(producer_edges) > 1: + return None + + # Now check the constraints we have on the producers. + # - The source of the producer can not be a view (we do not handle this) + # - The edge shall also not be a reduction edge. + # - Defined location to where they write. + # - No dynamic Memlets. + # Furthermore, we will also extract the subsets, i.e. the location they + # modify inside the intermediate array. + # Since we do not allow for WCR, we do not check if the producer subsets intersects. + producer_subsets: List[subsets.Subset] = [] + for producer_edge in producer_edges: + if isinstance(producer_edge.src, nodes.AccessNode) and self.is_view( + producer_edge.src, sdfg + ): + return None + if producer_edge.data.dynamic: + return None + if producer_edge.data.wcr is not None: + return None + if producer_edge.data.dst_subset is None: + return None + producer_subsets.append(producer_edge.data.dst_subset) + + # Check if the producer do not intersect + if len(producer_subsets) == 1: + pass + elif len(producer_subsets) == 2: + if producer_subsets[0].intersects(producer_subsets[1]): + return None + else: + for i, psbs1 in enumerate(producer_subsets): + for j, psbs2 in enumerate(producer_subsets): + if i == j: + continue + if psbs1.intersects(psbs2): + return None + + # Now we determine the consumer of nodes. For this we are using the edges + # leaves the second map entry. It is not necessary to find the actual + # consumer nodes, as they might depend on symbols of nested Maps. + # For the covering test we only need their subsets, but we will perform + # some scan and filtering on them. + found_second_map = False + consumer_subsets: List[subsets.Subset] = [] + for intermediate_node_out_edge in state.out_edges(intermediate_node): + # If the second map entry is not immediately reachable from the intermediate + # node, then ensure that there is not path that goes to it. + if intermediate_node_out_edge.dst is not map_entry_2: + if self.is_node_reachable_from( + graph=state, begin=intermediate_node_out_edge.dst, end=map_entry_2 + ): + return None + continue + + # Ensure that the second map is found exactly once. + # TODO(phimuell): Lift this restriction. + if found_second_map: + return None + found_second_map = True + + # The output of the top map can not define a dynamic map range in the + # second map. + if not intermediate_node_out_edge.dst_conn.startswith("IN_"): + return None + + # Now we look at all edges that leave the second map entry, i.e. the + # edges that feeds the consumer and define what is read inside the map. + # We do not check them, but collect them and inspect them. + # NOTE: The subset still uses the old iteration variables. + for inner_consumer_edge in state.out_edges_by_connector( + map_entry_2, "OUT_" + intermediate_node_out_edge.dst_conn[3:] + ): + if inner_consumer_edge.data.src_subset is None: + return None + if inner_consumer_edge.data.dynamic: + # TODO(phimuell): Is this restriction necessary, I am not sure. + return None + consumer_subsets.append(inner_consumer_edge.data.src_subset) + assert ( + found_second_map + ), f"Found '{intermediate_node}' which looked like a pure node, but is not one." + assert len(consumer_subsets) != 0 + + # The consumer still uses the original symbols of the second map, so we must rename them. + if repl_dict: + consumer_subsets = copy.deepcopy(consumer_subsets) + for consumer_subset in consumer_subsets: + symbolic.safe_replace( + mapping=repl_dict, replace_callback=consumer_subset.replace + ) + + # Now we are checking if a single iteration of the first (top) map + # can satisfy all data requirements of the second (bottom) map. + # For this we look if the producer covers the consumer. A consumer must + # be covered by exactly one producer. + for consumer_subset in consumer_subsets: + nb_coverings = sum( + producer_subset.covers(consumer_subset) for producer_subset in producer_subsets + ) + if nb_coverings != 1: + return None + + # After we have ensured coverage, we have to decide if the intermediate + # node can be removed (`\mathbb{E}`) or has to be restored (`\mathbb{S}`). + # Note that "removed" here means that it is reconstructed by a new + # output of the second map. + if self.is_shared_data(intermediate_node, sdfg): + # The intermediate data is used somewhere else, either in this or another state. + shared_outputs.add(out_edge) + else: + # The intermediate can be removed, as it is not used anywhere else. + exclusive_outputs.add(out_edge) + + assert len(processed_inter_nodes) == sum( + len(x) for x in [pure_outputs, exclusive_outputs, shared_outputs] + ) + return (pure_outputs, exclusive_outputs, shared_outputs) + + def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: + """Performs the serial Map fusing. + + The function first computes the map decomposition and then handles the + three sets. The pure outputs are handled by `relocate_nodes()` while + the two intermediate sets are handled by `handle_intermediate_set()`. + + By assumption we do not have to rename anything. + + Args: + graph: The SDFG state we are operating on. + sdfg: The SDFG we are operating on. + """ + # NOTE: `self.map_*` actually stores the ID of the node. + # once we start adding and removing nodes it seems that their ID changes. + # Thus we have to save them here, this is a known behaviour in DaCe. + assert isinstance(graph, dace.SDFGState) + assert isinstance(self.map_exit_1, nodes.MapExit) + assert isinstance(self.map_entry_2, nodes.MapEntry) + + map_exit_1: nodes.MapExit = self.map_exit_1 + map_entry_2: nodes.MapEntry = self.map_entry_2 + map_exit_2: nodes.MapExit = graph.exit_node(self.map_entry_2) + map_entry_1: nodes.MapEntry = graph.entry_node(self.map_exit_1) + + # Before we do anything we perform the renaming. + self.rename_map_parameters( + first_map=map_exit_1.map, + second_map=map_entry_2.map, + second_map_entry=map_entry_2, + state=graph, + ) + + output_partition = self.partition_first_outputs( + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + ) + assert output_partition is not None # Make MyPy happy. + pure_outputs, exclusive_outputs, shared_outputs = output_partition + + if len(exclusive_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=exclusive_outputs, + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + map_exit_2=map_exit_2, + is_exclusive_set=True, + ) + if len(shared_outputs) != 0: + self.handle_intermediate_set( + intermediate_outputs=shared_outputs, + state=graph, + sdfg=sdfg, + map_exit_1=map_exit_1, + map_entry_2=map_entry_2, + map_exit_2=map_exit_2, + is_exclusive_set=False, + ) + assert pure_outputs == set(graph.out_edges(map_exit_1)) + if len(pure_outputs) != 0: + self.relocate_nodes( + from_node=map_exit_1, + to_node=map_exit_2, + state=graph, + sdfg=sdfg, + ) + + # Above we have handled the input of the second map and moved them + # to the first map, now we must move the output of the first map + # to the second one, as this one is used. + self.relocate_nodes( + from_node=map_entry_2, + to_node=map_entry_1, + state=graph, + sdfg=sdfg, + ) + + for node_to_remove in [map_exit_1, map_entry_2]: + assert graph.degree(node_to_remove) == 0 + graph.remove_node(node_to_remove) + + # Now turn the second output node into the output node of the first Map. + map_exit_2.map = map_entry_1.map + + def handle_intermediate_set( + self, + intermediate_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]], + state: SDFGState, + sdfg: SDFG, + map_exit_1: nodes.MapExit, + map_entry_2: nodes.MapEntry, + map_exit_2: nodes.MapExit, + is_exclusive_set: bool, + ) -> None: + """This function handles the intermediate sets. + + The function is able to handle both the shared and exclusive intermediate + output set, see `partition_first_outputs()`. The main difference is that + in exclusive mode the intermediate nodes will be fully removed from + the SDFG. While in shared mode the intermediate node will be preserved. + The function assumes that the parameter renaming was already done. + + Args: + intermediate_outputs: The set of outputs, that should be processed. + state: The state in which the map is processed. + sdfg: The SDFG that should be optimized. + map_exit_1: The exit of the first/top map. + map_entry_2: The entry of the second map. + map_exit_2: The exit of the second map. + is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. + + Notes: + Before the transformation the `state` does not have to be valid and + after this function has run the state is (most likely) invalid. + """ + + map_params = map_exit_1.map.params.copy() + + # Now we will iterate over all intermediate edges and process them. + # If not stated otherwise the comments assume that we run in exclusive mode. + for out_edge in intermediate_outputs: + # This is the intermediate node that, that we want to get rid of. + # In shared mode we want to recreate it after the second map. + inter_node: nodes.AccessNode = out_edge.dst + inter_name = inter_node.data + inter_desc = inter_node.desc(sdfg) + inter_shape = inter_desc.shape + + # Now we will determine the shape of the new intermediate. This size of + # this temporary is given by the Memlet that goes into the first map exit. + pre_exit_edges = list( + state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) + ) + if len(pre_exit_edges) != 1: + raise NotImplementedError() + pre_exit_edge = pre_exit_edges[0] + new_inter_shape_raw = symbolic.overapproximate(pre_exit_edge.data.subset.size()) + + # Over approximation will leave us with some unneeded size one dimensions. + # If they are removed some dace transformations (especially auto optimization) + # will have problems. + if not self.strict_dataflow: + squeezed_dims: List[int] = [] # These are the dimensions we removed. + new_inter_shape: List[int] = [] # This is the final shape of the new intermediate. + for dim, (proposed_dim_size, full_dim_size) in enumerate( + zip(new_inter_shape_raw, inter_shape) + ): + if full_dim_size == 1: # Must be kept! + new_inter_shape.append(proposed_dim_size) + elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. + squeezed_dims.append(dim) + else: + new_inter_shape.append(proposed_dim_size) + else: + squeezed_dims = [] + new_inter_shape = list(new_inter_shape_raw) + + # This is the name of the new "intermediate" node that we will create. + # It will only have the shape `new_inter_shape` which is basically its + # output within one Map iteration. + # NOTE: The insertion process might generate a new name. + new_inter_name: str = f"__s{sdfg.node_id(state)}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" + + # Now generate the intermediate data container. + if len(new_inter_shape) == 0: + assert pre_exit_edge.data.subset.num_elements() == 1 + is_scalar = True + new_inter_name, new_inter_desc = sdfg.add_scalar( + new_inter_name, + dtype=inter_desc.dtype, + transient=True, + storage=dtypes.StorageType.Register, + find_new_name=True, + ) + + else: + assert (pre_exit_edge.data.subset.num_elements() > 1) or all( + x == 1 for x in new_inter_shape + ) + is_scalar = False + new_inter_name, new_inter_desc = sdfg.add_transient( + new_inter_name, + shape=new_inter_shape, + dtype=inter_desc.dtype, + find_new_name=True, + storage=dtypes.StorageType.Register, + ) + new_inter_node: nodes.AccessNode = state.add_access(new_inter_name) + + # Get the subset that defined into which part of the old intermediate + # the old output edge wrote to. We need that to adjust the producer + # Memlets, since they now write into the new (smaller) intermediate. + assert pre_exit_edge.data.data == inter_name + assert pre_exit_edge.data.dst_subset is not None + producer_offset = self.compute_offset_subset( + original_subset=pre_exit_edge.data.dst_subset, + intermediate_desc=inter_desc, + map_params=map_params, + ) + + # Memlets have a lot of additional informations, such as dynamic. + # To ensure that we get all of them, we will now copy them and modify + # the one that was originally there. We also hope that propagate will + # set the rest for us correctly. + new_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + new_pre_exit_memlet.data = new_inter_name + new_pre_exit_memlet.dst_subset = subsets.Range.from_array(new_inter_desc) + + # New we will reroute the output Memlet, thus it will no longer pass + # through the Map exit but through the newly created intermediate. + # NOTE: We will delete the previous edge later. + new_pre_exit_edge = state.add_edge( + pre_exit_edge.src, + pre_exit_edge.src_conn, + new_inter_node, + None, + new_pre_exit_memlet, + ) + + # We now handle the MemletTree defined by this edge. + # The newly created edge, only handled the last collection step. + for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children( + include_self=False + ): + producer_edge = producer_tree.edge + + # Associate the (already existing) Memlet with the new data. + # TODO(phimuell): Improve the code below to remove the check. + assert producer_edge.data.data == inter_name + producer_edge.data.data = new_inter_name + + if is_scalar: + producer_edge.data.dst_subset = "0" + elif producer_edge.data.dst_subset is not None: + # Since we now write into a smaller memory patch, we must + # compensate for that. We do this by substracting where the write + # originally had begun. + producer_edge.data.dst_subset.offset(producer_offset, negative=True) + producer_edge.data.dst_subset.pop(squeezed_dims) + + # Now after we have handled the input of the new intermediate node, + # we must handle its output. For this we have to "inject" the newly + # created intermediate into the second map. We do this by finding + # the input connectors on the map entry, such that we know where we + # have to reroute inside the Map. + # NOTE: Assumes that map (if connected is the direct neighbour). + conn_names: Set[str] = set() + for inter_node_out_edge in state.out_edges(inter_node): + if inter_node_out_edge.dst == map_entry_2: + assert inter_node_out_edge.dst_conn.startswith("IN_") + conn_names.add(inter_node_out_edge.dst_conn) + else: + # If we found another target than the second map entry from the + # intermediate node it means that the node _must_ survive, + # i.e. we are not in exclusive mode. + assert not is_exclusive_set + + # Now we will reroute the connections inside the second map, i.e. + # instead of consuming the old intermediate node, they will now + # consume the new intermediate node. + for in_conn_name in conn_names: + out_conn_name = "OUT_" + in_conn_name[3:] + + for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): + assert inner_edge.data.data == inter_name # DIRECTION!! + + # As for the producer side, we now read from a smaller array, + # So we must offset them, we use the original edge for this. + assert inner_edge.data.src_subset is not None + consumer_offset = self.compute_offset_subset( + original_subset=inner_edge.data.src_subset, + intermediate_desc=inter_desc, + map_params=map_params, + producer_offset=producer_offset, + ) + + # Now we create a new connection that instead reads from the new + # intermediate, instead of the old one. For this we use the + # old Memlet as template. However it is not fully initialized. + new_inner_memlet = copy.deepcopy(inner_edge.data) + new_inner_memlet.data = new_inter_name + + # Now we replace the edge from the SDFG. + state.remove_edge(inner_edge) + new_inner_edge = state.add_edge( + new_inter_node, + None, + inner_edge.dst, + inner_edge.dst_conn, + new_inner_memlet, + ) + + # Now modifying the Memlet, we do it after the insertion to make + # sure that the Memlet was properly initialized. + if is_scalar: + new_inner_memlet.subset = "0" + elif new_inner_memlet.src_subset is not None: + new_inner_memlet.src_subset.offset(consumer_offset, negative=True) + new_inner_memlet.src_subset.pop(squeezed_dims) + + # Now we have to make sure that all consumers are properly updated. + for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children( + include_self=False + ): + assert consumer_tree.edge.data.data == inter_name + + consumer_edge = consumer_tree.edge + consumer_edge.data.data = new_inter_name + if is_scalar: + consumer_edge.data.src_subset = "0" + elif consumer_edge.data.src_subset is not None: + consumer_edge.data.src_subset.offset(consumer_offset, negative=True) + consumer_edge.data.src_subset.pop(squeezed_dims) + + # The edge that leaves the second map entry was already deleted. We now delete + # the edges that connected the intermediate node with the second map entry. + for edge in list(state.in_edges_by_connector(map_entry_2, in_conn_name)): + assert edge.src == inter_node + state.remove_edge(edge) + map_entry_2.remove_in_connector(in_conn_name) + map_entry_2.remove_out_connector(out_conn_name) + + if is_exclusive_set: + # In exclusive mode the old intermediate node is no longer needed. + # This will also remove `out_edge` from the SDFG. + assert state.degree(inter_node) == 1 + state.remove_edge_and_connectors(out_edge) + state.remove_node(inter_node) + + state.remove_edge(pre_exit_edge) + map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + map_exit_1.remove_out_connector(out_edge.src_conn) + del sdfg.arrays[inter_name] + + else: + # This is the shared mode, so we have to recreate the intermediate + # node, but this time it is at the exit of the second map. + state.remove_edge(pre_exit_edge) + map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) + + # This is the Memlet that goes from the map internal intermediate + # temporary node to the Map output. This will essentially restore + # or preserve the output for the intermediate node. It is important + # that we use the data that `preExitEdge` was used. + final_pre_exit_memlet = copy.deepcopy(pre_exit_edge.data) + assert pre_exit_edge.data.data == inter_name + final_pre_exit_memlet.other_subset = subsets.Range.from_array(new_inter_desc) + + new_pre_exit_conn = map_exit_2.next_connector() + state.add_edge( + new_inter_node, + None, + map_exit_2, + "IN_" + new_pre_exit_conn, + final_pre_exit_memlet, + ) + state.add_edge( + map_exit_2, + "OUT_" + new_pre_exit_conn, + inter_node, + out_edge.dst_conn, + copy.deepcopy(out_edge.data), + ) + map_exit_2.add_in_connector("IN_" + new_pre_exit_conn) + map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn) + + map_exit_1.remove_out_connector(out_edge.src_conn) + state.remove_edge(out_edge) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py index 4b34dd6adc..8fb41c7d0a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_orderer.py @@ -16,12 +16,42 @@ from gt4py.next.program_processors.runners.dace_fieldview import utility as gtx_dace_fieldview_util +def gt_set_iteration_order( + sdfg: dace.SDFG, + leading_dim: Optional[ + Union[str, gtx_common.Dimension, list[Union[str, gtx_common.Dimension]]] + ] = None, + validate: bool = True, + validate_all: bool = False, +) -> Any: + """Set the iteration order of the Maps correctly. + + Modifies the order of the Map parameters such that `leading_dim` + is the fastest varying one, the order of the other dimensions in + a Map is unspecific. `leading_dim` should be the dimensions were + the stride is one. + + Args: + sdfg: The SDFG to process. + leading_dim: The leading dimensions. + validate: Perform validation at the end of the function. + validate_all: Perform validation also on intermediate steps. + """ + return sdfg.apply_transformations_once_everywhere( + MapIterationOrder( + leading_dims=leading_dim, + ), + validate=validate, + validate_all=validate_all, + ) + + @dace_properties.make_properties class MapIterationOrder(dace_transformation.SingleStateTransformation): """Modify the order of the iteration variables. The iteration order, while irrelevant from an SDFG point of view, is highly - relevant in code, and the fastest varying index ("inner most loop" in CPU or + relevant in code and the fastest varying index ("inner most loop" in CPU or "x block dimension" in GPU) should be associated with the stride 1 dimension of the array. This transformation will reorder the map indexes such that this is the case. @@ -29,9 +59,18 @@ class MapIterationOrder(dace_transformation.SingleStateTransformation): While the place of the leading dimension is clearly defined, the order of the other loop indexes, after this transformation is unspecified. + The transformation accepts either a single dimension or a list of dimensions. + In case a list is passed this is interpreted as priorities. + Assuming we have the `leading_dim=[EdgeDim, VertexDim]`, then we have the + following: + - `Map[EdgeDim, KDim, VertexDim]` -> `Map[KDim, VertexDim, EdgeDim]`. + - `Map[VertexDim, KDim]` -> `Map[KDim, VertexDim]`. + - `Map[EdgeDim, KDim]` -> `Map[KDim, EdgeDim]`. + - `Map[CellDim, KDim]` -> `Map[CellDim, KDim]` (no modification). + Args: - leading_dim: A GT4Py dimension object that identifies the dimension that - is supposed to have stride 1. + leading_dim: GT4Py dimensions that are associated with the dimension that is + supposed to have stride 1. If it is a list it is used as a ranking. Note: The transformation does follow the rules outlines in @@ -44,25 +83,33 @@ class MapIterationOrder(dace_transformation.SingleStateTransformation): - Maybe also process the parameters to bring them in a canonical order. """ - leading_dim = dace_properties.Property( - dtype=str, + leading_dims = dace_properties.ListProperty( + element_type=str, allow_none=True, - desc="Dimension that should become the leading dimension.", + default=None, + desc="Dimensions that should become the leading dimension.", ) - map_entry = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) def __init__( self, - leading_dim: Optional[Union[gtx_common.Dimension, str]] = None, + leading_dims: Optional[ + Union[str, gtx_common.Dimension, list[Union[str, gtx_common.Dimension]]] + ] = None, *args: Any, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) - if isinstance(leading_dim, gtx_common.Dimension): - self.leading_dim = gtx_dace_fieldview_util.get_map_variable(leading_dim) - elif leading_dim is not None: - self.leading_dim = leading_dim + if isinstance(leading_dims, (gtx_common.Dimension, str)): + leading_dims = [leading_dims] + if isinstance(leading_dims, list): + self.leading_dims = [ + leading_dim + if isinstance(leading_dim, str) + else gtx_dace_fieldview_util.get_map_variable(leading_dim) + for leading_dim in leading_dims + ] @classmethod def expressions(cls) -> Any: @@ -80,16 +127,15 @@ def can_be_applied( Essentially the function checks if the selected dimension is inside the map, and if so, if it is on the right place. """ - - if self.leading_dim is None: + if self.leading_dims is None: return False map_entry: dace_nodes.MapEntry = self.map_entry map_params: Sequence[str] = map_entry.map.params - map_var: str = self.leading_dim + processed_dims: set[str] = set(self.leading_dims) - if map_var not in map_params: + if not any(map_param in processed_dims for map_param in map_params): return False - if map_params[-1] == map_var: # Already at the correct location + if self.compute_map_param_order() is None: return False return True @@ -104,22 +150,52 @@ def apply( `self.leading_dim` the last map variable (this is given by the structure of DaCe's code generator). """ + map_object: dace_nodes.Map = self.map_entry.map + new_map_params_order: list[int] = self.compute_map_param_order() # type: ignore[assignment] # Guaranteed to be not `None`. + + def reorder(what: list[Any]) -> list[Any]: + assert isinstance(what, list) + return [what[new_pos] for new_pos in new_map_params_order] + + map_object.params = reorder(map_object.params) + map_object.range.ranges = reorder(map_object.range.ranges) + map_object.range.tile_sizes = reorder(map_object.range.tile_sizes) + + def compute_map_param_order(self) -> Optional[list[int]]: + """Computes the new iteration order of the matched map. + + The function returns a list, the value at index `i` indicates the old dimension + that should be put at the new location. If the order is already correct then + `None` is returned. + """ map_entry: dace_nodes.MapEntry = self.map_entry map_params: list[str] = map_entry.map.params - map_var: str = self.leading_dim - - # This implementation will just swap the variable that is currently the last - # with the one that should be the last. - dst_idx = -1 - src_idx = map_params.index(map_var) - - for to_process in [ - map_entry.map.params, - map_entry.map.range.ranges, - map_entry.map.range.tile_sizes, - ]: - assert isinstance(to_process, list) - src_val = to_process[src_idx] - dst_val = to_process[dst_idx] - to_process[dst_idx] = src_val - to_process[src_idx] = dst_val + org_mapping: dict[str, int] = {map_param: i for i, map_param in enumerate(map_params)} + leading_dims: list[str] = self.leading_dims + + # We divide the map parameters into two groups, the one we care and the others. + map_params_to_order: set[str] = { + map_param for map_param in map_params if map_param in leading_dims + } + + # If there is nothing to order, then we are done. + if not map_params_to_order: + return None + + # We start with all parameters that we ignore/do not care about. + new_map_params: list[str] = [ + map_param for map_param in map_params if map_param not in leading_dims + ] + + # Because how code generation works the leading dimension must be the most + # left one. Because this is also `self.leading_dims[0]` we have to process + # then in reverse order. + for map_param_to_check in reversed(leading_dims): + if map_param_to_check in map_params_to_order: + new_map_params.append(map_param_to_check) + assert len(map_params) == len(new_map_params) + + if map_params == new_map_params: + return None + + return [org_mapping[new_map_param] for new_map_param in new_map_params] diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py index 19818fd3d1..46d46c4bbe 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_promoter.py @@ -299,9 +299,9 @@ class SerialMapPromoter(BaseMapPromoter): """ # Pattern Matching - exit_first_map = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) - access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) - entry_second_map = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) + exit_first_map = dace_transformation.PatternNode(dace_nodes.MapExit) + access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + entry_second_map = dace_transformation.PatternNode(dace_nodes.MapEntry) @classmethod def expressions(cls) -> Any: @@ -346,17 +346,11 @@ def _test_if_promoted_maps_can_be_fused( ) -> bool: """This function checks if the promoted maps can be fused by map fusion. - This function assumes that `self.can_be_applied()` returned `True`. + This function assumes that `super().self.can_be_applied()` returned `True`. Args: state: The state in which we operate. sdfg: The SDFG we process. - - Note: - The current implementation uses a very hacky way to test this. - - Todo: - Find a better way of doing it. """ first_map_exit: dace_nodes.MapExit = self.exit_first_map access_node: dace_nodes.AccessNode = self.access_node @@ -373,23 +367,17 @@ def _test_if_promoted_maps_can_be_fused( # This will lead to a promotion of the map, this is needed that # Map fusion can actually inspect them. self.apply(graph=state, sdfg=sdfg) - - # Now create the map fusion object that we can then use to check if - # the fusion is possible or not. - serial_fuser = gtx_transformations.SerialMapFusion( - only_inner_maps=self.only_inner_maps, - only_toplevel_maps=self.only_toplevel_maps, - ) - candidate = { - type(serial_fuser).map_exit1: first_map_exit, - type(serial_fuser).access_node: access_node, - type(serial_fuser).map_entry2: second_map_entry, - } - state_id = sdfg.node_id(state) - serial_fuser.setup_match(sdfg, sdfg.cfg_id, state_id, candidate, 0, override=True) - - # Now use the serial fuser to see if fusion would succeed - if not serial_fuser.can_be_applied(state, 0, sdfg): + if not gtx_transformations.MapFusionSerial.can_be_applied_to( + sdfg=sdfg, + expr_index=0, + options={ + "only_inner_maps": self.only_inner_maps, + "only_toplevel_maps": self.only_toplevel_maps, + }, + map_exit_1=first_map_exit, + intermediate_access_node=access_node, + map_entry_2=second_map_entry, + ): return False finally: diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py deleted file mode 100644 index bca5aa2268..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/map_serial_fusion.py +++ /dev/null @@ -1,483 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -"""Implements the serial map fusing transformation. - -Note: - After DaCe [PR#1629](https://github.com/spcl/dace/pull/1629), that implements - a better map fusion transformation is merged, this file will be deleted. -""" - -import copy -from typing import Any, Union - -import dace -from dace import ( - dtypes as dace_dtypes, - properties as dace_properties, - subsets as dace_subsets, - symbolic as dace_symbolic, - transformation as dace_transformation, -) -from dace.sdfg import graph as dace_graph, nodes as dace_nodes - -from gt4py.next.program_processors.runners.dace_fieldview.transformations import map_fusion_helper - - -@dace_properties.make_properties -class SerialMapFusion(map_fusion_helper.MapFusionHelper): - """Specialized replacement for the map fusion transformation that is provided by DaCe. - - As its name is indicating this transformation is only able to handle Maps that - are in sequence. Compared to the native DaCe transformation, this one is able - to handle more complex cases of connection between the maps. In that sense, it - is much more similar to DaCe's `SubgraphFusion` transformation. - - Things that are improved, compared to the native DaCe implementation: - - Nested Maps. - - Temporary arrays and the correct propagation of their Memlets. - - Top Maps that have multiple outputs. - - Conceptually this transformation removes the exit of the first or upper map - and the entry of the lower or second map and then rewrites the connections - appropriately. - - This transformation assumes that an SDFG obeys the structure that is outlined - [here](https://hackmd.io/klvzLnzMR6GZBWtRU8HbDg#Requirements-on-SDFG). For that - reason it is not true replacement of the native DaCe transformation. - - Args: - only_inner_maps: Only match Maps that are internal, i.e. inside another Map. - only_toplevel_maps: Only consider Maps that are at the top. - - Notes: - - This transformation modifies more nodes than it matches! - """ - - map_exit1 = dace_transformation.transformation.PatternNode(dace_nodes.MapExit) - access_node = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) - map_entry2 = dace_transformation.transformation.PatternNode(dace_nodes.MapEntry) - - def __init__( - self, - **kwargs: Any, - ) -> None: - super().__init__(**kwargs) - - @classmethod - def expressions(cls) -> Any: - """Get the match expression. - - The transformation matches the exit node of the top Map that is connected to - an access node that again is connected to the entry node of the second Map. - An important note is, that the transformation operates not just on the - matched nodes, but more or less on anything that has an incoming connection - from the first Map or an outgoing connection to the second Map entry. - """ - return [dace.sdfg.utils.node_path_graph(cls.map_exit1, cls.access_node, cls.map_entry2)] - - def can_be_applied( - self, - graph: Union[dace.SDFGState, dace.SDFG], - expr_index: int, - sdfg: dace.SDFG, - permissive: bool = False, - ) -> bool: - """Tests if the matched Maps can be merged. - - The two Maps are mergeable iff: - - The `can_be_fused()` of the base succeed, which checks some basic constraints. - - The decomposition exists and at least one of the intermediate sets - is not empty. - """ - assert isinstance(self.map_exit1, dace_nodes.MapExit) - assert isinstance(self.map_entry2, dace_nodes.MapEntry) - map_entry_1: dace_nodes.MapEntry = graph.entry_node(self.map_exit1) - map_entry_2: dace_nodes.MapEntry = self.map_entry2 - - # This essentially test the structural properties of the two Maps. - if not self.can_be_fused( - map_entry_1=map_entry_1, map_entry_2=map_entry_2, graph=graph, sdfg=sdfg - ): - return False - - # Two maps can be serially fused if the node decomposition exists and - # at least one of the intermediate output sets is not empty. The state - # of the pure outputs is irrelevant for serial map fusion. - output_partition = self.partition_first_outputs( - state=graph, - sdfg=sdfg, - map_exit_1=self.map_exit1, - map_entry_2=self.map_entry2, - ) - if output_partition is None: - return False - _, exclusive_outputs, shared_outputs = output_partition - if not (exclusive_outputs or shared_outputs): - return False - return True - - def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> None: - """Performs the serial Map fusing. - - The function first computes the map decomposition and then handles the - three sets. The pure outputs are handled by `relocate_nodes()` while - the two intermediate sets are handled by `handle_intermediate_set()`. - - By assumption we do not have to rename anything. - - Args: - graph: The SDFG state we are operating on. - sdfg: The SDFG we are operating on. - """ - # NOTE: `self.map_*` actually stores the ID of the node. - # once we start adding and removing nodes it seems that their ID changes. - # Thus we have to save them here, this is a known behaviour in DaCe. - assert isinstance(graph, dace.SDFGState) - assert isinstance(self.map_exit1, dace_nodes.MapExit) - assert isinstance(self.map_entry2, dace_nodes.MapEntry) - assert self.map_parameter_compatible(self.map_exit1.map, self.map_entry2.map, graph, sdfg) - - map_exit_1: dace_nodes.MapExit = self.map_exit1 - map_entry_2: dace_nodes.MapEntry = self.map_entry2 - map_exit_2: dace_nodes.MapExit = graph.exit_node(self.map_entry2) - map_entry_1: dace_nodes.MapEntry = graph.entry_node(self.map_exit1) - - output_partition = self.partition_first_outputs( - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - ) - assert output_partition is not None # Make MyPy happy. - pure_outputs, exclusive_outputs, shared_outputs = output_partition - - if len(exclusive_outputs) != 0: - self.handle_intermediate_set( - intermediate_outputs=exclusive_outputs, - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - map_exit_2=map_exit_2, - is_exclusive_set=True, - ) - if len(shared_outputs) != 0: - self.handle_intermediate_set( - intermediate_outputs=shared_outputs, - state=graph, - sdfg=sdfg, - map_exit_1=map_exit_1, - map_entry_2=map_entry_2, - map_exit_2=map_exit_2, - is_exclusive_set=False, - ) - assert pure_outputs == set(graph.out_edges(map_exit_1)) - if len(pure_outputs) != 0: - self.relocate_nodes( - from_node=map_exit_1, - to_node=map_exit_2, - state=graph, - sdfg=sdfg, - ) - - # Above we have handled the input of the second map and moved them - # to the first map, now we must move the output of the first map - # to the second one, as this one is used. - self.relocate_nodes( - from_node=map_entry_2, - to_node=map_entry_1, - state=graph, - sdfg=sdfg, - ) - - for node_to_remove in [map_exit_1, map_entry_2]: - assert graph.degree(node_to_remove) == 0 - graph.remove_node(node_to_remove) - - # Now turn the second output node into the output node of the first Map. - map_exit_2.map = map_entry_1.map - - @staticmethod - def handle_intermediate_set( - intermediate_outputs: set[dace_graph.MultiConnectorEdge[dace.Memlet]], - state: dace.SDFGState, - sdfg: dace.SDFG, - map_exit_1: dace_nodes.MapExit, - map_entry_2: dace_nodes.MapEntry, - map_exit_2: dace_nodes.MapExit, - is_exclusive_set: bool, - ) -> None: - """This function handles the intermediate sets. - - The function is able to handle both the shared and exclusive intermediate - output set, see `partition_first_outputs()`. The main difference is that - in exclusive mode the intermediate nodes will be fully removed from - the SDFG. While in shared mode the intermediate node will be preserved. - - Args: - intermediate_outputs: The set of outputs, that should be processed. - state: The state in which the map is processed. - sdfg: The SDFG that should be optimized. - map_exit_1: The exit of the first/top map. - map_entry_2: The entry of the second map. - map_exit_2: The exit of the second map. - is_exclusive_set: If `True` `intermediate_outputs` is the exclusive set. - - Notes: - Before the transformation the `state` does not have to be valid and - after this function has run the state is (most likely) invalid. - - Todo: - Rewrite using `MemletTree`. - """ - - # Essentially this function removes the AccessNode between the two maps. - # However, we still need some temporary memory that we can use, which is - # just much smaller, i.e. a scalar. But all Memlets inside the second map - # assumes that the intermediate memory has the bigger shape. - # To fix that we will create this replacement dict that will replace all - # occurrences of the iteration variables of the second map with zero. - # Note that this is still not enough as the dimensionality might be different. - memlet_repl: dict[str, int] = {str(param): 0 for param in map_entry_2.map.params} - - # Now we will iterate over all intermediate edges and process them. - # If not stated otherwise the comments assume that we run in exclusive mode. - for out_edge in intermediate_outputs: - # This is the intermediate node that, that we want to get rid of. - # In shared mode we want to recreate it after the second map. - inter_node: dace_nodes.AccessNode = out_edge.dst - inter_name = inter_node.data - inter_desc = inter_node.desc(sdfg) - inter_shape = inter_desc.shape - - # Now we will determine the shape of the new intermediate. This size of - # this temporary is given by the Memlet that goes into the first map exit. - pre_exit_edges = list( - state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:]) - ) - if len(pre_exit_edges) != 1: - raise NotImplementedError() - pre_exit_edge = pre_exit_edges[0] - new_inter_shape_raw = dace_symbolic.overapproximate(pre_exit_edge.data.subset.size()) - - # Over approximation will leave us with some unneeded size one dimensions. - # That are known to cause some troubles, so we will now remove them. - squeezed_dims: list[int] = [] # These are the dimensions we removed. - new_inter_shape: list[int] = [] # This is the final shape of the new intermediate. - for dim, (proposed_dim_size, full_dim_size) in enumerate( - zip(new_inter_shape_raw, inter_shape) - ): - # Order of checks is important! - if full_dim_size == 1: # Must be kept! - new_inter_shape.append(proposed_dim_size) - elif proposed_dim_size == 1: # This dimension was reduced, so we can remove it. - squeezed_dims.append(dim) - else: - new_inter_shape.append(proposed_dim_size) - - # This is the name of the new "intermediate" node that we will create. - # It will only have the shape `new_inter_shape` which is basically its - # output within one Map iteration. - # NOTE: The insertion process might generate a new name. - new_inter_name: str = f"__s{sdfg.node_id(state)}_n{state.node_id(out_edge.src)}{out_edge.src_conn}_n{state.node_id(out_edge.dst)}{out_edge.dst_conn}" - - # Now generate the intermediate data container. - if len(new_inter_shape) == 0: - assert pre_exit_edge.data.subset.num_elements() == 1 - is_scalar = True - new_inter_name, new_inter_desc = sdfg.add_scalar( - new_inter_name, - dtype=inter_desc.dtype, - transient=True, - storage=dace_dtypes.StorageType.Register, - find_new_name=True, - ) - - else: - assert (pre_exit_edge.data.subset.num_elements() > 1) or all( - x == 1 for x in new_inter_shape - ) - is_scalar = False - new_inter_name, new_inter_desc = sdfg.add_transient( - new_inter_name, - shape=new_inter_shape, - dtype=inter_desc.dtype, - find_new_name=True, - ) - new_inter_node: dace_nodes.AccessNode = state.add_access(new_inter_name) - - # New we will reroute the output Memlet, thus it will no longer pass - # through the Map exit but through the newly created intermediate. - # we will delete the previous edge later. - pre_exit_memlet: dace.Memlet = pre_exit_edge.data - new_pre_exit_memlet = copy.deepcopy(pre_exit_memlet) - - # We might operate on a different array, but the check below, ensures - # that we do not change the direction of the Memlet. - assert pre_exit_memlet.data == inter_name - new_pre_exit_memlet.data = new_inter_name - - # Now we have to modify the subset of the Memlet. - # Before the subset of the Memlet was dependent on the Map variables, - # however, this is no longer the case, as we removed them. This change - # has to be reflected in the Memlet. - # NOTE: Assert above ensures that the below is correct. - new_pre_exit_memlet.replace(memlet_repl) - if is_scalar: - new_pre_exit_memlet.subset = "0" - new_pre_exit_memlet.other_subset = None - else: - new_pre_exit_memlet.subset.pop(squeezed_dims) - - # Now we create the new edge between the producer and the new output - # (the new intermediate node). We will remove the old edge further down. - new_pre_exit_edge = state.add_edge( - pre_exit_edge.src, - pre_exit_edge.src_conn, - new_inter_node, - None, - new_pre_exit_memlet, - ) - - # We just have handled the last Memlet, but we must actually handle the - # whole producer side, i.e. the scope of the top Map. - for producer_tree in state.memlet_tree(new_pre_exit_edge).traverse_children(): - producer_edge = producer_tree.edge - - # Ensure the correctness of the rerouting below. - # TODO(phimuell): Improve the code below to remove the check. - assert producer_edge.data.data == inter_name - - # Will not change the direction, because of test above! - producer_edge.data.data = new_inter_name - producer_edge.data.replace(memlet_repl) - if is_scalar: - producer_edge.data.dst_subset = "0" - elif producer_edge.data.dst_subset is not None: - producer_edge.data.dst_subset.pop(squeezed_dims) - - # Now after we have handled the input of the new intermediate node, - # we must handle its output. For this we have to "inject" the newly - # created intermediate into the second map. We do this by finding - # the input connectors on the map entry, such that we know where we - # have to reroute inside the Map. - # NOTE: Assumes that map (if connected is the direct neighbour). - conn_names: set[str] = set() - for inter_node_out_edge in state.out_edges(inter_node): - if inter_node_out_edge.dst == map_entry_2: - assert inter_node_out_edge.dst_conn.startswith("IN_") - conn_names.add(inter_node_out_edge.dst_conn) - else: - # If we found another target than the second map entry from the - # intermediate node it means that the node _must_ survive, - # i.e. we are not in exclusive mode. - assert not is_exclusive_set - - # Now we will reroute the connections inside the second map, i.e. - # instead of consuming the old intermediate node, they will now - # consume the new intermediate node. - for in_conn_name in conn_names: - out_conn_name = "OUT_" + in_conn_name[3:] - - for inner_edge in state.out_edges_by_connector(map_entry_2, out_conn_name): - assert inner_edge.data.data == inter_name # DIRECTION!! - - # The create the first Memlet to transmit information, within - # the second map, we do this again by copying and modifying - # the original Memlet. - # NOTE: Test above is important to ensure the direction of the - # Memlet and the correctness of the code below. - new_inner_memlet = copy.deepcopy(inner_edge.data) - new_inner_memlet.replace(memlet_repl) - new_inner_memlet.data = new_inter_name # Because of the assert above, this will not change the direction. - - # Now remove the old edge, that started the second map entry. - # Also add the new edge that started at the new intermediate. - state.remove_edge(inner_edge) - new_inner_edge = state.add_edge( - new_inter_node, - None, - inner_edge.dst, - inner_edge.dst_conn, - new_inner_memlet, - ) - - # Now we do subset modification to ensure that nothing failed. - if is_scalar: - new_inner_memlet.src_subset = "0" - elif new_inner_memlet.src_subset is not None: - new_inner_memlet.src_subset.pop(squeezed_dims) - - # Now clean the Memlets of that tree to use the new intermediate node. - for consumer_tree in state.memlet_tree(new_inner_edge).traverse_children(): - consumer_edge = consumer_tree.edge - assert consumer_edge.data.data == inter_name - consumer_edge.data.data = new_inter_name - consumer_edge.data.replace(memlet_repl) - if is_scalar: - consumer_edge.data.src_subset = "0" - elif consumer_edge.data.subset is not None: - consumer_edge.data.subset.pop(squeezed_dims) - - # The edge that leaves the second map entry was already deleted. - # We will now delete the edges that brought the data. - for edge in state.in_edges_by_connector(map_entry_2, in_conn_name): - assert edge.src == inter_node - state.remove_edge(edge) - map_entry_2.remove_in_connector(in_conn_name) - map_entry_2.remove_out_connector(out_conn_name) - - if is_exclusive_set: - # In exclusive mode the old intermediate node is no longer needed. - assert state.degree(inter_node) == 1 - state.remove_edge_and_connectors(out_edge) - state.remove_node(inter_node) - - state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) - map_exit_1.remove_out_connector(out_edge.src_conn) - del sdfg.arrays[inter_name] - - else: - # This is the shared mode, so we have to recreate the intermediate - # node, but this time it is at the exit of the second map. - state.remove_edge(pre_exit_edge) - map_exit_1.remove_in_connector(pre_exit_edge.dst_conn) - - # This is the Memlet that goes from the map internal intermediate - # temporary node to the Map output. This will essentially restore - # or preserve the output for the intermediate node. It is important - # that we use the data that `preExitEdge` was used. - new_exit_memlet = copy.deepcopy(pre_exit_edge.data) - assert new_exit_memlet.data == inter_name - new_exit_memlet.subset = pre_exit_edge.data.dst_subset - new_exit_memlet.other_subset = ( - "0" if is_scalar else dace_subsets.Range.from_array(inter_desc) - ) - - new_pre_exit_conn = map_exit_2.next_connector() - state.add_edge( - new_inter_node, - None, - map_exit_2, - "IN_" + new_pre_exit_conn, - new_exit_memlet, - ) - state.add_edge( - map_exit_2, - "OUT_" + new_pre_exit_conn, - inter_node, - out_edge.dst_conn, - copy.deepcopy(out_edge.data), - ) - map_exit_2.add_in_connector("IN_" + new_pre_exit_conn) - map_exit_2.add_out_connector("OUT_" + new_pre_exit_conn) - - map_exit_1.remove_out_connector(out_edge.src_conn) - state.remove_edge(out_edge) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py new file mode 100644 index 0000000000..6b7bd1b6d5 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -0,0 +1,1010 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +"""The GT4Py specific simplification pass.""" + +import collections +import copy +import uuid +from typing import Any, Final, Iterable, Optional, TypeAlias + +import dace +from dace import ( + data as dace_data, + properties as dace_properties, + subsets as dace_subsets, + transformation as dace_transformation, +) +from dace.sdfg import nodes as dace_nodes +from dace.transformation import ( + dataflow as dace_dataflow, + pass_pipeline as dace_ppl, + passes as dace_passes, +) + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +GT_SIMPLIFY_DEFAULT_SKIP_SET: Final[set[str]] = {"ScalarToSymbolPromotion", "ConstantPropagation"} +"""Set of simplify passes `gt_simplify()` skips by default. + +The following passes are included: +- `ScalarToSymbolPromotion`: The lowering has sometimes to turn a scalar into a + symbol or vice versa and at a later point to invert this again. However, this + pass has some problems with this pattern so for the time being it is disabled. +- `ConstantPropagation`: Same reasons as `ScalarToSymbolPromotion`. +""" + + +def gt_simplify( + sdfg: dace.SDFG, + validate: bool = True, + validate_all: bool = False, + skip: Optional[Iterable[str]] = None, +) -> Optional[dict[str, Any]]: + """Performs simplifications on the SDFG in place. + + Instead of calling `sdfg.simplify()` directly, you should use this function, + as it is specially tuned for GridTool based SDFGs. + + This function runs the DaCe simplification pass, but the following passes are + replaced: + - `InlineSDFGs`: Instead `gt_inline_nested_sdfg()` will be called. + + Further, the function will run the following passes in addition to DaCe simplify: + - `GT4PyGlobalSelfCopyElimination`: Special copy pattern that in the context + of GT4Py based SDFG behaves as a no op. + + Furthermore, by default, or if `None` is passed for `skip` the passes listed in + `GT_SIMPLIFY_DEFAULT_SKIP_SET` will be skipped. + + Args: + sdfg: The SDFG to optimize. + validate: Perform validation after the pass has run. + validate_all: Perform extensive validation. + skip: List of simplify passes that should not be applied, defaults + to `GT_SIMPLIFY_DEFAULT_SKIP_SET`. + + Note: + Currently DaCe does not provide a way to inject or exchange sub passes in + simplify. The custom inline pass is run at the beginning and the array + elimination at the end. The whole process is run inside a loop that ensures + that `gt_simplify()` results in a fix point. + """ + # Ensure that `skip` is a `set` + skip = GT_SIMPLIFY_DEFAULT_SKIP_SET if skip is None else set(skip) + + result: Optional[dict[str, Any]] = None + + at_least_one_xtrans_run = True + + while at_least_one_xtrans_run: + at_least_one_xtrans_run = False + + if "InlineSDFGs" not in skip: + inline_res = gt_inline_nested_sdfg( + sdfg=sdfg, + multistate=True, + permissive=False, + validate=validate, + validate_all=validate_all, + ) + if inline_res is not None: + at_least_one_xtrans_run = True + result = result or {} + result.update(inline_res) + + simplify_res = dace_passes.SimplifyPass( + validate=validate, + validate_all=validate_all, + verbose=False, + skip=(skip | {"InlineSDFGs"}), + ).apply_pass(sdfg, {}) + + if simplify_res is not None: + at_least_one_xtrans_run = True + result = result or {} + result.update(simplify_res) + + if "GT4PyGlobalSelfCopyElimination" not in skip: + self_copy_removal_result = sdfg.apply_transformations_repeated( + GT4PyGlobalSelfCopyElimination(), + validate=validate, + validate_all=validate_all, + ) + if self_copy_removal_result > 0: + at_least_one_xtrans_run = True + result = result or {} + result.setdefault("GT4PyGlobalSelfCopyElimination", 0) + result["GT4PyGlobalSelfCopyElimination"] += self_copy_removal_result + + return result + + +def gt_inline_nested_sdfg( + sdfg: dace.SDFG, + multistate: bool = True, + permissive: bool = False, + validate: bool = True, + validate_all: bool = False, +) -> Optional[dict[str, int]]: + """Perform inlining of nested SDFG into their parent SDFG. + + The function uses DaCe's `InlineSDFG` transformation, the same used in simplify. + However, before the inline transformation is run the function will run some + cleaning passes that allows inlining nested SDFGs. + As a side effect, the function will split stages into more states. + + Args: + sdfg: The SDFG that should be processed, will be modified in place and returned. + multistate: Allow inlining of multistate nested SDFG, defaults to `True`. + permissive: Be less strict on the accepted SDFGs. + validate: Perform validation after the transformation has finished. + validate_all: Performs extensive validation. + """ + first_iteration = True + nb_preproccess_total = 0 + nb_inlines_total = 0 + while True: + nb_preproccess = sdfg.apply_transformations_repeated( + [dace_dataflow.PruneSymbols, dace_dataflow.PruneConnectors], + validate=False, + validate_all=validate_all, + ) + nb_preproccess_total += nb_preproccess + if (nb_preproccess == 0) and (not first_iteration): + break + + # Create and configure the inline pass + inline_sdfg = dace_passes.InlineSDFGs() + inline_sdfg.progress = False + inline_sdfg.permissive = permissive + inline_sdfg.multistate = multistate + + # Apply the inline pass + # The pass returns `None` no indicate "nothing was done" + nb_inlines = inline_sdfg.apply_pass(sdfg, {}) or 0 + nb_inlines_total += nb_inlines + + # Check result, if needed and test if we can stop + if validate_all or validate: + sdfg.validate() + if nb_inlines == 0: + break + first_iteration = False + + result: dict[str, int] = {} + if nb_inlines_total != 0: + result["InlineSDFGs"] = nb_inlines_total + if nb_preproccess_total != 0: + result["PruneSymbols|PruneConnectors"] = nb_preproccess_total + return result if result else None + + +def gt_substitute_compiletime_symbols( + sdfg: dace.SDFG, + repl: dict[str, Any], + validate: bool = False, + validate_all: bool = False, +) -> None: + """Substitutes symbols that are known at compile time with their value. + + Some symbols are known to have a constant value. This function will remove these + symbols from the SDFG and replace them with the value. + An example where this makes sense are strides that are known to be one. + + Args: + sdfg: The SDFG to process. + repl: Maps the name of the symbol to the value it should be replaced with. + validate: Perform validation at the end of the function. + validate_all: Perform validation also on intermediate steps. + """ + + # We will use the `replace` function of the top SDFG, however, lower levels + # are handled using ConstantPropagation. + sdfg.replace_dict(repl) + + const_prop = dace_passes.ConstantPropagation() + const_prop.recursive = True + const_prop.progress = False + + const_prop.apply_pass( + sdfg=sdfg, + initial_symbols=repl, + _=None, + ) + gt_simplify( + sdfg=sdfg, + validate=validate, + validate_all=validate_all, + ) + dace.sdfg.propagation.propagate_memlets_sdfg(sdfg) + + +def gt_reduce_distributed_buffering( + sdfg: dace.SDFG, +) -> Optional[dict[dace.SDFG, dict[dace.SDFGState, set[str]]]]: + """Removes distributed write back buffers.""" + pipeline = dace_ppl.Pipeline([DistributedBufferRelocator()]) + all_result = {} + + for rsdfg in sdfg.all_sdfgs_recursive(): + ret = pipeline.apply_pass(sdfg, {}) + if ret is not None: + all_result[rsdfg] = ret + + return all_result + + +@dace_properties.make_properties +class GT4PyGlobalSelfCopyElimination(dace_transformation.SingleStateTransformation): + """Remove global self copy. + + This transformation matches the following case `(G) -> (T) -> (G)`, i.e. `G` + is read from and written too at the same time, however, in between is `T` + used as a buffer. In the example above `G` is a global memory and `T` is a + temporary. This situation is generated by the lowering if the data node is + not needed (because the computation on it is only conditional). + + In case `G` refers to global memory rule 3 of ADR-18 guarantees that we can + only have a point wise dependency of the output on the input. + This transformation will remove the write into `G`, i.e. we thus only have + `(G) -> (T)`. The read of `G` and the definition of `T`, will only be removed + if `T` is not used downstream. If it is used `T` will be maintained. + """ + + node_read_g = dace_transformation.PatternNode(dace_nodes.AccessNode) + node_tmp = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) + node_write_g = dace_transformation.PatternNode(dace_nodes.AccessNode) + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.node_read_g, cls.node_tmp, cls.node_write_g)] + + def can_be_applied( + self, + graph: dace.SDFGState | dace.SDFG, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + read_g = self.node_read_g + write_g = self.node_write_g + tmp_node = self.node_tmp + g_desc = read_g.desc(sdfg) + tmp_desc = tmp_node.desc(sdfg) + + # NOTE: We do not check if `G` is read downstream. + if read_g.data != write_g.data: + return False + if g_desc.transient: + return False + if not tmp_desc.transient: + return False + if graph.in_degree(read_g) != 0: + return False + if graph.out_degree(read_g) != 1: + return False + if graph.degree(tmp_node) != 2: + return False + if graph.in_degree(write_g) != 1: + return False + if graph.out_degree(write_g) != 0: + return False + if graph.scope_dict()[read_g] is not None: + return False + + return True + + def _is_read_downstream( + self, + start_state: dace.SDFGState, + sdfg: dace.SDFG, + data_to_look: str, + ) -> bool: + """Scans for reads to `data_to_look`. + + The function will go through states that are reachable from `start_state` + (including) and test if there is a read to the data container `data_to_look`. + It will return `True` the first time it finds such a node. + It is important that the matched nodes, i.e. `self.node_{read_g, write_g, tmp}` + are ignored. + + Args: + start_state: The state where the scanning starts. + sdfg: The SDFG on which we operate. + data_to_look: The data that we want to look for. + + Todo: + Port this function to use DaCe pass pipeline. + """ + read_g: dace_nodes.AccessNode = self.node_read_g + write_g: dace_nodes.AccessNode = self.node_write_g + tmp_node: dace_nodes.AccessNode = self.node_tmp + + return gtx_transformations.util.is_accessed_downstream( + start_state=start_state, + sdfg=sdfg, + data_to_look=data_to_look, + nodes_to_ignore={read_g, write_g, tmp_node}, + ) + + def apply( + self, + graph: dace.SDFGState | dace.SDFG, + sdfg: dace.SDFG, + ) -> None: + read_g: dace_nodes.AccessNode = self.node_read_g + write_g: dace_nodes.AccessNode = self.node_write_g + tmp_node: dace_nodes.AccessNode = self.node_tmp + + # We first check if `T`, the intermediate is not used downstream. In this + # case we can remove the read to `G` and `T` itself from the SDFG. + # We have to do this check before, because the matching is not fully stable. + is_tmp_used_downstream = self._is_read_downstream( + start_state=graph, sdfg=sdfg, data_to_look=tmp_node.data + ) + + # The write to `G` can always be removed. + graph.remove_node(write_g) + + # Also remove the read to `G` and `T` from the SDFG if possible. + if not is_tmp_used_downstream: + graph.remove_node(read_g) + graph.remove_node(tmp_node) + # It could still be used in a parallel branch. + try: + sdfg.remove_data(tmp_node.data, validate=True) + except ValueError as e: + if not str(e).startswith(f"Cannot remove data descriptor {tmp_node.data}:"): + raise + + +AccessLocation: TypeAlias = tuple[dace.SDFGState, dace_nodes.AccessNode] +"""Describes an access node and the state in which it is located. +""" + + +@dace_properties.make_properties +class DistributedBufferRelocator(dace_transformation.Pass): + """Moves the final write back of the results to where it is needed. + + In certain cases, especially in case where we have `if` the result is computed + in each branch and then in the join state written back. Thus there is some + additional storage needed. + The transformation will look for the following situation: + - A transient data container, called `src_cont`, is written into another + container, called `dst_cont`, which is not transient. + - The access node of `src_cont` has an in degree of zero and an out degree of one. + - The access node of `dst_cont` has an in degree of of one and an + out degree of zero (this might be lifted). + - `src_cont` is not used afterwards. + - `dst_cont` is only used to implement the buffering. + + The function will relocate the writing of `dst_cont` to where `src_cont` is + written, which might be multiple locations. + It will also remove the writing back. + It is advised that after this transformation simplify is run again. + + Note: + Essentially this transformation removes the double buffering of `dst_cont`. + Because we ensure that that `dst_cont` is non transient this is okay, as our + rule guarantees this. + + Todo: + - Allow that `dst_cont` can also be transient. + - Allow that `dst_cont` does not need to be a sink node, this is most + likely most relevant if it is transient. + - Check if `dst_cont` is used between where we want to place it and + where it is currently used. + """ + + def modifies(self) -> dace_ppl.Modifies: + return dace_ppl.Modifies.Memlets | dace_ppl.Modifies.AccessNodes + + def should_reapply(self, modified: dace_ppl.Modifies) -> bool: + return modified & (dace_ppl.Modifies.Memlets | dace_ppl.Modifies.AccessNodes) + + def depends_on(self) -> set[type[dace_transformation.Pass]]: + return { + dace_transformation.passes.StateReachability, + dace_transformation.passes.AccessSets, + } + + def apply_pass( + self, sdfg: dace.SDFG, pipeline_results: dict[str, Any] + ) -> Optional[dict[dace.SDFGState, set[str]]]: + reachable: dict[dace.SDFGState, set[dace.SDFGState]] = pipeline_results[ + "StateReachability" + ][sdfg.cfg_id] + access_sets: dict[dace.SDFGState, tuple[set[str], set[str]]] = pipeline_results[ + "AccessSets" + ][sdfg.cfg_id] + result: dict[dace.SDFGState, set[str]] = collections.defaultdict(set) + + to_relocate = self._find_candidates(sdfg, reachable, access_sets) + if len(to_relocate) == 0: + return None + self._relocate_write_backs(sdfg, to_relocate) + + for (wb_an, wb_state), _ in to_relocate: + result[wb_state].add(wb_an.data) + + return result + + def _relocate_write_backs( + self, + sdfg: dace.SDFG, + to_relocate: list[tuple[AccessLocation, list[AccessLocation]]], + ) -> None: + """Perform the actual relocation.""" + for (wb_an, wb_state), def_locations in to_relocate: + # Get the memlet that we have to replicate. + wb_edge = next(iter(wb_state.out_edges(wb_an))) + wb_memlet: dace.Memlet = wb_edge.data + final_dest_name: str = wb_edge.dst.data + + for def_an, def_state in def_locations: + def_state.add_edge( + def_an, + wb_edge.src_conn, + def_state.add_access(final_dest_name), + wb_edge.dst_conn, + copy.deepcopy(wb_memlet), + ) + + # Now remove the old node and if the old target become isolated + # remove that as well. + old_dst = wb_edge.dst + wb_state.remove_node(wb_an) + if wb_state.degree(old_dst) == 0: + wb_state.remove_node(old_dst) + + def _find_candidates( + self, + sdfg: dace.SDFG, + reachable: dict[dace.SDFGState, set[dace.SDFGState]], + access_sets: dict[dace.SDFGState, tuple[set[str], set[str]]], + ) -> list[tuple[AccessLocation, list[AccessLocation]]]: + """Determines all temporaries that have to be relocated. + + Returns: + A list of tuples. The first element element of the tuple is an + `AccessLocation` that describes where the temporary is read. + The second element is a list of `AccessLocation`s that describes + where the temporary is defined. + """ + # All nodes that are used as distributed buffers. + candidate_src_cont: list[AccessLocation] = [] + + # Which `src_cont` access node is written back to which global memory. + src_cont_to_global: dict[dace_nodes.AccessNode, str] = {} + + for state in sdfg.states(): + # These are the possible targets we want to write into. + candidate_dst_nodes: set[dace_nodes.AccessNode] = { + node + for node in state.sink_nodes() + if ( + isinstance(node, dace_nodes.AccessNode) + and state.in_degree(node) == 1 + and (not node.desc(sdfg).transient) + ) + } + if len(candidate_dst_nodes) == 0: + continue + + for src_cont in state.source_nodes(): + if not isinstance(src_cont, dace_nodes.AccessNode): + continue + if not src_cont.desc(sdfg).transient: + continue + if state.out_degree(src_cont) != 1: + continue + dst_candidate: dace_nodes.AccessNode = next( + iter(edge.dst for edge in state.out_edges(src_cont)) + ) + if dst_candidate not in candidate_dst_nodes: + continue + candidate_src_cont.append((src_cont, state)) + src_cont_to_global[src_cont] = dst_candidate.data + + if len(candidate_src_cont) == 0: + return [] + + # Now we have to find the places where the temporary sources are defined. + # I.e. This is also the location where the original value is defined. + result_candidates: list[tuple[AccessLocation, list[AccessLocation]]] = [] + + def find_upstream_states(dst_state: dace.SDFGState) -> set[dace.SDFGState]: + return { + src_state + for src_state in sdfg.states() + if dst_state in reachable[src_state] and dst_state is not src_state + } + + for src_cont in candidate_src_cont: + def_locations: list[AccessLocation] = [] + for upstream_state in find_upstream_states(src_cont[1]): + if src_cont[0].data in access_sets[upstream_state][1]: + def_locations.extend( + (data_node, upstream_state) + for data_node in upstream_state.data_nodes() + if data_node.data == src_cont[0].data + ) + if len(def_locations) != 0: + result_candidates.append((src_cont, def_locations)) + + # This transformation removes `src_cont` by writing its content directly + # to `dst_cont`, at the point where it is defined. + # For this transformation to be valid the following conditions have to be met: + # - Between the definition of `src_cont` and the write back to `dst_cont`, + # `dst_cont` can not be accessed. + # - Between the definitions of `src_cont` and the point where it is written + # back, `src_cont` can only be accessed in the range that is written back. + # - After the write back point, `src_cont` shall not be accessed. This + # restriction could be lifted. + # + # To keep the implementation simple, we use the conditions: + # - `src_cont` is only accessed were it is defined and at the write back + # point. + # - Between the definitions of `src_cont` and the write back point, + # `dst_cont` is not used. + + result: list[tuple[AccessLocation, list[AccessLocation]]] = [] + + for wb_localation, def_locations in result_candidates: + for def_node, def_state in def_locations: + # Test if `src_cont` is only accessed where it is defined and + # where it is written back. + if gtx_transformations.util.is_accessed_downstream( + start_state=def_state, + sdfg=sdfg, + data_to_look=wb_localation[0].data, + nodes_to_ignore={def_node, wb_localation[0]}, + ): + break + # check if the global data is not used between the definition of + # `dst_cont` and where its written back. We allow one exception, + # if the global data is used in the state the distributed temporary + # is defined is used only for reading then it is ignored. This is + # allowed because of rule 3 of ADR0018. + glob_nodes_in_def_state = { + dnode + for dnode in def_state.data_nodes() + if dnode.data == src_cont_to_global[wb_localation[0]] + } + if any(def_state.in_degree(gdnode) != 0 for gdnode in glob_nodes_in_def_state): + break + if gtx_transformations.util.is_accessed_downstream( + start_state=def_state, + sdfg=sdfg, + data_to_look=src_cont_to_global[wb_localation[0]], + nodes_to_ignore=glob_nodes_in_def_state, + states_to_ignore={wb_localation[1]}, + ): + break + else: + result.append((wb_localation, def_locations)) + + return result + + +@dace_properties.make_properties +class GT4PyMoveTaskletIntoMap(dace_transformation.SingleStateTransformation): + """Moves a Tasklet, with no input into a map. + + Tasklets without inputs, are mostly used to generate constants. + However, if they are outside a Map, then this constant value is an + argument to the kernel, and can not be used by the compiler. + + This transformation moves such Tasklets into a Map scope. + """ + + tasklet = dace_transformation.PatternNode(dace_nodes.Tasklet) + access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.tasklet, cls.access_node, cls.map_entry)] + + def can_be_applied( + self, + graph: dace.SDFGState | dace.SDFG, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + tasklet: dace_nodes.Tasklet = self.tasklet + access_node: dace_nodes.AccessNode = self.access_node + access_desc: dace_data.Data = access_node.desc(sdfg) + map_entry: dace_nodes.MapEntry = self.map_entry + + if graph.in_degree(tasklet) != 0: + return False + if graph.out_degree(tasklet) != 1: + return False + if tasklet.has_side_effects(sdfg): + return False + if tasklet.code_init.as_string: + return False + if tasklet.code_exit.as_string: + return False + if tasklet.code_global.as_string: + return False + if tasklet.state_fields: + return False + if not isinstance(access_desc, dace_data.Scalar): + return False + if not access_desc.transient: + return False + if not any( + edge.dst_conn and edge.dst_conn.startswith("IN_") + for edge in graph.out_edges(access_node) + if edge.dst is map_entry + ): + return False + # NOTE: We allow that the access node is used in multiple places. + + return True + + def apply( + self, + graph: dace.SDFGState | dace.SDFG, + sdfg: dace.SDFG, + ) -> None: + tasklet: dace_nodes.Tasklet = self.tasklet + access_node: dace_nodes.AccessNode = self.access_node + access_desc: dace_data.Scalar = access_node.desc(sdfg) + map_entry: dace_nodes.MapEntry = self.map_entry + + # Find _a_ connection that leads from the access node to the map. + edge_to_map = next( + iter( + edge + for edge in graph.out_edges(access_node) + if edge.dst is map_entry and edge.dst_conn.startswith("IN_") + ) + ) + connector_name: str = edge_to_map.dst_conn[3:] + + # This is the tasklet that we will put inside the map, note we have to do it + # this way to avoid some name clash stuff. + inner_tasklet: dace_nodes.Tasklet = graph.add_tasklet( + name=f"{tasklet.label}__clone_{str(uuid.uuid1()).replace('-', '_')}", + outputs=tasklet.out_connectors.keys(), + inputs=set(), + code=tasklet.code, + language=tasklet.language, + debuginfo=tasklet.debuginfo, + ) + inner_desc: dace_data.Scalar = access_desc.clone() + inner_data_name: str = sdfg.add_datadesc(access_node.data, inner_desc, find_new_name=True) + inner_an: dace_nodes.AccessNode = graph.add_access(inner_data_name) + + # Connect the tasklet with the map entry and the access node. + graph.add_nedge(map_entry, inner_tasklet, dace.Memlet()) + graph.add_edge( + inner_tasklet, + next(iter(inner_tasklet.out_connectors.keys())), + inner_an, + None, + dace.Memlet(f"{inner_data_name}[0]"), + ) + + # Now we will reroute the edges went through the inner map, through the + # inner access node instead. + for old_inner_edge in list( + graph.out_edges_by_connector(map_entry, "OUT_" + connector_name) + ): + # We now modify the downstream data. This is because we no longer refer + # to the data outside but the one inside. + self._modify_downstream_memlets( + state=graph, + edge=old_inner_edge, + old_data=access_node.data, + new_data=inner_data_name, + ) + + # After we have changed the properties of the MemletTree of `edge` + # we will now reroute it, such that the inner access node is used. + graph.add_edge( + inner_an, + None, + old_inner_edge.dst, + old_inner_edge.dst_conn, + old_inner_edge.data, + ) + graph.remove_edge(old_inner_edge) + map_entry.remove_in_connector("IN_" + connector_name) + map_entry.remove_out_connector("OUT_" + connector_name) + + # Now we can remove the map connection between the outer/old access + # node and the map. + graph.remove_edge(edge_to_map) + + # The data is no longer referenced in this state, so we can potentially + # remove + if graph.out_degree(access_node) == 0: + if not gtx_transformations.util.is_accessed_downstream( + start_state=graph, + sdfg=sdfg, + data_to_look=access_node.data, + nodes_to_ignore={access_node}, + ): + graph.remove_nodes_from([tasklet, access_node]) + # Needed if data is accessed in a parallel branch. + try: + sdfg.remove_data(access_node.data, validate=True) + except ValueError as e: + if not str(e).startswith(f"Cannot remove data descriptor {access_node.data}:"): + raise + + def _modify_downstream_memlets( + self, + state: dace.SDFGState, + edge: dace.sdfg.graph.MultiConnectorEdge, + old_data: str, + new_data: str, + ) -> None: + """Replaces the data along on the tree defined by `edge`. + + The function will traverse the MemletTree defined by `edge`. + Any Memlet that refers to `old_data` will be replaced with + `new_data`. + + Args: + state: The sate in which we operate. + edge: The edge defining the MemletTree. + old_data: The name of the data that should be replaced. + new_data: The name of the new data the Memlet should refer to. + """ + mtree: dace.memlet.MemletTree = state.memlet_tree(edge) + for tedge in mtree.traverse_children(True): + # Because we only change the name of the data, we do not change the + # direction of the Memlet, so `{src, dst}_subset` will remain the same. + if tedge.edge.data.data == old_data: + tedge.edge.data.data = new_data + + +@dace_properties.make_properties +class GT4PyMapBufferElimination(dace_transformation.SingleStateTransformation): + """Allows to remove unneeded buffering at map output. + + The transformation matches the case `MapExit -> (T) -> (G)`, where `T` is an + AccessNode referring to a transient and `G` an AccessNode that refers to non + transient memory. + If the following conditions are met then `T` is removed. + - `T` is not used to filter computations, i.e. what is written into `G` + is covered by what is written into `T`. + - `T` is not used anywhere else. + - `G` is not also an input to the map, except there is only a pointwise + dependency in `G`, see the note below. + - Everything needs to be at top scope. + + Notes: + - Rule 3 of ADR18 should guarantee that any valid GT4Py program meets the + point wise dependency in `G`, for that reason it is possible to disable + this test by specifying `assume_pointwise`. + + Todo: + - Implement a real pointwise test. + """ + + map_exit = dace_transformation.PatternNode(dace_nodes.MapExit) + tmp_ac = dace_transformation.transformation.PatternNode(dace_nodes.AccessNode) + glob_ac = dace_transformation.PatternNode(dace_nodes.AccessNode) + + assume_pointwise = dace_properties.Property( + dtype=bool, + default=False, + desc="Dimensions that should become the leading dimension.", + ) + + def __init__( + self, + assume_pointwise: Optional[bool] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + if assume_pointwise is not None: + self.assume_pointwise = assume_pointwise + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.map_exit, cls.tmp_ac, cls.glob_ac)] + + def depends_on(self) -> set[type[dace_transformation.Pass]]: + return {dace_transformation.passes.ConsolidateEdges} + + def can_be_applied( + self, + graph: dace.SDFGState | dace.SDFG, + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + tmp_ac: dace_nodes.AccessNode = self.tmp_ac + glob_ac: dace_nodes.AccessNode = self.glob_ac + tmp_desc: dace_data.Data = tmp_ac.desc(sdfg) + glob_desc: dace_data.Data = glob_ac.desc(sdfg) + + if not tmp_desc.transient: + return False + if glob_desc.transient: + return False + if graph.in_degree(tmp_ac) != 1: + return False + if any(gtx_transformations.util.is_view(ac, sdfg) for ac in [tmp_ac, glob_ac]): + return False + if len(glob_desc.shape) != len(tmp_desc.shape): + return False + + # Test if we are on the top scope (it is likely). + if graph.scope_dict()[glob_ac] is not None: + return False + + # Now perform if we are point wise + if not self._perform_pointwise_test(graph, sdfg): + return False + + # Test if `tmp` is only anywhere else, this is important for removing it. + if graph.out_degree(tmp_ac) != 1: + return False + if gtx_transformations.util.is_accessed_downstream( + start_state=graph, + sdfg=sdfg, + data_to_look=tmp_ac.data, + nodes_to_ignore={tmp_ac}, + ): + return False + + # Now we ensure that `tmp` is not used to filter out some computations. + map_to_tmp_edge = next(edge for edge in graph.in_edges(tmp_ac)) + tmp_to_glob_edge = next(edge for edge in graph.out_edges(tmp_ac)) + + tmp_in_subset = map_to_tmp_edge.data.get_dst_subset(map_to_tmp_edge, graph) + tmp_out_subset = tmp_to_glob_edge.data.get_src_subset(tmp_to_glob_edge, graph) + glob_in_subset = tmp_to_glob_edge.data.get_dst_subset(tmp_to_glob_edge, graph) + if tmp_in_subset is None: + tmp_in_subset = dace_subsets.Range.from_array(tmp_desc) + if tmp_out_subset is None: + tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) + if glob_in_subset is None: + return False + + # TODO(phimuell): Do we need simplify in the check. + # TODO(phimuell): Restrict this to having the same size. + if tmp_out_subset != tmp_in_subset: + return False + return True + + def _perform_pointwise_test( + self, + state: dace.SDFGState, + sdfg: dace.SDFG, + ) -> bool: + """Test if `G` is only point wise accessed. + + This function will also consider the `assume_pointwise` property. + """ + map_exit: dace_nodes.MapExit = self.map_exit + map_entry: dace_nodes.MapEntry = state.entry_node(map_exit) + glob_ac: dace_nodes.AccessNode = self.glob_ac + glob_data: str = glob_ac.data + + # First we check if `G` is also an input to this map. + conflicting_inputs: set[dace_nodes.AccessNode] = set() + for in_edge in state.in_edges(map_entry): + if not isinstance(in_edge.src, dace_nodes.AccessNode): + continue + + # Find the source of this data, if it is a view we trace it to + # its origin. + src_node: dace_nodes.AccessNode = gtx_transformations.util.track_view( + in_edge.src, state, sdfg + ) + + # Test if there is a conflict; We do not store the source but the + # actual node that is adjacent. + if src_node.data == glob_data: + conflicting_inputs.add(in_edge.src) + + # If there are no conflicting inputs, then we are point wise. + # This is an implementation detail that make life simpler. + if len(conflicting_inputs) == 0: + return True + + # If we can assume pointwise computations, then we do not have to do + # anything. + if self.assume_pointwise: + return True + + # Currently the only test that we do is, if we have a view, then we + # are not point wise. + # TODO(phimuell): Improve/implement this. + return any(gtx_transformations.util.is_view(node, sdfg) for node in conflicting_inputs) + + def apply( + self, + graph: dace.SDFGState | dace.SDFG, + sdfg: dace.SDFG, + ) -> None: + # Removal + # Propagation ofthe shift. + map_exit: dace_nodes.MapExit = self.map_exit + tmp_ac: dace_nodes.AccessNode = self.tmp_ac + tmp_desc: dace_data.Data = tmp_ac.desc(sdfg) + tmp_data = tmp_ac.data + glob_ac: dace_nodes.AccessNode = self.glob_ac + glob_data = glob_ac.data + + map_to_tmp_edge = next(edge for edge in graph.in_edges(tmp_ac)) + tmp_to_glob_edge = next(edge for edge in graph.out_edges(tmp_ac)) + + glob_in_subset = tmp_to_glob_edge.data.get_dst_subset(tmp_to_glob_edge, graph) + tmp_out_subset = tmp_to_glob_edge.data.get_src_subset(tmp_to_glob_edge, graph) + if tmp_out_subset is None: + tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) + assert glob_in_subset is not None + + # We now remove the `tmp` node, and create a new connection between + # the global node and the map exit. + new_map_to_glob_edge = graph.add_edge( + map_exit, + map_to_tmp_edge.src_conn, + glob_ac, + tmp_to_glob_edge.dst_conn, + dace.Memlet( + data=glob_ac.data, + subset=copy.deepcopy(glob_in_subset), + ), + ) + graph.remove_edge(map_to_tmp_edge) + graph.remove_edge(tmp_to_glob_edge) + graph.remove_node(tmp_ac) + + # We can not unconditionally remove the data `tmp` refers to, because + # it could be that in a parallel branch the `tmp` is also defined. + try: + sdfg.remove_data(tmp_ac.data, validate=True) + except ValueError as e: + if not str(e).startswith(f"Cannot remove data descriptor {tmp_ac.data}:"): + raise + + # Now we must modify the memlets inside the map scope, because + # they now write into `G` instead of `tmp`, which has a different + # offset. + # NOTE: Assumes that `tmp_out_subset` and `tmp_in_subset` are the same. + correcting_offset = glob_in_subset.offset_new(tmp_out_subset, negative=True) + mtree = graph.memlet_tree(new_map_to_glob_edge) + for tree in mtree.traverse_children(include_self=False): + curr_edge = tree.edge + curr_dst_subset = curr_edge.data.get_dst_subset(curr_edge, graph) + if curr_edge.data.data == tmp_data: + curr_edge.data.data = glob_data + if curr_dst_subset is not None: + curr_dst_subset.offset(correcting_offset, negative=False) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py new file mode 100644 index 0000000000..4e254f2880 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/strides.py @@ -0,0 +1,99 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dace +from dace import data as dace_data + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + + +def gt_change_transient_strides( + sdfg: dace.SDFG, + gpu: bool, +) -> dace.SDFG: + """Modifies the strides of transients. + + The function will analyse the access patterns and set the strides of + transients in the optimal way. + The function should run after all maps have been created. + + Args: + sdfg: The SDFG to process. + gpu: If the SDFG is supposed to run on the GPU. + + Note: + Currently the function will not scan the access pattern. Instead it will + either use FORTRAN order for GPU or C order (which is assumed to be the + default, so it is a no ops). + + Todo: + - Implement the estimation correctly. + - Handle the case of nested SDFGs correctly; on the outside a transient, + but on the inside a non transient. + """ + # TODO(phimeull): Implement this function correctly. + + # We assume that by default we have C order which is already correct, + # so in this case we have a no ops + if not gpu: + return sdfg + + for nsdfg in sdfg.all_sdfgs_recursive(): + # TODO(phimuell): Handle the case when transient goes into nested SDFG + # on the inside it is a non transient, so it is ignored. + _gt_change_transient_strides_non_recursive_impl(nsdfg) + + +def _gt_change_transient_strides_non_recursive_impl( + sdfg: dace.SDFG, +) -> None: + """Essentially this function just changes the stride to FORTRAN order.""" + for top_level_transient in _find_toplevel_transients(sdfg, only_arrays=True): + desc: dace_data.Array = sdfg.arrays[top_level_transient] + ndim = len(desc.shape) + if ndim <= 1: + continue + # We assume that everything is in C order initially, to get FORTRAN order + # we simply have to reverse the order. + new_stride_order = list(range(ndim)) + desc.set_strides_from_layout(*new_stride_order) + + +def _find_toplevel_transients( + sdfg: dace.SDFG, + only_arrays: bool = False, +) -> set[str]: + """Find all top level transients in the SDFG. + + The function will scan the SDFG, ignoring nested one, and return the + name of all transients that have an access node at the top level. + However, it will ignore access nodes that refers to registers. + """ + top_level_transients: set[str] = set() + for state in sdfg.states(): + scope_dict = state.scope_dict() + for dnode in state.data_nodes(): + data: str = dnode.data + if scope_dict[dnode] is not None: + if data in top_level_transients: + top_level_transients.remove(data) + continue + elif data in top_level_transients: + continue + elif gtx_transformations.util.is_view(dnode, sdfg): + continue + desc: dace_data.Data = dnode.desc(sdfg) + + if not desc.transient: + continue + elif only_arrays and not isinstance(desc, dace_data.Array): + continue + top_level_transients.add(data) + return top_level_transients diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py index 29bae7bbe0..29c099eecf 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/util.py @@ -8,153 +8,220 @@ """Common functionality for the transformations/optimization pipeline.""" -from typing import Iterable, Union +from typing import Any, Container, Optional, Union import dace -from dace.sdfg import graph as dace_graph, nodes as dace_nodes +from dace import data as dace_data +from dace.sdfg import nodes as dace_nodes +from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -def is_nested_sdfg( - sdfg: Union[dace.SDFG, dace.SDFGState, dace_nodes.NestedSDFG], -) -> bool: - """Tests if `sdfg` is a NestedSDFG.""" - if isinstance(sdfg, dace.SDFGState): - sdfg = sdfg.parent - if isinstance(sdfg, dace_nodes.NestedSDFG): - return True - elif isinstance(sdfg, dace.SDFG): - return sdfg.parent_nsdfg_node is not None - raise TypeError(f"Does not know how to handle '{type(sdfg).__name__}'.") - - -def all_nodes_between( - graph: dace.SDFG | dace.SDFGState, - begin: dace_nodes.Node, - end: dace_nodes.Node, - reverse: bool = False, -) -> set[dace_nodes.Node] | None: - """Find all nodes that are reachable from `begin` but bound by `end`. - - Essentially the function starts a DFS at `begin`. If an edge is found that lead - to `end`, this edge is ignored. It will thus found any node that is reachable - from `begin` by a path that does not involve `end`. The returned set will - never contain `end` nor `begin`. In case `end` is never found the function - will return `None`. - - If `reverse` is set to `True` the function will start exploring at `end` and - follows the outgoing edges, i.e. the meaning of `end` and `begin` are swapped. + +def gt_make_transients_persistent( + sdfg: dace.SDFG, + device: dace.DeviceType, +) -> dict[int, set[str]]: + """ + Changes the lifetime of certain transients to `Persistent`. + + A persistent lifetime means that the transient is allocated only the very first + time the SDFG is executed and only deallocated if the underlying `CompiledSDFG` + object goes out of scope. The main advantage is, that memory must not be + allocated every time the SDFG is run. The downside is that the SDFG can not be + called by different threads. Args: - graph: The graph to operate on. - begin: The start of the DFS. - end: The terminator node of the DFS. - reverse: Perform a backward DFS. - - Notes: - - The returned set will also contain the nodes of path that starts at - `begin` and ends at a node that is not `end`. + sdfg: The SDFG to process. + device: The device type. + + Returns: + A `dict` mapping SDFG IDs to a set of transient arrays that + were made persistent. + + Note: + This function is based on a similar function in DaCe. However, the DaCe + function does, for unknown reasons, also reset the `wcr_nonatomic` property, + but only for GPU. """ + result: dict[int, set[str]] = {} + for nsdfg in sdfg.all_sdfgs_recursive(): + fsyms: set[str] = nsdfg.free_symbols + modify_lifetime: set[str] = set() + not_modify_lifetime: set[str] = set() + + for state in nsdfg.states(): + for dnode in state.data_nodes(): + if dnode.data in not_modify_lifetime: + continue - def next_nodes(node: dace_nodes.Node) -> Iterable[dace_nodes.Node]: - return ( - (edge.src for edge in graph.in_edges(node)) - if reverse - else (edge.dst for edge in graph.out_edges(node)) - ) + if dnode.data in nsdfg.constants_prop: + not_modify_lifetime.add(dnode.data) + continue - if reverse: - begin, end = end, begin + desc = dnode.desc(nsdfg) + if not desc.transient or type(desc) not in {dace.data.Array, dace.data.Scalar}: + not_modify_lifetime.add(dnode.data) + continue + if desc.storage == dace.StorageType.Register: + not_modify_lifetime.add(dnode.data) + continue - to_visit: list[dace_nodes.Node] = [begin] - seen: set[dace_nodes.Node] = set() + if desc.lifetime == dace.AllocationLifetime.External: + not_modify_lifetime.add(dnode.data) + continue - while len(to_visit) > 0: - node: dace_nodes.Node = to_visit.pop() - if node != end and node not in seen: - to_visit.extend(next_nodes(node)) - seen.add(node) + try: + # The symbols describing the total size must be a subset of the + # free symbols of the SDFG (symbols passed as argument). + # NOTE: This ignores the renaming of symbols through the + # `symbol_mapping` property of nested SDFGs. + if not set(map(str, desc.total_size.free_symbols)).issubset(fsyms): + not_modify_lifetime.add(dnode.data) + continue + except AttributeError: # total_size is an integer / has no free symbols + pass - # If `end` was not found we have to return `None` to indicate this. - if end not in seen: - return None + # Make it persistent. + modify_lifetime.add(dnode.data) - # `begin` and `end` are not included in the output set. - return seen - {begin, end} + # Now setting the lifetime. + result[nsdfg.cfg_id] = modify_lifetime - not_modify_lifetime + for aname in result[nsdfg.cfg_id]: + nsdfg.arrays[aname].lifetime = dace.AllocationLifetime.Persistent + return result -def find_downstream_consumers( - state: dace.SDFGState, - begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], - only_tasklets: bool = False, - reverse: bool = False, -) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: - """Find all downstream connectors of `begin`. - - A consumer, in for this function, is any node that is neither an entry nor - an exit node. The function returns a set of pairs, the first element is the - node that acts as consumer and the second is the edge that leads to it. - By setting `only_tasklets` the nodes the function finds are only Tasklets. - - To find this set the function starts a search at `begin`, however, it is also - possible to pass an edge as `begin`. - If `reverse` is `True` the function essentially finds the producers that are - upstream. + +def gt_find_constant_arguments( + call_args: dict[str, Any], + include: Optional[Container[str]] = None, +) -> dict[str, Any]: + """Scans the calling arguments for compile time constants. + + The output of this function can be used as input to + `gt_substitute_compiletime_symbols()`, which then removes these symbols. + + By specifying `include` it is possible to force the function to include + additional arguments, that would not be matched otherwise. Importantly, + their value is not checked. Args: - state: The state in which to look for the consumers. - begin: The initial node that from which the search starts. - only_tasklets: Return only Tasklets. - reverse: Follow the reverse direction. + call_args: The full list of arguments that will be passed to the SDFG. + include: List of arguments that should be included. """ - if isinstance(begin, dace_graph.MultiConnectorEdge): - to_visit: list[dace_graph.MultiConnectorEdge[dace.Memlet]] = [begin] - else: - to_visit = state.in_edges(begin) if reverse else state.out_edges(begin) + if include is None: + include = set() + ret_value: dict[str, Any] = {} - seen: set[dace_graph.MultiConnectorEdge[dace.Memlet]] = set() - found: set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]] = set() + for name, value in call_args.items(): + if name in include or (dace_utils.is_field_symbol(name) and value == 1): + ret_value[name] = value - while len(to_visit) > 0: - curr_edge: dace_graph.MultiConnectorEdge[dace.Memlet] = to_visit.pop() - next_node: dace_nodes.Node = curr_edge.src if reverse else curr_edge.dst - - if curr_edge in seen: - continue - seen.add(curr_edge) - - if isinstance(next_node, (dace_nodes.MapEntry, dace_nodes.MapExit)): - if not reverse: - # In forward mode a Map entry could also mean the definition of a - # dynamic map range. - if isinstance(next_node, dace_nodes.MapEntry) and ( - not curr_edge.dst_conn.startswith("IN_") - ): - if not only_tasklets: - found.add((next_node, curr_edge)) - continue - target_conn = curr_edge.dst_conn[3:] - new_edges = state.out_edges_by_connector(curr_edge.dst, "OUT_" + target_conn) - else: - target_conn = curr_edge.src_conn[4:] - new_edges = state.in_edges_by_connector(curr_edge.src, "IN_" + target_conn) - to_visit.extend(new_edges) + return ret_value - elif isinstance(next_node, dace_nodes.Tasklet) or not only_tasklets: - # We have found a consumer. - found.add((next_node, curr_edge)) - return found +def is_accessed_downstream( + start_state: dace.SDFGState, + sdfg: dace.SDFG, + data_to_look: str, + nodes_to_ignore: Optional[set[dace_nodes.AccessNode]] = None, + states_to_ignore: Optional[set[dace.SDFGState]] = None, +) -> bool: + """Scans for accesses to the data container `data_to_look`. + The function will go through states that are reachable from `start_state` + (included) and test if there is an AccessNode that refers to `data_to_look`. + It will return `True` the first time it finds such a node. -def find_upstream_producers( + The function will ignore all nodes that are listed in `nodes_to_ignore`. + Furthermore, states listed in `states_to_ignore` will be ignored, i.e. + handled as they did not exist. + + Args: + start_state: The state where the scanning starts. + sdfg: The SDFG on which we operate. + data_to_look: The data that we want to look for. + nodes_to_ignore: Ignore these nodes. + states_to_ignore: Ignore these states. + """ + seen_states: set[dace.SDFGState] = set() + to_visit: list[dace.SDFGState] = [start_state] + ign_dnodes: set[dace_nodes.AccessNode] = nodes_to_ignore or set() + ign_states: set[dace.SDFGState] = states_to_ignore or set() + + while len(to_visit) > 0: + state = to_visit.pop() + seen_states.add(state) + for dnode in state.data_nodes(): + if dnode.data != data_to_look: + continue + if dnode in ign_dnodes: + continue + if state.out_degree(dnode) != 0: + return True # There is a read operation + + # Look for new states, also scan the interstate edges. + for out_edge in sdfg.out_edges(state): + if out_edge.dst in ign_states: + continue + if data_to_look in out_edge.data.read_symbols(): + return True + if out_edge.dst in seen_states: + continue + to_visit.append(out_edge.dst) + + return False + + +def is_view( + node: Union[dace_nodes.AccessNode, dace_data.Data], + sdfg: dace.SDFG, +) -> bool: + """Tests if `node` points to a view or not.""" + node_desc: dace_data.Data = node.desc(sdfg) if isinstance(node, dace_nodes.AccessNode) else node + return isinstance(node_desc, dace_data.View) + + +def track_view( + view: dace_nodes.AccessNode, state: dace.SDFGState, - begin: dace_nodes.Node | dace_graph.MultiConnectorEdge[dace.Memlet], - only_tasklets: bool = False, -) -> set[tuple[dace_nodes.Node, dace_graph.MultiConnectorEdge[dace.Memlet]]]: - """Same as `find_downstream_consumers()` but with `reverse` set to `True`.""" - return find_downstream_consumers( - state=state, - begin=begin, - only_tasklets=only_tasklets, - reverse=True, - ) + sdfg: dace.SDFG, +) -> dace_nodes.AccessNode: + """Find the original data of a View. + + Given the View `view`, the function will trace the view back to the original + access node. For convenience, if `view` is not a `View` the argument will be + returned. + + Args: + view: The view that should be traced. + state: The state in which we operate. + sdfg: The SDFG on which we operate. + """ + + # Test if it is a view at all, if not return the passed node as source. + if not is_view(view, sdfg): + return view + + # First determine if the view is used for reading or writing. + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"Failed to determine the direction of the view '{view}'.") + if curr_edge.dst_conn == "views": + # The view is used for reading. + next_node = lambda curr_edge: curr_edge.src # noqa: E731 + elif curr_edge.src_conn == "views": + # The view is used for writing. + next_node = lambda curr_edge: curr_edge.dst # noqa: E731 + else: + raise RuntimeError(f"Failed to determine the direction of the view '{view}' | {curr_edge}.") + + # Now trace the view back. + org_view = view + view = next_node(curr_edge) + while is_view(view, sdfg): + curr_edge = dace.sdfg.utils.get_view_edge(state, view) + if curr_edge is None: + raise RuntimeError(f"View tracing of '{org_view}' failed at note '{view}'.") + view = next_node(curr_edge) + return view diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py index 40d44f5ab0..a38a50d886 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/workflow.py @@ -72,7 +72,7 @@ def __call__( self, inp: stages.CompilableProgram ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: """Generate DaCe SDFG file from the GTIR definition.""" - program: itir.FencilDefinition | itir.Program = inp.data + program: itir.Program = inp.data assert isinstance(program, itir.Program) sdfg = self.generate_sdfg( diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py deleted file mode 100644 index ef09cf51cd..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ /dev/null @@ -1,377 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import dataclasses -import warnings -from collections import OrderedDict -from collections.abc import Callable, Sequence -from dataclasses import field -from inspect import currentframe, getframeinfo -from pathlib import Path -from typing import Any, ClassVar, Optional - -import dace -import numpy as np -from dace.sdfg import utils as sdutils -from dace.transformation.auto import auto_optimize as autoopt - -import gt4py.next.iterator.ir as itir -from gt4py.next import common -from gt4py.next.ffront import decorator -from gt4py.next.iterator import transforms as itir_transforms -from gt4py.next.iterator.ir import SymRef -from gt4py.next.iterator.transforms import ( - pass_manager_legacy as legacy_itir_transforms, - program_to_fencil, -) -from gt4py.next.iterator.type_system import inference as itir_type_inference -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.type_system import type_specifications as ts - -from .itir_to_sdfg import ItirToSDFG - - -def preprocess_program( - program: itir.FencilDefinition, - offset_provider_type: common.OffsetProviderType, - lift_mode: legacy_itir_transforms.LiftMode, - symbolic_domain_sizes: Optional[dict[str, str]] = None, - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, - unroll_reduce: bool = False, -): - node = legacy_itir_transforms.apply_common_transforms( - program, - common_subexpression_elimination=False, - force_inline_lambda_args=True, - lift_mode=lift_mode, - offset_provider_type=offset_provider_type, - symbolic_domain_sizes=symbolic_domain_sizes, - temporary_extraction_heuristics=temporary_extraction_heuristics, - unroll_reduce=unroll_reduce, - ) - - node = itir_type_inference.infer(node, offset_provider_type=offset_provider_type) - - if isinstance(node, itir.Program): - fencil_definition = program_to_fencil.program_to_fencil(node) - tmps = node.declarations - assert all(isinstance(tmp, itir.Temporary) for tmp in tmps) - else: - raise TypeError(f"Expected 'Program', got '{type(node).__name__}'.") - - return fencil_definition, tmps - - -def build_sdfg_from_itir( - program: itir.FencilDefinition, - arg_types: Sequence[ts.TypeSpec], - offset_provider_type: common.OffsetProviderType, - auto_optimize: bool = False, - on_gpu: bool = False, - column_axis: Optional[common.Dimension] = None, - lift_mode: legacy_itir_transforms.LiftMode = legacy_itir_transforms.LiftMode.FORCE_INLINE, - symbolic_domain_sizes: Optional[dict[str, str]] = None, - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None, - load_sdfg_from_file: bool = False, - save_sdfg: bool = True, - use_field_canonical_representation: bool = True, -) -> dace.SDFG: - """Translate a Fencil into an SDFG. - - Args: - program: The Fencil that should be translated. - arg_types: Types of the arguments passed to the fencil. - offset_provider: The set of offset providers that should be used. - auto_optimize: Apply DaCe's `auto_optimize` heuristic. - on_gpu: Performs the translation for GPU, defaults to `False`. - column_axis: The column axis to be used, defaults to `None`. - lift_mode: Which lift mode should be used, defaults `FORCE_INLINE`. - symbolic_domain_sizes: Used for generation of liskov bindings when temporaries are enabled. - load_sdfg_from_file: Allows to read the SDFG from file, instead of generating it, for debug only. - save_sdfg: If `True`, the default the SDFG is stored as a file and can be loaded, this allows to skip the lowering step, requires `load_sdfg_from_file` set to `True`. - use_field_canonical_representation: If `True`, assume that the fields dimensions are sorted alphabetically. - """ - - sdfg_filename = f"_dacegraphs/gt4py/{program.id}.sdfg" - if load_sdfg_from_file and Path(sdfg_filename).exists(): - sdfg: dace.SDFG = dace.SDFG.from_file(sdfg_filename) - sdfg.validate() - return sdfg - - # visit ITIR and generate SDFG - program, tmps = preprocess_program( - program, - offset_provider_type, - lift_mode, - symbolic_domain_sizes, - temporary_extraction_heuristics, - ) - sdfg_genenerator = ItirToSDFG( - list(arg_types), - offset_provider_type, - tmps, - use_field_canonical_representation, - column_axis, - ) - sdfg = sdfg_genenerator.visit(program) - if sdfg is None: - raise RuntimeError(f"Visit failed for program {program.id}.") - - for nested_sdfg in sdfg.all_sdfgs_recursive(): - if not nested_sdfg.debuginfo: - _, frameinfo = ( - warnings.warn( - f"{nested_sdfg.label} does not have debuginfo. Consider adding them in the corresponding nested sdfg.", - stacklevel=2, - ), - getframeinfo(currentframe()), # type: ignore[arg-type] - ) - nested_sdfg.debuginfo = dace.dtypes.DebugInfo( - start_line=frameinfo.lineno, end_line=frameinfo.lineno, filename=frameinfo.filename - ) - - # TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct - sdutils.inline_loop_blocks(sdfg) - - # run DaCe transformations to simplify the SDFG - sdfg.simplify() - - # run DaCe auto-optimization heuristics - if auto_optimize: - # TODO: Investigate performance improvement from SDFG specialization with constant symbols, - # for array shape and strides, although this would imply JIT compilation. - symbols: dict[str, int] = {} - device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU - sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) - elif on_gpu: - autoopt.apply_gpu_storage(sdfg) - - if on_gpu: - sdfg.apply_gpu_transformations() - - # Store the sdfg such that we can later reuse it. - if save_sdfg: - sdfg.save(sdfg_filename) - - return sdfg - - -@dataclasses.dataclass(frozen=True) -class Program(decorator.Program, dace.frontend.python.common.SDFGConvertible): - """Extension of GT4Py Program implementing the SDFGConvertible interface.""" - - sdfg_closure_vars: dict[str, Any] = field(default_factory=dict) - - # Being a ClassVar ensures that in an SDFG with multiple nested GT4Py Programs, - # there is no name mangling of the connectivity tables used across the nested SDFGs - # since they share the same memory address. - connectivity_tables_data_descriptors: ClassVar[ - dict[str, dace.data.Array] - ] = {} # symbolically defined - - def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: - if "dace" not in self.backend.name.lower(): # type: ignore[union-attr] - raise ValueError("The SDFG can be generated only for the DaCe backend.") - - params = {str(p.id): p.type for p in self.itir.params} - fields = {str(p.id): p.type for p in self.itir.params if hasattr(p.type, "dims")} - arg_types = [*params.values()] - - dace_parsed_args = [*args, *kwargs.values()] - gt4py_program_args = [*params.values()] - _crosscheck_dace_parsing(dace_parsed_args, gt4py_program_args) - - if self.connectivities is None: - raise ValueError( - "[DaCe Orchestration] Connectivities -at compile time- are required to generate the SDFG. Use `with_connectivities` method." - ) - offset_provider_type = {**self.connectivities, **self._implicit_offset_provider} - - sdfg = self.backend.executor.step.translation.generate_sdfg( # type: ignore[union-attr] - self.itir, - arg_types, - offset_provider_type=offset_provider_type, - column_axis=kwargs.get("column_axis", None), - ) - self.sdfg_closure_vars["sdfg.arrays"] = sdfg.arrays # use it in __sdfg_closure__ - - # Halo exchange related metadata, i.e. gt4py_program_input_fields, gt4py_program_output_fields, offset_providers_per_input_field - # Add them as dynamic properties to the SDFG - - assert all( - isinstance(in_field, SymRef) - for closure in self.itir.closures - for in_field in closure.inputs - ) # backend only supports SymRef inputs, not `index` calls - input_fields = [ - str(in_field.id) # type: ignore[union-attr] # ensured by assert - for closure in self.itir.closures - for in_field in closure.inputs - if str(in_field.id) in fields # type: ignore[union-attr] # ensured by assert - ] - sdfg.gt4py_program_input_fields = { - in_field: dim - for in_field in input_fields - for dim in fields[in_field].dims # type: ignore[union-attr] - if dim.kind == common.DimensionKind.HORIZONTAL - } - - output_fields = [] - for closure in self.itir.closures: - output = closure.output - if isinstance(output, itir.SymRef): - if str(output.id) in fields: - output_fields.append(str(output.id)) - else: - for arg in output.args: - if str(arg.id) in fields: # type: ignore[attr-defined] - output_fields.append(str(arg.id)) # type: ignore[attr-defined] - sdfg.gt4py_program_output_fields = { - output: dim - for output in output_fields - for dim in fields[output].dims # type: ignore[union-attr] - if dim.kind == common.DimensionKind.HORIZONTAL - } - - sdfg.offset_providers_per_input_field = {} - itir_tmp = legacy_itir_transforms.apply_common_transforms( - self.itir, offset_provider_type=offset_provider_type - ) - itir_tmp_fencil = program_to_fencil.program_to_fencil(itir_tmp) - for closure in itir_tmp_fencil.closures: - params_shifts = itir_transforms.trace_shifts.trace_stencil( - closure.stencil, num_args=len(closure.inputs) - ) - for param, shifts in zip(closure.inputs, params_shifts): - assert isinstance( - param, SymRef - ) # backend only supports SymRef inputs, not `index` calls - if not isinstance(param.id, str): - continue - if param.id not in sdfg.gt4py_program_input_fields: - continue - sdfg.offset_providers_per_input_field.setdefault(param.id, []).extend(list(shifts)) - - return sdfg - - def __sdfg_closure__(self, reevaluate: Optional[dict[str, str]] = None) -> dict[str, Any]: - """ - Returns the closure arrays of the SDFG represented by this object - as a mapping between array name and the corresponding value. - - The connectivity tables are defined symbolically, i.e. table sizes & strides are DaCe symbols. - The need to define the connectivity tables in the `__sdfg_closure__` arises from the fact that - the offset providers are not part of GT4Py Program's arguments. - Keep in mind, that `__sdfg_closure__` is called after `__sdfg__` method. - """ - offset_provider_type = self.connectivities - - # Define DaCe symbols - connectivity_table_size_symbols = { - dace_utils.field_size_symbol_name( - dace_utils.connectivity_identifier(k), axis - ): dace.symbol( - dace_utils.field_size_symbol_name(dace_utils.connectivity_identifier(k), axis) - ) - for k, v in offset_provider_type.items() # type: ignore[union-attr] - for axis in [0, 1] - if isinstance(v, common.NeighborConnectivityType) - and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] - } - - connectivity_table_stride_symbols = { - dace_utils.field_stride_symbol_name( - dace_utils.connectivity_identifier(k), axis - ): dace.symbol( - dace_utils.field_stride_symbol_name(dace_utils.connectivity_identifier(k), axis) - ) - for k, v in offset_provider_type.items() # type: ignore[union-attr] - for axis in [0, 1] - if isinstance(v, common.NeighborConnectivityType) - and dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"] - } - - symbols = {**connectivity_table_size_symbols, **connectivity_table_stride_symbols} - - # Define the storage location (e.g. CPU, GPU) of the connectivity tables - if "storage" not in Program.connectivity_tables_data_descriptors: - for k, v in offset_provider_type.items(): # type: ignore[union-attr] - if not isinstance(v, common.NeighborConnectivityType): - continue - if dace_utils.connectivity_identifier(k) in self.sdfg_closure_vars["sdfg.arrays"]: - Program.connectivity_tables_data_descriptors["storage"] = ( - self.sdfg_closure_vars[ - "sdfg.arrays" - ][dace_utils.connectivity_identifier(k)].storage - ) - break - - # Build the closure dictionary - closure_dict = {} - for k, v in offset_provider_type.items(): # type: ignore[union-attr] - conn_id = dace_utils.connectivity_identifier(k) - if ( - isinstance(v, common.NeighborConnectivityType) - and conn_id in self.sdfg_closure_vars["sdfg.arrays"] - ): - if conn_id not in Program.connectivity_tables_data_descriptors: - Program.connectivity_tables_data_descriptors[conn_id] = dace.data.Array( - dtype=dace.int64 if v.dtype.scalar_type == np.int64 else dace.int32, - shape=[ - symbols[dace_utils.field_size_symbol_name(conn_id, 0)], - symbols[dace_utils.field_size_symbol_name(conn_id, 1)], - ], - strides=[ - symbols[dace_utils.field_stride_symbol_name(conn_id, 0)], - symbols[dace_utils.field_stride_symbol_name(conn_id, 1)], - ], - storage=Program.connectivity_tables_data_descriptors["storage"], - ) - closure_dict[conn_id] = Program.connectivity_tables_data_descriptors[conn_id] - - return closure_dict - - def __sdfg_signature__(self) -> tuple[Sequence[str], Sequence[str]]: - args = [] - for arg in self.past_stage.past_node.params: - args.append(arg.id) - return (args, []) - - -def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: list[Any]) -> bool: - for dace_parsed_arg, gt4py_program_arg in zip(dace_parsed_args, gt4py_program_args): - if isinstance(dace_parsed_arg, dace.data.Scalar): - assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg) - elif isinstance( - dace_parsed_arg, (bool, int, float, str, np.bool_, np.integer, np.floating, np.str_) - ): # compile-time constant scalar - assert isinstance(gt4py_program_arg, ts.ScalarType) - if isinstance(dace_parsed_arg, (bool, np.bool_)): - assert gt4py_program_arg.kind == ts.ScalarKind.BOOL - elif isinstance(dace_parsed_arg, (int, np.integer)): - assert gt4py_program_arg.kind in [ts.ScalarKind.INT32, ts.ScalarKind.INT64] - elif isinstance(dace_parsed_arg, (float, np.floating)): - assert gt4py_program_arg.kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] - elif isinstance(dace_parsed_arg, (str, np.str_)): - assert gt4py_program_arg.kind == ts.ScalarKind.STRING - elif isinstance(dace_parsed_arg, dace.data.Array): - assert isinstance(gt4py_program_arg, ts.FieldType) - assert len(dace_parsed_arg.shape) == len(gt4py_program_arg.dims) - assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg.dtype) - elif isinstance( - dace_parsed_arg, (dace.data.Structure, dict, OrderedDict) - ): # offset_provider - continue - else: - raise ValueError(f"Unresolved case for {dace_parsed_arg} (==, !=) {gt4py_program_arg}") - - return True diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py deleted file mode 100644 index 823943cfd5..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ /dev/null @@ -1,809 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import warnings -from typing import Optional, Sequence, cast - -import dace -from dace.sdfg.state import LoopRegion - -import gt4py.eve as eve -from gt4py.next import Dimension, DimensionKind, common -from gt4py.next.ffront import fbuiltins as gtx_fbuiltins -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.type_system import type_info, type_specifications as ts, type_translation as tt - -from .itir_to_tasklet import ( - Context, - GatherOutputSymbolsPass, - PythonTaskletCodegen, - SymbolExpr, - TaskletExpr, - ValueExpr, - closure_to_tasklet_sdfg, - is_scan, -) -from .utility import ( - add_mapped_nested_sdfg, - flatten_list, - get_used_connectivities, - map_nested_sdfg_symbols, - new_array_symbols, - unique_var_name, -) - - -def _get_scan_args(stencil: Expr) -> tuple[bool, Literal]: - """ - Parse stencil expression to extract the scan arguments. - - Returns - ------- - tuple(is_forward, init_carry) - The output tuple fields verify the following semantics: - - is_forward: forward boolean flag - - init_carry: carry initial value - """ - stencil_fobj = cast(FunCall, stencil) - is_forward = stencil_fobj.args[1] - assert isinstance(is_forward, Literal) and type_info.is_logical(is_forward.type) - init_carry = stencil_fobj.args[2] - assert isinstance(init_carry, Literal) - return is_forward.value == "True", init_carry - - -def _get_scan_dim( - column_axis: Dimension, - storage_types: dict[str, ts.TypeSpec], - output: SymRef, - use_field_canonical_representation: bool, -) -> tuple[str, int, ts.ScalarType]: - """ - Extract information about the scan dimension. - - Returns - ------- - tuple(scan_dim_name, scan_dim_index, scan_dim_dtype) - The output tuple fields verify the following semantics: - - scan_dim_name: name of the scan dimension - - scan_dim_index: domain index of the scan dimension - - scan_dim_dtype: data type along the scan dimension - """ - output_type = storage_types[output.id] - assert isinstance(output_type, ts.FieldType) - sorted_dims = [ - dim - for _, dim in ( - dace_utils.get_sorted_dims(output_type.dims) - if use_field_canonical_representation - else enumerate(output_type.dims) - ) - ] - return (column_axis.value, sorted_dims.index(column_axis), output_type.dtype) - - -def _make_array_shape_and_strides( - name: str, - dims: Sequence[Dimension], - offset_provider_type: common.OffsetProviderType, - sort_dims: bool, -) -> tuple[list[dace.symbol], list[dace.symbol]]: - """ - Parse field dimensions and allocate symbols for array shape and strides. - - For local dimensions, the size is known at compile-time and therefore - the corresponding array shape dimension is set to an integer literal value. - - Returns - ------- - tuple(shape, strides) - The output tuple fields are arrays of dace symbolic expressions. - """ - dtype = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) - sorted_dims = dace_utils.get_sorted_dims(dims) if sort_dims else list(enumerate(dims)) - connectivity_types = dace_utils.filter_connectivity_types(offset_provider_type) - shape = [ - ( - connectivity_types[dim.value].max_neighbors - if dim.kind == DimensionKind.LOCAL - # we reuse the same gt4py symbol for field size passed as scalar argument which is used in closure domain - else dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) - ) - for i, dim in sorted_dims - ] - strides = [ - dace.symbol(dace_utils.field_stride_symbol_name(name, i), dtype) for i, _ in sorted_dims - ] - return shape, strides - - -def _check_no_lifts(node: itir.StencilClosure): - """ - Parse stencil closure ITIR to check that lift expressions only appear as child nodes in neighbor reductions. - - Returns - ------- - True if lifts do not appear in the ITIR exception lift expressions in neighbor reductions. False otherwise. - """ - neighbors_call_count = 0 - for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun"): - if getattr(fun, "id", "") == "neighbors": - neighbors_call_count = 3 - elif getattr(fun, "id", "") == "lift" and neighbors_call_count != 1: - return False - neighbors_call_count = max(0, neighbors_call_count - 1) - return True - - -class ItirToSDFG(eve.NodeVisitor): - param_types: list[ts.TypeSpec] - storage_types: dict[str, ts.TypeSpec] - column_axis: Optional[Dimension] - offset_provider_type: common.OffsetProviderType - unique_id: int - use_field_canonical_representation: bool - - def __init__( - self, - param_types: list[ts.TypeSpec], - offset_provider_type: common.OffsetProviderType, - tmps: list[itir.Temporary], - use_field_canonical_representation: bool, - column_axis: Optional[Dimension] = None, - ): - self.param_types = param_types - self.column_axis = column_axis - self.offset_provider_type = offset_provider_type - self.storage_types = {} - self.tmps = tmps - self.use_field_canonical_representation = use_field_canonical_representation - - def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, sort_dimensions: bool): - if isinstance(type_, ts.FieldType): - shape, strides = _make_array_shape_and_strides( - name, type_.dims, self.offset_provider_type, sort_dimensions - ) - dtype = dace_utils.as_dace_type(type_.dtype) - sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) - - elif isinstance(type_, ts.ScalarType): - dtype = dace_utils.as_dace_type(type_) - if name in sdfg.symbols: - assert sdfg.symbols[name].dtype == dtype - else: - sdfg.add_symbol(name, dtype) - - else: - raise NotImplementedError() - self.storage_types[name] = type_ - - def add_storage_for_temporaries( - self, node_params: list[Sym], defs_state: dace.SDFGState, program_sdfg: dace.SDFG - ) -> dict[str, str]: - symbol_map: dict[str, TaskletExpr] = {} - # The shape of temporary arrays might be defined based on scalar values passed as program arguments. - # Here we collect these values in a symbol map. - for sym in node_params: - if isinstance(sym.type, ts.ScalarType): - name_ = str(sym.id) - symbol_map[name_] = SymbolExpr(name_, dace_utils.as_dace_type(sym.type)) - - tmp_symbols: dict[str, str] = {} - for tmp in self.tmps: - tmp_name = str(tmp.id) - - # We visit the domain of the temporary field, passing the set of available symbols. - assert isinstance(tmp.domain, itir.FunCall) - domain_ctx = Context(program_sdfg, defs_state, symbol_map) - tmp_domain = self._visit_domain(tmp.domain, domain_ctx) - - if isinstance(tmp.type, ts.TupleType): - raise NotImplementedError("Temporaries of tuples are not supported.") - assert isinstance(tmp.type, ts.FieldType) and isinstance(tmp.dtype, ts.ScalarType) - - # We store the FieldType for this temporary array. - self.storage_types[tmp_name] = tmp.type - - # N.B.: skip generation of symbolic strides and just let dace assign default strides, for now. - # Another option, in the future, is to use symbolic strides and apply auto-tuning or some heuristics - # to assign optimal stride values. - tmp_shape, _ = new_array_symbols(tmp_name, len(tmp.type.dims)) - _, tmp_array = program_sdfg.add_array( - tmp_name, tmp_shape, dace_utils.as_dace_type(tmp.dtype), transient=True - ) - - # Loop through all dimensions to visit the symbolic expressions for array shape and offset. - # These expressions are later mapped to interstate symbols. - for (_, (begin, end)), shape_sym in zip(tmp_domain, tmp_array.shape): - # The temporary field has a dimension range defined by `begin` and `end` values. - # Therefore, the actual size is given by the difference `end.value - begin.value`. - # Instead of allocating the actual size, we allocate space to enable indexing from 0 - # because we want to avoid using dace array offsets (which will be deprecated soon). - # The result should still be valid, but the stencil will be using only a subset - # of the array. - if not (isinstance(begin, SymbolExpr) and begin.value == "0"): - warnings.warn( - f"Domain start offset for temporary {tmp_name} is ignored.", stacklevel=2 - ) - tmp_symbols[str(shape_sym)] = end.value - - return tmp_symbols - - def create_memlet_at(self, field_name: str, index: dict[str, str]): - field_type = self.storage_types[field_name] - assert isinstance(field_type, ts.FieldType) - if self.use_field_canonical_representation: - field_index = [ - index[dim.value] for _, dim in dace_utils.get_sorted_dims(field_type.dims) - ] - else: - field_index = [index[dim.value] for dim in field_type.dims] - subset = ", ".join(field_index) - return dace.Memlet(data=field_name, subset=subset) - - def get_output_nodes( - self, closure: itir.StencilClosure, sdfg: dace.SDFG, state: dace.SDFGState - ) -> dict[str, dace.nodes.AccessNode]: - # Visit output node, which could be a `make_tuple` expression, to collect the required access nodes - output_symbols_pass = GatherOutputSymbolsPass(sdfg, state) - output_symbols_pass.visit(closure.output) - # Visit output node again to generate the corresponding tasklet - context = Context(sdfg, state, output_symbols_pass.symbol_refs) - translator = PythonTaskletCodegen( - self.offset_provider_type, context, self.use_field_canonical_representation - ) - output_nodes = flatten_list(translator.visit(closure.output)) - return {node.value.data: node.value for node in output_nodes} - - def visit_FencilDefinition(self, node: itir.FencilDefinition): - program_sdfg = dace.SDFG(name=node.id) - program_sdfg.debuginfo = dace_utils.debug_info(node) - entry_state = program_sdfg.add_state("program_entry", is_start_block=True) - - # Filter neighbor tables from offset providers. - connectivity_types = get_used_connectivities(node, self.offset_provider_type) - - # Add program parameters as SDFG storages. - for param, type_ in zip(node.params, self.param_types): - self.add_storage( - program_sdfg, str(param.id), type_, self.use_field_canonical_representation - ) - - if self.tmps: - tmp_symbols = self.add_storage_for_temporaries(node.params, entry_state, program_sdfg) - # on the first interstate edge define symbols for shape and offsets of temporary arrays - last_state = program_sdfg.add_state("init_symbols_for_temporaries") - program_sdfg.add_edge( - entry_state, last_state, dace.InterstateEdge(assignments=tmp_symbols) - ) - else: - last_state = entry_state - - # Add connectivities as SDFG storages. - for offset, connectivity_type in connectivity_types.items(): - scalar_type = tt.from_dtype(connectivity_type.dtype) - type_ = ts.FieldType( - [connectivity_type.source_dim, connectivity_type.neighbor_dim], scalar_type - ) - self.add_storage( - program_sdfg, - dace_utils.connectivity_identifier(offset), - type_, - sort_dimensions=False, - ) - - # Create a nested SDFG for all stencil closures. - for closure in node.closures: - # Translate the closure and its stencil's body to an SDFG. - closure_sdfg, input_names, output_names = self.visit( - closure, array_table=program_sdfg.arrays - ) - - # Create a new state for the closure. - last_state = program_sdfg.add_state_after(last_state) - - # Create memlets to transfer the program parameters - input_mapping = { - name: dace.Memlet.from_array(name, program_sdfg.arrays[name]) - for name in input_names - } - output_mapping = { - name: dace.Memlet.from_array(name, program_sdfg.arrays[name]) - for name in output_names - } - - symbol_mapping = map_nested_sdfg_symbols(program_sdfg, closure_sdfg, input_mapping) - - # Insert the closure's SDFG as a nested SDFG of the program. - nsdfg_node = last_state.add_nested_sdfg( - sdfg=closure_sdfg, - parent=program_sdfg, - inputs=set(input_names), - outputs=set(output_names), - symbol_mapping=symbol_mapping, - debuginfo=closure_sdfg.debuginfo, - ) - - # Add access nodes for the program parameters and connect them to the nested SDFG's inputs via edges. - for inner_name, memlet in input_mapping.items(): - access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) - last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) - - for inner_name, memlet in output_mapping.items(): - access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) - last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) - - # Create the call signature for the SDFG. - # Only the arguments requiered by the Fencil, i.e. `node.params` are added as positional arguments. - # The implicit arguments, such as the offset providers or the arguments created by the translation process, must be passed as keywords only arguments. - program_sdfg.arg_names = [str(a) for a in node.params] - - program_sdfg.validate() - return program_sdfg - - def visit_StencilClosure( - self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array] - ) -> tuple[dace.SDFG, list[str], list[str]]: - assert _check_no_lifts(node) - - # Create the closure's nested SDFG and single state. - closure_sdfg = dace.SDFG(name="closure") - closure_sdfg.debuginfo = dace_utils.debug_info(node) - closure_state = closure_sdfg.add_state("closure_entry") - closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True) - - assert all( - isinstance(inp, SymRef) for inp in node.inputs - ) # backend only supports SymRef inputs, not `index` calls - input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert - neighbor_tables = get_used_connectivities(node, self.offset_provider_type) - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - output_nodes = self.get_output_nodes(node, closure_sdfg, closure_state) - output_names = [k for k, _ in output_nodes.items()] - - # Add DaCe arrays for inputs, outputs and connectivities to closure SDFG. - input_transients_mapping = {} - for name in [*input_names, *connectivity_names, *output_names]: - if name in closure_sdfg.arrays: - assert name in input_names and name in output_names - # In case of closures with in/out fields, there is risk of race condition - # between read/write access nodes in the (asynchronous) map tasklet. - transient_name = unique_var_name() - closure_sdfg.add_array( - transient_name, - shape=array_table[name].shape, - strides=array_table[name].strides, - dtype=array_table[name].dtype, - transient=True, - ) - closure_init_state.add_nedge( - closure_init_state.add_access(name, debuginfo=closure_sdfg.debuginfo), - closure_init_state.add_access(transient_name, debuginfo=closure_sdfg.debuginfo), - dace.Memlet.from_array(name, closure_sdfg.arrays[name]), - ) - input_transients_mapping[name] = transient_name - elif isinstance(self.storage_types[name], ts.FieldType): - closure_sdfg.add_array( - name, - shape=array_table[name].shape, - strides=array_table[name].strides, - dtype=array_table[name].dtype, - ) - else: - assert isinstance(self.storage_types[name], ts.ScalarType) - - input_field_names = [ - input_name - for input_name in input_names - if isinstance(self.storage_types[input_name], ts.FieldType) - ] - - # Closure outputs should all be fields - assert all( - isinstance(self.storage_types[output_name], ts.FieldType) - for output_name in output_names - ) - - # Update symbol table and get output domain of the closure - program_arg_syms: dict[str, TaskletExpr] = {} - for name, type_ in self.storage_types.items(): - if isinstance(type_, ts.ScalarType): - dtype = dace_utils.as_dace_type(type_) - if name in input_names: - out_name = unique_var_name() - closure_sdfg.add_scalar(out_name, dtype, transient=True) - out_tasklet = closure_init_state.add_tasklet( - f"get_{name}", - {}, - {"__result"}, - f"__result = {name}", - debuginfo=closure_sdfg.debuginfo, - ) - access = closure_init_state.add_access( - out_name, debuginfo=closure_sdfg.debuginfo - ) - value = ValueExpr(access, dtype) - memlet = dace.Memlet(data=out_name, subset="0") - closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) - program_arg_syms[name] = value - else: - program_arg_syms[name] = SymbolExpr(name, dtype) - else: - assert isinstance(type_, ts.FieldType) - # make shape symbols (corresponding to field size) available as arguments to domain visitor - if name in input_names or name in output_names: - field_symbols = [ - val - for val in closure_sdfg.arrays[name].shape - if isinstance(val, dace.symbol) and str(val) not in input_names - ] - for sym in field_symbols: - sym_name = str(sym) - program_arg_syms[sym_name] = SymbolExpr(sym, sym.dtype) - closure_ctx = Context(closure_sdfg, closure_state, program_arg_syms) - closure_domain = self._visit_domain(node.domain, closure_ctx) - - # Map SDFG tasklet arguments to parameters - input_local_names = [ - ( - input_transients_mapping[input_name] - if input_name in input_transients_mapping - else ( - input_name - if input_name in input_field_names - else cast(ValueExpr, program_arg_syms[input_name]).value.data - ) - ) - for input_name in input_names - ] - input_memlets = [ - dace.Memlet.from_array(name, closure_sdfg.arrays[name]) - for name in [*input_local_names, *connectivity_names] - ] - - # create and write to transient that is then copied back to actual output array to avoid aliasing of - # same memory in nested SDFG with different names - output_connectors_mapping = {unique_var_name(): output_name for output_name in output_names} - # scan operator should always be the first function call in a closure - if is_scan(node.stencil): - assert len(output_connectors_mapping) == 1, "Scan does not support multiple outputs" - transient_name, output_name = next(iter(output_connectors_mapping.items())) - - nsdfg, map_ranges, scan_dim_index = self._visit_scan_stencil_closure( - node, closure_sdfg.arrays, closure_domain, transient_name - ) - results = [transient_name] - - _, (scan_lb, scan_ub) = closure_domain[scan_dim_index] - output_subset = f"{scan_lb.value}:{scan_ub.value}" - - domain_subset = { - dim: ( - f"i_{dim}" - if f"i_{dim}" in map_ranges - else f"0:{closure_sdfg.arrays[output_name].shape[scan_dim_index]}" - ) - for dim, _ in closure_domain - } - output_memlets = [self.create_memlet_at(output_name, domain_subset)] - else: - nsdfg, map_ranges, results = self._visit_parallel_stencil_closure( - node, closure_sdfg.arrays, closure_domain - ) - - output_subset = "0" - - output_memlets = [ - self.create_memlet_at(output_name, {dim: f"i_{dim}" for dim, _ in closure_domain}) - for output_name in output_connectors_mapping.values() - ] - - input_mapping = { - param: arg for param, arg in zip([*input_names, *connectivity_names], input_memlets) - } - output_mapping = {param: memlet for param, memlet in zip(results, output_memlets)} - - symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, input_mapping) - - nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg( - closure_state, - sdfg=nsdfg, - map_ranges=map_ranges or {"__dummy": "0"}, - inputs=input_mapping, - outputs=output_mapping, - symbol_mapping=symbol_mapping, - output_nodes=output_nodes, - debuginfo=nsdfg.debuginfo, - ) - access_nodes = {edge.data.data: edge.dst for edge in closure_state.out_edges(map_exit)} - for edge in closure_state.in_edges(map_exit): - memlet = edge.data - if memlet.data not in output_connectors_mapping: - continue - transient_access = closure_state.add_access(memlet.data, debuginfo=nsdfg.debuginfo) - closure_state.add_edge( - nsdfg_node, - edge.src_conn, - transient_access, - None, - dace.Memlet(data=memlet.data, subset=output_subset, debuginfo=nsdfg.debuginfo), - ) - inner_memlet = dace.Memlet( - data=memlet.data, subset=output_subset, other_subset=memlet.subset - ) - closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet) - closure_state.remove_edge(edge) - access_nodes[memlet.data].data = output_connectors_mapping[memlet.data] - - return closure_sdfg, input_field_names + connectivity_names, output_names - - def _visit_scan_stencil_closure( - self, - node: itir.StencilClosure, - array_table: dict[str, dace.data.Array], - closure_domain: tuple[ - tuple[str, tuple[ValueExpr | SymbolExpr, ValueExpr | SymbolExpr]], ... - ], - output_name: str, - ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], int]: - # extract scan arguments - is_forward, init_carry_value = _get_scan_args(node.stencil) - # select the scan dimension based on program argument for column axis - assert self.column_axis - assert isinstance(node.output, SymRef) - scan_dim, scan_dim_index, scan_dtype = _get_scan_dim( - self.column_axis, - self.storage_types, - node.output, - self.use_field_canonical_representation, - ) - - assert isinstance(node.output, SymRef) - neighbor_tables = get_used_connectivities(node, self.offset_provider_type) - assert all( - isinstance(inp, SymRef) for inp in node.inputs - ) # backend only supports SymRef inputs, not `index` calls - input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - # find the scan dimension, same as output dimension, and exclude it from the map domain - map_ranges = {} - for dim, (lb, ub) in closure_domain: - lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value - ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value - if not dim == scan_dim: - map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" - else: - scan_lb_str = lb_str - scan_ub_str = ub_str - - # the scan operator is implemented as an SDFG to be nested in the closure SDFG - scan_sdfg = dace.SDFG(name="scan") - scan_sdfg.debuginfo = dace_utils.debug_info(node) - - # the carry value of the scan operator exists only in the scope of the scan sdfg - scan_carry_name = unique_var_name() - scan_sdfg.add_scalar( - scan_carry_name, dtype=dace_utils.as_dace_type(scan_dtype), transient=True - ) - - # create a loop region for lambda call over the scan dimension - scan_loop_var = f"i_{scan_dim}" - if is_forward: - scan_loop = LoopRegion( - label="scan", - condition_expr=f"{scan_loop_var} < {scan_ub_str}", - loop_var=scan_loop_var, - initialize_expr=f"{scan_loop_var} = {scan_lb_str}", - update_expr=f"{scan_loop_var} = {scan_loop_var} + 1", - inverted=False, - ) - else: - scan_loop = LoopRegion( - label="scan", - condition_expr=f"{scan_loop_var} >= {scan_lb_str}", - loop_var=scan_loop_var, - initialize_expr=f"{scan_loop_var} = {scan_ub_str} - 1", - update_expr=f"{scan_loop_var} = {scan_loop_var} - 1", - inverted=False, - ) - scan_sdfg.add_node(scan_loop) - compute_state = scan_loop.add_state("lambda_compute", is_start_block=True) - update_state = scan_loop.add_state("lambda_update") - scan_loop.add_edge(compute_state, update_state, dace.InterstateEdge()) - - start_state = scan_sdfg.add_state("start", is_start_block=True) - scan_sdfg.add_edge(start_state, scan_loop, dace.InterstateEdge()) - - # tasklet for initialization of carry - carry_init_tasklet = start_state.add_tasklet( - "get_carry_init_value", - {}, - {"__result"}, - f"__result = {init_carry_value}", - debuginfo=scan_sdfg.debuginfo, - ) - start_state.add_edge( - carry_init_tasklet, - "__result", - start_state.add_access(scan_carry_name, debuginfo=scan_sdfg.debuginfo), - None, - dace.Memlet(data=scan_carry_name, subset="0"), - ) - - # add storage to scan SDFG for inputs - for name in [*input_names, *connectivity_names]: - assert name not in scan_sdfg.arrays - if isinstance(self.storage_types[name], ts.FieldType): - scan_sdfg.add_array( - name, - shape=array_table[name].shape, - strides=array_table[name].strides, - dtype=array_table[name].dtype, - ) - else: - scan_sdfg.add_scalar( - name, - dtype=dace_utils.as_dace_type(cast(ts.ScalarType, self.storage_types[name])), - ) - # add storage to scan SDFG for output - scan_sdfg.add_array( - output_name, - shape=(array_table[node.output.id].shape[scan_dim_index],), - strides=(array_table[node.output.id].strides[scan_dim_index],), - dtype=array_table[node.output.id].dtype, - ) - - # implement the lambda function as a nested SDFG that computes a single item in the scan dimension - lambda_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} - input_arrays = [(scan_carry_name, scan_dtype)] + [ - (name, self.storage_types[name]) for name in input_names - ] - connectivity_arrays = [(scan_sdfg.arrays[name], name) for name in connectivity_names] - lambda_context, lambda_outputs = closure_to_tasklet_sdfg( - node, - self.offset_provider_type, - lambda_domain, - input_arrays, - connectivity_arrays, - self.use_field_canonical_representation, - ) - - lambda_input_names = [name for name, _ in input_arrays] - lambda_output_names = [connector.value.data for connector in lambda_outputs] - - input_memlets = [ - dace.Memlet.from_array(name, scan_sdfg.arrays[name]) for name in lambda_input_names - ] - connectivity_memlets = [ - dace.Memlet.from_array(name, scan_sdfg.arrays[name]) for name in connectivity_names - ] - input_mapping = {param: arg for param, arg in zip(lambda_input_names, input_memlets)} - connectivity_mapping = { - param: arg for param, arg in zip(connectivity_names, connectivity_memlets) - } - array_mapping = {**input_mapping, **connectivity_mapping} - symbol_mapping = map_nested_sdfg_symbols(scan_sdfg, lambda_context.body, array_mapping) - - scan_inner_node = compute_state.add_nested_sdfg( - lambda_context.body, - parent=scan_sdfg, - inputs=set(lambda_input_names) | set(connectivity_names), - outputs=set(lambda_output_names), - symbol_mapping=symbol_mapping, - debuginfo=lambda_context.body.debuginfo, - ) - - # connect scan SDFG to lambda inputs - for name, memlet in array_mapping.items(): - access_node = compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo) - compute_state.add_edge(access_node, None, scan_inner_node, name, memlet) - - output_names = [output_name] - assert len(lambda_output_names) == 1 - # connect lambda output to scan SDFG - for name, connector in zip(output_names, lambda_output_names): - compute_state.add_edge( - scan_inner_node, - connector, - compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo), - None, - dace.Memlet(data=name, subset=scan_loop_var), - ) - - update_state.add_nedge( - update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo), - update_state.add_access(scan_carry_name, debuginfo=lambda_context.body.debuginfo), - dace.Memlet(data=output_name, subset=scan_loop_var, other_subset="0"), - ) - - return scan_sdfg, map_ranges, scan_dim_index - - def _visit_parallel_stencil_closure( - self, - node: itir.StencilClosure, - array_table: dict[str, dace.data.Array], - closure_domain: tuple[ - tuple[str, tuple[ValueExpr | SymbolExpr, ValueExpr | SymbolExpr]], ... - ], - ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]: - neighbor_tables = get_used_connectivities(node, self.offset_provider_type) - assert all( - isinstance(inp, SymRef) for inp in node.inputs - ) # backend only supports SymRef inputs, not `index` calls - input_names = [str(inp.id) for inp in node.inputs] # type: ignore[union-attr] # ensured by assert - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - # find the scan dimension, same as output dimension, and exclude it from the map domain - map_ranges = {} - for dim, (lb, ub) in closure_domain: - lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value - ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value - map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" - - # Create an SDFG for the tasklet that computes a single item of the output domain. - index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} - - input_arrays = [(name, self.storage_types[name]) for name in input_names] - connectivity_arrays = [(array_table[name], name) for name in connectivity_names] - - context, results = closure_to_tasklet_sdfg( - node, - self.offset_provider_type, - index_domain, - input_arrays, - connectivity_arrays, - self.use_field_canonical_representation, - ) - - return context.body, map_ranges, [r.value.data for r in results] - - def _visit_domain( - self, node: itir.FunCall, context: Context - ) -> tuple[tuple[str, tuple[SymbolExpr | ValueExpr, SymbolExpr | ValueExpr]], ...]: - assert isinstance(node.fun, itir.SymRef) - assert node.fun.id == "cartesian_domain" or node.fun.id == "unstructured_domain" - - bounds: list[tuple[str, tuple[ValueExpr, ValueExpr]]] = [] - - for named_range in node.args: - assert isinstance(named_range, itir.FunCall) - assert isinstance(named_range.fun, itir.SymRef) - assert len(named_range.args) == 3 - dimension = named_range.args[0] - assert isinstance(dimension, itir.AxisLiteral) - lower_bound = named_range.args[1] - upper_bound = named_range.args[2] - translator = PythonTaskletCodegen( - self.offset_provider_type, - context, - self.use_field_canonical_representation, - ) - lb = translator.visit(lower_bound)[0] - ub = translator.visit(upper_bound)[0] - bounds.append((dimension.value, (lb, ub))) - - return tuple(bounds) - - @staticmethod - def _check_shift_offsets_are_literals(node: itir.StencilClosure): - fun_calls = eve.walk_values(node).if_isinstance(itir.FunCall) - shifts = [nd for nd in fun_calls if getattr(nd.fun, "id", "") == "shift"] - for shift in shifts: - if not all(isinstance(arg, (itir.Literal, itir.OffsetLiteral)) for arg in shift.args): - return False - return True diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py deleted file mode 100644 index 2b2669187a..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ /dev/null @@ -1,1564 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import copy -import dataclasses -import itertools -from collections.abc import Sequence -from typing import Any, Callable, Optional, TypeAlias, cast - -import dace -import numpy as np - -import gt4py.eve.codegen -from gt4py import eve -from gt4py.next import common -from gt4py.next.common import _DEFAULT_SKIP_VALUE as neighbor_skip_value -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir import FunCall, Lambda -from gt4py.next.iterator.type_system import type_specifications as it_ts -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils -from gt4py.next.type_system import type_specifications as ts - -from .utility import ( - add_mapped_nested_sdfg, - flatten_list, - get_used_connectivities, - map_nested_sdfg_symbols, - new_array_symbols, - unique_name, - unique_var_name, -) - - -_TYPE_MAPPING = { - "float": dace.float64, - "float32": dace.float32, - "float64": dace.float64, - "int": dace.int32 if np.dtype(int).itemsize == 4 else dace.int64, - "int32": dace.int32, - "int64": dace.int64, - "bool": dace.bool_, -} - - -def itir_type_as_dace_type(type_: ts.TypeSpec): - # TODO(tehrengruber): this function just converts the scalar type of whatever it is given, - # let it be a field, iterator, or directly a scalar. The caller should take care of the - # extraction. - dtype: ts.TypeSpec - if isinstance(type_, ts.FieldType): - dtype = type_.dtype - elif isinstance(type_, it_ts.IteratorType): - dtype = type_.element_type - else: - dtype = type_ - assert isinstance(dtype, ts.ScalarType) - return _TYPE_MAPPING[dtype.kind.name.lower()] - - -def get_reduce_identity_value(op_name_: str, type_: Any): - if op_name_ == "plus": - init_value = type_(0) - elif op_name_ == "multiplies": - init_value = type_(1) - elif op_name_ == "minimum": - init_value = type_("inf") - elif op_name_ == "maximum": - init_value = type_("-inf") - else: - raise NotImplementedError() - - return init_value - - -_MATH_BUILTINS_MAPPING = { - "abs": "abs({})", - "sin": "math.sin({})", - "cos": "math.cos({})", - "tan": "math.tan({})", - "arcsin": "asin({})", - "arccos": "acos({})", - "arctan": "atan({})", - "sinh": "math.sinh({})", - "cosh": "math.cosh({})", - "tanh": "math.tanh({})", - "arcsinh": "asinh({})", - "arccosh": "acosh({})", - "arctanh": "atanh({})", - "sqrt": "math.sqrt({})", - "exp": "math.exp({})", - "log": "math.log({})", - "gamma": "tgamma({})", - "cbrt": "cbrt({})", - "isfinite": "isfinite({})", - "isinf": "isinf({})", - "isnan": "isnan({})", - "floor": "math.ifloor({})", - "ceil": "ceil({})", - "trunc": "trunc({})", - "minimum": "min({}, {})", - "maximum": "max({}, {})", - "fmod": "fmod({}, {})", - "power": "math.pow({}, {})", - "float": "dace.float64({})", - "float32": "dace.float32({})", - "float64": "dace.float64({})", - "int": "dace.int32({})" if np.dtype(int).itemsize == 4 else "dace.int64({})", - "int32": "dace.int32({})", - "int64": "dace.int64({})", - "bool": "dace.bool_({})", - "plus": "({} + {})", - "minus": "({} - {})", - "multiplies": "({} * {})", - "divides": "({} / {})", - "floordiv": "({} // {})", - "eq": "({} == {})", - "not_eq": "({} != {})", - "less": "({} < {})", - "less_equal": "({} <= {})", - "greater": "({} > {})", - "greater_equal": "({} >= {})", - "and_": "({} & {})", - "or_": "({} | {})", - "xor_": "({} ^ {})", - "mod": "({} % {})", - "not_": "(not {})", # ~ is not bitwise in numpy -} - - -# Define type of variables used for field indexing -_INDEX_DTYPE = _TYPE_MAPPING["int64"] - - -@dataclasses.dataclass -class SymbolExpr: - value: dace.symbolic.SymbolicType - dtype: dace.typeclass - - -@dataclasses.dataclass -class ValueExpr: - value: dace.nodes.AccessNode - dtype: dace.typeclass - - -@dataclasses.dataclass -class IteratorExpr: - field: dace.nodes.AccessNode - indices: dict[str, dace.nodes.AccessNode] - dtype: dace.typeclass - dimensions: list[str] - - -# Union of possible expression types -TaskletExpr: TypeAlias = IteratorExpr | SymbolExpr | ValueExpr - - -@dataclasses.dataclass -class Context: - body: dace.SDFG - state: dace.SDFGState - symbol_map: dict[str, TaskletExpr] - # if we encounter a reduction node, the reduction state needs to be pushed to child nodes - reduce_identity: Optional[SymbolExpr] - - def __init__( - self, - body: dace.SDFG, - state: dace.SDFGState, - symbol_map: dict[str, TaskletExpr], - reduce_identity: Optional[SymbolExpr] = None, - ): - self.body = body - self.state = state - self.symbol_map = symbol_map - self.reduce_identity = reduce_identity - - -def _visit_lift_in_neighbors_reduction( - transformer: PythonTaskletCodegen, - node: itir.FunCall, - node_args: Sequence[IteratorExpr | list[ValueExpr]], - connectivity_type: common.NeighborConnectivityType, - map_entry: dace.nodes.MapEntry, - map_exit: dace.nodes.MapExit, - neighbor_index_node: dace.nodes.AccessNode, - neighbor_value_node: dace.nodes.AccessNode, -) -> list[ValueExpr]: - assert transformer.context.reduce_identity is not None - neighbor_dim = connectivity_type.codomain.value - origin_dim = connectivity_type.source_dim.value - - lifted_args: list[IteratorExpr | ValueExpr] = [] - for arg in node_args: - if isinstance(arg, IteratorExpr): - if origin_dim in arg.indices: - lifted_indices = arg.indices.copy() - lifted_indices.pop(origin_dim) - lifted_indices[neighbor_dim] = neighbor_index_node - lifted_args.append( - IteratorExpr(arg.field, lifted_indices, arg.dtype, arg.dimensions) - ) - else: - lifted_args.append(arg) - else: - lifted_args.append(arg[0]) - - lift_context, inner_inputs, inner_outputs = transformer.visit(node.args[0], args=lifted_args) - assert len(inner_outputs) == 1 - inner_out_connector = inner_outputs[0].value.data - - input_nodes = {} - iterator_index_nodes = {} - lifted_index_connectors = [] - - for x, y in inner_inputs: - if isinstance(y, IteratorExpr): - field_connector, inner_index_table = x - input_nodes[field_connector] = y.field - for dim, connector in inner_index_table.items(): - if dim == neighbor_dim: - lifted_index_connectors.append(connector) - iterator_index_nodes[connector] = y.indices[dim] - else: - assert isinstance(y, ValueExpr) - input_nodes[x] = y.value - - neighbor_tables = get_used_connectivities(node.args[0], transformer.offset_provider_type) - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - parent_sdfg = transformer.context.body - parent_state = transformer.context.state - - input_mapping = { - connector: dace.Memlet.from_array(node.data, node.desc(parent_sdfg)) - for connector, node in input_nodes.items() - } - connectivity_mapping = { - name: dace.Memlet.from_array(name, parent_sdfg.arrays[name]) for name in connectivity_names - } - array_mapping = {**input_mapping, **connectivity_mapping} - symbol_mapping = map_nested_sdfg_symbols(parent_sdfg, lift_context.body, array_mapping) - - nested_sdfg_node = parent_state.add_nested_sdfg( - lift_context.body, - parent_sdfg, - inputs={*array_mapping.keys(), *iterator_index_nodes.keys()}, - outputs={inner_out_connector}, - symbol_mapping=symbol_mapping, - debuginfo=lift_context.body.debuginfo, - ) - - for connectivity_connector, memlet in connectivity_mapping.items(): - parent_state.add_memlet_path( - parent_state.add_access(memlet.data, debuginfo=lift_context.body.debuginfo), - map_entry, - nested_sdfg_node, - dst_conn=connectivity_connector, - memlet=memlet, - ) - - for inner_connector, access_node in input_nodes.items(): - parent_state.add_memlet_path( - access_node, - map_entry, - nested_sdfg_node, - dst_conn=inner_connector, - memlet=input_mapping[inner_connector], - ) - - for inner_connector, access_node in iterator_index_nodes.items(): - memlet = dace.Memlet(data=access_node.data, subset="0") - if inner_connector in lifted_index_connectors: - parent_state.add_edge(access_node, None, nested_sdfg_node, inner_connector, memlet) - else: - parent_state.add_memlet_path( - access_node, map_entry, nested_sdfg_node, dst_conn=inner_connector, memlet=memlet - ) - - parent_state.add_memlet_path( - nested_sdfg_node, - map_exit, - neighbor_value_node, - src_conn=inner_out_connector, - memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)), - ) - - if connectivity_type.has_skip_values: - # check neighbor validity on if/else inter-state edge - # use one branch for connectivity case - start_state = lift_context.body.add_state_before( - lift_context.body.start_state, - "start", - condition=f"{lifted_index_connectors[0]} != {neighbor_skip_value}", - ) - # use the other branch for skip value case - skip_neighbor_state = lift_context.body.add_state("skip_neighbor") - skip_neighbor_state.add_edge( - skip_neighbor_state.add_tasklet( - "identity", {}, {"val"}, f"val = {transformer.context.reduce_identity.value}" - ), - "val", - skip_neighbor_state.add_access(inner_outputs[0].value.data), - None, - dace.Memlet(data=inner_outputs[0].value.data, subset="0"), - ) - lift_context.body.add_edge( - start_state, - skip_neighbor_state, - dace.InterstateEdge(condition=f"{lifted_index_connectors[0]} == {neighbor_skip_value}"), - ) - - return [ValueExpr(neighbor_value_node, inner_outputs[0].dtype)] - - -def builtin_neighbors( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - sdfg: dace.SDFG = transformer.context.body - state: dace.SDFGState = transformer.context.state - - di = dace_utils.debug_info(node, default=sdfg.debuginfo) - offset_literal, data = node_args - assert isinstance(offset_literal, itir.OffsetLiteral) - offset_dim = offset_literal.value - assert isinstance(offset_dim, str) - connectivity_type = transformer.offset_provider_type[offset_dim] - if not isinstance(connectivity_type, common.NeighborConnectivityType): - raise NotImplementedError( - "Neighbor reduction only implemented for connectivity based on neighbor tables." - ) - - lift_node = None - if isinstance(data, FunCall): - assert isinstance(data.fun, itir.FunCall) - fun_node = data.fun - if isinstance(fun_node.fun, itir.SymRef) and fun_node.fun.id == "lift": - lift_node = fun_node - lift_args = transformer.visit(data.args) - iterator = next(filter(lambda x: isinstance(x, IteratorExpr), lift_args), None) - if lift_node is None: - iterator = transformer.visit(data) - assert isinstance(iterator, IteratorExpr) - field_desc = iterator.field.desc(transformer.context.body) - origin_index_node = iterator.indices[connectivity_type.source_dim.value] - - assert transformer.context.reduce_identity is not None - assert transformer.context.reduce_identity.dtype == iterator.dtype - - # gather the neighbors in a result array dimensioned for `max_neighbors` - neighbor_value_var = unique_var_name() - sdfg.add_array( - neighbor_value_var, - dtype=iterator.dtype, - shape=(connectivity_type.max_neighbors,), - transient=True, - ) - neighbor_value_node = state.add_access(neighbor_value_var, debuginfo=di) - - # allocate scalar to store index for direct addressing of neighbor field - neighbor_index_var = unique_var_name() - sdfg.add_scalar(neighbor_index_var, _INDEX_DTYPE, transient=True) - neighbor_index_node = state.add_access(neighbor_index_var, debuginfo=di) - - # generate unique map index name to avoid conflict with other maps inside same state - neighbor_map_index = unique_name(f"{offset_dim}_neighbor_map_idx") - me, mx = state.add_map( - f"{offset_dim}_neighbor_map", - ndrange={neighbor_map_index: f"0:{connectivity_type.max_neighbors}"}, - debuginfo=di, - ) - - table_name = dace_utils.connectivity_identifier(offset_dim) - shift_tasklet = state.add_tasklet( - "shift", - code=f"__result = __table[__idx, {neighbor_map_index}]", - inputs={"__table", "__idx"}, - outputs={"__result"}, - debuginfo=di, - ) - state.add_memlet_path( - state.add_access(table_name, debuginfo=di), - me, - shift_tasklet, - memlet=dace.Memlet.from_array(table_name, sdfg.arrays[table_name]), - dst_conn="__table", - ) - state.add_memlet_path( - origin_index_node, - me, - shift_tasklet, - memlet=dace.Memlet(data=origin_index_node.data, subset="0"), - dst_conn="__idx", - ) - state.add_edge( - shift_tasklet, - "__result", - neighbor_index_node, - None, - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - - if lift_node is not None: - _visit_lift_in_neighbors_reduction( - transformer, - lift_node, - lift_args, - connectivity_type, - me, - mx, - neighbor_index_node, - neighbor_value_node, - ) - else: - sorted_dims = transformer.get_sorted_field_dimensions(iterator.dimensions) - data_access_index = ",".join(f"{dim}_v" for dim in sorted_dims) - connector_neighbor_dim = f"{connectivity_type.codomain.value}_v" - data_access_tasklet = state.add_tasklet( - "data_access", - code=f"__data = __field[{data_access_index}] " - + ( - f"if {connector_neighbor_dim} != {neighbor_skip_value} else {transformer.context.reduce_identity.value}" - if connectivity_type.has_skip_values - else "" - ), - inputs={"__field"} | {f"{dim}_v" for dim in iterator.dimensions}, - outputs={"__data"}, - debuginfo=di, - ) - state.add_memlet_path( - iterator.field, - me, - data_access_tasklet, - memlet=dace.Memlet.from_array(iterator.field.data, field_desc), - dst_conn="__field", - ) - for dim in iterator.dimensions: - connector = f"{dim}_v" - if dim == connectivity_type.codomain.value: - state.add_edge( - neighbor_index_node, - None, - data_access_tasklet, - connector, - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - else: - state.add_memlet_path( - iterator.indices[dim], - me, - data_access_tasklet, - dst_conn=connector, - memlet=dace.Memlet(data=iterator.indices[dim].data, subset="0"), - ) - - state.add_memlet_path( - data_access_tasklet, - mx, - neighbor_value_node, - memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index), - src_conn="__data", - ) - - if not connectivity_type.has_skip_values: - return [ValueExpr(neighbor_value_node, iterator.dtype)] - else: - """ - In case of neighbor tables with skip values, in addition to the array of neighbor values this function also - returns an array of booleans to indicate if the neighbor value is present or not. This node is only used - for neighbor reductions with lambda functions, a very specific case. For single input neighbor reductions, - the regular case, this node will be removed by the simplify pass. - """ - neighbor_valid_var = unique_var_name() - sdfg.add_array( - neighbor_valid_var, - dtype=dace.dtypes.bool, - shape=(connectivity_type.max_neighbors,), - transient=True, - ) - neighbor_valid_node = state.add_access(neighbor_valid_var, debuginfo=di) - - neighbor_valid_tasklet = state.add_tasklet( - f"check_valid_neighbor_{offset_dim}", - {"__idx"}, - {"__valid"}, - f"__valid = True if __idx != {neighbor_skip_value} else False", - debuginfo=di, - ) - state.add_edge( - neighbor_index_node, - None, - neighbor_valid_tasklet, - "__idx", - dace.Memlet(data=neighbor_index_var, subset="0"), - ) - state.add_memlet_path( - neighbor_valid_tasklet, - mx, - neighbor_valid_node, - memlet=dace.Memlet(data=neighbor_valid_var, subset=neighbor_map_index), - src_conn="__valid", - ) - return [ - ValueExpr(neighbor_value_node, iterator.dtype), - ValueExpr(neighbor_valid_node, dace.dtypes.bool), - ] - - -def builtin_can_deref( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - # first visit shift, to get set of indices for deref - can_deref_callable = node_args[0] - assert isinstance(can_deref_callable, itir.FunCall) - shift_callable = can_deref_callable.fun - assert isinstance(shift_callable, itir.FunCall) - assert isinstance(shift_callable.fun, itir.SymRef) - assert shift_callable.fun.id == "shift" - iterator = transformer._visit_shift(can_deref_callable) - - # TODO: remove this special case when ITIR reduce-unroll pass is able to catch it - if not isinstance(iterator, IteratorExpr): - assert len(iterator) == 1 and isinstance(iterator[0], ValueExpr) - # We can always deref a value expression, therefore hard-code `can_deref` to True. - # Returning a SymbolExpr would be preferable, but it requires update to type-checking. - result_name = unique_var_name() - transformer.context.body.add_scalar(result_name, dace.dtypes.bool, transient=True) - result_node = transformer.context.state.add_access(result_name, debuginfo=di) - transformer.context.state.add_edge( - transformer.context.state.add_tasklet( - "can_always_deref", {}, {"_out"}, "_out = True", debuginfo=di - ), - "_out", - result_node, - None, - dace.Memlet(data=result_name, subset="0"), - ) - return [ValueExpr(result_node, dace.dtypes.bool)] - - # create tasklet to check that field indices are non-negative (-1 is invalid) - args = [ValueExpr(access_node, _INDEX_DTYPE) for access_node in iterator.indices.values()] - internals = [f"{arg.value.data}_v" for arg in args] - expr_code = " and ".join(f"{v} != {neighbor_skip_value}" for v in internals) - - return transformer.add_expr_tasklet( - list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref", dace_debuginfo=di - ) - - -def builtin_if( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - assert len(node_args) == 3 - sdfg = transformer.context.body - current_state = transformer.context.state - is_start_state = sdfg.start_block == current_state - - # build an empty state to join true and false branches - join_state = sdfg.add_state_before(current_state, "join") - - def build_if_state(arg, state): - symbol_map = copy.deepcopy(transformer.context.symbol_map) - node_context = Context(sdfg, state, symbol_map) - node_taskgen = PythonTaskletCodegen( - transformer.offset_provider_type, - node_context, - transformer.use_field_canonical_representation, - ) - return node_taskgen.visit(arg) - - # represent the if-statement condition as a tasklet inside an `if_statement` state preceding `join` state - stmt_state = sdfg.add_state_before(join_state, "if_statement", is_start_state) - stmt_node = build_if_state(node_args[0], stmt_state)[0] - assert isinstance(stmt_node, ValueExpr) - assert stmt_node.dtype == dace.dtypes.bool - assert sdfg.arrays[stmt_node.value.data].shape == (1,) - - # visit true and false branches (here called `tbr` and `fbr`) as separate states, following `if_statement` state - tbr_state = sdfg.add_state("true_branch") - sdfg.add_edge( - stmt_state, tbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == True") - ) - sdfg.add_edge(tbr_state, join_state, dace.InterstateEdge()) - tbr_values = flatten_list(build_if_state(node_args[1], tbr_state)) - # - fbr_state = sdfg.add_state("false_branch") - sdfg.add_edge( - stmt_state, fbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == False") - ) - sdfg.add_edge(fbr_state, join_state, dace.InterstateEdge()) - fbr_values = flatten_list(build_if_state(node_args[2], fbr_state)) - - assert isinstance(stmt_node, ValueExpr) - assert stmt_node.dtype == dace.dtypes.bool - # make the result of the if-statement evaluation available inside current state - ctx_stmt_node = ValueExpr(current_state.add_access(stmt_node.value.data), stmt_node.dtype) - - # we distinguish between select if-statements, where both true and false branches are symbolic expressions, - # and therefore do not require exclusive branch execution, and regular if-statements where at least one branch - # is a value expression, which has to be evaluated at runtime with conditional state transition - result_values = [] - assert len(tbr_values) == len(fbr_values) - for tbr_value, fbr_value in zip(tbr_values, fbr_values): - assert isinstance(tbr_value, (SymbolExpr, ValueExpr)) - assert isinstance(fbr_value, (SymbolExpr, ValueExpr)) - assert tbr_value.dtype == fbr_value.dtype - - if all(isinstance(x, SymbolExpr) for x in (tbr_value, fbr_value)): - # both branches return symbolic expressions, therefore the if-node can be translated - # to a select-tasklet inside current state - # TODO: use select-memlet when it becomes available in dace - code = f"{tbr_value.value} if _cond else {fbr_value.value}" - if_expr = transformer.add_expr_tasklet( - [(ctx_stmt_node, "_cond")], code, tbr_value.dtype, "if_select" - )[0] - result_values.append(if_expr) - else: - # at least one of the two branches contains a value expression, which should be evaluated - # only if the corresponding true/false condition is satisfied - desc = sdfg.arrays[ - tbr_value.value.data if isinstance(tbr_value, ValueExpr) else fbr_value.value.data - ] - var = unique_var_name() - if isinstance(desc, dace.data.Scalar): - sdfg.add_scalar(var, desc.dtype, transient=True) - else: - sdfg.add_array(var, desc.shape, desc.dtype, transient=True) - - # write result to transient data container and access it in the original state - for state, expr in [(tbr_state, tbr_value), (fbr_state, fbr_value)]: - val_node = state.add_access(var) - if isinstance(expr, ValueExpr): - state.add_nedge( - expr.value, val_node, dace.Memlet.from_array(expr.value.data, desc) - ) - else: - assert desc.shape == (1,) - state.add_edge( - state.add_tasklet("write_symbol", {}, {"_out"}, f"_out = {expr.value}"), - "_out", - val_node, - None, - dace.Memlet(var, "0"), - ) - result_values.append(ValueExpr(current_state.add_access(var), desc.dtype)) - - if tbr_state.is_empty() and fbr_state.is_empty(): - # if all branches are symbolic expressions, the true/false and join states can be removed - # as well as the conditional state transition - sdfg.remove_nodes_from([join_state, tbr_state, fbr_state]) - sdfg.add_edge(stmt_state, current_state, dace.InterstateEdge()) - elif tbr_state.is_empty(): - # use direct edge from if-statement to join state for true branch - tbr_condition = sdfg.edges_between(stmt_state, tbr_state)[0].condition - sdfg.edges_between(stmt_state, join_state)[0].contition = tbr_condition - sdfg.remove_node(tbr_state) - elif fbr_state.is_empty(): - # use direct edge from if-statement to join state for false branch - fbr_condition = sdfg.edges_between(stmt_state, fbr_state)[0].condition - sdfg.edges_between(stmt_state, join_state)[0].contition = fbr_condition - sdfg.remove_node(fbr_state) - else: - # remove direct edge from if-statement to join state - sdfg.remove_edge(sdfg.edges_between(stmt_state, join_state)[0]) - # the if-statement condition is not used in current state - current_state.remove_node(ctx_stmt_node.value) - - return result_values - - -def builtin_list_get( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - args = list(itertools.chain(*transformer.visit(node_args))) - assert len(args) == 2 - # index node - if isinstance(args[0], SymbolExpr): - index_value = args[0].value - result_name = unique_var_name() - transformer.context.body.add_scalar(result_name, args[1].dtype, transient=True) - result_node = transformer.context.state.add_access(result_name) - transformer.context.state.add_nedge( - args[1].value, result_node, dace.Memlet(data=args[1].value.data, subset=index_value) - ) - return [ValueExpr(result_node, args[1].dtype)] - - else: - expr_args = [(arg, f"{arg.value.data}_v") for arg in args] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[1]}[{internals[0]}]" - return transformer.add_expr_tasklet( - expr_args, expr, args[1].dtype, "list_get", dace_debuginfo=di - ) - - -def builtin_cast( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - args = transformer.visit(node_args[0]) - internals = [f"{arg.value.data}_v" for arg in args] - target_type = node_args[1] - assert isinstance(target_type, itir.SymRef) - expr = _MATH_BUILTINS_MAPPING[target_type.id].format(*internals) - type_ = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference - return transformer.add_expr_tasklet( - list(zip(args, internals)), expr, type_, "cast", dace_debuginfo=di - ) - - -def builtin_make_const_list( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=transformer.context.body.debuginfo) - args = [transformer.visit(arg)[0] for arg in node_args] - assert all(isinstance(x, (SymbolExpr, ValueExpr)) for x in args) - args_dtype = [x.dtype for x in args] - assert len(set(args_dtype)) == 1 - dtype = args_dtype[0] - - var_name = unique_var_name() - transformer.context.body.add_array(var_name, (len(args),), dtype, transient=True) - var_node = transformer.context.state.add_access(var_name, debuginfo=di) - - for i, arg in enumerate(args): - if isinstance(arg, SymbolExpr): - transformer.context.state.add_edge( - transformer.context.state.add_tasklet( - f"get_arg{i}", {}, {"val"}, f"val = {arg.value}" - ), - "val", - var_node, - None, - dace.Memlet(data=var_name, subset=f"{i}"), - ) - else: - assert arg.value.desc(transformer.context.body).shape == (1,) - transformer.context.state.add_nedge( - arg.value, - var_node, - dace.Memlet(data=arg.value.data, subset="0", other_subset=f"{i}"), - ) - - return [ValueExpr(var_node, dtype)] - - -def builtin_make_tuple( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - args = [transformer.visit(arg) for arg in node_args] - return args - - -def builtin_tuple_get( - transformer: PythonTaskletCodegen, node: itir.Expr, node_args: list[itir.Expr] -) -> list[ValueExpr]: - elements = transformer.visit(node_args[1]) - index = node_args[0] - if isinstance(index, itir.Literal): - return [elements[int(index.value)]] - raise ValueError("Tuple can only be subscripted with compile-time constants.") - - -_GENERAL_BUILTIN_MAPPING: dict[ - str, Callable[[PythonTaskletCodegen, itir.Expr, list[itir.Expr]], list[ValueExpr]] -] = { - "can_deref": builtin_can_deref, - "cast_": builtin_cast, - "if_": builtin_if, - "list_get": builtin_list_get, - "make_const_list": builtin_make_const_list, - "make_tuple": builtin_make_tuple, - "neighbors": builtin_neighbors, - "tuple_get": builtin_tuple_get, -} - - -class GatherLambdaSymbolsPass(eve.NodeVisitor): - _sdfg: dace.SDFG - _state: dace.SDFGState - _symbol_map: dict[str, TaskletExpr | tuple[ValueExpr]] - _parent_symbol_map: dict[str, TaskletExpr] - - def __init__(self, sdfg, state, parent_symbol_map): - self._sdfg = sdfg - self._state = state - self._symbol_map = {} - self._parent_symbol_map = parent_symbol_map - - @property - def symbol_refs(self): - """Dictionary of symbols referenced from the lambda expression.""" - return self._symbol_map - - def _add_symbol(self, param, arg): - if isinstance(arg, ValueExpr): - # create storage in lambda sdfg - self._sdfg.add_scalar(param, dtype=arg.dtype) - # update table of lambda symbols - self._symbol_map[param] = ValueExpr( - self._state.add_access(param, debuginfo=self._sdfg.debuginfo), arg.dtype - ) - elif isinstance(arg, IteratorExpr): - # create storage in lambda sdfg - ndims = len(arg.dimensions) - shape, strides = new_array_symbols(param, ndims) - self._sdfg.add_array(param, shape=shape, strides=strides, dtype=arg.dtype) - index_names = {dim: f"__{param}_i_{dim}" for dim in arg.indices.keys()} - for _, index_name in index_names.items(): - self._sdfg.add_scalar(index_name, dtype=_INDEX_DTYPE) - # update table of lambda symbols - field = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) - indices = { - dim: self._state.add_access(index_arg, debuginfo=self._sdfg.debuginfo) - for dim, index_arg in index_names.items() - } - self._symbol_map[param] = IteratorExpr(field, indices, arg.dtype, arg.dimensions) - else: - assert isinstance(arg, SymbolExpr) - self._symbol_map[param] = arg - - def _add_tuple(self, param, args): - nodes = [] - # create storage in lambda sdfg for each tuple element - for arg in args: - var = unique_var_name() - self._sdfg.add_scalar(var, dtype=arg.dtype) - arg_node = self._state.add_access(var, debuginfo=self._sdfg.debuginfo) - nodes.append(ValueExpr(arg_node, arg.dtype)) - # update table of lambda symbols - self._symbol_map[param] = tuple(nodes) - - def visit_SymRef(self, node: itir.SymRef): - name = str(node.id) - if name in self._parent_symbol_map and name not in self._symbol_map: - arg = self._parent_symbol_map[name] - self._add_symbol(name, arg) - - def visit_Lambda(self, node: itir.Lambda, args: Optional[Sequence[TaskletExpr]] = None): - if args is not None: - if len(node.params) == len(args): - for param, arg in zip(node.params, args): - self._add_symbol(str(param.id), arg) - else: - # implicitly make tuple - assert len(node.params) == 1 - self._add_tuple(str(node.params[0].id), args) - self.visit(node.expr) - - -class GatherOutputSymbolsPass(eve.NodeVisitor): - _sdfg: dace.SDFG - _state: dace.SDFGState - _symbol_map: dict[str, TaskletExpr] - - @property - def symbol_refs(self): - """Dictionary of symbols referenced from the output expression.""" - return self._symbol_map - - def __init__(self, sdfg, state): - self._sdfg = sdfg - self._state = state - self._symbol_map = {} - - def visit_SymRef(self, node: itir.SymRef): - param = str(node.id) - if param not in _GENERAL_BUILTIN_MAPPING and param not in self._symbol_map: - access_node = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) - self._symbol_map[param] = ValueExpr( - access_node, - dtype=itir_type_as_dace_type(node.type), # type: ignore[arg-type] # ensure by type inference - ) - - -@dataclasses.dataclass -class PythonTaskletCodegen(gt4py.eve.codegen.TemplatedGenerator): - offset_provider_type: common.OffsetProviderType - context: Context - use_field_canonical_representation: bool - - def get_sorted_field_dimensions(self, dims: Sequence[str]): - return sorted(dims) if self.use_field_canonical_representation else dims - - def visit_FunctionDefinition(self, node: itir.FunctionDefinition, **kwargs): - raise NotImplementedError() - - def visit_Lambda( - self, node: itir.Lambda, args: Sequence[TaskletExpr], use_neighbor_tables: bool = True - ) -> tuple[ - Context, - list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]], - list[ValueExpr], - ]: - func_name = f"lambda_{abs(hash(node)):x}" - neighbor_tables = ( - get_used_connectivities(node, self.offset_provider_type) if use_neighbor_tables else {} - ) - connectivity_names = [ - dace_utils.connectivity_identifier(offset) for offset in neighbor_tables.keys() - ] - - # Create the SDFG for the lambda's body - lambda_sdfg = dace.SDFG(func_name) - lambda_sdfg.debuginfo = dace_utils.debug_info(node, default=self.context.body.debuginfo) - lambda_state = lambda_sdfg.add_state(f"{func_name}_body", is_start_block=True) - - lambda_symbols_pass = GatherLambdaSymbolsPass( - lambda_sdfg, lambda_state, self.context.symbol_map - ) - lambda_symbols_pass.visit(node, args=args) - - # Add for input nodes for lambda symbols - inputs: list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]] = [] - for sym, input_node in lambda_symbols_pass.symbol_refs.items(): - params = [str(p.id) for p in node.params] - try: - param_index = params.index(sym) - except ValueError: - param_index = -1 - if param_index >= 0: - outer_node = args[param_index] - else: - # the symbol is not found among lambda arguments, then it is inherited from parent scope - outer_node = self.context.symbol_map[sym] - if isinstance(input_node, IteratorExpr): - assert isinstance(outer_node, IteratorExpr) - index_params = { - dim: index_node.data for dim, index_node in input_node.indices.items() - } - inputs.append(((sym, index_params), outer_node)) - elif isinstance(input_node, ValueExpr): - assert isinstance(outer_node, ValueExpr) - inputs.append((sym, outer_node)) - elif isinstance(input_node, tuple): - assert param_index >= 0 - for i, input_node_i in enumerate(input_node): - arg_i = args[param_index + i] - assert isinstance(arg_i, ValueExpr) - assert isinstance(input_node_i, ValueExpr) - inputs.append((input_node_i.value.data, arg_i)) - - # Add connectivities as arrays - for name in connectivity_names: - shape, strides = new_array_symbols(name, ndim=2) - dtype = self.context.body.arrays[name].dtype - lambda_sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) - - # Translate the lambda's body in its own context - lambda_context = Context( - lambda_sdfg, - lambda_state, - lambda_symbols_pass.symbol_refs, - reduce_identity=self.context.reduce_identity, - ) - lambda_taskgen = PythonTaskletCodegen( - self.offset_provider_type, - lambda_context, - self.use_field_canonical_representation, - ) - - results: list[ValueExpr] = [] - # We are flattening the returned list of value expressions because the multiple outputs of a lambda - # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this. - node.expr.location = node.location - for expr in flatten_list(lambda_taskgen.visit(node.expr)): - if isinstance(expr, ValueExpr): - result_name = unique_var_name() - lambda_sdfg.add_scalar(result_name, expr.dtype, transient=True) - result_access = lambda_state.add_access( - result_name, debuginfo=lambda_sdfg.debuginfo - ) - lambda_state.add_nedge( - expr.value, result_access, dace.Memlet(data=result_access.data, subset="0") - ) - result = ValueExpr(value=result_access, dtype=expr.dtype) - else: - # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors - result = lambda_taskgen.add_expr_tasklet( - [], expr.value, expr.dtype, "forward", dace_debuginfo=lambda_sdfg.debuginfo - )[0] - lambda_sdfg.arrays[result.value.data].transient = False - results.append(result) - - # remove isolated access nodes for connectivity arrays not consumed by lambda - for sub_node in lambda_state.nodes(): - if isinstance(sub_node, dace.nodes.AccessNode): - if lambda_state.out_degree(sub_node) == 0 and lambda_state.in_degree(sub_node) == 0: - lambda_state.remove_node(sub_node) - - return lambda_context, inputs, results - - def visit_SymRef(self, node: itir.SymRef) -> list[ValueExpr | SymbolExpr] | IteratorExpr: - param = str(node.id) - value = self.context.symbol_map[param] - if isinstance(value, (ValueExpr, SymbolExpr)): - return [value] - return value - - def visit_Literal(self, node: itir.Literal) -> list[SymbolExpr]: - return [SymbolExpr(node.value, itir_type_as_dace_type(node.type))] - - def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: - node.fun.location = node.location - if isinstance(node.fun, itir.SymRef) and node.fun.id == "deref": - return self._visit_deref(node) - if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): - if node.fun.fun.id == "shift": - return self._visit_shift(node) - elif node.fun.fun.id == "reduce": - return self._visit_reduce(node) - - if isinstance(node.fun, itir.SymRef): - builtin_name = str(node.fun.id) - if builtin_name in _MATH_BUILTINS_MAPPING: - return self._visit_numeric_builtin(node) - elif builtin_name in _GENERAL_BUILTIN_MAPPING: - return self._visit_general_builtin(node) - else: - raise NotImplementedError(f"'{builtin_name}' not implemented.") - return self._visit_call(node) - - def _visit_call(self, node: itir.FunCall): - args = self.visit(node.args) - args = [arg if isinstance(arg, Sequence) else [arg] for arg in args] - args = list(itertools.chain(*args)) - node.fun.location = node.location - func_context, func_inputs, results = self.visit(node.fun, args=args) - - nsdfg_inputs = {} - for name, value in func_inputs: - if isinstance(value, ValueExpr): - nsdfg_inputs[name] = dace.Memlet.from_array( - value.value.data, self.context.body.arrays[value.value.data] - ) - else: - assert isinstance(value, IteratorExpr) - field = name[0] - indices = name[1] - nsdfg_inputs[field] = dace.Memlet.from_array( - value.field.data, self.context.body.arrays[value.field.data] - ) - for dim, var in indices.items(): - store = value.indices[dim].data - nsdfg_inputs[var] = dace.Memlet.from_array( - store, self.context.body.arrays[store] - ) - - neighbor_tables = get_used_connectivities(node.fun, self.offset_provider_type) - for offset in neighbor_tables.keys(): - var = dace_utils.connectivity_identifier(offset) - nsdfg_inputs[var] = dace.Memlet.from_array(var, self.context.body.arrays[var]) - - symbol_mapping = map_nested_sdfg_symbols(self.context.body, func_context.body, nsdfg_inputs) - - nsdfg_node = self.context.state.add_nested_sdfg( - func_context.body, - None, - inputs=set(nsdfg_inputs.keys()), - outputs=set(r.value.data for r in results), - symbol_mapping=symbol_mapping, - debuginfo=dace_utils.debug_info(node, default=func_context.body.debuginfo), - ) - - for name, value in func_inputs: - if isinstance(value, ValueExpr): - value_memlet = nsdfg_inputs[name] - self.context.state.add_edge(value.value, None, nsdfg_node, name, value_memlet) - else: - assert isinstance(value, IteratorExpr) - field = name[0] - indices = name[1] - field_memlet = nsdfg_inputs[field] - self.context.state.add_edge(value.field, None, nsdfg_node, field, field_memlet) - for dim, var in indices.items(): - store = value.indices[dim] - idx_memlet = nsdfg_inputs[var] - self.context.state.add_edge(store, None, nsdfg_node, var, idx_memlet) - for offset in neighbor_tables.keys(): - var = dace_utils.connectivity_identifier(offset) - memlet = nsdfg_inputs[var] - access = self.context.state.add_access(var, debuginfo=nsdfg_node.debuginfo) - self.context.state.add_edge(access, None, nsdfg_node, var, memlet) - - result_exprs = [] - for result in results: - name = unique_var_name() - self.context.body.add_scalar(name, result.dtype, transient=True) - result_access = self.context.state.add_access(name, debuginfo=nsdfg_node.debuginfo) - result_exprs.append(ValueExpr(result_access, result.dtype)) - memlet = dace.Memlet.from_array(name, self.context.body.arrays[name]) - self.context.state.add_edge(nsdfg_node, result.value.data, result_access, None, memlet) - - return result_exprs - - def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - iterator = self.visit(node.args[0]) - if not isinstance(iterator, IteratorExpr): - # already a list of ValueExpr - return iterator - - sorted_dims = self.get_sorted_field_dimensions(iterator.dimensions) - if all([dim in iterator.indices for dim in iterator.dimensions]): - # The deref iterator has index values on all dimensions: the result will be a scalar - args = [ValueExpr(iterator.field, iterator.dtype)] + [ - ValueExpr(iterator.indices[dim], _INDEX_DTYPE) for dim in sorted_dims - ] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{', '.join(internals[1:])}]" - return self.add_expr_tasklet( - list(zip(args, internals)), expr, iterator.dtype, "deref", dace_debuginfo=di - ) - - else: - dims_not_indexed = [dim for dim in iterator.dimensions if dim not in iterator.indices] - assert len(dims_not_indexed) == 1 - offset = dims_not_indexed[0] - offset_provider_type = self.offset_provider_type[offset] - assert isinstance(offset_provider_type, common.NeighborConnectivityType) - neighbor_dim = offset_provider_type.codomain.value - - result_name = unique_var_name() - self.context.body.add_array( - result_name, (offset_provider_type.max_neighbors,), iterator.dtype, transient=True - ) - result_array = self.context.body.arrays[result_name] - result_node = self.context.state.add_access(result_name, debuginfo=di) - - deref_connectors = ["_inp"] + [ - f"_i_{dim}" for dim in sorted_dims if dim in iterator.indices - ] - deref_nodes = [iterator.field] + [ - iterator.indices[dim] for dim in sorted_dims if dim in iterator.indices - ] - deref_memlets = [ - dace.Memlet.from_array(iterator.field.data, iterator.field.desc(self.context.body)) - ] + [dace.Memlet(data=node.data, subset="0") for node in deref_nodes[1:]] - - # we create a mapped tasklet for array slicing - index_name = unique_name(f"_i_{neighbor_dim}") - map_ranges = {index_name: f"0:{offset_provider_type.max_neighbors}"} - src_subset = ",".join( - [f"_i_{dim}" if dim in iterator.indices else index_name for dim in sorted_dims] - ) - self.context.state.add_mapped_tasklet( - "deref", - map_ranges, - inputs={k: v for k, v in zip(deref_connectors, deref_memlets)}, - outputs={"_out": dace.Memlet.from_array(result_name, result_array)}, - code=f"_out[{index_name}] = _inp[{src_subset}]", - external_edges=True, - input_nodes={node.data: node for node in deref_nodes}, - output_nodes={result_name: result_node}, - debuginfo=di, - ) - return [ValueExpr(result_node, iterator.dtype)] - - def _split_shift_args( - self, args: list[itir.Expr] - ) -> tuple[list[itir.Expr], Optional[list[itir.Expr]]]: - pairs = [args[i : i + 2] for i in range(0, len(args), 2)] - assert len(pairs) >= 1 - assert all(len(pair) == 2 for pair in pairs) - return pairs[-1], list(itertools.chain(*pairs[0:-1])) if len(pairs) > 1 else None - - def _make_shift_for_rest(self, rest, iterator): - return itir.FunCall( - fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), - args=[iterator], - location=iterator.location, - ) - - def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - shift = node.fun - assert isinstance(shift, itir.FunCall) - tail, rest = self._split_shift_args(shift.args) - if rest: - iterator = self.visit(self._make_shift_for_rest(rest, node.args[0])) - else: - iterator = self.visit(node.args[0]) - if not isinstance(iterator, IteratorExpr): - # shift cannot be applied because the argument is not iterable - # TODO: remove this special case when ITIR pass is able to catch it - assert isinstance(iterator, list) and len(iterator) == 1 - assert isinstance(iterator[0], ValueExpr) - return iterator - - assert isinstance(tail[0], itir.OffsetLiteral) - offset_dim = tail[0].value - assert isinstance(offset_dim, str) - offset_node = self.visit(tail[1])[0] - assert offset_node.dtype in dace.dtypes.INTEGER_TYPES - - if isinstance(self.offset_provider_type[offset_dim], common.NeighborConnectivityType): - offset_provider_type = cast( - common.NeighborConnectivityType, self.offset_provider_type[offset_dim] - ) # ensured by condition - connectivity = self.context.state.add_access( - dace_utils.connectivity_identifier(offset_dim), debuginfo=di - ) - - shifted_dim_tag = offset_provider_type.source_dim.value - target_dim_tag = offset_provider_type.codomain.value - args = [ - ValueExpr(connectivity, _INDEX_DTYPE), - ValueExpr(iterator.indices[shifted_dim_tag], offset_node.dtype), - offset_node, - ] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{internals[1]}, {internals[2]}]" - else: - shifted_dim = self.offset_provider_type[offset_dim] - assert isinstance(shifted_dim, common.Dimension) - - shifted_dim_tag = shifted_dim.value - target_dim_tag = shifted_dim_tag - args = [ValueExpr(iterator.indices[shifted_dim_tag], offset_node.dtype), offset_node] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]} + {internals[1]}" - - shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, offset_node.dtype, "shift", dace_debuginfo=di - )[0].value - - shifted_index = {dim: value for dim, value in iterator.indices.items()} - del shifted_index[shifted_dim_tag] - shifted_index[target_dim_tag] = shifted_value - - return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) - - def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - offset = node.value - assert isinstance(offset, int) - offset_var = unique_var_name() - self.context.body.add_scalar(offset_var, _INDEX_DTYPE, transient=True) - offset_node = self.context.state.add_access(offset_var, debuginfo=di) - tasklet_node = self.context.state.add_tasklet( - "get_offset", {}, {"__out"}, f"__out = {offset}", debuginfo=di - ) - self.context.state.add_edge( - tasklet_node, "__out", offset_node, None, dace.Memlet(data=offset_var, subset="0") - ) - return [ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)] - - def _visit_reduce(self, node: itir.FunCall): - di = dace_utils.debug_info(node, default=self.context.body.debuginfo) - reduce_dtype = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference - - if len(node.args) == 1: - assert ( - isinstance(node.args[0], itir.FunCall) - and isinstance(node.args[0].fun, itir.SymRef) - and node.args[0].fun.id == "neighbors" - ) - assert isinstance(node.fun, itir.FunCall) - op_name = node.fun.args[0] - assert isinstance(op_name, itir.SymRef) - reduce_identity = node.fun.args[1] - assert isinstance(reduce_identity, itir.Literal) - - # set reduction state - self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) - - args = self.visit(node.args[0]) - - assert 1 <= len(args) <= 2 - reduce_input_node = args[0].value - - else: - assert isinstance(node.fun, itir.FunCall) - assert isinstance(node.fun.args[0], itir.Lambda) - fun_node = node.fun.args[0] - assert isinstance(fun_node.expr, itir.FunCall) - - op_name = fun_node.expr.fun - assert isinstance(op_name, itir.SymRef) - reduce_identity = get_reduce_identity_value(op_name.id, reduce_dtype) - - # set reduction state in visit context - self.context.reduce_identity = SymbolExpr(reduce_identity, reduce_dtype) - - args = self.visit(node.args) - - # clear context - self.context.reduce_identity = None - - # check that all neighbor expressions have the same shape - args_shape = [ - arg[0].value.desc(self.context.body).shape - for arg in args - if arg[0].value.desc(self.context.body).shape != (1,) - ] - assert len(set(args_shape)) == 1 - nreduce_shape = args_shape[0] - - input_args = [arg[0] for arg in args] - input_valid_args = [arg[1] for arg in args if len(arg) == 2] - - assert len(nreduce_shape) == 1 - nreduce_index = unique_name("_i") - nreduce_domain = {nreduce_index: f"0:{nreduce_shape[0]}"} - - reduce_input_name = unique_var_name() - self.context.body.add_array( - reduce_input_name, nreduce_shape, reduce_dtype, transient=True - ) - - lambda_node = itir.Lambda( - expr=fun_node.expr.args[1], params=fun_node.params[1:], location=node.location - ) - lambda_context, inner_inputs, inner_outputs = self.visit( - lambda_node, args=input_args, use_neighbor_tables=False - ) - - input_mapping = { - param: ( - dace.Memlet(data=arg.value.data, subset="0") - if arg.value.desc(self.context.body).shape == (1,) - else dace.Memlet(data=arg.value.data, subset=nreduce_index) - ) - for (param, _), arg in zip(inner_inputs, input_args) - } - output_mapping = { - inner_outputs[0].value.data: dace.Memlet( - data=reduce_input_name, subset=nreduce_index - ) - } - symbol_mapping = map_nested_sdfg_symbols( - self.context.body, lambda_context.body, input_mapping - ) - - if input_valid_args: - """ - The neighbors builtin returns an array of booleans in case the connectivity table contains skip values. - These booleans indicate whether the neighbor is present or not, and are used in a tasklet to select - the result of field access or the identity value, respectively. - If the neighbor table has full connectivity (no skip values by type definition), the input_valid node - is not built, and the construction of the select tasklet below is also skipped. - """ - input_args.append(input_valid_args[0]) - input_valid_node = input_valid_args[0].value - lambda_output_node = inner_outputs[0].value - # add input connector to nested sdfg - lambda_context.body.add_scalar("_valid_neighbor", dace.dtypes.bool) - input_mapping["_valid_neighbor"] = dace.Memlet( - data=input_valid_node.data, subset=nreduce_index - ) - # add select tasklet before writing to output node - # TODO: consider replacing it with a select-memlet once it is supported by DaCe SDFG API - output_edge = lambda_context.state.in_edges(lambda_output_node)[0] - assert isinstance( - lambda_context.body.arrays[output_edge.src.data], dace.data.Scalar - ) - select_tasklet = lambda_context.state.add_tasklet( - "neighbor_select", - {"_inp", "_valid"}, - {"_out"}, - f"_out = _inp if _valid else {reduce_identity}", - ) - lambda_context.state.add_edge( - output_edge.src, - None, - select_tasklet, - "_inp", - dace.Memlet(data=output_edge.src.data, subset="0"), - ) - lambda_context.state.add_edge( - lambda_context.state.add_access("_valid_neighbor"), - None, - select_tasklet, - "_valid", - dace.Memlet(data="_valid_neighbor", subset="0"), - ) - lambda_context.state.add_edge( - select_tasklet, - "_out", - lambda_output_node, - None, - dace.Memlet(data=lambda_output_node.data, subset="0"), - ) - lambda_context.state.remove_edge(output_edge) - - reduce_input_node = self.context.state.add_access(reduce_input_name, debuginfo=di) - - nsdfg_node, map_entry, _ = add_mapped_nested_sdfg( - self.context.state, - sdfg=lambda_context.body, - map_ranges=nreduce_domain, - inputs=input_mapping, - outputs=output_mapping, - symbol_mapping=symbol_mapping, - input_nodes={arg.value.data: arg.value for arg in input_args}, - output_nodes={reduce_input_name: reduce_input_node}, - debuginfo=di, - ) - - reduce_input_desc = reduce_input_node.desc(self.context.body) - - result_name = unique_var_name() - # we allocate an array instead of a scalar because the reduce library node is generic and expects an array node - self.context.body.add_array(result_name, (1,), reduce_dtype, transient=True) - result_access = self.context.state.add_access(result_name, debuginfo=di) - - reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y") - reduce_node = self.context.state.add_reduce(reduce_wcr, None, reduce_identity) - self.context.state.add_nedge( - reduce_input_node, - reduce_node, - dace.Memlet.from_array(reduce_input_node.data, reduce_input_desc), - ) - self.context.state.add_nedge( - reduce_node, result_access, dace.Memlet(data=result_name, subset="0") - ) - - return [ValueExpr(result_access, reduce_dtype)] - - def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: - assert isinstance(node.fun, itir.SymRef) - fmt = _MATH_BUILTINS_MAPPING[str(node.fun.id)] - args = flatten_list(self.visit(node.args)) - expr_args = [ - (arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr) - ] - internals = [ - arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args - ] - expr = fmt.format(*internals) - type_ = itir_type_as_dace_type(node.type) # type: ignore[arg-type] # ensure by type inference - return self.add_expr_tasklet( - expr_args, - expr, - type_, - "numeric", - dace_debuginfo=dace_utils.debug_info(node, default=self.context.body.debuginfo), - ) - - def _visit_general_builtin(self, node: itir.FunCall) -> list[ValueExpr]: - assert isinstance(node.fun, itir.SymRef) - expr_func = _GENERAL_BUILTIN_MAPPING[str(node.fun.id)] - return expr_func(self, node, node.args) - - def add_expr_tasklet( - self, - args: list[tuple[ValueExpr, str]], - expr: str, - result_type: Any, - name: str, - dace_debuginfo: Optional[dace.dtypes.DebugInfo] = None, - ) -> list[ValueExpr]: - di = dace_debuginfo if dace_debuginfo else self.context.body.debuginfo - result_name = unique_var_name() - self.context.body.add_scalar(result_name, result_type, transient=True) - result_access = self.context.state.add_access(result_name, debuginfo=di) - - expr_tasklet = self.context.state.add_tasklet( - name=name, - inputs={internal for _, internal in args}, - outputs={"__result"}, - code=f"__result = {expr}", - debuginfo=di, - ) - - for arg, internal in args: - edges = self.context.state.in_edges(expr_tasklet) - used = False - for edge in edges: - if edge.dst_conn == internal: - used = True - break - if used: - continue - elif not isinstance(arg, SymbolExpr): - memlet = dace.Memlet.from_array( - arg.value.data, self.context.body.arrays[arg.value.data] - ) - self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - - memlet = dace.Memlet(data=result_access.data, subset="0") - self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) - - return [ValueExpr(result_access, result_type)] - - -def is_scan(node: itir.Node) -> bool: - return isinstance(node, itir.FunCall) and node.fun == itir.SymRef(id="scan") - - -def closure_to_tasklet_sdfg( - node: itir.StencilClosure, - offset_provider_type: common.OffsetProviderType, - domain: dict[str, str], - inputs: Sequence[tuple[str, ts.TypeSpec]], - connectivities: Sequence[tuple[dace.ndarray, str]], - use_field_canonical_representation: bool, -) -> tuple[Context, Sequence[ValueExpr]]: - body = dace.SDFG("tasklet_toplevel") - body.debuginfo = dace_utils.debug_info(node) - state = body.add_state("tasklet_toplevel_entry", True) - symbol_map: dict[str, TaskletExpr] = {} - - idx_accesses = {} - for dim, idx in domain.items(): - name = f"{idx}_value" - body.add_scalar(name, dtype=_INDEX_DTYPE, transient=True) - tasklet = state.add_tasklet( - f"get_{dim}", set(), {"value"}, f"value = {idx}", debuginfo=body.debuginfo - ) - access = state.add_access(name, debuginfo=body.debuginfo) - idx_accesses[dim] = access - state.add_edge(tasklet, "value", access, None, dace.Memlet(data=name, subset="0")) - for name, ty in inputs: - if isinstance(ty, ts.FieldType): - ndim = len(ty.dims) - shape, strides = new_array_symbols(name, ndim) - dims = [dim.value for dim in ty.dims] - dtype = dace_utils.as_dace_type(ty.dtype) - body.add_array(name, shape=shape, strides=strides, dtype=dtype) - field = state.add_access(name, debuginfo=body.debuginfo) - indices = {dim: idx_accesses[dim] for dim in domain.keys()} - symbol_map[name] = IteratorExpr(field, indices, dtype, dims) - else: - assert isinstance(ty, ts.ScalarType) - dtype = dace_utils.as_dace_type(ty) - body.add_scalar(name, dtype=dtype) - symbol_map[name] = ValueExpr(state.add_access(name, debuginfo=body.debuginfo), dtype) - for arr, name in connectivities: - shape, strides = new_array_symbols(name, ndim=2) - body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) - - context = Context(body, state, symbol_map) - translator = PythonTaskletCodegen( - offset_provider_type, context, use_field_canonical_representation - ) - - args = [itir.SymRef(id=name) for name, _ in inputs] - if is_scan(node.stencil): - stencil = cast(FunCall, node.stencil) - assert isinstance(stencil.args[0], Lambda) - lambda_node = itir.Lambda( - expr=stencil.args[0].expr, params=stencil.args[0].params, location=node.location - ) - fun_node = itir.FunCall(fun=lambda_node, args=args, location=node.location) - else: - fun_node = itir.FunCall(fun=node.stencil, args=args, location=node.location) - - results = translator.visit(fun_node) - for r in results: - context.body.arrays[r.value.data].transient = False - - return context, results diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py deleted file mode 100644 index 72bb32f003..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ /dev/null @@ -1,149 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import itertools -from typing import Any - -import dace - -import gt4py.next.iterator.ir as itir -from gt4py import eve -from gt4py.next import common -from gt4py.next.ffront import fbuiltins as gtx_fbuiltins -from gt4py.next.program_processors.runners.dace_common import utility as dace_utils - - -def get_used_connectivities( - node: itir.Node, offset_provider_type: common.OffsetProviderType -) -> dict[str, common.NeighborConnectivityType]: - connectivities = dace_utils.filter_connectivity_types(offset_provider_type) - offset_dims = set(eve.walk_values(node).if_isinstance(itir.OffsetLiteral).getattr("value")) - return {offset: connectivities[offset] for offset in offset_dims if offset in connectivities} - - -def map_nested_sdfg_symbols( - parent_sdfg: dace.SDFG, nested_sdfg: dace.SDFG, array_mapping: dict[str, dace.Memlet] -) -> dict[str, str]: - symbol_mapping: dict[str, str] = {} - for param, arg in array_mapping.items(): - arg_array = parent_sdfg.arrays[arg.data] - param_array = nested_sdfg.arrays[param] - if not isinstance(param_array, dace.data.Scalar): - assert len(arg.subset.size()) == len(param_array.shape) - for arg_shape, param_shape in zip(arg.subset.size(), param_array.shape): - if isinstance(param_shape, dace.symbol): - symbol_mapping[str(param_shape)] = str(arg_shape) - assert len(arg_array.strides) == len(param_array.strides) - for arg_stride, param_stride in zip(arg_array.strides, param_array.strides): - if isinstance(param_stride, dace.symbol): - symbol_mapping[str(param_stride)] = str(arg_stride) - else: - assert arg.subset.num_elements() == 1 - for sym in nested_sdfg.free_symbols: - if str(sym) not in symbol_mapping: - symbol_mapping[str(sym)] = str(sym) - return symbol_mapping - - -def add_mapped_nested_sdfg( - state: dace.SDFGState, - map_ranges: dict[str, str | dace.subsets.Subset] | list[tuple[str, str | dace.subsets.Subset]], - inputs: dict[str, dace.Memlet], - outputs: dict[str, dace.Memlet], - sdfg: dace.SDFG, - symbol_mapping: dict[str, Any] | None = None, - schedule: Any = dace.dtypes.ScheduleType.Default, - unroll_map: bool = False, - location: Any = None, - debuginfo: Any = None, - input_nodes: dict[str, dace.nodes.AccessNode] | None = None, - output_nodes: dict[str, dace.nodes.AccessNode] | None = None, -) -> tuple[dace.nodes.NestedSDFG, dace.nodes.MapEntry, dace.nodes.MapExit]: - if not symbol_mapping: - symbol_mapping = {sym: sym for sym in sdfg.free_symbols} - - nsdfg_node = state.add_nested_sdfg( - sdfg, - None, - set(inputs.keys()), - set(outputs.keys()), - symbol_mapping, - name=sdfg.name, - schedule=schedule, - location=location, - debuginfo=debuginfo, - ) - - map_entry, map_exit = state.add_map( - f"{sdfg.name}_map", map_ranges, schedule, unroll_map, debuginfo - ) - - if input_nodes is None: - input_nodes = { - memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) - for name, memlet in inputs.items() - } - if output_nodes is None: - output_nodes = { - memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) - for name, memlet in outputs.items() - } - if not inputs: - state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) - for name, memlet in inputs.items(): - state.add_memlet_path( - input_nodes[memlet.data], - map_entry, - nsdfg_node, - memlet=memlet, - src_conn=None, - dst_conn=name, - propagate=True, - ) - if not outputs: - state.add_edge(nsdfg_node, None, map_exit, None, dace.Memlet()) - for name, memlet in outputs.items(): - state.add_memlet_path( - nsdfg_node, - map_exit, - output_nodes[memlet.data], - memlet=memlet, - src_conn=name, - dst_conn=None, - propagate=True, - ) - - return nsdfg_node, map_entry, map_exit - - -def unique_name(prefix): - unique_id = getattr(unique_name, "_unique_id", 0) # static variable - setattr(unique_name, "_unique_id", unique_id + 1) # noqa: B010 [set-attr-with-constant] - - return f"{prefix}_{unique_id}" - - -def unique_var_name(): - return unique_name("_var") - - -def new_array_symbols(name: str, ndim: int) -> tuple[list[dace.symbol], list[dace.symbol]]: - dtype = dace.dtype_to_typeclass(gtx_fbuiltins.IndexType) - shape = [dace.symbol(dace_utils.field_size_symbol_name(name, i), dtype) for i in range(ndim)] - strides = [ - dace.symbol(dace_utils.field_stride_symbol_name(name, i), dtype) for i in range(ndim) - ] - return shape, strides - - -def flatten_list(node_list: list[Any]) -> list[Any]: - return list( - itertools.chain.from_iterable( - [flatten_list(e) if isinstance(e, list) else [e] for e in node_list] - ) - ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py deleted file mode 100644 index 653ed4719d..0000000000 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ /dev/null @@ -1,150 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -import dataclasses -import functools -from typing import Callable, Optional, Sequence - -import dace -import factory - -from gt4py._core import definitions as core_defs -from gt4py.next import common, config -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import program_to_fencil -from gt4py.next.otf import languages, recipes, stages, step_types, workflow -from gt4py.next.otf.binding import interface -from gt4py.next.otf.languages import LanguageSettings -from gt4py.next.program_processors.runners.dace_common import workflow as dace_workflow -from gt4py.next.type_system import type_specifications as ts - -from . import build_sdfg_from_itir - - -@dataclasses.dataclass(frozen=True) -class DaCeTranslator( - workflow.ChainableWorkflowMixin[ - stages.CompilableProgram, stages.ProgramSource[languages.SDFG, languages.LanguageSettings] - ], - step_types.TranslationStep[languages.SDFG, languages.LanguageSettings], -): - auto_optimize: bool = False - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - symbolic_domain_sizes: Optional[dict[str, str]] = None - temporary_extraction_heuristics: Optional[ - Callable[[itir.StencilClosure], Callable[[itir.Expr], bool]] - ] = None - use_field_canonical_representation: bool = False - - def _language_settings(self) -> languages.LanguageSettings: - return languages.LanguageSettings( - formatter_key="", formatter_style="", file_extension="sdfg" - ) - - def generate_sdfg( - self, - program: itir.FencilDefinition, - arg_types: Sequence[ts.TypeSpec], - offset_provider_type: common.OffsetProviderType, - column_axis: Optional[common.Dimension], - ) -> dace.SDFG: - on_gpu = ( - True - if self.device_type in [core_defs.DeviceType.CUDA, core_defs.DeviceType.ROCM] - else False - ) - - return build_sdfg_from_itir( - program, - arg_types, - offset_provider_type=offset_provider_type, - auto_optimize=self.auto_optimize, - on_gpu=on_gpu, - column_axis=column_axis, - symbolic_domain_sizes=self.symbolic_domain_sizes, - temporary_extraction_heuristics=self.temporary_extraction_heuristics, - load_sdfg_from_file=False, - save_sdfg=False, - use_field_canonical_representation=self.use_field_canonical_representation, - ) - - def __call__( - self, inp: stages.CompilableProgram - ) -> stages.ProgramSource[languages.SDFG, LanguageSettings]: - """Generate DaCe SDFG file from the ITIR definition.""" - program: itir.FencilDefinition | itir.Program = inp.data - - if isinstance(program, itir.Program): - program = program_to_fencil.program_to_fencil(program) - - sdfg = self.generate_sdfg( - program, - inp.args.args, - common.offset_provider_to_type(inp.args.offset_provider), - inp.args.column_axis, - ) - - param_types = tuple( - interface.Parameter(param, arg) for param, arg in zip(sdfg.arg_names, inp.args.args) - ) - - module: stages.ProgramSource[languages.SDFG, languages.LanguageSettings] = ( - stages.ProgramSource( - entry_point=interface.Function(program.id, param_types), - source_code=sdfg.to_json(), - library_deps=tuple(), - language=languages.SDFG, - language_settings=self._language_settings(), - implicit_domain=inp.data.implicit_domain, - ) - ) - return module - - -class DaCeTranslationStepFactory(factory.Factory): - class Meta: - model = DaCeTranslator - - -def _no_bindings(inp: stages.ProgramSource) -> stages.CompilableSource: - return stages.CompilableSource(program_source=inp, binding_source=None) - - -class DaCeWorkflowFactory(factory.Factory): - class Meta: - model = recipes.OTFCompileWorkflow - - class Params: - device_type: core_defs.DeviceType = core_defs.DeviceType.CPU - cmake_build_type: config.CMakeBuildType = factory.LazyFunction( - lambda: config.CMAKE_BUILD_TYPE - ) - use_field_canonical_representation: bool = False - - translation = factory.SubFactory( - DaCeTranslationStepFactory, - device_type=factory.SelfAttribute("..device_type"), - use_field_canonical_representation=factory.SelfAttribute( - "..use_field_canonical_representation" - ), - ) - bindings = _no_bindings - compilation = factory.SubFactory( - dace_workflow.DaCeCompilationStepFactory, - cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), - cmake_build_type=factory.SelfAttribute("..cmake_build_type"), - ) - decoration = factory.LazyAttribute( - lambda o: functools.partial( - dace_workflow.convert_args, - device=o.device_type, - use_field_canonical_representation=o.use_field_canonical_representation, - ) - ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 1f3778f227..c0a9be9168 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -7,11 +7,14 @@ # SPDX-License-Identifier: BSD-3-Clause import functools +import pathlib +import tempfile import warnings from typing import Any, Optional import diskcache import factory +import filelock import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators @@ -122,7 +125,7 @@ def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str: Generates a unique hash string for a stencil source program representing the program, sorted offset_provider, and column_axis. """ - program: itir.FencilDefinition | itir.Program = inp.data + program: itir.Program = inp.data offset_provider: common.OffsetProvider = inp.args.offset_provider column_axis: Optional[common.Dimension] = inp.args.column_axis @@ -139,13 +142,34 @@ def fingerprint_compilable_program(inp: stages.CompilableProgram) -> str: class FileCache(diskcache.Cache): """ - This class extends `diskcache.Cache` to ensure the cache is closed upon deletion, - i.e. it ensures that any resources associated with the cache are properly - released when the instance is garbage collected. + This class extends `diskcache.Cache` to ensure the cache is properly + - opened when accessed by multiple processes using a file lock. This guards the creating of the + cache object, which has been reported to cause `sqlite3.OperationalError: database is locked` + errors and slow startup times when multiple processes access the cache concurrently. While this + issue occurred frequently and was observed to be fixed on distributed file systems, the lock + does not guarantee correct behavior in particular for accesses to the cache (beyond opening) + since the underlying SQLite database is unreliable when stored on an NFS based file system. + It does however ensure correctness of concurrent cache accesses on a local file system. See + #1745 for more details. + - closed upon deletion, i.e. it ensures that any resources associated with the cache are + properly released when the instance is garbage collected. """ + def __init__(self, directory: Optional[str | pathlib.Path] = None, **settings: Any) -> None: + if directory: + lock_dir = pathlib.Path(directory).parent + else: + lock_dir = pathlib.Path(tempfile.gettempdir()) + + lock = filelock.FileLock(lock_dir / "file_cache.lock") + with lock: + super().__init__(directory=directory, **settings) + + self._init_complete = True + def __del__(self) -> None: - self.close() + if getattr(self, "_init_complete", False): # skip if `__init__` didn't finished + self.close() class GTFNCompileWorkflowFactory(factory.Factory): diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index 1dd568b95a..32c3f7a360 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -46,7 +46,6 @@ class EmbeddedDSL(codegen.TemplatedGenerator): AxisLiteral = as_fmt("{value}") FunCall = as_fmt("{fun}({','.join(args)})") Lambda = as_mako("(lambda ${','.join(params)}: ${expr})") - StencilClosure = as_mako("closure(${domain}, ${stencil}, ${output}, [${','.join(inputs)}])") FunctionDefinition = as_mako( """ @fundef @@ -91,11 +90,11 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: def fencil_generator( - ir: itir.Program | itir.FencilDefinition, + ir: itir.Program, debug: bool, use_embedded: bool, offset_provider: common.OffsetProvider, - transforms: itir_transforms.ITIRTransform, + transforms: itir_transforms.GTIRTransform, ) -> stages.CompiledProgram: """ Generate a directly executable fencil from an ITIR node. @@ -198,7 +197,7 @@ class Roundtrip(workflow.Workflow[stages.CompilableProgram, stages.CompiledProgr debug: Optional[bool] = None use_embedded: bool = True dispatch_backend: Optional[next_backend.Backend] = None - transforms: itir_transforms.ITIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` + transforms: itir_transforms.GTIRTransform = itir_transforms.apply_common_transforms # type: ignore[assignment] # TODO(havogt): cleanup interface of `apply_common_transforms` def __call__(self, inp: stages.CompilableProgram) -> stages.CompiledProgram: debug = config.DEBUG if self.debug is None else self.debug @@ -266,10 +265,10 @@ def decorated_fencil( gtir = next_backend.Backend( name="roundtrip_gtir", - executor=Roundtrip(transforms=itir_transforms.apply_fieldview_transforms), # type: ignore[arg-type] # on purpose doesn't support `FencilDefintion` will resolve itself later... + executor=Roundtrip(transforms=itir_transforms.apply_fieldview_transforms), # type: ignore[arg-type] # don't understand why mypy complains allocator=next_allocators.StandardCPUFieldBufferAllocator(), transforms=next_backend.Transforms( - past_to_itir=past_to_itir.past_to_itir_factory(to_gtir=True), + past_to_itir=past_to_itir.past_to_gtir_factory(), foast_to_itir=foast_to_gtir.adapted_foast_to_gtir_factory(cached=True), field_view_op_to_prog=foast_to_past.operator_to_program_factory( foast_to_itir_step=foast_to_gtir.adapted_foast_to_gtir_factory() diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index c4d07d7337..4609184547 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -421,7 +421,7 @@ def stencil(field_a: gtscript.Field[np.float_], field_b: gtscript.Field[np.int_] @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_mask_with_offset_written_in_conditional(backend): - @gtscript.stencil(backend, externals={"mord": 5}) + @gtscript.stencil(backend) def stencil(outp: gtscript.Field[np.float_]): with computation(PARALLEL), interval(...): cond = True @@ -582,13 +582,17 @@ def test_K_offset_write(backend): # Cuda generates bad code for the K offset if backend == "cuda": pytest.skip("cuda K-offset write generates bad code") - if backend in ["gt:gpu", "dace:gpu"]: + if backend in ["dace:gpu"]: import cupy as cp if cp.cuda.runtime.runtimeGetVersion() < 12000: pytest.skip( f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" ) + if backend in ["gt:gpu"]: + pytest.skip( + f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1754" + ) arraylib = get_array_library(backend) array_shape = (1, 1, 4) @@ -660,13 +664,17 @@ def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): def test_K_offset_write_conditional(backend): if backend == "cuda": pytest.skip("Cuda backend is not capable of K offset write") - if backend in ["gt:gpu", "dace:gpu"]: + if backend in ["dace:gpu"]: import cupy as cp if cp.cuda.runtime.runtimeGetVersion() < 12000: pytest.skip( f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" ) + if backend in ["gt:gpu"]: + pytest.skip( + f"{backend} backend is not capable of K offset write, bug remains unsolved: https://github.com/GridTools/gt4py/issues/1754" + ) arraylib = get_array_library(backend) array_shape = (1, 1, 4) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 349d3e9f70..bed6e89a52 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -11,11 +11,10 @@ import dataclasses import enum import importlib -from typing import Final, Optional, Protocol import pytest -from gt4py.next import allocators as next_allocators, backend as next_backend +from gt4py.next import allocators as next_allocators # Skip definitions @@ -67,10 +66,10 @@ class EmbeddedIds(_PythonObjectIdMixin, str, enum.Enum): class OptionalProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): - DACE_CPU = "gt4py.next.program_processors.runners.dace.itir_cpu" - DACE_GPU = "gt4py.next.program_processors.runners.dace.itir_gpu" - GTIR_DACE_CPU = "gt4py.next.program_processors.runners.dace.gtir_cpu" - GTIR_DACE_GPU = "gt4py.next.program_processors.runners.dace.gtir_gpu" + DACE_CPU = "gt4py.next.program_processors.runners.dace.run_dace_cpu" + DACE_GPU = "gt4py.next.program_processors.runners.dace.run_dace_gpu" + DACE_CPU_NO_OPT = "gt4py.next.program_processors.runners.dace.run_dace_cpu_noopt" + DACE_GPU_NO_OPT = "gt4py.next.program_processors.runners.dace.run_dace_gpu_noopt" class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): @@ -86,8 +85,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): # to avoid needing to mark all tests. ALL = "all" REQUIRES_ATLAS = "requires_atlas" -# TODO(havogt): Remove, skipped during refactoring to GTIR -STARTS_FROM_GTIR_PROGRAM = "starts_from_gtir_program" USES_APPLIED_SHIFTS = "uses_applied_shifts" USES_CONSTANT_FIELDS = "uses_constant_fields" USES_DYNAMIC_OFFSETS = "uses_dynamic_offsets" @@ -95,10 +92,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_IF_STMTS = "uses_if_stmts" USES_IR_IF_STMTS = "uses_ir_if_stmts" USES_INDEX_FIELDS = "uses_index_fields" -USES_LIFT_EXPRESSIONS = "uses_lift_expressions" USES_NEGATIVE_MODULO = "uses_negative_modulo" USES_ORIGIN = "uses_origin" -USES_REDUCTION_OVER_LIFT_EXPRESSIONS = "uses_reduction_over_lift_expressions" USES_SCAN = "uses_scan" USES_SCAN_IN_FIELD_OPERATOR = "uses_scan_in_field_operator" USES_SCAN_IN_STENCIL = "uses_scan_in_stencil" @@ -118,7 +113,6 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): USES_MESH_WITH_SKIP_VALUES = "uses_mesh_with_skip_values" USES_SCALAR_IN_DOMAIN_AND_FO = "uses_scalar_in_domain_and_fo" CHECKS_SPECIFIC_ERROR = "checks_specific_error" -USES_INDEX_BUILTIN = "uses_index_builtin" # Skip messages (available format keys: 'marker', 'backend') UNSUPPORTED_MESSAGE = "'{marker}' tests not supported by '{backend}' backend" @@ -136,24 +130,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): ] # Markers to skip because of missing features in the domain inference DOMAIN_INFERENCE_SKIP_LIST = [ - (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), ] -DACE_SKIP_TEST_LIST = COMMON_SKIP_TEST_LIST + [ - (USES_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), - (USES_IR_IF_STMTS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_SCALAR_IN_DOMAIN_AND_FO, XFAIL, UNSUPPORTED_MESSAGE), - (USES_INDEX_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_LIFT_EXPRESSIONS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ORIGIN, XFAIL, UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), - (USES_TUPLE_ARGS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), - (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), - (STARTS_FROM_GTIR_PROGRAM, SKIP, UNSUPPORTED_MESSAGE), -] -GTIR_DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ +DACE_SKIP_TEST_LIST = DOMAIN_INFERENCE_SKIP_LIST + [ (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_SCAN, XFAIL, UNSUPPORTED_MESSAGE), (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), @@ -191,8 +170,8 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): EmbeddedIds.CUPY_EXECUTION: EMBEDDED_SKIP_LIST, OptionalProgramBackendId.DACE_CPU: DACE_SKIP_TEST_LIST, OptionalProgramBackendId.DACE_GPU: DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.GTIR_DACE_CPU: GTIR_DACE_SKIP_TEST_LIST, - OptionalProgramBackendId.GTIR_DACE_GPU: GTIR_DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_CPU_NO_OPT: DACE_SKIP_TEST_LIST, + OptionalProgramBackendId.DACE_GPU_NO_OPT: DACE_SKIP_TEST_LIST, ProgramBackendId.GTFN_CPU: GTFN_SKIP_TEST_LIST + [(USES_SCAN_NESTED, XFAIL, UNSUPPORTED_MESSAGE)], ProgramBackendId.GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 9fb7850666..759cd1cf1f 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -499,13 +499,21 @@ def unstructured_case( Vertex: mesh_descriptor.num_vertices, Edge: mesh_descriptor.num_edges, Cell: mesh_descriptor.num_cells, - KDim: 10, }, grid_type=common.GridType.UNSTRUCTURED, allocator=exec_alloc_descriptor.allocator, ) +@pytest.fixture +def unstructured_case_3d(unstructured_case): + return dataclasses.replace( + unstructured_case, + default_sizes={**unstructured_case.default_sizes, KDim: 10}, + offset_provider={**unstructured_case.offset_provider, "KOff": KDim}, + ) + + def _allocate_from_type( case: Case, arg_type: ts.TypeSpec, diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index f5646c71e4..08904c06f3 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -6,14 +6,11 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from types import ModuleType -from typing import Optional - import numpy as np import pytest import gt4py.next as gtx -from gt4py.next import backend as next_backend, common +from gt4py.next import allocators as gtx_allocators, common as gtx_common from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case, unstructured_case @@ -34,24 +31,22 @@ try: import dace - from gt4py.next.program_processors.runners.dace import ( - itir_cpu as run_dace_cpu, - itir_gpu as run_dace_gpu, - ) except ImportError: dace: Optional[ModuleType] = None # type:ignore[no-redef] - run_dace_cpu: Optional[next_backend.Backend] = None - run_dace_gpu: Optional[next_backend.Backend] = None pytestmark = pytest.mark.requires_dace def test_sdfgConvertible_laplap(cartesian_case): - # TODO(kotsaloscv): Temporary solution until the `requires_dace` marker is fully functional - if cartesian_case.backend not in [run_dace_cpu, run_dace_gpu]: + if not cartesian_case.backend or "dace" not in cartesian_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") - if cartesian_case.backend == run_dace_gpu: + # TODO(ricoh): enable test after adding GTIR support + pytest.skip("DaCe SDFGConvertible interface does not support GTIR program.") + + allocator, backend = unstructured_case.allocator, unstructured_case.backend + + if gtx_allocators.is_field_allocator_factory_for(allocator, gtx_allocators.CUPY_DEVICE): import cupy as xp else: import numpy as xp @@ -64,13 +59,13 @@ def test_sdfgConvertible_laplap(cartesian_case): def sdfg(): tmp_field = xp.empty_like(out_field) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( - cartesian_case.backend - ).with_connectivities(common.offset_provider_to_type(cartesian_case.offset_provider))( + backend + ).with_connectivities(gtx_common.offset_provider_to_type(cartesian_case.offset_provider))( in_field, tmp_field ) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( - cartesian_case.backend - ).with_connectivities(common.offset_provider_to_type(cartesian_case.offset_provider))( + backend + ).with_connectivities(gtx_common.offset_provider_to_type(cartesian_case.offset_provider))( tmp_field, out_field ) @@ -94,13 +89,15 @@ def testee(a: gtx.Field[gtx.Dims[Vertex], gtx.float64], b: gtx.Field[gtx.Dims[Ed @pytest.mark.uses_unstructured_shift def test_sdfgConvertible_connectivities(unstructured_case): - # TODO(kotsaloscv): Temporary solution until the `requires_dace` marker is fully functional - if unstructured_case.backend not in [run_dace_cpu, run_dace_gpu]: + if not unstructured_case.backend or "dace" not in unstructured_case.backend.name: pytest.skip("DaCe-related test: Test SDFGConvertible interface for GT4Py programs") + # TODO(ricoh): enable test after adding GTIR support + pytest.skip("DaCe SDFGConvertible interface does not support GTIR program.") + allocator, backend = unstructured_case.allocator, unstructured_case.backend - if backend == run_dace_gpu: + if gtx_allocators.is_field_allocator_factory_for(allocator, gtx_allocators.CUPY_DEVICE): import cupy as xp dace_storage_type = dace.StorageType.GPU_Global diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py index 794dd06709..1147f4bc3e 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/ffront_test_utils.py @@ -66,11 +66,11 @@ def __gt_allocator__( marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), ), pytest.param( - next_tests.definitions.OptionalProgramBackendId.GTIR_DACE_CPU, + next_tests.definitions.OptionalProgramBackendId.DACE_CPU_NO_OPT, marks=pytest.mark.requires_dace, ), pytest.param( - next_tests.definitions.OptionalProgramBackendId.GTIR_DACE_GPU, + next_tests.definitions.OptionalProgramBackendId.DACE_GPU_NO_OPT, marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), ), ], diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py index 47419c278b..9e80dba53b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_decorator.py @@ -21,7 +21,7 @@ ) -def test_program_itir_regression(cartesian_case): +def test_program_gtir_regression(cartesian_case): @gtx.field_operator(backend=None) def testee_op(a: cases.IField) -> cases.IField: return a @@ -30,10 +30,8 @@ def testee_op(a: cases.IField) -> cases.IField: def testee(a: cases.IField, out: cases.IField): testee_op(a, out=out) - assert isinstance(testee.itir, (itir.Program, itir.FencilDefinition)) - assert isinstance( - testee.with_backend(cartesian_case.backend).itir, (itir.Program, itir.FencilDefinition) - ) + assert isinstance(testee.gtir, itir.Program) + assert isinstance(testee.with_backend(cartesian_case.backend).gtir, itir.Program) def test_frozen(cartesian_case): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 1a51e3667d..9de4449ac2 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -41,6 +41,7 @@ Edge, cartesian_case, unstructured_case, + unstructured_case_3d, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, @@ -93,6 +94,20 @@ def testee(a: cases.VField) -> cases.EField: ) +def test_horizontal_only_with_3d_mesh(unstructured_case_3d): + # test field operator operating only on horizontal fields while using an offset provider + # including a vertical dimension. + @gtx.field_operator + def testee(a: cases.VField) -> cases.VField: + return a + + cases.verify_with_default_data( + unstructured_case_3d, + testee, + ref=lambda a: a, + ) + + @pytest.mark.uses_unstructured_shift def test_composed_unstructured_shift(unstructured_case): @gtx.field_operator @@ -276,7 +291,6 @@ def testee(a: tuple[cases.VField, cases.EField]) -> cases.VField: ) -@pytest.mark.uses_index_fields @pytest.mark.uses_cartesian_shift def test_scalar_arg_with_field(cartesian_case): @gtx.field_operator @@ -423,6 +437,22 @@ def testee(a: cases.IFloatField) -> gtx.Field[[IDim], int64]: ) +def test_astype_int_local_field(unstructured_case): + @gtx.field_operator + def testee(a: gtx.Field[[Vertex], np.float64]) -> gtx.Field[[Edge], int64]: + tmp = astype(a(E2V), int64) + return neighbor_sum(tmp, axis=E2VDim) + + e2v_table = unstructured_case.offset_provider["E2V"].ndarray + + cases.verify_with_default_data( + unstructured_case, + testee, + ref=lambda a: np.sum(a.astype(int64)[e2v_table], axis=1, initial=0), + comparison=lambda a, b: np.all(a == b), + ) + + @pytest.mark.uses_tuple_returns def test_astype_on_tuples(cartesian_case): @gtx.field_operator @@ -571,7 +601,6 @@ def combine(a: cases.IField, b: cases.IField) -> cases.IField: @pytest.mark.uses_unstructured_shift -@pytest.mark.uses_reduction_over_lift_expressions def test_nested_reduction(unstructured_case): @gtx.field_operator def testee(a: cases.VField) -> cases.VField: @@ -691,7 +720,6 @@ def simple_scan_operator(carry: float) -> float: @pytest.mark.uses_scan -@pytest.mark.uses_lift_expressions @pytest.mark.uses_scan_nested def test_solve_triag(cartesian_case): @gtx.scan_operator(axis=KDim, forward=True, init=(0.0, 0.0)) @@ -773,7 +801,6 @@ def testee( @pytest.mark.uses_constant_fields @pytest.mark.uses_unstructured_shift -@pytest.mark.uses_reduction_over_lift_expressions def test_ternary_builtin_neighbor_sum(unstructured_case): @gtx.field_operator def testee(a: cases.EField, b: cases.EField) -> cases.VField: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 7648d34db7..ab1c625fef 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -29,6 +29,7 @@ Vertex, cartesian_case, unstructured_case, + unstructured_case_3d, ) from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( exec_alloc_descriptor, @@ -105,10 +106,10 @@ def reduction_ke_field( @pytest.mark.parametrize( "fop", [reduction_e_field, reduction_ek_field, reduction_ke_field], ids=lambda fop: fop.__name__ ) -def test_neighbor_sum(unstructured_case, fop): - v2e_table = unstructured_case.offset_provider["V2E"].ndarray +def test_neighbor_sum(unstructured_case_3d, fop): + v2e_table = unstructured_case_3d.offset_provider["V2E"].ndarray - edge_f = cases.allocate(unstructured_case, fop, "edge_f")() + edge_f = cases.allocate(unstructured_case_3d, fop, "edge_f")() local_dim_idx = edge_f.domain.dims.index(Edge) + 1 adv_indexing = tuple( @@ -131,10 +132,10 @@ def test_neighbor_sum(unstructured_case, fop): where=broadcasted_table != common._DEFAULT_SKIP_VALUE, ) cases.verify( - unstructured_case, + unstructured_case_3d, fop, edge_f, - out=cases.allocate(unstructured_case, fop, cases.RETURN)(), + out=cases.allocate(unstructured_case_3d, fop, cases.RETURN)(), ref=ref, ) @@ -463,11 +464,13 @@ def conditional_program( ) -def test_promotion(unstructured_case): +def test_promotion(unstructured_case_3d): @gtx.field_operator def promotion( inp1: gtx.Field[[Edge, KDim], float64], inp2: gtx.Field[[KDim], float64] ) -> gtx.Field[[Edge, KDim], float64]: return inp1 / inp2 - cases.verify_with_default_data(unstructured_case, promotion, ref=lambda inp1, inp2: inp1 / inp2) + cases.verify_with_default_data( + unstructured_case_3d, promotion, ref=lambda inp1, inp2: inp1 / inp2 + ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py index 66c56c4827..7d2eec772c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -107,12 +107,12 @@ def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, mesh def test_temporary_symbols(testee, mesh_descriptor): - itir_with_tmp = apply_common_transforms( - testee.itir, + gtir_with_tmp = apply_common_transforms( + testee.gtir, extract_temporaries=True, offset_provider=mesh_descriptor.offset_provider, ) params = ["num_vertices", "num_edges", "num_cells"] for param in params: - assert any([param == str(p) for p in itir_with_tmp.params]) + assert any([param == str(p) for p in gtir_with_tmp.params]) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 5e3a2fcd14..c0a4cd166d 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -18,6 +18,7 @@ from gt4py.next.iterator import builtins as it_builtins from gt4py.next.iterator.builtins import ( and_, + as_fieldop, bool, can_deref, cartesian_domain, @@ -45,9 +46,8 @@ plus, shift, xor_, - as_fieldop, ) -from gt4py.next.iterator.runtime import set_at, closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import fendef, fundef, offset, set_at from gt4py.next.program_processors.runners.gtfn import run_gtfn from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py index 2ebcd0c033..fedfd83fd2 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_cartesian_offset_provider.py @@ -10,7 +10,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import double_roundtrip, roundtrip @@ -27,16 +27,14 @@ def foo(inp): @fendef(offset_provider={"I": I_loc, "J": J_loc}) def fencil(output, input): - closure( - cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input] - ) + domain = cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)) + set_at(as_fieldop(foo, domain)(input), domain, output) @fendef(offset_provider={"I": J_loc, "J": I_loc}) def fencil_swapped(output, input): - closure( - cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)), foo, output, [input] - ) + domain = cartesian_domain(named_range(I_loc, 0, 1), named_range(J_loc, 0, 1)) + set_at(as_fieldop(foo, domain)(input), domain, output) def test_cartesian_offset_provider(): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py index 551c567e61..eae66d425b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_conditional.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef +from gt4py.next.iterator.runtime import set_at, fendef, fundef from next_tests.unit_tests.conftest import program_processor, run_processor diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py index f6fd0a48d0..c79f8dbb6b 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_program.py @@ -42,7 +42,6 @@ def copy_program(inp, out, size): ) -@pytest.mark.starts_from_gtir_program def test_prog(program_processor): program_processor, validate = program_processor @@ -64,8 +63,7 @@ def index_program_simple(out, size): ) -@pytest.mark.starts_from_gtir_program -@pytest.mark.uses_index_builtin +@pytest.mark.uses_index_fields def test_index_builtin(program_processor): program_processor, validate = program_processor @@ -88,8 +86,7 @@ def index_program_shift(out, size): ) -@pytest.mark.starts_from_gtir_program -@pytest.mark.uses_index_builtin +@pytest.mark.uses_index_fields def test_index_builtin_shift(program_processor): program_processor, validate = program_processor diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py index 7bde55bfd2..68e5f9d532 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_strided_offset_provider.py @@ -10,8 +10,8 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.builtins import deref, named_range, shift, unstructured_domain, as_fieldop +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.unit_tests.conftest import program_processor, run_processor from gt4py.next.iterator.embedded import StridedConnectivityField @@ -36,7 +36,8 @@ def foo(inp): @fendef(offset_provider={"O": LocA2LocAB_offset_provider}) def fencil(size, out, inp): - closure(unstructured_domain(named_range(LocA, 0, size)), foo, out, [inp]) + domain = unstructured_domain(named_range(LocA, 0, size)) + set_at(as_fieldop(foo, domain)(inp), domain, out) @pytest.mark.uses_strided_neighbor_offset diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index 5f1c70a6b3..fe89fe7c9d 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -12,7 +12,7 @@ import gt4py.next as gtx from gt4py.next.iterator import transforms from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.cases import IDim, JDim, KDim from next_tests.unit_tests.conftest import program_processor, run_processor @@ -94,12 +94,8 @@ def test_shifted_arg_to_lift(program_processor): @fendef def fen_direct_deref(i_size, j_size, out, inp): - closure( - cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size)), - deref, - out, - [inp], - ) + domain = cartesian_domain(named_range(IDim, 0, i_size), named_range(JDim, 0, j_size)) + set_at(as_fieldop(deref, domain)(inp), domain, out) def test_direct_deref(program_processor): diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py index 2d84439c93..39d0bd69c3 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_tuple.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef +from gt4py.next.iterator.runtime import set_at, fendef, fundef from next_tests.unit_tests.conftest import program_processor, run_processor @@ -114,16 +114,10 @@ def test_tuple_of_field_output_constructed_inside(program_processor, stencil): @fendef def fencil(size0, size1, size2, inp1, inp2, out1, out2): - closure( - cartesian_domain( - named_range(IDim, 0, size0), - named_range(JDim, 0, size1), - named_range(KDim, 0, size2), - ), - stencil, - make_tuple(out1, out2), - [inp1, inp2], + domain = cartesian_domain( + named_range(IDim, 0, size0), named_range(JDim, 0, size1), named_range(KDim, 0, size2) ) + set_at(as_fieldop(stencil, domain)(inp1, inp2), domain, make_tuple(out1, out2)) shape = [5, 7, 9] rng = np.random.default_rng() @@ -159,15 +153,13 @@ def stencil(inp1, inp2, inp3): @fendef def fencil(size0, size1, size2, inp1, inp2, inp3, out1, out2, out3): - closure( - cartesian_domain( - named_range(IDim, 0, size0), - named_range(JDim, 0, size1), - named_range(KDim, 0, size2), - ), - stencil, + domain = cartesian_domain( + named_range(IDim, 0, size0), named_range(JDim, 0, size1), named_range(KDim, 0, size2) + ) + set_at( + as_fieldop(stencil, domain)(inp1, inp2, inp3), + domain, make_tuple(make_tuple(out1, out2), out3), - [inp1, inp2, inp3], ) shape = [5, 7, 9] diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index 3ce9d6b470..d0a1601816 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -10,8 +10,15 @@ import pytest import gt4py.next as gtx -from gt4py.next.iterator.builtins import cartesian_domain, deref, lift, named_range, shift -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.builtins import ( + cartesian_domain, + deref, + lift, + named_range, + shift, + as_fieldop, +) +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn from next_tests.unit_tests.conftest import program_processor, run_processor @@ -85,14 +92,10 @@ def test_anton_toy(stencil, program_processor): @fendef(offset_provider={"i": IDim, "j": JDim}) def fencil(x, y, z, out, inp): - closure( - cartesian_domain( - named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z) - ), - stencil, - out, - [inp], + domain = cartesian_domain( + named_range(IDim, 0, x), named_range(JDim, 0, y), named_range(KDim, 0, z) ) + set_at(as_fieldop(stencil, domain)(inp), domain, out) shape = [5, 7, 9] rng = np.random.default_rng() diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index 4487681abf..22b4d8b3c5 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -28,8 +28,9 @@ reduce, tuple_get, unstructured_domain, + as_fieldop, ) -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from next_tests.integration_tests.multi_feature_tests.fvm_nabla_setup import ( assert_close, @@ -55,7 +56,8 @@ def compute_zavgS(pp, S_M): @fendef def compute_zavgS_fencil(n_edges, out, pp, S_M): - closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS, out, [pp, S_M]) + domain = unstructured_domain(named_range(Edge, 0, n_edges)) + set_at(as_fieldop(compute_zavgS, domain)(pp, S_M), domain, out) @fundef @@ -100,12 +102,8 @@ def compute_pnabla2(pp, S_M, sign, vol): @fendef def nabla(n_nodes, out, pp, S_MXX, S_MYY, sign, vol): - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - pnabla, - out, - [pp, S_MXX, S_MYY, sign, vol], - ) + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at(as_fieldop(pnabla, domain)(pp, S_MXX, S_MYY, sign, vol), domain, out) @pytest.mark.requires_atlas @@ -145,7 +143,8 @@ def test_compute_zavgS(program_processor): @fendef def compute_zavgS2_fencil(n_edges, out, pp, S_M): - closure(unstructured_domain(named_range(Edge, 0, n_edges)), compute_zavgS2, out, [pp, S_M]) + domain = unstructured_domain(named_range(Edge, 0, n_edges)) + set_at(as_fieldop(compute_zavgS2, domain)(pp, S_M), domain, out) @pytest.mark.requires_atlas @@ -212,12 +211,8 @@ def test_nabla(program_processor): @fendef def nabla2(n_nodes, out, pp, S, sign, vol): - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla2, - out, - [pp, S, sign, vol], - ) + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at(as_fieldop(compute_pnabla2, domain)(pp, S, sign, vol), domain, out) @pytest.mark.requires_atlas @@ -276,17 +271,16 @@ def compute_pnabla_sign(pp, S_M, vol, node_index, is_pole_edge): @fendef def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_pole_edge): # TODO replace by single stencil which returns tuple - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla_sign, + domain = unstructured_domain(named_range(Vertex, 0, n_nodes)) + set_at( + as_fieldop(compute_pnabla_sign, domain)(pp, S_MXX, vol, node_index, is_pole_edge), + domain, out_MXX, - [pp, S_MXX, vol, node_index, is_pole_edge], ) - closure( - unstructured_domain(named_range(Vertex, 0, n_nodes)), - compute_pnabla_sign, + set_at( + as_fieldop(compute_pnabla_sign, domain)(pp, S_MYY, vol, node_index, is_pole_edge), + domain, out_MYY, - [pp, S_MYY, vol, node_index, is_pole_edge], ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 45793b1d3e..e44e92013f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -11,7 +11,7 @@ import gt4py.next as gtx from gt4py.next.iterator.builtins import * -from gt4py.next.iterator.runtime import closure, fendef, fundef, offset +from gt4py.next.iterator.runtime import set_at, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests.cases import IDim, JDim @@ -57,12 +57,8 @@ def hdiff_sten(inp, coeff): @fendef(offset_provider={"I": IDim, "J": JDim}) def hdiff(inp, coeff, out, x, y): - closure( - cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y)), - hdiff_sten, - out, - [inp, coeff], - ) + domain = cartesian_domain(named_range(IDim, 0, x), named_range(JDim, 0, y)) + set_at(as_fieldop(hdiff_sten, domain)(inp, coeff), domain, out) @pytest.mark.uses_origin diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index f1269f1ed8..03662f8dcc 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -53,22 +53,13 @@ def _program_processor(request) -> tuple[ProgramProcessor, bool]: (None, True), (next_tests.definitions.ProgramBackendId.ROUNDTRIP, True), (next_tests.definitions.ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES, True), + (next_tests.definitions.ProgramBackendId.GTIR_EMBEDDED, True), (next_tests.definitions.ProgramBackendId.DOUBLE_ROUNDTRIP, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True), # pytest.param((definitions.ProgramBackendId.GTFN_GPU, True), marks=pytest.mark.requires_gpu), # TODO(havogt): update tests to use proper allocation - (next_tests.definitions.ProgramFormatterId.LISP_FORMATTER, False), (next_tests.definitions.ProgramFormatterId.ITIR_PRETTY_PRINTER, False), (next_tests.definitions.ProgramFormatterId.GTFN_CPP_FORMATTER, False), - pytest.param( - (next_tests.definitions.OptionalProgramBackendId.DACE_CPU, True), - marks=pytest.mark.requires_dace, - ), - # TODO(havogt): update tests to use proper allocation - # pytest.param( - # (next_tests.definitions.OptionalProgramBackendId.DACE_GPU, True), - # marks=(pytest.mark.requires_dace, pytest.mark.requires_gpu), - # ), ], ids=lambda p: p[0].short_id() if p[0] is not None else "None", ) diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py index 516890ea46..59a8dc961b 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_gtir.py @@ -283,9 +283,22 @@ def foo(a: gtx.Field[[TDim], float64]): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) + lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) reference = im.cast_as_fieldop("int32")("a") + assert lowered_inlined.expr == reference + + +def test_astype_local_field(): + def foo(a: gtx.Field[gtx.Dims[Vertex, V2EDim], float64]): + return astype(a, int32) + + parsed = FieldOperatorParser.apply_to_function(foo) + lowered = FieldOperatorLowering.apply(parsed) + + reference = im.op_as_fieldop(im.map_(im.lambda_("val")(im.call("cast_")("val", "int32"))))("a") + assert lowered.expr == reference @@ -295,10 +308,11 @@ def foo(a: float64): parsed = FieldOperatorParser.apply_to_function(foo) lowered = FieldOperatorLowering.apply(parsed) + lowered_inlined = inline_lambdas.InlineLambdas.apply(lowered) reference = im.call("cast_")("a", "int32") - assert lowered.expr == reference + assert lowered_inlined.expr == reference def test_astype_tuple(): diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py deleted file mode 100644 index c102df9d57..0000000000 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ /dev/null @@ -1,598 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -# TODO(tehrengruber): The style of the tests in this file is not optimal as a single change in the -# lowering can (and often does) make all of them fail. Once we have embedded field view we want to -# switch to executing the different cases here; once with a regular backend (i.e. including -# parsing) and then with embedded field view (i.e. no parsing). If the results match the lowering -# should be correct. - -from __future__ import annotations - -from types import SimpleNamespace - -import pytest - -import gt4py.next as gtx -from gt4py.next import float32, float64, int32, int64, neighbor_sum -from gt4py.next.ffront import type_specifications as ts_ffront -from gt4py.next.ffront.ast_passes import single_static_assign as ssa -from gt4py.next.ffront.foast_to_itir import FieldOperatorLowering -from gt4py.next.ffront.func_to_foast import FieldOperatorParser -from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.ir_utils import ir_makers as im -from gt4py.next.type_system import type_specifications as ts, type_translation -from gt4py.next.iterator.type_system import type_specifications as it_ts - - -IDim = gtx.Dimension("IDim") -Edge = gtx.Dimension("Edge") -Vertex = gtx.Dimension("Vertex") -Cell = gtx.Dimension("Cell") -V2EDim = gtx.Dimension("V2E", gtx.DimensionKind.LOCAL) -V2E = gtx.FieldOffset("V2E", source=Edge, target=(Vertex, V2EDim)) -TDim = gtx.Dimension("TDim") # Meaningless dimension, used for tests. - - -def debug_itir(tree): - """Compare tree snippets while debugging.""" - from devtools import debug - - from gt4py.eve.codegen import format_python_source - from gt4py.next.program_processors import EmbeddedDSL - - debug(format_python_source(EmbeddedDSL.apply(tree))) - - -def test_copy(): - def copy_field(inp: gtx.Field[[TDim], float64]): - return inp - - parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - - assert lowered.id == "copy_field" - assert lowered.expr == im.ref("inp") - - -def test_scalar_arg(): - def scalar_arg(bar: gtx.Field[[IDim], int64], alpha: int64) -> gtx.Field[[IDim], int64]: - return alpha * bar - - parsed = FieldOperatorParser.apply_to_function(scalar_arg) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("multiplies")( - "alpha", "bar" - ) # no difference to non-scalar arg - - assert lowered.expr == reference - - -def test_multicopy(): - def multicopy(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): - return inp1, inp2 - - parsed = FieldOperatorParser.apply_to_function(multicopy) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple("inp1", "inp2") - - assert lowered.expr == reference - - -def test_arithmetic(): - def arithmetic(inp1: gtx.Field[[IDim], float64], inp2: gtx.Field[[IDim], float64]): - return inp1 + inp2 - - parsed = FieldOperatorParser.apply_to_function(arithmetic) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("plus")("inp1", "inp2") - - assert lowered.expr == reference - - -def test_shift(): - Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) - - def shift_by_one(inp: gtx.Field[[IDim], float64]): - return inp(Ioff[1]) - - parsed = FieldOperatorParser.apply_to_function(shift_by_one) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", 1)("it"))))("inp") - - assert lowered.expr == reference - - -def test_negative_shift(): - Ioff = gtx.FieldOffset("Ioff", source=IDim, target=(IDim,)) - - def shift_by_one(inp: gtx.Field[[IDim], float64]): - return inp(Ioff[-1]) - - parsed = FieldOperatorParser.apply_to_function(shift_by_one) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.lift(im.lambda_("it")(im.deref(im.shift("Ioff", -1)("it"))))("inp") - - assert lowered.expr == reference - - -def test_temp_assignment(): - def copy_field(inp: gtx.Field[[TDim], float64]): - tmp = inp - inp = tmp - tmp2 = inp - return tmp2 - - parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let(ssa.unique_name("tmp", 0), "inp")( - im.let( - ssa.unique_name("inp", 0), - ssa.unique_name("tmp", 0), - )( - im.let( - ssa.unique_name("tmp2", 0), - ssa.unique_name("inp", 0), - )(ssa.unique_name("tmp2", 0)) - ) - ) - - assert lowered.expr == reference - - -def test_unary_ops(): - def unary(inp: gtx.Field[[TDim], float64]): - tmp = +inp - tmp = -tmp - return tmp - - parsed = FieldOperatorParser.apply_to_function(unary) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let( - ssa.unique_name("tmp", 0), - im.promote_to_lifted_stencil("plus")( - im.promote_to_const_iterator(im.literal("0", "float64")), "inp" - ), - )( - im.let( - ssa.unique_name("tmp", 1), - im.promote_to_lifted_stencil("minus")( - im.promote_to_const_iterator(im.literal("0", "float64")), ssa.unique_name("tmp", 0) - ), - )(ssa.unique_name("tmp", 1)) - ) - - assert lowered.expr == reference - - -@pytest.mark.parametrize("var, var_type", [("-1.0", "float64"), ("True", "bool")]) -def test_unary_op_type_conversion(var, var_type): - def unary_float(): - return float(-1) - - def unary_bool(): - return bool(-1) - - fun = unary_bool if var_type == "bool" else unary_float - parsed = FieldOperatorParser.apply_to_function(fun) - lowered = FieldOperatorLowering.apply(parsed) - reference = im.promote_to_const_iterator(im.literal(var, var_type)) - - assert lowered.expr == reference - - -def test_unpacking(): - """Unpacking assigns should get separated.""" - - def unpacking( - inp1: gtx.Field[[TDim], float64], inp2: gtx.Field[[TDim], float64] - ) -> gtx.Field[[TDim], float64]: - tmp1, tmp2 = inp1, inp2 # noqa - return tmp1 - - parsed = FieldOperatorParser.apply_to_function(unpacking) - lowered = FieldOperatorLowering.apply(parsed) - - tuple_expr = im.make_tuple("inp1", "inp2") - tuple_access_0 = im.tuple_get(0, "__tuple_tmp_0") - tuple_access_1 = im.tuple_get(1, "__tuple_tmp_0") - - reference = im.let("__tuple_tmp_0", tuple_expr)( - im.let( - ssa.unique_name("tmp1", 0), - tuple_access_0, - )( - im.let( - ssa.unique_name("tmp2", 0), - tuple_access_1, - )(ssa.unique_name("tmp1", 0)) - ) - ) - - assert lowered.expr == reference - - -def test_annotated_assignment(): - pytest.xfail("Annotated assignments are not properly supported at the moment.") - - def copy_field(inp: gtx.Field[[TDim], float64]): - tmp: gtx.Field[[TDim], float64] = inp - return tmp - - parsed = FieldOperatorParser.apply_to_function(copy_field) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let(ssa.unique_name("tmp", 0), "inp")(ssa.unique_name("tmp", 0)) - - assert lowered.expr == reference - - -def test_call(): - # create something that appears to the lowering like a field operator. - # we could also create an actual field operator, but we want to avoid - # using such heavy constructs for testing the lowering. - field_type = type_translation.from_type_hint(gtx.Field[[TDim], float64]) - identity = SimpleNamespace( - __gt_type__=lambda: ts_ffront.FieldOperatorType( - definition=ts.FunctionType( - pos_only_args=[field_type], pos_or_kw_args={}, kw_only_args={}, returns=field_type - ) - ) - ) - - def call(inp: gtx.Field[[TDim], float64]) -> gtx.Field[[TDim], float64]: - return identity(inp) - - parsed = FieldOperatorParser.apply_to_function(call) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.call("identity")("inp") - - assert lowered.expr == reference - - -def test_temp_tuple(): - """Returning a temp tuple should work.""" - - def temp_tuple(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], int64]): - tmp = a, b - return tmp - - parsed = FieldOperatorParser.apply_to_function(temp_tuple) - lowered = FieldOperatorLowering.apply(parsed) - - tuple_expr = im.make_tuple("a", "b") - reference = im.let(ssa.unique_name("tmp", 0), tuple_expr)(ssa.unique_name("tmp", 0)) - - assert lowered.expr == reference - - -def test_unary_not(): - def unary_not(cond: gtx.Field[[TDim], "bool"]): - return not cond - - parsed = FieldOperatorParser.apply_to_function(unary_not) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("not_")("cond") - - assert lowered.expr == reference - - -def test_binary_plus(): - def plus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a + b - - parsed = FieldOperatorParser.apply_to_function(plus) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("plus")("a", "b") - - assert lowered.expr == reference - - -def test_add_scalar_literal_to_field(): - def scalar_plus_field(a: gtx.Field[[IDim], float64]) -> gtx.Field[[IDim], float64]: - return 2.0 + a - - parsed = FieldOperatorParser.apply_to_function(scalar_plus_field) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("plus")( - im.promote_to_const_iterator(im.literal("2.0", "float64")), "a" - ) - - assert lowered.expr == reference - - -def test_add_scalar_literals(): - def scalar_plus_scalar(a: gtx.Field[[IDim], "int32"]) -> gtx.Field[[IDim], "int32"]: - tmp = int32(1) + int32("1") - return a + tmp - - parsed = FieldOperatorParser.apply_to_function(scalar_plus_scalar) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.let( - ssa.unique_name("tmp", 0), - im.promote_to_lifted_stencil("plus")( - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int32")), - ), - )(im.promote_to_lifted_stencil("plus")("a", ssa.unique_name("tmp", 0))) - - assert lowered.expr == reference - - -def test_binary_mult(): - def mult(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a * b - - parsed = FieldOperatorParser.apply_to_function(mult) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("multiplies")("a", "b") - - assert lowered.expr == reference - - -def test_binary_minus(): - def minus(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a - b - - parsed = FieldOperatorParser.apply_to_function(minus) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("minus")("a", "b") - - assert lowered.expr == reference - - -def test_binary_div(): - def division(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a / b - - parsed = FieldOperatorParser.apply_to_function(division) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("divides")("a", "b") - - assert lowered.expr == reference - - -def test_binary_and(): - def bit_and(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): - return a & b - - parsed = FieldOperatorParser.apply_to_function(bit_and) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("and_")("a", "b") - - assert lowered.expr == reference - - -def test_scalar_and(): - def scalar_and(a: gtx.Field[[IDim], "bool"]) -> gtx.Field[[IDim], "bool"]: - return a & False - - parsed = FieldOperatorParser.apply_to_function(scalar_and) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("and_")( - "a", im.promote_to_const_iterator(im.literal("False", "bool")) - ) - - assert lowered.expr == reference - - -def test_binary_or(): - def bit_or(a: gtx.Field[[TDim], "bool"], b: gtx.Field[[TDim], "bool"]): - return a | b - - parsed = FieldOperatorParser.apply_to_function(bit_or) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("or_")("a", "b") - - assert lowered.expr == reference - - -def test_compare_scalars(): - def comp_scalars() -> bool: - return 3 > 4 - - parsed = FieldOperatorParser.apply_to_function(comp_scalars) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("greater")( - im.promote_to_const_iterator(im.literal("3", "int32")), - im.promote_to_const_iterator(im.literal("4", "int32")), - ) - - assert lowered.expr == reference - - -def test_compare_gt(): - def comp_gt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a > b - - parsed = FieldOperatorParser.apply_to_function(comp_gt) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("greater")("a", "b") - - assert lowered.expr == reference - - -def test_compare_lt(): - def comp_lt(a: gtx.Field[[TDim], float64], b: gtx.Field[[TDim], float64]): - return a < b - - parsed = FieldOperatorParser.apply_to_function(comp_lt) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("less")("a", "b") - - assert lowered.expr == reference - - -def test_compare_eq(): - def comp_eq(a: gtx.Field[[TDim], "int64"], b: gtx.Field[[TDim], "int64"]): - return a == b - - parsed = FieldOperatorParser.apply_to_function(comp_eq) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("eq")("a", "b") - - assert lowered.expr == reference - - -def test_compare_chain(): - def compare_chain( - a: gtx.Field[[IDim], float64], b: gtx.Field[[IDim], float64], c: gtx.Field[[IDim], float64] - ) -> gtx.Field[[IDim], bool]: - return a > b > c - - parsed = FieldOperatorParser.apply_to_function(compare_chain) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil("and_")( - im.promote_to_lifted_stencil("greater")("a", "b"), - im.promote_to_lifted_stencil("greater")("b", "c"), - ) - - assert lowered.expr == reference - - -def test_reduction_lowering_simple(): - def reduction(edge_f: gtx.Field[[Edge], float64]): - return neighbor_sum(edge_f(V2E), axis=V2EDim) - - parsed = FieldOperatorParser.apply_to_function(reduction) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.promote_to_lifted_stencil( - im.call( - im.call("reduce")( - "plus", - im.deref(im.promote_to_const_iterator(im.literal(value="0", typename="float64"))), - ) - ) - )(im.lifted_neighbors("V2E", "edge_f")) - - assert lowered.expr == reference - - -def test_reduction_lowering_expr(): - def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], float64]): - e1_nbh = e1(V2E) - return neighbor_sum(1.1 * (e1_nbh + e2), axis=V2EDim) - - parsed = FieldOperatorParser.apply_to_function(reduction) - lowered = FieldOperatorLowering.apply(parsed) - - mapped = im.promote_to_lifted_stencil(im.map_("multiplies"))( - im.promote_to_lifted_stencil("make_const_list")( - im.promote_to_const_iterator(im.literal("1.1", "float64")) - ), - im.promote_to_lifted_stencil(im.map_("plus"))(ssa.unique_name("e1_nbh", 0), "e2"), - ) - - reference = im.let( - ssa.unique_name("e1_nbh", 0), - im.lifted_neighbors("V2E", "e1"), - )( - im.promote_to_lifted_stencil( - im.call( - im.call("reduce")( - "plus", - im.deref( - im.promote_to_const_iterator(im.literal(value="0", typename="float64")) - ), - ) - ) - )(mapped) - ) - - assert lowered.expr == reference - - -def test_builtin_int_constructors(): - def int_constrs() -> tuple[int32, int32, int64, int32, int64]: - return 1, int32(1), int64(1), int32("1"), int64("1") - - parsed = FieldOperatorParser.apply_to_function(int_constrs) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple( - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int64")), - im.promote_to_const_iterator(im.literal("1", "int32")), - im.promote_to_const_iterator(im.literal("1", "int64")), - ) - - assert lowered.expr == reference - - -def test_builtin_float_constructors(): - def float_constrs() -> tuple[float, float, float32, float64, float, float32, float64]: - return ( - 0.1, - float(0.1), - float32(0.1), - float64(0.1), - float(".1"), - float32(".1"), - float64(".1"), - ) - - parsed = FieldOperatorParser.apply_to_function(float_constrs) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple( - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float32")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - im.promote_to_const_iterator(im.literal("0.1", "float32")), - im.promote_to_const_iterator(im.literal("0.1", "float64")), - ) - - assert lowered.expr == reference - - -def test_builtin_bool_constructors(): - def bool_constrs() -> tuple[bool, bool, bool, bool, bool, bool, bool, bool]: - return True, False, bool(True), bool(False), bool(0), bool(5), bool("True"), bool("False") - - parsed = FieldOperatorParser.apply_to_function(bool_constrs) - lowered = FieldOperatorLowering.apply(parsed) - - reference = im.make_tuple( - im.promote_to_const_iterator(im.literal(str(True), "bool")), - im.promote_to_const_iterator(im.literal(str(False), "bool")), - im.promote_to_const_iterator(im.literal(str(True), "bool")), - im.promote_to_const_iterator(im.literal(str(False), "bool")), - im.promote_to_const_iterator(im.literal(str(bool(0)), "bool")), - im.promote_to_const_iterator(im.literal(str(bool(5)), "bool")), - im.promote_to_const_iterator(im.literal(str(bool("True")), "bool")), - im.promote_to_const_iterator(im.literal(str(bool("False")), "bool")), - ) - - assert lowered.expr == reference diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py index a6231c22a7..c813285bd0 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_past_to_gtir.py @@ -46,7 +46,6 @@ def test_copy_lowering(copy_program_def, gtir_identity_fundef): past_node, function_definitions=[gtir_identity_fundef], grid_type=gtx.GridType.CARTESIAN, - to_gtir=True, ) set_at_pattern = P( itir.SetAt, @@ -93,7 +92,6 @@ def test_copy_restrict_lowering(copy_restrict_program_def, gtir_identity_fundef) past_node, function_definitions=[gtir_identity_fundef], grid_type=gtx.GridType.CARTESIAN, - to_gtir=True, ) set_at_pattern = P( itir.SetAt, @@ -149,9 +147,7 @@ def tuple_program( make_tuple_op(inp, out=(out1[1:], out2[1:])) parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply( - parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN, to_gtir=True - ) + ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) @pytest.mark.xfail( @@ -166,9 +162,7 @@ def tuple_program( make_tuple_op(inp, out=(out1[1:], out2)) parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply( - parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN, to_gtir=True - ) + ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) @pytest.mark.xfail @@ -194,7 +188,6 @@ def test_invalid_call_sig_program(invalid_call_sig_program_def): ProgramParser.apply_to_function(invalid_call_sig_program_def), function_definitions=[], grid_type=gtx.GridType.CARTESIAN, - to_gtir=True, ) assert exc_info.match("Invalid call to 'identity'") diff --git a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py deleted file mode 100644 index fefd3c653b..0000000000 --- a/tests/next_tests/unit_tests/ffront_tests/test_past_to_itir.py +++ /dev/null @@ -1,214 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -import re - -import pytest - -import gt4py.eve as eve -import gt4py.next as gtx -from gt4py.eve.pattern_matching import ObjectPattern as P -from gt4py.next import errors -from gt4py.next.ffront.func_to_past import ProgramParser -from gt4py.next.ffront.past_to_itir import ProgramLowering -from gt4py.next.iterator import ir as itir -from gt4py.next.type_system import type_specifications as ts - -from next_tests.past_common_fixtures import ( - IDim, - copy_program_def, - copy_restrict_program_def, - float64, - identity_def, - invalid_call_sig_program_def, -) - - -@pytest.fixture -def itir_identity_fundef(): - return itir.FunctionDefinition( - id="identity", - params=[itir.Sym(id="x")], - expr=itir.FunCall(fun=itir.SymRef(id="deref"), args=[itir.SymRef(id="x")]), - ) - - -def test_copy_lowering(copy_program_def, itir_identity_fundef): - past_node = ProgramParser.apply_to_function(copy_program_def) - itir_node = ProgramLowering.apply( - past_node, function_definitions=[itir_identity_fundef], grid_type=gtx.GridType.CARTESIAN - ) - closure_pattern = P( - itir.StencilClosure, - domain=P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("cartesian_domain")), - args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), - args=[ - P(itir.AxisLiteral, value="IDim"), - P(itir.Literal, value="0", type=ts.ScalarType(kind=ts.ScalarKind.INT32)), - P(itir.SymRef, id=eve.SymbolRef("__out_size_0")), - ], - ) - ], - ), - stencil=P( - itir.Lambda, - params=[P(itir.Sym, id=eve.SymbolName("__stencil_arg0"))], - expr=P( - itir.FunCall, - fun=P( - itir.Lambda, - params=[P(itir.Sym)], - expr=P(itir.FunCall, fun=P(itir.SymRef, id=eve.SymbolRef("deref"))), - ), - args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("identity")), - args=[P(itir.SymRef, id=eve.SymbolRef("__stencil_arg0"))], - ) - ], - ), - ), - inputs=[P(itir.SymRef, id=eve.SymbolRef("in_field"))], - output=P(itir.SymRef, id=eve.SymbolRef("out")), - ) - fencil_pattern = P( - itir.FencilDefinition, - id=eve.SymbolName("copy_program"), - params=[ - P(itir.Sym, id=eve.SymbolName("in_field")), - P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_size_0")), - P(itir.Sym, id=eve.SymbolName("__out_size_0")), - ], - closures=[closure_pattern], - ) - - fencil_pattern.match(itir_node, raise_exception=True) - - -def test_copy_restrict_lowering(copy_restrict_program_def, itir_identity_fundef): - past_node = ProgramParser.apply_to_function(copy_restrict_program_def) - itir_node = ProgramLowering.apply( - past_node, function_definitions=[itir_identity_fundef], grid_type=gtx.GridType.CARTESIAN - ) - closure_pattern = P( - itir.StencilClosure, - domain=P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("cartesian_domain")), - args=[ - P( - itir.FunCall, - fun=P(itir.SymRef, id=eve.SymbolRef("named_range")), - args=[ - P(itir.AxisLiteral, value="IDim"), - P( - itir.Literal, - value="1", - type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) - ), - ), - P( - itir.Literal, - value="2", - type=ts.ScalarType( - kind=getattr(ts.ScalarKind, itir.INTEGER_INDEX_BUILTIN.upper()) - ), - ), - ], - ) - ], - ), - ) - fencil_pattern = P( - itir.FencilDefinition, - id=eve.SymbolName("copy_restrict_program"), - params=[ - P(itir.Sym, id=eve.SymbolName("in_field")), - P(itir.Sym, id=eve.SymbolName("out")), - P(itir.Sym, id=eve.SymbolName("__in_field_size_0")), - P(itir.Sym, id=eve.SymbolName("__out_size_0")), - ], - closures=[closure_pattern], - ) - - fencil_pattern.match(itir_node, raise_exception=True) - - -def test_tuple_constructed_in_out_with_slicing(make_tuple_op): - def tuple_program( - inp: gtx.Field[[IDim], float64], - out1: gtx.Field[[IDim], float64], - out2: gtx.Field[[IDim], float64], - ): - make_tuple_op(inp, out=(out1[1:], out2[1:])) - - parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) - - -@pytest.mark.xfail( - reason="slicing is only allowed if all fields are sliced in the same way." -) # see ADR 10 -def test_tuple_constructed_in_out_with_slicing(make_tuple_op): - def tuple_program( - inp: gtx.Field[[IDim], float64], - out1: gtx.Field[[IDim], float64], - out2: gtx.Field[[IDim], float64], - ): - make_tuple_op(inp, out=(out1[1:], out2)) - - parsed = ProgramParser.apply_to_function(tuple_program) - ProgramLowering.apply(parsed, function_definitions=[], grid_type=gtx.GridType.CARTESIAN) - - -@pytest.mark.xfail -def test_inout_prohibited(identity_def): - identity = gtx.field_operator(identity_def) - - def inout_field_program(inout_field: gtx.Field[[IDim], "float64"]): - identity(inout_field, out=inout_field) - - with pytest.raises( - ValueError, match=(r"Call to function with field as input and output not allowed.") - ): - ProgramLowering.apply( - ProgramParser.apply_to_function(inout_field_program), - function_definitions=[], - grid_type=gtx.GridType.CARTESIAN, - ) - - -def test_invalid_call_sig_program(invalid_call_sig_program_def): - with pytest.raises(errors.DSLError) as exc_info: - ProgramLowering.apply( - ProgramParser.apply_to_function(invalid_call_sig_program_def), - function_definitions=[], - grid_type=gtx.GridType.CARTESIAN, - ) - - assert exc_info.match("Invalid call to 'identity'") - # TODO(tehrengruber): re-enable again when call signature check doesn't return - # immediately after missing `out` argument - # assert ( - # re.search( - # "Function takes 1 arguments, but 2 were given.", exc_info.value.__cause__.args[0] - # ) - # is not None - # ) - assert ( - re.search(r"Missing required keyword argument 'out'", exc_info.value.__cause__.args[0]) - is not None - ) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py new file mode 100644 index 0000000000..ff7a761c5a --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/test_inline_dynamic_shifts.py @@ -0,0 +1,48 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Callable, Optional + +from gt4py import next as gtx +from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import inline_dynamic_shifts +from gt4py.next.type_system import type_specifications as ts + +IDim = gtx.Dimension("IDim") +field_type = ts.FieldType(dims=[IDim], dtype=ts.ScalarType(kind=ts.ScalarKind.INT32)) + + +def test_inline_dynamic_shift_as_fieldop_arg(): + testee = im.as_fieldop(im.lambda_("a", "b")(im.deref(im.shift("IOff", im.deref("b"))("a"))))( + im.as_fieldop("deref")("inp"), "offset_field" + ) + expected = im.as_fieldop( + im.lambda_("inp", "offset_field")( + im.deref(im.shift("IOff", im.deref("offset_field"))("inp")) + ) + )("inp", "offset_field") + + actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee) + assert actual == expected + + +def test_inline_dynamic_shift_let_var(): + testee = im.let("tmp", im.as_fieldop("deref")("inp"))( + im.as_fieldop(im.lambda_("a", "b")(im.deref(im.shift("IOff", im.deref("b"))("a"))))( + "tmp", "offset_field" + ) + ) + + expected = im.as_fieldop( + im.lambda_("inp", "offset_field")( + im.deref(im.shift("IOff", im.deref("offset_field"))("inp")) + ) + )("inp", "offset_field") + + actual = inline_dynamic_shifts.InlineDynamicShifts.apply(testee) + assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py index da4bea8874..bf47f997d6 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_parser.py @@ -7,8 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.next.iterator import ir -from gt4py.next.iterator.pretty_parser import pparse from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.pretty_parser import pparse from gt4py.next.type_system import type_specifications as ts @@ -208,18 +208,6 @@ def test_temporary(): assert actual == expected -def test_stencil_closure(): - testee = "y ← (deref)(x) @ cartesian_domain();" - expected = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], - ) - actual = pparse(testee) - assert actual == expected - - def test_set_at(): testee = "y @ cartesian_domain() ← x;" expected = ir.SetAt( @@ -262,28 +250,6 @@ def test_if_stmt(): assert actual == expected -# TODO(havogt): remove after refactoring to GTIR -def test_fencil_definition(): - testee = "f(d, x, y) {\n g = λ(x) → x;\n y ← (deref)(x) @ cartesian_domain();\n}" - expected = ir.FencilDefinition( - id="f", - function_definitions=[ - ir.FunctionDefinition(id="g", params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")) - ], - params=[ir.Sym(id="d"), ir.Sym(id="x"), ir.Sym(id="y")], - closures=[ - ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], - ) - ], - ) - actual = pparse(testee) - assert actual == expected - - def test_program(): testee = "f(d, x, y) {\n g = λ(x) → x;\n tmp = temporary(domain=cartesian_domain(), dtype=float64);\n y @ cartesian_domain() ← x;\n}" expected = ir.Program( diff --git a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py index 69a45cf128..11f50dbf6d 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_pretty_printer.py @@ -7,8 +7,8 @@ # SPDX-License-Identifier: BSD-3-Clause from gt4py.next.iterator import ir -from gt4py.next.iterator.pretty_printer import PrettyPrinter, pformat from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.pretty_printer import PrettyPrinter, pformat from gt4py.next.type_system import type_specifications as ts @@ -313,18 +313,6 @@ def test_temporary(): assert actual == expected -def test_stencil_closure(): - testee = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], - ) - expected = "y ← (deref)(x) @ cartesian_domain();" - actual = pformat(testee) - assert actual == expected - - def test_set_at(): testee = ir.SetAt( expr=ir.SymRef(id="x"), @@ -336,28 +324,6 @@ def test_set_at(): assert actual == expected -# TODO(havogt): remove after refactoring. -def test_fencil_definition(): - testee = ir.FencilDefinition( - id="f", - function_definitions=[ - ir.FunctionDefinition(id="g", params=[ir.Sym(id="x")], expr=ir.SymRef(id="x")) - ], - params=[ir.Sym(id="d"), ir.Sym(id="x"), ir.Sym(id="y")], - closures=[ - ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.SymRef(id="deref"), - output=ir.SymRef(id="y"), - inputs=[ir.SymRef(id="x")], - ) - ], - ) - actual = pformat(testee) - expected = "f(d, x, y) {\n g = λ(x) → x;\n y ← (deref)(x) @ cartesian_domain();\n}" - assert actual == expected - - def test_program(): testee = ir.Program( id="f", diff --git a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py index 13e8637d1a..bf2df06bf2 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_runtime_domain.py @@ -27,21 +27,20 @@ def foo(inp): dtype=None, ) +I = gtx.Dimension("I") + def test_deduce_domain(): assert isinstance(_deduce_domain({}, {}), CartesianDomain) assert isinstance(_deduce_domain(UnstructuredDomain(), {}), UnstructuredDomain) assert isinstance(_deduce_domain({}, {"foo": connectivity}), UnstructuredDomain) assert isinstance( - _deduce_domain(CartesianDomain([("I", range(1))]), {"foo": connectivity}), CartesianDomain + _deduce_domain(CartesianDomain([(I, range(1))]), {"foo": connectivity}), CartesianDomain ) -I = gtx.Dimension("I") - - def test_embedded_error_on_wrong_domain(): - dom = CartesianDomain([("I", range(1))]) + dom = CartesianDomain([(I, range(1))]) out = gtx.as_field([I], np.zeros(1)) with pytest.raises(RuntimeError, match="expected 'UnstructuredDomain'"): diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 305947ec77..e333015152 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -25,13 +25,12 @@ ) from gt4py.next.type_system import type_specifications as ts -from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import simple_mesh - from next_tests.integration_tests.cases import ( C2E, E2V, V2E, E2VDim, + Edge, IDim, Ioff, JDim, @@ -39,11 +38,12 @@ Koff, V2EDim, Vertex, - Edge, - mesh_descriptor, exec_alloc_descriptor, + mesh_descriptor, unstructured_case, ) +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import simple_mesh + bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) @@ -287,48 +287,35 @@ def test_cast_first_arg_inference(): assert result.type == float64_type -# TODO(tehrengruber): Rewrite tests to use itir.Program def test_cartesian_fencil_definition(): cartesian_domain = im.call("cartesian_domain")( im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) ) - testee = itir.FencilDefinition( + testee = itir.Program( id="f", function_definitions=[], params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( + expr=im.call(im.call("as_fieldop")(im.ref("deref"), cartesian_domain))( + im.ref("inp") + ), domain=cartesian_domain, - stencil=im.ref("deref"), - output=im.ref("out"), - inputs=[im.ref("inp")], + target=im.ref("out"), ), ], ) result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) - closure_type = it_ts.StencilClosureType( - domain=it_ts.DomainType(dims=[IDim]), - stencil=ts.FunctionType( - pos_only_args=[ - it_ts.IteratorType( - position_dims=[IDim], defined_dims=[IDim], element_type=float64_type - ) - ], - pos_or_kw_args={}, - kw_only_args={}, - returns=float64_type, - ), - output=float_i_field, - inputs=[float_i_field], - ) - fencil_type = it_ts.FencilType( - params={"inp": float_i_field, "out": float_i_field}, closures=[closure_type] - ) - assert result.type == fencil_type - assert result.closures[0].type == closure_type + program_type = it_ts.ProgramType(params={"inp": float_i_field, "out": float_i_field}) + assert result.type == program_type + domain_type = it_ts.DomainType(dims=[IDim]) + assert result.body[0].domain.type == domain_type + assert result.body[0].expr.type == float_i_field + assert result.body[0].target.type == float_i_field def test_unstructured_fencil_definition(): @@ -342,44 +329,34 @@ def test_unstructured_fencil_definition(): ), ) - testee = itir.FencilDefinition( + testee = itir.Program( id="f", function_definitions=[], params=[im.sym("inp", float_edge_k_field), im.sym("out", float_vertex_k_field)], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), unstructured_domain + ) + )(im.ref("inp")), domain=unstructured_domain, - stencil=im.lambda_("it")(im.deref(im.shift("V2E", 0)("it"))), - output=im.ref("out"), - inputs=[im.ref("inp")], + target=im.ref("out"), ), ], ) result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) - closure_type = it_ts.StencilClosureType( - domain=it_ts.DomainType(dims=[Vertex, KDim]), - stencil=ts.FunctionType( - pos_only_args=[ - it_ts.IteratorType( - position_dims=[Vertex, KDim], - defined_dims=[Edge, KDim], - element_type=float64_type, - ) - ], - pos_or_kw_args={}, - kw_only_args={}, - returns=float64_type, - ), - output=float_vertex_k_field, - inputs=[float_edge_k_field], - ) - fencil_type = it_ts.FencilType( - params={"inp": float_edge_k_field, "out": float_vertex_k_field}, closures=[closure_type] + program_type = it_ts.ProgramType( + params={"inp": float_edge_k_field, "out": float_vertex_k_field} ) - assert result.type == fencil_type - assert result.closures[0].type == closure_type + assert result.type == program_type + domain_type = it_ts.DomainType(dims=[Vertex, KDim]) + assert result.body[0].domain.type == domain_type + assert result.body[0].expr.type == float_vertex_k_field + assert result.body[0].target.type == float_vertex_k_field def test_function_definition(): @@ -387,45 +364,29 @@ def test_function_definition(): im.call("named_range")(itir.AxisLiteral(value="IDim"), 0, 1) ) - testee = itir.FencilDefinition( + testee = itir.Program( id="f", function_definitions=[ itir.FunctionDefinition(id="foo", params=[im.sym("it")], expr=im.deref("it")), itir.FunctionDefinition(id="bar", params=[im.sym("it")], expr=im.call("foo")("it")), ], params=[im.sym("inp", float_i_field), im.sym("out", float_i_field)], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( domain=cartesian_domain, - stencil=im.ref("bar"), - output=im.ref("out"), - inputs=[im.ref("inp")], + expr=im.call(im.call("as_fieldop")(im.ref("bar"), cartesian_domain))(im.ref("inp")), + target=im.ref("out"), ), ], ) result = itir_type_inference.infer(testee, offset_provider_type={"Ioff": IDim}) - closure_type = it_ts.StencilClosureType( - domain=it_ts.DomainType(dims=[IDim]), - stencil=ts.FunctionType( - pos_only_args=[ - it_ts.IteratorType( - position_dims=[IDim], defined_dims=[IDim], element_type=float64_type - ) - ], - pos_or_kw_args={}, - kw_only_args={}, - returns=float64_type, - ), - output=float_i_field, - inputs=[float_i_field], - ) - fencil_type = it_ts.FencilType( - params={"inp": float_i_field, "out": float_i_field}, closures=[closure_type] - ) - assert result.type == fencil_type - assert result.closures[0].type == closure_type + program_type = it_ts.ProgramType(params={"inp": float_i_field, "out": float_i_field}) + assert result.type == program_type + assert result.body[0].expr.type == float_i_field + assert result.body[0].target.type == float_i_field def test_fencil_with_nb_field_input(): @@ -439,24 +400,30 @@ def test_fencil_with_nb_field_input(): ), ) - testee = itir.FencilDefinition( + testee = itir.Program( id="f", function_definitions=[], params=[im.sym("inp", float_vertex_v2e_field), im.sym("out", float_vertex_k_field)], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( domain=unstructured_domain, - stencil=im.lambda_("it")(im.call(im.call("reduce")("plus", 0.0))(im.deref("it"))), - output=im.ref("out"), - inputs=[im.ref("inp")], + expr=im.call( + im.call("as_fieldop")( + im.lambda_("it")(im.call(im.call("reduce")("plus", 0.0))(im.deref("it"))), + unstructured_domain, + ) + )(im.ref("inp")), + target=im.ref("out"), ), ], ) result = itir_type_inference.infer(testee, offset_provider_type=mesh.offset_provider_type) - assert result.closures[0].stencil.expr.args[0].type == float64_list_type - assert result.closures[0].stencil.type.returns == float64_type + stencil = result.body[0].expr.fun.args[0] + assert stencil.expr.args[0].type == float64_list_type + assert stencil.type.returns == float64_type def test_program_tuple_setat_short_target(): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 817c06e8f0..779ab738cb 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -8,21 +8,24 @@ # TODO(SF-N): test scan operator -import pytest +from typing import Iterable, Literal, Optional, Union + import numpy as np -from typing import Iterable, Optional, Literal, Union +import pytest from gt4py import eve -from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next import constructors +from gt4py.next import common, constructors, utils +from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir +from gt4py.next.iterator.ir_utils import ( + common_pattern_matcher as cpm, + domain_utils, + ir_makers as im, +) from gt4py.next.iterator.transforms import infer_domain -from gt4py.next.iterator.ir_utils import domain_utils -from gt4py.next.common import Dimension -from gt4py.next import common -from gt4py.next.type_system import type_specifications as ts from gt4py.next.iterator.transforms.constant_folding import ConstantFolding -from gt4py.next import utils +from gt4py.next.type_system import type_specifications as ts + float_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) IDim = common.Dimension(value="IDim", kind=common.DimensionKind.HORIZONTAL) @@ -73,7 +76,7 @@ def setup_test_as_fieldop( def run_test_program( testee: itir.Program, expected: itir.Program, offset_provider: common.OffsetProvider ) -> None: - actual_program = infer_domain.infer_program(testee, offset_provider) + actual_program = infer_domain.infer_program(testee, offset_provider=offset_provider) folded_program = constant_fold_domain_exprs(actual_program) assert folded_program == expected @@ -86,12 +89,14 @@ def run_test_expr( expected_domains: dict[str, itir.Expr | dict[str | Dimension, tuple[itir.Expr, itir.Expr]]], offset_provider: common.OffsetProvider, symbolic_domain_sizes: Optional[dict[str, str]] = None, + allow_uninferred: bool = False, ): actual_call, actual_domains = infer_domain.infer_expr( testee, domain_utils.SymbolicDomain.from_expr(domain), - offset_provider, - symbolic_domain_sizes, + offset_provider=offset_provider, + symbolic_domain_sizes=symbolic_domain_sizes, + allow_uninferred=allow_uninferred, ) folded_call = constant_fold_domain_exprs(actual_call) folded_domains = constant_fold_accessed_domains(actual_domains) if actual_domains else None @@ -101,10 +106,8 @@ def run_test_expr( def canonicalize_domain(d): if isinstance(d, dict): return im.domain(grid_type, d) - elif isinstance(d, itir.FunCall): + elif isinstance(d, (itir.FunCall, infer_domain.DomainAccessDescriptor)): return d - elif d is None: - return None raise AssertionError() expected_domains = {ref: canonicalize_domain(d) for ref, d in expected_domains.items()} @@ -125,10 +128,12 @@ def constant_fold_domain_exprs(arg: itir.Node) -> itir.Node: def constant_fold_accessed_domains( - domains: infer_domain.ACCESSED_DOMAINS, -) -> infer_domain.ACCESSED_DOMAINS: - def fold_domain(domain: domain_utils.SymbolicDomain | None): - if domain is None: + domains: infer_domain.AccessedDomains, +) -> infer_domain.AccessedDomains: + def fold_domain( + domain: domain_utils.SymbolicDomain | Literal[infer_domain.DomainAccessDescriptor.NEVER], + ): + if isinstance(domain, infer_domain.DomainAccessDescriptor): return domain return constant_fold_domain_exprs(domain.as_expr()) @@ -151,7 +156,7 @@ def translate_domain( shift_list = [item for sublist in shift_tuples for item in sublist] translated_domain_expr = domain_utils.SymbolicDomain.from_expr(domain).translate( - shift_list, offset_provider + shift_list, offset_provider=offset_provider ) return constant_fold_domain_exprs(translated_domain_expr.as_expr()) @@ -337,7 +342,7 @@ def test_nested_stencils(offset_provider): "in_field2": translate_domain(domain, {"Ioff": 0, "Joff": -2}, offset_provider), } actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) folded_call = constant_fold_domain_exprs(actual_call) @@ -381,7 +386,7 @@ def test_nested_stencils_n_times(offset_provider, iterations): } actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -394,7 +399,10 @@ def test_unused_input(offset_provider): stencil = im.lambda_("arg0", "arg1")(im.deref("arg0")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) - expected_domains = {"in_field1": {IDim: (0, 11)}, "in_field2": None} + expected_domains = { + "in_field1": {IDim: (0, 11)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } testee, expected = setup_test_as_fieldop( stencil, domain, @@ -406,7 +414,7 @@ def test_let_unused_field(offset_provider): testee = im.let("a", "c")("b") domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.let("a", "c")("b") - expected_domains = {"b": {IDim: (0, 11)}, "c": None} + expected_domains = {"b": {IDim: (0, 11)}, "c": infer_domain.DomainAccessDescriptor.NEVER} run_test_expr(testee, expected, domain, expected_domains, offset_provider) @@ -519,7 +527,7 @@ def test_cond(offset_provider): expected = im.if_(cond, expected_field_1, expected_field_2) actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains = constant_fold_accessed_domains(actual_domains) @@ -576,7 +584,7 @@ def test_let(offset_provider): expected_domains_sym = {"in_field": translate_domain(domain, {"Ioff": 2}, offset_provider)} actual_call2, actual_domains2 = infer_domain.infer_expr( - testee2, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee2, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_domains2 = constant_fold_accessed_domains(actual_domains2) folded_call2 = constant_fold_domain_exprs(actual_call2) @@ -800,7 +808,7 @@ def test_make_tuple(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -812,13 +820,13 @@ def test_tuple_get_1_make_tuple(offset_provider): domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.make_tuple(im.ref("a"), im.ref("b"), im.ref("c"))) expected_domains = { - "a": None, + "a": infer_domain.DomainAccessDescriptor.NEVER, "b": im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}), - "c": None, + "c": infer_domain.DomainAccessDescriptor.NEVER, } actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -830,7 +838,7 @@ def test_tuple_get_1_nested_make_tuple(offset_provider): domain1 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) domain2 = im.domain(common.GridType.CARTESIAN, {IDim: (0, 12)}) expected = im.tuple_get(1, im.make_tuple(im.ref("a"), im.make_tuple(im.ref("b"), im.ref("c")))) - expected_domains = {"a": None, "b": domain1, "c": domain2} + expected_domains = {"a": infer_domain.DomainAccessDescriptor.NEVER, "b": domain1, "c": domain2} actual, actual_domains = infer_domain.infer_expr( testee, @@ -838,7 +846,7 @@ def test_tuple_get_1_nested_make_tuple(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -849,14 +857,18 @@ def test_tuple_get_let_arg_make_tuple(offset_provider): testee = im.tuple_get(1, im.let("a", im.make_tuple(im.ref("b"), im.ref("c")))("d")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.let("a", im.make_tuple(im.ref("b"), im.ref("c")))("d")) - expected_domains = {"b": None, "c": None, "d": (None, domain)} + expected_domains = { + "b": infer_domain.DomainAccessDescriptor.NEVER, + "c": infer_domain.DomainAccessDescriptor.NEVER, + "d": (infer_domain.DomainAccessDescriptor.NEVER, domain), + } actual, actual_domains = infer_domain.infer_expr( testee, domain_utils.SymbolicDomain.from_expr( im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -867,12 +879,16 @@ def test_tuple_get_let_make_tuple(offset_provider): testee = im.tuple_get(1, im.let("a", "b")(im.make_tuple(im.ref("c"), im.ref("d")))) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.let("a", "b")(im.make_tuple(im.ref("c"), im.ref("d")))) - expected_domains = {"c": None, "d": domain, "b": None} + expected_domains = { + "c": infer_domain.DomainAccessDescriptor.NEVER, + "d": domain, + "b": infer_domain.DomainAccessDescriptor.NEVER, + } actual, actual_domains = infer_domain.infer_expr( testee, domain_utils.SymbolicDomain.from_expr(domain), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -900,7 +916,7 @@ def test_nested_make_tuple(offset_provider): ), domain_utils.SymbolicDomain.from_expr(domain3), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -911,10 +927,10 @@ def test_tuple_get_1(offset_provider): testee = im.tuple_get(1, im.ref("a")) domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 11)}) expected = im.tuple_get(1, im.ref("a")) - expected_domains = {"a": (None, domain)} + expected_domains = {"a": (infer_domain.DomainAccessDescriptor.NEVER, domain)} actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -934,7 +950,7 @@ def test_domain_tuple(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -950,7 +966,7 @@ def test_as_fieldop_tuple_get(offset_provider): expected_domains = {"a": (domain, domain)} actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -970,7 +986,7 @@ def test_make_tuple_2tuple_get(offset_provider): domain_utils.SymbolicDomain.from_expr(domain1), domain_utils.SymbolicDomain.from_expr(domain2), ), - offset_provider, + offset_provider=offset_provider, ) assert expected == actual @@ -987,7 +1003,7 @@ def test_make_tuple_non_tuple_domain(offset_provider): expected_domains = {"in_field1": domain, "in_field2": domain} actual, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) assert expected == actual @@ -1001,7 +1017,7 @@ def test_arithmetic_builtin(offset_provider): expected_domains = {} actual_call, actual_domains = infer_domain.infer_expr( - testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider + testee, domain_utils.SymbolicDomain.from_expr(domain), offset_provider=offset_provider ) folded_call = constant_fold_domain_exprs(actual_call) @@ -1045,3 +1061,35 @@ def test_symbolic_domain_sizes(unstructured_offset_provider): unstructured_offset_provider, symbolic_domain_sizes, ) + + +def test_unknown_domain(offset_provider): + stencil = im.lambda_("arg0", "arg1")(im.deref(im.shift("Ioff", im.deref("arg1"))("arg0"))) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": infer_domain.DomainAccessDescriptor.UNKNOWN, + "in_field2": {IDim: (0, 10)}, + } + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_never_accessed_domain(offset_provider): + stencil = im.lambda_("arg0", "arg1")(im.deref("arg0")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": {IDim: (0, 10)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } + testee, expected = setup_test_as_fieldop(stencil, domain) + run_test_expr(testee, expected, domain, expected_domains, offset_provider) + + +def test_never_accessed_domain_tuple(offset_provider): + testee = im.tuple_get(0, im.make_tuple("in_field1", "in_field2")) + domain = im.domain(common.GridType.CARTESIAN, {IDim: (0, 10)}) + expected_domains = { + "in_field1": {IDim: (0, 10)}, + "in_field2": infer_domain.DomainAccessDescriptor.NEVER, + } + run_test_expr(testee, testee, domain, expected_domains, offset_provider) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py index 2e0a83d33b..c10d48ad06 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_inline_lambdas.py @@ -84,3 +84,10 @@ def test_inline_lambda_args(): ) inlined = InlineLambdas.apply(testee, opcount_preserving=True, force_inline_lambda_args=True) assert inlined == expected + + +def test_type_preservation(): + testee = im.let("a", "b")("a") + testee.type = testee.annex.type = ts.ScalarType(kind=ts.ScalarKind.FLOAT32) + inlined = InlineLambdas.apply(testee) + assert inlined.type == inlined.annex.type == ts.ScalarType(kind=ts.ScalarKind.FLOAT32) diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py deleted file mode 100644 index 407ccad924..0000000000 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_prune_closure_inputs.py +++ /dev/null @@ -1,68 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from gt4py.next.iterator import ir -from gt4py.next.iterator.transforms.prune_closure_inputs import PruneClosureInputs - - -def test_simple(): - testee = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="x"), ir.Sym(id="y"), ir.Sym(id="z")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="y")]), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="foo"), ir.SymRef(id="bar"), ir.SymRef(id="baz")], - ) - expected = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="y")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="y")]), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="bar")], - ) - actual = PruneClosureInputs().visit(testee) - assert actual == expected - - -def test_shadowing(): - testee = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="x"), ir.Sym(id="y"), ir.Sym(id="z")], - expr=ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="z")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="z")]), - ), - args=[ir.SymRef(id="y")], - ), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="foo"), ir.SymRef(id="bar"), ir.SymRef(id="baz")], - ) - expected = ir.StencilClosure( - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - stencil=ir.Lambda( - params=[ir.Sym(id="y")], - expr=ir.FunCall( - fun=ir.Lambda( - params=[ir.Sym(id="z")], - expr=ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="z")]), - ), - args=[ir.SymRef(id="y")], - ), - ), - output=ir.SymRef(id="out"), - inputs=[ir.SymRef(id="bar")], - ) - actual = PruneClosureInputs().visit(testee) - assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py index 0c118ff6dc..c162860c7c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_symbol_ref_utils.py @@ -6,28 +6,23 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from dataclasses import dataclass -from typing import Optional -from gt4py import eve from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms.symbol_ref_utils import ( - collect_symbol_refs, - get_user_defined_symbols, -) +from gt4py.next.iterator.transforms.symbol_ref_utils import get_user_defined_symbols def test_get_user_defined_symbols(): - ir = itir.FencilDefinition( + domain = itir.FunCall(fun=itir.SymRef(id="cartesian_domain"), args=[]) + ir = itir.Program( id="foo", function_definitions=[], params=[itir.Sym(id="target_symbol")], - closures=[ - itir.StencilClosure( - domain=itir.FunCall(fun=itir.SymRef(id="cartesian_domain"), args=[]), - stencil=itir.SymRef(id="deref"), - output=itir.SymRef(id="target_symbol"), - inputs=[], + declarations=[], + body=[ + itir.SetAt( + expr=itir.Lambda(params=[itir.Sym(id="foo")], expr=itir.SymRef(id="foo")), + domain=domain, + target=itir.SymRef(id="target_symbol"), ) ], ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index e64bd8a57d..0586d48703 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -6,11 +6,11 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import numpy as np -import pytest import copy -import diskcache +import diskcache +import numpy as np +import pytest import gt4py.next as gtx from gt4py.next.iterator import ir as itir @@ -19,18 +19,17 @@ from gt4py.next.program_processors.codegens.gtfn import gtfn_module from gt4py.next.program_processors.runners import gtfn from gt4py.next.type_system import type_translation -from next_tests.integration_tests import cases -from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import KDim +from next_tests.integration_tests import cases from next_tests.integration_tests.cases import cartesian_case - from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + KDim, exec_alloc_descriptor, ) @pytest.fixture -def fencil_example(): +def program_example(): IDim = gtx.Dimension("I") params = [gtx.as_field([IDim], np.empty((1,), dtype=np.float32)), np.float32(3.14)] param_types = [type_translation.from_value(param) for param in params] @@ -48,7 +47,7 @@ def fencil_example(): ) ], ) - fencil = itir.FencilDefinition( + program = itir.Program( id="example", params=[im.sym(name, type_) for name, type_ in zip(("buf", "sc"), param_types)], function_definitions=[ @@ -58,20 +57,22 @@ def fencil_example(): expr=im.literal("1", "float32"), ) ], - closures=[ - itir.StencilClosure( + declarations=[], + body=[ + itir.SetAt( + expr=im.call(im.call("as_fieldop")(itir.SymRef(id="stencil"), domain))( + itir.SymRef(id="buf"), itir.SymRef(id="sc") + ), domain=domain, - stencil=itir.SymRef(id="stencil"), - output=itir.SymRef(id="buf"), - inputs=[itir.SymRef(id="buf"), itir.SymRef(id="sc")], + target=itir.SymRef(id="buf"), ) ], ) - return fencil, params + return program, params -def test_codegen(fencil_example): - fencil, parameters = fencil_example +def test_codegen(program_example): + fencil, parameters = program_example module = gtfn_module.translate_program_cpu( stages.CompilableProgram( data=fencil, @@ -85,8 +86,8 @@ def test_codegen(fencil_example): assert module.language is languages.CPP -def test_hash_and_diskcache(fencil_example, tmp_path): - fencil, parameters = fencil_example +def test_hash_and_diskcache(program_example, tmp_path): + fencil, parameters = program_example compilable_program = stages.CompilableProgram( data=fencil, args=arguments.CompileTimeArgs.from_concrete_no_size( @@ -129,8 +130,8 @@ def test_hash_and_diskcache(fencil_example, tmp_path): ) != gtfn.fingerprint_compilable_program(altered_program_column_axis) -def test_gtfn_file_cache(fencil_example): - fencil, parameters = fencil_example +def test_gtfn_file_cache(program_example): + fencil, parameters = program_example compilable_program = stages.CompilableProgram( data=fencil, args=arguments.CompileTimeArgs.from_concrete_no_size( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py index c7466b853f..03b8e3bc15 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_gtir_to_sdfg.py @@ -36,6 +36,7 @@ from . import pytestmark + dace_backend = pytest.importorskip("gt4py.next.program_processors.runners.dace_fieldview") @@ -1984,7 +1985,7 @@ def test_gtir_index(): ) testee = gtir.Program( - id="gtir_cast", + id="gtir_index", function_definitions=[], params=[ gtir.Sym(id="x", type=ts.FieldType(dims=[IDim], dtype=SIZE_TYPE)), diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py index e85ef6ad1f..0eb0bf39c2 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/conftest.py @@ -11,7 +11,7 @@ import pytest -@pytest.fixture() +@pytest.fixture(autouse=True) def set_dace_settings() -> Generator[None, None, None]: """Sets the common DaCe settings for the tests. @@ -24,6 +24,6 @@ def set_dace_settings() -> Generator[None, None, None]: import dace with dace.config.temporary_config(): - dace.Config.set("optimizer", "match_exception", value=False) + dace.Config.set("optimizer", "match_exception", value=True) dace.Config.set("compiler", "allow_view_arguments", value=True) yield diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py new file mode 100644 index 0000000000..04a4f098ef --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_constant_substitution.py @@ -0,0 +1,142 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + + +def test_constant_substitution(): + sdfg, nsdfg = _make_sdfg() + + # Ensure that `One` is present. + assert len(sdfg.symbols) == 2 + assert len(nsdfg.sdfg.symbols) == 2 + assert len(nsdfg.symbol_mapping) == 2 + assert "One" in sdfg.symbols + assert "One" in nsdfg.sdfg.symbols + assert "One" in nsdfg.symbol_mapping + assert "One" == str(nsdfg.symbol_mapping["One"]) + assert all(str(desc.strides[1]) == "One" for desc in sdfg.arrays.values()) + assert all(str(desc.strides[1]) == "One" for desc in nsdfg.sdfg.arrays.values()) + assert all(str(desc.strides[0]) == "N" for desc in sdfg.arrays.values()) + assert all(str(desc.strides[0]) == "N" for desc in nsdfg.sdfg.arrays.values()) + assert "One" in sdfg.used_symbols(True) + + # Now replace `One` with 1 + gtx_transformations.gt_substitute_compiletime_symbols(sdfg, {"One": 1}) + + assert len(sdfg.symbols) == 1 + assert len(nsdfg.sdfg.symbols) == 1 + assert len(nsdfg.symbol_mapping) == 1 + assert "One" not in sdfg.symbols + assert "One" not in nsdfg.sdfg.symbols + assert "One" not in nsdfg.symbol_mapping + assert all(desc.strides[1] == 1 and len(desc.strides) == 2 for desc in sdfg.arrays.values()) + assert all( + desc.strides[1] == 1 and len(desc.strides) == 2 for desc in nsdfg.sdfg.arrays.values() + ) + assert all(str(desc.strides[0]) == "N" for desc in sdfg.arrays.values()) + assert all(str(desc.strides[0]) == "N" for desc in nsdfg.sdfg.arrays.values()) + assert "One" not in sdfg.used_symbols(True) + + +def _make_nested_sdfg() -> dace.SDFG: + sdfg = dace.SDFG("nested") + N = dace.symbol(sdfg.add_symbol("N", dace.int32)) + One = dace.symbol(sdfg.add_symbol("One", dace.int32)) + for name in "ABC": + sdfg.add_array( + name=name, + dtype=dace.float64, + shape=(N, N), + strides=(N, One), + transient=False, + ) + state = sdfg.add_state(is_start_block=True) + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:N", "__i1": "0:N"}, + inputs={ + "__in0": dace.Memlet("A[__i0, __i1]"), + "__in1": dace.Memlet("B[__i0, __i1]"), + }, + code="__out = __in0 + __in1", + outputs={"__out": dace.Memlet("C[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg + + +def _make_sdfg() -> tuple[dace.SDFG, dace.nodes.NestedSDFG]: + sdfg = dace.SDFG("outer_sdfg") + N = dace.symbol(sdfg.add_symbol("N", dace.int32)) + One = dace.symbol(sdfg.add_symbol("One", dace.int32)) + for name in "ABCD": + sdfg.add_array( + name=name, + dtype=dace.float64, + shape=(N, N), + strides=(N, One), + transient=False, + ) + sdfg.arrays["C"].transient = True + + first_state: dace.SDFGState = sdfg.add_state(is_start_block=True) + nested_sdfg: dace.SDFG = _make_nested_sdfg() + nsdfg = first_state.add_nested_sdfg( + nested_sdfg, + parent=sdfg, + inputs={"A", "B"}, + outputs={"C"}, + symbol_mapping={"One": "One", "N": "N"}, + ) + first_state.add_edge( + first_state.add_access("A"), + None, + nsdfg, + "A", + dace.Memlet("A[0:N, 0:N]"), + ) + first_state.add_edge( + first_state.add_access("B"), + None, + nsdfg, + "B", + dace.Memlet("B[0:N, 0:N]"), + ) + first_state.add_edge( + nsdfg, + "C", + first_state.add_access("C"), + None, + dace.Memlet("C[0:N, 0:N]"), + ) + + second_state: dace.SDFGState = sdfg.add_state_after(first_state) + second_state.add_mapped_tasklet( + "outer_computation", + map_ranges={"__i0": "0:N", "__i1": "0:N"}, + inputs={ + "__in0": dace.Memlet("A[__i0, __i1]"), + "__in1": dace.Memlet("C[__i0, __i1]"), + }, + code="__out = __in0 * __in1", + outputs={"__out": dace.Memlet("D[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + return sdfg, nsdfg diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py new file mode 100644 index 0000000000..3d9201c603 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py @@ -0,0 +1,239 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np +import copy + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes +from dace import data as dace_data + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _create_sdfg_double_read_part_1( + sdfg: dace.SDFG, + state: dace.SDFGState, + me: dace.nodes.MapEntry, + mx: dace.nodes.MapExit, + A_in: dace.nodes.AccessNode, + nb: int, +) -> dace.nodes.Tasklet: + tskl = state.add_tasklet( + name=f"tasklet_1", inputs={"__in1"}, outputs={"__out"}, code="__out = __in1 + 1.0" + ) + + state.add_edge(A_in, None, me, f"IN_{nb}", dace.Memlet("A[0:10]")) + state.add_edge(me, f"OUT_{nb}", tskl, "__in1", dace.Memlet("A[__i0]")) + me.add_in_connector(f"IN_{nb}") + me.add_out_connector(f"OUT_{nb}") + + state.add_edge(tskl, "__out", mx, f"IN_{nb}", dace.Memlet("A[__i0]")) + state.add_edge(mx, f"OUT_{nb}", state.add_access("A"), None, dace.Memlet("A[0:10]")) + mx.add_in_connector(f"IN_{nb}") + mx.add_out_connector(f"OUT_{nb}") + + +def _create_sdfg_double_read_part_2( + sdfg: dace.SDFG, + state: dace.SDFGState, + me: dace.nodes.MapEntry, + mx: dace.nodes.MapExit, + A_in: dace.nodes.AccessNode, + nb: int, +) -> dace.nodes.Tasklet: + tskl = state.add_tasklet( + name=f"tasklet_2", inputs={"__in1"}, outputs={"__out"}, code="__out = __in1 + 3.0" + ) + + state.add_edge(A_in, None, me, f"IN_{nb}", dace.Memlet("A[0:10]")) + state.add_edge(me, f"OUT_{nb}", tskl, "__in1", dace.Memlet("A[__i0]")) + me.add_in_connector(f"IN_{nb}") + me.add_out_connector(f"OUT_{nb}") + + state.add_edge(tskl, "__out", mx, f"IN_{nb}", dace.Memlet("B[__i0]")) + state.add_edge(mx, f"OUT_{nb}", state.add_access("B"), None, dace.Memlet("B[0:10]")) + mx.add_in_connector(f"IN_{nb}") + mx.add_out_connector(f"OUT_{nb}") + + +def _create_sdfg_double_read( + version: int, +) -> tuple[dace.SDFG]: + sdfg = dace.SDFG(util.unique_name(f"double_read_version_{version}")) + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + A_in = state.add_access("A") + me, mx = state.add_map("map", ndrange={"__i0": "0:10"}) + + if version == 0: + _create_sdfg_double_read_part_1(sdfg, state, me, mx, A_in, 0) + _create_sdfg_double_read_part_2(sdfg, state, me, mx, A_in, 1) + elif version == 1: + _create_sdfg_double_read_part_1(sdfg, state, me, mx, A_in, 1) + _create_sdfg_double_read_part_2(sdfg, state, me, mx, A_in, 0) + else: + raise ValueError(f"Does not know version {version}") + sdfg.validate() + return sdfg + + +def test_local_double_buffering_double_read_sdfg(): + sdfg0 = _create_sdfg_double_read(0) + sdfg1 = _create_sdfg_double_read(1) + args0 = {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in "AB"} + args1 = copy.deepcopy(args0) + + count0 = gtx_transformations.gt_create_local_double_buffering(sdfg0) + assert count0 == 1 + + count1 = gtx_transformations.gt_create_local_double_buffering(sdfg1) + assert count1 == 1 + + sdfg0(**args0) + sdfg1(**args1) + for name in args0: + assert np.allclose(args0[name], args1[name]), f"Failed verification in '{name}'." + + +def test_local_double_buffering_no_connection(): + """There is no direct connection between read and write.""" + sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_connection")) + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + A_in, B, A_out = (state.add_access(name) for name in "ABA") + + comp_tskl, me, mx = state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("B[__i0]")}, + input_nodes={A_in}, + output_nodes={B}, + external_edges=True, + ) + + fill_tasklet = state.add_tasklet( + name="fill_tasklet", + inputs=set(), + code="__out = 2.", + outputs={"__out"}, + ) + state.add_nedge(me, fill_tasklet, dace.Memlet()) + state.add_edge(fill_tasklet, "__out", mx, "IN_1", dace.Memlet("A[__i0]")) + state.add_edge(mx, "OUT_1", A_out, None, dace.Memlet("A[0:10]")) + mx.add_in_connector("IN_1") + mx.add_out_connector("OUT_1") + sdfg.validate() + + count = gtx_transformations.gt_create_local_double_buffering(sdfg) + assert count == 1 + + # Ensure that a second application of the transformation does not run again. + count_again = gtx_transformations.gt_create_local_double_buffering(sdfg) + assert count_again == 0 + + # Find the newly created access node. + comp_tasklet_producers = [in_edge.src for in_edge in state.in_edges(comp_tskl)] + assert len(comp_tasklet_producers) == 1 + new_double_buffer = comp_tasklet_producers[0] + assert isinstance(new_double_buffer, dace_nodes.AccessNode) + assert not any(new_double_buffer.data == name for name in "AB") + assert isinstance(new_double_buffer.desc(sdfg), dace_data.Scalar) + assert new_double_buffer.desc(sdfg).transient + + # The newly created access node, must have an empty Memlet to the fill tasklet. + read_dependencies = [ + out_edge.dst for out_edge in state.out_edges(new_double_buffer) if out_edge.data.is_empty() + ] + assert len(read_dependencies) == 1 + assert read_dependencies[0] is fill_tasklet + + res = {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in "AB"} + ref = {"A": np.full_like(res["A"], 2.0), "B": res["A"] + 10.0} + sdfg(**res) + for name in res: + assert np.allclose(res[name], ref[name]), f"Failed verification in '{name}'." + + +def test_local_double_buffering_no_apply(): + """Here it does not apply, because are all distinct.""" + sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_apply")) + state = sdfg.add_state(is_start_block=True) + for name in "AB": + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("B[__i0]")}, + external_edges=True, + ) + sdfg.validate() + + count = gtx_transformations.gt_create_local_double_buffering(sdfg) + assert count == 0 + + +def test_local_double_buffering_already_buffered(): + """It is already buffered.""" + sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_apply")) + state = sdfg.add_state(is_start_block=True) + sdfg.add_array( + "A", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + + tsklt, me, mx = state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("A[__i0]")}, + external_edges=True, + ) + + sdfg.add_scalar("tmp", dtype=dace.float64, transient=True) + tmp = state.add_access("tmp") + me_to_tskl_edge = next(iter(state.out_edges(me))) + + state.add_edge(me, me_to_tskl_edge.src_conn, tmp, None, dace.Memlet("A[__i0]")) + state.add_edge(tmp, None, tsklt, "__in1", dace.Memlet("tmp[0]")) + state.remove_edge(me_to_tskl_edge) + sdfg.validate() + + count = gtx_transformations.gt_create_local_double_buffering(sdfg) + assert count == 0 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py new file mode 100644 index 0000000000..1543a048ad --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py @@ -0,0 +1,84 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +# from . import util + + +# dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes +import dace + + +def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: + sdfg = dace.SDFG("NAME") # util.unique_name("distributed_buffer_sdfg")) + + for name in ["a", "b", "tmp"]: + sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) + sdfg.arrays["tmp"].transient = True + sdfg.arrays["b"].shape = (100, 100) + + state1: dace.SDFGState = sdfg.add_state(is_start_block=True) + state1.add_mapped_tasklet( + "computation", + map_ranges={"__i1": "0:10", "__i2": "0:10"}, + inputs={"__in": dace.Memlet("a[__i1, __i2]")}, + code="__out = __in + 10.0", + outputs={"__out": dace.Memlet("tmp[__i1, __i2]")}, + external_edges=True, + ) + + state2 = sdfg.add_state_after(state1) + state2_tskl = state2.add_tasklet( + name="empty_blocker_tasklet", + inputs={}, + code="pass", + outputs={"__out"}, + side_effects=True, + ) + state2.add_edge( + state2_tskl, + "__out", + state2.add_access("a"), + None, + dace.Memlet("a[0, 0]"), + ) + + state3 = sdfg.add_state_after(state2) + state3.add_edge( + state3.add_access("tmp"), + None, + state3.add_access("b"), + None, + dace.Memlet("tmp[0:10, 0:10] -> [11:21, 22:32]"), + ) + sdfg.validate() + assert sdfg.number_of_nodes() == 3 + + return sdfg, state1 + + +def test_distributed_buffer_remover(): + sdfg, state1 = _mk_distributed_buffer_sdfg() + assert state1.number_of_nodes() == 5 + assert not any(dnode.data == "b" for dnode in state1.data_nodes()) + + res = gtx_transformations.gt_reduce_distributed_buffering(sdfg) + assert res is not None + + # Because the final state has now become empty + assert sdfg.number_of_nodes() == 3 + assert state1.number_of_nodes() == 6 + assert any(dnode.data == "b" for dnode in state1.data_nodes()) + assert any(dnode.data == "tmp" for dnode in state1.data_nodes()) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py new file mode 100644 index 0000000000..4ca44d43eb --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py @@ -0,0 +1,148 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_self_copy_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: + """Generates an SDFG that contains the self copying pattern.""" + sdfg = dace.SDFG(util.unique_name("self_copy_sdfg")) + state = sdfg.add_state(is_start_block=True) + + for name in "GT": + sdfg.add_array( + name, + shape=(10, 10), + dtype=dace.float64, + transient=True, + ) + sdfg.arrays["G"].transient = False + g_read, tmp_node, g_write = (state.add_access(name) for name in "GTG") + + state.add_nedge(g_read, tmp_node, dace.Memlet("G[0:10, 0:10]")) + state.add_nedge(tmp_node, g_write, dace.Memlet("G[0:10, 0:10]")) + sdfg.validate() + + return sdfg, state + + +def test_global_self_copy_elimination_only_pattern(): + """Contains only the pattern -> Total elimination.""" + sdfg, state = _make_self_copy_sdfg() + assert sdfg.number_of_nodes() == 1 + assert state.number_of_nodes() == 3 + assert util.count_nodes(state, dace_nodes.AccessNode) == 3 + assert state.number_of_edges() == 2 + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyGlobalSelfCopyElimination, validate=True, validate_all=True + ) + assert count != 0 + + assert sdfg.number_of_nodes() == 1 + assert ( + state.number_of_nodes() == 0 + ), f"Expected that 0 access nodes remained, but {state.number_of_nodes()} were there." + + +def test_global_self_copy_elimination_g_downstream(): + """`G` is read downstream. + + Since we ignore reads to `G` downstream, this will not influence the + transformation. + """ + sdfg, state1 = _make_self_copy_sdfg() + + # Add a read to `G` downstream. + state2 = sdfg.add_state_after(state1) + sdfg.add_array( + "output", + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + + state2.add_mapped_tasklet( + "downstream_computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("G[__i0, __i1]")}, + code="__out = __in + 10.0", + outputs={"__out": dace.Memlet("output[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + assert state2.number_of_nodes() == 5 + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyGlobalSelfCopyElimination, validate=True, validate_all=True + ) + assert count != 0 + + assert sdfg.number_of_nodes() == 2 + assert ( + state1.number_of_nodes() == 0 + ), f"Expected that 0 access nodes remained, but {state.number_of_nodes()} were there." + assert state2.number_of_nodes() == 5 + assert util.count_nodes(state2, dace_nodes.AccessNode) == 2 + assert util.count_nodes(state2, dace_nodes.MapEntry) == 1 + + +def test_global_self_copy_elimination_tmp_downstream(): + """`T` is read downstream. + + Because `T` is read downstream, the read to `G` will be retained, but the write + will be removed. + """ + sdfg, state1 = _make_self_copy_sdfg() + + # Add a read to `G` downstream. + state2 = sdfg.add_state_after(state1) + sdfg.add_array( + "output", + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + + state2.add_mapped_tasklet( + "downstream_computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in": dace.Memlet("T[__i0, __i1]")}, + code="__out = __in + 10.0", + outputs={"__out": dace.Memlet("output[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + assert state2.number_of_nodes() == 5 + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyGlobalSelfCopyElimination, validate=True, validate_all=True + ) + assert count != 0 + + assert sdfg.number_of_nodes() == 2 + assert state1.number_of_nodes() == 2 + assert util.count_nodes(state1, dace_nodes.AccessNode) == 2 + assert all(state1.degree(node) == 1 for node in state1.nodes()) + assert next(iter(state1.source_nodes())).data == "G" + assert next(iter(state1.sink_nodes())).data == "T" + + assert state2.number_of_nodes() == 5 + assert util.count_nodes(state2, dace_nodes.AccessNode) == 2 + assert util.count_nodes(state2, dace_nodes.MapEntry) == 1 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index 30266d71d1..89f067e5a9 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -24,11 +24,12 @@ def _get_trivial_gpu_promotable( tasklet_code: str, + trivial_map_range: str = "0", ) -> tuple[dace.SDFG, dace_nodes.MapEntry, dace_nodes.MapEntry]: - """Returns an SDFG that is suitable to test the `TrivialGPUMapPromoter` promoter. + """Returns an SDFG that is suitable to test the `TrivialGPUMapElimination` promoter. The first map is a trivial map (`Map[__trival_gpu_it=0]`) containing a Tasklet, - that does not have an output, but writes a scalar value into `tmp` (output + that does not have an input, but writes a scalar value into `tmp` (output connector `__out`), the body of this Tasklet can be controlled through the `tasklet_code` argument. The second map (`Map[__i0=0:N]`) contains a Tasklet that computes the sum of its @@ -41,6 +42,7 @@ def _get_trivial_gpu_promotable( Args: tasklet_code: The body of the Tasklet inside the trivial map. + trivial_map_range: Range of the trivial map, defaults to `"0"`. """ sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) state = sdfg.add_state("state", is_start_block=True) @@ -57,11 +59,11 @@ def _get_trivial_gpu_promotable( _, trivial_map_entry, _ = state.add_mapped_tasklet( "trivail_top_tasklet", - map_ranges={"__trivial_gpu_it": "0"}, + map_ranges={"__trivial_gpu_it": trivial_map_range}, inputs={}, code=tasklet_code, outputs={"__out": dace.Memlet("tmp[0]")}, - output_nodes={"tmp": tmp}, + output_nodes={tmp}, external_edges=True, schedule=schedule, ) @@ -74,15 +76,15 @@ def _get_trivial_gpu_promotable( }, code="__out = __in0 + __in1", outputs={"__out": dace.Memlet("b[__i0]")}, - input_nodes={"a": a, "tmp": tmp}, - output_nodes={"b": b}, + input_nodes={a, tmp}, + output_nodes={b}, external_edges=True, schedule=schedule, ) return sdfg, trivial_map_entry, second_map_entry -def test_trivial_gpu_map_promoter(): +def test_trivial_gpu_map_promoter_1(): """Tests if the GPU map promoter works. By using a body such as `__out = 3.0`, the transformation will apply. @@ -92,15 +94,15 @@ def test_trivial_gpu_map_promoter(): org_second_map_ranges = copy.deepcopy(second_map_entry.map.range) nb_runs = sdfg.apply_transformations_once_everywhere( - gtx_dace_fieldview_gpu_utils.TrivialGPUMapPromoter(), + gtx_dace_fieldview_gpu_utils.TrivialGPUMapElimination(do_not_fuse=True), validate=True, validate_all=True, ) assert ( nb_runs == 1 - ), f"Expected that 'TrivialGPUMapPromoter' applies once but it applied {nb_runs}." + ), f"Expected that 'TrivialGPUMapElimination' applies once but it applied {nb_runs}." trivial_map_params = trivial_map_entry.map.params - trivial_map_ranges = trivial_map_ranges.map.range + trivial_map_ranges = trivial_map_entry.map.range second_map_params = second_map_entry.map.params second_map_ranges = second_map_entry.map.range @@ -119,32 +121,82 @@ def test_trivial_gpu_map_promoter(): assert sdfg.is_valid() -def test_trivial_gpu_map_promoter(): +def test_trivial_gpu_map_promoter_2(): """Test if the GPU promoter does not fuse a special trivial map. By using a body such as `__out = __trivial_gpu_it` inside the - Tasklet's body, the map parameter is now used, and thus can not be fused. + Tasklet's body, the map parameter must now be replaced inside + the Tasklet's body. """ sdfg, trivial_map_entry, second_map_entry = _get_trivial_gpu_promotable( - "__out = __trivial_gpu_it" + tasklet_code="__out = __trivial_gpu_it", + trivial_map_range="2", + ) + state: dace.SDFGStae = sdfg.nodes()[0] + trivial_tasklet: dace_nodes.Tasklet = next( + iter( + out_edge.dst + for out_edge in state.out_edges(trivial_map_entry) + if isinstance(out_edge.dst, dace_nodes.Tasklet) + ) ) - org_trivial_map_params = list(trivial_map_entry.map.params) - org_second_map_params = list(second_map_entry.map.params) nb_runs = sdfg.apply_transformations_once_everywhere( - gtx_dace_fieldview_gpu_utils.TrivialGPUMapPromoter(), + gtx_dace_fieldview_gpu_utils.TrivialGPUMapElimination(do_not_fuse=True), validate=True, validate_all=True, ) - assert ( - nb_runs == 0 - ), f"Expected that 'TrivialGPUMapPromoter' does not apply but it applied {nb_runs}." - trivial_map_params = trivial_map_entry.map.params - second_map_params = second_map_entry.map.params - assert ( - trivial_map_params == org_trivial_map_params - ), f"Expected the trivial map to have parameters '{org_trivial_map_params}', but it had '{trivial_map_params}'." - assert ( - second_map_params == org_second_map_params - ), f"Expected the trivial map to have parameters '{org_trivial_map_params}', but it had '{trivial_map_params}'." - assert sdfg.is_valid() + assert nb_runs == 1 + + expected_trivial_code = "__out = 2" + assert trivial_tasklet.code == expected_trivial_code + + +def test_set_gpu_properties(): + """Tests the `gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize()`.""" + sdfg = dace.SDFG("gpu_properties_test") + state = sdfg.add_state(is_start_block=True) + + map_entries: dict[int, dace_nodes.MapEntry] = {} + for dim in [1, 2, 3]: + shape = (10,) * dim + sdfg.add_array( + f"A_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global + ) + sdfg.add_array( + f"B_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global + ) + _, me, _ = state.add_mapped_tasklet( + f"map_{dim}", + map_ranges={f"__i{i}": f"0:{s}" for i, s in enumerate(shape)}, + inputs={"__in": dace.Memlet(f"A_{dim}[{','.join(f'__i{i}' for i in range(dim))}]")}, + code="__out = math.cos(__in)", + outputs={"__out": dace.Memlet(f"B_{dim}[{','.join(f'__i{i}' for i in range(dim))}]")}, + external_edges=True, + ) + map_entries[dim] = me + + sdfg.apply_gpu_transformations() + sdfg.validate() + + gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize( + sdfg=sdfg, + block_size=(10, "11", 12), + launch_factor_2d=2, + block_size_2d=(2, 2, 2), + launch_bounds_3d=200, + ) + + map1, map2, map3 = (map_entries[d].map for d in [1, 2, 3]) + + assert len(map1.params) == 1 + assert map1.gpu_block_size == [10, 1, 1] + assert map1.gpu_launch_bounds == "0" + + assert len(map2.params) == 2 + assert map2.gpu_block_size == [2, 2, 1] + assert map2.gpu_launch_bounds == "8" + + assert len(map3.params) == 3 + assert map3.gpu_block_size == [10, 11, 12] + assert map3.gpu_launch_bounds == "200" diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py index c1e0ddd2f6..67bec9c09f 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py @@ -29,7 +29,7 @@ def _get_simple_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], np The k blocking transformation can be applied to the SDFG, however no node can be taken out. This is because how it is constructed. However, applying - some simplistic transformations this can be done. + some simplistic transformations will enable the transformation. """ sdfg = dace.SDFG(util.unique_name("simple_block_sdfg")) state = sdfg.add_state("state", is_start_block=True) @@ -136,6 +136,83 @@ def _get_chained_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], n return sdfg, lambda a, b: (a + (2 * b.reshape((-1, 1)) + 3)) +def _get_sdfg_with_empty_memlet( + first_tasklet_independent: bool, + only_empty_memlets: bool, +) -> tuple[ + dace.SDFG, dace_nodes.MapEntry, dace_nodes.Tasklet, dace_nodes.AccessNode, dace_nodes.Tasklet +]: + """Generates an SDFG with an empty tasklet. + + The map contains two (serial) tasklets, connected through an access node. + The first tasklet has an empty memlet that connects it to the map entry. + Depending on `first_tasklet_independent` the tasklet is either independent + or not. The second tasklet has an additional in connector that accesses an array. + + If `only_empty_memlets` is given then the second memlet will only depend + on the input of the first tasklet. However, since it is connected to the + map exit, it will be classified as dependent. + + Returns: + The function returns the SDFG, the map entry and the first tasklet (that + is either dependent or independent), the access node between the tasklets + and the second tasklet that is always dependent. + """ + sdfg = dace.SDFG(util.unique_name("empty_memlet_sdfg")) + state = sdfg.add_state("state", is_start_block=True) + sdfg.add_symbol("N", dace.int32) + sdfg.add_symbol("M", dace.int32) + sdfg.add_array("b", ("N", "M"), dace.float64, transient=False) + b = state.add_access("b") + sdfg.add_scalar("tmp", dtype=dace.float64, transient=True) + tmp = state.add_access("tmp") + + if not only_empty_memlets: + sdfg.add_array("a", ("N", "M"), dace.float64, transient=False) + a = state.add_access("a") + + # This is the first tasklet. + task1 = state.add_tasklet( + "task1", + inputs={}, + outputs={"__out0"}, + code="__out0 = 1.0" if first_tasklet_independent else "__out0 = j", + ) + + if only_empty_memlets: + task2 = state.add_tasklet( + "task2", inputs={"__in0"}, outputs={"__out0"}, code="__out0 = __in0 + 1.0" + ) + else: + task2 = state.add_tasklet( + "task2", inputs={"__in0", "__in1"}, outputs={"__out0"}, code="__out0 = __in0 + __in1" + ) + + # Now create the map + mentry, mexit = state.add_map("map", ndrange={"i": "0:N", "j": "0:M"}) + + if not only_empty_memlets: + state.add_edge(a, None, mentry, "IN_a", dace.Memlet("a[0:N, 0:M]")) + state.add_edge(mentry, "OUT_a", task2, "__in1", dace.Memlet("a[i, j]")) + + state.add_edge(task2, "__out0", mexit, "IN_b", dace.Memlet("b[i, j]")) + state.add_edge(mexit, "OUT_b", b, None, dace.Memlet("b[0:N, 0:M]")) + + state.add_edge(mentry, None, task1, None, dace.Memlet()) + state.add_edge(task1, "__out0", tmp, None, dace.Memlet("tmp[0]")) + state.add_edge(tmp, None, task2, "__in0", dace.Memlet("tmp[0]")) + + if not only_empty_memlets: + mentry.add_in_connector("IN_a") + mentry.add_out_connector("OUT_a") + mexit.add_in_connector("IN_b") + mexit.add_out_connector("OUT_b") + + sdfg.validate() + + return sdfg, mentry, task1, tmp, task2 + + def test_only_dependent(): """Just applying the transformation to the SDFG. @@ -152,11 +229,12 @@ def test_only_dependent(): ref = reff(a, b) # Apply the transformation - sdfg.apply_transformations_repeated( + count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) + assert count == 1 assert len(sdfg.states()) == 1 state = sdfg.states()[0] @@ -216,11 +294,12 @@ def test_intermediate_access_node(): assert np.allclose(ref, c) # Apply the transformation. - sdfg.apply_transformations_repeated( + count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) + assert count == 1 # Inspect if the SDFG was modified correctly. # We only inspect `tmp` which now has to be between the two maps. @@ -254,12 +333,12 @@ def test_chained_access() -> None: c[:] = 0 # Apply the transformation. - ret = sdfg.apply_transformations_repeated( + count = sdfg.apply_transformations_repeated( gtx_transformations.LoopBlocking(blocking_size=10, blocking_parameter="j"), validate=True, validate_all=True, ) - assert ret == 1, f"Expected that the transformation was applied 1 time, but it was {ret}." + assert count == 1 # Now run the SDFG to see if it is still the same sdfg(a=a, b=b, c=c, M=M, N=N) @@ -305,3 +384,471 @@ def test_chained_access() -> None: assert isinstance(inner_tasklet, dace_nodes.Tasklet) assert inner_tasklet not in first_level_tasklets + + +def test_direct_map_exit_connection() -> dace.SDFG: + """Generates a SDFG with a mapped independent tasklet connected to the map exit. + + Because the tasklet is connected to the map exit it can not be independent. + """ + sdfg = dace.SDFG(util.unique_name("mapped_tasklet_sdfg")) + state = sdfg.add_state("state", is_start_block=True) + sdfg.add_array("a", (10,), dace.float64, transient=False) + sdfg.add_array("b", (10, 30), dace.float64, transient=False) + tsklt, me, mx = state.add_mapped_tasklet( + name="comp", + map_ranges=dict(i=f"0:10", j=f"0:30"), + inputs=dict(__in0=dace.Memlet("a[i]")), + outputs=dict(__out=dace.Memlet("b[i, j]")), + code="__out = __in0 + 1", + external_edges=True, + ) + + assert all(out_edge.dst is tsklt for out_edge in state.out_edges(me)) + assert all(in_edge.src is tsklt for in_edge in state.in_edges(mx)) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1 + + assert all(isinstance(out_edge.dst, dace_nodes.MapEntry) for out_edge in state.out_edges(me)) + assert all(isinstance(in_edge.src, dace_nodes.MapExit) for in_edge in state.in_edges(mx)) + + +def test_empty_memlet_1(): + sdfg, mentry, itask, tmp, task2 = _get_sdfg_with_empty_memlet( + first_tasklet_independent=True, + only_empty_memlets=False, + ) + state: dace.SDFGState = next(iter(sdfg.nodes())) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1 + + scope_dict = state.scope_dict() + assert scope_dict[mentry] is None + assert scope_dict[itask] is mentry + assert scope_dict[tmp] is mentry + assert scope_dict[task2] is not mentry + assert scope_dict[task2] is not None + assert all( + isinstance(in_edge.src, dace_nodes.MapEntry) and in_edge.src is not mentry + for in_edge in state.in_edges(task2) + ) + + +def test_empty_memlet_2(): + sdfg, mentry, dtask, tmp, task2 = _get_sdfg_with_empty_memlet( + first_tasklet_independent=False, + only_empty_memlets=False, + ) + state: dace.SDFGState = next(iter(sdfg.nodes())) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1 + + # Find the inner map entry + assert all( + isinstance(out_edge.dst, dace_nodes.MapEntry) for out_edge in state.out_edges(mentry) + ) + inner_mentry = next(iter(state.out_edges(mentry))).dst + + scope_dict = state.scope_dict() + assert scope_dict[mentry] is None + assert scope_dict[inner_mentry] is mentry + assert scope_dict[dtask] is inner_mentry + assert scope_dict[tmp] is inner_mentry + assert scope_dict[task2] is inner_mentry + + +def test_empty_memlet_3(): + # This is the only interesting case with only empty memlet. + sdfg, mentry, dtask, tmp, task2 = _get_sdfg_with_empty_memlet( + first_tasklet_independent=False, + only_empty_memlets=True, + ) + state: dace.SDFGState = next(iter(sdfg.nodes())) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1 + + # The top map only has a single output, which is the empty edge, that is holding + # the inner map entry in the scope. + assert all(out_edge.data.is_empty() for out_edge in state.out_edges(mentry)) + assert state.in_degree(mentry) == 0 + assert state.out_degree(mentry) == 1 + assert all( + isinstance(out_edge.dst, dace_nodes.MapEntry) for out_edge in state.out_edges(mentry) + ) + + inner_mentry = next(iter(state.out_edges(mentry))).dst + + scope_dict = state.scope_dict() + assert scope_dict[mentry] is None + assert scope_dict[inner_mentry] is mentry + assert scope_dict[dtask] is inner_mentry + assert scope_dict[tmp] is inner_mentry + assert scope_dict[task2] is inner_mentry + + +def _make_loop_blocking_sdfg_with_inner_map( + add_independent_part: bool, +) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapEntry]: + """ + Generate the SDFGs with an inner map. + + The SDFG has an inner map that is classified as dependent. If + `add_independent_part` is `True` then the SDFG has a part that is independent. + Note that everything is read from a single connector. + + Return: + The function will return the SDFG, the state and the map entry for the outer + and inner map. + """ + sdfg = dace.SDFG(util.unique_name("sdfg_with_inner_map")) + state = sdfg.add_state(is_start_block=True) + + for name in "AB": + sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) + + me_out, mx_out = state.add_map("outer_map", ndrange={"__i0": "0:10"}) + me_in, mx_in = state.add_map("inner_map", ndrange={"__i1": "0:10"}) + A, B = (state.add_access(name) for name in "AB") + tskl = state.add_tasklet( + "computation", inputs={"__in1", "__in2"}, outputs={"__out"}, code="__out = __in1 + __in2" + ) + + if add_independent_part: + sdfg.add_array("C", shape=(10,), dtype=dace.float64, transient=False) + sdfg.add_scalar("tmp", dtype=dace.float64, transient=True) + sdfg.add_scalar("tmp2", dtype=dace.float64, transient=True) + tmp, tmp2, C = (state.add_access(name) for name in ("tmp", "tmp2", "C")) + tskli = state.add_tasklet( + "independent_comp", inputs={"__field"}, outputs={"__out"}, code="__out = __field[1, 1]" + ) + + # construct the inner map of the map. + state.add_edge(A, None, me_out, "IN_A", dace.Memlet("A[0:10, 0:10]")) + me_out.add_in_connector("IN_A") + state.add_edge(me_out, "OUT_A", me_in, "IN_A", dace.Memlet("A[__i0, 0:10]")) + me_out.add_out_connector("OUT_A") + me_in.add_in_connector("IN_A") + state.add_edge(me_in, "OUT_A", tskl, "__in1", dace.Memlet("A[__i0, __i1]")) + me_in.add_out_connector("OUT_A") + + state.add_edge(me_out, "OUT_A", me_in, "IN_A1", dace.Memlet("A[__i0, 0:10]")) + me_in.add_in_connector("IN_A1") + state.add_edge(me_in, "OUT_A1", tskl, "__in2", dace.Memlet("A[__i0, 9 - __i1]")) + me_in.add_out_connector("OUT_A1") + + state.add_edge(tskl, "__out", mx_in, "IN_B", dace.Memlet("B[__i0, __i1]")) + mx_in.add_in_connector("IN_B") + state.add_edge(mx_in, "OUT_B", mx_out, "IN_B", dace.Memlet("B[__i0, 0:10]")) + mx_in.add_out_connector("OUT_B") + mx_out.add_in_connector("IN_B") + state.add_edge(mx_out, "OUT_B", B, None, dace.Memlet("B[0:10, 0:10]")) + mx_out.add_out_connector("OUT_B") + + # If requested add a part that is independent, i.e. is before the inner loop + if add_independent_part: + state.add_edge(me_out, "OUT_A", tskli, "__field", dace.Memlet("A[0:10, 0:10]")) + state.add_edge(tskli, "__out", tmp, None, dace.Memlet("tmp[0]")) + state.add_edge(tmp, None, tmp2, None, dace.Memlet("tmp2[0]")) + state.add_edge(tmp2, None, mx_out, "IN_tmp", dace.Memlet("C[__i0]")) + mx_out.add_in_connector("IN_tmp") + state.add_edge(mx_out, "OUT_tmp", C, None, dace.Memlet("C[0:10]")) + mx_out.add_out_connector("OUT_tmp") + + sdfg.validate() + return sdfg, state, me_out, me_in + + +def test_loop_blocking_inner_map(): + """ + Tests with an inner map, without an independent part. + """ + sdfg, state, outer_map, inner_map = _make_loop_blocking_sdfg_with_inner_map(False) + assert all(oedge.dst is inner_map for oedge in state.out_edges(outer_map)) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="__i0"), + validate=True, + validate_all=True, + ) + assert count == 1 + assert all( + oedge.dst is not inner_map and isinstance(oedge.dst, dace_nodes.MapEntry) + for oedge in state.out_edges(outer_map) + ) + inner_blocking_map: dace_nodes.MapEntry = next( + oedge.dst + for oedge in state.out_edges(outer_map) + if isinstance(oedge.dst, dace_nodes.MapEntry) + ) + assert inner_blocking_map is not inner_map + + assert all(oedge.dst is inner_map for oedge in state.out_edges(inner_blocking_map)) + + +def test_loop_blocking_inner_map_with_independent_part(): + """ + Tests with an inner map with an independent part. + """ + sdfg, state, outer_map, inner_map = _make_loop_blocking_sdfg_with_inner_map(True) + + # Find the parts that are independent. + itskl: dace_nodes.Tasklet = next( + oedge.dst + for oedge in state.out_edges(outer_map) + if isinstance(oedge.dst, dace_nodes.Tasklet) + ) + assert itskl.label == "independent_comp" + i_access_node: dace_nodes.AccessNode = next(oedge.dst for oedge in state.out_edges(itskl)) + assert i_access_node.data == "tmp" + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=5, blocking_parameter="__i0"), + validate=True, + validate_all=True, + ) + assert count == 1 + inner_blocking_map: dace_nodes.MapEntry = next( + oedge.dst + for oedge in state.out_edges(outer_map) + if isinstance(oedge.dst, dace_nodes.MapEntry) + ) + assert inner_blocking_map is not inner_map + + assert all(oedge.dst in {inner_blocking_map, itskl} for oedge in state.out_edges(outer_map)) + assert state.scope_dict()[i_access_node] is outer_map + assert all(oedge.dst is inner_blocking_map for oedge in state.out_edges(i_access_node)) + + +def _make_mixed_memlet_sdfg( + tskl1_independent: bool, +) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.Tasklet, dace_nodes.Tasklet]: + """ + Generates the SDFGs for the mixed Memlet tests. + + The SDFG that is generated has the following structure: + - `tsklt2`, is always dependent, it has an incoming connection from the + map entry, and an incoming, but empty, connection with `tskl1`. + - `tskl1` is connected to the map entry, depending on `tskl1_independent` + it is independent or dependent, it has an empty connection to `tskl2`, + thus it is sequenced before. + - Both have connection to other nodes down stream, but they are dependent. + + Returns: + A tuple containing the following objects. + - The SDFG. + - The SDFG state. + - The outer map entry node. + - `tskl1`. + - `tskl2`. + """ + sdfg = dace.SDFG(util.unique_name("mixed_memlet_sdfg")) + state = sdfg.add_state(is_start_block=True) + names_array = ["A", "B", "C"] + names_scalar = ["tmp1", "tmp2"] + for aname in names_array: + sdfg.add_array( + aname, + shape=((10,) if aname == "A" else (10, 10)), + dtype=dace.float64, + transient=False, + ) + for sname in names_scalar: + sdfg.add_scalar( + sname, + dtype=dace.float64, + transient=True, + ) + A, B, C, tmp1, tmp2 = (state.add_access(name) for name in names_array + names_scalar) + + me, mx = state.add_map("outer_map", ndrange={"i": "0:10", "j": "0:10"}) + tskl1 = state.add_tasklet( + "tskl1", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1" if tskl1_independent else "__out = __in1 + j", + ) + tskl2 = state.add_tasklet( + "tskl2", + inputs={"__in1"}, + outputs={"__out"}, + code="__out = __in1 + 10.0", + ) + tskl3 = state.add_tasklet( + "tskl3", + inputs={"__in1", "__in2"}, + outputs={"__out"}, + code="__out = __in1 + __in2", + ) + + state.add_edge(A, None, me, "IN_A", dace.Memlet("A[0:10]")) + me.add_in_connector("IN_A") + state.add_edge(me, "OUT_A", tskl1, "__in1", dace.Memlet("A[i]")) + me.add_out_connector("OUT_A") + state.add_edge(tskl1, "__out", tmp1, None, dace.Memlet("tmp1[0]")) + + state.add_edge(B, None, me, "IN_B", dace.Memlet("B[0:10, 0:10]")) + me.add_in_connector("IN_B") + state.add_edge(me, "OUT_B", tskl2, "__in1", dace.Memlet("B[i, j]")) + me.add_out_connector("OUT_B") + state.add_edge(tskl2, "__out", tmp2, None, dace.Memlet("tmp2[0]")) + + # Add the empty Memlet that sequences `tskl1` before `tskl2`. + state.add_edge(tskl1, None, tskl2, None, dace.Memlet()) + + state.add_edge(tmp1, None, tskl3, "__in1", dace.Memlet("tmp1[0]")) + state.add_edge(tmp2, None, tskl3, "__in2", dace.Memlet("tmp2[0]")) + state.add_edge(tskl3, "__out", mx, "IN_C", dace.Memlet("C[i, j]")) + mx.add_in_connector("IN_C") + state.add_edge(mx, "OUT_C", C, None, dace.Memlet("C[0:10, 0:10]")) + mx.add_out_connector("OUT_C") + sdfg.validate() + + return (sdfg, state, me, tskl1, tskl2) + + +def _apply_and_run_mixed_memlet_sdfg(sdfg: dace.SDFG) -> None: + ref = { + "A": np.array(np.random.rand(10), dtype=np.float64, copy=True), + "B": np.array(np.random.rand(10, 10), dtype=np.float64, copy=True), + "C": np.array(np.random.rand(10, 10), dtype=np.float64, copy=True), + } + res = copy.deepcopy(ref) + sdfg(**ref) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking(blocking_size=2, blocking_parameter="j"), + validate=True, + validate_all=True, + ) + assert count == 1, f"Expected one application, but git {count}" + sdfg(**res) + assert all(np.allclose(ref[name], res[name]) for name in ref) + + +def test_loop_blocking_mixked_memlets_1(): + sdfg, state, me, tskl1, tskl2 = _make_mixed_memlet_sdfg(True) + mx = state.exit_node(me) + + _apply_and_run_mixed_memlet_sdfg(sdfg) + scope_dict = state.scope_dict() + + # Ensure that `tskl1` is independent. + assert scope_dict[tskl1] is me + + # The output of `tskl1`, which is `tmp1` should also be classified as independent. + tmp1 = next(iter(edge.dst for edge in state.out_edges(tskl1) if not edge.data.is_empty())) + assert scope_dict[tmp1] is me + assert isinstance(tmp1, dace_nodes.AccessNode) + assert tmp1.data == "tmp1" + + # Find the inner map. + inner_map_entry: dace_nodes.MapEntry = scope_dict[tskl2] + assert inner_map_entry is not me and isinstance(inner_map_entry, dace_nodes.MapEntry) + inner_map_exit: dace_nodes.MapExit = state.exit_node(inner_map_entry) + + outer_scope = {tskl1, tmp1, inner_map_entry, inner_map_exit, mx} + for node in state.nodes(): + if scope_dict[node] is None: + assert (node is me) or ( + isinstance(node, dace_nodes.AccessNode) and node.data in {"A", "B", "C"} + ) + elif scope_dict[node] is me: + assert node in outer_scope + else: + assert ( + (node is inner_map_exit) + or (isinstance(node, dace_nodes.AccessNode) and node.data == "tmp2") + or (isinstance(node, dace_nodes.Tasklet) and node.label in {"tskl2", "tskl3"}) + ) + + +def test_loop_blocking_mixked_memlets_2(): + sdfg, state, me, tskl1, tskl2 = _make_mixed_memlet_sdfg(False) + mx = state.exit_node(me) + + _apply_and_run_mixed_memlet_sdfg(sdfg) + scope_dict = state.scope_dict() + + # Because `tskl1` is now dependent, everything is now dependent. + inner_map_entry = scope_dict[tskl1] + assert isinstance(inner_map_entry, dace_nodes.MapEntry) + assert inner_map_entry is not me + + for node in state.nodes(): + if scope_dict[node] is None: + assert (node is me) or ( + isinstance(node, dace_nodes.AccessNode) and node.data in {"A", "B", "C"} + ) + elif scope_dict[node] is me: + assert isinstance(node, dace_nodes.MapEntry) or (node is mx) + else: + assert scope_dict[node] is inner_map_entry + + +def test_loop_blocking_no_independent_nodes(): + import dace + + sdfg = dace.SDFG(util.unique_name("mixed_memlet_sdfg")) + state = sdfg.add_state(is_start_block=True) + names = ["A", "B"] + for aname in names: + sdfg.add_array( + aname, + shape=(10, 10), + dtype=dace.float64, + transient=False, + ) + state.add_mapped_tasklet( + "fully_dependent_computation", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={"__in1": dace.Memlet("A[__i0, __i1]")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet("B[__i0, __i1]")}, + external_edges=True, + ) + sdfg.validate() + + # Because there is nothing that is independent the transformation will + # not apply if `require_independent_nodes` is enabled. + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking( + blocking_size=2, + blocking_parameter="__i1", + require_independent_nodes=True, + ), + validate=True, + validate_all=True, + ) + assert count == 0 + + # But it will apply once this requirement is lifted. + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking( + blocking_size=2, + blocking_parameter="__i1", + require_independent_nodes=False, + ), + validate=True, + validate_all=True, + ) + assert count == 1 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py new file mode 100644 index 0000000000..1a4ce6d047 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py @@ -0,0 +1,264 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np +import copy + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_test_data(names: list[str]) -> dict[str, np.ndarray]: + return {name: np.array(np.random.rand(10), dtype=np.float64, copy=True) for name in names} + + +def _make_test_sdfg( + output_name: str = "G", + input_name: str = "G", + tmp_name: str = "T", + array_size: int | str = 10, + tmp_size: int | str | None = None, + map_range: tuple[int | str, int | str] | None = None, + tmp_to_glob_memlet: str | None = None, + in_offset: str | None = None, + out_offset: str | None = None, +) -> dace.SDFG: + if isinstance(array_size, str): + array_size = sdfg.add_symbol(array_size, dace.int32, find_new_name=True) + if tmp_size is None: + tmp_size = array_size + if map_range is None: + map_range = (0, array_size) + if tmp_to_glob_memlet is None: + tmp_to_glob_memlet = f"{tmp_name}[0:{array_size}] -> [0:{array_size}]" + elif tmp_to_glob_memlet[0] == "[": + tmp_to_glob_memlet = tmp_name + tmp_to_glob_memlet + if in_offset is None: + in_offset = "0" + if out_offset is None: + out_offset = in_offset + + sdfg = dace.SDFG(util.unique_name("map_buffer")) + state = sdfg.add_state(is_start_block=True) + names = {input_name, tmp_name, output_name} + for name in names: + sdfg.add_array( + name, + shape=((array_size,) if name != tmp_name else (tmp_size,)), + dtype=dace.float64, + transient=False, + ) + sdfg.arrays[tmp_name].transient = True + + input_ac = state.add_access(input_name) + tmp_ac = state.add_access(tmp_name) + output_ac = state.add_access(output_name) + + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": f"{map_range[0]}:{map_range[1]}"}, + inputs={"__in1": dace.Memlet(data=input_ac.data, subset=f"__i0 + {in_offset}")}, + code="__out = __in1 + 10.0", + outputs={"__out": dace.Memlet(data=tmp_ac.data, subset=f"__i0 + {out_offset}")}, + input_nodes={input_ac}, + output_nodes={tmp_ac}, + external_edges=True, + ) + state.add_edge( + tmp_ac, + None, + output_ac, + None, + dace.Memlet(tmp_to_glob_memlet), + ) + sdfg.validate() + return sdfg + + +def _perform_test( + sdfg: dace.SDFG, + xform: gtx_transformations.GT4PyMapBufferElimination, + exp_count: int, + array_size: int = 10, +) -> None: + ref = { + name: np.array(np.random.rand(array_size), dtype=np.float64, copy=True) + for name, desc in sdfg.arrays.items() + if not desc.transient + } + if "array_size" in sdfg.symbols: + ref["array_size"] = array_size + + res = copy.deepcopy(ref) + sdfg(**ref) + + count = sdfg.apply_transformations_repeated([xform], validate=True, validate_all=True) + assert count == exp_count, f"Expected {exp_count} applications, but got {count}" + + if count == 0: + return + + sdfg(**res) + assert all(np.allclose(ref[name], res[name]) for name in ref.keys()), f"Failed for '{name}'." + + +def test_map_buffer_elimination_simple(): + sdfg = _make_test_sdfg() + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=True), + exp_count=1, + ) + + +def test_map_buffer_elimination_simple_2(): + sdfg = _make_test_sdfg() + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=0, + ) + + +def test_map_buffer_elimination_simple_3(): + sdfg = _make_test_sdfg(input_name="A", output_name="O") + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_offset_1(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + tmp_to_glob_memlet="[2:8] -> [2:8]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_offset_2(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + in_offset="-2", + out_offset="-2", + tmp_to_glob_memlet="[0:6] -> [0:6]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_offset_3(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + in_offset="-2", + out_offset="-2", + tmp_to_glob_memlet="[0:6] -> [2:8]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_offset_4(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + in_offset="-2", + out_offset="-2", + tmp_to_glob_memlet="[1:7] -> [2:8]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=0, + ) + + +def test_map_buffer_elimination_offset_5(): + sdfg = _make_test_sdfg( + map_range=(2, 8), + tmp_size=6, + in_offset="0", + out_offset="-2", + tmp_to_glob_memlet="[0:6] -> [2:8]", + input_name="A", + output_name="O", + ) + _perform_test( + sdfg, + gtx_transformations.GT4PyMapBufferElimination(assume_pointwise=False), + exp_count=1, + ) + + +def test_map_buffer_elimination_not_apply(): + """Indirect accessing, because of this the double buffer is needed.""" + sdfg = dace.SDFG(util.unique_name("map_buffer")) + state = sdfg.add_state(is_start_block=True) + + names = ["A", "tmp", "idx"] + for name in names: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.int32 if name == "tmp" else dace.float64, + transient=False, + ) + sdfg.arrays["tmp"].transient = True + + tmp = state.add_access("tmp") + state.add_mapped_tasklet( + "indirect_accessing", + map_ranges={"__i0": "0:10"}, + inputs={ + "__field": dace.Memlet("A[0:10]"), + "__idx": dace.Memlet("idx[__i0]"), + }, + code="__out = __field[__idx]", + outputs={"__out": dace.Memlet("tmp[__i0]")}, + output_nodes={tmp}, + external_edges=True, + ) + state.add_nedge(tmp, state.add_access("A"), dace.Memlet("tmp[0:10] -> [0:10]")) + + # TODO(phimuell): Update the transformation such that we can specify + # `assume_pointwise=True` and the test would still pass. + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMapBufferElimination( + assume_pointwise=False, + ), + validate=True, + validate_all=True, + ) + assert count == 0 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py index c9d467ba80..b468b80b8e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion.py @@ -58,14 +58,14 @@ def _make_serial_sdfg_1( inputs={"__in0": dace.Memlet("a[__i0, __i1]")}, code="__out = __in0 + 1.0", outputs={"__out": dace.Memlet("tmp[__i0, __i1]")}, - output_nodes={"tmp": tmp}, + output_nodes={tmp}, external_edges=True, ) state.add_mapped_tasklet( name="second_computation", map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], - input_nodes={"tmp": tmp}, + input_nodes={tmp}, inputs={"__in0": dace.Memlet("tmp[__i0, __i1]")}, code="__out = __in0 + 3.0", outputs={"__out": dace.Memlet("b[__i0, __i1]")}, @@ -118,17 +118,14 @@ def _make_serial_sdfg_2( "__out0": dace.Memlet("tmp_1[__i0, __i1]"), "__out1": dace.Memlet("tmp_2[__i0, __i1]"), }, - output_nodes={ - "tmp_1": tmp_1, - "tmp_2": tmp_2, - }, + output_nodes={tmp_1, tmp_2}, external_edges=True, ) state.add_mapped_tasklet( name="first_computation", map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], - input_nodes={"tmp_1": tmp_1}, + input_nodes={tmp_1}, inputs={"__in0": dace.Memlet("tmp_1[__i0, __i1]")}, code="__out = __in0 + 3.0", outputs={"__out": dace.Memlet("b[__i0, __i1]")}, @@ -137,7 +134,7 @@ def _make_serial_sdfg_2( state.add_mapped_tasklet( name="second_computation", map_ranges=[("__i0", f"0:{N}"), ("__i1", f"0:{N}")], - input_nodes={"tmp_2": tmp_2}, + input_nodes={tmp_2}, inputs={"__in0": dace.Memlet("tmp_2[__i0, __i1]")}, code="__out = __in0 - 3.0", outputs={"__out": dace.Memlet("c[__i0, __i1]")}, @@ -194,14 +191,14 @@ def _make_serial_sdfg_3( }, code="__out = __in0 + __in1", outputs={"__out": dace.Memlet("tmp[__i0]")}, - output_nodes={"tmp": tmp}, + output_nodes={tmp}, external_edges=True, ) state.add_mapped_tasklet( name="indirect_access", map_ranges=[("__i0", f"0:{N_output}")], - input_nodes={"tmp": tmp}, + input_nodes={tmp}, inputs={ "__index": dace.Memlet("idx[__i0]"), "__array": dace.Memlet.simple("tmp", subset_str=f"0:{N_input}", num_accesses=1), @@ -220,19 +217,19 @@ def test_exclusive_itermediate(): sdfg = _make_serial_sdfg_1(N) # Now apply the optimizations. - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 1 assert "tmp" not in sdfg.arrays # Test if the intermediate is a scalar intermediate_nodes: list[dace_nodes.Node] = [ node - for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) + for node in util.count_nodes(sdfg, dace_nodes.AccessNode, True) if node.data not in ["a", "b"] ] assert len(intermediate_nodes) == 1 @@ -257,19 +254,19 @@ def test_shared_itermediate(): sdfg.arrays["tmp"].transient = False # Now apply the optimizations. - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 1 assert "tmp" in sdfg.arrays # Test if the intermediate is a scalar intermediate_nodes: list[dace_nodes.Node] = [ node - for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) + for node in util.count_nodes(sdfg, dace_nodes.AccessNode, True) if node.data not in ["a", "b", "tmp"] ] assert len(intermediate_nodes) == 1 @@ -291,21 +288,21 @@ def test_pure_output_node(): """Tests the path of a pure intermediate.""" N = 10 sdfg = _make_serial_sdfg_2(N) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 3 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 3 # The first fusion will only bring it down to two maps. sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 1 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 1 a = np.random.rand(N, N) b = np.empty_like(a) @@ -327,17 +324,17 @@ def test_array_intermediate(): """ N = 10 sdfg = _make_serial_sdfg_1(N) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 sdfg.apply_transformations_repeated([dace_dataflow.MapExpansion]) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 4 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 4 # Now perform the fusion sdfg.apply_transformations( - gtx_transformations.SerialMapFusion(only_toplevel_maps=True), + gtx_transformations.MapFusionSerial(only_toplevel_maps=True), validate=True, validate_all=True, ) - map_entries = util._count_nodes(sdfg, dace_nodes.MapEntry, return_nodes=True) + map_entries = util.count_nodes(sdfg, dace_nodes.MapEntry, return_nodes=True) scope = next(iter(sdfg.states())).scope_dict() assert len(map_entries) == 3 @@ -349,7 +346,7 @@ def test_array_intermediate(): # Find the access node that is the new intermediate node. inner_access_nodes: list[dace_nodes.AccessNode] = [ node - for node in util._count_nodes(sdfg, dace_nodes.AccessNode, True) + for node in util.count_nodes(sdfg, dace_nodes.AccessNode, True) if scope[node] is not None ] assert len(inner_access_nodes) == 1 @@ -374,7 +371,7 @@ def test_interstate_transient(): """ N = 10 sdfg = _make_serial_sdfg_2(N) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 3 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 3 assert sdfg.number_of_nodes() == 1 # Now add the new state and the new output. @@ -393,15 +390,15 @@ def test_interstate_transient(): # Now apply the transformation sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) assert "tmp_1" in sdfg.arrays assert "tmp_2" not in sdfg.arrays assert sdfg.number_of_nodes() == 2 - assert util._count_nodes(head_state, dace_nodes.MapEntry) == 1 - assert util._count_nodes(new_state, dace_nodes.MapEntry) == 1 + assert util.count_nodes(head_state, dace_nodes.MapEntry) == 1 + assert util.count_nodes(new_state, dace_nodes.MapEntry) == 1 a = np.random.rand(N, N) b = np.empty_like(a) @@ -430,7 +427,7 @@ def test_indirect_access(): c = np.empty(N_output) idx = np.random.randint(low=0, high=N_input, size=N_output, dtype=np.int32) sdfg = _make_serial_sdfg_3(N_input=N_input, N_output=N_output) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 def _ref(a, b, idx): tmp = a + b @@ -443,11 +440,11 @@ def _ref(a, b, idx): # Now "apply" the transformation sdfg.apply_transformations_repeated( - gtx_transformations.SerialMapFusion(), + gtx_transformations.MapFusionSerial(), validate=True, validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 c[:] = -1.0 sdfg(a=a, b=b, idx=idx, c=c) @@ -455,5 +452,58 @@ def _ref(a, b, idx): def test_indirect_access_2(): - # TODO(phimuell): Index should be computed and that map should be fusable. - pass + """Indirect accesses, with non point wise input dependencies. + + Because `a` is used as input and output and `a` is indirectly accessed + the access to `a` can not be point wise so, fusing is not possible. + """ + sdfg = dace.SDFG(util.unique_name("indirect_access_sdfg_2")) + state = sdfg.add_state(is_start_block=True) + + names = ["a", "b", "idx", "tmp"] + + for name in names: + sdfg.add_array( + name=name, + shape=(10,), + dtype=dace.int32 if name == "idx" else dace.float64, + transient=False, + ) + sdfg.arrays["tmp"].transient = True + + a_in, b, idx, tmp, a_out = (state.add_access(name) for name in (names + ["a"])) + + state.add_mapped_tasklet( + "indirect_access", + map_ranges={"__i0": "0:10"}, + inputs={ + "__idx": dace.Memlet("idx[__i0]"), + "__field": dace.Memlet("a[0:10]", volume=1), + }, + code="__out = __field[__idx]", + outputs={"__out": dace.Memlet("tmp[__i0]")}, + input_nodes={a_in, idx}, + output_nodes={tmp}, + external_edges=True, + ) + state.add_mapped_tasklet( + "computation", + map_ranges={"__i0": "0:10"}, + inputs={ + "__in1": dace.Memlet("tmp[__i0]"), + "__in2": dace.Memlet("b[__i0]"), + }, + code="__out = __in1 + __in2", + outputs={"__out": dace.Memlet("a[__i0]")}, + input_nodes={tmp, b}, + output_nodes={a_out}, + external_edges=True, + ) + sdfg.validate() + + count = sdfg.apply_transformations_repeated( + gtx_transformations.MapFusionSerial(), + validate=True, + validate_all=True, + ) + assert count == 0 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py new file mode 100644 index 0000000000..72efc2fe34 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py @@ -0,0 +1,100 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import numpy as np + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes + + +def _perform_reorder_test( + sdfg: dace.SDFG, + leading_dim: list[str], + expected_order: list[str], +) -> None: + """Performs the reorder transformation and test it. + + If `expected_order` is the empty list, then the transformation should not apply. + """ + map_entries: list[dace.nodes.MapEntry] = util.count_nodes(sdfg, dace.nodes.MapEntry, True) + assert len(map_entries) == 1 + map_entry: dace.nodes.MapEntry = map_entries[0] + old_map_params = map_entry.map.params.copy() + + apply_count = sdfg.apply_transformations_repeated( + gtx_transformations.MapIterationOrder( + leading_dims=leading_dim, + ), + validate=True, + validate_all=True, + ) + new_map_params = map_entry.map.params.copy() + + if len(expected_order) == 0: + assert ( + apply_count == 0 + ), f"Expected that the transformation was not applied. New map order: {map_entry.map.params}" + return + else: + assert ( + apply_count > 0 + ), f"Expected that the transformation was applied. Old map order: {map_entry.map.params}; Expected order: {expected_order}" + assert len(expected_order) == len(new_map_params) + + assert ( + expected_order == new_map_params + ), f"Expected map order {expected_order} but got {new_map_params} instead." + + +def _make_test_sdfg(map_params: list[str]) -> dace.SDFG: + """Generate an SDFG for the test.""" + sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) + state: dace.SDFGState = sdfg.add_state("state", is_start_block=True) + dim = len(map_params) + for aname in ["a", "b"]: + sdfg.add_array(aname, shape=((4,) * dim), dtype=dace.float64, transient=False) + + state.add_mapped_tasklet( + "mapped_tasklet", + map_ranges=[(map_param, "0:4") for map_param in map_params], + inputs={"__in": dace.Memlet("a[" + ",".join(map_params) + "]")}, + code="__out = __in + 1", + outputs={"__out": dace.Memlet("b[" + ",".join(map_params) + "]")}, + external_edges=True, + ) + sdfg.validate() + + return sdfg + + +def test_map_order_1(): + sdfg = _make_test_sdfg(["EDim", "KDim", "VDim"]) + _perform_reorder_test(sdfg, ["EDim", "VDim"], ["KDim", "VDim", "EDim"]) + + +def test_map_order_2(): + sdfg = _make_test_sdfg(["VDim", "KDim"]) + _perform_reorder_test(sdfg, ["EDim", "VDim"], ["KDim", "VDim"]) + + +def test_map_order_3(): + sdfg = _make_test_sdfg(["EDim", "KDim"]) + _perform_reorder_test(sdfg, ["EDim", "VDim"], ["KDim", "EDim"]) + + +def test_map_order_4(): + sdfg = _make_test_sdfg(["CDim", "KDim"]) + _perform_reorder_test(sdfg, ["EDim", "VDim"], []) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py new file mode 100644 index 0000000000..7b39bc4e1d --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py @@ -0,0 +1,164 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import numpy as np +import pytest + + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes, propagation as dace_propagation +from dace.transformation import dataflow as dace_dataflow + +from gt4py.next.program_processors.runners.dace_fieldview import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_movable_tasklet( + outer_tasklet_code: str, +) -> tuple[ + dace.SDFG, dace.SDFGState, dace_nodes.Tasklet, dace_nodes.AccessNode, dace_nodes.MapEntry +]: + sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) + state = sdfg.add_state("state", is_start_block=True) + + sdfg.add_scalar("outer_scalar", dtype=dace.float64, transient=True) + for name in "AB": + sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) + A, B, outer_scalar = (state.add_access(name) for name in ["A", "B", "outer_scalar"]) + + outer_tasklet = state.add_tasklet( + name="outer_tasklet", + inputs=set(), + outputs={"__out"}, + code=f"__out = {outer_tasklet_code}", + ) + state.add_edge(outer_tasklet, "__out", outer_scalar, None, dace.Memlet("outer_scalar[0]")) + + _, me, _ = state.add_mapped_tasklet( + "map", + map_ranges={"__i0": "0:10", "__i1": "0:10"}, + inputs={ + "__in0": dace.Memlet("A[__i0, __i1]"), + "__in1": dace.Memlet("outer_scalar[0]"), + }, + code="__out = __in0 + __in1", + outputs={"__out": dace.Memlet("B[__i0, __i1]")}, + external_edges=True, + input_nodes={outer_scalar, A}, + output_nodes={B}, + ) + sdfg.validate() + + return sdfg, state, outer_tasklet, outer_scalar, me + + +def test_move_tasklet_inside_trivial_memlet_tree(): + sdfg, state, outer_tasklet, outer_scalar, me = _make_movable_tasklet( + outer_tasklet_code="1.2", + ) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate_all=True, + ) + assert count == 1 + + A = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + B = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + ref = A + 1.2 + + csdfg = sdfg.compile() + csdfg(A=A, B=B) + assert np.allclose(B, ref) + + +def test_move_tasklet_inside_non_trivial_memlet_tree(): + sdfg, state, outer_tasklet, outer_scalar, me = _make_movable_tasklet( + outer_tasklet_code="1.2", + ) + # By expanding the maps, we the memlet tree is no longer trivial. + sdfg.apply_transformations_repeated(dace_dataflow.MapExpansion) + assert util.count_nodes(state, dace_nodes.MapEntry) == 2 + me = None + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate_all=True, + ) + assert count == 1 + + A = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + B = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + ref = A + 1.2 + + csdfg = sdfg.compile() + csdfg(A=A, B=B) + assert np.allclose(B, ref) + + +def test_move_tasklet_inside_two_inner_connector(): + sdfg, state, outer_tasklet, outer_scalar, me = _make_movable_tasklet( + outer_tasklet_code="32.2", + ) + mapped_tasklet = next( + iter(e.dst for e in state.out_edges(me) if isinstance(e.dst, dace_nodes.Tasklet)) + ) + + state.add_edge( + me, + f"OUT_{outer_scalar.data}", + mapped_tasklet, + "__in2", + dace.Memlet(f"{outer_scalar.data}[0]"), + ) + mapped_tasklet.add_in_connector("__in2") + mapped_tasklet.code.as_string = "__out = __in0 + __in1 + __in2" + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate_all=True, + ) + assert count == 1 + + A = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + B = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + ref = A + 2 * (32.2) + + csdfg = sdfg.compile() + csdfg(A=A, B=B) + assert np.allclose(B, ref) + + +def test_move_tasklet_inside_outer_scalar_used_outside(): + sdfg, state, outer_tasklet, outer_scalar, me = _make_movable_tasklet( + outer_tasklet_code="22.6", + ) + sdfg.add_array("C", shape=(1,), dtype=dace.float64, transient=False) + state.add_edge(outer_scalar, None, state.add_access("C"), None, dace.Memlet("C[0]")) + + count = sdfg.apply_transformations_repeated( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate_all=True, + ) + assert count == 1 + + A = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + B = np.array(np.random.rand(10, 10), dtype=np.float64, copy=True) + C = np.array(np.random.rand(1), dtype=np.float64, copy=True) + ref_C = 22.6 + ref_B = A + ref_C + + csdfg = sdfg.compile() + csdfg(A=A, B=B, C=C) + assert np.allclose(B, ref_B) + assert np.allclose(C, ref_C) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py index 96584b8273..8626cb8e07 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_serial_map_promoter.py @@ -68,7 +68,7 @@ def test_serial_map_promotion(): external_edges=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 assert len(map_entry_1d.map.params) == 1 assert len(map_entry_1d.map.range) == 1 assert len(map_entry_2d.map.params) == 2 @@ -83,7 +83,7 @@ def test_serial_map_promotion(): validate_all=True, ) - assert util._count_nodes(sdfg, dace_nodes.MapEntry) == 2 + assert util.count_nodes(sdfg, dace_nodes.MapEntry) == 2 assert len(map_entry_1d.map.params) == 2 assert len(map_entry_1d.map.range) == 2 assert len(map_entry_2d.map.params) == 2 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py index ac88f4fef8..b82cecee98 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py @@ -14,7 +14,7 @@ @overload -def _count_nodes( +def count_nodes( graph: Union[dace.SDFG, dace.SDFGState], node_type: tuple[type, ...] | type, return_nodes: Literal[False], @@ -22,14 +22,14 @@ def _count_nodes( @overload -def _count_nodes( +def count_nodes( graph: Union[dace.SDFG, dace.SDFGState], node_type: tuple[type, ...] | type, return_nodes: Literal[True], ) -> list[dace_nodes.Node]: ... -def _count_nodes( +def count_nodes( graph: Union[dace.SDFG, dace.SDFGState], node_type: tuple[type, ...] | type, return_nodes: bool = False,