Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor[next]: use eve.datamodel for types #1750

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
8 changes: 7 additions & 1 deletion src/gt4py/eve/datamodels/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import dataclasses
import functools
import sys
import types
import typing
import warnings

Expand Down Expand Up @@ -1040,6 +1041,11 @@ def _make_datamodel(
for key in annotations:
type_hint = annotations[key] = resolved_annotations[key]

if isinstance(
type_hint, types.UnionType
): # see https://github.com/python/cpython/issues/105499
type_hint = typing.Union[type_hint.__args__]

# Skip members annotated as class variables
if type_hint is ClassVar or xtyping.get_origin(type_hint) is ClassVar:
continue
Expand Down Expand Up @@ -1255,7 +1261,7 @@ def _make_concrete_with_cache(
raise TypeError(f"'{datamodel_cls.__name__}' is not a generic model class.")
for t in type_args:
if not (
isinstance(t, (type, type(None), xtyping.StdGenericAliasType))
isinstance(t, (type, type(None), xtyping.StdGenericAliasType, types.UnionType))
or (getattr(type(t), "__module__", None) in ("typing", "typing_extensions"))
):
raise TypeError(
Expand Down
11 changes: 6 additions & 5 deletions src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Any, Optional, TypeVar, cast
from typing import Any, Optional, TypeAlias, TypeVar, cast

import gt4py.next.ffront.field_operator_ast as foast
from gt4py.eve import NodeTranslator, NodeVisitor, traits
Expand Down Expand Up @@ -48,7 +48,7 @@ def with_altered_scalar_kind(
if isinstance(type_spec, ts.FieldType):
return ts.FieldType(
dims=type_spec.dims,
dtype=ts.ScalarType(kind=new_scalar_kind, shape=type_spec.dtype.shape),
dtype=with_altered_scalar_kind(type_spec.dtype, new_scalar_kind),
)
elif isinstance(type_spec, ts.ScalarType):
return ts.ScalarType(kind=new_scalar_kind, shape=type_spec.shape)
Expand Down Expand Up @@ -360,7 +360,7 @@ def visit_Assign(self, node: foast.Assign, **kwargs: Any) -> foast.Assign:
def visit_TupleTargetAssign(
self, node: foast.TupleTargetAssign, **kwargs: Any
) -> foast.TupleTargetAssign:
TargetType = list[foast.Starred | foast.Symbol]
TargetType: TypeAlias = list[foast.Starred | foast.Symbol]
values = self.visit(node.value, **kwargs)

if isinstance(values.type, ts.TupleType):
Expand All @@ -374,7 +374,7 @@ def visit_TupleTargetAssign(
)

new_targets: TargetType = []
new_type: ts.TupleType | ts.DataType
new_type: ts.DataType
for i, index in enumerate(indices):
old_target = targets[i]

Expand All @@ -391,7 +391,8 @@ def visit_TupleTargetAssign(
location=old_target.location,
)
else:
new_type = values.type.types[index]
new_type = values.type.types[index] # type: ignore[assignment] # see check in next line
assert isinstance(new_type, ts.DataType)
new_target = self.visit(
old_target, refine_type=new_type, location=old_target.location, **kwargs
)
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/past_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def visit_Subscript(self, node: past.Subscript, **kwargs: Any) -> past.Subscript
value = self.visit(node.value, **kwargs)
return past.Subscript(
value=value,
slice_=self.visit(node.slice_, **kwargs),
slice_=node.slice_, # TODO: I don't think we are using this type
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
slice_=node.slice_, # TODO: I don't think we are using this type
slice_=node.slice_,

check and remove

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have a type for past.Slice, but the visit here also types & validates its children (e.g 1+1. in a slice bound would error here). For that reason I consider unconditionally visiting all children a good approach. The role of ts.DeferredType and None in the type system is however not well specified. I would postpone that discussion to the point where we also have time to think about generics / "templated" / partial types, because the discussion is related.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are you proposing to do? Accept None in TupleType elements, this proposed solution or add visit_Slice and return DeferredType.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I overlooked that slice_ is a tuple of slices (maybe the attr name is not so great?). visit_Slice with DeferredType makes sense to me even though this would add another violation of all-type-attrs-of-types-are-concrete rule of a ts.TypeSpec. Otherwise introduction of a slice type makes sense to me too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

type=value.type,
location=node.location,
)
Expand Down
1 change: 1 addition & 0 deletions src/gt4py/next/ffront/past_process_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def _field_constituents_shape_and_dims(
match arg_type:
case ts.TupleType():
for el, el_type in zip(arg, arg_type.types):
assert isinstance(el_type, ts.DataType)
yield from _field_constituents_shape_and_dims(el, el_type)
case ts.FieldType():
dims = type_info.extract_dims(arg_type)
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/ffront/type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ def function_signature_incompatibilities_scanop(
# build a function type to leverage the already existing signature checking capabilities
function_type = ts.FunctionType(
pos_only_args=[],
pos_or_kw_args=promoted_params, # type: ignore[arg-type] # dict is invariant, but we don't care here.
kw_only_args=promoted_kwparams, # type: ignore[arg-type] # same as above
pos_or_kw_args=promoted_params,
kw_only_args=promoted_kwparams,
returns=ts.DeferredType(constraint=None),
)

Expand Down
8 changes: 2 additions & 6 deletions src/gt4py/next/ffront/type_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,19 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from dataclasses import dataclass

import gt4py.next.type_system.type_specifications as ts
from gt4py.next import common as func_common
from gt4py.next import common


@dataclass(frozen=True)
class ProgramType(ts.TypeSpec, ts.CallableType):
definition: ts.FunctionType


@dataclass(frozen=True)
class FieldOperatorType(ts.TypeSpec, ts.CallableType):
definition: ts.FunctionType


@dataclass(frozen=True)
class ScanOperatorType(ts.TypeSpec, ts.CallableType):
axis: func_common.Dimension
axis: common.Dimension
definition: ts.FunctionType
11 changes: 5 additions & 6 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
)
from gt4py.next.ffront import fbuiltins
from gt4py.next.iterator import builtins, runtime
from gt4py.next.iterator.type_system import type_specifications as itir_ts
from gt4py.next.otf import arguments
from gt4py.next.type_system import type_specifications as ts, type_translation

Expand Down Expand Up @@ -1460,7 +1459,7 @@ class _List(Generic[DT]):
def __getitem__(self, i: int):
return self.values[i]

def __gt_type__(self) -> itir_ts.ListType:
def __gt_type__(self) -> ts.ListType:
offset_tag = self.offset.value
assert isinstance(offset_tag, str)
element_type = type_translation.from_value(self.values[0])
Expand All @@ -1470,7 +1469,7 @@ def __gt_type__(self) -> itir_ts.ListType:
connectivity = offset_provider[offset_tag]
assert common.is_neighbor_connectivity(connectivity)
local_dim = connectivity.__gt_type__().neighbor_dim
return itir_ts.ListType(element_type=element_type, offset_type=local_dim)
return ts.ListType(element_type=element_type, offset_type=local_dim)


@dataclasses.dataclass(frozen=True)
Expand All @@ -1480,10 +1479,10 @@ class _ConstList(Generic[DT]):
def __getitem__(self, _):
return self.value

def __gt_type__(self) -> itir_ts.ListType:
def __gt_type__(self) -> ts.ListType:
element_type = type_translation.from_value(self.value)
assert isinstance(element_type, ts.DataType)
return itir_ts.ListType(
return ts.ListType(
element_type=element_type,
offset_type=_CONST_DIM,
)
Expand Down Expand Up @@ -1799,7 +1798,7 @@ def _fieldspec_list_to_value(
domain: common.Domain, type_: ts.TypeSpec
) -> tuple[common.Domain, ts.TypeSpec]:
"""Translate the list element type into the domain."""
if isinstance(type_, itir_ts.ListType):
if isinstance(type_, ts.ListType):
if type_.offset_type == _CONST_DIM:
return domain.insert(
len(domain), common.named_range((_CONST_DIM, 1))
Expand Down
9 changes: 3 additions & 6 deletions src/gt4py/next/iterator/transforms/fuse_as_fieldop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@
inline_lifts,
trace_shifts,
)
from gt4py.next.iterator.type_system import (
inference as type_inference,
type_specifications as it_ts,
)
from gt4py.next.iterator.type_system import inference as type_inference
from gt4py.next.type_system import type_info, type_specifications as ts


Expand Down Expand Up @@ -178,7 +175,7 @@ def visit_FunCall(self, node: itir.FunCall):
and isinstance(arg.fun.args[0], itir.Lambda)
or cpm.is_call_to(arg, "if_")
)
and (isinstance(dtype, it_ts.ListType) or len(arg_shifts) <= 1)
and (isinstance(dtype, ts.ListType) or len(arg_shifts) <= 1)
)
if should_inline:
if cpm.is_applied_as_fieldop(arg):
Expand All @@ -199,7 +196,7 @@ def visit_FunCall(self, node: itir.FunCall):

new_args = _merge_arguments(new_args, extracted_args)
else:
assert not isinstance(dtype, it_ts.ListType)
assert not isinstance(dtype, ts.ListType)
new_param: str
if isinstance(
arg, itir.SymRef
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def _get_dimensions(obj: Any):
if isinstance(obj, common.Dimension):
yield obj
elif isinstance(obj, ts.TypeSpec):
for field in dataclasses.fields(obj.__class__):
havogt marked this conversation as resolved.
Show resolved Hide resolved
for field in obj.__datamodel_fields__.values():
yield from _get_dimensions(getattr(obj, field.name))
havogt marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(obj, collections.abc.Mapping):
for el in obj.values():
Expand Down Expand Up @@ -490,7 +490,7 @@ def visit_Temporary(self, node: itir.Temporary, *, ctx) -> ts.FieldType | ts.Tup
assert domain.dims != "unknown"
assert node.dtype
return type_info.apply_to_primitive_constituents(
lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype), # type: ignore[arg-type] # ensured by domain.dims != "unknown" above
lambda dtype: ts.FieldType(dims=domain.dims, dtype=dtype),
node.dtype,
)

Expand Down
18 changes: 1 addition & 17 deletions src/gt4py/next/iterator/type_system/type_specifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,30 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import dataclasses
from typing import Literal, Optional
from typing import Literal

from gt4py.next import common
from gt4py.next.type_system import type_specifications as ts


@dataclasses.dataclass(frozen=True)
class NamedRangeType(ts.TypeSpec):
dim: common.Dimension


@dataclasses.dataclass(frozen=True)
class DomainType(ts.DataType):
dims: list[common.Dimension] | Literal["unknown"]


@dataclasses.dataclass(frozen=True)
class OffsetLiteralType(ts.TypeSpec):
value: ts.ScalarType | common.Dimension


@dataclasses.dataclass(frozen=True)
class ListType(ts.DataType):
element_type: ts.DataType
# TODO(havogt): the `offset_type` is not yet used in type_inference,
# it is meant to describe the neighborhood (via the local dimension)
offset_type: Optional[common.Dimension] = None


@dataclasses.dataclass(frozen=True)
class IteratorType(ts.DataType, ts.CallableType):
position_dims: list[common.Dimension] | Literal["unknown"]
defined_dims: list[common.Dimension]
element_type: ts.DataType


@dataclasses.dataclass(frozen=True)
class StencilClosureType(ts.TypeSpec):
domain: DomainType
stencil: ts.FunctionType
Expand All @@ -61,12 +47,10 @@ def __post_init__(self):


# TODO(tehrengruber): Remove after new ITIR format with apply_stencil is used everywhere
@dataclasses.dataclass(frozen=True)
class FencilType(ts.TypeSpec):
params: dict[str, ts.DataType]
closures: list[StencilClosureType]


@dataclasses.dataclass(frozen=True)
class ProgramType(ts.TypeSpec):
params: dict[str, ts.DataType]
26 changes: 13 additions & 13 deletions src/gt4py/next/iterator/type_system/type_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,18 @@ def if_(pred: ts.ScalarType, true_branch: ts.DataType, false_branch: ts.DataType


@_register_builtin_type_synthesizer
def make_const_list(scalar: ts.ScalarType) -> it_ts.ListType:
def make_const_list(scalar: ts.ScalarType) -> ts.ListType:
assert isinstance(scalar, ts.ScalarType)
return it_ts.ListType(element_type=scalar)
return ts.ListType(element_type=scalar)


@_register_builtin_type_synthesizer
def list_get(index: ts.ScalarType | it_ts.OffsetLiteralType, list_: it_ts.ListType) -> ts.DataType:
def list_get(index: ts.ScalarType | it_ts.OffsetLiteralType, list_: ts.ListType) -> ts.DataType:
if isinstance(index, it_ts.OffsetLiteralType):
assert isinstance(index.value, ts.ScalarType)
index = index.value
assert isinstance(index, ts.ScalarType) and type_info.is_integral(index)
assert isinstance(list_, it_ts.ListType)
assert isinstance(list_, ts.ListType)
return list_.element_type


Expand Down Expand Up @@ -198,14 +198,14 @@ def index(arg: ts.DimensionType) -> ts.FieldType:


@_register_builtin_type_synthesizer
def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> it_ts.ListType:
def neighbors(offset_literal: it_ts.OffsetLiteralType, it: it_ts.IteratorType) -> ts.ListType:
assert (
isinstance(offset_literal, it_ts.OffsetLiteralType)
and isinstance(offset_literal.value, common.Dimension)
and offset_literal.value.kind == common.DimensionKind.LOCAL
)
assert isinstance(it, it_ts.IteratorType)
return it_ts.ListType(element_type=it.element_type)
return ts.ListType(element_type=it.element_type)


@_register_builtin_type_synthesizer
Expand Down Expand Up @@ -270,7 +270,7 @@ def _convert_as_fieldop_input_to_iterator(
else:
defined_dims.append(dim)
if is_nb_field:
element_type = it_ts.ListType(element_type=element_type)
element_type = ts.ListType(element_type=element_type)

return it_ts.IteratorType(
position_dims=domain.dims, defined_dims=defined_dims, element_type=element_type
Expand Down Expand Up @@ -342,23 +342,23 @@ def apply_scan(
def map_(op: TypeSynthesizer) -> TypeSynthesizer:
@TypeSynthesizer
def applied_map(
*args: it_ts.ListType, offset_provider_type: common.OffsetProviderType
) -> it_ts.ListType:
*args: ts.ListType, offset_provider_type: common.OffsetProviderType
) -> ts.ListType:
assert len(args) > 0
assert all(isinstance(arg, it_ts.ListType) for arg in args)
assert all(isinstance(arg, ts.ListType) for arg in args)
arg_el_types = [arg.element_type for arg in args]
el_type = op(*arg_el_types, offset_provider_type=offset_provider_type)
assert isinstance(el_type, ts.DataType)
return it_ts.ListType(element_type=el_type)
return ts.ListType(element_type=el_type)

return applied_map


@_register_builtin_type_synthesizer
def reduce(op: TypeSynthesizer, init: ts.TypeSpec) -> TypeSynthesizer:
@TypeSynthesizer
def applied_reduce(*args: it_ts.ListType, offset_provider_type: common.OffsetProviderType):
assert all(isinstance(arg, it_ts.ListType) for arg in args)
def applied_reduce(*args: ts.ListType, offset_provider_type: common.OffsetProviderType):
assert all(isinstance(arg, ts.ListType) for arg in args)
return op(
init, *(arg.element_type for arg in args), offset_provider_type=offset_provider_type
)
Expand Down
1 change: 1 addition & 0 deletions src/gt4py/next/otf/binding/nanobind.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _type_string(type_: ts.TypeSpec) -> str:
return f"std::tuple<{','.join(_type_string(t) for t in type_.types)}>"
elif isinstance(type_, ts.FieldType):
ndims = len(type_.dims)
assert isinstance(type_.dtype, ts.ScalarType)
dtype = cpp_interface.render_scalar_type(type_.dtype)
shape = f"nanobind::shape<{', '.join(['gridtools::nanobind::dynamic_size'] * ndims)}>"
buffer_t = f"nanobind::ndarray<{dtype}, {shape}>"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def visit_Temporary(
def dtype_to_cpp(x: ts.DataType) -> str:
if isinstance(x, ts.TupleType):
assert all(isinstance(i, ts.ScalarType) for i in x.types)
return "::gridtools::tuple<" + ", ".join(dtype_to_cpp(i) for i in x.types) + ">"
return "::gridtools::tuple<" + ", ".join(dtype_to_cpp(i) for i in x.types) + ">" # type: ignore[arg-type] # ensured by assert
assert isinstance(x, ts.ScalarType)
res = pytype_to_cpptype(x)
assert isinstance(res, str)
Expand Down
Loading