From e16c4aead942fde64e317681276f4374719af175 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 31 Aug 2023 16:09:45 +0200 Subject: [PATCH 01/10] refactor typealiases --- src/gt4py/_core/definitions.py | 4 +- src/gt4py/next/common.py | 125 ++++++++++++---------- src/gt4py/next/embedded/common.py | 39 +++---- src/gt4py/next/embedded/exceptions.py | 4 +- src/gt4py/next/embedded/nd_array_field.py | 23 ++-- 5 files changed, 97 insertions(+), 98 deletions(-) diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index f49bac531a..059ba6c24c 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -247,13 +247,15 @@ def subndim(self) -> int: return len(self.tensor_shape) def __eq__(self, other: Any) -> bool: - # TODO: discuss (make concrete subclasses equal to instances of this with the same type) return ( isinstance(other, DType) and self.scalar_type == other.scalar_type and self.tensor_shape == other.tensor_shape ) + def __hash__(self) -> int: + return hash((self.scalar_type, self.tensor_shape)) + @dataclasses.dataclass(frozen=True) class IntegerDType(DType[IntegralT]): diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 889c73d61f..a1f40c2f92 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -21,7 +21,7 @@ import sys from collections.abc import Mapping, Sequence, Set from types import EllipsisType -from typing import TypeGuard, overload +from typing import overload import numpy as np import numpy.typing as npt @@ -36,7 +36,9 @@ ParamSpec, Protocol, TypeAlias, + TypeGuard, TypeVar, + cast, extended_runtime_checkable, runtime_checkable, ) @@ -139,36 +141,42 @@ def __str__(self) -> str: return f"({self.start}:{self.stop})" -def unit_range(r: UnitRangeLike) -> UnitRange: - assert is_unit_range_like(r) +RangeLike: TypeAlias = UnitRange | range | tuple[int, int] + + +def unit_range(r: RangeLike) -> UnitRange: if isinstance(r, UnitRange): return r - if isinstance(r, range) and r.step == 1: + if isinstance(r, range): + if r.step != 1: + raise ValueError(f"`UnitRange` requires step size 1, got `{r.step}`.") return UnitRange(r.start, r.stop) if isinstance(r, tuple) and isinstance(r[0], int) and isinstance(r[1], int): return UnitRange(r[0], r[1]) - raise ValueError(f"`{r}` is not `UnitRangeLike`.") + raise ValueError(f"`{r}` cannot be interpreted as `UnitRange`.") -IntIndex: TypeAlias = int | np.integer -DomainRange: TypeAlias = UnitRange | IntIndex -NamedRange: TypeAlias = tuple[Dimension, UnitRange] +IntIndex: TypeAlias = int | core_defs.IntegralScalar NamedIndex: TypeAlias = tuple[Dimension, IntIndex] -AnyIndex: TypeAlias = IntIndex | NamedRange | NamedIndex | slice | EllipsisType -DomainSlice: TypeAlias = Sequence[NamedRange | NamedIndex] -BufferSlice: TypeAlias = tuple[slice | IntIndex | EllipsisType, ...] -FieldSlice: TypeAlias = DomainSlice | BufferSlice | AnyIndex -UnitRangeLike: TypeAlias = UnitRange | range | tuple[int, int] -DomainLike: TypeAlias = ( - Sequence[tuple[Dimension, UnitRangeLike]] | Mapping[Dimension, UnitRangeLike] -) +NamedRange: TypeAlias = tuple[Dimension, UnitRange] +RelativeIndexElement: TypeAlias = IntIndex | slice | EllipsisType +AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange +AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement +AbsoluteIndexSequence: TypeAlias = Sequence[NamedRange | NamedIndex] +RelativeIndexSequence: TypeAlias = tuple[ + slice | IntIndex | EllipsisType, ... +] # is a tuple but called Sequence for symmetry +AnyIndexSequence: TypeAlias = RelativeIndexSequence | AbsoluteIndexSequence +AnyIndex: TypeAlias = AnyIndexElement | AnyIndexSequence def is_int_index(p: Any) -> TypeGuard[IntIndex]: - return isinstance(p, (int, np.integer)) + # should be replaced by isinstance(p, IntIndex), but mypy complains with + # `Argument 2 to "isinstance" has incompatible type ""; expected "_ClassInfo" [arg-type]` + return isinstance(p, (int, core_defs.INTEGRAL_TYPES)) -def is_named_range(v: Any) -> TypeGuard[NamedRange]: +def is_named_range(v: AnyIndex) -> TypeGuard[NamedRange]: return ( isinstance(v, tuple) and len(v) == 2 @@ -177,44 +185,41 @@ def is_named_range(v: Any) -> TypeGuard[NamedRange]: ) -def is_named_index(v: Any) -> TypeGuard[NamedRange]: +def is_named_index(v: AnyIndex) -> TypeGuard[NamedRange]: return ( isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1]) ) -def is_domain_slice(v: Any) -> TypeGuard[DomainSlice]: +def is_any_index_element(v: AnyIndex) -> TypeGuard[AnyIndexElement]: + return ( + is_int_index(v) + or is_named_range(v) + or is_named_index(v) + or isinstance(v, slice) + or v is Ellipsis + ) + + +def is_absolute_index_sequence(v: AnyIndexSequence) -> TypeGuard[AbsoluteIndexSequence]: return isinstance(v, Sequence) and all(is_named_range(e) or is_named_index(e) for e in v) -def is_buffer_slice(v: Any) -> TypeGuard[BufferSlice]: +def is_relative_index_sequence(v: AnyIndexSequence) -> TypeGuard[RelativeIndexSequence]: return isinstance(v, tuple) and all( isinstance(e, slice) or is_int_index(e) or e is Ellipsis for e in v ) -def is_unit_range_like(v: Any) -> TypeGuard[UnitRangeLike]: - return ( - isinstance(v, UnitRange) - or (isinstance(v, range) and v.step == 1) - or (isinstance(v, tuple) and isinstance(v[0], int) and isinstance(v[1], int)) - ) - - -def is_domain_like(v: Any) -> TypeGuard[DomainLike]: - return ( - isinstance(v, Sequence) - and all( - isinstance(e, tuple) and isinstance(e[0], Dimension) and is_unit_range_like(e[1]) - for e in v - ) - ) or ( - isinstance(v, Mapping) - and all(isinstance(d, Dimension) and is_unit_range_like(r) for d, r in v.items()) +def as_any_index_sequence(index: AnyIndex) -> AnyIndexSequence: + # `cast` because mypy/typing doesn't special case 1-element tuples, i.e. `tuple[A|B] != tuple[A]|tuple[B]` + return cast( + AnyIndexSequence, + (index,) if is_any_index_element(index) else index, ) -def named_range(v: tuple[Dimension, UnitRangeLike]) -> NamedRange: +def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: return (v[0], unit_range(v[1])) @@ -230,7 +235,8 @@ def __init__( *args: NamedRange, dims: Optional[tuple[Dimension, ...]] = None, ranges: Optional[tuple[UnitRange, ...]] = None, - ): + ) -> None: + # TODO throw user error in case pre-conditions are not met if dims is not None or ranges is not None: if dims is None and ranges is None: raise ValueError("Either both none of `dims` and `ranges` must be specified.") @@ -239,8 +245,15 @@ def __init__( "No extra `args` allowed when constructing fomr `dims` and `ranges`." ) - assert dims is not None - assert ranges is not None + assert dims is not None and ranges is not None # for mypy + if not all(isinstance(dim, Dimension) for dim in dims): + raise ValueError( + f"`dims` argument needs to be a `tuple[Dimension, ...], got `{dims}`." + ) + if not all(isinstance(rng, Dimension) for rng in ranges): + raise ValueError( + f"`ranges` argument needs to be a `tuple[UnitRange, ...], got `{ranges}`." + ) if len(dims) != len(ranges): raise ValueError( f"Number of provided dimensions ({len(dims)}) does not match number of provided ranges ({len(ranges)})." @@ -249,7 +262,8 @@ def __init__( object.__setattr__(self, "dims", dims) object.__setattr__(self, "ranges", ranges) else: - assert all(is_named_range(arg) for arg in args) + if not all(is_named_range(arg) for arg in args): + raise ValueError(f"Elements of `Domain` need to be `NamedRange`s, got `{args}`.") dims, ranges = zip(*args) if len(args) > 0 else ((), ()) object.__setattr__(self, "dims", tuple(dims)) object.__setattr__(self, "ranges", tuple(ranges)) @@ -265,7 +279,7 @@ def __getitem__(self, index: int) -> NamedRange: ... @overload - def __getitem__(self, index: slice) -> "Domain": + def __getitem__(self, index: slice) -> Domain: ... @overload @@ -303,6 +317,9 @@ def __str__(self) -> str: return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})" +DomainLike: TypeAlias = Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike] + + def domain(domain_like: DomainLike) -> Domain: """ Construct `Domain` from `DomainLike` object. @@ -318,17 +335,11 @@ def domain(domain_like: DomainLike) -> Domain: >>> domain({I: (2, 4), J: (3, 5)}) Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) """ - assert is_domain_like(domain_like) if isinstance(domain_like, Domain): return domain_like - if isinstance(domain_like, Sequence) and all( - isinstance(e, tuple) and isinstance(e[0], Dimension) and is_unit_range_like(e[1]) - for e in domain_like - ): + if isinstance(domain_like, Sequence): return Domain(*tuple(named_range(d) for d in domain_like)) - if isinstance(domain_like, Mapping) and all( - isinstance(d, Dimension) and is_unit_range_like(r) for d, r in domain_like.items() - ): + if isinstance(domain_like, Mapping): return Domain( dims=tuple(domain_like.keys()), ranges=tuple(unit_range(r) for r in domain_like.values()), @@ -394,7 +405,7 @@ def remap(self, index_field: Field) -> Field: ... @abc.abstractmethod - def restrict(self, item: FieldSlice) -> Field | core_defs.ScalarT: + def restrict(self, item: AnyIndex) -> Field | core_defs.ScalarT: ... # Operators @@ -403,7 +414,7 @@ def __call__(self, index_field: Field) -> Field: ... @abc.abstractmethod - def __getitem__(self, item: FieldSlice) -> Field | core_defs.ScalarT: + def __getitem__(self, item: AnyIndex) -> Field | core_defs.ScalarT: ... @abc.abstractmethod @@ -472,12 +483,12 @@ def is_field( @extended_runtime_checkable class MutableField(Field[DimsT, core_defs.ScalarT], Protocol[DimsT, core_defs.ScalarT]): @abc.abstractmethod - def __setitem__(self, index: FieldSlice, value: Field | core_defs.ScalarT) -> None: + def __setitem__(self, index: AnyIndex, value: Field | core_defs.ScalarT) -> None: ... def is_mutable_field( - v: Any, + v: Field, ) -> TypeGuard[MutableField]: # This function is introduced to localize the `type: ignore` because # extended_runtime_checkable does not make the protocol runtime_checkable diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index a9a77aeeee..d02ac9d44c 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -14,26 +14,27 @@ import itertools from types import EllipsisType -from typing import Any, Optional, Sequence, cast +from typing import Any, Optional, Sequence from gt4py.next import common from gt4py.next.embedded import exceptions as embedded_exceptions -def sub_domain(domain: common.Domain, index: common.FieldSlice) -> common.Domain: - index = _tuplize_field_slice(index) +def sub_domain(domain: common.Domain, index: common.AnyIndex) -> common.Domain: + index_sequence = common.as_any_index_sequence(index) - if common.is_domain_slice(index): - return _absolute_sub_domain(domain, index) + if common.is_absolute_index_sequence(index_sequence): + return _absolute_sub_domain(domain, index_sequence) - assert isinstance(index, tuple) - if all(isinstance(idx, slice) or common.is_int_index(idx) or idx is Ellipsis for idx in index): - return _relative_sub_domain(domain, index) + if common.is_relative_index_sequence(index_sequence): + return _relative_sub_domain(domain, index_sequence) raise IndexError(f"Unsupported index type: {index}") -def _relative_sub_domain(domain: common.Domain, index: common.BufferSlice) -> common.Domain: +def _relative_sub_domain( + domain: common.Domain, index: common.RelativeIndexSequence +) -> common.Domain: named_ranges: list[common.NamedRange] = [] expanded = _expand_ellipsis(index, len(domain)) @@ -62,7 +63,9 @@ def _relative_sub_domain(domain: common.Domain, index: common.BufferSlice) -> co return common.Domain(*named_ranges) -def _absolute_sub_domain(domain: common.Domain, index: common.DomainSlice) -> common.Domain: +def _absolute_sub_domain( + domain: common.Domain, index: common.AbsoluteIndexSequence +) -> common.Domain: named_ranges: list[common.NamedRange] = [] for i, (dim, rng) in enumerate(domain): if (pos := _find_index_of_dim(dim, index)) is not None: @@ -89,22 +92,6 @@ def _absolute_sub_domain(domain: common.Domain, index: common.DomainSlice) -> co return common.Domain(*named_ranges) -def _tuplize_field_slice(v: common.FieldSlice) -> common.FieldSlice: - """ - Wrap a single index/slice/range into a tuple. - - Note: the condition is complex as `NamedRange`, `NamedIndex` are implemented as `tuple`. - """ - if ( - not isinstance(v, tuple) - and not common.is_domain_slice(v) - or common.is_named_index(v) - or common.is_named_range(v) - ): - return cast(common.FieldSlice, (v,)) - return v - - def _expand_ellipsis( indices: tuple[common.IntIndex | slice | EllipsisType, ...], target_size: int ) -> tuple[common.IntIndex | slice, ...]: diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py index 78c8be5f4b..b190d1a821 100644 --- a/src/gt4py/next/embedded/exceptions.py +++ b/src/gt4py/next/embedded/exceptions.py @@ -20,8 +20,8 @@ class IndexOutOfBounds(gt4py_exceptions.GT4PyError): def __init__( self, domain: common.Domain, - indices: common.FieldSlice, - index: common.AnyIndex, + indices: common.AnyIndex, + index: common.AnyIndexElement, dim: common.Dimension, ): super().__init__( diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 119cd89bbc..7e2dc598cd 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -188,7 +188,7 @@ def from_array( def remap(self: _BaseNdArrayField, connectivity) -> _BaseNdArrayField: raise NotImplementedError() - def restrict(self, index: common.FieldSlice) -> common.Field | core_defs.ScalarT: + def restrict(self, index: common.AnyIndex) -> common.Field | core_defs.ScalarT: new_domain, buffer_slice = self._slice(index) new_buffer = self.ndarray[buffer_slice] @@ -254,17 +254,16 @@ def __invert__(self) -> _BaseNdArrayField: return _make_unary_array_field_intrinsic_func("invert", "invert")(self) raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") - def _slice(self, index: common.FieldSlice) -> tuple[common.Domain, common.BufferSlice]: + def _slice(self, index: common.AnyIndex) -> tuple[common.Domain, common.RelativeIndexSequence]: new_domain = embedded_common.sub_domain(self.domain, index) - index = embedded_common._tuplize_field_slice(index) - + index_sequence = common.as_any_index_sequence(index) slice_ = ( - _get_slices_from_domain_slice(self.domain, index) - if common.is_domain_slice(index) - else index + _get_slices_from_domain_slice(self.domain, index_sequence) + if common.is_absolute_index_sequence(index_sequence) + else index_sequence ) - assert common.is_buffer_slice(slice_), slice_ + assert common.is_relative_index_sequence(slice_), slice_ return new_domain, slice_ @@ -298,7 +297,7 @@ def _slice(self, index: common.FieldSlice) -> tuple[common.Domain, common.Buffer def _np_cp_setitem( self: _BaseNdArrayField[common.DimsT, core_defs.ScalarT], - index: common.FieldSlice, + index: common.AnyIndex, value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: target_domain, target_slice = self._slice(index) @@ -350,7 +349,7 @@ class JaxArrayField(_BaseNdArrayField): def __setitem__( self, - index: common.FieldSlice, + index: common.AnyIndex, value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: # use `self.ndarray.at(index).set(value)` @@ -388,7 +387,7 @@ def _builtins_broadcast( def _get_slices_from_domain_slice( domain: common.Domain, domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], -) -> common.BufferSlice: +) -> common.RelativeIndexSequence: """Generate slices for sub-array extraction based on named ranges or named indices within a Domain. This function generates a tuple of slices that can be used to extract sub-arrays from a field. The provided @@ -415,7 +414,7 @@ def _get_slices_from_domain_slice( def _compute_slice( - rng: common.DomainRange, domain: common.Domain, pos: int + rng: common.UnitRange | common.IntIndex, domain: common.Domain, pos: int ) -> slice | common.IntIndex: """Compute a slice or integer based on the provided range, domain, and position. From af5aa118ca3a31a9a587ce29c581b3d02cce9b46 Mon Sep 17 00:00:00 2001 From: Samuel Date: Thu, 31 Aug 2023 16:36:26 +0200 Subject: [PATCH 02/10] feat[next]: Add `FieldBuiltinFuncRegistry` mixin (#1330) Adds FieldBuiltinFuncRegistry to allow Field subclasses to register their own builtins Co-authored-by: Hannes Vogt --- src/gt4py/next/common.py | 24 +++++++++++++++- src/gt4py/next/embedded/nd_array_field.py | 35 ++--------------------- src/gt4py/next/ffront/fbuiltins.py | 4 --- 3 files changed, 26 insertions(+), 37 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index e06f9c54b1..866b2aadb7 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -15,13 +15,14 @@ from __future__ import annotations import abc +import collections import dataclasses import enum import functools import sys from collections.abc import Sequence, Set from types import EllipsisType -from typing import TypeGuard, overload +from typing import ChainMap, TypeGuard, overload import numpy as np import numpy.typing as npt @@ -481,3 +482,24 @@ def is_domain_slice(index: Any) -> TypeGuard[DomainSlice]: return isinstance(index, Sequence) and all( is_named_range(idx) or is_named_index(idx) for idx in index ) + + +class FieldBuiltinFuncRegistry: + _builtin_func_map: ChainMap[fbuiltins.BuiltInFunction, Callable] = collections.ChainMap() + + def __init_subclass__(cls, **kwargs): + # might break in multiple inheritance (if multiple ancestors have `_builtin_func_map`) + cls._builtin_func_map = cls._builtin_func_map.new_child() + + @classmethod + def register_builtin_func( + cls, /, op: fbuiltins.BuiltInFunction[_R, _P], op_func: Optional[Callable[_P, _R]] = None + ) -> Any: + assert op not in cls._builtin_func_map + if op_func is None: # when used as a decorator + return functools.partial(cls.register_builtin_func, op) + return cls._builtin_func_map.setdefault(op, op_func) + + @classmethod + def __gt_builtin_func__(cls, /, func: fbuiltins.BuiltInFunction[_R, _P]) -> Callable[_P, _R]: + return cls._builtin_func_map.get(func, NotImplemented) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 9813efdd22..ddef77bb78 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -15,17 +15,17 @@ from __future__ import annotations import dataclasses -import functools import itertools from collections.abc import Callable, Sequence from types import EllipsisType, ModuleType -from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar, cast, overload +from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar, cast import numpy as np from numpy import typing as npt from gt4py._core import definitions as core_defs from gt4py.next import common +from gt4py.next.common import FieldBuiltinFuncRegistry from gt4py.next.ffront import fbuiltins @@ -82,7 +82,7 @@ def _builtin_binary_op(a: _BaseNdArrayField, b: common.Field) -> common.Field: @dataclasses.dataclass(frozen=True) -class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT]): +class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT], FieldBuiltinFuncRegistry): """ Shared field implementation for NumPy-like fields. @@ -100,35 +100,6 @@ class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT]): ModuleType ] # TODO(havogt) after storage PR is merged, update to the NDArrayNamespace protocol - _builtin_func_map: ClassVar[dict[fbuiltins.BuiltInFunction, Callable]] = {} - - @classmethod - def __gt_builtin_func__(cls, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _R]: - return cls._builtin_func_map.get(func, NotImplemented) - - @overload - @classmethod - def register_builtin_func( - cls, op: fbuiltins.BuiltInFunction[_R, _P], op_func: None - ) -> functools.partial[Callable[_P, _R]]: - ... - - @overload - @classmethod - def register_builtin_func( - cls, op: fbuiltins.BuiltInFunction[_R, _P], op_func: Callable[_P, _R] - ) -> Callable[_P, _R]: - ... - - @classmethod - def register_builtin_func( - cls, op: fbuiltins.BuiltInFunction[_R, _P], op_func: Optional[Callable[_P, _R]] = None - ) -> Callable[_P, _R] | functools.partial[Callable[_P, _R]]: - assert op not in cls._builtin_func_map - if op_func is None: # when used as a decorator - return functools.partial(cls.register_builtin_func, op) # type: ignore[arg-type] - return cls._builtin_func_map.setdefault(op, op_func) - @property def domain(self) -> common.Domain: return self._domain diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index b6831b35df..ba027be13c 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -11,7 +11,6 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - import dataclasses import inspect from builtins import bool, float, int, tuple @@ -49,7 +48,6 @@ TYPE_ALIAS_NAMES = ["IndexType"] - _P = ParamSpec("_P") _R = TypeVar("_R") @@ -205,7 +203,6 @@ def astype(field: Field | gt4py_defs.ScalarT, type_: type, /) -> Field: "trunc", ] - UNARY_MATH_FP_PREDICATE_BUILTIN_NAMES = ["isfinite", "isinf", "isnan"] @@ -224,7 +221,6 @@ def impl(value: Field | gt4py_defs.ScalarT, /) -> Field | gt4py_defs.ScalarT: ): _make_unary_math_builtin(f) - BINARY_MATH_NUMBER_BUILTIN_NAMES = ["minimum", "maximum", "fmod", "power"] From 00897377e6df97d84d8f854d1ed32cb8a8b8b231 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Thu, 31 Aug 2023 17:25:48 +0200 Subject: [PATCH 03/10] remaining review comments --- src/gt4py/next/common.py | 44 ++++++++++++++++--- src/gt4py/next/embedded/common.py | 24 +++++----- src/gt4py/next/embedded/exceptions.py | 9 ++++ src/gt4py/next/embedded/nd_array_field.py | 10 ++--- src/gt4py/next/utils.py | 7 ++- .../unit_tests/embedded_tests/test_common.py | 22 +++++++--- 6 files changed, 82 insertions(+), 34 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index a1f40c2f92..9708d2a6d2 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -19,8 +19,8 @@ import enum import functools import sys +import types from collections.abc import Mapping, Sequence, Set -from types import EllipsisType from typing import overload import numpy as np @@ -159,12 +159,12 @@ def unit_range(r: RangeLike) -> UnitRange: IntIndex: TypeAlias = int | core_defs.IntegralScalar NamedIndex: TypeAlias = tuple[Dimension, IntIndex] NamedRange: TypeAlias = tuple[Dimension, UnitRange] -RelativeIndexElement: TypeAlias = IntIndex | slice | EllipsisType +RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement AbsoluteIndexSequence: TypeAlias = Sequence[NamedRange | NamedIndex] RelativeIndexSequence: TypeAlias = tuple[ - slice | IntIndex | EllipsisType, ... + slice | IntIndex | types.EllipsisType, ... ] # is a tuple but called Sequence for symmetry AnyIndexSequence: TypeAlias = RelativeIndexSequence | AbsoluteIndexSequence AnyIndex: TypeAlias = AnyIndexElement | AnyIndexSequence @@ -250,7 +250,7 @@ def __init__( raise ValueError( f"`dims` argument needs to be a `tuple[Dimension, ...], got `{dims}`." ) - if not all(isinstance(rng, Dimension) for rng in ranges): + if not all(isinstance(rng, UnitRange) for rng in ranges): raise ValueError( f"`ranges` argument needs to be a `tuple[UnitRange, ...], got `{ranges}`." ) @@ -264,7 +264,7 @@ def __init__( else: if not all(is_named_range(arg) for arg in args): raise ValueError(f"Elements of `Domain` need to be `NamedRange`s, got `{args}`.") - dims, ranges = zip(*args) if len(args) > 0 else ((), ()) + dims, ranges = zip(*args) if args else ((), ()) object.__setattr__(self, "dims", tuple(dims)) object.__setattr__(self, "ranges", tuple(ranges)) @@ -303,6 +303,20 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: raise KeyError("Invalid index type, must be either int, slice, or Dimension.") def __and__(self, other: Domain) -> Domain: + """ + Intersect `Domain`s, missing `Dimension`s are considered infinite. + + Examples: + --------- + >>> I = Dimension("I") + >>> J = Dimension("J") + + >>> Domain((I, UnitRange(-1, 3))) & Domain((I, UnitRange(1, 6))) + Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(1, 3),)) + + >>> Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4))) & Domain((I, UnitRange(1, 6))) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(1, 3), UnitRange(2, 4))) + """ broadcast_dims = tuple(promote_dims(self.dims, other.dims)) intersected_ranges = tuple( rng1 & rng2 @@ -371,8 +385,8 @@ class NextGTDimsInterface(Protocol): """ A `GTDimsInterface` is an object providing the `__gt_dims__` property, naming :class:`Field` dimensions. - The dimension names are objects of type :class:`Dimension`, in contrast to :py:mod:`gt4py.cartesian`, - where the labels are `str` s with implied semantics, see :py:class:`~gt4py._core.definitions.GTDimsInterface` . + The dimension names are objects of type :class:`Dimension`, in contrast to :mod:`gt4py.cartesian`, + where the labels are `str` s with implied semantics, see :class:`~gt4py._core.definitions.GTDimsInterface` . """ # TODO(havogt): unify with GTDimsInterface, ideally in backward compatible way @@ -425,6 +439,10 @@ def __abs__(self) -> Field: def __neg__(self) -> Field: ... + @abc.abstractmethod + def __invert__(self) -> Field: + """Only defined for `Field` of value type `bool`.""" + @abc.abstractmethod def __add__(self, other: Field | core_defs.ScalarT) -> Field: ... @@ -469,6 +487,18 @@ def __rtruediv__(self, other: Field | core_defs.ScalarT) -> Field: def __pow__(self, other: Field | core_defs.ScalarT) -> Field: ... + @abc.abstractmethod + def __and__(self, other: Field | core_defs.ScalarT) -> Field: + """Only defined for `Field` of value type `bool`.""" + + @abc.abstractmethod + def __or__(self, other: Field | core_defs.ScalarT) -> Field: + """Only defined for `Field` of value type `bool`.""" + + @abc.abstractmethod + def __xor__(self, other: Field | core_defs.ScalarT) -> Field: + """Only defined for `Field` of value type `bool`.""" + def is_field( v: Any, diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index d02ac9d44c..6ad909fc1b 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -12,9 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import itertools -from types import EllipsisType -from typing import Any, Optional, Sequence +from typing import Any, Optional, Sequence, cast from gt4py.next import common from gt4py.next.embedded import exceptions as embedded_exceptions @@ -40,9 +38,8 @@ def _relative_sub_domain( expanded = _expand_ellipsis(index, len(domain)) if len(domain) < len(expanded): raise IndexError(f"Trying to index a `Field` with {len(domain)} dimensions with {index}.") - for (dim, rng), idx in itertools.zip_longest( # type: ignore[misc] # "slice" object is not iterable, not sure which slice... - domain, expanded, fillvalue=slice(None) - ): + expanded += (slice(None),) * (len(domain) - len(expanded)) + for (dim, rng), idx in zip(domain, expanded, strict=True): if isinstance(idx, slice): try: sliced = _slice_range(rng, idx) @@ -93,15 +90,14 @@ def _absolute_sub_domain( def _expand_ellipsis( - indices: tuple[common.IntIndex | slice | EllipsisType, ...], target_size: int + indices: common.RelativeIndexSequence, target_size: int ) -> tuple[common.IntIndex | slice, ...]: - expanded_indices: list[common.IntIndex | slice] = [] - for idx in indices: - if idx is Ellipsis: - expanded_indices.extend([slice(None)] * (target_size - (len(indices) - 1))) - else: - expanded_indices.append(idx) - return tuple(expanded_indices) + if Ellipsis in indices: + idx = indices.index(Ellipsis) + indices = ( + indices[:idx] + (slice(None),) * (target_size - (len(indices) - 1)) + indices[idx + 1 :] + ) + return cast(tuple[common.IntIndex | slice, ...], indices) # mypy leave me alone and trust me! def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.UnitRange: diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py index b190d1a821..c115487367 100644 --- a/src/gt4py/next/embedded/exceptions.py +++ b/src/gt4py/next/embedded/exceptions.py @@ -17,6 +17,11 @@ class IndexOutOfBounds(gt4py_exceptions.GT4PyError): + domain: common.Domain + indices: common.AnyIndex + index: common.AnyIndexElement + dim: common.Dimension + def __init__( self, domain: common.Domain, @@ -27,3 +32,7 @@ def __init__( super().__init__( f"Out of bounds: slicing {domain} with index `{indices}`, `{index}` is out of bounds in dimension `{dim}`." ) + self.domain = domain + self.indices = indices + self.index = index + self.dim = dim diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 7e2dc598cd..effd5e2694 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -224,7 +224,7 @@ def restrict(self, index: common.AnyIndex) -> common.Field | core_defs.ScalarT: __mod__ = __rmod__ = _make_binary_array_field_intrinsic_func("mod", "mod") - def __and__(self, other: common.Field) -> _BaseNdArrayField: + def __and__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: if self.dtype == core_defs.BoolDType(): return _make_binary_array_field_intrinsic_func("logical_and", "logical_and")( self, other @@ -233,14 +233,14 @@ def __and__(self, other: common.Field) -> _BaseNdArrayField: __rand__ = __and__ - def __or__(self, other: common.Field) -> _BaseNdArrayField: + def __or__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: if self.dtype == core_defs.BoolDType(): return _make_binary_array_field_intrinsic_func("logical_or", "logical_or")(self, other) raise NotImplementedError("`__or__` not implemented for non-`bool` fields.") __ror__ = __or__ - def __xor__(self, other: common.Field) -> _BaseNdArrayField: + def __xor__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: if self.dtype == core_defs.BoolDType(): return _make_binary_array_field_intrinsic_func("logical_xor", "logical_xor")( self, other @@ -263,7 +263,7 @@ def _slice(self, index: common.AnyIndex) -> tuple[common.Domain, common.Relative if common.is_absolute_index_sequence(index_sequence) else index_sequence ) - assert common.is_relative_index_sequence(slice_), slice_ + assert common.is_relative_index_sequence(slice_) return new_domain, slice_ @@ -352,7 +352,7 @@ def __setitem__( index: common.AnyIndex, value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: - # use `self.ndarray.at(index).set(value)` + # TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)` raise NotImplementedError("`__setitem__` for JaxArrayField not yet implemented.") common.field.register(jnp.ndarray, JaxArrayField.from_array) diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 3e7dc2d4d3..006b3057b0 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, ClassVar, TypeGuard +from typing import Any, ClassVar, TypeGuard, TypeVar class RecursionGuard: @@ -51,5 +51,8 @@ def __exit__(self, *exc): self.guarded_objects.remove(id(self.obj)) -def is_tuple_of(v: Any, t: type) -> TypeGuard[tuple]: +_T = TypeVar("_T") + + +def is_tuple_of(v: Any, t: type[_T]) -> TypeGuard[tuple[_T, ...]]: return isinstance(v, tuple) and all(isinstance(e, t) for e in v) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index 444978097c..640ed326bb 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -22,12 +22,17 @@ from gt4py.next.embedded.common import _slice_range, sub_domain -def test_slice_range(): - input_range = UnitRange(2, 10) - slice_obj = slice(2, -2) - expected = UnitRange(4, 8) - - result = _slice_range(input_range, slice_obj) +@pytest.mark.parametrize( + "rng, slce, expected", + [ + (UnitRange(2, 10), slice(2, -2), UnitRange(4, 8)), + (UnitRange(2, 10), slice(2, None), UnitRange(4, 10)), + (UnitRange(2, 10), slice(None, -2), UnitRange(2, 8)), + (UnitRange(2, 10), slice(None), UnitRange(2, 10)), + ], +) +def test_slice_range(rng, slce, expected): + result = _slice_range(rng, slce) assert result == expected @@ -114,6 +119,11 @@ def test_slice_range(): (slice(1, 2), Ellipsis, slice(2, 3)), [(I, (3, 4)), (J, (3, 6)), (K, (6, 7))], ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), slice(1, 2), Ellipsis), + [(I, (3, 4)), (J, (4, 5)), (K, (4, 7))], + ), ], ) def test_sub_domain(domain, index, expected): From 9ff35ab74da602b515579f9cfe02e70d656b5d9a Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 5 Sep 2023 09:27:38 +0200 Subject: [PATCH 04/10] Apply suggestions from code review Co-authored-by: Enrique G. Paredes <18477+egparedes@users.noreply.github.com> --- src/gt4py/next/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index 5d7a5f480e..f5af3d1f97 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -25,7 +25,7 @@ from . import common, ffront, iterator, program_processors, type_inference from .common import Dimension, DimensionKind, Field, GridType -from .embedded import nd_array_field +from .embedded import nd_array_field as _nd_array_field # Just for registering field implementations from .ffront import fbuiltins from .ffront.decorator import field_operator, program, scan_operator from .ffront.fbuiltins import * # noqa: F403 # fbuiltins defines __all__ and we explicitly want to reexport everything here From 0ceaaf4dd28b38061770cc386200d7a4a997bde7 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 5 Sep 2023 09:28:30 +0200 Subject: [PATCH 05/10] address review comments --- src/gt4py/next/common.py | 17 ++++++++--------- src/gt4py/next/embedded/common.py | 2 +- src/gt4py/next/embedded/exceptions.py | 4 ++-- src/gt4py/next/embedded/nd_array_field.py | 10 ++++++---- 4 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 84e0d1f145..f5c96f07df 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -168,7 +168,7 @@ def unit_range(r: RangeLike) -> UnitRange: slice | IntIndex | types.EllipsisType, ... ] # is a tuple but called Sequence for symmetry AnyIndexSequence: TypeAlias = RelativeIndexSequence | AbsoluteIndexSequence -AnyIndex: TypeAlias = AnyIndexElement | AnyIndexSequence +AnyIndexSpec: TypeAlias = AnyIndexElement | AnyIndexSequence def is_int_index(p: Any) -> TypeGuard[IntIndex]: @@ -177,7 +177,7 @@ def is_int_index(p: Any) -> TypeGuard[IntIndex]: return isinstance(p, (int, core_defs.INTEGRAL_TYPES)) -def is_named_range(v: AnyIndex) -> TypeGuard[NamedRange]: +def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]: return ( isinstance(v, tuple) and len(v) == 2 @@ -186,13 +186,13 @@ def is_named_range(v: AnyIndex) -> TypeGuard[NamedRange]: ) -def is_named_index(v: AnyIndex) -> TypeGuard[NamedRange]: +def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]: return ( isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1]) ) -def is_any_index_element(v: AnyIndex) -> TypeGuard[AnyIndexElement]: +def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]: return ( is_int_index(v) or is_named_range(v) @@ -212,7 +212,7 @@ def is_relative_index_sequence(v: AnyIndexSequence) -> TypeGuard[RelativeIndexSe ) -def as_any_index_sequence(index: AnyIndex) -> AnyIndexSequence: +def as_any_index_sequence(index: AnyIndexSpec) -> AnyIndexSequence: # `cast` because mypy/typing doesn't special case 1-element tuples, i.e. `tuple[A|B] != tuple[A]|tuple[B]` return cast( AnyIndexSequence, @@ -237,7 +237,6 @@ def __init__( dims: Optional[tuple[Dimension, ...]] = None, ranges: Optional[tuple[UnitRange, ...]] = None, ) -> None: - # TODO throw user error in case pre-conditions are not met if dims is not None or ranges is not None: if dims is None and ranges is None: raise ValueError("Either both none of `dims` and `ranges` must be specified.") @@ -420,7 +419,7 @@ def remap(self, index_field: Field) -> Field: ... @abc.abstractmethod - def restrict(self, item: AnyIndex) -> Field | core_defs.ScalarT: + def restrict(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... # Operators @@ -429,7 +428,7 @@ def __call__(self, index_field: Field) -> Field: ... @abc.abstractmethod - def __getitem__(self, item: AnyIndex) -> Field | core_defs.ScalarT: + def __getitem__(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... @abc.abstractmethod @@ -514,7 +513,7 @@ def is_field( @extended_runtime_checkable class MutableField(Field[DimsT, core_defs.ScalarT], Protocol[DimsT, core_defs.ScalarT]): @abc.abstractmethod - def __setitem__(self, index: AnyIndex, value: Field | core_defs.ScalarT) -> None: + def __setitem__(self, index: AnyIndexSpec, value: Field | core_defs.ScalarT) -> None: ... diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 6ad909fc1b..3799923d87 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -18,7 +18,7 @@ from gt4py.next.embedded import exceptions as embedded_exceptions -def sub_domain(domain: common.Domain, index: common.AnyIndex) -> common.Domain: +def sub_domain(domain: common.Domain, index: common.AnyIndexSpec) -> common.Domain: index_sequence = common.as_any_index_sequence(index) if common.is_absolute_index_sequence(index_sequence): diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py index c115487367..393123db36 100644 --- a/src/gt4py/next/embedded/exceptions.py +++ b/src/gt4py/next/embedded/exceptions.py @@ -18,14 +18,14 @@ class IndexOutOfBounds(gt4py_exceptions.GT4PyError): domain: common.Domain - indices: common.AnyIndex + indices: common.AnyIndexSpec index: common.AnyIndexElement dim: common.Dimension def __init__( self, domain: common.Domain, - indices: common.AnyIndex, + indices: common.AnyIndexSpec, index: common.AnyIndexElement, dim: common.Dimension, ): diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index e684f3e24c..fcaa09e7eb 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -160,7 +160,7 @@ def from_array( def remap(self: _BaseNdArrayField, connectivity) -> _BaseNdArrayField: raise NotImplementedError() - def restrict(self, index: common.AnyIndex) -> common.Field | core_defs.ScalarT: + def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT: new_domain, buffer_slice = self._slice(index) new_buffer = self.ndarray[buffer_slice] @@ -226,7 +226,9 @@ def __invert__(self) -> _BaseNdArrayField: return _make_unary_array_field_intrinsic_func("invert", "invert")(self) raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") - def _slice(self, index: common.AnyIndex) -> tuple[common.Domain, common.RelativeIndexSequence]: + def _slice( + self, index: common.AnyIndexSpec + ) -> tuple[common.Domain, common.RelativeIndexSequence]: new_domain = embedded_common.sub_domain(self.domain, index) index_sequence = common.as_any_index_sequence(index) @@ -269,7 +271,7 @@ def _slice(self, index: common.AnyIndex) -> tuple[common.Domain, common.Relative def _np_cp_setitem( self: _BaseNdArrayField[common.DimsT, core_defs.ScalarT], - index: common.AnyIndex, + index: common.AnyIndexSpec, value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: target_domain, target_slice = self._slice(index) @@ -321,7 +323,7 @@ class JaxArrayField(_BaseNdArrayField): def __setitem__( self, - index: common.AnyIndex, + index: common.AnyIndexSpec, value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, ) -> None: # TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)` From d7366cba9114b24213c9abf3cdc644681360aa0b Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 5 Sep 2023 09:34:29 +0200 Subject: [PATCH 06/10] fix formatting --- src/gt4py/next/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index f5af3d1f97..cc35899668 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -25,7 +25,9 @@ from . import common, ffront, iterator, program_processors, type_inference from .common import Dimension, DimensionKind, Field, GridType -from .embedded import nd_array_field as _nd_array_field # Just for registering field implementations +from .embedded import ( # Just for registering field implementations + nd_array_field as _nd_array_field, +) from .ffront import fbuiltins from .ffront.decorator import field_operator, program, scan_operator from .ffront.fbuiltins import * # noqa: F403 # fbuiltins defines __all__ and we explicitly want to reexport everything here From 9d87bbd8add1a0d741b4481886f657df84c41e5b Mon Sep 17 00:00:00 2001 From: edopao Date: Tue, 5 Sep 2023 09:40:36 +0200 Subject: [PATCH 07/10] [dace] Enable tests with type inference (#1331) Enable some ITIR tests which were disabled because argument types were not propagated. Now possible to run, after improvements to type inference. --- .../feature_tests/ffront_tests/test_gt4py_builtins.py | 5 ----- .../feature_tests/iterator_tests/test_builtins.py | 8 -------- .../feature_tests/iterator_tests/test_implicit_fencil.py | 4 ---- .../iterator_tests/test_column_stencil.py | 4 ---- 4 files changed, 21 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index 26f01ca813..5f19311a32 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -46,8 +46,6 @@ ids=["positive_values", "negative_values"], ) def test_maxover_execution_(unstructured_case, strategy): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions") if unstructured_case.backend in [gtfn_cpu.run_gtfn, gtfn_cpu.run_gtfn_imperative]: pytest.xfail("`maxover` broken in gtfn, see #1289.") @@ -65,9 +63,6 @@ def testee(edge_f: cases.EField) -> cases.VField: def test_minover_execution(unstructured_case): - if unstructured_case.backend == dace_iterator.run_dace_iterator: - pytest.xfail("Not supported in DaCe backend: reductions") - @gtx.field_operator def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py index 13fcf3b87f..673a989122 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_builtins.py @@ -171,10 +171,6 @@ def arithmetic_and_logical_test_data(): @pytest.mark.parametrize("builtin, inputs, expected", arithmetic_and_logical_test_data()) def test_arithmetic_and_logical_builtins(program_processor, builtin, inputs, expected, as_column): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) inps = asfield(*asarray(*inputs)) out = asfield((np.zeros_like(*asarray(expected))))[0] @@ -207,10 +203,6 @@ def test_arithmetic_and_logical_functors_gtfn(builtin, inputs, expected): @pytest.mark.parametrize("builtin_name, inputs", math_builtin_test_data()) def test_math_function_builtins(program_processor, builtin_name, inputs, as_column): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) if builtin_name == "gamma": # numpy has no gamma function diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py index 37ac4623fd..2f7808b30e 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_implicit_fencil.py @@ -59,10 +59,6 @@ def test_single_argument(program_processor, dom): def test_2_arguments(program_processor, dom): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) @fundef def fun(inp0, inp1): diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index 6c58ded3a9..5970b9a2a9 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -292,10 +292,6 @@ def sum_shifted_fencil(out, inp0, inp1, k_size): def test_different_vertical_sizes(program_processor): program_processor, validate = program_processor - if program_processor == run_dace_iterator: - pytest.xfail( - "Not supported in DaCe backend: argument types are not propagated for ITIR tests" - ) k_size = 10 inp0 = gtx.np_as_located_field(KDim)(np.arange(0, k_size)) From f2a9030c81cd18781877fbc004c6442fab5acb9c Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 5 Sep 2023 10:10:02 +0200 Subject: [PATCH 08/10] fix comments --- src/gt4py/next/common.py | 21 ++++++++++++++++++--- src/gt4py/next/embedded/common.py | 2 +- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index f5c96f07df..b85239cd0a 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -331,7 +331,9 @@ def __str__(self) -> str: return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})" -DomainLike: TypeAlias = Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike] +DomainLike: TypeAlias = ( + Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike] +) # `Domain` is `Sequence[NamedRange]` and therefore a subset def domain(domain_like: DomainLike) -> Domain: @@ -656,13 +658,26 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: class FieldBuiltinFuncRegistry: + """ + Mixin for adding `fbuiltins` registry to a `Field`. + + Subclasses of a `Field` with `FieldBuiltinFuncRegistry` get their own registry, + dispatching (via ChainMap) to its parent's registries. + """ + _builtin_func_map: collections.ChainMap[ fbuiltins.BuiltInFunction, Callable ] = collections.ChainMap() def __init_subclass__(cls, **kwargs): - # might break in multiple inheritance (if multiple ancestors have `_builtin_func_map`) - cls._builtin_func_map = cls._builtin_func_map.new_child() + cls._builtin_func_map = collections.ChainMap( + {}, # New empty `dict`` for new registrations on this class + *[ + c.__dict__["_builtin_func_map"].maps[0] # adding parent `dict`s in mro order + for c in cls.__mro__ + if "_builtin_func_map" in c.__dict__ + ], + ) @classmethod def register_builtin_func( diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 3799923d87..37ba4954f3 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -112,7 +112,7 @@ def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.Unit ) + (slice_obj.stop or len(input_range)) if start < input_range.start or stop > input_range.stop: - raise IndexError() + raise IndexError("Slice out of range (no clipping following array API standard).") return common.UnitRange(start, stop) From 8d91a7b047f5d3af57ac3c1407a412e69f6cea89 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 5 Sep 2023 11:03:27 +0200 Subject: [PATCH 09/10] refactor[next] Prepare new Field for itir.embedded (#1329) - improve TypeAliases - add `domain` and `unit_range` constructors - extract domain slicing utils to `next.embedded.common` - introduce `MutableField` - add some missing operators to `Field` --------- Co-authored-by: Enrique G. Paredes <18477+egparedes@users.noreply.github.com> --- docs/user/cartesian/Makefile | 15 +- docs/user/cartesian/arrays.rst | 6 +- src/gt4py/_core/definitions.py | 54 ++- src/gt4py/next/__init__.py | 3 + src/gt4py/next/common.py | 326 ++++++++++++++---- src/gt4py/next/embedded/common.py | 127 +++++++ src/gt4py/next/embedded/exceptions.py | 38 ++ src/gt4py/next/embedded/nd_array_field.py | 272 +++++++-------- src/gt4py/next/errors/exceptions.py | 12 +- src/gt4py/next/ffront/fbuiltins.py | 1 + src/gt4py/next/utils.py | 9 +- .../ffront_tests/test_foast_pretty_printer.py | 2 +- .../unit_tests/embedded_tests/test_common.py | 137 ++++++++ .../embedded_tests/test_nd_array_field.py | 197 ++++++++--- tests/next_tests/unit_tests/test_common.py | 125 ++++--- 15 files changed, 1002 insertions(+), 322 deletions(-) create mode 100644 src/gt4py/next/embedded/common.py create mode 100644 src/gt4py/next/embedded/exceptions.py create mode 100644 tests/next_tests/unit_tests/embedded_tests/test_common.py diff --git a/docs/user/cartesian/Makefile b/docs/user/cartesian/Makefile index 091bc3b8d2..13e692b96d 100644 --- a/docs/user/cartesian/Makefile +++ b/docs/user/cartesian/Makefile @@ -2,12 +2,13 @@ # # You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = -SRCDIR = ../../../src/gt4py -AUTODOCDIR = _source -BUILDDIR = _build +SPHINXOPTS = +SPHINXBUILD = sphinx-build +PAPER = +SRCDIR = ../../../src/gt4py +SPHINX_APIDOC_OPTS = --private # private modules for gt4py._core +AUTODOCDIR = _source +BUILDDIR = _build # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) @@ -55,7 +56,7 @@ clean: autodoc: @echo @echo "Running sphinx-apidoc..." - sphinx-apidoc ${SPHINX_OPTS} -o ${AUTODOCDIR} ${SRCDIR} + sphinx-apidoc ${SPHINX_APIDOC_OPTS} -o ${AUTODOCDIR} ${SRCDIR} @echo @echo "sphinx-apidoc finished. The generated autodocs are in $(AUTODOCDIR)." diff --git a/docs/user/cartesian/arrays.rst b/docs/user/cartesian/arrays.rst index 6788e2757f..6ef7c6e5c1 100644 --- a/docs/user/cartesian/arrays.rst +++ b/docs/user/cartesian/arrays.rst @@ -39,6 +39,8 @@ Internally, gt4py uses the utilities :code:`gt4py.utils.as_numpy` and :code:`gt4 buffers. GT4Py developers are advised to always use those utilities as to guarantee support across gt4py as the supported interfaces are extended. +.. _cartesian-arrays-dimension-mapping: + Dimension Mapping ^^^^^^^^^^^^^^^^^ @@ -56,6 +58,8 @@ which implements this lookup. Note: Support for xarray can be added manually by the user by means of the mechanism described `here `_. +.. _cartesian-arrays-default-origin: + Default Origin ^^^^^^^^^^^^^^ @@ -180,4 +184,4 @@ Additionally, these **optional** keyword-only parameters are accepted: determine the default layout for the storage. Currently supported will be :code:`"I"`, :code:`"J"`, :code:`"K"` and additional dimensions as string representations of integers, starting at :code:`"0"`. (This information is not retained in the resulting array, and needs to be specified instead - with the :code:`__gt_dims__` interface. ) \ No newline at end of file + with the :code:`__gt_dims__` interface. ) diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 2546ae3e4e..059ba6c24c 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -213,18 +213,13 @@ class DType(Generic[ScalarT]): """ scalar_type: Type[ScalarT] - tensor_shape: TensorShape + tensor_shape: TensorShape = dataclasses.field(default=()) - def __init__( - self, scalar_type: Type[ScalarT], tensor_shape: Sequence[IntegralScalar] = () - ) -> None: - if not isinstance(scalar_type, type): - raise TypeError(f"Invalid scalar type '{scalar_type}'") - if not is_valid_tensor_shape(tensor_shape): - raise TypeError(f"Invalid tensor shape '{tensor_shape}'") - - object.__setattr__(self, "scalar_type", scalar_type) - object.__setattr__(self, "tensor_shape", tensor_shape) + def __post_init__(self) -> None: + if not isinstance(self.scalar_type, type): + raise TypeError(f"Invalid scalar type '{self.scalar_type}'") + if not is_valid_tensor_shape(self.tensor_shape): + raise TypeError(f"Invalid tensor shape '{self.tensor_shape}'") @functools.cached_property def kind(self) -> DTypeKind: @@ -251,6 +246,16 @@ def lanes(self) -> int: def subndim(self) -> int: return len(self.tensor_shape) + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, DType) + and self.scalar_type == other.scalar_type + and self.tensor_shape == other.tensor_shape + ) + + def __hash__(self) -> int: + return hash((self.scalar_type, self.tensor_shape)) + @dataclasses.dataclass(frozen=True) class IntegerDType(DType[IntegralT]): @@ -322,6 +327,11 @@ class Float64DType(FloatingDType[float64]): scalar_type: Final[Type[float64]] = dataclasses.field(default=float64, init=False) +@dataclasses.dataclass(frozen=True) +class BoolDType(DType[bool_]): + scalar_type: Final[Type[bool_]] = dataclasses.field(default=bool_, init=False) + + DTypeLike = Union[DType, npt.DTypeLike] @@ -332,11 +342,29 @@ def dtype(dtype_like: DTypeLike) -> DType: # -- Custom protocols -- class GTDimsInterface(Protocol): - __gt_dims__: Tuple[str, ...] + """ + A `GTDimsInterface` is an object providing the `__gt_dims__` property, naming the buffer dimensions. + + In `gt4py.cartesian` the allowed values are `"I"`, `"J"` and `"K"` with the established semantics. + + See :ref:`cartesian-arrays-dimension-mapping` for details. + """ + + @property + def __gt_dims__(self) -> Tuple[str, ...]: + ... class GTOriginInterface(Protocol): - __gt_origin__: Tuple[int, ...] + """ + A `GTOriginInterface` is an object providing `__gt_origin__`, describing the origin of a buffer. + + See :ref:`cartesian-arrays-default-origin` for details. + """ + + @property + def __gt_origin__(self) -> Tuple[int, ...]: + ... # -- Device representation -- diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index b4d1fc0c09..cc35899668 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -25,6 +25,9 @@ from . import common, ffront, iterator, program_processors, type_inference from .common import Dimension, DimensionKind, Field, GridType +from .embedded import ( # Just for registering field implementations + nd_array_field as _nd_array_field, +) from .ffront import fbuiltins from .ffront.decorator import field_operator, program, scan_operator from .ffront.fbuiltins import * # noqa: F403 # fbuiltins defines __all__ and we explicitly want to reexport everything here diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 866b2aadb7..b85239cd0a 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -20,9 +20,9 @@ import enum import functools import sys -from collections.abc import Sequence, Set -from types import EllipsisType -from typing import ChainMap, TypeGuard, overload +import types +from collections.abc import Mapping, Sequence, Set +from typing import overload import numpy as np import numpy.typing as npt @@ -37,16 +37,18 @@ ParamSpec, Protocol, TypeAlias, + TypeGuard, TypeVar, + cast, extended_runtime_checkable, - final, runtime_checkable, ) from gt4py.eve.type_definitions import StrEnum -DimT = TypeVar("DimT", bound="Dimension") -DimsT = TypeVar("DimsT", bound=Sequence["Dimension"], covariant=True) +DimsT = TypeVar( + "DimsT", covariant=True +) # bound to `Sequence[Dimension]` if instance of Dimension would be a type class Infinity(int): @@ -66,7 +68,7 @@ class DimensionKind(StrEnum): LOCAL = "local" def __str__(self): - return f"{type(self).__name__}.{self.name}" + return self.value @dataclasses.dataclass(frozen=True) @@ -75,7 +77,7 @@ class Dimension: kind: DimensionKind = dataclasses.field(default=DimensionKind.HORIZONTAL) def __str__(self): - return f'Dimension(value="{self.value}", kind={self.kind})' + return f"{self.value}[{self.kind}]" @dataclasses.dataclass(frozen=True) @@ -136,36 +138,139 @@ def __and__(self, other: Set[Any]) -> UnitRange: else: raise NotImplementedError("Can only find the intersection between UnitRange instances.") + def __str__(self) -> str: + return f"({self.start}:{self.stop})" + + +RangeLike: TypeAlias = UnitRange | range | tuple[int, int] + + +def unit_range(r: RangeLike) -> UnitRange: + if isinstance(r, UnitRange): + return r + if isinstance(r, range): + if r.step != 1: + raise ValueError(f"`UnitRange` requires step size 1, got `{r.step}`.") + return UnitRange(r.start, r.stop) + if isinstance(r, tuple) and isinstance(r[0], int) and isinstance(r[1], int): + return UnitRange(r[0], r[1]) + raise ValueError(f"`{r}` cannot be interpreted as `UnitRange`.") -DomainRange: TypeAlias = UnitRange | int + +IntIndex: TypeAlias = int | core_defs.IntegralScalar +NamedIndex: TypeAlias = tuple[Dimension, IntIndex] NamedRange: TypeAlias = tuple[Dimension, UnitRange] -NamedIndex: TypeAlias = tuple[Dimension, int] -DomainSlice: TypeAlias = Sequence[NamedRange | NamedIndex] -FieldSlice: TypeAlias = ( - DomainSlice - | tuple[slice | int | EllipsisType, ...] - | slice - | int - | EllipsisType - | NamedRange - | NamedIndex -) +RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType +AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange +AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement +AbsoluteIndexSequence: TypeAlias = Sequence[NamedRange | NamedIndex] +RelativeIndexSequence: TypeAlias = tuple[ + slice | IntIndex | types.EllipsisType, ... +] # is a tuple but called Sequence for symmetry +AnyIndexSequence: TypeAlias = RelativeIndexSequence | AbsoluteIndexSequence +AnyIndexSpec: TypeAlias = AnyIndexElement | AnyIndexSequence + + +def is_int_index(p: Any) -> TypeGuard[IntIndex]: + # should be replaced by isinstance(p, IntIndex), but mypy complains with + # `Argument 2 to "isinstance" has incompatible type ""; expected "_ClassInfo" [arg-type]` + return isinstance(p, (int, core_defs.INTEGRAL_TYPES)) + + +def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]: + return ( + isinstance(v, tuple) + and len(v) == 2 + and isinstance(v[0], Dimension) + and isinstance(v[1], UnitRange) + ) -@dataclasses.dataclass(frozen=True) +def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]: + return ( + isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1]) + ) + + +def is_any_index_element(v: AnyIndexSpec) -> TypeGuard[AnyIndexElement]: + return ( + is_int_index(v) + or is_named_range(v) + or is_named_index(v) + or isinstance(v, slice) + or v is Ellipsis + ) + + +def is_absolute_index_sequence(v: AnyIndexSequence) -> TypeGuard[AbsoluteIndexSequence]: + return isinstance(v, Sequence) and all(is_named_range(e) or is_named_index(e) for e in v) + + +def is_relative_index_sequence(v: AnyIndexSequence) -> TypeGuard[RelativeIndexSequence]: + return isinstance(v, tuple) and all( + isinstance(e, slice) or is_int_index(e) or e is Ellipsis for e in v + ) + + +def as_any_index_sequence(index: AnyIndexSpec) -> AnyIndexSequence: + # `cast` because mypy/typing doesn't special case 1-element tuples, i.e. `tuple[A|B] != tuple[A]|tuple[B]` + return cast( + AnyIndexSequence, + (index,) if is_any_index_element(index) else index, + ) + + +def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: + return (v[0], unit_range(v[1])) + + +@dataclasses.dataclass(frozen=True, init=False) class Domain(Sequence[NamedRange]): + """Describes the `Domain` of a `Field` as a `Sequence` of `NamedRange` s.""" + dims: tuple[Dimension, ...] ranges: tuple[UnitRange, ...] - def __post_init__(self): + def __init__( + self, + *args: NamedRange, + dims: Optional[tuple[Dimension, ...]] = None, + ranges: Optional[tuple[UnitRange, ...]] = None, + ) -> None: + if dims is not None or ranges is not None: + if dims is None and ranges is None: + raise ValueError("Either both none of `dims` and `ranges` must be specified.") + if len(args) > 0: + raise ValueError( + "No extra `args` allowed when constructing fomr `dims` and `ranges`." + ) + + assert dims is not None and ranges is not None # for mypy + if not all(isinstance(dim, Dimension) for dim in dims): + raise ValueError( + f"`dims` argument needs to be a `tuple[Dimension, ...], got `{dims}`." + ) + if not all(isinstance(rng, UnitRange) for rng in ranges): + raise ValueError( + f"`ranges` argument needs to be a `tuple[UnitRange, ...], got `{ranges}`." + ) + if len(dims) != len(ranges): + raise ValueError( + f"Number of provided dimensions ({len(dims)}) does not match number of provided ranges ({len(ranges)})." + ) + + object.__setattr__(self, "dims", dims) + object.__setattr__(self, "ranges", ranges) + else: + if not all(is_named_range(arg) for arg in args): + raise ValueError(f"Elements of `Domain` need to be `NamedRange`s, got `{args}`.") + dims, ranges = zip(*args) if args else ((), ()) + object.__setattr__(self, "dims", tuple(dims)) + object.__setattr__(self, "ranges", tuple(ranges)) + if len(set(self.dims)) != len(self.dims): raise NotImplementedError(f"Domain dimensions must be unique, not {self.dims}.") - if len(self.dims) != len(self.ranges): - raise ValueError( - f"Number of provided dimensions ({len(self.dims)}) does not match number of provided ranges ({len(self.ranges)})." - ) - def __len__(self) -> int: return len(self.ranges) @@ -174,7 +279,7 @@ def __getitem__(self, index: int) -> NamedRange: ... @overload - def __getitem__(self, index: slice) -> "Domain": + def __getitem__(self, index: slice) -> Domain: ... @overload @@ -187,7 +292,7 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: elif isinstance(index, slice): dims_slice = self.dims[index] ranges_slice = self.ranges[index] - return Domain(dims_slice, ranges_slice) + return Domain(dims=dims_slice, ranges=ranges_slice) elif isinstance(index, Dimension): try: index_pos = self.dims.index(index) @@ -197,7 +302,21 @@ def __getitem__(self, index: int | slice | Dimension) -> NamedRange | Domain: else: raise KeyError("Invalid index type, must be either int, slice, or Dimension.") - def __and__(self, other: "Domain") -> "Domain": + def __and__(self, other: Domain) -> Domain: + """ + Intersect `Domain`s, missing `Dimension`s are considered infinite. + + Examples: + --------- + >>> I = Dimension("I") + >>> J = Dimension("J") + + >>> Domain((I, UnitRange(-1, 3))) & Domain((I, UnitRange(1, 6))) + Domain(dims=(Dimension(value='I', kind=),), ranges=(UnitRange(1, 3),)) + + >>> Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4))) & Domain((I, UnitRange(1, 6))) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(1, 3), UnitRange(2, 4))) + """ broadcast_dims = tuple(promote_dims(self.dims, other.dims)) intersected_ranges = tuple( rng1 & rng2 @@ -206,15 +325,49 @@ def __and__(self, other: "Domain") -> "Domain": _broadcast_ranges(broadcast_dims, other.dims, other.ranges), ) ) - return Domain(broadcast_dims, intersected_ranges) + return Domain(dims=broadcast_dims, ranges=intersected_ranges) + + def __str__(self) -> str: + return f"Domain({', '.join(f'{e[0]}={e[1]}' for e in self)})" + + +DomainLike: TypeAlias = ( + Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike] +) # `Domain` is `Sequence[NamedRange]` and therefore a subset + + +def domain(domain_like: DomainLike) -> Domain: + """ + Construct `Domain` from `DomainLike` object. + + Examples: + --------- + >>> I = Dimension("I") + >>> J = Dimension("J") + + >>> domain(((I, (2, 4)), (J, (3, 5)))) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) + + >>> domain({I: (2, 4), J: (3, 5)}) + Domain(dims=(Dimension(value='I', kind=), Dimension(value='J', kind=)), ranges=(UnitRange(2, 4), UnitRange(3, 5))) + """ + if isinstance(domain_like, Domain): + return domain_like + if isinstance(domain_like, Sequence): + return Domain(*tuple(named_range(d) for d in domain_like)) + if isinstance(domain_like, Mapping): + return Domain( + dims=tuple(domain_like.keys()), + ranges=tuple(unit_range(r) for r in domain_like.values()), + ) + raise ValueError(f"`{domain_like}` is not `DomainLike`.") def _broadcast_ranges( broadcast_dims: Sequence[Dimension], dims: Sequence[Dimension], ranges: Sequence[UnitRange] ) -> tuple[UnitRange, ...]: return tuple( - ranges[dims.index(d)] if d in dims else UnitRange(Infinity.negative(), Infinity.positive()) - for d in broadcast_dims + ranges[dims.index(d)] if d in dims else UnitRange.infinity() for d in broadcast_dims ) @@ -230,8 +383,22 @@ def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _ ... +class NextGTDimsInterface(Protocol): + """ + A `GTDimsInterface` is an object providing the `__gt_dims__` property, naming :class:`Field` dimensions. + + The dimension names are objects of type :class:`Dimension`, in contrast to :mod:`gt4py.cartesian`, + where the labels are `str` s with implied semantics, see :class:`~gt4py._core.definitions.GTDimsInterface` . + """ + + # TODO(havogt): unify with GTDimsInterface, ideally in backward compatible way + @property + def __gt_dims__(self) -> tuple[Dimension, ...]: + ... + + @extended_runtime_checkable -class Field(Protocol[DimsT, core_defs.ScalarT]): +class Field(NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, core_defs.ScalarT]): __gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher] @property @@ -242,24 +409,19 @@ def domain(self) -> Domain: def dtype(self) -> core_defs.DType[core_defs.ScalarT]: ... - @property - def value_type(self) -> type[core_defs.ScalarT]: - ... - @property def ndarray(self) -> core_defs.NDArrayObject: ... def __str__(self) -> str: - codomain = self.value_type.__name__ - return f"⟨{self.domain!s} → {codomain}⟩" + return f"⟨{self.domain!s} → {self.dtype}⟩" @abc.abstractmethod def remap(self, index_field: Field) -> Field: ... @abc.abstractmethod - def restrict(self, item: FieldSlice) -> Field | core_defs.ScalarT: + def restrict(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... # Operators @@ -268,7 +430,7 @@ def __call__(self, index_field: Field) -> Field: ... @abc.abstractmethod - def __getitem__(self, item: FieldSlice) -> Field | core_defs.ScalarT: + def __getitem__(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... @abc.abstractmethod @@ -279,6 +441,10 @@ def __abs__(self) -> Field: def __neg__(self) -> Field: ... + @abc.abstractmethod + def __invert__(self) -> Field: + """Only defined for `Field` of value type `bool`.""" + @abc.abstractmethod def __add__(self, other: Field | core_defs.ScalarT) -> Field: ... @@ -323,23 +489,44 @@ def __rtruediv__(self, other: Field | core_defs.ScalarT) -> Field: def __pow__(self, other: Field | core_defs.ScalarT) -> Field: ... + @abc.abstractmethod + def __and__(self, other: Field | core_defs.ScalarT) -> Field: + """Only defined for `Field` of value type `bool`.""" + + @abc.abstractmethod + def __or__(self, other: Field | core_defs.ScalarT) -> Field: + """Only defined for `Field` of value type `bool`.""" + + @abc.abstractmethod + def __xor__(self, other: Field | core_defs.ScalarT) -> Field: + """Only defined for `Field` of value type `bool`.""" + def is_field( v: Any, -) -> TypeGuard[Field]: # this function is introduced to localize the `type: ignore`` +) -> TypeGuard[Field]: + # This function is introduced to localize the `type: ignore` because + # extended_runtime_checkable does not make the protocol runtime_checkable + # for mypy. + # TODO(egparedes): remove it when extended_runtime_checkable is fixed return isinstance(v, Field) # type: ignore[misc] # we use extended_runtime_checkable -class FieldABC(Field[DimsT, core_defs.ScalarT]): - """Abstract base class for implementations of the :class:`Field` protocol.""" +@extended_runtime_checkable +class MutableField(Field[DimsT, core_defs.ScalarT], Protocol[DimsT, core_defs.ScalarT]): + @abc.abstractmethod + def __setitem__(self, index: AnyIndexSpec, value: Field | core_defs.ScalarT) -> None: + ... - @final - def __setattr__(self, key, value) -> None: - raise TypeError("Immutable type") - @final - def __setitem__(self, key, value) -> None: - raise TypeError("Immutable type") +def is_mutable_field( + v: Field, +) -> TypeGuard[MutableField]: + # This function is introduced to localize the `type: ignore` because + # extended_runtime_checkable does not make the protocol runtime_checkable + # for mypy. + # TODO(egparedes): remove it when extended_runtime_checkable is fixed + return isinstance(v, MutableField) # type: ignore[misc] # we use extended_runtime_checkable @functools.singledispatch @@ -347,8 +534,8 @@ def field( definition: Any, /, *, - domain: Optional[Any] = None, # TODO(havogt): provide domain_like to Domain conversion - value_type: Optional[type] = None, + domain: Optional[DomainLike] = None, + dtype: Optional[core_defs.DType] = None, ) -> Field: raise NotImplementedError @@ -470,26 +657,27 @@ def promote_dims(*dims_list: Sequence[Dimension]) -> list[Dimension]: return topologically_sorted_list -def is_named_range(v: Any) -> TypeGuard[NamedRange]: - return isinstance(v, tuple) and isinstance(v[0], Dimension) and isinstance(v[1], UnitRange) - - -def is_named_index(v: Any) -> TypeGuard[NamedIndex]: - return isinstance(v, tuple) and isinstance(v[0], Dimension) and isinstance(v[1], int) - - -def is_domain_slice(index: Any) -> TypeGuard[DomainSlice]: - return isinstance(index, Sequence) and all( - is_named_range(idx) or is_named_index(idx) for idx in index - ) +class FieldBuiltinFuncRegistry: + """ + Mixin for adding `fbuiltins` registry to a `Field`. + Subclasses of a `Field` with `FieldBuiltinFuncRegistry` get their own registry, + dispatching (via ChainMap) to its parent's registries. + """ -class FieldBuiltinFuncRegistry: - _builtin_func_map: ChainMap[fbuiltins.BuiltInFunction, Callable] = collections.ChainMap() + _builtin_func_map: collections.ChainMap[ + fbuiltins.BuiltInFunction, Callable + ] = collections.ChainMap() def __init_subclass__(cls, **kwargs): - # might break in multiple inheritance (if multiple ancestors have `_builtin_func_map`) - cls._builtin_func_map = cls._builtin_func_map.new_child() + cls._builtin_func_map = collections.ChainMap( + {}, # New empty `dict`` for new registrations on this class + *[ + c.__dict__["_builtin_func_map"].maps[0] # adding parent `dict`s in mro order + for c in cls.__mro__ + if "_builtin_func_map" in c.__dict__ + ], + ) @classmethod def register_builtin_func( diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py new file mode 100644 index 0000000000..37ba4954f3 --- /dev/null +++ b/src/gt4py/next/embedded/common.py @@ -0,0 +1,127 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Any, Optional, Sequence, cast + +from gt4py.next import common +from gt4py.next.embedded import exceptions as embedded_exceptions + + +def sub_domain(domain: common.Domain, index: common.AnyIndexSpec) -> common.Domain: + index_sequence = common.as_any_index_sequence(index) + + if common.is_absolute_index_sequence(index_sequence): + return _absolute_sub_domain(domain, index_sequence) + + if common.is_relative_index_sequence(index_sequence): + return _relative_sub_domain(domain, index_sequence) + + raise IndexError(f"Unsupported index type: {index}") + + +def _relative_sub_domain( + domain: common.Domain, index: common.RelativeIndexSequence +) -> common.Domain: + named_ranges: list[common.NamedRange] = [] + + expanded = _expand_ellipsis(index, len(domain)) + if len(domain) < len(expanded): + raise IndexError(f"Trying to index a `Field` with {len(domain)} dimensions with {index}.") + expanded += (slice(None),) * (len(domain) - len(expanded)) + for (dim, rng), idx in zip(domain, expanded, strict=True): + if isinstance(idx, slice): + try: + sliced = _slice_range(rng, idx) + named_ranges.append((dim, sliced)) + except IndexError: + raise embedded_exceptions.IndexOutOfBounds( + domain=domain, indices=index, index=idx, dim=dim + ) + else: + # not in new domain + assert common.is_int_index(idx) + new_index = (rng.start if idx >= 0 else rng.stop) + idx + if new_index < rng.start or new_index >= rng.stop: + raise embedded_exceptions.IndexOutOfBounds( + domain=domain, indices=index, index=idx, dim=dim + ) + + return common.Domain(*named_ranges) + + +def _absolute_sub_domain( + domain: common.Domain, index: common.AbsoluteIndexSequence +) -> common.Domain: + named_ranges: list[common.NamedRange] = [] + for i, (dim, rng) in enumerate(domain): + if (pos := _find_index_of_dim(dim, index)) is not None: + named_idx = index[pos] + idx = named_idx[1] + if isinstance(idx, common.UnitRange): + if not idx <= rng: + raise embedded_exceptions.IndexOutOfBounds( + domain=domain, indices=index, index=named_idx, dim=dim + ) + + named_ranges.append((dim, idx)) + else: + # not in new domain + assert common.is_int_index(idx) + if idx < rng.start or idx >= rng.stop: + raise embedded_exceptions.IndexOutOfBounds( + domain=domain, indices=index, index=named_idx, dim=dim + ) + else: + # dimension not mentioned in slice + named_ranges.append((dim, domain.ranges[i])) + + return common.Domain(*named_ranges) + + +def _expand_ellipsis( + indices: common.RelativeIndexSequence, target_size: int +) -> tuple[common.IntIndex | slice, ...]: + if Ellipsis in indices: + idx = indices.index(Ellipsis) + indices = ( + indices[:idx] + (slice(None),) * (target_size - (len(indices) - 1)) + indices[idx + 1 :] + ) + return cast(tuple[common.IntIndex | slice, ...], indices) # mypy leave me alone and trust me! + + +def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.UnitRange: + if slice_obj == slice(None): + return common.UnitRange(input_range.start, input_range.stop) + + start = ( + input_range.start if slice_obj.start is None or slice_obj.start >= 0 else input_range.stop + ) + (slice_obj.start or 0) + stop = ( + input_range.start if slice_obj.stop is None or slice_obj.stop >= 0 else input_range.stop + ) + (slice_obj.stop or len(input_range)) + + if start < input_range.start or stop > input_range.stop: + raise IndexError("Slice out of range (no clipping following array API standard).") + + return common.UnitRange(start, stop) + + +def _find_index_of_dim( + dim: common.Dimension, + domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], +) -> Optional[int]: + for i, (d, _) in enumerate(domain_slice): + if dim == d: + return i + return None diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py new file mode 100644 index 0000000000..393123db36 --- /dev/null +++ b/src/gt4py/next/embedded/exceptions.py @@ -0,0 +1,38 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from gt4py.next import common +from gt4py.next.errors import exceptions as gt4py_exceptions + + +class IndexOutOfBounds(gt4py_exceptions.GT4PyError): + domain: common.Domain + indices: common.AnyIndexSpec + index: common.AnyIndexElement + dim: common.Dimension + + def __init__( + self, + domain: common.Domain, + indices: common.AnyIndexSpec, + index: common.AnyIndexElement, + dim: common.Dimension, + ): + super().__init__( + f"Out of bounds: slicing {domain} with index `{indices}`, `{index}` is out of bounds in dimension `{dim}`." + ) + self.domain = domain + self.indices = indices + self.index = index + self.dim = dim diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index ddef77bb78..fcaa09e7eb 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -15,17 +15,16 @@ from __future__ import annotations import dataclasses -import itertools from collections.abc import Callable, Sequence -from types import EllipsisType, ModuleType -from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar, cast +from types import ModuleType +from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar import numpy as np from numpy import typing as npt from gt4py._core import definitions as core_defs from gt4py.next import common -from gt4py.next.common import FieldBuiltinFuncRegistry +from gt4py.next.embedded import common as embedded_common from gt4py.next.ffront import fbuiltins @@ -56,7 +55,7 @@ def _make_binary_array_field_intrinsic_func(builtin_name: str, array_builtin_nam def _builtin_binary_op(a: _BaseNdArrayField, b: common.Field) -> common.Field: xp = a.__class__.array_ns op = getattr(xp, array_builtin_name) - if hasattr(b, "__gt_builtin_func__"): # isinstance(b, common.Field): + if hasattr(b, "__gt_builtin_func__"): # common.is_field(b): if not a.domain == b.domain: domain_intersection = a.domain & b.domain a_broadcasted = _broadcast(a, domain_intersection.dims) @@ -82,7 +81,9 @@ def _builtin_binary_op(a: _BaseNdArrayField, b: common.Field) -> common.Field: @dataclasses.dataclass(frozen=True) -class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT], FieldBuiltinFuncRegistry): +class _BaseNdArrayField( + common.MutableField[common.DimsT, core_defs.ScalarT], common.FieldBuiltinFuncRegistry +): """ Shared field implementation for NumPy-like fields. @@ -94,7 +95,6 @@ class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT], FieldB _domain: common.Domain _ndarray: core_defs.NDArrayObject - _value_type: type[core_defs.ScalarT] array_ns: ClassVar[ ModuleType @@ -104,13 +104,28 @@ class _BaseNdArrayField(common.FieldABC[common.DimsT, core_defs.ScalarT], FieldB def domain(self) -> common.Domain: return self._domain + @property + def shape(self) -> tuple[int, ...]: + return self._ndarray.shape + + @property + def __gt_dims__(self) -> tuple[common.Dimension, ...]: + return self._domain.dims + + @property + def __gt_origin__(self) -> tuple[int, ...]: + return tuple(-r.start for _, r in self._domain) + @property def ndarray(self) -> core_defs.NDArrayObject: return self._ndarray + def __array__(self, dtype: npt.DTypeLike = None) -> np.ndarray: + return np.asarray(self._ndarray, dtype) + @property - def value_type(self) -> type[core_defs.ScalarT]: - return self._value_type + def dtype(self) -> core_defs.DType[core_defs.ScalarT]: + return core_defs.dtype(self._ndarray.dtype.type) @classmethod def from_array( @@ -119,38 +134,52 @@ def from_array( | core_defs.NDArrayObject, # TODO: NDArrayObject should be part of ArrayLike /, *, - domain: common.Domain, - value_type: Optional[type] = None, + domain: common.DomainLike, + dtype_like: Optional[core_defs.DType] = None, # TODO define DTypeLike ) -> _BaseNdArrayField: + domain = common.domain(domain) xp = cls.array_ns - dtype = None - if value_type is not None: - dtype = xp.dtype(value_type) - array = xp.asarray(data, dtype=dtype) - value_type = array.dtype.type # TODO add support for Dimensions as value_type + xp_dtype = None if dtype_like is None else xp.dtype(core_defs.dtype(dtype_like).scalar_type) + array = xp.asarray(data, dtype=xp_dtype) + + if dtype_like is not None: + assert array.dtype.type == core_defs.dtype(dtype_like).scalar_type assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES) - assert all(isinstance(d, common.Dimension) for d, r in domain), domain + assert all(isinstance(d, common.Dimension) for d in domain.dims), domain assert len(domain) == array.ndim assert all( - len(nr[1]) == s or (s == 1 and nr[1] == common.UnitRange.infinity()) - for nr, s in zip(domain, array.shape) + len(r) == s or (s == 1 and r == common.UnitRange.infinity()) + for r, s in zip(domain.ranges, array.shape) ) - assert value_type is not None # for mypy - return cls(domain, array, value_type) + return cls(domain, array) def remap(self: _BaseNdArrayField, connectivity) -> _BaseNdArrayField: raise NotImplementedError() + def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.ScalarT: + new_domain, buffer_slice = self._slice(index) + + new_buffer = self.ndarray[buffer_slice] + if len(new_domain) == 0: + assert core_defs.is_scalar_type(new_buffer) + return new_buffer # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here + else: + return self.__class__.from_array(new_buffer, domain=new_domain) + + __getitem__ = restrict + __call__ = None # type: ignore[assignment] # TODO: remap __abs__ = _make_unary_array_field_intrinsic_func("abs", "abs") __neg__ = _make_unary_array_field_intrinsic_func("neg", "negative") + __pos__ = _make_unary_array_field_intrinsic_func("pos", "positive") + __add__ = __radd__ = _make_binary_array_field_intrinsic_func("add", "add") __sub__ = __rsub__ = _make_binary_array_field_intrinsic_func("sub", "subtract") @@ -165,78 +194,51 @@ def remap(self: _BaseNdArrayField, connectivity) -> _BaseNdArrayField: __pow__ = _make_binary_array_field_intrinsic_func("pow", "power") - def __getitem__(self, index: common.FieldSlice) -> common.Field | core_defs.ScalarT: - if ( - not isinstance(index, tuple) - and not common.is_domain_slice(index) - or common.is_named_index(index) - or common.is_named_range(index) - ): - index = cast(common.FieldSlice, (index,)) + __mod__ = __rmod__ = _make_binary_array_field_intrinsic_func("mod", "mod") - if common.is_domain_slice(index): - return self._getitem_absolute_slice(index) + def __and__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: + if self.dtype == core_defs.BoolDType(): + return _make_binary_array_field_intrinsic_func("logical_and", "logical_and")( + self, other + ) + raise NotImplementedError("`__and__` not implemented for non-`bool` fields.") - assert isinstance(index, tuple) - if all(isinstance(idx, (slice, int)) or idx is Ellipsis for idx in index): - return self._getitem_relative_slice(index) + __rand__ = __and__ - raise IndexError(f"Unsupported index type: {index}") + def __or__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: + if self.dtype == core_defs.BoolDType(): + return _make_binary_array_field_intrinsic_func("logical_or", "logical_or")(self, other) + raise NotImplementedError("`__or__` not implemented for non-`bool` fields.") - restrict = ( - __getitem__ # type:ignore[assignment] # TODO(havogt) I don't see the problem that mypy has - ) + __ror__ = __or__ - def _getitem_absolute_slice( - self, index: common.DomainSlice - ) -> common.Field | core_defs.ScalarT: - slices = _get_slices_from_domain_slice(self.domain, index) - new_ranges = [] - new_dims = [] - new = self.ndarray[slices] - - for i, dim in enumerate(self.domain.dims): - if (pos := _find_index_of_dim(dim, index)) is not None: - index_or_range = index[pos][1] - if isinstance(index_or_range, common.UnitRange): - new_ranges.append(index_or_range) - new_dims.append(dim) - else: - # dimension not mentioned in slice - new_ranges.append(self.domain.ranges[i]) - new_dims.append(dim) - - new_domain = common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) + def __xor__(self, other: common.Field | core_defs.ScalarT) -> _BaseNdArrayField: + if self.dtype == core_defs.BoolDType(): + return _make_binary_array_field_intrinsic_func("logical_xor", "logical_xor")( + self, other + ) + raise NotImplementedError("`__xor__` not implemented for non-`bool` fields.") - if len(new_domain) == 0: - assert core_defs.is_scalar_type(new) - return new # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here - else: - return self.__class__.from_array(new, domain=new_domain, value_type=self.value_type) - - def _getitem_relative_slice( - self, indices: tuple[slice | int | EllipsisType, ...] - ) -> common.Field | core_defs.ScalarT: - new = self.ndarray[indices] - new_dims = [] - new_ranges = [] - - for (dim, rng), idx in itertools.zip_longest( # type: ignore[misc] # "slice" object is not iterable, not sure which slice... - self.domain, _expand_ellipsis(indices, len(self.domain)), fillvalue=slice(None) - ): - if isinstance(idx, slice): - new_dims.append(dim) - new_ranges.append(_slice_range(rng, idx)) - else: - assert isinstance(idx, int) # not in new_domain - - new_domain = common.Domain(dims=tuple(new_dims), ranges=tuple(new_ranges)) + __rxor__ = __xor__ - if len(new_domain) == 0: - assert core_defs.is_scalar_type(new), new - return new # type: ignore[return-value] # I don't think we can express that we return `ScalarT` here - else: - return self.__class__.from_array(new, domain=new_domain, value_type=self.value_type) + def __invert__(self) -> _BaseNdArrayField: + if self.dtype == core_defs.BoolDType(): + return _make_unary_array_field_intrinsic_func("invert", "invert")(self) + raise NotImplementedError("`__invert__` not implemented for non-`bool` fields.") + + def _slice( + self, index: common.AnyIndexSpec + ) -> tuple[common.Domain, common.RelativeIndexSequence]: + new_domain = embedded_common.sub_domain(self.domain, index) + + index_sequence = common.as_any_index_sequence(index) + slice_ = ( + _get_slices_from_domain_slice(self.domain, index_sequence) + if common.is_absolute_index_sequence(index_sequence) + else index_sequence + ) + assert common.is_relative_index_sequence(slice_) + return new_domain, slice_ # -- Specialized implementations for intrinsic operations on array fields -- @@ -266,6 +268,25 @@ def _getitem_relative_slice( fbuiltins.fmod, _make_binary_array_field_intrinsic_func("fmod", "fmod") # type: ignore[attr-defined] ) + +def _np_cp_setitem( + self: _BaseNdArrayField[common.DimsT, core_defs.ScalarT], + index: common.AnyIndexSpec, + value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, +) -> None: + target_domain, target_slice = self._slice(index) + + if common.is_field(value): + if not value.domain == target_domain: + raise ValueError( + f"Incompatible `Domain` in assignment. Source domain = {value.domain}, target domain = {target_domain}." + ) + value = value.ndarray + + assert hasattr(self.ndarray, "__setitem__") + self.ndarray[target_slice] = value + + # -- Concrete array implementations -- # NumPy _nd_array_implementations = [np] @@ -275,6 +296,8 @@ def _getitem_relative_slice( class NumPyArrayField(_BaseNdArrayField): array_ns: ClassVar[ModuleType] = np + __setitem__ = _np_cp_setitem + common.field.register(np.ndarray, NumPyArrayField.from_array) @@ -286,6 +309,8 @@ class NumPyArrayField(_BaseNdArrayField): class CuPyArrayField(_BaseNdArrayField): array_ns: ClassVar[ModuleType] = cp + __setitem__ = _np_cp_setitem + common.field.register(cp.ndarray, CuPyArrayField.from_array) # JAX @@ -296,38 +321,30 @@ class CuPyArrayField(_BaseNdArrayField): class JaxArrayField(_BaseNdArrayField): array_ns: ClassVar[ModuleType] = jnp - common.field.register(jnp.ndarray, JaxArrayField.from_array) - + def __setitem__( + self, + index: common.AnyIndexSpec, + value: common.Field | core_defs.NDArrayObject | core_defs.ScalarT, + ) -> None: + # TODO(havogt): use something like `self.ndarray = self.ndarray.at(index).set(value)` + raise NotImplementedError("`__setitem__` for JaxArrayField not yet implemented.") -def _find_index_of_dim( - dim: common.Dimension, - domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], -) -> Optional[int]: - for i, (d, _) in enumerate(domain_slice): - if dim == d: - return i - return None + common.field.register(jnp.ndarray, JaxArrayField.from_array) def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field: domain_slice: list[slice | None] = [] - new_domain_dims = [] - new_domain_ranges = [] + named_ranges = [] for dim in new_dimensions: - if (pos := _find_index_of_dim(dim, field.domain)) is not None: + if (pos := embedded_common._find_index_of_dim(dim, field.domain)) is not None: domain_slice.append(slice(None)) - new_domain_dims.append(dim) - new_domain_ranges.append(field.domain[pos][1]) + named_ranges.append((dim, field.domain[pos][1])) else: domain_slice.append(np.newaxis) - new_domain_dims.append(dim) - new_domain_ranges.append( - common.UnitRange(common.Infinity.negative(), common.Infinity.positive()) + named_ranges.append( + (dim, common.UnitRange(common.Infinity.negative(), common.Infinity.positive())) ) - return common.field( - field.ndarray[tuple(domain_slice)], - domain=common.Domain(tuple(new_domain_dims), tuple(new_domain_ranges)), - ) + return common.field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) def _builtins_broadcast( @@ -344,7 +361,7 @@ def _builtins_broadcast( def _get_slices_from_domain_slice( domain: common.Domain, domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any], -) -> tuple[slice | int | None, ...]: +) -> common.RelativeIndexSequence: """Generate slices for sub-array extraction based on named ranges or named indices within a Domain. This function generates a tuple of slices that can be used to extract sub-arrays from a field. The provided @@ -359,10 +376,10 @@ def _get_slices_from_domain_slice( specified in the Domain. If a dimension is not included in the named indices or ranges, a None is used to indicate expansion along that axis. """ - slice_indices: list[slice | int | None] = [] + slice_indices: list[slice | common.IntIndex] = [] for pos_old, (dim, _) in enumerate(domain): - if (pos := _find_index_of_dim(dim, domain_slice)) is not None: + if (pos := embedded_common._find_index_of_dim(dim, domain_slice)) is not None: index_or_range = domain_slice[pos][1] slice_indices.append(_compute_slice(index_or_range, domain, pos_old)) else: @@ -370,7 +387,9 @@ def _get_slices_from_domain_slice( return tuple(slice_indices) -def _compute_slice(rng: common.DomainRange, domain: common.Domain, pos: int) -> slice | int: +def _compute_slice( + rng: common.UnitRange | common.IntIndex, domain: common.Domain, pos: int +) -> slice | common.IntIndex: """Compute a slice or integer based on the provided range, domain, and position. Args: @@ -392,34 +411,7 @@ def _compute_slice(rng: common.DomainRange, domain: common.Domain, pos: int) -> rng.start - domain.ranges[pos].start, rng.stop - domain.ranges[pos].start, ) - elif isinstance(rng, int): + elif common.is_int_index(rng): return rng - domain.ranges[pos].start else: raise ValueError(f"Can only use integer or UnitRange ranges, provided type: {type(rng)}") - - -def _slice_range(input_range: common.UnitRange, slice_obj: slice) -> common.UnitRange: - # handle slice(None) case - if slice_obj == slice(None): - return common.UnitRange(input_range.start, input_range.stop) - - start = ( - input_range.start if slice_obj.start is None or slice_obj.start >= 0 else input_range.stop - ) + (slice_obj.start or 0) - stop = ( - input_range.start if slice_obj.stop is None or slice_obj.stop >= 0 else input_range.stop - ) + (slice_obj.stop or len(input_range)) - - return common.UnitRange(start, stop) - - -def _expand_ellipsis( - indices: tuple[int | slice | EllipsisType, ...], target_size: int -) -> tuple[int | slice, ...]: - expanded_indices: list[int | slice] = [] - for idx in indices: - if idx is Ellipsis: - expanded_indices.extend([slice(None)] * (target_size - (len(indices) - 1))) - else: - expanded_indices.append(idx) - return tuple(expanded_indices) diff --git a/src/gt4py/next/errors/exceptions.py b/src/gt4py/next/errors/exceptions.py index 74230263db..e956858549 100644 --- a/src/gt4py/next/errors/exceptions.py +++ b/src/gt4py/next/errors/exceptions.py @@ -33,17 +33,19 @@ from . import formatting -class DSLError(Exception): +class GT4PyError(Exception): + @property + def message(self) -> str: + return self.args[0] + + +class DSLError(GT4PyError): location: Optional[SourceLocation] def __init__(self, location: Optional[SourceLocation], message: str) -> None: self.location = location super().__init__(message) - @property - def message(self) -> str: - return self.args[0] - def with_location(self, location: Optional[SourceLocation]) -> DSLError: self.location = location return self diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index ba027be13c..52aae34b3f 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -11,6 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later + import dataclasses import inspect from builtins import bool, float, int, tuple diff --git a/src/gt4py/next/utils.py b/src/gt4py/next/utils.py index 0c5de764f2..006b3057b0 100644 --- a/src/gt4py/next/utils.py +++ b/src/gt4py/next/utils.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, ClassVar +from typing import Any, ClassVar, TypeGuard, TypeVar class RecursionGuard: @@ -49,3 +49,10 @@ def __enter__(self): def __exit__(self, *exc): self.guarded_objects.remove(id(self.obj)) + + +_T = TypeVar("_T") + + +def is_tuple_of(v: Any, t: type[_T]) -> TypeGuard[tuple[_T, ...]]: + return isinstance(v, tuple) and all(isinstance(e, t) for e in v) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py index 0bc5a98a4e..c1bee4fa2f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_foast_pretty_printer.py @@ -82,7 +82,7 @@ def scan(inp: int32) -> int32: expected = textwrap.dedent( f""" - @scan_operator(axis=Dimension(value="KDim", kind=DimensionKind.VERTICAL), forward=False, init=1) + @scan_operator(axis=KDim[vertical], forward=False, init=1) def scan(inp: int32) -> int32: {ssa.unique_name("foo", 0)} = inp return inp diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py new file mode 100644 index 0000000000..640ed326bb --- /dev/null +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -0,0 +1,137 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Sequence + +import pytest + +from gt4py.next import common +from gt4py.next.common import UnitRange +from gt4py.next.embedded import exceptions as embedded_exceptions +from gt4py.next.embedded.common import _slice_range, sub_domain + + +@pytest.mark.parametrize( + "rng, slce, expected", + [ + (UnitRange(2, 10), slice(2, -2), UnitRange(4, 8)), + (UnitRange(2, 10), slice(2, None), UnitRange(4, 10)), + (UnitRange(2, 10), slice(None, -2), UnitRange(2, 8)), + (UnitRange(2, 10), slice(None), UnitRange(2, 10)), + ], +) +def test_slice_range(rng, slce, expected): + result = _slice_range(rng, slce) + assert result == expected + + +I = common.Dimension("I") +J = common.Dimension("J") +K = common.Dimension("K") + + +@pytest.mark.parametrize( + "domain, index, expected", + [ + ([(I, (2, 5))], 1, []), + ([(I, (2, 5))], slice(1, 2), [(I, (3, 4))]), + ([(I, (2, 5))], (I, 2), []), + ([(I, (2, 5))], (I, UnitRange(2, 3)), [(I, (2, 3))]), + ([(I, (-2, 3))], 1, []), + ([(I, (-2, 3))], slice(1, 2), [(I, (-1, 0))]), + ([(I, (-2, 3))], (I, 1), []), + ([(I, (-2, 3))], (I, UnitRange(2, 3)), [(I, (2, 3))]), + ([(I, (-2, 3))], -5, []), + ([(I, (-2, 3))], -6, IndexError), + ([(I, (-2, 3))], slice(-7, -6), IndexError), + ([(I, (-2, 3))], slice(-6, -7), IndexError), + ([(I, (-2, 3))], 4, []), + ([(I, (-2, 3))], 5, IndexError), + ([(I, (-2, 3))], slice(4, 5), [(I, (2, 3))]), + ([(I, (-2, 3))], slice(5, 6), IndexError), + ([(I, (-2, 3))], (I, -3), IndexError), + ([(I, (-2, 3))], (I, UnitRange(-3, -2)), IndexError), + ([(I, (-2, 3))], (I, 3), IndexError), + ([(I, (-2, 3))], (I, UnitRange(3, 4)), IndexError), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + 2, + [(J, (3, 6)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + slice(2, 3), + [(I, (4, 5)), (J, (3, 6)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (I, 2), + [(J, (3, 6)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (I, UnitRange(2, 3)), + [(I, (2, 3)), (J, (3, 6)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (J, 3), + [(I, (2, 5)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (J, UnitRange(4, 5)), + [(I, (2, 5)), (J, (4, 5)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + ((J, 3), (I, 2)), + [(K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + ((J, UnitRange(4, 5)), (I, 2)), + [(J, (4, 5)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), slice(2, 3)), + [(I, (3, 4)), (J, (5, 6)), (K, (4, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (Ellipsis, slice(2, 3)), + [(I, (2, 5)), (J, (3, 6)), (K, (6, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), Ellipsis, slice(2, 3)), + [(I, (3, 4)), (J, (3, 6)), (K, (6, 7))], + ), + ( + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), slice(1, 2), Ellipsis), + [(I, (3, 4)), (J, (4, 5)), (K, (4, 7))], + ), + ], +) +def test_sub_domain(domain, index, expected): + domain = common.domain(domain) + if expected is IndexError: + with pytest.raises(embedded_exceptions.IndexOutOfBounds): + sub_domain(domain, index) + else: + expected = common.domain(expected) + result = sub_domain(domain, index) + assert result == expected diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index a2aa3112bd..95093c8307 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -22,8 +22,8 @@ from gt4py.next import Dimension, common from gt4py.next.common import Domain, UnitRange -from gt4py.next.embedded import nd_array_field -from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice, _slice_range +from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field +from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice from gt4py.next.ffront import fbuiltins from next_tests.integration_tests.feature_tests.math_builtin_test_data import math_builtin_test_data @@ -40,16 +40,42 @@ def nd_array_implementation(request): @pytest.fixture( - params=[operator.add, operator.sub, operator.mul, operator.truediv, operator.floordiv], + params=[ + operator.add, + operator.sub, + operator.mul, + operator.truediv, + operator.floordiv, + operator.mod, + ] ) -def binary_op(request): +def binary_arithmetic_op(request): yield request.param -def _make_field(lst: Iterable, nd_array_implementation): +@pytest.fixture( + params=[operator.xor, operator.and_, operator.or_], +) +def binary_logical_op(request): + yield request.param + + +@pytest.fixture(params=[operator.neg, operator.pos]) +def unary_arithmetic_op(request): + yield request.param + + +@pytest.fixture(params=[operator.invert]) +def unary_logical_op(request): + yield request.param + + +def _make_field(lst: Iterable, nd_array_implementation, *, dtype=None): + if not dtype: + dtype = nd_array_implementation.float32 return common.field( - nd_array_implementation.asarray(lst, dtype=nd_array_implementation.float32), - domain=((common.Dimension("foo"), common.UnitRange(0, len(lst))),), + nd_array_implementation.asarray(lst, dtype=dtype), + domain={common.Dimension("foo"): (0, len(lst))}, ) @@ -72,16 +98,57 @@ def test_math_function_builtins(builtin_name: str, inputs, nd_array_implementati assert np.allclose(result.ndarray, expected) -def test_binary_ops(binary_op, nd_array_implementation): +def test_binary_arithmetic_ops(binary_arithmetic_op, nd_array_implementation): inp_a = [-1.0, 4.2, 42] inp_b = [2.0, 3.0, -3.0] inputs = [inp_a, inp_b] - expected = binary_op(*[np.asarray(inp, dtype=np.float32) for inp in inputs]) + expected = binary_arithmetic_op(*[np.asarray(inp, dtype=np.float32) for inp in inputs]) field_inputs = [_make_field(inp, nd_array_implementation) for inp in inputs] - result = binary_op(*field_inputs) + result = binary_arithmetic_op(*field_inputs) + + assert np.allclose(result.ndarray, expected) + + +def test_binary_logical_ops(binary_logical_op, nd_array_implementation): + inp_a = [True, True, False, False] + inp_b = [True, False, True, False] + inputs = [inp_a, inp_b] + + expected = binary_logical_op(*[np.asarray(inp) for inp in inputs]) + + field_inputs = [_make_field(inp, nd_array_implementation, dtype=bool) for inp in inputs] + + result = binary_logical_op(*field_inputs) + + assert np.allclose(result.ndarray, expected) + + +def test_unary_logical_ops(unary_logical_op, nd_array_implementation): + inp = [ + True, + False, + ] + + expected = unary_logical_op(np.asarray(inp)) + + field_input = _make_field(inp, nd_array_implementation, dtype=bool) + + result = unary_logical_op(field_input) + + assert np.allclose(result.ndarray, expected) + + +def test_unary_arithmetic_ops(unary_arithmetic_op, nd_array_implementation): + inp = [1.0, -2.0, 0.0] + + expected = unary_arithmetic_op(np.asarray(inp, dtype=np.float32)) + + field_input = _make_field(inp, nd_array_implementation) + + result = unary_arithmetic_op(field_input) assert np.allclose(result.ndarray, expected) @@ -93,7 +160,7 @@ def test_binary_ops(binary_op, nd_array_implementation): ((JDim,), (None, slice(5, 10))), ], ) -def test_binary_operations_with_intersection(binary_op, dims, expected_indices): +def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expected_indices): arr1 = np.arange(10) arr1_domain = common.Domain(dims=dims, ranges=(UnitRange(0, 10),)) @@ -103,8 +170,8 @@ def test_binary_operations_with_intersection(binary_op, dims, expected_indices): field1 = common.field(arr1, domain=arr1_domain) field2 = common.field(arr2, domain=arr2_domain) - op_result = binary_op(field1, field2) - expected_result = binary_op(arr1[expected_indices[0], expected_indices[1]], arr2) + op_result = binary_arithmetic_op(field1, field2) + expected_result = binary_arithmetic_op(arr1[expected_indices[0], expected_indices[1]], arr2) assert op_result.ndarray.shape == (5, 5) assert np.allclose(op_result.ndarray, expected_result) @@ -122,10 +189,8 @@ def product_nd_array_implementation(request): def test_mixed_fields(product_nd_array_implementation): first_impl, second_impl = product_nd_array_implementation - if (first_impl.__name__ == "cupy" and second_impl.__name__ == "numpy") or ( - first_impl.__name__ == "numpy" and second_impl.__name__ == "cupy" - ): - pytest.skip("Binary operation between CuPy and NumPy requires explicit conversion.") + if "numpy" in first_impl.__name__ and "cupy" in second_impl.__name__: + pytest.skip("Binary operation between NumPy and CuPy requires explicit conversion.") inp_a = [-1.0, 4.2, 42] inp_b = [2.0, 3.0, -3.0] @@ -271,7 +336,7 @@ def test_get_slices_invalid_type(): (JDim, KDim), (2, 15), ), - ((IDim, 1), (JDim, KDim), (10, 15)), + ((IDim, 5), (JDim, KDim), (10, 15)), ((IDim, UnitRange(5, 7)), (IDim, JDim, KDim), (2, 10, 15)), ], ) @@ -282,20 +347,20 @@ def test_absolute_indexing(domain_slice, expected_dimensions, expected_shape): field = common.field(np.ones((5, 10, 15)), domain=domain) indexed_field = field[domain_slice] - assert isinstance(indexed_field, common.Field) + assert common.is_field(indexed_field) assert indexed_field.ndarray.shape == expected_shape assert indexed_field.domain.dims == expected_dimensions def test_absolute_indexing_value_return(): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(10, 20), UnitRange(5, 15))) - field = common.field(np.ones((10, 10), dtype=np.int32), domain=domain) + field = common.field(np.reshape(np.arange(100, dtype=np.int32), (10, 10)), domain=domain) - named_index = ((IDim, 2), (JDim, 4)) + named_index = ((IDim, 12), (JDim, 6)) value = field[named_index] assert isinstance(value, np.int32) - assert value == 1 + assert value == 21 @pytest.mark.parametrize( @@ -304,19 +369,23 @@ def test_absolute_indexing_value_return(): ( (slice(None, 5), slice(None, 2)), (5, 2), - Domain((IDim, JDim), (UnitRange(5, 10), UnitRange(2, 4))), + Domain((IDim, UnitRange(5, 10)), (JDim, UnitRange(2, 4))), + ), + ((slice(None, 5),), (5, 10), Domain((IDim, UnitRange(5, 10)), (JDim, UnitRange(2, 12)))), + ( + (Ellipsis, 1), + (10,), + Domain((IDim, UnitRange(5, 15))), ), - ((slice(None, 5),), (5, 10), Domain((IDim, JDim), (UnitRange(5, 10), UnitRange(2, 12)))), - ((Ellipsis, 1), (10,), Domain((IDim,), (UnitRange(5, 15),))), ( (slice(2, 3), slice(5, 7)), (1, 2), - Domain((IDim, JDim), (UnitRange(7, 8), UnitRange(7, 9))), + Domain((IDim, UnitRange(7, 8)), (JDim, UnitRange(7, 9))), ), ( (slice(1, 2), 0), (1,), - Domain((IDim,), (UnitRange(6, 7),)), + Domain((IDim, UnitRange(6, 7))), ), ], ) @@ -325,7 +394,7 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): field = common.field(np.ones((10, 10)), domain=domain) indexed_field = field[index] - assert isinstance(indexed_field, common.Field) + assert common.is_field(indexed_field) assert indexed_field.ndarray.shape == expected_shape assert indexed_field.domain == expected_domain @@ -333,32 +402,44 @@ def test_relative_indexing_slice_2D(index, expected_shape, expected_domain): @pytest.mark.parametrize( "index, expected_shape, expected_domain", [ - ((1, slice(None), 2), (15,), Domain((JDim,), (UnitRange(10, 25),))), + ((1, slice(None), 2), (15,), Domain(dims=(JDim,), ranges=(UnitRange(10, 25),))), ( (slice(None), slice(None), 2), (10, 15), - Domain((IDim, JDim), (UnitRange(5, 15), UnitRange(10, 25))), + Domain(dims=(IDim, JDim), ranges=(UnitRange(5, 15), UnitRange(10, 25))), ), ( (slice(None),), (10, 15, 10), - Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + ), ), ( (slice(None), slice(None), slice(None)), (10, 15, 10), - Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + ), ), ( (slice(None)), (10, 15, 10), - Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + ), ), - ((0, Ellipsis, 0), (15,), Domain((JDim,), (UnitRange(10, 25),))), + ((0, Ellipsis, 0), (15,), Domain(dims=(JDim,), ranges=(UnitRange(10, 25),))), ( Ellipsis, (10, 15, 10), - Domain((IDim, JDim, KDim), (UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(5, 15), UnitRange(10, 25), UnitRange(10, 20)), + ), ), ], ) @@ -369,7 +450,7 @@ def test_relative_indexing_slice_3D(index, expected_shape, expected_domain): field = common.field(np.ones((10, 15, 10)), domain=domain) indexed_field = field[index] - assert isinstance(indexed_field, common.Field) + assert common.is_field(indexed_field) assert indexed_field.ndarray.shape == expected_shape assert indexed_field.domain == expected_domain @@ -391,7 +472,7 @@ def test_relative_indexing_out_of_bounds(lazy_slice): domain = common.Domain(dims=(IDim, JDim), ranges=(UnitRange(3, 13), UnitRange(-5, 5))) field = common.field(np.ones((10, 10)), domain=domain) - with pytest.raises(IndexError): + with pytest.raises((embedded_exceptions.IndexOutOfBounds, IndexError)): lazy_slice(field) @@ -403,10 +484,40 @@ def test_field_unsupported_index(index): field[index] -def test_slice_range(): - input_range = UnitRange(2, 10) - slice_obj = slice(2, -2) - expected = UnitRange(4, 8) +@pytest.mark.parametrize( + "index, value", + [ + ((1, 1), 42.0), + ((1, slice(None)), np.ones((10,)) * 42.0), + ( + (1, slice(None)), + common.field(np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(0, 10)))), + ), + ], +) +def test_setitem(index, value): + field = common.field( + np.arange(100).reshape(10, 10), + domain=common.Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange(0, 10))), + ) + + expected = np.copy(field.ndarray) + expected[index] = value + + field[index] = value + + assert np.allclose(field.ndarray, expected) + + +def test_setitem_wrong_domain(): + field = common.field( + np.arange(100).reshape(10, 10), + domain=common.Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange(0, 10))), + ) + + value_incompatible = common.field( + np.ones((10,)) * 42.0, domain=common.Domain((JDim, UnitRange(-5, 5))) + ) - result = _slice_range(input_range, slice_obj) - assert result == expected + with pytest.raises(ValueError, match=r"Incompatible `Domain`.*"): + field[(1, slice(None))] = value_incompatible diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 8cdc96254c..31e35221ab 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -15,7 +15,17 @@ import pytest -from gt4py.next.common import Dimension, DimensionKind, Domain, Infinity, UnitRange, promote_dims +from gt4py.next.common import ( + Dimension, + DimensionKind, + Domain, + Infinity, + UnitRange, + domain, + named_range, + promote_dims, + unit_range, +) IDim = Dimension("IDim") @@ -25,15 +35,8 @@ @pytest.fixture -def domain(): - range1 = UnitRange(0, 10) - range2 = UnitRange(5, 15) - range3 = UnitRange(20, 30) - - dimensions = (IDim, JDim, KDim) - ranges = (range1, range2, range3) - - return Domain(dimensions, ranges) +def a_domain(): + return Domain((IDim, UnitRange(0, 10)), (JDim, UnitRange(5, 15)), (KDim, UnitRange(20, 30))) def test_empty_range(): @@ -53,6 +56,11 @@ def test_unit_range_length(rng): assert len(rng) == 10 +@pytest.mark.parametrize("rng_like", [(2, 4), range(2, 4), UnitRange(2, 4)]) +def test_unit_range_like(rng_like): + assert unit_range(rng_like) == UnitRange(2, 4) + + def test_unit_range_repr(rng): assert repr(rng) == "UnitRange(-5, 5)" @@ -142,54 +150,87 @@ def test_mixed_infinity_range(): assert len(mixed_inf_range) == Infinity.positive() -def test_domain_length(domain): - assert len(domain) == 3 +@pytest.mark.parametrize( + "named_rng_like", + [ + (IDim, (2, 4)), + (IDim, range(2, 4)), + (IDim, UnitRange(2, 4)), + ], +) +def test_named_range_like(named_rng_like): + assert named_range(named_rng_like) == (IDim, UnitRange(2, 4)) + + +def test_domain_length(a_domain): + assert len(a_domain) == 3 -def test_domain_iteration(domain): - iterated_values = [val for val in domain] - assert iterated_values == list(zip(domain.dims, domain.ranges)) +@pytest.mark.parametrize( + "domain_like", + [ + (Domain(dims=(IDim, JDim), ranges=(UnitRange(2, 4), UnitRange(3, 5)))), + ((IDim, (2, 4)), (JDim, (3, 5))), + ({IDim: (2, 4), JDim: (3, 5)}), + ], +) +def test_domain_like(domain_like): + assert domain(domain_like) == Domain( + dims=(IDim, JDim), ranges=(UnitRange(2, 4), UnitRange(3, 5)) + ) + +def test_domain_iteration(a_domain): + iterated_values = [val for val in a_domain] + assert iterated_values == list(zip(a_domain.dims, a_domain.ranges)) -def test_domain_contains_named_range(domain): - assert (IDim, UnitRange(0, 10)) in domain - assert (IDim, UnitRange(-5, 5)) not in domain + +def test_domain_contains_named_range(a_domain): + assert (IDim, UnitRange(0, 10)) in a_domain + assert (IDim, UnitRange(-5, 5)) not in a_domain @pytest.mark.parametrize( "second_domain, expected", [ ( - Domain((IDim, JDim), (UnitRange(2, 12), UnitRange(7, 17))), - Domain((IDim, JDim, KDim), (UnitRange(2, 10), UnitRange(7, 15), UnitRange(20, 30))), + Domain(dims=(IDim, JDim), ranges=(UnitRange(2, 12), UnitRange(7, 17))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(2, 10), UnitRange(7, 15), UnitRange(20, 30)), + ), ), ( - Domain((IDim, KDim), (UnitRange(2, 12), UnitRange(7, 27))), - Domain((IDim, JDim, KDim), (UnitRange(2, 10), UnitRange(5, 15), UnitRange(20, 27))), + Domain(dims=(IDim, KDim), ranges=(UnitRange(2, 12), UnitRange(7, 27))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(2, 10), UnitRange(5, 15), UnitRange(20, 27)), + ), ), ( - Domain((JDim, KDim), (UnitRange(2, 12), UnitRange(4, 27))), - Domain((IDim, JDim, KDim), (UnitRange(0, 10), UnitRange(5, 12), UnitRange(20, 27))), + Domain(dims=(JDim, KDim), ranges=(UnitRange(2, 12), UnitRange(4, 27))), + Domain( + dims=(IDim, JDim, KDim), + ranges=(UnitRange(0, 10), UnitRange(5, 12), UnitRange(20, 27)), + ), ), ], ) -def test_domain_intersection_different_dimensions(domain, second_domain, expected): - result_domain = domain & second_domain +def test_domain_intersection_different_dimensions(a_domain, second_domain, expected): + result_domain = a_domain & second_domain print(result_domain) assert result_domain == expected -def test_domain_intersection_reversed_dimensions(domain): - dimensions = (JDim, IDim) - ranges = (UnitRange(2, 12), UnitRange(7, 17)) - domain2 = Domain(dimensions, ranges) +def test_domain_intersection_reversed_dimensions(a_domain): + domain2 = Domain(dims=(JDim, IDim), ranges=(UnitRange(2, 12), UnitRange(7, 17))) with pytest.raises( ValueError, match="Dimensions can not be promoted. The following dimensions appear in contradicting order: IDim, JDim.", ): - domain & domain2 + a_domain & domain2 @pytest.mark.parametrize( @@ -202,8 +243,8 @@ def test_domain_intersection_reversed_dimensions(domain): (-2, (JDim, UnitRange(5, 15))), ], ) -def test_domain_integer_indexing(domain, index, expected): - result = domain[index] +def test_domain_integer_indexing(a_domain, index, expected): + result = a_domain[index] assert result == expected @@ -214,8 +255,8 @@ def test_domain_integer_indexing(domain, index, expected): (slice(1, None), ((JDim, UnitRange(5, 15)), (KDim, UnitRange(20, 30)))), ], ) -def test_domain_slice_indexing(domain, slice_obj, expected): - result = domain[slice_obj] +def test_domain_slice_indexing(a_domain, slice_obj, expected): + result = a_domain[slice_obj] assert isinstance(result, Domain) assert len(result) == len(expected) assert all(res == exp for res, exp in zip(result, expected)) @@ -228,28 +269,28 @@ def test_domain_slice_indexing(domain, slice_obj, expected): (KDim, (KDim, UnitRange(20, 30))), ], ) -def test_domain_dimension_indexing(domain, index, expected_result): - result = domain[index] +def test_domain_dimension_indexing(a_domain, index, expected_result): + result = a_domain[index] assert result == expected_result -def test_domain_indexing_dimension_missing(domain): +def test_domain_indexing_dimension_missing(a_domain): with pytest.raises(KeyError, match=r"No Dimension of type .* is present in the Domain."): - domain[ECDim] + a_domain[ECDim] -def test_domain_indexing_invalid_type(domain): +def test_domain_indexing_invalid_type(a_domain): with pytest.raises( KeyError, match="Invalid index type, must be either int, slice, or Dimension." ): - domain["foo"] + a_domain["foo"] def test_domain_repeat_dims(): dims = (IDim, JDim, IDim) ranges = (UnitRange(0, 5), UnitRange(0, 8), UnitRange(0, 3)) with pytest.raises(NotImplementedError, match=r"Domain dimensions must be unique, not .*"): - Domain(dims, ranges) + Domain(dims=dims, ranges=ranges) def test_domain_dims_ranges_length_mismatch(): From 42ae8e4183830f713364652128147efe07451f97 Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 5 Sep 2023 09:05:55 +0000 Subject: [PATCH 10/10] add missing methods --- src/gt4py/next/iterator/embedded.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index ef5d571fff..b9bc3a69c9 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1060,6 +1060,9 @@ def __abs__(self) -> common.Field: def __neg__(self) -> common.Field: raise NotImplementedError() + def __invert__(self) -> common.Field: + raise NotImplementedError() + def __add__(self, other: common.Field | core_defs.ScalarT) -> common.Field: raise NotImplementedError() @@ -1093,6 +1096,15 @@ def __rtruediv__(self, other: common.Field | core_defs.ScalarT) -> common.Field: def __pow__(self, other: common.Field | core_defs.ScalarT) -> common.Field: raise NotImplementedError() + def __and__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __or__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __xor__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + def index_field(axis: common.Dimension) -> common.Field: return IndexField(axis) @@ -1149,6 +1161,9 @@ def __abs__(self) -> common.Field: def __neg__(self) -> common.Field: raise NotImplementedError() + def __invert__(self) -> common.Field: + raise NotImplementedError() + def __add__(self, other: common.Field | core_defs.ScalarT) -> common.Field: raise NotImplementedError() @@ -1182,6 +1197,15 @@ def __rtruediv__(self, other: common.Field | core_defs.ScalarT) -> common.Field: def __pow__(self, other: common.Field | core_defs.ScalarT) -> common.Field: raise NotImplementedError() + def __and__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __or__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + + def __xor__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + raise NotImplementedError() + def constant_field(value: Any, dtype_like: Optional[core_defs.DTypeLike] = None) -> common.Field: if dtype_like is None: