Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[next]: use eve.datamodel for types #1750

Merged
merged 16 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/gt4py/eve/datamodels/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import dataclasses
import functools
import sys
import types
import typing
import warnings

Expand Down Expand Up @@ -1254,8 +1255,11 @@ def _make_concrete_with_cache(
if not is_generic_datamodel_class(datamodel_cls):
raise TypeError(f"'{datamodel_cls.__name__}' is not a generic model class.")
for t in type_args:
_accepted_types: tuple[type, ...] = (type, type(None), xtyping.StdGenericAliasType)
if sys.version_info >= (3, 10):
_accepted_types = (*_accepted_types, types.UnionType)
if not (
isinstance(t, (type, type(None), xtyping.StdGenericAliasType))
isinstance(t, _accepted_types)
or (getattr(type(t), "__module__", None) in ("typing", "typing_extensions"))
):
raise TypeError(
Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/eve/type_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import collections.abc
import dataclasses
import functools
import sys
import types
import typing

from . import exceptions, extended_typing as xtyping, utils
Expand Down Expand Up @@ -193,6 +195,12 @@ def __call__(
if type_annotation is None:
type_annotation = type(None)

if sys.version_info >= (3, 10):
if isinstance(
type_annotation, types.UnionType
): # see https://github.com/python/cpython/issues/105499
type_annotation = typing.Union[type_annotation.__args__]

# Non-generic types
if xtyping.is_actual_type(type_annotation):
assert not xtyping.get_args(type_annotation)
Expand Down Expand Up @@ -277,6 +285,7 @@ def __call__(

if issubclass(origin_type, (collections.abc.Sequence, collections.abc.Set)):
assert len(type_args) == 1
make_recursive(type_args[0])
if (member_validator := make_recursive(type_args[0])) is None:
raise exceptions.EveValueError(
f"{type_args[0]} type annotation is not supported."
Expand Down
28 changes: 17 additions & 11 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Any, Optional, TypeVar, cast
from typing import Any, Optional, TypeAlias, TypeVar, cast

import gt4py.next.ffront.field_operator_ast as foast
from gt4py.eve import NodeTranslator, NodeVisitor, traits
Expand Down Expand Up @@ -48,7 +48,7 @@ def with_altered_scalar_kind(
if isinstance(type_spec, ts.FieldType):
return ts.FieldType(
dims=type_spec.dims,
dtype=ts.ScalarType(kind=new_scalar_kind, shape=type_spec.dtype.shape),
dtype=with_altered_scalar_kind(type_spec.dtype, new_scalar_kind),
)
elif isinstance(type_spec, ts.ScalarType):
return ts.ScalarType(kind=new_scalar_kind, shape=type_spec.shape)
Expand All @@ -68,13 +68,18 @@ def construct_tuple_type(
>>> mask_type = ts.FieldType(
... dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.BOOL)
... )
>>> true_branch_types = [ts.ScalarType(kind=ts.ScalarKind), ts.ScalarType(kind=ts.ScalarKind)]
>>> true_branch_types = [
... ts.ScalarType(kind=ts.ScalarKind.FLOAT64),
... ts.ScalarType(kind=ts.ScalarKind.FLOAT64),
... ]
>>> false_branch_types = [
... ts.FieldType(dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind)),
... ts.ScalarType(kind=ts.ScalarKind),
... ts.FieldType(
... dims=[Dimension(value="I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
... ),
... ts.ScalarType(kind=ts.ScalarKind.FLOAT64),
... ]
>>> print(construct_tuple_type(true_branch_types, false_branch_types, mask_type))
[FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<enum 'ScalarKind'>, shape=None)), FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<enum 'ScalarKind'>, shape=None))]
[FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 1064>, shape=None)), FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 1064>, shape=None))]
"""
element_types_new = true_branch_types
for i, element in enumerate(true_branch_types):
Expand Down Expand Up @@ -105,8 +110,8 @@ def promote_to_mask_type(
>>> I, J = (Dimension(value=dim) for dim in ["I", "J"])
>>> bool_type = ts.ScalarType(kind=ts.ScalarKind.BOOL)
>>> dtype = ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
>>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), ts.ScalarType(kind=dtype))
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=ScalarType(kind=<ScalarKind.FLOAT64: 1064>, shape=None), shape=None))
>>> promote_to_mask_type(ts.FieldType(dims=[I, J], dtype=bool_type), dtype)
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.FLOAT64: 1064>, shape=None))
>>> promote_to_mask_type(
... ts.FieldType(dims=[I, J], dtype=bool_type), ts.FieldType(dims=[I], dtype=dtype)
... )
Expand Down Expand Up @@ -360,7 +365,7 @@ def visit_Assign(self, node: foast.Assign, **kwargs: Any) -> foast.Assign:
def visit_TupleTargetAssign(
self, node: foast.TupleTargetAssign, **kwargs: Any
) -> foast.TupleTargetAssign:
TargetType = list[foast.Starred | foast.Symbol]
TargetType: TypeAlias = list[foast.Starred | foast.Symbol]
values = self.visit(node.value, **kwargs)

if isinstance(values.type, ts.TupleType):
Expand All @@ -374,7 +379,7 @@ def visit_TupleTargetAssign(
)

new_targets: TargetType = []
new_type: ts.TupleType | ts.DataType
new_type: ts.DataType
for i, index in enumerate(indices):
old_target = targets[i]

Expand All @@ -391,7 +396,8 @@ def visit_TupleTargetAssign(
location=old_target.location,
)
else:
new_type = values.type.types[index]
new_type = values.type.types[index] # type: ignore[assignment] # see check in next line
assert isinstance(new_type, ts.DataType)
new_target = self.visit(
old_target, refine_type=new_type, location=old_target.location, **kwargs
)
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def visit_TupleExpr(self, node: foast.TupleExpr, **kwargs: Any) -> itir.Expr:
def visit_UnaryOp(self, node: foast.UnaryOp, **kwargs: Any) -> itir.Expr:
# TODO(tehrengruber): extend iterator ir to support unary operators
dtype = type_info.extract_dtype(node.type)
assert isinstance(dtype, ts.ScalarType)
if node.op in [dialect_ast_enums.UnaryOperator.NOT, dialect_ast_enums.UnaryOperator.INVERT]:
if dtype.kind != ts.ScalarKind.BOOL:
raise NotImplementedError(f"'{node.op}' is only supported on 'bool' arguments.")
Expand Down Expand Up @@ -417,12 +418,14 @@ def _visit_neighbor_sum(self, node: foast.Call, **kwargs: Any) -> itir.Expr:

def _visit_max_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
dtype = type_info.extract_dtype(node.type)
assert isinstance(dtype, ts.ScalarType)
min_value, _ = type_info.arithmetic_bounds(dtype)
init_expr = self._make_literal(str(min_value), dtype)
return self._make_reduction_expr(node, "maximum", init_expr, **kwargs)

def _visit_min_over(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
dtype = type_info.extract_dtype(node.type)
assert isinstance(dtype, ts.ScalarType)
_, max_value = type_info.arithmetic_bounds(dtype)
init_expr = self._make_literal(str(max_value), dtype)
return self._make_reduction_expr(node, "minimum", init_expr, **kwargs)
Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/next/ffront/past_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ def visit_Program(self, node: past.Program, **kwargs: Any) -> past.Program:
location=node.location,
)

def visit_Slice(self, node: past.Slice, **kwargs: Any) -> past.Slice:
return past.Slice(
lower=self.visit(node.lower, **kwargs),
upper=self.visit(node.upper, **kwargs),
step=self.visit(node.step, **kwargs),
type=ts.DeferredType(constraint=None),
location=node.location,
)

def visit_Subscript(self, node: past.Subscript, **kwargs: Any) -> past.Subscript:
value = self.visit(node.value, **kwargs)
return past.Subscript(
Expand Down
1 change: 1 addition & 0 deletions src/gt4py/next/ffront/past_process_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def _field_constituents_shape_and_dims(
match arg_type:
case ts.TupleType():
for el, el_type in zip(arg, arg_type.types):
assert isinstance(el_type, ts.DataType)
yield from _field_constituents_shape_and_dims(el, el_type)
case ts.FieldType():
dims = type_info.extract_dims(arg_type)
Expand Down
8 changes: 5 additions & 3 deletions src/gt4py/next/ffront/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def _scan_param_promotion(param: ts.TypeSpec, arg: ts.TypeSpec) -> ts.FieldType
--------
>>> _scan_param_promotion(
... ts.ScalarType(kind=ts.ScalarKind.INT64),
... ts.FieldType(dims=[common.Dimension("I")], dtype=ts.ScalarKind.FLOAT64),
... ts.FieldType(
... dims=[common.Dimension("I")], dtype=ts.ScalarType(kind=ts.ScalarKind.FLOAT64)
... ),
... )
FieldType(dims=[Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)], dtype=ScalarType(kind=<ScalarKind.INT64: 64>, shape=None))
"""
Expand Down Expand Up @@ -252,8 +254,8 @@ def function_signature_incompatibilities_scanop(
# build a function type to leverage the already existing signature checking capabilities
function_type = ts.FunctionType(
pos_only_args=[],
pos_or_kw_args=promoted_params, # type: ignore[arg-type] # dict is invariant, but we don't care here.
kw_only_args=promoted_kwparams, # type: ignore[arg-type] # same as above
pos_or_kw_args=promoted_params,
kw_only_args=promoted_kwparams,
returns=ts.DeferredType(constraint=None),
)

Expand Down
8 changes: 2 additions & 6 deletions src/gt4py/next/ffront/type_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,19 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from dataclasses import dataclass

import gt4py.next.type_system.type_specifications as ts
from gt4py.next import common as func_common
from gt4py.next import common


@dataclass(frozen=True)
class ProgramType(ts.TypeSpec, ts.CallableType):
definition: ts.FunctionType


@dataclass(frozen=True)
class FieldOperatorType(ts.TypeSpec, ts.CallableType):
definition: ts.FunctionType


@dataclass(frozen=True)
class ScanOperatorType(ts.TypeSpec, ts.CallableType):
axis: func_common.Dimension
axis: common.Dimension
definition: ts.FunctionType
11 changes: 5 additions & 6 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
)
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import builtins, runtime
from gt4py.next.iterator.type_system import type_specifications as itir_ts
from gt4py.next.otf import arguments
from gt4py.next.type_system import type_specifications as ts, type_translation

Expand Down Expand Up @@ -1460,7 +1459,7 @@ class _List(Generic[DT]):
def __getitem__(self, i: int):
return self.values[i]

def __gt_type__(self) -> itir_ts.ListType:
def __gt_type__(self) -> ts.ListType:
offset_tag = self.offset.value
assert isinstance(offset_tag, str)
element_type = type_translation.from_value(self.values[0])
Expand All @@ -1470,7 +1469,7 @@ def __gt_type__(self) -> itir_ts.ListType:
connectivity = offset_provider[offset_tag]
assert common.is_neighbor_connectivity(connectivity)
local_dim = connectivity.__gt_type__().neighbor_dim
return itir_ts.ListType(element_type=element_type, offset_type=local_dim)
return ts.ListType(element_type=element_type, offset_type=local_dim)


@dataclasses.dataclass(frozen=True)
Expand All @@ -1480,10 +1479,10 @@ class _ConstList(Generic[DT]):
def __getitem__(self, _):
return self.value

def __gt_type__(self) -> itir_ts.ListType:
def __gt_type__(self) -> ts.ListType:
element_type = type_translation.from_value(self.value)
assert isinstance(element_type, ts.DataType)
return itir_ts.ListType(
return ts.ListType(
element_type=element_type,
offset_type=_CONST_DIM,
)
Expand Down Expand Up @@ -1801,7 +1800,7 @@ def _fieldspec_list_to_value(
domain: common.Domain, type_: ts.TypeSpec
) -> tuple[common.Domain, ts.TypeSpec]:
"""Translate the list element type into the domain."""
if isinstance(type_, itir_ts.ListType):
if isinstance(type_, ts.ListType):
if type_.offset_type == _CONST_DIM:
return domain.insert(
len(domain), common.named_range((_CONST_DIM, 1))
Expand Down
9 changes: 3 additions & 6 deletions src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@
inline_lifts,
trace_shifts,
)
from gt4py.next.iterator.type_system import (
inference as type_inference,
type_specifications as it_ts,
)
from gt4py.next.iterator.type_system import inference as type_inference
from gt4py.next.type_system import type_info, type_specifications as ts


Expand Down Expand Up @@ -140,7 +137,7 @@ def fuse_as_fieldop(
if arg.type and not isinstance(arg.type, ts.DeferredType):
assert isinstance(arg.type, ts.TypeSpec)
dtype = type_info.apply_to_primitive_constituents(type_info.extract_dtype, arg.type)
assert not isinstance(dtype, it_ts.ListType)
assert not isinstance(dtype, ts.ListType)
new_param: str
if isinstance(
arg, itir.SymRef
Expand Down Expand Up @@ -246,7 +243,7 @@ def visit_FunCall(self, node: itir.FunCall):
)
or cpm.is_call_to(arg, "if_")
)
and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1)
and (isinstance(dtype, ts.ListType) or len(arg_shifts) <= 1)
)
)

Expand Down
12 changes: 6 additions & 6 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ def _transform_by_pattern(
tmp_expr.type,
tuple_constructor=lambda *elements: tuple(elements),
)
tmp_dtypes: ts.ScalarType | tuple[ts.ScalarType | tuple, ...] = (
type_info.apply_to_primitive_constituents(
type_info.extract_dtype,
tmp_expr.type,
tuple_constructor=lambda *elements: tuple(elements),
)
tmp_dtypes: (
ts.ScalarType | ts.ListType | tuple[ts.ScalarType | ts.ListType | tuple, ...]
) = type_info.apply_to_primitive_constituents(
type_info.extract_dtype,
tmp_expr.type,
tuple_constructor=lambda *elements: tuple(elements),
)

# allocate temporary for all tuple elements
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ def _get_dimensions(obj: Any):
if isinstance(obj, common.Dimension):
yield obj
elif isinstance(obj, ts.TypeSpec):
for field in dataclasses.fields(obj.__class__):
yield from _get_dimensions(getattr(obj, field.name))
for field in obj.__datamodel_fields__.keys():
yield from _get_dimensions(getattr(obj, field))
elif isinstance(obj, collections.abc.Mapping):
for el in obj.values():
yield from _get_dimensions(el)
Expand Down Expand Up @@ -479,7 +479,7 @@ def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.Tup
assert domain.dims != "unknown"
assert node.dtype
return type_info.apply_to_primitive_constituents(
lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), # type: ignore[arg-type] # ensured by domain.dims != "unknown" above
lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype),
node.dtype,
)

Expand Down
16 changes: 1 addition & 15 deletions src/gt4py/next/iterator/type_system/type_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,29 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses
from typing import Literal, Optional
from typing import Literal

from gt4py.next import common
from gt4py.next.type_system import type_specifications as ts


@dataclasses.dataclass(frozen=True)
class NamedRangeType(ts.TypeSpec):
dim: common.Dimension


@dataclasses.dataclass(frozen=True)
class DomainType(ts.DataType):
dims: list[common.Dimension] | Literal["unknown"]


@dataclasses.dataclass(frozen=True)
class OffsetLiteralType(ts.TypeSpec):
value: ts.ScalarType | common.Dimension


@dataclasses.dataclass(frozen=True)
class ListType(ts.DataType):
element_type: ts.DataType
# TODO(havogt): the `offset_type` is not yet used in type_inference,
# it is meant to describe the neighborhood (via the local dimension)
offset_type: Optional[common.Dimension] = None


@dataclasses.dataclass(frozen=True)
class IteratorType(ts.DataType, ts.CallableType):
position_dims: list[common.Dimension] | Literal["unknown"]
defined_dims: list[common.Dimension]
element_type: ts.DataType


@dataclasses.dataclass(frozen=True)
class ProgramType(ts.TypeSpec):
params: dict[str, ts.DataType]
Loading
Loading