From 19644b467a34011d6303534d742e53db37ff0dbe Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 26 Nov 2024 22:35:02 +0100 Subject: [PATCH 01/11] refactor[next]: use eve.datamodel for types --- src/gt4py/eve/datamodels/core.py | 8 ++- .../ffront/foast_passes/type_deduction.py | 11 ++-- .../next/ffront/past_passes/type_deduction.py | 2 +- src/gt4py/next/ffront/past_process_args.py | 1 + src/gt4py/next/ffront/type_info.py | 4 +- src/gt4py/next/ffront/type_specifications.py | 8 +-- src/gt4py/next/iterator/embedded.py | 11 ++-- .../iterator/transforms/fuse_as_fieldop.py | 9 +-- .../next/iterator/type_system/inference.py | 4 +- .../type_system/type_specifications.py | 18 +----- .../iterator/type_system/type_synthesizer.py | 26 ++++---- src/gt4py/next/otf/binding/nanobind.py | 1 + .../codegens/gtfn/itir_to_gtfn_ir.py | 2 +- .../gtir_builtin_translators.py | 14 +++-- .../runners/dace_fieldview/gtir_dataflow.py | 29 +++++---- .../runners/dace_fieldview/gtir_sdfg.py | 2 + .../runners/dace_fieldview/utility.py | 7 ++- .../runners/dace_iterator/__init__.py | 1 + .../runners/dace_iterator/itir_to_sdfg.py | 4 +- .../runners/dace_iterator/itir_to_tasklet.py | 1 + src/gt4py/next/type_system/type_info.py | 28 ++++----- .../next/type_system/type_specifications.py | 61 ++++++++----------- .../next/type_system/type_translation.py | 8 +-- tests/eve_tests/unit_tests/test_datamodels.py | 3 +- .../iterator_tests/test_type_inference.py | 18 +++--- 25 files changed, 133 insertions(+), 148 deletions(-) diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index d596f59cfb..a45fbb821e 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -16,6 +16,7 @@ import dataclasses import functools import sys +import types import typing import warnings @@ -1040,6 +1041,11 @@ def _make_datamodel( for key in annotations: type_hint = annotations[key] = resolved_annotations[key] + if isinstance( + type_hint, types.UnionType + ): # see https://github.com/python/cpython/issues/105499 + type_hint = typing.Union[type_hint.__args__] + # Skip members annotated as class variables if type_hint is ClassVar or xtyping.get_origin(type_hint) is ClassVar: continue @@ -1255,7 +1261,7 @@ def _make_concrete_with_cache( raise TypeError(f"'{datamodel_cls.__name__}' is not a generic model class.") for t in type_args: if not ( - isinstance(t, (type, type(None), xtyping.StdGenericAliasType)) + isinstance(t, (type, type(None), xtyping.StdGenericAliasType, types.UnionType)) or (getattr(type(t), "__module__", None) in ("typing", "typing_extensions")) ): raise TypeError( diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index d334487ae1..4495717173 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -6,7 +6,7 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Optional, TypeVar, cast +from typing import Any, Optional, TypeAlias, TypeVar, cast import gt4py.next.ffront.field_operator_ast as foast from gt4py.eve import NodeTranslator, NodeVisitor, traits @@ -48,7 +48,7 @@ def with_altered_scalar_kind( if isinstance(type_spec, ts.FieldType): return ts.FieldType( dims=type_spec.dims, - dtype=ts.ScalarType(kind=new_scalar_kind, shape=type_spec.dtype.shape), + dtype=with_altered_scalar_kind(type_spec.dtype, new_scalar_kind), ) elif isinstance(type_spec, ts.ScalarType): return ts.ScalarType(kind=new_scalar_kind, shape=type_spec.shape) @@ -360,7 +360,7 @@ def visit_Assign(self, node: foast.Assign, **kwargs: Any) -> foast.Assign: def visit_TupleTargetAssign( self, node: foast.TupleTargetAssign, **kwargs: Any ) -> foast.TupleTargetAssign: - TargetType = list[foast.Starred | foast.Symbol] + TargetType: TypeAlias = list[foast.Starred | foast.Symbol] values = self.visit(node.value, **kwargs) if isinstance(values.type, ts.TupleType): @@ -374,7 +374,7 @@ def visit_TupleTargetAssign( ) new_targets: TargetType = [] - new_type: ts.TupleType | ts.DataType + new_type: ts.DataType for i, index in enumerate(indices): old_target = targets[i] @@ -391,7 +391,8 @@ def visit_TupleTargetAssign( location=old_target.location, ) else: - new_type = values.type.types[index] + new_type = values.type.types[index] # type: ignore[assignment] # see check in next line + assert isinstance(new_type, ts.DataType) new_target = self.visit( old_target, refine_type=new_type, location=old_target.location, **kwargs ) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 92f7327218..31998a6d26 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -108,7 +108,7 @@ def visit_Subscript(self, node: past.Subscript, **kwargs: Any) -> past.Subscript value = self.visit(node.value, **kwargs) return past.Subscript( value=value, - slice_=self.visit(node.slice_, **kwargs), + slice_=node.slice_, # TODO: I don't think we are using this type type=value.type, location=node.location, ) diff --git a/src/gt4py/next/ffront/past_process_args.py b/src/gt4py/next/ffront/past_process_args.py index 7958b7a8d3..1add668791 100644 --- a/src/gt4py/next/ffront/past_process_args.py +++ b/src/gt4py/next/ffront/past_process_args.py @@ -109,6 +109,7 @@ def _field_constituents_shape_and_dims( match arg_type: case ts.TupleType(): for el, el_type in zip(arg, arg_type.types): + assert isinstance(el_type, ts.DataType) yield from _field_constituents_shape_and_dims(el, el_type) case ts.FieldType(): dims = type_info.extract_dims(arg_type) diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index 8160a2c42d..d0c1e9bbfd 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -252,8 +252,8 @@ def function_signature_incompatibilities_scanop( # build a function type to leverage the already existing signature checking capabilities function_type = ts.FunctionType( pos_only_args=[], - pos_or_kw_args=promoted_params, # type: ignore[arg-type] # dict is invariant, but we don't care here. - kw_only_args=promoted_kwparams, # type: ignore[arg-type] # same as above + pos_or_kw_args=promoted_params, + kw_only_args=promoted_kwparams, returns=ts.DeferredType(constraint=None), ) diff --git a/src/gt4py/next/ffront/type_specifications.py b/src/gt4py/next/ffront/type_specifications.py index e4f6c826fe..b76a116297 100644 --- a/src/gt4py/next/ffront/type_specifications.py +++ b/src/gt4py/next/ffront/type_specifications.py @@ -6,23 +6,19 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from dataclasses import dataclass import gt4py.next.type_system.type_specifications as ts -from gt4py.next import common as func_common +from gt4py.next import common -@dataclass(frozen=True) class ProgramType(ts.TypeSpec, ts.CallableType): definition: ts.FunctionType -@dataclass(frozen=True) class FieldOperatorType(ts.TypeSpec, ts.CallableType): definition: ts.FunctionType -@dataclass(frozen=True) class ScanOperatorType(ts.TypeSpec, ts.CallableType): - axis: func_common.Dimension + axis: common.Dimension definition: ts.FunctionType diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 3c63ffef30..83f278afc5 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -54,7 +54,6 @@ ) from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import builtins, runtime -from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.otf import arguments from gt4py.next.type_system import type_specifications as ts, type_translation @@ -1460,7 +1459,7 @@ class _List(Generic[DT]): def __getitem__(self, i: int): return self.values[i] - def __gt_type__(self) -> itir_ts.ListType: + def __gt_type__(self) -> ts.ListType: offset_tag = self.offset.value assert isinstance(offset_tag, str) element_type = type_translation.from_value(self.values[0]) @@ -1470,7 +1469,7 @@ def __gt_type__(self) -> itir_ts.ListType: connectivity = offset_provider[offset_tag] assert common.is_neighbor_connectivity(connectivity) local_dim = connectivity.__gt_type__().neighbor_dim - return itir_ts.ListType(element_type=element_type, offset_type=local_dim) + return ts.ListType(element_type=element_type, offset_type=local_dim) @dataclasses.dataclass(frozen=True) @@ -1480,10 +1479,10 @@ class _ConstList(Generic[DT]): def __getitem__(self, _): return self.value - def __gt_type__(self) -> itir_ts.ListType: + def __gt_type__(self) -> ts.ListType: element_type = type_translation.from_value(self.value) assert isinstance(element_type, ts.DataType) - return itir_ts.ListType( + return ts.ListType( element_type=element_type, offset_type=_CONST_DIM, ) @@ -1799,7 +1798,7 @@ def _fieldspec_list_to_value( domain: common.Domain, type_: ts.TypeSpec ) -> tuple[common.Domain, ts.TypeSpec]: """Translate the list element type into the domain.""" - if isinstance(type_, itir_ts.ListType): + if isinstance(type_, ts.ListType): if type_.offset_type == _CONST_DIM: return domain.insert( len(domain), common.named_range((_CONST_DIM, 1)) diff --git a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py index 9076bf2d3f..848655da19 100644 --- a/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py +++ b/src/gt4py/next/iterator/transforms/fuse_as_fieldop.py @@ -20,10 +20,7 @@ inline_lifts, trace_shifts, ) -from gt4py.next.iterator.type_system import ( - inference as type_inference, - type_specifications as it_ts, -) +from gt4py.next.iterator.type_system import inference as type_inference from gt4py.next.type_system import type_info, type_specifications as ts @@ -178,7 +175,7 @@ def visit_FunCall(self, node: itir.FunCall): and isinstance(arg.fun.args[0], itir.Lambda) or cpm.is_call_to(arg, "if_") ) - and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1) + and (isinstance(dtype, ts.ListType) or len(arg_shifts) <= 1) ) if should_inline: if cpm.is_applied_as_fieldop(arg): @@ -199,7 +196,7 @@ def visit_FunCall(self, node: itir.FunCall): new_args = _merge_arguments(new_args, extracted_args) else: - assert not isinstance(dtype, it_ts.ListType) + assert not isinstance(dtype, ts.ListType) new_param: str if isinstance( arg, itir.SymRef diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 987eb0f308..2247b35e32 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -272,7 +272,7 @@ def _get_dimensions(obj: Any): if isinstance(obj, common.Dimension): yield obj elif isinstance(obj, ts.TypeSpec): - for field in dataclasses.fields(obj.__class__): + for field in obj.__datamodel_fields__.values(): yield from _get_dimensions(getattr(obj, field.name)) elif isinstance(obj, collections.abc.Mapping): for el in obj.values(): @@ -490,7 +490,7 @@ def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.Tup assert domain.dims != "unknown" assert node.dtype return type_info.apply_to_primitive_constituents( - lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), # type: ignore[arg-type] # ensured by domain.dims != "unknown" above + lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), node.dtype, ) diff --git a/src/gt4py/next/iterator/type_system/type_specifications.py b/src/gt4py/next/iterator/type_system/type_specifications.py index edb56f5659..c0b2c7125b 100644 --- a/src/gt4py/next/iterator/type_system/type_specifications.py +++ b/src/gt4py/next/iterator/type_system/type_specifications.py @@ -6,44 +6,30 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import dataclasses -from typing import Literal, Optional +from typing import Literal from gt4py.next import common from gt4py.next.type_system import type_specifications as ts -@dataclasses.dataclass(frozen=True) class NamedRangeType(ts.TypeSpec): dim: common.Dimension -@dataclasses.dataclass(frozen=True) class DomainType(ts.DataType): dims: list[common.Dimension] | Literal["unknown"] -@dataclasses.dataclass(frozen=True) class OffsetLiteralType(ts.TypeSpec): value: ts.ScalarType | common.Dimension -@dataclasses.dataclass(frozen=True) -class ListType(ts.DataType): - element_type: ts.DataType - # TODO(havogt): the `offset_type` is not yet used in type_inference, - # it is meant to describe the neighborhood (via the local dimension) - offset_type: Optional[common.Dimension] = None - - -@dataclasses.dataclass(frozen=True) class IteratorType(ts.DataType, ts.CallableType): position_dims: list[common.Dimension] | Literal["unknown"] defined_dims: list[common.Dimension] element_type: ts.DataType -@dataclasses.dataclass(frozen=True) class StencilClosureType(ts.TypeSpec): domain: DomainType stencil: ts.FunctionType @@ -61,12 +47,10 @@ def __post_init__(self): # TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere -@dataclasses.dataclass(frozen=True) class FencilType(ts.TypeSpec): params: dict[str, ts.DataType] closures: list[StencilClosureType] -@dataclasses.dataclass(frozen=True) class ProgramType(ts.TypeSpec): params: dict[str, ts.DataType] diff --git a/src/gt4py/next/iterator/type_system/type_synthesizer.py b/src/gt4py/next/iterator/type_system/type_synthesizer.py index 5be9ed7438..22a04ec04a 100644 --- a/src/gt4py/next/iterator/type_system/type_synthesizer.py +++ b/src/gt4py/next/iterator/type_system/type_synthesizer.py @@ -155,18 +155,18 @@ def if_(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType @_register_builtin_type_synthesizer -def make_const_list(scalar: ts.ScalarType) -> it_ts.ListType: +def make_const_list(scalar: ts.ScalarType) -> ts.ListType: assert isinstance(scalar, ts.ScalarType) - return it_ts.ListType(element_type=scalar) + return ts.ListType(element_type=scalar) @_register_builtin_type_synthesizer -def list_get(index: ts.ScalarType | it_ts.OffsetLiteralType, list_: it_ts.ListType) -> ts.DataType: +def list_get(index: ts.ScalarType | it_ts.OffsetLiteralType, list_: ts.ListType) -> ts.DataType: if isinstance(index, it_ts.OffsetLiteralType): assert isinstance(index.value, ts.ScalarType) index = index.value assert isinstance(index, ts.ScalarType) and type_info.is_integral(index) - assert isinstance(list_, it_ts.ListType) + assert isinstance(list_, ts.ListType) return list_.element_type @@ -198,14 +198,14 @@ def index(arg: ts.DimensionType) -> ts.FieldType: @_register_builtin_type_synthesizer -def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> it_ts.ListType: +def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> ts.ListType: assert ( isinstance(offset_literal, it_ts.OffsetLiteralType) and isinstance(offset_literal.value, common.Dimension) and offset_literal.value.kind == common.DimensionKind.LOCAL ) assert isinstance(it, it_ts.IteratorType) - return it_ts.ListType(element_type=it.element_type) + return ts.ListType(element_type=it.element_type) @_register_builtin_type_synthesizer @@ -270,7 +270,7 @@ def _convert_as_fieldop_input_to_iterator( else: defined_dims.append(dim) if is_nb_field: - element_type = it_ts.ListType(element_type=element_type) + element_type = ts.ListType(element_type=element_type) return it_ts.IteratorType( position_dims=domain.dims, defined_dims=defined_dims, element_type=element_type @@ -342,14 +342,14 @@ def apply_scan( def map_(op: TypeSynthesizer) -> TypeSynthesizer: @TypeSynthesizer def applied_map( - *args: it_ts.ListType, offset_provider_type: common.OffsetProviderType - ) -> it_ts.ListType: + *args: ts.ListType, offset_provider_type: common.OffsetProviderType + ) -> ts.ListType: assert len(args) > 0 - assert all(isinstance(arg, it_ts.ListType) for arg in args) + assert all(isinstance(arg, ts.ListType) for arg in args) arg_el_types = [arg.element_type for arg in args] el_type = op(*arg_el_types, offset_provider_type=offset_provider_type) assert isinstance(el_type, ts.DataType) - return it_ts.ListType(element_type=el_type) + return ts.ListType(element_type=el_type) return applied_map @@ -357,8 +357,8 @@ def applied_map( @_register_builtin_type_synthesizer def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer: @TypeSynthesizer - def applied_reduce(*args: it_ts.ListType, offset_provider_type: common.OffsetProviderType): - assert all(isinstance(arg, it_ts.ListType) for arg in args) + def applied_reduce(*args: ts.ListType, offset_provider_type: common.OffsetProviderType): + assert all(isinstance(arg, ts.ListType) for arg in args) return op( init, *(arg.element_type for arg in args), offset_provider_type=offset_provider_type ) diff --git a/src/gt4py/next/otf/binding/nanobind.py b/src/gt4py/next/otf/binding/nanobind.py index 24913a1365..edd56fad48 100644 --- a/src/gt4py/next/otf/binding/nanobind.py +++ b/src/gt4py/next/otf/binding/nanobind.py @@ -86,6 +86,7 @@ def _type_string(type_: ts.TypeSpec) -> str: return f"std::tuple<{','.join(_type_string(t) for t in type_.types)}>" elif isinstance(type_, ts.FieldType): ndims = len(type_.dims) + assert isinstance(type_.dtype, ts.ScalarType) dtype = cpp_interface.render_scalar_type(type_.dtype) shape = f"nanobind::shape<{', '.join(['gridtools::nanobind::dynamic_size'] * ndims)}>" buffer_t = f"nanobind::ndarray<{dtype}, {shape}>" diff --git a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py index 129d81d6f9..283eafb207 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/itir_to_gtfn_ir.py @@ -698,7 +698,7 @@ def visit_Temporary( def dtype_to_cpp(x: ts.DataType) -> str: if isinstance(x, ts.TupleType): assert all(isinstance(i, ts.ScalarType) for i in x.types) - return "::gridtools::tuple<" + ", ".join(dtype_to_cpp(i) for i in x.types) + ">" + return "::gridtools::tuple<" + ", ".join(dtype_to_cpp(i) for i in x.types) + ">" # type: ignore[arg-type] # ensured by assert assert isinstance(x, ts.ScalarType) res = pytype_to_cpptype(x) assert isinstance(res, str) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py index 69aedf44d6..cb45071ee7 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_builtin_translators.py @@ -19,7 +19,6 @@ from gt4py.next.ffront import fbuiltins as gtx_fbuiltins from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm -from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_dataflow, @@ -74,7 +73,7 @@ def get_local_view( ) elif len(local_dims) == 1: - field_dtype = itir_ts.ListType( + field_dtype = ts.ListType( element_type=self.gt_type.dtype, offset_type=local_dims[0] ) field_dims = [ @@ -193,12 +192,13 @@ def _create_temporary_field( output_desc = dataflow_output.result.dc_node.desc(sdfg) if isinstance(output_desc, dace.data.Array): - assert isinstance(node_type.dtype, itir_ts.ListType) + assert isinstance(node_type.dtype, ts.ListType) assert isinstance(node_type.dtype.element_type, ts.ScalarType) assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype.element_type) # extend the array with the local dimensions added by the field operator (e.g. `neighbors`) field_shape.extend(output_desc.shape) elif isinstance(output_desc, dace.data.Scalar): + assert isinstance(node_type.dtype, ts.ScalarType) assert output_desc.dtype == dace_utils.as_dace_type(node_type.dtype) else: raise ValueError(f"Cannot create field for dace type {output_desc}.") @@ -295,7 +295,7 @@ def translate_as_fieldop( input_edges, output = taskgen.visit(stencil_expr, args=stencil_args) output_desc = output.result.dc_node.desc(sdfg) - if isinstance(node.type.dtype, itir_ts.ListType): + if isinstance(node.type.dtype, ts.ListType): assert isinstance(output_desc, dace.data.Array) # additional local dimension for neighbors # TODO(phimuell): Investigate if we should swap the two. @@ -367,7 +367,8 @@ def translate_broadcast_scalar( gt_dtype = node.args[0].type elif isinstance(node.args[0].type, ts.FieldType): assert isinstance(scalar_expr, gtir_dataflow.IteratorExpr) - if len(node.args[0].type.dims) == 0: # zero-dimensional field + type_ = node.args[0].type + if len(type_.dims) == 0: # zero-dimensional field input_subset = "0" elif all( isinstance(scalar_expr.indices[dim], gtir_dataflow.SymbolExpr) @@ -384,7 +385,8 @@ def translate_broadcast_scalar( raise ValueError(f"Cannot deref field {scalar_expr.field} in broadcast expression.") input_node = scalar_expr.field - gt_dtype = node.args[0].type.dtype + assert isinstance(type_.dtype, ts.ScalarType) + gt_dtype = type_.dtype else: raise ValueError(f"Unexpected argument {node.args[0]} in broadcast expression.") diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py index 74142dec66..cc14b6841b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_dataflow.py @@ -19,7 +19,6 @@ from gt4py.next import common as gtx_common from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.type_system import type_specifications as itir_ts from gt4py.next.program_processors.runners.dace_common import utility as dace_utils from gt4py.next.program_processors.runners.dace_fieldview import ( gtir_python_codegen, @@ -52,7 +51,7 @@ class ValueExpr: """ dc_node: dace.nodes.AccessNode - gt_dtype: itir_ts.ListType | ts.ScalarType + gt_dtype: ts.ListType | ts.ScalarType @dataclasses.dataclass(frozen=True) @@ -67,7 +66,7 @@ class MemletExpr: """ dc_node: dace.nodes.AccessNode - gt_dtype: itir_ts.ListType | ts.ScalarType + gt_dtype: ts.ListType | ts.ScalarType subset: sbs.Indices | sbs.Range @@ -97,7 +96,7 @@ class IteratorExpr: """ field: dace.nodes.AccessNode - gt_dtype: itir_ts.ListType | ts.ScalarType + gt_dtype: ts.ListType | ts.ScalarType dimensions: list[gtx_common.Dimension] indices: dict[gtx_common.Dimension, DataExpr] @@ -406,7 +405,7 @@ def _construct_tasklet_result( return ValueExpr( dc_node=temp_node, gt_dtype=( - itir_ts.ListType(element_type=data_type, offset_type=_CONST_DIM) + ts.ListType(element_type=data_type, offset_type=_CONST_DIM) if use_array else data_type ), @@ -447,7 +446,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: # default case: deref a field with one or more dimensions if all(isinstance(index, SymbolExpr) for index in arg_expr.indices.values()): # when all indices are symblic expressions, we can perform direct field access through a memlet - if isinstance(arg_expr.gt_dtype, itir_ts.ListType): + if isinstance(arg_expr.gt_dtype, ts.ListType): assert len(field_desc.shape) == len(arg_expr.dimensions) + 1 assert arg_expr.gt_dtype.offset_type is not None field_dims = [*arg_expr.dimensions, arg_expr.gt_dtype.offset_type] @@ -521,7 +520,7 @@ def _visit_deref(self, node: gtir.FunCall) -> DataExpr: return self._construct_tasklet_result(field_desc.dtype, deref_node, "val") def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: - assert isinstance(node.type, itir_ts.ListType) + assert isinstance(node.type, ts.ListType) assert len(node.args) == 2 assert isinstance(node.args[0], gtir.OffsetLiteral) @@ -622,7 +621,7 @@ def _visit_neighbors(self, node: gtir.FunCall) -> ValueExpr: ) return ValueExpr( - dc_node=neighbors_node, gt_dtype=itir_ts.ListType(node.type.element_type, offset_type) + dc_node=neighbors_node, gt_dtype=ts.ListType(node.type.element_type, offset_type) ) def _visit_map(self, node: gtir.FunCall) -> ValueExpr: @@ -641,7 +640,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: In above example, the result would be an array with size V2E.max_neighbors, containing the V2E neighbor values incremented by 1.0. """ - assert isinstance(node.type, itir_ts.ListType) + assert isinstance(node.type, ts.ListType) assert isinstance(node.fun, gtir.FunCall) assert len(node.fun.args) == 1 # the operation to be mapped on the arguments @@ -661,7 +660,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: gtx_common.Dimension, gtx_common.NeighborConnectivityType ] = {} for input_arg in input_args: - assert isinstance(input_arg.gt_dtype, itir_ts.ListType) + assert isinstance(input_arg.gt_dtype, ts.ListType) assert input_arg.gt_dtype.offset_type is not None offset_type = input_arg.gt_dtype.offset_type if offset_type == _CONST_DIM: @@ -731,7 +730,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: connectivity_slice = self._construct_local_view( MemletExpr( dc_node=self.state.add_access(connectivity), - gt_dtype=itir_ts.ListType( + gt_dtype=ts.ListType( element_type=node.type.element_type, offset_type=offset_type ), subset=sbs.Range.from_string( @@ -770,7 +769,7 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: return ValueExpr( dc_node=result_node, - gt_dtype=itir_ts.ListType(node.type.element_type, offset_type), + gt_dtype=ts.ListType(node.type.element_type, offset_type), ) def _make_reduce_with_skip_values( @@ -797,7 +796,7 @@ def _make_reduce_with_skip_values( origin_map_index = dace_gtir_utils.get_map_variable(offset_provider_type.source_dim) assert ( - isinstance(input_expr.gt_dtype, itir_ts.ListType) + isinstance(input_expr.gt_dtype, ts.ListType) and input_expr.gt_dtype.offset_type is not None ) offset_type = input_expr.gt_dtype.offset_type @@ -908,7 +907,7 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: input_expr = self.visit(node.args[0]) assert isinstance(input_expr, (MemletExpr, ValueExpr)) assert ( - isinstance(input_expr.gt_dtype, itir_ts.ListType) + isinstance(input_expr.gt_dtype, ts.ListType) and input_expr.gt_dtype.offset_type is not None ) offset_type = input_expr.gt_dtype.offset_type @@ -1210,7 +1209,7 @@ def _visit_generic_builtin(self, node: gtir.FunCall) -> ValueExpr: connector, ) - if isinstance(node.type, itir_ts.ListType): + if isinstance(node.type, ts.ListType): # The only builtin function (so far) handled here that returns a list # is 'make_const_list'. There are other builtin functions (map_, neighbors) # that return a list but they are handled in specialized visit methods. diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py index 52284edfac..864dfbef7a 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_sdfg.py @@ -251,6 +251,7 @@ def _add_storage( for tname, tsymbol_type in dace_gtir_utils.get_tuple_fields( name, gt_type, flatten=True ): + assert isinstance(tsymbol_type, ts.DataType) tuple_fields.extend( self._add_storage( sdfg, symbolic_arguments, tname, tsymbol_type, transient, tuple_name=name @@ -263,6 +264,7 @@ def _add_storage( # represent zero-dimensional fields as scalar arguments return self._add_storage(sdfg, symbolic_arguments, name, gt_type.dtype, transient) # handle default case: field with one or more dimensions + assert isinstance(gt_type.dtype, ts.ScalarType) dc_dtype = dace_utils.as_dace_type(gt_type.dtype) if tuple_name is None: # Use symbolic shape, which allows to invoke the program with fields of different size; diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py index caec6cd87e..7c369b9b41 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/utility.py @@ -45,17 +45,18 @@ def get_tuple_fields( ... ("a_1_1", sty), ... ] """ + assert all(isinstance(t, ts.DataType) for t in tuple_type.types) fields = [(f"{tuple_name}_{i}", field_type) for i, field_type in enumerate(tuple_type.types)] if flatten: - expanded_fields = [ + expanded_fields: list[list[tuple[str, ts.DataType]]] = [ get_tuple_fields(field_name, field_type) if isinstance(field_type, ts.TupleType) - else [(field_name, field_type)] + else [(field_name, field_type)] # type: ignore[list-item] # checked in assert for field_name, field_type in fields ] return list(itertools.chain(*expanded_fields)) else: - return fields + return fields # type: ignore[return-value] # checked in assert def get_tuple_type(data: tuple[Any, ...]) -> ts.TupleType: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index ef09cf51cd..87e931fbc3 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -366,6 +366,7 @@ def _crosscheck_dace_parsing(dace_parsed_args: list[Any], gt4py_program_args: li elif isinstance(dace_parsed_arg, dace.data.Array): assert isinstance(gt4py_program_arg, ts.FieldType) assert len(dace_parsed_arg.shape) == len(gt4py_program_arg.dims) + assert isinstance(gt4py_program_arg.dtype, ts.ScalarType) assert dace_parsed_arg.dtype == dace_utils.as_dace_type(gt4py_program_arg.dtype) elif isinstance( dace_parsed_arg, (dace.data.Structure, dict, OrderedDict) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 823943cfd5..fd02cdf8d6 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -64,7 +64,7 @@ def _get_scan_dim( storage_types: dict[str, ts.TypeSpec], output: SymRef, use_field_canonical_representation: bool, -) -> tuple[str, int, ts.ScalarType]: +) -> tuple[str, int, ts.ScalarType | ts.ListType]: """ Extract information about the scan dimension. @@ -170,6 +170,7 @@ def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, sort_dimen shape, strides = _make_array_shape_and_strides( name, type_.dims, self.offset_provider_type, sort_dimensions ) + assert isinstance(type_.dtype, ts.ScalarType) dtype = dace_utils.as_dace_type(type_.dtype) sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) @@ -595,6 +596,7 @@ def _visit_scan_stencil_closure( # the carry value of the scan operator exists only in the scope of the scan sdfg scan_carry_name = unique_var_name() + assert isinstance(scan_dtype, ts.ScalarType) scan_sdfg.add_scalar( scan_carry_name, dtype=dace_utils.as_dace_type(scan_dtype), transient=True ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 2b2669187a..056ab4b543 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -1527,6 +1527,7 @@ def closure_to_tasklet_sdfg( ndim = len(ty.dims) shape, strides = new_array_symbols(name, ndim) dims = [dim.value for dim in ty.dims] + assert isinstance(ty.dtype, ts.ScalarType) dtype = dace_utils.as_dace_type(ty.dtype) body.add_array(name, shape=shape, strides=strides, dtype=dtype) field = state.add_access(name, debuginfo=body.debuginfo) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 66f8937dc5..34dde1e3e2 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -78,15 +78,15 @@ def type_class(symbol_type: ts.TypeSpec) -> Type[ts.TypeSpec]: >>> type_class(ts.TupleType(types=[])).__name__ 'TupleType' """ - match symbol_type: - case ts.DeferredType(constraint): - if constraint is None: - raise ValueError(f"No type information available for '{symbol_type}'.") - elif isinstance(constraint, tuple): - raise ValueError(f"Not sufficient type information available for '{symbol_type}'.") - return constraint - case ts.TypeSpec() as concrete_type: - return concrete_type.__class__ + if isinstance(symbol_type, ts.DeferredType): + constraint = symbol_type.constraint + if constraint is None: + raise ValueError(f"No type information available for '{symbol_type}'.") + elif isinstance(constraint, tuple): + raise ValueError(f"Not sufficient type information available for '{symbol_type}'.") + return constraint + if isinstance(symbol_type, ts.TypeSpec): + return symbol_type.__class__ raise ValueError( f"Invalid type for TypeInfo: requires '{ts.TypeSpec}', got '{type(symbol_type)}'." ) @@ -213,6 +213,7 @@ def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType: """ match symbol_type: case ts.FieldType(dtype=dtype): + assert isinstance(dtype, ts.ScalarType) return dtype case ts.ScalarType() as dtype: return dtype @@ -385,11 +386,10 @@ def extract_dims(symbol_type: ts.TypeSpec) -> list[common.Dimension]: >>> extract_dims(ts.FieldType(dims=[I, J], dtype=ts.ScalarType(kind=ts.ScalarKind.INT64))) [Dimension(value='I', kind=), Dimension(value='J', kind=)] """ - match symbol_type: - case ts.ScalarType(): - return [] - case ts.FieldType(dims): - return dims + if isinstance(symbol_type, ts.ScalarType): + return [] + if isinstance(symbol_type, ts.FieldType): + return symbol_type.dims raise ValueError(f"Can not extract dimensions from '{symbol_type}'.") diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index fa8c9b9ab1..7deb9a2150 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -6,21 +6,13 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -from dataclasses import dataclass from typing import Iterator, Optional, Sequence, Union -from gt4py.eve.type_definitions import IntEnum -from gt4py.eve.utils import content_hash -from gt4py.next import common as func_common +from gt4py.eve import datamodels as eve_datamodels, type_definitions as eve_types +from gt4py.next import common -@dataclass(frozen=True) -class TypeSpec: - def __hash__(self) -> int: - return hash(content_hash(self)) - - def __init_subclass__(cls) -> None: - cls.__hash__ = TypeSpec.__hash__ # type: ignore[method-assign] +class TypeSpec(eve_datamodels.DataModel, kw_only=False, frozen=True): ... # type: ignore[call-arg] class DataType(TypeSpec): @@ -40,14 +32,12 @@ class CallableType: """ -@dataclass(frozen=True) class DeferredType(TypeSpec): """Dummy used to represent a type not yet inferred.""" constraint: Optional[type[TypeSpec] | tuple[type[TypeSpec], ...]] -@dataclass(frozen=True) class VoidType(TypeSpec): """ Return type of a function without return values. @@ -56,22 +46,20 @@ class VoidType(TypeSpec): """ -@dataclass(frozen=True) class DimensionType(TypeSpec): - dim: func_common.Dimension + dim: common.Dimension -@dataclass(frozen=True) class OffsetType(TypeSpec): # TODO(havogt): replace by ConnectivityType - source: func_common.Dimension - target: tuple[func_common.Dimension] | tuple[func_common.Dimension, func_common.Dimension] + source: common.Dimension + target: tuple[common.Dimension] | tuple[common.Dimension, common.Dimension] def __str__(self) -> str: return f"Offset[{self.source}, {self.target}]" -class ScalarKind(IntEnum): +class ScalarKind(eve_types.IntEnum): BOOL = 1 INT32 = 32 INT64 = 64 @@ -80,7 +68,6 @@ class ScalarKind(IntEnum): STRING = 3001 -@dataclass(frozen=True) class ScalarType(DataType): kind: ScalarKind shape: Optional[list[int]] = None @@ -92,31 +79,37 @@ def __str__(self) -> str: return f"{kind_str}{self.shape}" -@dataclass(frozen=True) +class ListType(DataType): + element_type: DataType + # TODO(havogt): the `offset_type` is not yet used in type_inference, + # it is meant to describe the neighborhood (via the local dimension) + offset_type: Optional[common.Dimension] = None + + +class FieldType(DataType, CallableType): + dims: list[common.Dimension] + dtype: ScalarType | ListType + + def __str__(self) -> str: + dims = "..." if self.dims is Ellipsis else f"[{', '.join(dim.value for dim in self.dims)}]" + return f"Field[{dims}, {self.dtype}]" + + class TupleType(DataType): - types: list[DataType] + # TODO: or DimensionType is a DataType + types: list[DataType | DimensionType | DeferredType] + # TODO validate DeferredType constraints def __str__(self) -> str: return f"tuple[{', '.join(map(str, self.types))}]" - def __iter__(self) -> Iterator[DataType]: + def __iter__(self) -> Iterator[DataType | DimensionType | DeferredType]: yield from self.types def __len__(self) -> int: return len(self.types) -@dataclass(frozen=True) -class FieldType(DataType, CallableType): - dims: list[func_common.Dimension] - dtype: ScalarType - - def __str__(self) -> str: - dims = "..." if self.dims is Ellipsis else f"[{', '.join(dim.value for dim in self.dims)}]" - return f"Field[{dims}, {self.dtype}]" - - -@dataclass(frozen=True) class FunctionType(TypeSpec, CallableType): pos_only_args: Sequence[TypeSpec] pos_or_kw_args: dict[str, TypeSpec] diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 62a6781316..e601556e55 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -10,7 +10,6 @@ import builtins import collections.abc -import dataclasses import functools import types import typing @@ -105,7 +104,7 @@ def from_type_hint( raise ValueError(f"Unbound tuples '{type_hint}' are not allowed.") tuple_types = [recursive_make_symbol(arg) for arg in args] assert all(isinstance(elem, ts.DataType) for elem in tuple_types) - return ts.TupleType(types=tuple_types) # type: ignore[arg-type] # checked in assert + return ts.TupleType(types=tuple_types) case common.Field: if (n_args := len(args)) != 2: @@ -168,7 +167,6 @@ def from_type_hint( raise ValueError(f"'{type_hint}' type is not supported.") -@dataclasses.dataclass(frozen=True) class UnknownPythonObject(ts.TypeSpec): _object: Any @@ -217,9 +215,9 @@ def from_value(value: Any) -> ts.TypeSpec: # not needed anymore. elems = [from_value(el) for el in value] assert all(isinstance(elem, ts.DataType) for elem in elems) - return ts.TupleType(types=elems) # type: ignore[arg-type] # checked in assert + return ts.TupleType(types=elems) elif isinstance(value, types.ModuleType): - return UnknownPythonObject(_object=value) + return UnknownPythonObject(value) else: type_ = xtyping.infer_type(value, annotate_callable_kwargs=True) symbol_type = from_type_hint(type_) diff --git a/tests/eve_tests/unit_tests/test_datamodels.py b/tests/eve_tests/unit_tests/test_datamodels.py index 05be5f3db0..7f523df6cf 100644 --- a/tests/eve_tests/unit_tests/test_datamodels.py +++ b/tests/eve_tests/unit_tests/test_datamodels.py @@ -12,7 +12,6 @@ import numbers import types import typing -from typing import Set # noqa: F401 [unused-import] used in exec() context from typing import ( Any, Callable, @@ -26,6 +25,7 @@ MutableSequence, Optional, Sequence, + Set, # noqa: F401 [unused-import] used in exec() context Tuple, Type, TypeVar, @@ -555,6 +555,7 @@ class WrongModel: ("typing.MutableSequence[int]", ([1, 2, 3], []), ((1, 2, 3), tuple(), 1, [1.0], {1})), ("typing.Set[int]", ({1, 2, 3}, set()), (1, [1], (1,), {1: None})), ("typing.Union[int, float, str]", [1, 3.0, "one"], [[1], [], 1j]), + ("int | float | str", [1, 3.0, "one"], [[1], [], 1j]), ("typing.Optional[int]", [1, None], [[1], [], 1j]), ( "typing.Dict[Union[int, float, str], Union[Tuple[int, Optional[float]], Set[int]]]", diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 65a5b5888d..ece7af5633 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -23,13 +23,12 @@ ) from gt4py.next.type_system import type_specifications as ts -from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import simple_mesh - from next_tests.integration_tests.cases import ( C2E, E2V, V2E, E2VDim, + Edge, IDim, Ioff, JDim, @@ -37,17 +36,18 @@ Koff, V2EDim, Vertex, - Edge, - mesh_descriptor, exec_alloc_descriptor, + mesh_descriptor, unstructured_case, ) +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import simple_mesh + bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) int_type = ts.ScalarType(kind=ts.ScalarKind.INT32) float64_type = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) -float64_list_type = it_ts.ListType(element_type=float64_type) -int_list_type = it_ts.ListType(element_type=int_type) +float64_list_type = ts.ListType(element_type=float64_type) +int_list_type = ts.ListType(element_type=int_type) float_i_field = ts.FieldType(dims=[IDim], dtype=float64_type) float_vertex_k_field = ts.FieldType(dims=[Vertex, KDim], dtype=float64_type) @@ -77,8 +77,8 @@ def expression_test_cases(): (im.deref(im.ref("it", it_on_e_of_e_type)), it_on_e_of_e_type.element_type), (im.call("can_deref")(im.ref("it", it_on_e_of_e_type)), bool_type), (im.if_(True, 1, 2), int_type), - (im.call("make_const_list")(True), it_ts.ListType(element_type=bool_type)), - (im.call("list_get")(0, im.ref("l", it_ts.ListType(element_type=bool_type))), bool_type), + (im.call("make_const_list")(True), ts.ListType(element_type=bool_type)), + (im.call("list_get")(0, im.ref("l", ts.ListType(element_type=bool_type))), bool_type), ( im.call("named_range")(itir.AxisLiteral(value="Vertex"), 0, 1), it_ts.NamedRangeType(dim=Vertex), @@ -110,7 +110,7 @@ def expression_test_cases(): # neighbors ( im.neighbors("E2V", im.ref("a", it_on_e_of_e_type)), - it_ts.ListType(element_type=it_on_e_of_e_type.element_type), + ts.ListType(element_type=it_on_e_of_e_type.element_type), ), # cast (im.call("cast_")(1, "int32"), int_type), From f220b34af4b771636b022b1ecebfbd6ff809bf95 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 3 Dec 2024 15:44:28 +0100 Subject: [PATCH 02/11] fix | in sequences --- src/gt4py/eve/datamodels/core.py | 10 ++++------ src/gt4py/eve/type_validation.py | 9 +++++++++ tests/eve_tests/unit_tests/test_datamodels.py | 1 + 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index a45fbb821e..af80973fcd 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -1041,11 +1041,6 @@ def _make_datamodel( for key in annotations: type_hint = annotations[key] = resolved_annotations[key] - if isinstance( - type_hint, types.UnionType - ): # see https://github.com/python/cpython/issues/105499 - type_hint = typing.Union[type_hint.__args__] - # Skip members annotated as class variables if type_hint is ClassVar or xtyping.get_origin(type_hint) is ClassVar: continue @@ -1260,8 +1255,11 @@ def _make_concrete_with_cache( if not is_generic_datamodel_class(datamodel_cls): raise TypeError(f"'{datamodel_cls.__name__}' is not a generic model class.") for t in type_args: + _accepted_types: tuple[type, ...] = (type, type(None), xtyping.StdGenericAliasType) + if sys.version_info >= (3, 10): + _accepted_types = (*_accepted_types, types.UnionType) if not ( - isinstance(t, (type, type(None), xtyping.StdGenericAliasType, types.UnionType)) + isinstance(t, _accepted_types) or (getattr(type(t), "__module__", None) in ("typing", "typing_extensions")) ): raise TypeError( diff --git a/src/gt4py/eve/type_validation.py b/src/gt4py/eve/type_validation.py index 613eca40b2..7464db1b4e 100644 --- a/src/gt4py/eve/type_validation.py +++ b/src/gt4py/eve/type_validation.py @@ -14,6 +14,8 @@ import collections.abc import dataclasses import functools +import sys +import types import typing from . import exceptions, extended_typing as xtyping, utils @@ -193,6 +195,12 @@ def __call__( if type_annotation is None: type_annotation = type(None) + if sys.version_info >= (3, 10): + if isinstance( + type_annotation, types.UnionType + ): # see https://github.com/python/cpython/issues/105499 + type_annotation = typing.Union[type_annotation.__args__] + # Non-generic types if xtyping.is_actual_type(type_annotation): assert not xtyping.get_args(type_annotation) @@ -277,6 +285,7 @@ def __call__( if issubclass(origin_type, (collections.abc.Sequence, collections.abc.Set)): assert len(type_args) == 1 + make_recursive(type_args[0]) if (member_validator := make_recursive(type_args[0])) is None: raise exceptions.EveValueError( f"{type_args[0]} type annotation is not supported." diff --git a/tests/eve_tests/unit_tests/test_datamodels.py b/tests/eve_tests/unit_tests/test_datamodels.py index 7f523df6cf..d826d7a02f 100644 --- a/tests/eve_tests/unit_tests/test_datamodels.py +++ b/tests/eve_tests/unit_tests/test_datamodels.py @@ -556,6 +556,7 @@ class WrongModel: ("typing.Set[int]", ({1, 2, 3}, set()), (1, [1], (1,), {1: None})), ("typing.Union[int, float, str]", [1, 3.0, "one"], [[1], [], 1j]), ("int | float | str", [1, 3.0, "one"], [[1], [], 1j]), + ("typing.List[int|float]", [[1, 2.0], []], [1, 2.0, [1, "2.0"]]), ("typing.Optional[int]", [1, None], [[1], [], 1j]), ( "typing.Dict[Union[int, float, str], Union[Tuple[int, Optional[float]], Set[int]]]", From 7fb18aea51f493e5edf5add2940e057e92a3a2bd Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 3 Dec 2024 15:47:32 +0100 Subject: [PATCH 03/11] address review comments --- src/gt4py/next/iterator/type_system/inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 87f599bc1b..8c99fa9430 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -272,8 +272,8 @@ def _get_dimensions(obj: Any): if isinstance(obj, common.Dimension): yield obj elif isinstance(obj, ts.TypeSpec): - for field in obj.__datamodel_fields__.values(): - yield from _get_dimensions(getattr(obj, field.name)) + for field in obj.__datamodel_fields__.keys(): + yield from _get_dimensions(getattr(obj, field)) elif isinstance(obj, collections.abc.Mapping): for el in obj.values(): yield from _get_dimensions(el) From 1471750f14367e5ea7766d9a5198f03e2e3a8fac Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 3 Dec 2024 17:46:10 +0100 Subject: [PATCH 04/11] exclude test for python < 3.10 --- tests/eve_tests/unit_tests/test_datamodels.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/eve_tests/unit_tests/test_datamodels.py b/tests/eve_tests/unit_tests/test_datamodels.py index d826d7a02f..75b07fd8a0 100644 --- a/tests/eve_tests/unit_tests/test_datamodels.py +++ b/tests/eve_tests/unit_tests/test_datamodels.py @@ -10,6 +10,7 @@ import enum import numbers +import sys import types import typing from typing import ( @@ -555,8 +556,18 @@ class WrongModel: ("typing.MutableSequence[int]", ([1, 2, 3], []), ((1, 2, 3), tuple(), 1, [1.0], {1})), ("typing.Set[int]", ({1, 2, 3}, set()), (1, [1], (1,), {1: None})), ("typing.Union[int, float, str]", [1, 3.0, "one"], [[1], [], 1j]), - ("int | float | str", [1, 3.0, "one"], [[1], [], 1j]), - ("typing.List[int|float]", [[1, 2.0], []], [1, 2.0, [1, "2.0"]]), + pytest.param( + "int | float | str", + [1, 3.0, "one"], + [[1], [], 1j], + marks=pytest.mark.skipif(sys.version_info < (3, 10), reason="| union syntax not supported"), + ), + pytest.param( + "typing.List[int|float]", + [[1, 2.0], []], + [1, 2.0, [1, "2.0"]], + marks=pytest.mark.skipif(sys.version_info < (3, 10), reason="| union syntax not supported"), + ), ("typing.Optional[int]", [1, None], [[1], [], 1j]), ( "typing.Dict[Union[int, float, str], Union[Tuple[int, Optional[float]], Set[int]]]", From 86960cb7f368d488398e74ac4c63d7e411583d24 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 3 Dec 2024 18:38:36 +0100 Subject: [PATCH 05/11] fix extract_dtype --- src/gt4py/next/ffront/foast_to_gtir.py | 3 +++ src/gt4py/next/ffront/foast_to_itir.py | 3 +++ .../next/iterator/transforms/global_tmps.py | 12 ++++++------ src/gt4py/next/type_system/type_info.py | 17 ++++++++++++----- 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 2c2971f49a..0331adf1c1 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -236,6 +236,7 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators dtype = type_info.extract_dtype(node.type) + assert isinstance(dtype, ts.ScalarType) if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: if dtype.kind != ts.ScalarKind.BOOL: raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") @@ -421,12 +422,14 @@ def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: dtype = type_info.extract_dtype(node.type) + assert isinstance(dtype, ts.ScalarType) min_value, _ = type_info.arithmetic_bounds(dtype) init_expr = self._make_literal(str(min_value), dtype) return self._make_reduction_expr(node, "maximum", init_expr, **kwargs) def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: dtype = type_info.extract_dtype(node.type) + assert isinstance(dtype, ts.ScalarType) _, max_value = type_info.arithmetic_bounds(dtype) init_expr = self._make_literal(str(max_value), dtype) return self._make_reduction_expr(node, "minimum", init_expr, **kwargs) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 538b0f3ddb..e46513222b 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -264,6 +264,7 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr: def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr: # TODO(tehrengruber): extend iterator ir to support unary operators dtype = type_info.extract_dtype(node.type) + assert isinstance(dtype, ts.ScalarType) if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]: if dtype.kind != ts.ScalarKind.BOOL: raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.") @@ -441,12 +442,14 @@ def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr: def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: dtype = type_info.extract_dtype(node.type) + assert isinstance(dtype, ts.ScalarType) min_value, _ = type_info.arithmetic_bounds(dtype) init_expr = self._make_literal(str(min_value), dtype) return self._make_reduction_expr(node, "maximum", init_expr, **kwargs) def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr: dtype = type_info.extract_dtype(node.type) + assert isinstance(dtype, ts.ScalarType) _, max_value = type_info.arithmetic_bounds(dtype) init_expr = self._make_literal(str(max_value), dtype) return self._make_reduction_expr(node, "minimum", init_expr, **kwargs) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index a6d39883e3..d37cfc239d 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -98,12 +98,12 @@ def _transform_by_pattern( tmp_expr.type, tuple_constructor=lambda *elements: tuple(elements), ) - tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = ( - type_info.apply_to_primitive_constituents( - type_info.extract_dtype, - tmp_expr.type, - tuple_constructor=lambda *elements: tuple(elements), - ) + tmp_dtypes: ( + ts.ScalarType | ts.ListType | tuple[ts.ScalarType | ts.ListType | tuple, ...] + ) = type_info.apply_to_primitive_constituents( + type_info.extract_dtype, + tmp_expr.type, + tuple_constructor=lambda *elements: tuple(elements), ) # allocate temporary for all tuple elements diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 34dde1e3e2..7812699005 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -197,7 +197,7 @@ def apply_to_primitive_constituents( return fun(*symbol_types) -def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType: +def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType | ts.ListType: """ Extract the data type from ``symbol_type`` if it is either `FieldType` or `ScalarType`. @@ -213,7 +213,6 @@ def extract_dtype(symbol_type: ts.TypeSpec) -> ts.ScalarType: """ match symbol_type: case ts.FieldType(dtype=dtype): - assert isinstance(dtype, ts.ScalarType) return dtype case ts.ScalarType() as dtype: return dtype @@ -235,7 +234,10 @@ def is_floating_point(symbol_type: ts.TypeSpec) -> bool: >>> is_floating_point(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))) True """ - return extract_dtype(symbol_type).kind in [ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64] + return isinstance(symbol_type, ts.ScalarType) and extract_dtype(symbol_type).kind in [ # type: ignore[union-attr] # checked is `ScalarType` + ts.ScalarKind.FLOAT32, + ts.ScalarKind.FLOAT64, + ] def is_integer(symbol_type: ts.TypeSpec) -> bool: @@ -296,7 +298,10 @@ def is_number(symbol_type: ts.TypeSpec) -> bool: def is_logical(symbol_type: ts.TypeSpec) -> bool: - return extract_dtype(symbol_type).kind is ts.ScalarKind.BOOL + return ( + isinstance(symbol_type, ts.ScalarType) + and extract_dtype(symbol_type).kind is ts.ScalarKind.BOOL # type: ignore[union-attr] # checked is `ScalarType` + ) def is_arithmetic(symbol_type: ts.TypeSpec) -> bool: @@ -502,7 +507,9 @@ def promote( return types[0] elif all(isinstance(type_, (ts.ScalarType, ts.FieldType)) for type_ in types): dims = common.promote_dims(*(extract_dims(type_) for type_ in types)) - dtype = cast(ts.ScalarType, promote(*(extract_dtype(type_) for type_ in types))) + extracted_dtypes = [extract_dtype(type_) for type_ in types] + assert all(isinstance(dtype, ts.ScalarType) for dtype in extracted_dtypes) + dtype = cast(ts.ScalarType, promote(*extracted_dtypes)) # type: ignore[arg-type] # checked is `ScalarType` return ts.FieldType(dims=dims, dtype=dtype) raise TypeError("Expected a 'FieldType' or 'ScalarType'.") From 75bd7377c96f69abaf07cc67338a49fe5bf7daea Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 4 Dec 2024 15:57:36 +0100 Subject: [PATCH 06/11] fix the fix for is_xxx which include Field --- src/gt4py/next/type_system/type_info.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 7812699005..983063a9cb 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -234,7 +234,7 @@ def is_floating_point(symbol_type: ts.TypeSpec) -> bool: >>> is_floating_point(ts.FieldType(dims=[], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT32))) True """ - return isinstance(symbol_type, ts.ScalarType) and extract_dtype(symbol_type).kind in [ # type: ignore[union-attr] # checked is `ScalarType` + return isinstance(dtype := extract_dtype(symbol_type), ts.ScalarType) and dtype.kind in [ ts.ScalarKind.FLOAT32, ts.ScalarKind.FLOAT64, ] @@ -299,8 +299,8 @@ def is_number(symbol_type: ts.TypeSpec) -> bool: def is_logical(symbol_type: ts.TypeSpec) -> bool: return ( - isinstance(symbol_type, ts.ScalarType) - and extract_dtype(symbol_type).kind is ts.ScalarKind.BOOL # type: ignore[union-attr] # checked is `ScalarType` + isinstance(dtype := extract_dtype(symbol_type), ts.ScalarType) + and dtype.kind is ts.ScalarKind.BOOL ) From d217a7b062172e4bc81bf4a4ef63ab984513207e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 4 Dec 2024 16:12:44 +0100 Subject: [PATCH 07/11] address review comments --- src/gt4py/next/ffront/past_passes/type_deduction.py | 11 ++++++++++- src/gt4py/next/type_system/type_specifications.py | 5 +++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index 31998a6d26..9355273588 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -104,11 +104,20 @@ def visit_Program(self, node: past.Program, **kwargs: Any) -> past.Program: location=node.location, ) + def visit_Slice(self, node: past.Slice, **kwargs: Any) -> past.Slice: + return past.Slice( + lower=self.visit(node.lower, **kwargs), + upper=self.visit(node.upper, **kwargs), + step=self.visit(node.step, **kwargs), + type=ts.DeferredType(constraint=None), + location=node.location, + ) + def visit_Subscript(self, node: past.Subscript, **kwargs: Any) -> past.Subscript: value = self.visit(node.value, **kwargs) return past.Subscript( value=value, - slice_=node.slice_, # TODO: I don't think we are using this type + slice_=self.visit(node.slice_, **kwargs), type=value.type, location=node.location, ) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index 7deb9a2150..a695b2268a 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -80,6 +80,11 @@ def __str__(self) -> str: class ListType(DataType): + """Represents a neighbor list in the ITIR representation. + + Note: not used in the frontend. + """ + element_type: DataType # TODO(havogt): the `offset_type` is not yet used in type_inference, # it is meant to describe the neighborhood (via the local dimension) From ea81913ed04b92444ea62dc6cdd56beade6a51f4 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 4 Dec 2024 16:23:27 +0100 Subject: [PATCH 08/11] Update src/gt4py/next/type_system/type_specifications.py Co-authored-by: Till Ehrengruber --- src/gt4py/next/type_system/type_specifications.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index a695b2268a..f57f50aa97 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -102,6 +102,9 @@ def __str__(self) -> str: class TupleType(DataType): # TODO: or DimensionType is a DataType + # TODO(tehrengruber): Remove `DeferredType` again. This was erroneously + # introduced before we checked the annotations at runtime. All attributes of + # a type that are types themselves must be concrete. types: list[DataType | DimensionType | DeferredType] # TODO validate DeferredType constraints From ffd555b84e7659886cba4379f3f97757a40939c6 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 4 Dec 2024 16:24:43 +0100 Subject: [PATCH 09/11] fix formatting --- src/gt4py/next/type_system/type_specifications.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index f57f50aa97..e10c06864c 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -103,7 +103,7 @@ def __str__(self) -> str: class TupleType(DataType): # TODO: or DimensionType is a DataType # TODO(tehrengruber): Remove `DeferredType` again. This was erroneously - # introduced before we checked the annotations at runtime. All attributes of + # introduced before we checked the annotations at runtime. All attributes of # a type that are types themselves must be concrete. types: list[DataType | DimensionType | DeferredType] # TODO validate DeferredType constraints From 1fdcd406ca46d73391d074ca99ad6d2123c3c5c4 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 4 Dec 2024 16:36:06 +0100 Subject: [PATCH 10/11] remove todos --- src/gt4py/next/type_system/type_specifications.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/gt4py/next/type_system/type_specifications.py b/src/gt4py/next/type_system/type_specifications.py index e10c06864c..060d56aea2 100644 --- a/src/gt4py/next/type_system/type_specifications.py +++ b/src/gt4py/next/type_system/type_specifications.py @@ -101,12 +101,10 @@ def __str__(self) -> str: class TupleType(DataType): - # TODO: or DimensionType is a DataType # TODO(tehrengruber): Remove `DeferredType` again. This was erroneously # introduced before we checked the annotations at runtime. All attributes of # a type that are types themselves must be concrete. types: list[DataType | DimensionType | DeferredType] - # TODO validate DeferredType constraints def __str__(self) -> str: return f"tuple[{', '.join(map(str, self.types))}]" From 23e31e65e7878c8e3c8a58192f88dced6cf0271c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Wed, 4 Dec 2024 18:29:32 +0100 Subject: [PATCH 11/11] fix doctests --- .../next/ffront/foast_passes/type_deduction.py | 17 +++++++++++------ src/gt4py/next/ffront/type_info.py | 4 +++- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 4495717173..6b40cbb77f 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -68,13 +68,18 @@ def construct_tuple_type( >>> mask_type = ts.FieldType( ... dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL) ... ) - >>> true_branch_types = [ts.ScalarType(kind=ts.ScalarKind), ts.ScalarType(kind=ts.ScalarKind)] + >>> true_branch_types = [ + ... ts.ScalarType(kind=ts.ScalarKind.FLOAT64), + ... ts.ScalarType(kind=ts.ScalarKind.FLOAT64), + ... ] >>> false_branch_types = [ - ... ts.FieldType(dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind)), - ... ts.ScalarType(kind=ts.ScalarKind), + ... ts.FieldType( + ... dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ... ), + ... ts.ScalarType(kind=ts.ScalarKind.FLOAT64), ... ] >>> print(construct_tuple_type(true_branch_types, false_branch_types, mask_type)) - [FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)), FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None))] + [FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)), FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None))] """ element_types_new = true_branch_types for i, element in enumerate(true_branch_types): @@ -105,8 +110,8 @@ def promote_to_mask_type( >>> I, J = (Dimension(value=dim) for dim in ["I", "J"]) >>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL) >>> dtype = ts.ScalarType(kind=ts.ScalarKind.FLOAT64) - >>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), ts.ScalarType(kind=dtype)) - FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=ScalarType(kind=, shape=None), shape=None)) + >>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), dtype) + FieldType(dims=[Dimension(value='I', kind=), Dimension(value='J', kind=)], dtype=ScalarType(kind=, shape=None)) >>> promote_to_mask_type( ... ts.FieldType(dims=[I, J], dtype=bool_type), ts.FieldType(dims=[I], dtype=dtype) ... ) diff --git a/src/gt4py/next/ffront/type_info.py b/src/gt4py/next/ffront/type_info.py index d0c1e9bbfd..83ecf92839 100644 --- a/src/gt4py/next/ffront/type_info.py +++ b/src/gt4py/next/ffront/type_info.py @@ -169,7 +169,9 @@ def _scan_param_promotion(param: ts.TypeSpec, arg: ts.TypeSpec) -> ts.FieldType -------- >>> _scan_param_promotion( ... ts.ScalarType(kind=ts.ScalarKind.INT64), - ... ts.FieldType(dims=[common.Dimension("I")], dtype=ts.ScalarKind.FLOAT64), + ... ts.FieldType( + ... dims=[common.Dimension("I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64) + ... ), ... ) FieldType(dims=[Dimension(value='I', kind=)], dtype=ScalarType(kind=, shape=None)) """