Skip to content

Commit

Permalink
Merge branch 'hacked_iterator_embedded_with_field' into column_oob_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Sep 5, 2023
2 parents 8de4b10 + 42ae8e4 commit 2e9eec0
Show file tree
Hide file tree
Showing 10 changed files with 245 additions and 163 deletions.
4 changes: 3 additions & 1 deletion src/gt4py/_core/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,15 @@ def subndim(self) -> int:
return len(self.tensor_shape)

def __eq__(self, other: Any) -> bool:
# TODO: discuss (make concrete subclasses equal to instances of this with the same type)
return (
isinstance(other, DType)
and self.scalar_type == other.scalar_type
and self.tensor_shape == other.tensor_shape
)

def __hash__(self) -> int:
return hash((self.scalar_type, self.tensor_shape))


@dataclasses.dataclass(frozen=True)
class IntegerDType(DType[IntegralT]):
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/next/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@

from . import common, ffront, iterator, program_processors, type_inference
from .common import Dimension, DimensionKind, Field, GridType
from .embedded import nd_array_field
from .embedded import ( # Just for registering field implementations
nd_array_field as _nd_array_field,
)
from .ffront import fbuiltins
from .ffront.decorator import field_operator, program, scan_operator
from .ffront.fbuiltins import * # noqa: F403 # fbuiltins defines __all__ and we explicitly want to reexport everything here
Expand Down
201 changes: 140 additions & 61 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
from __future__ import annotations

import abc
import collections
import dataclasses
import enum
import functools
import sys
import types
from collections.abc import Mapping, Sequence, Set
from types import EllipsisType
from typing import TypeGuard, overload
from typing import overload

import numpy as np
import numpy.typing as npt
Expand All @@ -36,7 +37,9 @@
ParamSpec,
Protocol,
TypeAlias,
TypeGuard,
TypeVar,
cast,
extended_runtime_checkable,
runtime_checkable,
)
Expand Down Expand Up @@ -139,36 +142,42 @@ def __str__(self) -> str:
return f"({self.start}:{self.stop})"


def unit_range(r: UnitRangeLike) -> UnitRange:
assert is_unit_range_like(r)
RangeLike: TypeAlias = UnitRange | range | tuple[int, int]


def unit_range(r: RangeLike) -> UnitRange:
if isinstance(r, UnitRange):
return r
if isinstance(r, range) and r.step == 1:
if isinstance(r, range):
if r.step != 1:
raise ValueError(f"`UnitRange` requires step size 1, got `{r.step}`.")
return UnitRange(r.start, r.stop)
if isinstance(r, tuple) and isinstance(r[0], int) and isinstance(r[1], int):
return UnitRange(r[0], r[1])
raise ValueError(f"`{r}` is not `UnitRangeLike`.")
raise ValueError(f"`{r}` cannot be interpreted as `UnitRange`.")


IntIndex: TypeAlias = int | np.integer
DomainRange: TypeAlias = UnitRange | IntIndex
NamedRange: TypeAlias = tuple[Dimension, UnitRange]
IntIndex: TypeAlias = int | core_defs.IntegralScalar
NamedIndex: TypeAlias = tuple[Dimension, IntIndex]
AnyIndex: TypeAlias = IntIndex | NamedRange | NamedIndex | slice | EllipsisType
DomainSlice: TypeAlias = Sequence[NamedRange | NamedIndex]
BufferSlice: TypeAlias = tuple[slice | IntIndex | EllipsisType, ...]
FieldSlice: TypeAlias = DomainSlice | BufferSlice | AnyIndex
UnitRangeLike: TypeAlias = UnitRange | range | tuple[int, int]
DomainLike: TypeAlias = (
Sequence[tuple[Dimension, UnitRangeLike]] | Mapping[Dimension, UnitRangeLike]
)
NamedRange: TypeAlias = tuple[Dimension, UnitRange]
RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType
AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange
AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement
AbsoluteIndexSequence: TypeAlias = Sequence[NamedRange | NamedIndex]
RelativeIndexSequence: TypeAlias = tuple[
slice | IntIndex | types.EllipsisType, ...
] # is a tuple but called Sequence for symmetry
AnyIndexSequence: TypeAlias = RelativeIndexSequence | AbsoluteIndexSequence
AnyIndexSpec: TypeAlias = AnyIndexElement | AnyIndexSequence


def is_int_index(p: Any) -> TypeGuard[IntIndex]:
return isinstance(p, (int, np.integer))
# should be replaced by isinstance(p, IntIndex), but mypy complains with
# `Argument 2 to "isinstance" has incompatible type "<typing special form>"; expected "_ClassInfo" [arg-type]`
return isinstance(p, (int, core_defs.INTEGRAL_TYPES))


def is_named_range(v: Any) -> TypeGuard[NamedRange]:
def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]:
return (
isinstance(v, tuple)
and len(v) == 2
Expand All @@ -177,44 +186,41 @@ def is_named_range(v: Any) -> TypeGuard[NamedRange]:
)


def is_named_index(v: Any) -> TypeGuard[NamedRange]:
def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]:
return (
isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1])
)


def is_domain_slice(v: Any) -> TypeGuard[DomainSlice]:
def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]:
return (
is_int_index(v)
or is_named_range(v)
or is_named_index(v)
or isinstance(v, slice)
or v is Ellipsis
)


def is_absolute_index_sequence(v: AnyIndexSequence) -> TypeGuard[AbsoluteIndexSequence]:
return isinstance(v, Sequence) and all(is_named_range(e) or is_named_index(e) for e in v)


def is_buffer_slice(v: Any) -> TypeGuard[BufferSlice]:
def is_relative_index_sequence(v: AnyIndexSequence) -> TypeGuard[RelativeIndexSequence]:
return isinstance(v, tuple) and all(
isinstance(e, slice) or is_int_index(e) or e is Ellipsis for e in v
)


def is_unit_range_like(v: Any) -> TypeGuard[UnitRangeLike]:
return (
isinstance(v, UnitRange)
or (isinstance(v, range) and v.step == 1)
or (isinstance(v, tuple) and isinstance(v[0], int) and isinstance(v[1], int))
def as_any_index_sequence(index: AnyIndexSpec) -> AnyIndexSequence:
# `cast` because mypy/typing doesn't special case 1-element tuples, i.e. `tuple[A|B] != tuple[A]|tuple[B]`
return cast(
AnyIndexSequence,
(index,) if is_any_index_element(index) else index,
)


def is_domain_like(v: Any) -> TypeGuard[DomainLike]:
return (
isinstance(v, Sequence)
and all(
isinstance(e, tuple) and isinstance(e[0], Dimension) and is_unit_range_like(e[1])
for e in v
)
) or (
isinstance(v, Mapping)
and all(isinstance(d, Dimension) and is_unit_range_like(r) for d, r in v.items())
)


def named_range(v: tuple[Dimension, UnitRangeLike]) -> NamedRange:
def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange:
return (v[0], unit_range(v[1]))


Expand All @@ -230,7 +236,7 @@ def __init__(
*args: NamedRange,
dims: Optional[tuple[Dimension, ...]] = None,
ranges: Optional[tuple[UnitRange, ...]] = None,
):
) -> None:
if dims is not None or ranges is not None:
if dims is None and ranges is None:
raise ValueError("Either both none of `dims` and `ranges` must be specified.")
Expand All @@ -239,8 +245,15 @@ def __init__(
"No extra `args` allowed when constructing fomr `dims` and `ranges`."
)

assert dims is not None
assert ranges is not None
assert dims is not None and ranges is not None # for mypy
if not all(isinstance(dim, Dimension) for dim in dims):
raise ValueError(
f"`dims` argument needs to be a `tuple[Dimension, ...], got `{dims}`."
)
if not all(isinstance(rng, UnitRange) for rng in ranges):
raise ValueError(
f"`ranges` argument needs to be a `tuple[UnitRange, ...], got `{ranges}`."
)
if len(dims) != len(ranges):
raise ValueError(
f"Number of provided dimensions ({len(dims)}) does not match number of provided ranges ({len(ranges)})."
Expand All @@ -249,8 +262,9 @@ def __init__(
object.__setattr__(self, "dims", dims)
object.__setattr__(self, "ranges", ranges)
else:
assert all(is_named_range(arg) for arg in args)
dims, ranges = zip(*args) if len(args) > 0 else ((), ())
if not all(is_named_range(arg) for arg in args):
raise ValueError(f"Elements of `Domain` need to be `NamedRange`s, got `{args}`.")
dims, ranges = zip(*args) if args else ((), ())
object.__setattr__(self, "dims", tuple(dims))
object.__setattr__(self, "ranges", tuple(ranges))

Expand All @@ -265,7 +279,7 @@ def __getitem__(self, index: int) -> NamedRange:
...

@overload
def __getitem__(self, index: slice) -> "Domain":
def __getitem__(self, index: slice) -> Domain:
...

@overload
Expand All @@ -289,6 +303,20 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain:
raise KeyError("Invalid index type, must be either int, slice, or Dimension.")

def __and__(self, other: Domain) -> Domain:
"""
Intersect `Domain`s, missing `Dimension`s are considered infinite.
Examples:
---------
>>> I = Dimension("I")
>>> J = Dimension("J")
>>> Domain((I, UnitRange(-1, 3))) & Domain((I, UnitRange(1, 6)))
Domain(dims=(Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>),), ranges=(UnitRange(1, 3),))
>>> Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4))) & Domain((I, UnitRange(1, 6)))
Domain(dims=(Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)), ranges=(UnitRange(1, 3), UnitRange(2, 4)))
"""
broadcast_dims = tuple(promote_dims(self.dims, other.dims))
intersected_ranges = tuple(
rng1 & rng2
Expand All @@ -303,6 +331,11 @@ def __str__(self) -> str:
return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})"


DomainLike: TypeAlias = (
Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike]
) # `Domain` is `Sequence[NamedRange]` and therefore a subset


def domain(domain_like: DomainLike) -> Domain:
"""
Construct `Domain` from `DomainLike` object.
Expand All @@ -318,17 +351,11 @@ def domain(domain_like: DomainLike) -> Domain:
>>> domain({I: (2, 4), J: (3, 5)})
Domain(dims=(Dimension(value='I', kind=<DimensionKind.HORIZONTAL: 'horizontal'>), Dimension(value='J', kind=<DimensionKind.HORIZONTAL: 'horizontal'>)), ranges=(UnitRange(2, 4), UnitRange(3, 5)))
"""
assert is_domain_like(domain_like)
if isinstance(domain_like, Domain):
return domain_like
if isinstance(domain_like, Sequence) and all(
isinstance(e, tuple) and isinstance(e[0], Dimension) and is_unit_range_like(e[1])
for e in domain_like
):
if isinstance(domain_like, Sequence):
return Domain(*tuple(named_range(d) for d in domain_like))
if isinstance(domain_like, Mapping) and all(
isinstance(d, Dimension) and is_unit_range_like(r) for d, r in domain_like.items()
):
if isinstance(domain_like, Mapping):
return Domain(
dims=tuple(domain_like.keys()),
ranges=tuple(unit_range(r) for r in domain_like.values()),
Expand Down Expand Up @@ -360,8 +387,8 @@ class NextGTDimsInterface(Protocol):
"""
A `GTDimsInterface` is an object providing the `__gt_dims__` property, naming :class:`Field` dimensions.
The dimension names are objects of type :class:`Dimension`, in contrast to :py:mod:`gt4py.cartesian`,
where the labels are `str` s with implied semantics, see :py:class:`~gt4py._core.definitions.GTDimsInterface` .
The dimension names are objects of type :class:`Dimension`, in contrast to :mod:`gt4py.cartesian`,
where the labels are `str` s with implied semantics, see :class:`~gt4py._core.definitions.GTDimsInterface` .
"""

# TODO(havogt): unify with GTDimsInterface, ideally in backward compatible way
Expand Down Expand Up @@ -394,7 +421,7 @@ def remap(self, index_field: Field) -> Field:
...

@abc.abstractmethod
def restrict(self, item: FieldSlice) -> Field | core_defs.ScalarT:
def restrict(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT:
...

# Operators
Expand All @@ -403,7 +430,7 @@ def __call__(self, index_field: Field) -> Field:
...

@abc.abstractmethod
def __getitem__(self, item: FieldSlice) -> Field | core_defs.ScalarT:
def __getitem__(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT:
...

@abc.abstractmethod
Expand All @@ -414,6 +441,10 @@ def __abs__(self) -> Field:
def __neg__(self) -> Field:
...

@abc.abstractmethod
def __invert__(self) -> Field:
"""Only defined for `Field` of value type `bool`."""

@abc.abstractmethod
def __add__(self, other: Field | core_defs.ScalarT) -> Field:
...
Expand Down Expand Up @@ -458,6 +489,18 @@ def __rtruediv__(self, other: Field | core_defs.ScalarT) -> Field:
def __pow__(self, other: Field | core_defs.ScalarT) -> Field:
...

@abc.abstractmethod
def __and__(self, other: Field | core_defs.ScalarT) -> Field:
"""Only defined for `Field` of value type `bool`."""

@abc.abstractmethod
def __or__(self, other: Field | core_defs.ScalarT) -> Field:
"""Only defined for `Field` of value type `bool`."""

@abc.abstractmethod
def __xor__(self, other: Field | core_defs.ScalarT) -> Field:
"""Only defined for `Field` of value type `bool`."""


def is_field(
v: Any,
Expand All @@ -472,12 +515,12 @@ def is_field(
@extended_runtime_checkable
class MutableField(Field[DimsT, core_defs.ScalarT], Protocol[DimsT, core_defs.ScalarT]):
@abc.abstractmethod
def __setitem__(self, index: FieldSlice, value: Field | core_defs.ScalarT) -> None:
def __setitem__(self, index: AnyIndexSpec, value: Field | core_defs.ScalarT) -> None:
...


def is_mutable_field(
v: Any,
v: Field,
) -> TypeGuard[MutableField]:
# This function is introduced to localize the `type: ignore` because
# extended_runtime_checkable does not make the protocol runtime_checkable
Expand Down Expand Up @@ -612,3 +655,39 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]:
)

return topologically_sorted_list


class FieldBuiltinFuncRegistry:
"""
Mixin for adding `fbuiltins` registry to a `Field`.
Subclasses of a `Field` with `FieldBuiltinFuncRegistry` get their own registry,
dispatching (via ChainMap) to its parent's registries.
"""

_builtin_func_map: collections.ChainMap[
fbuiltins.BuiltInFunction, Callable
] = collections.ChainMap()

def __init_subclass__(cls, **kwargs):
cls._builtin_func_map = collections.ChainMap(
{}, # New empty `dict`` for new registrations on this class
*[
c.__dict__["_builtin_func_map"].maps[0] # adding parent `dict`s in mro order
for c in cls.__mro__
if "_builtin_func_map" in c.__dict__
],
)

@classmethod
def register_builtin_func(
cls, /, op: fbuiltins.BuiltInFunction[_R, _P], op_func: Optional[Callable[_P, _R]] = None
) -> Any:
assert op not in cls._builtin_func_map
if op_func is None: # when used as a decorator
return functools.partial(cls.register_builtin_func, op)
return cls._builtin_func_map.setdefault(op, op_func)

@classmethod
def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Callable[_P, _R]:
return cls._builtin_func_map.get(func, NotImplemented)
Loading

0 comments on commit 2e9eec0

Please sign in to comment.