diff --git a/.flake8 b/.flake8 index 1caa5971df..1648ecdd38 100644 --- a/.flake8 +++ b/.flake8 @@ -3,25 +3,31 @@ max-line-length = 100 max-complexity = 15 doctests = true -ignore = - B008 # Do not perform function calls in argument defaults - D1 # Public code object needs docstring - DAR # Disable dargling errors by default - E203 # Whitespace before ':' (black formatter breaks this sometimes) - E501 # Line too long (using Bugbear's B950 warning) - W503 # Line break occurred before a binary operator +extend-ignore = + # Do not perform function calls in argument defaults + B008, + # Public code object needs docstring + D1, + # Disable dargling errors by default + DAR, + # Whitespace before ':' (black formatter breaks this sometimes) + E203, + # Line too long (using Bugbear's B950 warning) + E501, + # Line break occurred before a binary operator + W503 exclude = - .eggs - .gt_cache - .ipynb_checkpoints - .tox - _local_ - build - dist - docs - _external_src - tests/_disabled + .eggs, + .gt_cache, + .ipynb_checkpoints, + .tox, + _local_, + build, + dist, + docs, + _external_src, + tests/_disabled, setup.py rst-roles = diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7398abdc68..d38e84f076 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,6 +37,7 @@ repos: # args: [--remove] - id: name-tests-test args: [--pytest-test-first] + exclude: ^liskov/tests/samples - repo: https://gitlab.com/bmares/check-json5 rev: v1.0.0 diff --git a/liskov/README.md b/liskov/README.md new file mode 100644 index 0000000000..ef517a1177 --- /dev/null +++ b/liskov/README.md @@ -0,0 +1,126 @@ +# icon4py-liskov + +A preprocessor that facilitates integration of gt4py code into the ICON model. + +## Installation + +To install the icon4py-liskov package, follow the instructions in the `README.md` file located in the root of the repository. + +## Description + +The icon4py-liskov package includes the `icon_liskov` CLI tool which takes a fortran file as input and processes it with the ICON-Liskov DSL Preprocessor. This preprocessor adds the necessary `USE` statements and generates OpenACC `DATA CREATE` statements and declares DSL input/output fields based on directives in the input file. The preprocessor also processes stencils defined in the input file using the `START STENCIL` and `END STENCIL` directives, inserting the necessary code to run the stencils and adding nvtx profile statements if specified with the `--profile` flag. + +### Usage + +To use the `icon_liskov` tool, run the following command: + +```bash +icon_liskov [--profile] +``` + +Where `input_filepath` is the path to the input file to be processed, and `output_filepath` is the path to the output file. The optional `--profile` flag adds nvtx profile statements to the stencils. + +### Preprocessor directives + +The ICON-Liskov DSL Preprocessor supports the following directives: + +#### `!$DSL IMPORTS()` + +This directive generates the necessary `USE` statements to import the Fortran to C interfaces. + +#### `!$DSL START CREATE()` + +This directive generates an OpenACC `DATA CREATE` statement for all output fields used in each DSL (icon4py) stencil. + +#### `!$DSL END CREATE()` + +This directive generates an OpenACC `END DATA` statement which is neccessary to close the OpenACC data region. + +#### `!$DSL DECLARE()` + +This directive is used to declare all DSL input/output fields. The required arguments are the field name and its associated dimensions. For example: + +```fortran +!$DSL DECLARE(vn=(nproma, p_patch%nlev, p_patch%nblks_e)) +``` + +will generate the following code: + +```fortran +! DSL INPUT / OUTPUT FIELDS +REAL(wp), DIMENSION((nproma, p_patch%nlev, p_patch%nblks_e)) :: vn_before +``` + +Furthermore, this directive also takes two optional keyword arguments. `type` takes a string which will be used to fill in the type of the declared field, for example `type=LOGICAL`. `suffix` takes a string which will be used as the suffix of the field e.g. `suffix=dsl`, by default the suffix is `before`. + +#### `!$DSL START STENCIL()` + +This directive denotes the start of a stencil. Required arguments are `name`, `vertical_lower`, `vertical_upper`, `horizontal_lower`, `horizontal_upper`. The value for `name` must correspond to a stencil found in one of the stencil modules inside `icon4py`, and all fields defined in the directive must correspond to the fields defined in the respective icon4py stencil. Optionally, absolute and relative tolerances for the output fields can also be set using the `_tol` or `_abs` suffixes respectively. An example call looks like this: + +```fortran +!$DSL START STENCIL(name=mo_nh_diffusion_stencil_06; & +!$DSL z_nabla2_e=z_nabla2_e(:,:,1); area_edge=p_patch%edges%area_edge(:,1); & +!$DSL fac_bdydiff_v=fac_bdydiff_v; vn=p_nh_prog%vn(:,:,1); vn_abs_tol=1e-21_wp; & +!$DSL vertical_lower=1; vertical_upper=nlev; & +!$DSL horizontal_lower=i_startidx; horizontal_upper=i_endidx) +``` + +In addition, other optional keyword arguments are the following: + +- `accpresent`: Takes a boolean string input, and controls the default data-sharing behavior for variables used in the OpenACC parallel region. Setting the flag to true will cause all variables to be assumed present on the device by default (`DEFAULT(PRESENT)`), and no explicit data-sharing attributes need to be specified. Setting it to false will require explicit data-sharing attributes for every variable used in the parallel region (`DEFAULT(NONE)`). By default it is set to false.

+ +- `mergecopy`: Takes a boolean string input. When set to True consecutive before field copy regions of stencils that have the mergecopy flag set to True are combined into a single before field copy region with a new name created by concatenating the names of the merged stencil regions. This is useful when there are consecutive stencils. By default it is set to false.

+ +- `copies`: Takes a boolean string input, and controls whether before field copies should be made or not. If set to False only the `#ifdef __DSL_VERIFY` directive is generated. Defaults to true.

+ +#### `!$DSL END STENCIL()` + +This directive denotes the end of a stencil. The required argument is `name`, which must match the name of the preceding `START STENCIL` directive. + +Together, the `START STENCIL` and `END STENCIL` directives result in the following generated code at the start and end of a stencil respectively. + +```fortran +#ifdef __DSL_VERIFY +!$ACC PARALLEL IF( i_am_accel_node .AND. acc_on ) DEFAULT(NONE) ASYNC(1) +vn_before(:, :, :) = vn(:, :, :) +!$ACC END PARALLEL +``` + +```fortran +call nvtxEndRange() +#endif +call wrap_run_mo_nh_diffusion_stencil_06( & +z_nabla2_e=z_nabla2_e(:, :, 1), & +area_edge=p_patch%edges%area_edge(:, 1), & +fac_bdydiff_v=fac_bdydiff_v, & +vn=p_nh_prog%vn(:, :, 1), & +vn_before=vn_before(:, :, 1), & +vn_abs_tol=1e-21_wp, & +vertical_lower=1, & +vertical_upper=nlev, & +horizontal_lower=i_startidx, & +horizontal_upper=i_endidx +) +``` + +Additionally, there are the following keyword arguments: + +- `noendif`: Takes a boolean string input and controls whether an `#endif` is generated or not. Defaults to false.

+ +- `noprofile`: Takes a boolean string input and controls whether a nvtx end profile directive is generated or not. Defaults to false.

+ +#### `!$DSL INSERT()` + +This directive allows the user to generate any text that is placed between the parentheses. This is useful for situations where custom code generation is necessary. + +#### `!$DSL START PROFILE()` + +This directive allows generating an nvtx start profile data statement, and takes the stencil `name` as an argument. + +#### `!$DSL END PROFILE()` + +This directive allows generating an nvtx end profile statement. + +#### `!$DSL ENDIF()` + +This directive generates an `#endif` statement. diff --git a/liskov/requirements-dev.txt b/liskov/requirements-dev.txt new file mode 100644 index 0000000000..2cb7bda5e8 --- /dev/null +++ b/liskov/requirements-dev.txt @@ -0,0 +1,3 @@ +-r ../base-requirements-dev.txt +-e ../common +-e . diff --git a/liskov/requirements.txt b/liskov/requirements.txt new file mode 100644 index 0000000000..c8420e5bf8 --- /dev/null +++ b/liskov/requirements.txt @@ -0,0 +1,3 @@ +-r ../base-requirements.txt +../common +. diff --git a/liskov/setup.cfg b/liskov/setup.cfg new file mode 100644 index 0000000000..a85b07a00f --- /dev/null +++ b/liskov/setup.cfg @@ -0,0 +1,55 @@ +# This file is mainly used to configure package creation with setuptools. +# Documentation: +# http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files +# +[metadata] +name = icon4py_liskov +description = ICON preprocessor to integrate Gt4Py code. +long_description = file: README.md +long_description_content_type = text/markdown +url = https://github.com/C2SM/icon4py +author = ETH Zurich +author_email = gridtools@cscs.ch +license = gpl3 +license_files = LICENSE +platforms = Linux, Mac +classifiers = + Development Status :: 3 - Alpha + Intended Audience :: Science/Research + License :: OSI Approved :: GNU General Public License v3 or later (GPLv3+) + Operating System :: POSIX + Programming Language :: Python + Programming Language :: Python :: 3 + Programming Language :: Python :: 3 :: Only + Programming Language :: Python :: 3.10 + Programming Language :: Python :: Implementation :: CPython + Topic :: Scientific/Engineering :: Atmospheric Science + Topic :: Scientific/Engineering :: Mathematics + Topic :: Scientific/Engineering :: Physics +project_urls = + Source Code = https://github.com/GridTools/gt4py + +[options] +packages = find_namespace: +install_requires = + icon4py-common + +python_requires = >=3.10 +package_dir = + = src +zip_safe = False + +[options.package_data] +# References: +# https://setuptools.pypa.io/en/latest/userguide/datafiles.html +# https://github.com/abravalheri/experiment-setuptools-package-data +* = *.md, *.rst, *.toml, *.txt, py.typed + +[options.packages.find] +where = src +exclude = + tests + +[options.entry_points] +console_scripts = + icon_liskov = icon4py.liskov.cli:main diff --git a/liskov/setup.py b/liskov/setup.py new file mode 100644 index 0000000000..9c9f7b81c8 --- /dev/null +++ b/liskov/setup.py @@ -0,0 +1,18 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from setuptools import setup + + +if __name__ == "__main__": + setup() diff --git a/liskov/src/icon4py/liskov/__init__.py b/liskov/src/icon4py/liskov/__init__.py new file mode 100644 index 0000000000..15dfdb0098 --- /dev/null +++ b/liskov/src/icon4py/liskov/__init__.py @@ -0,0 +1,12 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/liskov/src/icon4py/liskov/cli.py b/liskov/src/icon4py/liskov/cli.py new file mode 100644 index 0000000000..a27798d706 --- /dev/null +++ b/liskov/src/icon4py/liskov/cli.py @@ -0,0 +1,64 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pathlib + +import click + +from icon4py.liskov.logger import setup_logger +from icon4py.liskov.pipeline import ( + load_gt4py_stencils, + parse_fortran_file, + run_code_generation, +) + + +logger = setup_logger(__name__) + + +@click.command("icon_liskov") +@click.argument( + "input_filepath", + type=click.Path( + exists=True, dir_okay=False, resolve_path=True, path_type=pathlib.Path + ), +) +@click.argument( + "output_filepath", + type=click.Path(dir_okay=False, resolve_path=True, path_type=pathlib.Path), +) +@click.option( + "--profile", "-p", is_flag=True, help="Add nvtx profile statements to stencils." +) +def main( + input_filepath: pathlib.Path, output_filepath: pathlib.Path, profile: bool +) -> None: + """Command line interface for interacting with the ICON-Liskov DSL Preprocessor. + + Usage: + icon_liskov [--profile] + + Options: + -p --profile Add nvtx profile statements to stencils. + + Arguments: + input_filepath Path to the input file to process. + output_filepath Path to the output file to generate. + """ + parsed = parse_fortran_file(input_filepath) + parsed_checked = load_gt4py_stencils(parsed) + run_code_generation(parsed_checked, input_filepath, output_filepath, profile) + + +if __name__ == "__main__": + main() diff --git a/liskov/src/icon4py/liskov/codegen/__init__.py b/liskov/src/icon4py/liskov/codegen/__init__.py new file mode 100644 index 0000000000..8d0cdb4874 --- /dev/null +++ b/liskov/src/icon4py/liskov/codegen/__init__.py @@ -0,0 +1,15 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +# required import for CheckForDirectiveClasses metaclass +import icon4py.liskov.parsing.types as ts # noqa: F401 diff --git a/liskov/src/icon4py/liskov/codegen/f90.py b/liskov/src/icon4py/liskov/codegen/f90.py new file mode 100644 index 0000000000..bdc62df51b --- /dev/null +++ b/liskov/src/icon4py/liskov/codegen/f90.py @@ -0,0 +1,404 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import re +from dataclasses import asdict +from typing import Optional, Sequence, Type + +import gt4py.eve as eve +from gt4py.eve.codegen import JinjaTemplate as as_jinja +from gt4py.eve.codegen import TemplatedGenerator + +from icon4py.bindings.utils import format_fortran_code +from icon4py.liskov.codegen.interface import ( + CodeGenInput, + DeclareData, + StartStencilData, +) + + +def enclose_in_parentheses(string: str) -> str: + return f"({string})" + + +def generate_fortran_code( + parent_node: Type[eve.Node], + code_generator: Type[TemplatedGenerator], + **kwargs: CodeGenInput | Sequence[CodeGenInput] | Optional[bool], +) -> str: + """ + Generate Fortran code for the given parent node and code generator. + + Args: + parent_node: A subclass of eve.Node that represents the parent node. + code_generator: A subclass of TemplatedGenerator that will be used + to generate the code. + **kwargs: Arguments to be passed to the parent node constructor. + This can be a single CodeGenInput value, a sequence of CodeGenInput + values, or a boolean value, which is required by some parent nodes which + require a profile argument. + + Returns: + A string containing the formatted Fortran code. + """ + parent = parent_node(**kwargs) + source = code_generator.apply(parent) + formatted_source = format_fortran_code(source) + return formatted_source + + +class BoundsFields(eve.Node): + vlower: str + vupper: str + hlower: str + hupper: str + + +class Assign(eve.Node): + variable: str + association: str + + +class Field(Assign): + dims: Optional[int] + abs_tol: Optional[str] = None + rel_tol: Optional[str] = None + inp: bool + out: bool + + +class InputFields(eve.Node): + fields: list[Field] + + +class OutputFields(InputFields): + ... + + +class ToleranceFields(InputFields): + ... + + +def get_array_dims(association: str) -> str: + """ + Return the dimensions of an array in a string format. + + Args: + association: The string representation of the array. + """ + indexes = re.findall("\\(([^)]+)", association) + if len(indexes) > 1: + idx = indexes[-1] + else: + idx = indexes[0] + + dims = list(idx) + + return "".join(list(dims)) + + +class EndStencilStatement(eve.Node): + stencil_data: StartStencilData + profile: bool + noendif: Optional[bool] + noprofile: Optional[bool] + + name: str = eve.datamodels.field(init=False) + input_fields: InputFields = eve.datamodels.field(init=False) + output_fields: OutputFields = eve.datamodels.field(init=False) + tolerance_fields: ToleranceFields = eve.datamodels.field(init=False) + bounds_fields: BoundsFields = eve.datamodels.field(init=False) + + def __post_init__(self) -> None: # type: ignore + all_fields = [Field(**asdict(f)) for f in self.stencil_data.fields] + self.bounds_fields = BoundsFields(**asdict(self.stencil_data.bounds)) + self.name = self.stencil_data.name + self.input_fields = InputFields(fields=[f for f in all_fields if f.inp]) + self.output_fields = OutputFields(fields=[f for f in all_fields if f.out]) + self.tolerance_fields = ToleranceFields( + fields=[f for f in all_fields if f.rel_tol or f.abs_tol] + ) + + +class EndStencilStatementGenerator(TemplatedGenerator): + EndStencilStatement = as_jinja( + """ + {%- if _this_node.profile %} + {% if _this_node.noprofile %}{% else %}call nvtxEndRange(){% endif %} + {%- endif %} + {% if _this_node.noendif %}{% else %}#endif{% endif %} + call wrap_run_{{ name }}( & + {{ input_fields }} + {{ output_fields }} + {{ tolerance_fields }} + {{ bounds_fields }} + """ + ) + + InputFields = as_jinja( + """ + {%- for field in _this_node.fields %} + {%- if field.out %} + + {%- else %} + {{ field.variable }}={{ field.association }},& + {%- endif -%} + {%- endfor %} + """ + ) + + OutputFields = as_jinja( + """ + {%- for field in _this_node.fields %} + {{ field.variable }}={{ field.association }},& + {{ field.variable }}_before={{ field.variable }}_before{{ field.rh_index }},& + {%- endfor %} + """ + ) + + def visit_OutputFields(self, out: OutputFields) -> OutputFields: # type: ignore + for f in out.fields: # type: ignore + idx = render_index(f.dims) + split_idx = idx.split(",") + + if len(split_idx) >= 3: + split_idx[-1] = "1" + + f.rh_index = enclose_in_parentheses(",".join(split_idx)) + return self.generic_visit(out) + + ToleranceFields = as_jinja( + """ + {%- if _this_node.fields|length < 1 -%} + + {%- else -%} + + {%- for f in _this_node.fields -%} + {% if f.rel_tol %} + {{ f.variable }}_rel_tol={{ f.rel_tol }}, & + {%- endif -%} + {% if f.abs_tol %} + {{ f.variable }}_abs_tol={{ f.abs_tol }}, & + {% endif %} + {%- endfor -%} + + {%- endif -%} + """ + ) + + BoundsFields = as_jinja( + """vertical_lower={{ vlower }}, & + vertical_upper={{ vupper }}, & + horizontal_lower={{ hlower }}, & + horizontal_upper={{ hupper }}) + """ + ) + + +class Declaration(Assign): + ... + + +class DeclareStatement(eve.Node): + declare_data: DeclareData + declarations: list[Declaration] = eve.datamodels.field(init=False) + + def __post_init__(self) -> None: # type: ignore + self.declarations = [ + Declaration(variable=k, association=v) + for k, v in self.declare_data.declarations.items() + ] + + +class DeclareStatementGenerator(TemplatedGenerator): + DeclareStatement = as_jinja( + """ + ! DSL INPUT / OUTPUT FIELDS + {%- for d in _this_node.declarations %} + {{ _this_node.declare_data.ident_type }}, DIMENSION({{ d.association }}) :: {{ d.variable }}_{{ _this_node.declare_data.suffix }} + {%- endfor %} + """ + ) + + +class CopyDeclaration(Declaration): + lh_index: str + rh_index: str + + +class StartStencilStatement(eve.Node): + stencil_data: StartStencilData + profile: bool + copy_declarations: list[CopyDeclaration] = eve.datamodels.field(init=False) + + def __post_init__(self) -> None: # type: ignore + all_fields = [Field(**asdict(f)) for f in self.stencil_data.fields] + self.copy_declarations = [ + self.make_copy_declaration(f) for f in all_fields if f.out + ] + self.acc_present = "PRESENT" if self.stencil_data.acc_present else "NONE" + + @staticmethod + def make_copy_declaration(f: Field) -> CopyDeclaration: + if f.dims is None: + raise Exception(f"{f.variable} not declared!") + + lh_idx = render_index(f.dims) + + # get length of association index + association_dims = get_array_dims(f.association).split(",") + n_association_dims = len(association_dims) + + offset = len(",".join(association_dims)) + 2 + truncated_association = f.association[:-offset] + + if n_association_dims > f.dims: + rh_idx = f"{lh_idx},{association_dims[-1]}" + else: + rh_idx = f"{lh_idx}" + + lh_idx = enclose_in_parentheses(lh_idx) + rh_idx = enclose_in_parentheses(rh_idx) + + return CopyDeclaration( + variable=f.variable, + association=truncated_association, + lh_index=lh_idx, + rh_index=rh_idx, + ) + + +def render_index(n: int) -> str: + """ + Render a string of comma-separated colon characters, used to define the shape of an array in Fortran. + + Args: + n (int): The number of colons to include in the returned string. + + Returns: + str: A comma-separated string of n colons. + + Example: + >>> render_index(3) + ':,:,:' + """ + return ",".join([":" for _ in range(n)]) + + +class StartStencilStatementGenerator(TemplatedGenerator): + StartStencilStatement = as_jinja( + """ + #ifdef __DSL_VERIFY + {% if _this_node.stencil_data.copies -%} + !$ACC PARALLEL IF( i_am_accel_node ) DEFAULT({{ _this_node.acc_present }}) ASYNC(1) + {%- for d in _this_node.copy_declarations %} + {{ d.variable }}_before{{ d.lh_index }} = {{ d.association }}{{ d.rh_index }} + {%- endfor %} + !$ACC END PARALLEL + {%- endif -%} + + {%- if _this_node.profile %} + call nvtxStartRange("{{ _this_node.stencil_data.name }}") + {%- endif %} + """ + ) + + +class ImportsStatement(eve.Node): + stencils: list[StartStencilData] + stencil_names: list[str] = eve.datamodels.field(init=False) + + def __post_init__(self) -> None: # type: ignore + self.stencil_names = sorted(set([stencil.name for stencil in self.stencils])) + + +class ImportsStatementGenerator(TemplatedGenerator): + ImportsStatement = as_jinja( + """ {% for name in stencil_names %}USE {{ name }}, ONLY: wrap_run_{{ name }}\n{% endfor %}""" + ) + + +class StartCreateStatement(eve.Node): + stencils: list[StartStencilData] + out_field_names: list[str] = eve.datamodels.field(init=False) + + def __post_init__(self) -> None: # type: ignore + self.out_field_names = sorted( + set( + [ + field.variable + for stencil in self.stencils + for field in stencil.fields + if field.out + ] + ) + ) + + +class StartCreateStatementGenerator(TemplatedGenerator): + StartCreateStatement = as_jinja( + """ + #ifdef __DSL_VERIFY + dsl_verify = .TRUE. + #else + dsl_verify = .FALSE. + #endif + + !$ACC DATA CREATE( & + {%- for name in out_field_names %} + !$ACC {{ name }}_before {%- if not loop.last -%}, & {% else %} & {%- endif -%} + {%- endfor %} + !$ACC ), & + !$ACC IF ( i_am_accel_node .AND. dsl_verify) + """ + ) + + +class EndCreateStatement(eve.Node): + ... + + +class EndCreateStatementGenerator(TemplatedGenerator): + EndCreateStatement = as_jinja("!$ACC END DATA") + + +class EndIfStatement(eve.Node): + ... + + +class EndIfStatementGenerator(TemplatedGenerator): + EndIfStatement = as_jinja("#endif") + + +class StartProfileStatement(eve.Node): + name: str + + +class StartProfileStatementGenerator(TemplatedGenerator): + StartProfileStatement = as_jinja('call nvtxStartRange("{{ _this_node.name }}")') + + +class EndProfileStatement(eve.Node): + ... + + +class EndProfileStatementGenerator(TemplatedGenerator): + EndProfileStatement = as_jinja("call nvtxEndRange()") + + +class InsertStatement(eve.Node): + content: str + + +class InsertStatementGenerator(TemplatedGenerator): + InsertStatement = as_jinja("{{ _this_node.content }}") diff --git a/liskov/src/icon4py/liskov/codegen/generate.py b/liskov/src/icon4py/liskov/codegen/generate.py new file mode 100644 index 0000000000..6cbb057ccd --- /dev/null +++ b/liskov/src/icon4py/liskov/codegen/generate.py @@ -0,0 +1,264 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from dataclasses import dataclass +from typing import Optional, Sequence, Type + +import gt4py.eve as eve +from gt4py.eve.codegen import TemplatedGenerator +from typing_extensions import Any + +from icon4py.liskov.codegen.f90 import ( + DeclareStatement, + DeclareStatementGenerator, + EndCreateStatement, + EndCreateStatementGenerator, + EndIfStatement, + EndIfStatementGenerator, + EndProfileStatement, + EndProfileStatementGenerator, + EndStencilStatement, + EndStencilStatementGenerator, + ImportsStatement, + ImportsStatementGenerator, + InsertStatement, + InsertStatementGenerator, + StartCreateStatement, + StartCreateStatementGenerator, + StartProfileStatement, + StartProfileStatementGenerator, + StartStencilStatement, + StartStencilStatementGenerator, + generate_fortran_code, +) +from icon4py.liskov.codegen.interface import ( + CodeGenInput, + DeserialisedDirectives, + StartStencilData, + UnusedDirective, +) +from icon4py.liskov.common import Step +from icon4py.liskov.logger import setup_logger + + +logger = setup_logger(__name__) + + +@dataclass +class GeneratedCode: + """A class for storing generated f90 code and its line number information.""" + + source: str + startln: int + endln: int + + +class IntegrationGenerator(Step): + def __init__(self, directives: DeserialisedDirectives, profile: bool): + self.profile = profile + self.directives = directives + self.generated: list[GeneratedCode] = [] + + def __call__(self, data: Any = None) -> list[GeneratedCode]: + """Generate all f90 code for integration. + + Args: + profile: A boolean indicating whether to include profiling calls in the generated code. + """ + self._generate_create() + self._generate_imports() + self._generate_declare() + self._generate_start_stencil() + self._generate_end_stencil() + self._generate_endif() + self._generate_profile() + self._generate_insert() + return self.generated + + def _generate( + self, + parent_node: Type[eve.Node], + code_generator: Type[TemplatedGenerator], + startln: int, + endln: int, + **kwargs: CodeGenInput | Sequence[CodeGenInput] | Optional[bool] | Any, + ) -> None: + """Add a GeneratedCode object to the `generated` attribute with the given source code and line number information. + + Args: + parent_node: The parent node of the code to be generated. + code_generator: The code generator to use for generating the code. + startln: The start line number of the generated code. + endln: The end line number of the generated code. + **kwargs: Additional keyword arguments to be passed to the code generator. + """ + source = generate_fortran_code(parent_node, code_generator, **kwargs) + code = GeneratedCode(source=source, startln=startln, endln=endln) + self.generated.append(code) + + def _generate_declare(self) -> None: + """Generate f90 code for declaration statements.""" + for i, declare in enumerate(self.directives.Declare): + logger.info("Generating DECLARE statement.") + self._generate( + DeclareStatement, + DeclareStatementGenerator, + self.directives.Declare[i].startln, + self.directives.Declare[i].endln, + declare_data=declare, + ) + + def _generate_start_stencil(self) -> None: + """Generate f90 integration code surrounding a stencil. + + Args: + profile: A boolean indicating whether to include profiling calls in the generated code. + """ + i = 0 + + while i < len(self.directives.StartStencil): + stencil = self.directives.StartStencil[i] + logger.info(f"Generating START statement for {stencil.name}") + + try: + next_stencil = self.directives.StartStencil[i + 1] + except IndexError: + pass + + if stencil.mergecopy and next_stencil.mergecopy: + stencil = StartStencilData( + startln=stencil.startln, + endln=next_stencil.endln, + name=stencil.name + "_" + next_stencil.name, + fields=stencil.fields + next_stencil.fields, + bounds=stencil.bounds, + acc_present=stencil.acc_present, + mergecopy=stencil.mergecopy, + copies=stencil.copies, + ) + i += 2 + + self._generate( + StartStencilStatement, + StartStencilStatementGenerator, + stencil.startln, + next_stencil.endln, + stencil_data=stencil, + profile=self.profile, + ) + else: + self._generate( + StartStencilStatement, + StartStencilStatementGenerator, + self.directives.StartStencil[i].startln, + self.directives.StartStencil[i].endln, + stencil_data=stencil, + profile=self.profile, + ) + i += 1 + + def _generate_end_stencil(self) -> None: + """Generate f90 integration code surrounding a stencil. + + Args: + profile: A boolean indicating whether to include profiling calls in the generated code. + """ + for i, stencil in enumerate(self.directives.StartStencil): + logger.info(f"Generating END statement for {stencil.name}") + self._generate( + EndStencilStatement, + EndStencilStatementGenerator, + self.directives.EndStencil[i].startln, + self.directives.EndStencil[i].endln, + stencil_data=stencil, + profile=self.profile, + noendif=self.directives.EndStencil[i].noendif, + noprofile=self.directives.EndStencil[i].noprofile, + ) + + def _generate_imports(self) -> None: + """Generate f90 code for import statements.""" + logger.info("Generating IMPORT statement.") + self._generate( + ImportsStatement, + ImportsStatementGenerator, + self.directives.Imports.startln, + self.directives.Imports.endln, + stencils=self.directives.StartStencil, + ) + + def _generate_create(self) -> None: + """Generate f90 code for OpenACC DATA CREATE statements.""" + logger.info("Generating DATA CREATE statement.") + self._generate( + StartCreateStatement, + StartCreateStatementGenerator, + self.directives.StartCreate.startln, + self.directives.StartCreate.endln, + stencils=self.directives.StartStencil, + ) + + self._generate( + EndCreateStatement, + EndCreateStatementGenerator, + self.directives.EndCreate.startln, + self.directives.EndCreate.endln, + ) + + def _generate_endif(self) -> None: + """Generate f90 code for endif statements.""" + if self.directives.EndIf != UnusedDirective: + for endif in self.directives.EndIf: # type: ignore + logger.info("Generating ENDIF statement.") + self._generate( + EndIfStatement, + EndIfStatementGenerator, + endif.startln, + endif.endln, + ) + + def _generate_profile(self) -> None: + """Generate additional nvtx profiling statements.""" + if self.directives.StartProfile != UnusedDirective: + for start in self.directives.StartProfile: # type: ignore + logger.info("Generating nvtx start statement.") + self._generate( + StartProfileStatement, + StartProfileStatementGenerator, + start.startln, + start.endln, + name=start.name, + ) + + if self.directives.EndProfile != UnusedDirective: + for end in self.directives.EndProfile: # type: ignore + logger.info("Generating nvtx end statement.") + self._generate( + EndProfileStatement, + EndProfileStatementGenerator, + end.startln, + end.endln, + ) + + def _generate_insert(self) -> None: + """Generate free form statement from insert directive.""" + if self.directives.Insert != UnusedDirective: + for insert in self.directives.Insert: # type: ignore + logger.info("Generating free form statement.") + self._generate( + InsertStatement, + InsertStatementGenerator, + insert.startln, + insert.endln, + content=insert.content, + ) diff --git a/liskov/src/icon4py/liskov/codegen/interface.py b/liskov/src/icon4py/liskov/codegen/interface.py new file mode 100644 index 0000000000..f89fccd9e7 --- /dev/null +++ b/liskov/src/icon4py/liskov/codegen/interface.py @@ -0,0 +1,117 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +import dataclasses +from dataclasses import dataclass +from typing import Optional, Sequence + + +class UnusedDirective: + ... + + +@dataclass +class CodeGenInput: + startln: int + endln: int + + +@dataclass +class BoundsData: + hlower: str + hupper: str + vlower: str + vupper: str + + +@dataclass +class FieldAssociationData: + variable: str + association: str + dims: Optional[int] + abs_tol: Optional[str] = dataclasses.field(kw_only=True, default=None) + rel_tol: Optional[str] = dataclasses.field(kw_only=True, default=None) + inp: Optional[bool] = dataclasses.field(kw_only=False, default=None) + out: Optional[bool] = dataclasses.field(kw_only=False, default=None) + + +@dataclass +class DeclareData(CodeGenInput): + declarations: dict[str, str] + ident_type: str + suffix: str + + +@dataclass +class ImportsData(CodeGenInput): + ... + + +@dataclass +class StartCreateData(CodeGenInput): + ... + + +@dataclass +class EndCreateData(CodeGenInput): + ... + + +@dataclass +class EndIfData(CodeGenInput): + ... + + +@dataclass +class StartProfileData(CodeGenInput): + name: str + + +@dataclass +class EndProfileData(CodeGenInput): + ... + + +@dataclass +class StartStencilData(CodeGenInput): + name: str + fields: list[FieldAssociationData] + bounds: BoundsData + acc_present: Optional[bool] + mergecopy: Optional[bool] + copies: Optional[bool] + + +@dataclass +class EndStencilData(CodeGenInput): + name: str + noendif: Optional[bool] + noprofile: Optional[bool] + + +@dataclass +class InsertData(CodeGenInput): + content: str + + +@dataclass +class DeserialisedDirectives: + StartStencil: Sequence[StartStencilData] + EndStencil: Sequence[EndStencilData] + Declare: Sequence[DeclareData] + Imports: ImportsData + StartCreate: StartCreateData + EndCreate: EndCreateData + EndIf: Sequence[EndIfData] | UnusedDirective + StartProfile: Sequence[StartProfileData] | UnusedDirective + EndProfile: Sequence[EndProfileData] | UnusedDirective + Insert: Sequence[InsertData] | UnusedDirective diff --git a/liskov/src/icon4py/liskov/codegen/write.py b/liskov/src/icon4py/liskov/codegen/write.py new file mode 100644 index 0000000000..6ce3567510 --- /dev/null +++ b/liskov/src/icon4py/liskov/codegen/write.py @@ -0,0 +1,113 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +from pathlib import Path +from typing import List + +from icon4py.liskov.codegen.generate import GeneratedCode +from icon4py.liskov.common import Step +from icon4py.liskov.logger import setup_logger +from icon4py.liskov.parsing.types import DIRECTIVE_IDENT + + +logger = setup_logger(__name__) + + +class IntegrationWriter(Step): + def __init__(self, input_filepath: Path, output_filepath: Path) -> None: + """Initialize an IntegrationWriter instance with a list of generated code. + + Args: + input_filepath: Path to file containing directives. + output_filepath: Path to file to write generated code. + """ + self.input_filepath = input_filepath + self.output_filepath = output_filepath + + def __call__(self, generated: List[GeneratedCode]) -> None: + """Write a file containing generated code, with the DSL directives removed in the same directory as filepath using a new suffix. + + Args: + generated: A list of GeneratedCode instances representing the generated code that will be written to a file. + """ + current_file = self._read_file() + with_generated_code = self._insert_generated_code(current_file, generated) + without_directives = self._remove_directives(with_generated_code) + self._write_file(without_directives) + + def _read_file(self) -> List[str]: + """Read the lines of a file into a list. + + Returns: + A list of strings representing the lines of the file. + """ + with self.input_filepath.open("r") as f: + lines = f.readlines() + return lines + + @staticmethod + def _insert_generated_code( + current_file: List[str], generated_code: List[GeneratedCode] + ) -> List[str]: + """Insert generated code into the current file at the specified line numbers. + + The generated code is sorted in ascending order of the start line number to ensure that + it is inserted into the current file in the correct order. The `cur_line_num` variable is + used to keep track of the current line number in the current file, and is updated after + each generated code block is inserted to account for any additional lines that have been + added to the file. + + Args: + current_file: A list of strings representing the lines of the current file. + generated_code: A list of GeneratedCode instances representing the generated code to be inserted into the current file. + + Returns: + A list of strings representing the current file with the generated code inserted at the + specified line numbers. + """ + generated_code.sort(key=lambda gen: gen.startln) + cur_line_num = 0 + + for gen in generated_code: + gen.startln += cur_line_num + + to_insert = gen.source.split("\n") + + to_insert = [f"{s}\n" for s in to_insert] + + current_file[gen.startln : gen.startln] = to_insert + + cur_line_num += len(to_insert) + return current_file + + def _write_file(self, generated_code: List[str]) -> None: + """Write generated code to a file. + + Args: + generated_code: A list of strings representing the generated code to be written to the file. + """ + code = "".join(generated_code) + with self.output_filepath.open("w") as f: + f.write(code) + logger.info(f"Wrote new file to {self.output_filepath}") + + @staticmethod + def _remove_directives(current_file: List[str]) -> List[str]: + """Remove the directives from the current file. + + Args: + current_file: A list of strings representing the lines of the current file. + + Returns: + A list of strings representing the current file with the directives removed. + """ + return [ln for ln in current_file if DIRECTIVE_IDENT not in ln] diff --git a/liskov/src/icon4py/liskov/common.py b/liskov/src/icon4py/liskov/common.py new file mode 100644 index 0000000000..6fb9bef3bc --- /dev/null +++ b/liskov/src/icon4py/liskov/common.py @@ -0,0 +1,72 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from abc import ABC, abstractmethod +from functools import wraps +from typing import Any, Callable, Sequence + + +class Step(ABC): + """Abstract base class for pipeline steps. + + Defines the interface for pipeline steps to implement. + """ + + @abstractmethod + def __call__(self, data: Any) -> Any: + """Abstract method to be implemented by concrete steps. + + Args: + data (Any): Data to be processed. + + Returns: + Any: Processed data to be passed to the next step. + + """ + pass + + +class LinearPipelineComposer: + """Creates a linear pipeline for executing a sequence of functions (steps). + + Args: + steps (List): List of functions to be executed in the pipeline. + """ + + def __init__(self, steps: Sequence[Step]) -> None: + self.steps = steps + + def execute(self, data: Any = None) -> Any: + """Execute all pipeline steps.""" + for step in self.steps: + data = step(data) + return data + + +def linear_pipeline(func: Callable) -> Callable: + """Apply a linear pipeline to a function using a decorator. + + Args: + func: The function to be decorated. + + Returns: + The decorated function. + """ + + @wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + steps = func(*args, **kwargs) + composer = LinearPipelineComposer(steps) + return composer.execute() + + return wrapper diff --git a/liskov/src/icon4py/liskov/external/__init__.py b/liskov/src/icon4py/liskov/external/__init__.py new file mode 100644 index 0000000000..15dfdb0098 --- /dev/null +++ b/liskov/src/icon4py/liskov/external/__init__.py @@ -0,0 +1,12 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/liskov/src/icon4py/liskov/external/gt4py.py b/liskov/src/icon4py/liskov/external/gt4py.py new file mode 100644 index 0000000000..618c18115a --- /dev/null +++ b/liskov/src/icon4py/liskov/external/gt4py.py @@ -0,0 +1,81 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import importlib +from inspect import getmembers + +from gt4py.next.ffront.decorator import Program +from typing_extensions import Any + +from icon4py.liskov.codegen.interface import DeserialisedDirectives +from icon4py.liskov.common import Step +from icon4py.liskov.logger import setup_logger +from icon4py.liskov.parsing.exceptions import ( + IncompatibleFieldError, + UnknownStencilError, +) +from icon4py.pyutils.metadata import get_stencil_info + + +logger = setup_logger(__name__) + + +class UpdateFieldsWithGt4PyStencils(Step): + _STENCIL_PACKAGES = ["atm_dyn_iconam", "advection"] + + def __init__(self, parsed: DeserialisedDirectives): + self.parsed = parsed + + def __call__(self, data: Any = None) -> DeserialisedDirectives: + logger.info("Updating parsed fields with data from icon4py stencils...") + + for s in self.parsed.StartStencil: + gt4py_program = self._collect_icon4py_stencil(s.name) + gt4py_stencil_info = get_stencil_info(gt4py_program) + gt4py_fields = gt4py_stencil_info.fields + for f in s.fields: + try: + field_info = gt4py_fields[f.variable] + except KeyError: + raise IncompatibleFieldError( + f"Used field variable name that is incompatible with the expected field names defined in {s.name} in icon4py." + ) + f.out = field_info.out + f.inp = field_info.inp + return self.parsed + + def _collect_icon4py_stencil(self, stencil_name: str) -> Program: + """Collect and return the ICON4PY stencil program with the given name.""" + err_counter = 0 + for pkg in self._STENCIL_PACKAGES: + + try: + module_name = f"icon4py.{pkg}.{stencil_name}" + module = importlib.import_module(module_name) + except ModuleNotFoundError: + err_counter += 1 + + if err_counter == len(self._STENCIL_PACKAGES): + raise UnknownStencilError( + f"Did not find module: {stencil_name} in icon4py." + ) + + module_members = getmembers(module) + found_stencil = [elt for elt in module_members if elt[0] == stencil_name] + + if len(found_stencil) == 0: + raise UnknownStencilError( + f"Did not find module member: {stencil_name} in module: {module.__name__} in icon4py." + ) + + return found_stencil[0][1] diff --git a/liskov/src/icon4py/liskov/logger.py b/liskov/src/icon4py/liskov/logger.py new file mode 100644 index 0000000000..636a582d49 --- /dev/null +++ b/liskov/src/icon4py/liskov/logger.py @@ -0,0 +1,27 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import logging + + +def setup_logger(name: str, log_level: int = logging.INFO) -> logging.Logger: + """Set up a logger with a given name and log level.""" + logger = logging.getLogger(name) + logger.setLevel(log_level) + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) + return logger diff --git a/liskov/src/icon4py/liskov/parsing/__init__.py b/liskov/src/icon4py/liskov/parsing/__init__.py new file mode 100644 index 0000000000..15dfdb0098 --- /dev/null +++ b/liskov/src/icon4py/liskov/parsing/__init__.py @@ -0,0 +1,12 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/liskov/src/icon4py/liskov/parsing/deserialise.py b/liskov/src/icon4py/liskov/parsing/deserialise.py new file mode 100644 index 0000000000..5b8850b71b --- /dev/null +++ b/liskov/src/icon4py/liskov/parsing/deserialise.py @@ -0,0 +1,423 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from dataclasses import dataclass +from typing import Any, Callable, Optional, Protocol, Type + +import icon4py.liskov.parsing.types as ts +from icon4py.liskov.codegen.interface import ( + BoundsData, + CodeGenInput, + DeclareData, + DeserialisedDirectives, + EndCreateData, + EndIfData, + EndProfileData, + EndStencilData, + FieldAssociationData, + ImportsData, + InsertData, + StartCreateData, + StartProfileData, + StartStencilData, + UnusedDirective, +) +from icon4py.liskov.common import Step +from icon4py.liskov.logger import setup_logger +from icon4py.liskov.parsing.exceptions import ( + DirectiveSyntaxError, + MissingBoundsError, + MissingDirectiveArgumentError, +) +from icon4py.liskov.parsing.utils import ( + extract_directive, + flatten_list_of_dicts, + string_to_bool, +) + + +TOLERANCE_ARGS = ["abs_tol", "rel_tol"] +DEFAULT_DECLARE_IDENT_TYPE = "REAL(wp)" +DEFAULT_DECLARE_SUFFIX = "before" + +logger = setup_logger(__name__) + + +def _extract_stencil_name(named_args: dict, directive: ts.ParsedDirective) -> str: + """Extract stencil name from directive arguments.""" + try: + stencil_name = named_args["name"] + except KeyError as e: + raise MissingDirectiveArgumentError( + f"Missing argument {e} in {directive.type_name} directive on line {directive.startln}." + ) + return stencil_name + + +def _extract_boolean_kwarg( + directive: ts.ParsedDirective, args: dict, arg_name: str +) -> Optional[bool]: + """Extract a boolean kwarg from the parsed dictionary. Kwargs are false by default.""" + if a := args.get(arg_name): + try: + return string_to_bool(a) + except Exception: + raise DirectiveSyntaxError( + f"Expected boolean string as value to keyword argument {arg_name} on line {directive.startln}. Got {a}" + ) + return False + + +class DirectiveInputFactory(Protocol): + def __call__( + self, parsed: ts.ParsedDict + ) -> list[CodeGenInput] | CodeGenInput | Type[UnusedDirective]: + ... + + +@dataclass +class DataFactoryBase: + directive_cls: Type[ts.ParsedDirective] + dtype: Type[CodeGenInput] + + +@dataclass +class OptionalMultiUseDataFactory(DataFactoryBase): + def __call__( + self, parsed: ts.ParsedDict, **kwargs: Any + ) -> Type[UnusedDirective] | list[CodeGenInput]: + extracted = extract_directive(parsed["directives"], self.directive_cls) + if len(extracted) < 1: + return UnusedDirective + else: + deserialised = [] + for directive in extracted: + deserialised.append( + self.dtype( + startln=directive.startln, endln=directive.endln, **kwargs + ) + ) + return deserialised + + +@dataclass +class RequiredSingleUseDataFactory(DataFactoryBase): + def __call__(self, parsed: ts.ParsedDict) -> CodeGenInput: + extracted = extract_directive(parsed["directives"], self.directive_cls)[0] + return self.dtype(startln=extracted.startln, endln=extracted.endln) + + +@dataclass +class StartCreateDataFactory(RequiredSingleUseDataFactory): + directive_cls: Type[ts.ParsedDirective] = ts.StartCreate + dtype: Type[StartCreateData] = StartCreateData + + +@dataclass +class EndCreateDataFactory(RequiredSingleUseDataFactory): + directive_cls: Type[ts.ParsedDirective] = ts.EndCreate + dtype: Type[EndCreateData] = EndCreateData + + +@dataclass +class ImportsDataFactory(RequiredSingleUseDataFactory): + directive_cls: Type[ts.ParsedDirective] = ts.Imports + dtype: Type[ImportsData] = ImportsData + + +@dataclass +class EndIfDataFactory(OptionalMultiUseDataFactory): + directive_cls: Type[ts.ParsedDirective] = ts.EndIf + dtype: Type[EndIfData] = EndIfData + + +@dataclass +class EndProfileDataFactory(OptionalMultiUseDataFactory): + directive_cls: Type[ts.ParsedDirective] = ts.EndProfile + dtype: Type[EndProfileData] = EndProfileData + + +def pop_item_from_dict(dictionary: dict, key: str, default_value: str) -> str: + return dictionary.pop(key, default_value) + + +@dataclass +class DeclareDataFactory(DataFactoryBase): + directive_cls: Type[ts.ParsedDirective] = ts.Declare + dtype: Type[DeclareData] = DeclareData + + @staticmethod + def get_field_dimensions(declarations: dict) -> dict[str, int]: + return {k: len(v.split(",")) for k, v in declarations.items()} + + def __call__(self, parsed: ts.ParsedDict) -> list[DeclareData]: + deserialised = [] + extracted = extract_directive(parsed["directives"], self.directive_cls) + for i, directive in enumerate(extracted): + named_args = parsed["content"]["Declare"][i] + ident_type = pop_item_from_dict( + named_args, "type", DEFAULT_DECLARE_IDENT_TYPE + ) + suffix = pop_item_from_dict(named_args, "suffix", DEFAULT_DECLARE_SUFFIX) + deserialised.append( + self.dtype( + startln=directive.startln, + endln=directive.endln, + declarations=named_args, + ident_type=ident_type, + suffix=suffix, + ) + ) + return deserialised + + +@dataclass +class StartProfileDataFactory(DataFactoryBase): + directive_cls: Type[ts.ParsedDirective] = ts.StartProfile + dtype: Type[StartProfileData] = StartProfileData + + def __call__(self, parsed: ts.ParsedDict) -> list[StartProfileData]: + deserialised = [] + extracted = extract_directive(parsed["directives"], self.directive_cls) + for i, directive in enumerate(extracted): + named_args = parsed["content"]["StartProfile"][i] + stencil_name = _extract_stencil_name(named_args, directive) + deserialised.append( + self.dtype( + name=stencil_name, + startln=directive.startln, + endln=directive.endln, + ) + ) + return deserialised + + +@dataclass +class EndStencilDataFactory(DataFactoryBase): + directive_cls: Type[ts.ParsedDirective] = ts.EndStencil + dtype: Type[EndStencilData] = EndStencilData + + def __call__(self, parsed: ts.ParsedDict) -> list[EndStencilData]: + deserialised = [] + extracted = extract_directive(parsed["directives"], self.directive_cls) + for i, directive in enumerate(extracted): + named_args = parsed["content"]["EndStencil"][i] + stencil_name = _extract_stencil_name(named_args, directive) + noendif = _extract_boolean_kwarg(directive, named_args, "noendif") + noprofile = _extract_boolean_kwarg(directive, named_args, "noprofile") + deserialised.append( + self.dtype( + name=stencil_name, + startln=directive.startln, + endln=directive.endln, + noendif=noendif, + noprofile=noprofile, + ) + ) + return deserialised + + +@dataclass +class StartStencilDataFactory(DataFactoryBase): + directive_cls: Type[ts.ParsedDirective] = ts.StartStencil + dtype: Type[StartStencilData] = StartStencilData + + def __call__(self, parsed: ts.ParsedDict) -> list[StartStencilData]: + """Create and return a list of StartStencilData objects from the parsed directives. + + Args: + parsed (ParsedDict): Dictionary of parsed directives and their associated content. + + Returns: + List[StartStencilData]: List of StartStencilData objects created from the parsed directives. + """ + deserialised = [] + field_dimensions = flatten_list_of_dicts( + [ + DeclareDataFactory.get_field_dimensions(dim) + for dim in parsed["content"]["Declare"] + ] + ) + directives = extract_directive(parsed["directives"], self.directive_cls) + for i, directive in enumerate(directives): + named_args = parsed["content"]["StartStencil"][i] + acc_present = string_to_bool( + pop_item_from_dict(named_args, "accpresent", "false") + ) + mergecopy = string_to_bool( + pop_item_from_dict(named_args, "mergecopy", "false") + ) + copies = string_to_bool(pop_item_from_dict(named_args, "copies", "true")) + stencil_name = _extract_stencil_name(named_args, directive) + bounds = self._make_bounds(named_args) + fields = self._make_fields(named_args, field_dimensions) + fields_w_tolerance = self._update_tolerances(named_args, fields) + + deserialised.append( + self.dtype( + name=stencil_name, + fields=fields_w_tolerance, + bounds=bounds, + startln=directive.startln, + endln=directive.endln, + acc_present=acc_present, + mergecopy=mergecopy, + copies=copies, + ) + ) + return deserialised + + @staticmethod + def _make_bounds(named_args: dict) -> BoundsData: + """Extract stencil bounds from directive arguments.""" + try: + bounds = BoundsData( + hlower=named_args["horizontal_lower"], + hupper=named_args["horizontal_upper"], + vlower=named_args["vertical_lower"], + vupper=named_args["vertical_upper"], + ) + except Exception: + raise MissingBoundsError( + f"Missing or invalid bounds provided in stencil: {named_args['name']}" + ) + return bounds + + def _make_fields( + self, named_args: dict[str, str], dimensions: dict + ) -> list[FieldAssociationData]: + """Extract all fields from directive arguments and create corresponding field association data.""" + field_args = self._create_field_args(named_args) + fields = self._make_field_associations(field_args, dimensions) + return fields + + @staticmethod + def _create_field_args(named_args: dict[str, str]) -> dict[str, str]: + """Create a dictionary of field names and their associations from named_args. + + Raises: + MissingDirectiveArgumentError: If a required argument is missing in the named_args. + """ + field_args = named_args.copy() + required_args = ( + "name", + "horizontal_lower", + "horizontal_upper", + "vertical_lower", + "vertical_upper", + ) + + for arg in required_args: + if arg not in field_args: + raise MissingDirectiveArgumentError( + f"Missing required argument '{arg}' in a StartStencil directive." + ) + else: + field_args.pop(arg) + + return field_args + + @staticmethod + def _make_field_associations( + field_args: dict[str, str], dimensions: dict + ) -> list[FieldAssociationData]: + """Create a list of FieldAssociation objects.""" + fields = [] + for field_name, association in field_args.items(): + + # skipped as handled by _update_field_tolerances + if any([field_name.endswith(tol) for tol in TOLERANCE_ARGS]): + continue + + field_association_data = FieldAssociationData( + variable=field_name, + association=association, + dims=dimensions.get(field_name), + ) + fields.append(field_association_data) + return fields + + @staticmethod + def _update_tolerances( + named_args: dict, fields: list[FieldAssociationData] + ) -> list[FieldAssociationData]: + """Set relative and absolute tolerance for a given field if set in the directives.""" + for field_name, association in named_args.items(): + for tol in TOLERANCE_ARGS: + + _tol = f"_{tol}" + + if field_name.endswith(_tol): + name = field_name.replace(_tol, "") + + for f in fields: + if f.variable == name: + setattr(f, tol, association) + return fields + + +@dataclass +class InsertDataFactory(DataFactoryBase): + directive_cls: Type[ts.ParsedDirective] = ts.Insert + dtype: Type[InsertData] = InsertData + + def __call__(self, parsed: ts.ParsedDict) -> list[InsertData]: + deserialised = [] + extracted = extract_directive(parsed["directives"], self.directive_cls) + for i, directive in enumerate(extracted): + content = parsed["content"]["Insert"][i] + deserialised.append( + self.dtype( + startln=directive.startln, endln=directive.endln, content=content # type: ignore + ) + ) + return deserialised + + +class DirectiveDeserialiser(Step): + _FACTORIES: dict[str, Callable] = { + "StartCreate": StartCreateDataFactory(), + "EndCreate": EndCreateDataFactory(), + "Imports": ImportsDataFactory(), + "Declare": DeclareDataFactory(), + "StartStencil": StartStencilDataFactory(), + "EndStencil": EndStencilDataFactory(), + "EndIf": EndIfDataFactory(), + "StartProfile": StartProfileDataFactory(), + "EndProfile": EndProfileDataFactory(), + "Insert": InsertDataFactory(), + } + + def __call__(self, directives: ts.ParsedDict) -> DeserialisedDirectives: + """Deserialise the provided parsed directives to a DeserialisedDirectives object. + + Args: + directives: The parsed directives to deserialise. + + Returns: + A DeserialisedDirectives object containing the deserialised directives. + + Note: + The method uses the `_FACTORIES` class attribute to create the appropriate + factory object for each directive type, and uses these objects to deserialise + the parsed directives. The DeserialisedDirectives class is a dataclass + containing the deserialised versions of the different directives. + """ + logger.info("Deserialising directives ...") + deserialised = dict() + + for key, func in self._FACTORIES.items(): + ser = func(directives) + deserialised[key] = ser + + return DeserialisedDirectives(**deserialised) diff --git a/liskov/src/icon4py/liskov/parsing/exceptions.py b/liskov/src/icon4py/liskov/parsing/exceptions.py new file mode 100644 index 0000000000..89b7b8b6b0 --- /dev/null +++ b/liskov/src/icon4py/liskov/parsing/exceptions.py @@ -0,0 +1,48 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +class UnsupportedDirectiveError(Exception): + pass + + +class DirectiveSyntaxError(Exception): + pass + + +class RepeatedDirectiveError(Exception): + pass + + +class RequiredDirectivesError(Exception): + pass + + +class UnbalancedStencilDirectiveError(Exception): + pass + + +class MissingBoundsError(Exception): + pass + + +class MissingDirectiveArgumentError(Exception): + pass + + +class IncompatibleFieldError(Exception): + pass + + +class UnknownStencilError(Exception): + pass diff --git a/liskov/src/icon4py/liskov/parsing/parse.py b/liskov/src/icon4py/liskov/parsing/parse.py new file mode 100644 index 0000000000..a5b96c87bd --- /dev/null +++ b/liskov/src/icon4py/liskov/parsing/parse.py @@ -0,0 +1,104 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +import collections +import sys +from pathlib import Path +from typing import Sequence + +import icon4py.liskov.parsing.types as ts +from icon4py.liskov.common import Step +from icon4py.liskov.logger import setup_logger +from icon4py.liskov.parsing.exceptions import UnsupportedDirectiveError +from icon4py.liskov.parsing.validation import VALIDATORS + + +REPLACE_CHARS = [ts.DIRECTIVE_IDENT, "&", "\n"] + +logger = setup_logger(__name__) + + +class DirectivesParser(Step): + def __init__(self, filepath: Path) -> None: + """Initialize a DirectivesParser instance. + + This class parses a Sequence of RawDirective objects and returns a dictionary of parsed directives and their associated content. + + Args: + directives: Sequence of RawDirective objects to parse. + filepath: Path to file being parsed. + """ + self.filepath = filepath + + def __call__(self, directives: list[ts.RawDirective]) -> ts.ParsedDict: + """Parse the directives and return a dictionary of parsed directives and their associated content. + + Returns: + ParsedType: Dictionary of parsed directives and their associated content. + """ + logger.info(f"Parsing DSL Preprocessor directives at {self.filepath}") + if len(directives) != 0: + typed = self._determine_type(directives) + preprocessed = self._preprocess(typed) + self._run_validation_passes(preprocessed) + return dict(directives=preprocessed, content=self._parse(preprocessed)) + logger.warning(f"No DSL Preprocessor directives found in {self.filepath}") + sys.exit() + + @staticmethod + def _determine_type( + raw_directives: Sequence[ts.RawDirective], + ) -> Sequence[ts.ParsedDirective]: + """Determine the type of each RawDirective and return a Sequence of ParsedDirective objects.""" + typed = [] + for raw in raw_directives: + found = False + for directive in ts.SUPPORTED_DIRECTIVES: + if directive.pattern in raw.string: + typed.append(directive(raw.string, raw.startln, raw.endln)) # type: ignore + found = True + break + if not found: + raise UnsupportedDirectiveError( + f"Used unsupported directive(s): {raw.string} on line(s) {raw.startln}." + ) + return typed + + def _preprocess( + self, directives: Sequence[ts.ParsedDirective] + ) -> Sequence[ts.ParsedDirective]: + """Preprocess the directives by removing unnecessary characters and formatting the directive strings.""" + return [ + d.__class__(self._clean_string(d.string), d.startln, d.endln) # type: ignore + for d in directives + ] + + def _run_validation_passes( + self, preprocessed: Sequence[ts.ParsedDirective] + ) -> None: + """Run validation passes on the directives.""" + for validator in VALIDATORS: + validator(self.filepath).validate(preprocessed) + + @staticmethod + def _clean_string(string: str) -> str: + """Remove leading or trailing whitespaces, and words from the REPLACE_CHARS list.""" + return " ".join([c for c in string.strip().split() if c not in REPLACE_CHARS]) + + @staticmethod + def _parse(directives: Sequence[ts.ParsedDirective]) -> ts.ParsedContent: + """Parse directives and return a dictionary of parsed directives type names and their associated content.""" + parsed_content = collections.defaultdict(list) + for d in directives: + content = d.get_content() + parsed_content[d.type_name].append(content) + return parsed_content diff --git a/liskov/src/icon4py/liskov/parsing/scan.py b/liskov/src/icon4py/liskov/parsing/scan.py new file mode 100644 index 0000000000..c5e8dd942b --- /dev/null +++ b/liskov/src/icon4py/liskov/parsing/scan.py @@ -0,0 +1,116 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import icon4py.liskov.parsing.types as ts +from icon4py.liskov.common import Step +from icon4py.liskov.logger import setup_logger +from icon4py.liskov.parsing.exceptions import DirectiveSyntaxError + + +logger = setup_logger(__name__) + + +@dataclass(frozen=True) +class Scanned: + string: str + lnumber: int + + +class DirectivesScanner(Step): + def __init__(self, filepath: Path) -> None: + r"""Class for scanning a file for ICON-Liskov DSL directives. + + A directive must start with !$DSL ( with the + directive arguments delimited by a ;. The directive if on multiple + lines must include a & at the end of the line. The directive + must always be closed by a closing bracket ). + + Example: + !$DSL IMPORTS() + + !$DSL START STENCIL(name=single; b=test) + + !$DSL START STENCIL(name=multi; &\n + !$DSL b=test) + + Args: + filepath: Path to file to scan for directives. + """ + self.filepath = filepath + + def __call__(self, data: Any = None) -> list[ts.RawDirective]: + """Scan filepath for directives and return them along with their line numbers. + + Returns: + A list of RawDirective objects containing the scanned directives and their line numbers. + """ + directives = [] + with self.filepath.open() as f: + + scanned_directives = [] + lines = f.readlines() + for lnumber, string in enumerate(lines): + + if ts.DIRECTIVE_IDENT in string: + stripped = string.strip() + eol = stripped[-1] + scanned = Scanned(string, lnumber) + scanned_directives.append(scanned) + + match eol: + case ")": + directives.append(self._process_scanned(scanned_directives)) + scanned_directives = [] + case "&": + next_line = self._peek_directive(lines, lnumber) + if ts.DIRECTIVE_IDENT not in next_line: + raise DirectiveSyntaxError( + f"Error in directive on line number: {lnumber + 1}\n Invalid use of & in single line " + f"directive. " + ) + continue + case _: + raise DirectiveSyntaxError( + f"Error in directive on line number: {lnumber + 1}\n Used invalid end of line character." + ) + logger.info(f"Scanning for directives at {self.filepath}") + return directives + + @staticmethod + def _process_scanned(collected: list[Scanned]) -> ts.RawDirective: + """Process a list of scanned directives. + + Returns + A RawDirective object containing the concatenated directive string and its line numbers. + """ + directive_string = "".join([c.string for c in collected]) + abs_startln, abs_endln = collected[0].lnumber, collected[-1].lnumber + return ts.RawDirective(directive_string, startln=abs_startln, endln=abs_endln) + + @staticmethod + def _peek_directive(lines: list[str], lnumber: int) -> str: + """Retrieve the next line in the input file. + + This method is used to check if a directive that spans multiple lines is still a valid directive. + + Args: + lines: List of lines from the input file. + lnumber: Line number of the current line being processed. + + Returns: + Next line in the input file. + """ + return lines[lnumber + 1] diff --git a/liskov/src/icon4py/liskov/parsing/types.py b/liskov/src/icon4py/liskov/parsing/types.py new file mode 100644 index 0000000000..b49664f026 --- /dev/null +++ b/liskov/src/icon4py/liskov/parsing/types.py @@ -0,0 +1,154 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +from dataclasses import dataclass, field +from typing import ( + Any, + Protocol, + Sequence, + Type, + TypeAlias, + TypedDict, + runtime_checkable, +) + + +DIRECTIVE_IDENT = "!$DSL" + + +@runtime_checkable +class ParsedDirective(Protocol): + string: str + startln: int + endln: int + pattern: str + regex: str + + @property + def type_name(self) -> str: + ... + + def get_content(self) -> Any: + ... + + +ParsedContent: TypeAlias = dict[str, list[dict[str, str]]] + + +class ParsedDict(TypedDict): + directives: Sequence[ParsedDirective] + content: ParsedContent + + +@dataclass +class RawDirective: + string: str + startln: int + endln: int + + +class TypedDirective(RawDirective): + pattern: str + + @property + def type_name(self) -> str: + return self.__class__.__name__ + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TypedDirective): + raise NotImplementedError + return self.string == other.string + + +@dataclass(eq=False) +class WithArguments(TypedDirective): + regex: str = field(default=r"(.+?)=(.+?)", init=False) + + def get_content(self) -> dict[str, str]: + args = self.string.replace(f"{self.pattern}", "") + delimited = args[1:-1].split(";") + content = {a.split("=")[0].strip(): a.split("=")[1] for a in delimited} + return content + + +@dataclass(eq=False) +class WithoutArguments(TypedDirective): + # matches an empty string at the beginning of a line + regex: str = field(default=r"^(?![\s\S])", init=False) + + def get_content(self) -> dict: + return {} + + +@dataclass(eq=False) +class FreeForm(TypedDirective): + # matches any string inside brackets + regex: str = field(default=r"(.+?)", init=False) + + def get_content(self) -> str: + args = self.string.replace(f"{self.pattern}", "") + return args[1:-1] + + +class StartStencil(WithArguments): + pattern = "START STENCIL" + + +class EndStencil(WithArguments): + pattern = "END STENCIL" + + +class Declare(WithArguments): + pattern = "DECLARE" + + +class Imports(WithoutArguments): + pattern = "IMPORTS" + + +class StartCreate(WithoutArguments): + pattern = "START CREATE" + + +class EndCreate(WithoutArguments): + pattern = "END CREATE" + + +class EndIf(WithoutArguments): + pattern = "ENDIF" + + +class StartProfile(WithArguments): + pattern = "START PROFILE" + + +class EndProfile(WithoutArguments): + pattern = "END PROFILE" + + +class Insert(FreeForm): + pattern = "INSERT" + + +# When adding a new directive this list must be updated. +SUPPORTED_DIRECTIVES: Sequence[Type[ParsedDirective]] = [ + StartStencil, + EndStencil, + Imports, + Declare, + StartCreate, + EndCreate, + EndIf, + StartProfile, + EndProfile, + Insert, +] diff --git a/liskov/src/icon4py/liskov/parsing/utils.py b/liskov/src/icon4py/liskov/parsing/utils.py new file mode 100644 index 0000000000..8a6f1d743b --- /dev/null +++ b/liskov/src/icon4py/liskov/parsing/utils.py @@ -0,0 +1,58 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +from typing import Sequence, Type + +import icon4py.liskov.parsing.types as ts + + +def flatten_list_of_dicts(list_of_dicts: list[dict]) -> dict: + """Flatten a list of dictionaries into a single dictionary.""" + if not isinstance(list_of_dicts, list): + raise TypeError("Input must be a list") + for d in list_of_dicts: + if not isinstance(d, dict): + raise TypeError("Input list must contain dictionaries only") + + return {k: v for d in list_of_dicts for k, v in d.items()} + + +def string_to_bool(string: str) -> bool: + """Convert a string representation of a boolean to a bool.""" + if string.lower() == "true": + return True + elif string.lower() == "false": + return False + else: + raise ValueError(f"Cannot convert '{string}' to a boolean.") + + +def print_parsed_directive(directive: ts.ParsedDirective) -> str: + """Print a parsed directive, including its contents, and start and end line numbers.""" + return f"Directive: {directive.string}, start line: {directive.startln}, end line: {directive.endln}\n" + + +def extract_directive( + directives: Sequence[ts.ParsedDirective], + required_type: Type[ts.ParsedDirective], +) -> Sequence[ts.ParsedDirective]: + """Extract a directive type from a list of directives.""" + directives = [d for d in directives if type(d) == required_type] + return directives + + +def remove_directive_types( + directives: Sequence[ts.ParsedDirective], + exclude_types: Sequence[Type[ts.ParsedDirective]], +) -> Sequence[ts.ParsedDirective]: + """Remove specified directive types from a list of directives.""" + return [d for d in directives if type(d) not in exclude_types] diff --git a/liskov/src/icon4py/liskov/parsing/validation.py b/liskov/src/icon4py/liskov/parsing/validation.py new file mode 100644 index 0000000000..616803e709 --- /dev/null +++ b/liskov/src/icon4py/liskov/parsing/validation.py @@ -0,0 +1,211 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import re +from abc import abstractmethod +from pathlib import Path +from typing import Match, Optional, Protocol + +import icon4py.liskov.parsing.types as ts +from icon4py.liskov.logger import setup_logger +from icon4py.liskov.parsing.exceptions import ( + DirectiveSyntaxError, + RepeatedDirectiveError, + RequiredDirectivesError, + UnbalancedStencilDirectiveError, +) +from icon4py.liskov.parsing.utils import ( + print_parsed_directive, + remove_directive_types, +) + + +logger = setup_logger(__name__) + + +class Validator(Protocol): + filepath: Path + + @abstractmethod + def validate(self, directives: list[ts.ParsedDirective]) -> None: + ... + + +class DirectiveSyntaxValidator: + """Validates syntax of preprocessor directives.""" + + def __init__(self, filepath: Path) -> None: + """Initialise a DirectiveSyntaxValidator. + + Args: + filepath: Path to file being parsed. + """ + self.filepath = filepath + self.exception_handler = SyntaxExceptionHandler + + def validate(self, directives: list[ts.ParsedDirective]) -> None: + """Validate the syntax of preprocessor directives. + + Checks that each directive's pattern and inner contents, if any, match the expected syntax. + If a syntax error is detected an appropriate exception using the exception_handler attribute + is raised. + + Args: + directives: A list of typed directives to validate. + """ + for d in directives: + self._validate_outer(d.string, d.pattern, d) + self._validate_inner(d.string, d.pattern, d) + + def _validate_outer( + self, to_validate: str, pattern: str, d: ts.ParsedDirective + ) -> None: + regex = f"{pattern}\\((.*)\\)" + match = re.fullmatch(regex, to_validate) + self.exception_handler.check_for_matches(d, match, regex, self.filepath) + + def _validate_inner( + self, to_validate: str, pattern: str, d: ts.ParsedDirective + ) -> None: + inner = to_validate.replace(f"{pattern}", "")[1:-1].split(";") + for arg in inner: + match = re.fullmatch(d.regex, arg) + self.exception_handler.check_for_matches(d, match, d.regex, self.filepath) + + +class DirectiveSemanticsValidator: + """Validates semantics of preprocessor directives.""" + + def __init__(self, filepath: Path) -> None: + """Initialise a DirectiveSyntaxValidator. + + Args: + filepath: Path to file being parsed. + """ + self.filepath = filepath + + def validate(self, directives: list[ts.ParsedDirective]) -> None: + """Validate the semantics of preprocessor directives. + + Checks that all used directives are unique, that all required directives + are used at least once, and that the number of start and end stencil directives match. + + Args: + directives: A list of typed directives to validate. + """ + self._validate_directive_uniqueness(directives) + self._validate_required_directives(directives) + self._validate_stencil_directives(directives) + + def _validate_directive_uniqueness( + self, directives: list[ts.ParsedDirective] + ) -> None: + """Check that all used directives are unique. + + Note: Allow repeated START STENCIL, END STENCIL and ENDIF directives. + """ + repeated = remove_directive_types( + [d for d in directives if directives.count(d) > 1], + [ + ts.StartStencil, + ts.EndStencil, + ts.EndIf, + ts.EndProfile, + ts.StartProfile, + ts.Insert, + ], + ) + if repeated: + pretty_printed = " ".join([print_parsed_directive(d) for d in repeated]) + raise RepeatedDirectiveError( + f"Error in {self.filepath}.\n Found same directive more than once in the following directives:\n {pretty_printed}" + ) + + def _validate_required_directives( + self, directives: list[ts.ParsedDirective] + ) -> None: + """Check that all required directives are used at least once.""" + expected = [ + ts.Declare, + ts.Imports, + ts.StartCreate, + ts.EndCreate, + ts.StartStencil, + ts.EndStencil, + ] + for expected_type in expected: + if not any([isinstance(d, expected_type) for d in directives]): + raise RequiredDirectivesError( + f"Error in {self.filepath}.\n Missing required directive of type {expected_type.pattern} in source." + ) + + @staticmethod + def extract_arg_from_directive(directive: str, arg: str) -> str: + match = re.search(f"{arg}=([^;)]+)", directive) + if match: + return match.group(1) + else: + raise ValueError( + f"Invalid directive string, could not find '{arg}' parameter." + ) + + def _validate_stencil_directives( + self, directives: list[ts.ParsedDirective] + ) -> None: + """Validate that the number of start and end stencil directives match in the input `directives`. + + Also verifies that each unique stencil has a corresponding start and end directive. + Raise an error if there are unbalanced START or END directives or if any unique stencil does not have corresponding start and end directive. + + Args: + directives (list[ts.ParsedDirective]): List of stencil directives to validate. + """ + stencil_directives = [ + d for d in directives if isinstance(d, (ts.StartStencil, ts.EndStencil)) + ] + stencil_counts: dict = {} + for directive in stencil_directives: + stencil_name = self.extract_arg_from_directive(directive.string, "name") + stencil_counts[stencil_name] = stencil_counts.get(stencil_name, 0) + ( + 1 if isinstance(directive, ts.StartStencil) else -1 + ) + + unbalanced_stencils = [ + stencil for stencil, count in stencil_counts.items() if count != 0 + ] + if unbalanced_stencils: + raise UnbalancedStencilDirectiveError( + f"Error in {self.filepath}. Each unique stencil must have a corresponding START STENCIL and END STENCIL directive." + f" Errors found in the following stencils: {', '.join(unbalanced_stencils)}" + ) + + +VALIDATORS: list = [ + DirectiveSyntaxValidator, + DirectiveSemanticsValidator, +] + + +class SyntaxExceptionHandler: + @staticmethod + def check_for_matches( + directive: ts.ParsedDirective, + match: Optional[Match[str]], + regex: str, + filepath: Path, + ) -> None: + if match is None: + raise DirectiveSyntaxError( + f"Error in {filepath} on line {directive.startln + 1}.\n {directive.string} is invalid, " + f"expected the following regex pattern {directive.pattern}({regex}).\n" + ) diff --git a/liskov/src/icon4py/liskov/pipeline.py b/liskov/src/icon4py/liskov/pipeline.py new file mode 100644 index 0000000000..bc6f57a2a5 --- /dev/null +++ b/liskov/src/icon4py/liskov/pipeline.py @@ -0,0 +1,83 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from pathlib import Path + +from icon4py.liskov.codegen.generate import IntegrationGenerator +from icon4py.liskov.codegen.interface import DeserialisedDirectives +from icon4py.liskov.codegen.write import IntegrationWriter +from icon4py.liskov.common import Step, linear_pipeline +from icon4py.liskov.external.gt4py import UpdateFieldsWithGt4PyStencils +from icon4py.liskov.parsing.deserialise import DirectiveDeserialiser +from icon4py.liskov.parsing.parse import DirectivesParser +from icon4py.liskov.parsing.scan import DirectivesScanner + + +@linear_pipeline +def parse_fortran_file(filepath: Path) -> list[Step]: + """Execute a pipeline to parse and deserialize directives from a file. + + The pipeline consists of three steps: DirectivesScanner, DirectivesParser, and + DirectiveDeserialiser. The DirectivesScanner scans the file for directives, + the DirectivesParser parses the directives into a dictionary, and the + DirectiveDeserialiser deserializes the dictionary into a + DeserialisedDirectives object. + + Args: + filepath (Path): The file path of the directives file. + + Returns: + DeserialisedDirectives: The deserialized directives object. + """ + return [ + DirectivesScanner(filepath), + DirectivesParser(filepath), + DirectiveDeserialiser(), + ] + + +@linear_pipeline +def load_gt4py_stencils(parsed: DeserialisedDirectives) -> list[Step]: + """Execute a pipeline to update fields of a DeserialisedDirectives object with GT4Py stencils. + + Args: + parsed: The input DeserialisedDirectives object. + + Returns: + The updated object with fields containing information from GT4Py stencils. + """ + return [UpdateFieldsWithGt4PyStencils(parsed)] + + +@linear_pipeline +def run_code_generation( + parsed: DeserialisedDirectives, + input_filepath: Path, + output_filepath: Path, + profile: bool, +) -> list[Step]: + """Execute a pipeline to generate and write code for a set of directives. + + The pipeline consists of two steps: IntegrationGenerator and IntegrationWriter. The IntegrationGenerator generates + code based on the parsed directives and profile flag. The IntegrationWriter writes the generated code to the + specified filepath. + + Args: + parsed: The deserialized directives object. + filepath: The file path to write the generated code to. + profile: A flag to indicate if profiling information should be included in the generated code. + """ + return [ + IntegrationGenerator(parsed, profile), + IntegrationWriter(input_filepath, output_filepath), + ] diff --git a/liskov/src/icon4py/liskov/py.typed.py b/liskov/src/icon4py/liskov/py.typed.py new file mode 100644 index 0000000000..15dfdb0098 --- /dev/null +++ b/liskov/src/icon4py/liskov/py.typed.py @@ -0,0 +1,12 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/liskov/tests/conftest.py b/liskov/tests/conftest.py new file mode 100644 index 0000000000..fd870ff584 --- /dev/null +++ b/liskov/tests/conftest.py @@ -0,0 +1,54 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from pathlib import Path + +import pytest +from click.testing import CliRunner + +import icon4py.liskov.parsing.types as ts +from icon4py.liskov.parsing.scan import DirectivesScanner + + +@pytest.fixture +def make_f90_tmpfile(tmp_path) -> Path: + """Fixture factory which creates a temporary Fortran file. + + Args: + content: Content to be present in the file. + """ + + def _make_f90_tmpfile(content: str): + fn = tmp_path / "tmp.f90" + with open(fn, "w") as f: + f.write(content) + return fn + + return _make_f90_tmpfile + + +@pytest.fixture +def cli(): + return CliRunner() + + +def scan_for_directives(fpath: Path) -> list[ts.RawDirective]: + collector = DirectivesScanner(fpath) + return collector() + + +def insert_new_lines(fname: Path, lines: list[str]) -> None: + """Append new lines into file.""" + with open(fname, "a") as f: + for ln in lines: + f.write(f"{ln}\n") diff --git a/liskov/tests/samples/fortran_samples.py b/liskov/tests/samples/fortran_samples.py new file mode 100644 index 0000000000..858b3a2089 --- /dev/null +++ b/liskov/tests/samples/fortran_samples.py @@ -0,0 +1,244 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +NO_DIRECTIVES_STENCIL = """\ + !$ACC PARALLEL LOOP DEFAULT(NONE) GANG VECTOR COLLAPSE(2) ASYNC(1) IF( i_am_accel_node .AND. acc_on ) + DO jk = 1, nlev + !DIR$ IVDEP + DO je = i_startidx, i_endidx + p_nh_prog%vn(je,jk,jb) = & + p_nh_prog%vn(je,jk,jb) + & + z_nabla2_e(je,jk,jb) * & + p_patch%edges%area_edge(je,jb)*fac_bdydiff_v + ENDDO + ENDDO + !$ACC END PARALLEL LOOP + """ + +SINGLE_STENCIL = """\ + !$DSL IMPORTS() + + !$DSL START CREATE() + + !$DSL DECLARE(vn=nproma,p_patch%nlev,p_patch%nblks_e; suffix=dsl) + + !$DSL DECLARE(vn=nproma,p_patch%nlev,p_patch%nblks_e; a=nproma,p_patch%nlev,p_patch%nblks_e; & + !$DSL b=nproma,p_patch%nlev,p_patch%nblks_e; type=REAL(vp)) + + !$DSL START STENCIL(name=apply_nabla2_to_vn_in_lateral_boundary; & + !$DSL z_nabla2_e=z_nabla2_e(:,:,1); area_edge=p_patch%edges%area_edge(:,1); & + !$DSL fac_bdydiff_v=fac_bdydiff_v; vn=p_nh_prog%vn(:,:,1); & + !$DSL vertical_lower=1; vertical_upper=nlev; & + !$DSL horizontal_lower=i_startidx; horizontal_upper=i_endidx; & + !$DSL accpresent=True) + !$OMP DO PRIVATE(je,jk,jb,i_startidx,i_endidx) ICON_OMP_DEFAULT_SCHEDULE + DO jb = i_startblk,i_endblk + + CALL get_indices_e(p_patch, jb, i_startblk, i_endblk, & + i_startidx, i_endidx, start_bdydiff_e, grf_bdywidth_e) + + !$ACC PARALLEL IF( i_am_accel_node .AND. acc_on ) DEFAULT(NONE) ASYNC(1) + vn_before(:,:,:) = p_nh_prog%vn(:,:,:) + !$ACC END PARALLEL + + !$ACC PARALLEL LOOP DEFAULT(NONE) GANG VECTOR COLLAPSE(2) ASYNC(1) IF( i_am_accel_node .AND. acc_on ) + DO jk = 1, nlev + !DIR$ IVDEP + DO je = i_startidx, i_endidx + p_nh_prog%vn(je,jk,jb) = & + p_nh_prog%vn(je,jk,jb) + & + z_nabla2_e(je,jk,jb) * & + p_patch%edges%area_edge(je,jb)*fac_bdydiff_v + ENDDO + ENDDO + !$DSL START PROFILE(name=apply_nabla2_to_vn_in_lateral_boundary) + !$ACC END PARALLEL LOOP + !$DSL END PROFILE() + !$DSL END STENCIL(name=apply_nabla2_to_vn_in_lateral_boundary; noprofile=True) + !$DSL END CREATE() + """ + +MULTIPLE_STENCILS = """\ + !$DSL IMPORTS() + + !$DSL START CREATE() + + !$DSL DECLARE(vn=nproma,p_patch%nlev,p_patch%nblks_e; z_rho_e=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL z_theta_v_e=nproma,p_patch%nlev,p_patch%nblks_c; z_nabla2_c=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL z_rth_pr_1=nproma,p_patch%nlev,p_patch%nblks_c; z_rth_pr_2=nproma,p_patch%nlev,p_patch%nblks_c; & + !$DSL rho_ic=nproma,p_patch%nlev,p_patch%nblks_c) + + !$DSL START STENCIL(name=mo_solve_nonhydro_stencil_08; wgtfac_c=p_nh%metrics%wgtfac_c(:,:,1); rho=p_nh%prog(nnow)%rho(:,:,1); rho_ref_mc=p_nh%metrics%rho_ref_mc(:,:,1); & + !$DSL theta_v=p_nh%prog(nnow)%theta_v(:,:,1); theta_ref_mc=p_nh%metrics%theta_ref_mc(:,:,1); rho_ic=p_nh%diag%rho_ic(:,:,1); z_rth_pr_1=z_rth_pr(:,:,1,1); & + !$DSL z_rth_pr_2=z_rth_pr(:,:,1,2); vertical_lower=2; vertical_upper=nlev; horizontal_lower=i_startidx; horizontal_upper=i_endidx) + !$ACC PARALLEL IF(i_am_accel_node) DEFAULT(NONE) ASYNC(1) + !$ACC LOOP GANG VECTOR TILE(32, 4) + DO jk = 2, nlev + !DIR$ IVDEP + DO jc = i_startidx, i_endidx + ! density at interface levels for vertical flux divergence computation + p_nh%diag%rho_ic(jc,jk,jb) = p_nh%metrics%wgtfac_c(jc,jk,jb) *p_nh%prog(nnow)%rho(jc,jk ,jb) + & + (1._wp-p_nh%metrics%wgtfac_c(jc,jk,jb))*p_nh%prog(nnow)%rho(jc,jk-1,jb) + + ! perturbation density and virtual potential temperature at main levels for horizontal flux divergence term + ! (needed in the predictor step only) + #ifdef __SWAPDIM + z_rth_pr(jc,jk,jb,1) = p_nh%prog(nnow)%rho(jc,jk,jb) - p_nh%metrics%rho_ref_mc(jc,jk,jb) + z_rth_pr(jc,jk,jb,2) = p_nh%prog(nnow)%theta_v(jc,jk,jb) - p_nh%metrics%theta_ref_mc(jc,jk,jb) + #else + z_rth_pr(1,jc,jk,jb) = p_nh%prog(nnow)%rho(jc,jk,jb) - p_nh%metrics%rho_ref_mc(jc,jk,jb) + z_rth_pr(2,jc,jk,jb) = p_nh%prog(nnow)%theta_v(jc,jk,jb) - p_nh%metrics%theta_ref_mc(jc,jk,jb) + #endif + #ifdef _OPENACC + ENDDO + ENDDO + !$ACC END PARALLEL + #endif + + !$DSL END STENCIL(name=mo_solve_nonhydro_stencil_08) + + + !$DSL START STENCIL(name=apply_nabla2_to_vn_in_lateral_boundary; & + !$DSL z_nabla2_e=z_nabla2_e(:,:,1); area_edge=p_patch%edges%area_edge(:,1); & + !$DSL fac_bdydiff_v=fac_bdydiff_v; vn=p_nh_prog%vn(:,:,1); vn_abs_tol=1e-21_wp; & + !$DSL vertical_lower=1; vertical_upper=nlev; & + !$DSL horizontal_lower=i_startidx; horizontal_upper=i_endidx) + !$ACC PARALLEL LOOP DEFAULT(NONE) GANG VECTOR COLLAPSE(2) ASYNC(1) IF( i_am_accel_node .AND. acc_on ) + DO jk = 1, nlev + !DIR$ IVDEP + DO je = i_startidx, i_endidx + p_nh_prog%vn(je,jk,jb) = & + p_nh_prog%vn(je,jk,jb) + & + z_nabla2_e(je,jk,jb) * & + p_patch%edges%area_edge(je,jb)*fac_bdydiff_v + ENDDO + ENDDO + !$ACC END PARALLEL LOOP + !$DSL END STENCIL(name=apply_nabla2_to_vn_in_lateral_boundary) + + !$DSL START STENCIL(name=calculate_nabla2_for_w; & + !$DSL w=p_nh_prog%w(:,:,1); geofac_n2s=p_int%geofac_n2s(:,:,1); & + !$DSL z_nabla2_c=z_nabla2_c(:,:,1); z_nabla2_c_abs_tol=1e-21_wp; & + !$DSL z_nabla2_c_rel_tol=1e-21_wp; & + !$DSL vertical_lower=1; vertical_upper=nlev; & + !$DSL horizontal_lower=i_startidx; horizontal_upper=i_endidx) + !$ACC PARALLEL LOOP DEFAULT(NONE) GANG VECTOR COLLAPSE(2) ASYNC(1) IF( i_am_accel_node .AND. acc_on ) + #ifdef __LOOP_EXCHANGE + DO jc = i_startidx, i_endidx + !DIR$ IVDEP + #ifdef _CRAYFTN + !DIR$ PREFERVECTOR + #endif + DO jk = 1, nlev + #else + DO jk = 1, nlev + DO jc = i_startidx, i_endidx + #endif + z_nabla2_c(jc,jk,jb) = & + p_nh_prog%w(jc,jk,jb) *p_int%geofac_n2s(jc,1,jb) + & + p_nh_prog%w(icidx(jc,jb,1),jk,icblk(jc,jb,1))*p_int%geofac_n2s(jc,2,jb) + & + p_nh_prog%w(icidx(jc,jb,2),jk,icblk(jc,jb,2))*p_int%geofac_n2s(jc,3,jb) + & + p_nh_prog%w(icidx(jc,jb,3),jk,icblk(jc,jb,3))*p_int%geofac_n2s(jc,4,jb) + ENDDO + ENDDO + !$ACC END PARALLEL LOOP + !$DSL ENDIF() + !$DSL END STENCIL(name=calculate_nabla2_for_w; noendif=true) + !$DSL END CREATE() + """ + +DIRECTIVES_SAMPLE = """\ +!$DSL IMPORTS() + +!$DSL START CREATE() + +!$DSL DECLARE(vn=p_patch%vn; vn2=p_patch%vn2) + +!$DSL START STENCIL(name=mo_nh_diffusion_06; vn=p_patch%vn; & +!$DSL a=a; b=c) + +!$DSL END STENCIL(name=mo_nh_diffusion_06) + +!$DSL START STENCIL(name=mo_nh_diffusion_07; xn=p_patch%xn) + +!$DSL END STENCIL(name=mo_nh_diffusion_07) + +!$DSL UNKNOWN_DIRECTIVE() +!$DSL END CREATE() +""" + +CONSECUTIVE_STENCIL = """\ + !$DSL IMPORTS() + + !$DSL START CREATE() + + !$DSL DECLARE(z_q=nproma,p_patch%nlev; z_alpha=nproma,p_patch%nlev) + + !$DSL START STENCIL(name=mo_solve_nonhydro_stencil_45; z_alpha=z_alpha(:,:); vertical_lower=nlevp1; & + !$DSL vertical_upper=nlevp1; horizontal_lower=i_startidx; horizontal_upper=i_endidx; mergecopy=true) + + !$DSL START STENCIL(name=mo_solve_nonhydro_stencil_45_b; z_q=z_q(:,:); vertical_lower=1; vertical_upper=1; & + !$DSL horizontal_lower=i_startidx; horizontal_upper=i_endidx; mergecopy=true) + + !$ACC PARALLEL IF(i_am_accel_node) DEFAULT(NONE) ASYNC(1) + !$ACC LOOP GANG VECTOR + DO jc = i_startidx, i_endidx + z_alpha(jc,nlevp1) = 0.0_wp + ! + ! Note: z_q is used in the tridiagonal matrix solver for w below. + ! z_q(1) is always zero, irrespective of w(1)=0 or w(1)/=0 + ! z_q(1)=0 is equivalent to cp(slev)=c(slev)/b(slev) in mo_math_utilities:tdma_solver_vec + z_q(jc,1) = 0._vp + ENDDO + !$ACC END PARALLEL + !$DSL END PROFILE() + !$DSL ENDIF() + + !$DSL END STENCIL(name=mo_solve_nonhydro_stencil_45; noendif=true; noprofile=true) + !$DSL END STENCIL(name=mo_solve_nonhydro_stencil_45_b; noendif=true; noprofile=true) + + !$DSL END CREATE() +""" + + +FREE_FORM_STENCIL = """\ + !$DSL IMPORTS() + + !$DSL START CREATE() + + !$DSL DECLARE(z_q=nproma,p_patch%nlev; z_alpha=nproma,p_patch%nlev) + + !$DSL INSERT(some custom fields go here) + + !$DSL START STENCIL(name=mo_solve_nonhydro_stencil_45; z_alpha=z_alpha(:,:); vertical_lower=nlevp1; & + !$DSL vertical_upper=nlevp1; horizontal_lower=i_startidx; horizontal_upper=i_endidx) + + !$ACC PARALLEL IF(i_am_accel_node) DEFAULT(NONE) ASYNC(1) + !$ACC LOOP GANG VECTOR + DO jc = i_startidx, i_endidx + z_alpha(jc,nlevp1) = 0.0_wp + ! + ! Note: z_q is used in the tridiagonal matrix solver for w below. + ! z_q(1) is always zero, irrespective of w(1)=0 or w(1)/=0 + ! z_q(1)=0 is equivalent to cp(slev)=c(slev)/b(slev) in mo_math_utilities:tdma_solver_vec + z_q(jc,1) = 0._vp + ENDDO + !$ACC END PARALLEL + + !$DSL INSERT(some custom code goes here) + + !$DSL END STENCIL(name=mo_solve_nonhydro_stencil_45) + + !$DSL END CREATE() +""" diff --git a/liskov/tests/test_cli.py b/liskov/tests/test_cli.py new file mode 100644 index 0000000000..75a6114da5 --- /dev/null +++ b/liskov/tests/test_cli.py @@ -0,0 +1,58 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pytest +from samples.fortran_samples import ( + CONSECUTIVE_STENCIL, + FREE_FORM_STENCIL, + MULTIPLE_STENCILS, + NO_DIRECTIVES_STENCIL, + SINGLE_STENCIL, +) + +from icon4py.liskov.cli import main + + +@pytest.fixture +def outfile(tmp_path): + return str(tmp_path / "gen.f90") + + +@pytest.mark.parametrize("file", [NO_DIRECTIVES_STENCIL]) +def test_cli_no_directives(make_f90_tmpfile, cli, file, outfile): + fpath = str(make_f90_tmpfile(content=file)) + result = cli.invoke(main, [fpath, outfile]) + assert result.exit_code == 0 + + +@pytest.mark.parametrize( + "file, profile", + [ + (NO_DIRECTIVES_STENCIL, False), + (SINGLE_STENCIL, False), + (CONSECUTIVE_STENCIL, False), + (FREE_FORM_STENCIL, False), + (MULTIPLE_STENCILS, False), + (SINGLE_STENCIL, True), + (CONSECUTIVE_STENCIL, True), + (FREE_FORM_STENCIL, True), + (MULTIPLE_STENCILS, True), + ], +) +def test_cli(make_f90_tmpfile, cli, file, outfile, profile): + fpath = str(make_f90_tmpfile(content=file)) + args = [fpath, outfile] + if profile: + args.append("--profile") + result = cli.invoke(main, args) + assert result.exit_code == 0 diff --git a/liskov/tests/test_deserialiser.py b/liskov/tests/test_deserialiser.py new file mode 100644 index 0000000000..c474dd2cfb --- /dev/null +++ b/liskov/tests/test_deserialiser.py @@ -0,0 +1,297 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import unittest + +import pytest + +import icon4py.liskov.parsing.types as ts +from icon4py.liskov.codegen.interface import ( + BoundsData, + DeclareData, + EndCreateData, + EndIfData, + EndProfileData, + EndStencilData, + FieldAssociationData, + ImportsData, + InsertData, + StartCreateData, + StartProfileData, +) +from icon4py.liskov.parsing.deserialise import ( + DeclareDataFactory, + EndCreateDataFactory, + EndIfDataFactory, + EndProfileDataFactory, + EndStencilDataFactory, + ImportsDataFactory, + InsertDataFactory, + StartCreateDataFactory, + StartProfileDataFactory, + StartStencilDataFactory, +) +from icon4py.liskov.parsing.exceptions import ( + DirectiveSyntaxError, + MissingBoundsError, + MissingDirectiveArgumentError, +) + + +@pytest.mark.parametrize( + "factory_class, directive_type, startln, endln, string, expected", + [ + ( + StartCreateDataFactory, + ts.StartCreate, + "START CREATE()", + 1, + 1, + StartCreateData, + ), + (EndCreateDataFactory, ts.EndCreate, "END CREATE", 2, 2, EndCreateData), + (ImportsDataFactory, ts.Imports, "IMPORTS", 3, 3, ImportsData), + (EndIfDataFactory, ts.EndIf, "ENDIF", 4, 4, EndIfData), + (EndProfileDataFactory, ts.EndProfile, "END PROFILE", 5, 5, EndProfileData), + ], +) +def test_data_factories_no_args( + factory_class, directive_type, string, startln, endln, expected +): + parsed = { + "directives": [directive_type(string=string, startln=startln, endln=endln)], + "content": {}, + } + factory = factory_class() + result = factory(parsed) + + if type(result) == list: + result = result[0] + + assert isinstance(result, expected) + assert result.startln == startln + assert result.endln == endln + + +@pytest.mark.parametrize( + "factory,target,mock_data", + [ + ( + EndStencilDataFactory, + EndStencilData, + { + "directives": [ + ts.EndStencil("END STENCIL(name=foo)", 5, 5), + ts.EndStencil( + "END STENCIL(name=bar; noendif=true; noprofile=true)", 20, 20 + ), + ], + "content": { + "EndStencil": [ + {"name": "foo"}, + {"name": "bar", "noendif": "true", "noprofile": "true"}, + ] + }, + }, + ), + ( + EndStencilDataFactory, + EndStencilData, + { + "directives": [ + ts.EndStencil("END STENCIL(name=foo; noprofile=true)", 5, 5) + ], + "content": {"EndStencil": [{"name": "foo"}]}, + }, + ), + ( + StartProfileDataFactory, + StartProfileData, + { + "directives": [ + ts.StartProfile("START PROFILE(name=foo)", 5, 5), + ts.StartProfile("START PROFILE(name=bar)", 20, 20), + ], + "content": {"StartProfile": [{"name": "foo"}, {"name": "bar"}]}, + }, + ), + ( + StartProfileDataFactory, + StartProfileData, + { + "directives": [ts.StartProfile("START PROFILE(name=foo)", 5, 5)], + "content": {"StartProfile": [{"name": "foo"}]}, + }, + ), + ( + DeclareDataFactory, + DeclareData, + { + "directives": [ + ts.Declare( + "DECLARE(vn=nlev,nblks_c; w=nlevp1,nblks_e; suffix=dsl; type=LOGICAL)", + 5, + 5, + ) + ], + "content": { + "Declare": [ + { + "vn": "nlev,nblks_c", + "w": "nlevp1,nblks_e", + "suffix": "dsl", + "type": "LOGICAL", + } + ] + }, + }, + ), + ( + InsertDataFactory, + InsertData, + { + "directives": [ts.Insert("INSERT(content=foo)", 5, 5)], + "content": {"Insert": ["foo"]}, + }, + ), + ], +) +def test_data_factories_with_args(factory, target, mock_data): + factory_init = factory() + result = factory_init(mock_data) + assert all([isinstance(r, target) for r in result]) + + +@pytest.mark.parametrize( + "factory,target,mock_data", + [ + ( + EndStencilDataFactory, + EndStencilData, + { + "directives": [ + ts.EndStencil("END STENCIL(name=foo)", 5, 5), + ts.EndStencil("END STENCIL(name=bar; noendif=foo)", 20, 20), + ], + "content": { + "EndStencil": [{"name": "foo"}, {"name": "bar", "noendif": "foo"}] + }, + }, + ), + ], +) +def test_data_factories_invalid_args(factory, target, mock_data): + factory_init = factory() + with pytest.raises(DirectiveSyntaxError): + factory_init(mock_data) + + +class TestStartStencilFactory(unittest.TestCase): + def setUp(self): + self.factory = StartStencilDataFactory() + self.mock_fields = [ + FieldAssociationData("x", "i", 3), + FieldAssociationData("y", "i", 3), + ] + + def test_get_bounds(self): + """Test that bounds are extracted correctly.""" + named_args = { + "name": "stencil1", + "horizontal_lower": 0, + "horizontal_upper": 10, + "vertical_lower": -5, + "vertical_upper": 15, + } + assert self.factory._make_bounds(named_args) == BoundsData(0, 10, -5, 15) + + def test_get_bounds_missing_bounds(self): + """Test that exception is raised if bounds are not provided.""" + named_args = {"name": "stencil1", "horizontal_upper": 10, "vertical_upper": 15} + with pytest.raises(MissingBoundsError): + self.factory._make_bounds(named_args) + + def test_get_field_associations(self): + """Test that field associations are extracted correctly.""" + named_args = { + "name": "mo_nh_diffusion_stencil_06", + "z_nabla2_e": "z_nabla2_e(:,:,1)", + "area_edge": "p_patch%edges%area_edge(:,1)", + "fac_bdydiff_v": "fac_bdydiff_v", + "vn": "p_nh_prog%vn(:,:,1)", + "vertical_lower": "1", + "vertical_upper": "nlev", + "horizontal_lower": "i_startidx", + "horizontal_upper": "i_endidx", + } + dimensions = {"z_nabla2_e": 3, "area_edge": 3, "fac_bdydiff_v": 3, "vn": 2} + + expected_fields = [ + FieldAssociationData("z_nabla2_e", "z_nabla2_e(:,:,1)", 3), + FieldAssociationData("area_edge", "p_patch%edges%area_edge(:,1)", 3), + FieldAssociationData("fac_bdydiff_v", "fac_bdydiff_v", 3), + FieldAssociationData("vn", "p_nh_prog%vn(:,:,1)", 2), + ] + assert self.factory._make_fields(named_args, dimensions) == expected_fields + + def test_missing_directive_argument_error(self): + """Test that exception is raised if 'name' argument is not provided.""" + named_args = { + "vn": "p_nh_prog%vn(:,:,1)", + "vertical_lower": "1", + "vertical_upper": "nlev", + "horizontal_lower": "i_startidx", + "horizontal_upper": "i_endidx", + } + dimensions = {"vn": 2} + with pytest.raises(MissingDirectiveArgumentError): + self.factory._make_fields(named_args, dimensions) + + def test_update_field_tolerances(self): + """Test that relative and absolute tolerances are set correctly for fields.""" + named_args = { + "x_rel_tol": "0.01", + "x_abs_tol": "0.1", + "y_rel_tol": "0.001", + } + expected_fields = [ + FieldAssociationData("x", "i", 3, rel_tol="0.01", abs_tol="0.1"), + FieldAssociationData("y", "i", 3, rel_tol="0.001"), + ] + assert ( + self.factory._update_tolerances(named_args, self.mock_fields) + == expected_fields + ) + + def test_update_field_tolerances_not_all_fields(self): + # Test that tolerance is not set for fields that are not provided in the named_args. + named_args = { + "x_rel_tol": "0.01", + "x_abs_tol": "0.1", + } + expected_fields = [ + FieldAssociationData("x", "i", 3, rel_tol="0.01", abs_tol="0.1"), + FieldAssociationData("y", "i", 3), + ] + assert ( + self.factory._update_tolerances(named_args, self.mock_fields) + == expected_fields + ) + + def test_update_field_tolerances_no_tolerances(self): + # Test that fields are not updated if named_args does not contain any tolerances. + named_args = {} + assert ( + self.factory._update_tolerances(named_args, self.mock_fields) + == self.mock_fields + ) diff --git a/liskov/tests/test_external.py b/liskov/tests/test_external.py new file mode 100644 index 0000000000..0d11730e99 --- /dev/null +++ b/liskov/tests/test_external.py @@ -0,0 +1,101 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import os +from pathlib import Path + +import pytest +from gt4py.next.ffront.decorator import Program + +from icon4py.liskov.codegen.interface import ( + DeserialisedDirectives, + FieldAssociationData, + StartStencilData, +) +from icon4py.liskov.external.gt4py import UpdateFieldsWithGt4PyStencils +from icon4py.liskov.parsing.exceptions import ( + IncompatibleFieldError, + UnknownStencilError, +) + + +def test_stencil_collector(): + name = "calculate_nabla4" + updater = UpdateFieldsWithGt4PyStencils(None) + assert isinstance(updater._collect_icon4py_stencil(name), Program) + + +def test_stencil_collector_invalid_module(): + name = "non_existent_module" + updater = UpdateFieldsWithGt4PyStencils(None) + with pytest.raises(UnknownStencilError, match=r"Did not find module: (\w*)"): + updater._collect_icon4py_stencil(name) + + +def test_stencil_collector_invalid_member(): + from icon4py.atm_dyn_iconam import apply_nabla2_to_w + + module_path = Path(apply_nabla2_to_w.__file__) + parents = module_path.parents[0] + + updater = UpdateFieldsWithGt4PyStencils(None) + + path = os.path.join(parents, "foo.py") + with open(path, "w") as f: + f.write("") + + with pytest.raises(UnknownStencilError, match=r"Did not find module member: (\w*)"): + updater._collect_icon4py_stencil("foo") + + os.remove(path) + + +mock_deserialised_directives = DeserialisedDirectives( + StartStencil=[ + StartStencilData( + name="apply_nabla2_to_w", + fields=[ + FieldAssociationData( + variable="incompatible_field_name", + association="z_nabla2_e(:,:,1)", + dims=None, + abs_tol=None, + rel_tol=None, + inp=None, + out=None, + ) + ], + bounds=None, + startln=None, + endln=None, + acc_present=False, + mergecopy=False, + copies=True, + ) + ], + Imports=None, + Declare=None, + EndStencil=None, + StartCreate=None, + EndCreate=None, + EndIf=None, + StartProfile=None, + EndProfile=None, + Insert=None, +) + + +def test_incompatible_field_error(): + updater = UpdateFieldsWithGt4PyStencils(mock_deserialised_directives) + with pytest.raises(IncompatibleFieldError): + updater() diff --git a/liskov/tests/test_generation.py b/liskov/tests/test_generation.py new file mode 100644 index 0000000000..2c8facda8a --- /dev/null +++ b/liskov/tests/test_generation.py @@ -0,0 +1,236 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +import pytest + +from icon4py.liskov.codegen.generate import IntegrationGenerator +from icon4py.liskov.codegen.interface import ( + BoundsData, + DeclareData, + DeserialisedDirectives, + EndCreateData, + EndIfData, + EndProfileData, + EndStencilData, + FieldAssociationData, + ImportsData, + InsertData, + StartCreateData, + StartProfileData, + StartStencilData, +) + + +# TODO: fix tests to adapt to new custom output fields +@pytest.fixture +def serialised_directives(): + start_stencil_data = StartStencilData( + name="stencil1", + fields=[ + FieldAssociationData("scalar1", "scalar1", inp=True, out=False, dims=None), + FieldAssociationData("inp1", "inp1(:,:,1)", inp=True, out=False, dims=2), + FieldAssociationData( + "out1", "out1(:,:,1)", inp=False, out=True, dims=2, abs_tol="0.5" + ), + FieldAssociationData( + "out2", + "p_nh%prog(nnew)%out2(:,:,1)", + inp=False, + out=True, + dims=3, + abs_tol="0.2", + ), + FieldAssociationData( + "out3", "p_nh%prog(nnew)%w(:,:,jb)", inp=False, out=True, dims=2 + ), + FieldAssociationData( + "out4", "p_nh%prog(nnew)%w(:,:,1,2)", inp=False, out=True, dims=3 + ), + FieldAssociationData( + "out5", "p_nh%prog(nnew)%w(:,:,:,ntnd)", inp=False, out=True, dims=3 + ), + FieldAssociationData( + "out6", "p_nh%prog(nnew)%w(:,:,1,ntnd)", inp=False, out=True, dims=3 + ), + ], + bounds=BoundsData("1", "10", "-1", "-10"), + startln=1, + endln=2, + acc_present=False, + mergecopy=False, + copies=True, + ) + end_stencil_data = EndStencilData( + name="stencil1", startln=3, endln=4, noendif=False, noprofile=False + ) + declare_data = DeclareData( + startln=5, + endln=6, + declarations={"field2": "(nproma, p_patch%nlev, p_patch%nblks_e)"}, + ident_type="REAL(wp)", + suffix="before", + ) + imports_data = ImportsData(startln=7, endln=8) + start_create_data = StartCreateData(startln=9, endln=10) + end_create_data = EndCreateData(startln=11, endln=11) + endif_data = EndIfData(startln=12, endln=12) + start_profile_data = StartProfileData(startln=13, endln=13, name="test_stencil") + end_profile_data = EndProfileData(startln=14, endln=14) + insert_data = InsertData(startln=15, endln=15, content="print *, 'Hello, World!'") + + return DeserialisedDirectives( + StartStencil=[start_stencil_data], + EndStencil=[end_stencil_data], + Declare=[declare_data], + Imports=imports_data, + StartCreate=start_create_data, + EndCreate=end_create_data, + EndIf=[endif_data], + StartProfile=[start_profile_data], + EndProfile=[end_profile_data], + Insert=[insert_data], + ) + + +@pytest.fixture +def expected_start_create_source(): + return """ +#ifdef __DSL_VERIFY + dsl_verify = .TRUE. +#else + dsl_verify = .FALSE. +#endif + + !$ACC DATA CREATE( & + !$ACC out1_before, & + !$ACC out2_before, & + !$ACC out3_before, & + !$ACC out4_before, & + !$ACC out5_before, & + !$ACC out6_before & + !$ACC ), & + !$ACC IF ( i_am_accel_node .AND. dsl_verify)""" + + +@pytest.fixture +def expected_end_create_source(): + return "!$ACC END DATA" + + +@pytest.fixture +def expected_imports_source(): + return " USE stencil1, ONLY: wrap_run_stencil1" + + +@pytest.fixture +def expected_declare_source(): + return """ + ! DSL INPUT / OUTPUT FIELDS + REAL(wp), DIMENSION((nproma, p_patch%nlev, p_patch%nblks_e)) :: field2_before""" + + +@pytest.fixture +def expected_start_stencil_source(): + return """ +#ifdef __DSL_VERIFY + !$ACC PARALLEL IF( i_am_accel_node ) DEFAULT(NONE) ASYNC(1) + out1_before(:, :) = out1(:, :, 1) + out2_before(:, :, :) = p_nh%prog(nnew)%out2(:, :, :) + out3_before(:, :) = p_nh%prog(nnew)%w(:, :, jb) + out4_before(:, :, :) = p_nh%prog(nnew)%w(:, :, :, 2) + out5_before(:, :, :) = p_nh%prog(nnew)%w(:, :, :, ntnd) + out6_before(:, :, :) = p_nh%prog(nnew)%w(:, :, :, ntnd) + !$ACC END PARALLEL + call nvtxStartRange("stencil1")""" + + +@pytest.fixture +def expected_end_stencil_source(): + return """ + call nvtxEndRange() +#endif + call wrap_run_stencil1( & + scalar1=scalar1, & + inp1=inp1(:, :, 1), & + out1=out1(:, :, 1), & + out1_before=out1_before(:, :), & + out2=p_nh%prog(nnew)%out2(:, :, 1), & + out2_before=out2_before(:, :, 1), & + out3=p_nh%prog(nnew)%w(:, :, jb), & + out3_before=out3_before(:, :), & + out4=p_nh%prog(nnew)%w(:, :, 1, 2), & + out4_before=out4_before(:, :, 1), & + out5=p_nh%prog(nnew)%w(:, :, :, ntnd), & + out5_before=out5_before(:, :, 1), & + out6=p_nh%prog(nnew)%w(:, :, 1, ntnd), & + out6_before=out6_before(:, :, 1), & + out1_abs_tol=0.5, & + out2_abs_tol=0.2, & + vertical_lower=-1, & + vertical_upper=-10, & + horizontal_lower=1, & + horizontal_upper=10)""" + + +@pytest.fixture +def expected_endif_source(): + return "#endif" + + +@pytest.fixture +def expected_start_profile_source(): + return 'call nvtxStartRange("test_stencil")' + + +@pytest.fixture +def expected_end_profile_source(): + return "call nvtxEndRange()" + + +@pytest.fixture +def expected_insert_source(): + return "print *, 'Hello, World!'" + + +@pytest.fixture +def generator(serialised_directives): + return IntegrationGenerator(serialised_directives, profile=True) + + +def test_generate( + generator, + expected_start_create_source, + expected_end_create_source, + expected_imports_source, + expected_declare_source, + expected_start_stencil_source, + expected_end_stencil_source, + expected_endif_source, + expected_start_profile_source, + expected_end_profile_source, + expected_insert_source, +): + # Check that the generated code snippets are as expected + generated = generator() + assert len(generated) == 10 + assert generated[0].source == expected_start_create_source + assert generated[1].source == expected_end_create_source + assert generated[2].source == expected_imports_source + assert generated[3].source == expected_declare_source + assert generated[4].source == expected_start_stencil_source + assert generated[5].source == expected_end_stencil_source + assert generated[6].source == expected_endif_source + assert generated[7].source == expected_start_profile_source + assert generated[8].source == expected_end_profile_source + assert generated[9].source == expected_insert_source diff --git a/liskov/tests/test_parser.py b/liskov/tests/test_parser.py new file mode 100644 index 0000000000..b4f9069e3d --- /dev/null +++ b/liskov/tests/test_parser.py @@ -0,0 +1,129 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + + +from collections import defaultdict + +import pytest +from conftest import insert_new_lines, scan_for_directives +from pytest import mark +from samples.fortran_samples import ( + MULTIPLE_STENCILS, + NO_DIRECTIVES_STENCIL, + SINGLE_STENCIL, +) + +import icon4py.liskov.parsing.types as ts +from icon4py.liskov.parsing.exceptions import UnsupportedDirectiveError +from icon4py.liskov.parsing.parse import DirectivesParser + + +def test_parse_no_input(): + directives = [] + assert DirectivesParser._parse(directives) == defaultdict(list) + + +@mark.parametrize( + "directive, string, startln, endln, expected_content", + [ + ( + ts.Imports("IMPORTS()", 1, 1), + "IMPORTS()", + 1, + 1, + defaultdict(list, {"Imports": [{}]}), + ), + ( + ts.StartCreate("START CREATE()", 2, 2), + "START CREATE()", + 2, + 2, + defaultdict(list, {"StartCreate": [{}]}), + ), + ( + ts.StartStencil( + "START STENCIL(name=mo_nh_diffusion_06; vn=p_patch%p%vn; foo=abc)", 3, 4 + ), + "START STENCIL(name=mo_nh_diffusion_06; vn=p_patch%p%vn; foo=abc)", + 3, + 4, + defaultdict( + list, + { + "StartStencil": [ + { + "name": "mo_nh_diffusion_06", + "vn": "p_patch%p%vn", + "foo": "abc", + } + ] + }, + ), + ), + ], +) +def test_parse_single_directive(directive, string, startln, endln, expected_content): + directives = [directive] + assert DirectivesParser._parse(directives) == expected_content + + +@mark.parametrize( + "stencil, num_directives, num_content", + [(SINGLE_STENCIL, 9, 8), (MULTIPLE_STENCILS, 11, 7)], +) +def test_file_parsing(make_f90_tmpfile, stencil, num_directives, num_content): + fpath = make_f90_tmpfile(content=stencil) + directives = scan_for_directives(fpath) + parser = DirectivesParser(fpath) + parsed = parser(directives) + + directives = parsed["directives"] + content = parsed["content"] + + assert len(directives) == num_directives + assert len(content) == num_content + + assert isinstance(content, defaultdict) + assert all([isinstance(d, ts.ParsedDirective) for d in directives]) + + +def test_directive_parser_no_directives_found(make_f90_tmpfile): + fpath = make_f90_tmpfile(content=NO_DIRECTIVES_STENCIL) + directives = scan_for_directives(fpath) + parser = DirectivesParser(fpath) + with pytest.raises(SystemExit): + parser(directives) + + +@mark.parametrize( + "stencil, directive", + [ + (SINGLE_STENCIL, "!$DSL FOO()"), + (MULTIPLE_STENCILS, "!$DSL BAR()"), + ], +) +def test_unsupported_directives( + make_f90_tmpfile, + stencil, + directive, +): + fpath = make_f90_tmpfile(content=stencil) + insert_new_lines(fpath, [directive]) + directives = scan_for_directives(fpath) + parser = DirectivesParser(fpath) + + with pytest.raises( + UnsupportedDirectiveError, + match=r"Used unsupported directive\(s\):.", + ): + parser(directives) diff --git a/liskov/tests/test_scanner.py b/liskov/tests/test_scanner.py new file mode 100644 index 0000000000..64dcc4aedf --- /dev/null +++ b/liskov/tests/test_scanner.py @@ -0,0 +1,91 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +import string +import tempfile +from pathlib import Path + +import pytest +from pytest import mark +from samples.fortran_samples import DIRECTIVES_SAMPLE, NO_DIRECTIVES_STENCIL + +from icon4py.liskov.parsing.exceptions import DirectiveSyntaxError +from icon4py.liskov.parsing.scan import DirectivesScanner +from icon4py.liskov.parsing.types import RawDirective + + +ALLOWED_EOL_CHARS = [")", "&"] + + +def scan_tempfile(string: str): + with tempfile.NamedTemporaryFile() as tmp: + tmp.write(string.encode()) + tmp.flush() + scanner = DirectivesScanner(Path(tmp.name)) + return scanner() + + +def special_char(): + def special_chars_generator(): + for char in string.punctuation: + yield char + + return special_chars_generator() + + +@mark.parametrize( + "string,expected", + [ + (NO_DIRECTIVES_STENCIL, []), + ( + DIRECTIVES_SAMPLE, + [ + RawDirective("!$DSL IMPORTS()\n", 0, 0), + RawDirective("!$DSL START CREATE()\n", 2, 2), + RawDirective("!$DSL DECLARE(vn=p_patch%vn; vn2=p_patch%vn2)\n", 4, 4), + RawDirective( + "!$DSL START STENCIL(name=mo_nh_diffusion_06; vn=p_patch%vn; &\n!$DSL a=a; b=c)\n", + 6, + 7, + ), + RawDirective("!$DSL END STENCIL(name=mo_nh_diffusion_06)\n", 9, 9), + RawDirective( + "!$DSL START STENCIL(name=mo_nh_diffusion_07; xn=p_patch%xn)\n", + 11, + 11, + ), + RawDirective("!$DSL END STENCIL(name=mo_nh_diffusion_07)\n", 13, 13), + RawDirective("!$DSL UNKNOWN_DIRECTIVE()\n", 15, 15), + RawDirective("!$DSL END CREATE()\n", 16, 16), + ], + ), + ], +) +def test_directives_scanning(string, expected): + scanned = scan_tempfile(string) + assert scanned == expected + + +@pytest.mark.parametrize("special_char", special_char()) +def test_directive_eol(special_char): + if special_char in ALLOWED_EOL_CHARS: + pytest.skip() + else: + directive = "!$DSL IMPORT(" + special_char + with pytest.raises(DirectiveSyntaxError): + scan_tempfile(directive) + + +def test_directive_unclosed(): + directive = "!$DSL IMPORT(&\n!CALL foo()" + with pytest.raises(DirectiveSyntaxError): + scan_tempfile(directive) diff --git a/liskov/tests/test_utils.py b/liskov/tests/test_utils.py new file mode 100644 index 0000000000..1a327af09d --- /dev/null +++ b/liskov/tests/test_utils.py @@ -0,0 +1,69 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +from copy import deepcopy + +import pytest + +import icon4py.liskov.parsing.types as ts +from icon4py.liskov.parsing.types import Imports, StartCreate +from icon4py.liskov.parsing.utils import ( + extract_directive, + print_parsed_directive, + remove_directive_types, + string_to_bool, +) + + +def test_extract_directive(): + directives = [ + Imports("IMPORTS()", 1, 1), + StartCreate("START CREATE()", 3, 4), + ] + + # Test that only the expected directive is extracted. + assert extract_directive(directives, Imports) == [directives[0]] + assert extract_directive(directives, StartCreate) == [directives[1]] + + +def test_remove_directive(): + directives = [ + Imports("IMPORTS()", 1, 1), + StartCreate("START CREATE()", 3, 4), + ] + new_directives = deepcopy(directives) + assert remove_directive_types(new_directives, [Imports]) == [directives[1]] + + +@pytest.mark.parametrize( + "string, expected", + [ + ("True", True), + ("TRUE", True), + ("false", False), + ("FALSE", False), + ("not a boolean", ValueError("Cannot convert 'not a boolean' to a boolean.")), + ], +) +def test_string_to_bool(string, expected): + if isinstance(expected, bool): + assert string_to_bool(string) == expected + else: + with pytest.raises(ValueError) as exc_info: + string_to_bool(string) + assert str(exc_info.value) == str(expected) + + +def test_print_parsed_directive(): + directive = ts.Imports("IMPORTS()", 1, 1) + expected_output = "Directive: IMPORTS(), start line: 1, end line: 1\n" + assert print_parsed_directive(directive) == expected_output diff --git a/liskov/tests/test_validation.py b/liskov/tests/test_validation.py new file mode 100644 index 0000000000..0e484b63f1 --- /dev/null +++ b/liskov/tests/test_validation.py @@ -0,0 +1,130 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pytest +from conftest import insert_new_lines, scan_for_directives +from pytest import mark +from samples.fortran_samples import MULTIPLE_STENCILS, SINGLE_STENCIL + +from icon4py.liskov.parsing.exceptions import ( + DirectiveSyntaxError, + RepeatedDirectiveError, + RequiredDirectivesError, + UnbalancedStencilDirectiveError, +) +from icon4py.liskov.parsing.parse import DirectivesParser +from icon4py.liskov.parsing.types import ( + Declare, + Imports, + StartCreate, + StartStencil, +) +from icon4py.liskov.parsing.validation import DirectiveSyntaxValidator + + +@mark.parametrize( + "stencil, directive", + [ + ( + MULTIPLE_STENCILS, + "!$DSL START STENCIL(name=foo)\n!$DSL END STENCIL(name=bar)", + ), + (MULTIPLE_STENCILS, "!$DSL END STENCIL(name=foo)"), + ], +) +def test_directive_semantics_validation_unbalanced_stencil_directives( + make_f90_tmpfile, stencil, directive +): + fpath = make_f90_tmpfile(stencil + directive) + directives = scan_for_directives(fpath) + parser = DirectivesParser(fpath) + + with pytest.raises(UnbalancedStencilDirectiveError): + parser(directives) + + +@mark.parametrize( + "directive", + ( + [StartStencil("!$DSL START STENCIL(name=foo, x=bar)", 0, 0)], + [StartStencil("!$DSL START STENCIL(name=foo, x=bar;)", 0, 0)], + [StartStencil("!$DSL START STENCIL(name=foo, x=bar;", 0, 0)], + [Declare("!$DSL DECLARE(name=foo; bar)", 0, 0)], + [Imports("!$DSL IMPORTS(foo)", 0, 0)], + [Imports("!$DSL IMPORTS())", 0, 0)], + [StartCreate("!$DSL START CREATE(;)", 0, 0)], + ), +) +def test_directive_syntax_validator(directive): + validator = DirectiveSyntaxValidator("test") + with pytest.raises(DirectiveSyntaxError, match=r"Error in .+ on line \d+\.\s+."): + validator.validate(directive) + + +@mark.parametrize( + "directive", + [ + "!$DSL IMPORTS()", + "!$DSL START CREATE()", + ], +) +def test_directive_semantics_validation_repeated_directives( + make_f90_tmpfile, directive +): + fpath = make_f90_tmpfile(content=SINGLE_STENCIL) + insert_new_lines(fpath, [directive]) + directives = scan_for_directives(fpath) + parser = DirectivesParser(fpath) + + with pytest.raises( + RepeatedDirectiveError, + match="Found same directive more than once in the following directives:\n", + ): + parser(directives) + + +@mark.parametrize( + "directive", + [ + "!$DSL START STENCIL(name=mo_nh_diffusion_stencil_06)\n!$DSL END STENCIL(name=mo_nh_diffusion_stencil_06)" + ], +) +def test_directive_semantics_validation_repeated_stencil(make_f90_tmpfile, directive): + fpath = make_f90_tmpfile(content=SINGLE_STENCIL) + insert_new_lines(fpath, [directive]) + directives = scan_for_directives(fpath) + parser = DirectivesParser(fpath) + parser(directives) + + +@mark.parametrize( + "directive", + [ + """!$DSL IMPORTS()""", + """!$DSL START CREATE()""", + """!$DSL END STENCIL(name=apply_nabla2_to_vn_in_lateral_boundary; noprofile=True)""", + ], +) +def test_directive_semantics_validation_required_directives( + make_f90_tmpfile, directive +): + new = SINGLE_STENCIL.replace(directive, "") + fpath = make_f90_tmpfile(content=new) + directives = scan_for_directives(fpath) + parser = DirectivesParser(fpath) + + with pytest.raises( + RequiredDirectivesError, + match=r"Missing required directive of type (\w.*) in source.", + ): + parser(directives) diff --git a/liskov/tests/test_writer.py b/liskov/tests/test_writer.py new file mode 100644 index 0000000000..2a91077910 --- /dev/null +++ b/liskov/tests/test_writer.py @@ -0,0 +1,87 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from pathlib import Path +from tempfile import TemporaryDirectory + +from icon4py.liskov.codegen.generate import GeneratedCode +from icon4py.liskov.codegen.write import DIRECTIVE_IDENT, IntegrationWriter + + +def test_write_from(): + # create temporary directory and file + with TemporaryDirectory() as temp_dir: + input_filepath = Path(temp_dir) / "test.f90" + output_filepath = input_filepath.with_suffix(".gen") + + with open(input_filepath, "w") as f: + f.write("!$DSL\n some code\n another line") + + # create an instance of IntegrationWriter and write generated code + generated = [GeneratedCode("generated code", 1, 3)] + integration_writer = IntegrationWriter(input_filepath, output_filepath) + integration_writer(generated) + + # check that the generated code was inserted into the file + with open(output_filepath, "r") as f: + content = f.read() + assert "generated code" in content + + # check that the directive was removed from the file + assert DIRECTIVE_IDENT not in content + + +def test_remove_directives(): + current_file = [ + "some code", + "!$DSL directive", + "another line", + "!$DSL another directive", + ] + expected_output = ["some code", "another line"] + assert IntegrationWriter._remove_directives(current_file) == expected_output + + +def test_insert_generated_code(): + current_file = ["some code", "another line"] + generated = [ + GeneratedCode("generated code2", 5, 6), + GeneratedCode("generated code1", 1, 3), + ] + expected_output = [ + "some code", + "generated code1\n", + "another line", + "generated code2\n", + ] + assert ( + IntegrationWriter._insert_generated_code(current_file, generated) + == expected_output + ) + + +def test_write_file(): + # create temporary directory and file + with TemporaryDirectory() as temp_dir: + input_filepath = Path(temp_dir) / "test.f90" + output_filepath = input_filepath.with_suffix(".gen") + + generated_code = ["some code", "another line"] + writer = IntegrationWriter(input_filepath, output_filepath) + writer._write_file(generated_code) + + # check that the generated code was written to the file + with open(output_filepath, "r") as f: + content = f.read() + assert "some code" in content + assert "another line" in content diff --git a/requirements-dev.txt b/requirements-dev.txt index bacaeba53c..0ae33918fe 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,3 +5,4 @@ -e ./pyutils -e ./testutils -e ./atm_dyn_iconam +-e ./liskov diff --git a/requirements.txt b/requirements.txt index 4e67c2a537..0331681cb8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ ./pyutils ./testutils ./atm_dyn_iconam +./liskov \ No newline at end of file