Skip to content

Commit

Permalink
Fix typing in icon4pytools (#267)
Browse files Browse the repository at this point in the history
Fix static typing annotations in icon4pytools.
  • Loading branch information
samkellerhals authored Sep 21, 2023
1 parent 32e8f6a commit 92e23fb
Show file tree
Hide file tree
Showing 33 changed files with 236 additions and 215 deletions.
1 change: 1 addition & 0 deletions base-requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,4 @@ setuptools>=40.8.0
wheel>=0.37.1
tox >= 3.25
wget>=3.2
types-cffi>=1.15
Empty file.
3 changes: 2 additions & 1 deletion model/common/src/icon4py/model/common/grid/vertical.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Final

import numpy as np
from gt4py.next import common
from gt4py.next.ffront.fbuiltins import int32

from icon4py.model.common.dimension import KDim
Expand All @@ -34,7 +35,7 @@ class VerticalModelParams:
rayleigh_damping_height: height of rayleigh damping in [m] mo_nonhydro_nml
"""

vct_a: Field[[KDim], float]
vct_a: common.Field
rayleigh_damping_height: Final[float]
index_of_damping_layer: Final[int32] = field(init=False)

Expand Down
4 changes: 3 additions & 1 deletion tools/.flake8
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ extend-ignore =
# Line too long (using Bugbear's B950 warning)
E501,
# Line break occurred before a binary operator
W503
W503,
# Calling setattr with a constant attribute value
B010

exclude =
.eggs,
Expand Down
72 changes: 28 additions & 44 deletions tools/src/icon4pytools/f2ser/deserialise.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,40 +31,32 @@ def __init__(
self.parsed = parsed
self.directory = directory
self.prefix = prefix
self.data = {"Savepoint": [], "Init": ..., "Import": ...}

def __call__(self) -> SerialisationCodeInterface:
"""Deserialise the parsed granule and returns a serialisation interface.
Returns:
A `SerialisationInterface` object representing the deserialised data.
"""
"""Deserialise the parsed granule and returns a serialisation interface."""
self._merge_out_inout_fields()
self._make_savepoints()
self._make_init_data()
self._make_imports()
return SerialisationCodeInterface(**self.data)
savepoints = self._make_savepoints()
init_data = self._make_init_data()
import_data = self._make_imports()
return SerialisationCodeInterface(Import=import_data, Init=init_data, Savepoint=savepoints)

def _make_savepoints(self) -> None:
"""Create savepoints for each subroutine and intent in the parsed granule.
def _make_savepoints(self) -> list[SavepointData]:
"""Create savepoints for each subroutine and intent in the parsed granule."""
savepoints: list[SavepointData] = []

Returns:
None.
"""
for subroutine_name, intent_dict in self.parsed.subroutines.items():
for intent, var_dict in intent_dict.items():
self._create_savepoint(subroutine_name, intent, var_dict)
savepoints.append(self._create_savepoint(subroutine_name, intent, var_dict))

def _create_savepoint(self, subroutine_name: str, intent: str, var_dict: dict) -> None:
return savepoints

def _create_savepoint(self, subroutine_name: str, intent: str, var_dict: dict) -> SavepointData:
"""Create a savepoint for the given variables.
Args:
subroutine_name: The name of the subroutine.
intent: The intent of the fields to be serialised.
var_dict: A dictionary representing the variables to be saved.
Returns:
None.
"""
field_vals = {k: v for k, v in var_dict.items() if isinstance(v, dict)}
fields = [
Expand All @@ -80,14 +72,12 @@ def _create_savepoint(self, subroutine_name: str, intent: str, var_dict: dict) -
for var_name, var_data in field_vals.items()
]

self.data["Savepoint"].append(
SavepointData(
subroutine=subroutine_name,
intent=intent,
startln=self._get_codegen_line(var_dict["codegen_ctx"], intent),
fields=fields,
metadata=None,
)
return SavepointData(
subroutine=subroutine_name,
intent=intent,
startln=self._get_codegen_line(var_dict["codegen_ctx"], intent),
fields=fields,
metadata=None,
)

@staticmethod
Expand Down Expand Up @@ -123,39 +113,33 @@ def _create_association(self, var_data: dict, var_name: str) -> str:
)
return var_name

def _make_init_data(self) -> None:
"""Create an `InitData` object and sets it to the `Init` key in the `data` dictionary.
Returns:
None.
"""
def _make_init_data(self) -> InitData:
"""Create an `InitData` object and sets it to the `Init` key in the `data` dictionary."""
first_intent_in_subroutine = [
var_dict
for intent_dict in self.parsed.subroutines.values()
for intent, var_dict in intent_dict.items()
if intent == "in"
][0]

startln = self._get_codegen_line(first_intent_in_subroutine["codegen_ctx"], "init")
self.data["Init"] = InitData(

return InitData(
startln=startln,
directory=self.directory,
prefix=self.prefix,
)

def _merge_out_inout_fields(self):
"""Merge the `inout` fields into the `in` and `out` fields in the `parsed` dictionary.
Returns:
None.
"""
def _merge_out_inout_fields(self) -> None:
"""Merge the `inout` fields into the `in` and `out` fields in the `parsed` dictionary."""
for _, intent_dict in self.parsed.subroutines.items():
if "inout" in intent_dict:
intent_dict["in"].update(intent_dict["inout"])
intent_dict["out"].update(intent_dict["inout"])
del intent_dict["inout"]

@staticmethod
def _get_codegen_line(ctx: CodegenContext, intent: str):
def _get_codegen_line(ctx: CodegenContext, intent: str) -> int:
if intent == "in":
return ctx.last_declaration_ln
elif intent == "out":
Expand All @@ -165,5 +149,5 @@ def _get_codegen_line(ctx: CodegenContext, intent: str):
else:
raise ValueError(f"Unrecognized intent: {intent}")

def _make_imports(self):
self.data["Import"] = ImportData(startln=self.parsed.last_import_ln)
def _make_imports(self) -> ImportData:
return ImportData(startln=self.parsed.last_import_ln)
31 changes: 13 additions & 18 deletions tools/src/icon4pytools/f2ser/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class CodegenContext:
end_subroutine_ln: int


ParsedSubroutines = dict[str, dict[str, dict[str, any] | CodegenContext]]
ParsedSubroutines = dict[str, dict[str, dict[str, Any]]]


@dataclass
Expand Down Expand Up @@ -69,36 +69,34 @@ def __call__(self) -> ParsedGranule:
last_import_ln = self._find_last_fortran_use_statement()
return ParsedGranule(subroutines=subroutines, last_import_ln=last_import_ln)

def _find_last_fortran_use_statement(self) -> Optional[int]:
def _find_last_fortran_use_statement(self) -> int:
"""Find the line number of the last Fortran USE statement in the code.
Returns:
int: the line number of the last USE statement, or None if no USE statement is found.
int: the line number of the last USE statement.
"""
# Reverse the order of the lines so we can search from the end
code = self._read_code_from_file()
code_lines = code.splitlines()
code_lines.reverse()

# Look for the last USE statement
use_ln = None
for i, line in enumerate(code_lines):
if line.strip().lower().startswith("use"):
use_ln = len(code_lines) - i
if i > 0 and code_lines[i - 1].strip().lower() == "#endif":
# If the USE statement is preceded by an #endif statement, return the line number after the #endif statement
return use_ln + 1
else:
return use_ln
return None
use_ln += 1
return use_ln
raise ParsingError("Could not find any USE statements.")

def _read_code_from_file(self) -> str:
"""Read the content of the granule and returns it as a string."""
with open(self.granule_path) as f:
code = f.read()
return code

def parse_subroutines(self):
def parse_subroutines(self) -> dict:
subroutines = self._extract_subroutines(crack(self.granule_path))
variables_grouped_by_intent = {
name: self._extract_intent_vars(routine) for name, routine in subroutines.items()
Expand Down Expand Up @@ -263,7 +261,7 @@ def _combine_types(derived_type_vars: dict, intrinsic_type_vars: dict) -> dict:
combined[subroutine_name][intent].update(new_vars)
return combined

def _update_with_codegen_lines(self, parsed_types: dict) -> dict:
def _update_with_codegen_lines(self, parsed_types: dict[str, Any]) -> dict[str, Any]:
"""Update the parsed_types dictionary with the line numbers for codegen.
Args:
Expand All @@ -285,9 +283,6 @@ def _get_subroutine_lines(self, subroutine_name: str) -> CodegenContext:
Args:
subroutine_name (str): Name of the subroutine to look for in the code.
Returns:
CodegenContext: Object containing the line number of the last declaration statement and the line number of the last line of the code before the end of the given subroutine.
"""
code = self._read_code_from_file()

Expand All @@ -312,7 +307,7 @@ def _get_subroutine_lines(self, subroutine_name: str) -> CodegenContext:
return CodegenContext(first_declaration_ln, last_declaration_ln, pre_end_subroutine_ln)

@staticmethod
def _find_subroutine_lines(code: str, subroutine_name: str) -> tuple[int]:
def _find_subroutine_lines(code: str, subroutine_name: str) -> tuple[int, int]:
"""Find line numbers of a subroutine within a code block.
Args:
Expand All @@ -327,15 +322,15 @@ def _find_subroutine_lines(code: str, subroutine_name: str) -> tuple[int]:
start_match = re.search(start_subroutine_pattern, code)
end_match = re.search(end_subroutine_pattern, code)
if start_match is None or end_match is None:
return None
raise ParsingError(f"Could not find {start_match} or {end_match}")
start_subroutine_ln = code[: start_match.start()].count("\n") + 1
end_subroutine_ln = code[: end_match.start()].count("\n") + 1
return start_subroutine_ln, end_subroutine_ln

@staticmethod
def _find_variable_declarations(
code: str, start_subroutine_ln: int, end_subroutine_ln: int
) -> list:
) -> list[int]:
"""Find line numbers of variable declarations within a code block.
Args:
Expand Down Expand Up @@ -371,8 +366,8 @@ def _find_variable_declarations(

@staticmethod
def _get_variable_declaration_bounds(
declaration_pattern_lines: list, start_subroutine_ln: int
) -> tuple:
declaration_pattern_lines: list[int], start_subroutine_ln: int
) -> tuple[int, int]:
"""Return the line numbers of the bounds for a variable declaration block.
Args:
Expand Down
7 changes: 3 additions & 4 deletions tools/src/icon4pytools/icon4pygen/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def _is_size_param(param: itir.Sym) -> bool:
@staticmethod
def _missing_domain_params(params: List[itir.Sym]) -> Iterable[itir.Sym]:
"""Get domain limit params that are not present in param list."""
return map(
lambda p: itir.Sym(id=p),
filter(lambda s: s not in map(lambda p: p.id, params), _DOMAIN_ARGS),
)
param_ids = [p.id for p in params]
missing_args = [s for s in _DOMAIN_ARGS if s not in param_ids]
return (itir.Sym(id=p) for p in missing_args)
4 changes: 2 additions & 2 deletions tools/src/icon4pytools/icon4pygen/bindings/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

from pathlib import Path
from typing import Sequence
from typing import Any, Sequence

from gt4py import eve
from gt4py.eve.codegen import JinjaTemplate as as_jinja
Expand Down Expand Up @@ -678,7 +678,7 @@ def _get_field_data(self) -> tuple:
)
return fields, offsets

def __post_init__(self) -> None: # type: ignore
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
fields, offsets = self._get_field_data()
offset_renderer = GpuTriMeshOffsetRenderer(self.offsets)

Expand Down
14 changes: 7 additions & 7 deletions tools/src/icon4pytools/icon4pygen/bindings/codegen/f90.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

from pathlib import Path
from typing import Sequence, Union
from typing import Any, Sequence, Union

from gt4py import eve
from gt4py.eve import Node
Expand Down Expand Up @@ -214,7 +214,7 @@ class F90RunFun(eve.Node):
params: F90EntityList = eve.datamodels.field(init=False)
binds: F90EntityList = eve.datamodels.field(init=False)

def __post_init__(self) -> None: # type: ignore
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
param_fields = [F90Field(name=field.name) for field in self.all_fields] + [
F90Field(name=name) for name in _DOMAIN_ARGS
]
Expand Down Expand Up @@ -242,7 +242,7 @@ class F90RunAndVerifyFun(eve.Node):
params: F90EntityList = eve.datamodels.field(init=False)
binds: F90EntityList = eve.datamodels.field(init=False)

def __post_init__(self) -> None: # type: ignore
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
param_fields = (
[F90Field(name=field.name) for field in self.all_fields]
+ [F90Field(name=field.name, suffix="before") for field in self.out_fields]
Expand Down Expand Up @@ -295,7 +295,7 @@ class F90SetupFun(Node):
params: F90EntityList = eve.datamodels.field(init=False)
binds: F90EntityList = eve.datamodels.field(init=False)

def __post_init__(self) -> None: # type: ignore
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
param_fields = [
F90Field(name=name)
for name in [
Expand Down Expand Up @@ -346,7 +346,7 @@ class F90WrapRunFun(Node):
run_ver_params: F90EntityList = eve.datamodels.field(init=False)
run_params: F90EntityList = eve.datamodels.field(init=False)

def __post_init__(self) -> None: # type: ignore
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
param_fields = (
[F90Field(name=field.name) for field in self.all_fields]
+ [F90Field(name=field.name, suffix="before") for field in self.out_fields]
Expand Down Expand Up @@ -457,7 +457,7 @@ class F90WrapSetupFun(Node):
vert_conditionals: F90EntityList = eve.datamodels.field(init=False)
setup_params: F90EntityList = eve.datamodels.field(init=False)

def __post_init__(self) -> None: # type: ignore
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
param_fields = [
F90Field(name=name)
for name in [
Expand Down Expand Up @@ -534,7 +534,7 @@ class F90File(Node):
wrap_run_fun: F90WrapRunFun = eve.datamodels.field(init=False)
wrap_setup_fun: F90WrapSetupFun = eve.datamodels.field(init=False)

def __post_init__(self) -> None: # type: ignore
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
all_fields = self.fields
out_fields = [field for field in self.fields if field.intent.out]
tol_fields = [field for field in out_fields if not field.is_integral()]
Expand Down
4 changes: 2 additions & 2 deletions tools/src/icon4pytools/icon4pygen/bindings/codegen/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# SPDX-License-Identifier: GPL-3.0-or-later

from pathlib import Path
from typing import Sequence
from typing import Any, Sequence

from gt4py import eve
from gt4py.eve import Node
Expand Down Expand Up @@ -149,7 +149,7 @@ class CppHeaderFile(Node):
setupFunc: CppSetupFuncDeclaration = eve.datamodels.field(init=False)
freeFunc: CppFreeFunc = eve.datamodels.field(init=False)

def __post_init__(self) -> None: # type: ignore
def __post_init__(self, *args: Any, **kwargs: Any) -> None:
output_fields = [field for field in self.fields if field.intent.out]
tolerance_fields = [field for field in output_fields if not field.is_integral()]

Expand Down
Loading

0 comments on commit 92e23fb

Please sign in to comment.