diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 3e1fe52f31..29d606ccc0 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -20,9 +20,8 @@ import enum import functools import numbers -import sys import types -from collections.abc import Mapping, Sequence, Set +from collections.abc import Mapping, Sequence import numpy as np import numpy.typing as npt @@ -33,10 +32,12 @@ Any, Callable, ClassVar, + Generic, Never, Optional, ParamSpec, Protocol, + Self, TypeAlias, TypeGuard, TypeVar, @@ -52,16 +53,6 @@ DimsT = TypeVar("DimsT", bound=Sequence["Dimension"], covariant=True) -class Infinity(int): - @classmethod - def positive(cls) -> Infinity: - return cls(sys.maxsize) - - @classmethod - def negative(cls) -> Infinity: - return cls(-sys.maxsize) - - Tag: TypeAlias = str @@ -84,31 +75,86 @@ def __str__(self): return f"{self.value}[{self.kind}]" +class Infinity(enum.Enum): + """Describes an unbounded `UnitRange`.""" + + NEGATIVE = enum.auto() + POSITIVE = enum.auto() + + def __add__(self, _: int) -> Self: + return self + + __radd__ = __add__ + + def __sub__(self, _: int) -> Self: + return self + + __rsub__ = __sub__ + + def __le__(self, other: int | Infinity) -> bool: + return self is self.NEGATIVE or other is self.POSITIVE + + def __lt__(self, other: int | Infinity) -> bool: + return self is self.NEGATIVE and other is not self + + def __ge__(self, other: int | Infinity) -> bool: + return self is self.POSITIVE or other is self.NEGATIVE + + def __gt__(self, other: int | Infinity) -> bool: + return self is self.POSITIVE and other is not self + + +def _as_int(v: core_defs.IntegralScalar | Infinity) -> int | Infinity: + return v if isinstance(v, Infinity) else int(v) + + +_Left = TypeVar("_Left", int, Infinity) +_Right = TypeVar("_Right", int, Infinity) + + @dataclasses.dataclass(frozen=True, init=False) -class UnitRange(Sequence[int], Set[int]): +class UnitRange(Sequence[int], Generic[_Left, _Right]): """Range from `start` to `stop` with step size one.""" - start: int - stop: int + start: _Left + stop: _Right - def __init__(self, start: core_defs.IntegralScalar, stop: core_defs.IntegralScalar) -> None: + def __init__( + self, start: core_defs.IntegralScalar | Infinity, stop: core_defs.IntegralScalar | Infinity + ) -> None: if start < stop: - object.__setattr__(self, "start", int(start)) - object.__setattr__(self, "stop", int(stop)) + object.__setattr__(self, "start", _as_int(start)) + object.__setattr__(self, "stop", _as_int(stop)) else: # make UnitRange(0,0) the single empty UnitRange object.__setattr__(self, "start", 0) object.__setattr__(self, "stop", 0) - # TODO: the whole infinity idea and implementation is broken and should be replaced @classmethod - def infinity(cls) -> UnitRange: - return cls(Infinity.negative(), Infinity.positive()) + def infinite( + cls, + ) -> UnitRange: + return cls(Infinity.NEGATIVE, Infinity.POSITIVE) def __len__(self) -> int: - if Infinity.positive() in (abs(self.start), abs(self.stop)): - return Infinity.positive() - return max(0, self.stop - self.start) + if UnitRange.is_finite(self): + return max(0, self.stop - self.start) + raise ValueError("Cannot compute length of open 'UnitRange'.") + + @classmethod + def is_finite(cls, obj: UnitRange) -> TypeGuard[FiniteUnitRange]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return obj.start is not Infinity.NEGATIVE and obj.stop is not Infinity.POSITIVE + + @classmethod + def is_right_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[_Left, int]]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return obj.stop is not Infinity.POSITIVE + + @classmethod + def is_left_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[int, _Right]]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return obj.start is not Infinity.NEGATIVE def __repr__(self) -> str: return f"UnitRange({self.start}, {self.stop})" @@ -122,6 +168,7 @@ def __getitem__(self, index: slice) -> UnitRange: # noqa: F811 # redefine unuse ... def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # redefine unused + assert UnitRange.is_finite(self) if isinstance(index, slice): start, stop, step = index.indices(len(self)) if step != 1: @@ -138,61 +185,60 @@ def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # re else: raise IndexError("'UnitRange' index out of range") - def __and__(self, other: Set[int]) -> UnitRange: - if isinstance(other, UnitRange): - start = max(self.start, other.start) - stop = min(self.stop, other.stop) - return UnitRange(start, stop) - else: - raise NotImplementedError( - "Can only find the intersection between 'UnitRange' instances." - ) + def __and__(self, other: UnitRange) -> UnitRange: + return UnitRange(max(self.start, other.start), min(self.stop, other.stop)) + + def __contains__(self, value: Any) -> bool: + return ( + isinstance(value, core_defs.INTEGRAL_TYPES) + and value >= self.start + and value < self.stop + ) + + def __le__(self, other: UnitRange) -> bool: + return self.start >= other.start and self.stop <= other.stop + + def __lt__(self, other: UnitRange) -> bool: + return (self.start > other.start and self.stop <= other.stop) or ( + self.start >= other.start and self.stop < other.stop + ) + + def __ge__(self, other: UnitRange) -> bool: + return self.start <= other.start and self.stop >= other.stop - def __le__(self, other: Set[int]): + def __gt__(self, other: UnitRange) -> bool: + return (self.start < other.start and self.stop >= other.stop) or ( + self.start <= other.start and self.stop > other.stop + ) + + def __eq__(self, other: Any) -> bool: if isinstance(other, UnitRange): - return self.start >= other.start and self.stop <= other.stop - elif len(self) == Infinity.positive(): - return False - else: - return Set.__le__(self, other) - - def __add__(self, other: int | Set[int]) -> UnitRange: - if isinstance(other, int): - if other == Infinity.positive(): - return UnitRange.infinity() - elif other == Infinity.negative(): - return UnitRange(0, 0) - return UnitRange( - *( - s if s in [Infinity.negative(), Infinity.positive()] else s + other - for s in (self.start, self.stop) - ) - ) - else: - raise NotImplementedError("Can only compute union with 'int' instances.") - - def __sub__(self, other: int | Set[int]) -> UnitRange: - if isinstance(other, int): - if other == Infinity.negative(): - return self + Infinity.positive() - elif other == Infinity.positive(): - return self + Infinity.negative() - else: - return self + (-other) + return self.start == other.start and self.stop == other.stop else: - raise NotImplementedError("Can only compute substraction with 'int' instances.") + return False + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) - __ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented + def __add__(self, other: int) -> UnitRange: + return UnitRange(self.start + other, self.stop + other) + + def __sub__(self, other: int) -> UnitRange: + return UnitRange(self.start - other, self.stop - other) def __str__(self) -> str: return f"({self.start}:{self.stop})" +FiniteUnitRange: TypeAlias = UnitRange[int, int] + + RangeLike: TypeAlias = ( UnitRange | range | tuple[core_defs.IntegralScalar, core_defs.IntegralScalar] | core_defs.IntegralScalar + | None ) @@ -207,18 +253,23 @@ def unit_range(r: RangeLike) -> UnitRange: # once the related mypy bug (#16358) gets fixed if ( isinstance(r, tuple) - and isinstance(r[0], core_defs.INTEGRAL_TYPES) - and isinstance(r[1], core_defs.INTEGRAL_TYPES) + and (isinstance(r[0], core_defs.INTEGRAL_TYPES) or r[0] in (None, Infinity.NEGATIVE)) + and (isinstance(r[1], core_defs.INTEGRAL_TYPES) or r[1] in (None, Infinity.POSITIVE)) ): - return UnitRange(r[0], r[1]) + start = r[0] if r[0] is not None else Infinity.NEGATIVE + stop = r[1] if r[1] is not None else Infinity.POSITIVE + return UnitRange(start, stop) if isinstance(r, core_defs.INTEGRAL_TYPES): return UnitRange(0, cast(core_defs.IntegralScalar, r)) + if r is None: + return UnitRange.infinite() raise ValueError(f"'{r!r}' cannot be interpreted as 'UnitRange'.") IntIndex: TypeAlias = int | core_defs.IntegralScalar NamedIndex: TypeAlias = tuple[Dimension, IntIndex] # TODO: convert to NamedTuple NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple +FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement @@ -245,6 +296,10 @@ def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]: ) +def is_finite_named_range(v: NamedRange) -> TypeGuard[FiniteNamedRange]: + return UnitRange.is_finite(v[1]) + + def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]: return ( isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1]) @@ -283,18 +338,27 @@ def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange: return (v[0], unit_range(v[1])) +_Rng = TypeVar( + "_Rng", + UnitRange[int, int], + UnitRange[Infinity, int], + UnitRange[int, Infinity], + UnitRange[Infinity, Infinity], +) + + @dataclasses.dataclass(frozen=True, init=False) -class Domain(Sequence[NamedRange]): +class Domain(Sequence[tuple[Dimension, _Rng]], Generic[_Rng]): """Describes the `Domain` of a `Field` as a `Sequence` of `NamedRange` s.""" dims: tuple[Dimension, ...] - ranges: tuple[UnitRange, ...] + ranges: tuple[_Rng, ...] def __init__( self, - *args: NamedRange, + *args: tuple[Dimension, _Rng], dims: Optional[Sequence[Dimension]] = None, - ranges: Optional[Sequence[UnitRange]] = None, + ranges: Optional[Sequence[_Rng]] = None, ) -> None: if dims is not None or ranges is not None: if dims is None and ranges is None: @@ -343,16 +407,23 @@ def ndim(self) -> int: def shape(self) -> tuple[int, ...]: return tuple(len(r) for r in self.ranges) + @classmethod + def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]: + # classmethod since TypeGuards requires the guarded obj as separate argument + return all(UnitRange.is_finite(rng) for rng in obj.ranges) + @overload - def __getitem__(self, index: int) -> NamedRange: + def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ... @overload - def __getitem__(self, index: slice) -> Domain: # noqa: F811 # redefine unused + def __getitem__(self, index: slice) -> Self: # noqa: F811 # redefine unused ... @overload - def __getitem__(self, index: Dimension) -> NamedRange: # noqa: F811 # redefine unused + def __getitem__( # noqa: F811 # redefine unused + self, index: Dimension + ) -> tuple[Dimension, _Rng]: ... def __getitem__( # noqa: F811 # redefine unused @@ -434,6 +505,9 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain: return Domain(dims=dims, ranges=ranges) +FiniteDomain: TypeAlias = Domain[FiniteUnitRange] + + DomainLike: TypeAlias = ( Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike] ) # `Domain` is `Sequence[NamedRange]` and therefore a subset @@ -484,7 +558,7 @@ def _broadcast_ranges( broadcast_dims: Sequence[Dimension], dims: Sequence[Dimension], ranges: Sequence[UnitRange] ) -> tuple[UnitRange, ...]: return tuple( - ranges[dims.index(d)] if d in dims else UnitRange.infinity() for d in broadcast_dims + ranges[dims.index(d)] if d in dims else UnitRange.infinite() for d in broadcast_dims ) @@ -847,7 +921,7 @@ def asnumpy(self) -> Never: @functools.cached_property def domain(self) -> Domain: - return Domain(dims=(self.dimension,), ranges=(UnitRange.infinity(),)) + return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),)) @property def __gt_dims__(self) -> tuple[Dimension, ...]: diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 87e0800a10..94efe4d61d 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -58,6 +58,7 @@ def _relative_sub_domain( else: # not in new domain assert common.is_int_index(idx) + assert common.UnitRange.is_finite(rng) new_index = (rng.start if idx >= 0 else rng.stop) + idx if new_index < rng.start or new_index >= rng.stop: raise embedded_exceptions.IndexOutOfBounds( diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index fbfe64ac42..8bd2673db9 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -113,6 +113,7 @@ def __gt_dims__(self) -> tuple[common.Dimension, ...]: @property def __gt_origin__(self) -> tuple[int, ...]: + assert common.Domain.is_finite(self._domain) return tuple(-r.start for _, r in self._domain) @property @@ -386,6 +387,7 @@ def inverse_image( assert isinstance(image_range, common.UnitRange) + assert common.UnitRange.is_finite(image_range) restricted_mask = (self._ndarray >= image_range.start) & ( self._ndarray < image_range.stop ) @@ -566,9 +568,7 @@ def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...] named_ranges.append((dim, field.domain[pos][1])) else: domain_slice.append(np.newaxis) - named_ranges.append( - (dim, common.UnitRange(common.Infinity.negative(), common.Infinity.positive())) - ) + named_ranges.append((dim, common.UnitRange.infinite())) return common.field(field.ndarray[tuple(domain_slice)], domain=common.Domain(*named_ranges)) @@ -638,14 +638,19 @@ def _compute_slice( ValueError: If `new_rng` is not an integer or a UnitRange. """ if isinstance(rng, common.UnitRange): - if domain.ranges[pos] == common.UnitRange.infinity(): - return slice(None) - else: - return slice( - rng.start - domain.ranges[pos].start, - rng.stop - domain.ranges[pos].start, - ) + start = ( + rng.start - domain.ranges[pos].start + if common.UnitRange.is_left_finite(domain.ranges[pos]) + else None + ) + stop = ( + rng.stop - domain.ranges[pos].start + if common.UnitRange.is_right_finite(domain.ranges[pos]) + else None + ) + return slice(start, stop) elif common.is_int_index(rng): + assert common.Domain.is_finite(domain) return rng - domain.ranges[pos].start else: raise ValueError(f"Can only use integer or UnitRange ranges, provided type: '{type(rng)}'.") diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 93f17b1eb8..278dde9180 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -192,7 +192,7 @@ def broadcast( np.asarray(field)[ tuple([np.newaxis] * len(dims)) ], # TODO(havogt) use FunctionField once available - domain=common.Domain(dims=dims, ranges=tuple([common.UnitRange.infinity()] * len(dims))), + domain=common.Domain(dims=dims, ranges=tuple([common.UnitRange.infinite()] * len(dims))), ) diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index a4f32929db..ef70a2e645 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -1059,7 +1059,7 @@ def __gt_builtin_func__(func: Callable, /) -> NoReturn: # type: ignore[override @property def domain(self) -> common.Domain: - return common.Domain((self._dimension, common.UnitRange.infinity())) + return common.Domain((self._dimension, common.UnitRange.infinite())) @property def codomain(self) -> type[core_defs.int32]: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 65f9d9d71a..037c4f3e4d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -24,10 +24,9 @@ import gt4py.next.iterator.ir as itir import gt4py.next.program_processors.otf_compile_executor as otf_exec import gt4py.next.program_processors.processor_interface as ppi -from gt4py.next.common import Dimension, Domain, UnitRange, is_field -from gt4py.next.iterator.embedded import NeighborTableOffsetProvider, StridedNeighborOffsetProvider -from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms -from gt4py.next.otf.compilation import cache +from gt4py.next import common +from gt4py.next.iterator import embedded as itir_embedded, transforms as itir_transforms +from gt4py.next.otf.compilation import cache as compilation_cache from gt4py.next.type_system import type_specifications as ts, type_translation from .itir_to_sdfg import ItirToSDFG @@ -40,7 +39,8 @@ cp = None -def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]: +def get_sorted_dim_ranges(domain: common.Domain) -> Sequence[common.FiniteUnitRange]: + assert common.Domain.is_finite(domain) sorted_dims = get_sorted_dims(domain.dims) return [domain.ranges[dim_index] for dim_index, _ in sorted_dims] @@ -54,7 +54,7 @@ def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]: def convert_arg(arg: Any): - if is_field(arg): + if common.is_field(arg): sorted_dims = get_sorted_dims(arg.domain.dims) ndim = len(sorted_dims) dim_indices = [dim_index for dim_index, _ in sorted_dims] @@ -67,9 +67,11 @@ def convert_arg(arg: Any): def preprocess_program( - program: itir.FencilDefinition, offset_provider: Mapping[str, Any], lift_mode: LiftMode + program: itir.FencilDefinition, + offset_provider: Mapping[str, Any], + lift_mode: itir_transforms.LiftMode, ): - node = apply_common_transforms( + node = itir_transforms.apply_common_transforms( program, common_subexpression_elimination=False, lift_mode=lift_mode, @@ -81,7 +83,7 @@ def preprocess_program( if all([ItirToSDFG._check_no_lifts(closure) for closure in node.closures]): fencil_definition = node else: - fencil_definition = apply_common_transforms( + fencil_definition = itir_transforms.apply_common_transforms( program, common_subexpression_elimination=False, force_inline_lambda_args=True, @@ -109,7 +111,7 @@ def _ensure_is_on_device( def get_connectivity_args( - neighbor_tables: Sequence[tuple[str, NeighborTableOffsetProvider]], + neighbor_tables: Sequence[tuple[str, itir_embedded.NeighborTableOffsetProvider]], device: dace.dtypes.DeviceType, ) -> dict[str, Any]: return { @@ -134,7 +136,7 @@ def get_offset_args( return { str(sym): -drange.start for param, arg in zip(params, args) - if is_field(arg) + if common.is_field(arg) for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain)) } @@ -162,13 +164,19 @@ def get_stride_args( def get_cache_id( program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], - column_axis: Optional[Dimension], + column_axis: Optional[common.Dimension], offset_provider: Mapping[str, Any], ) -> str: max_neighbors = [ (k, v.max_neighbors) for k, v in offset_provider.items() - if isinstance(v, (NeighborTableOffsetProvider, StridedNeighborOffsetProvider)) + if isinstance( + v, + ( + itir_embedded.NeighborTableOffsetProvider, + itir_embedded.StridedNeighborOffsetProvider, + ), + ) ] cache_id_args = [ str(arg) @@ -191,8 +199,8 @@ def build_sdfg_from_itir( offset_provider: dict[str, Any], auto_optimize: bool = False, on_gpu: bool = False, - column_axis: Optional[Dimension] = None, - lift_mode: LiftMode = LiftMode.FORCE_INLINE, + column_axis: Optional[common.Dimension] = None, + lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE, ) -> dace.SDFG: """Translate a Fencil into an SDFG. @@ -210,7 +218,7 @@ def build_sdfg_from_itir( """ # TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force # `lift_more` to `FORCE_INLINE` mode. - lift_mode = LiftMode.FORCE_INLINE + lift_mode = itir_transforms.LiftMode.FORCE_INLINE arg_types = [type_translation.from_value(arg) for arg in args] device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU @@ -237,7 +245,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): build_type = kwargs.get("build_type", "RelWithDebInfo") on_gpu = kwargs.get("on_gpu", False) auto_optimize = kwargs.get("auto_optimize", False) - lift_mode = kwargs.get("lift_mode", LiftMode.FORCE_INLINE) + lift_mode = kwargs.get("lift_mode", itir_transforms.LiftMode.FORCE_INLINE) # ITIR parameters column_axis = kwargs.get("column_axis", None) offset_provider = kwargs["offset_provider"] @@ -263,7 +271,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): lift_mode=lift_mode, ) - sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" + sdfg.build_folder = compilation_cache._session_cache_dir_path / ".dacecache" with dace.config.temporary_config(): dace.config.Config.set("compiler", "build_type", value=build_type) dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py index 1a38e5245e..6863b09c12 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py @@ -11,7 +11,6 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -import dataclasses import itertools import math import operator @@ -20,7 +19,7 @@ import numpy as np import pytest -from gt4py.next import common, embedded +from gt4py.next import common from gt4py.next.common import Dimension, Domain, UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions, nd_array_field from gt4py.next.embedded.nd_array_field import _get_slices_from_domain_slice @@ -353,7 +352,7 @@ def test_cartesian_remap_implementation(): common.field( np.arange(10), domain=common.Domain(dims=(IDim,), ranges=(UnitRange(0, 10),)) ), - Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange.infinity())), + Domain(dims=(IDim, JDim), ranges=(UnitRange(0, 10), UnitRange.infinite())), ) ), ( @@ -362,7 +361,7 @@ def test_cartesian_remap_implementation(): common.field( np.arange(10), domain=common.Domain(dims=(JDim,), ranges=(UnitRange(0, 10),)) ), - Domain(dims=(IDim, JDim), ranges=(UnitRange.infinity(), UnitRange(0, 10))), + Domain(dims=(IDim, JDim), ranges=(UnitRange.infinite(), UnitRange(0, 10))), ) ), ( @@ -373,7 +372,7 @@ def test_cartesian_remap_implementation(): ), Domain( dims=(IDim, JDim, KDim), - ranges=(UnitRange.infinity(), UnitRange(0, 10), UnitRange.infinity()), + ranges=(UnitRange.infinite(), UnitRange(0, 10), UnitRange.infinite()), ), ) ), diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index bafabfb56e..7650e90c3c 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -14,6 +14,7 @@ import operator from typing import Optional, Pattern +import numpy as np import pytest from gt4py.next.common import ( @@ -41,6 +42,56 @@ def a_domain(): return Domain((IDim, UnitRange(0, 10)), (JDim, UnitRange(5, 15)), (KDim, UnitRange(20, 30))) +@pytest.fixture(params=[Infinity.POSITIVE, Infinity.NEGATIVE]) +def unbounded(request): + yield request.param + + +def test_unbounded_add_sub(unbounded): + assert unbounded + 1 == unbounded + assert unbounded - 1 == unbounded + + +@pytest.mark.parametrize("value", [-1, 0, 1]) +@pytest.mark.parametrize("op", [operator.le, operator.lt]) +def test_unbounded_comparison_less(value, op): + assert not op(Infinity.POSITIVE, value) + assert op(value, Infinity.POSITIVE) + + assert op(Infinity.NEGATIVE, value) + assert not op(value, Infinity.NEGATIVE) + + assert op(Infinity.NEGATIVE, Infinity.POSITIVE) + + +@pytest.mark.parametrize("value", [-1, 0, 1]) +@pytest.mark.parametrize("op", [operator.ge, operator.gt]) +def test_unbounded_comparison_greater(value, op): + assert op(Infinity.POSITIVE, value) + assert not op(value, Infinity.POSITIVE) + + assert not op(Infinity.NEGATIVE, value) + assert op(value, Infinity.NEGATIVE) + + assert not op(Infinity.NEGATIVE, Infinity.POSITIVE) + + +def test_unbounded_eq(unbounded): + assert unbounded == unbounded + assert unbounded <= unbounded + assert unbounded >= unbounded + assert not unbounded < unbounded + assert not unbounded > unbounded + + +@pytest.mark.parametrize("value", [-1, 0, 1]) +def test_unbounded_max_min(value): + assert max(Infinity.POSITIVE, value) == Infinity.POSITIVE + assert min(Infinity.POSITIVE, value) == value + assert max(Infinity.NEGATIVE, value) == value + assert min(Infinity.NEGATIVE, value) == Infinity.NEGATIVE + + def test_empty_range(): expected = UnitRange(0, 0) assert UnitRange(1, 1) == expected @@ -58,9 +109,20 @@ def test_unit_range_length(rng): assert len(rng) == 10 -@pytest.mark.parametrize("rng_like", [(2, 4), range(2, 4), UnitRange(2, 4)]) -def test_unit_range_like(rng_like): - assert unit_range(rng_like) == UnitRange(2, 4) +@pytest.mark.parametrize( + "rng_like, expected", + [ + ((2, 4), UnitRange(2, 4)), + (range(2, 4), UnitRange(2, 4)), + (UnitRange(2, 4), UnitRange(2, 4)), + ((None, None), UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE)), + ((2, None), UnitRange(2, Infinity.POSITIVE)), + ((None, 4), UnitRange(Infinity.NEGATIVE, 4)), + (None, UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE)), + ], +) +def test_unit_range_like(rng_like, expected): + assert unit_range(rng_like) == expected def test_unit_range_repr(rng): @@ -94,13 +156,6 @@ def test_unit_range_slice_error(rng): rng[1:2:5] -def test_unit_range_set_intersection(rng): - with pytest.raises( - NotImplementedError, match="Can only find the intersection between 'UnitRange' instances." - ): - rng & {1, 5} - - @pytest.mark.parametrize( "rng1, rng2, expected", [ @@ -121,46 +176,65 @@ def test_unit_range_intersection(rng1, rng2, expected): @pytest.mark.parametrize( "rng1, rng2, expected", [ - (UnitRange(20, Infinity.positive()), UnitRange(10, 15), UnitRange(0, 0)), - (UnitRange(Infinity.negative(), 0), UnitRange(5, 10), UnitRange(0, 0)), - (UnitRange(Infinity.negative(), 0), UnitRange(-10, 0), UnitRange(-10, 0)), - (UnitRange(0, Infinity.positive()), UnitRange(Infinity.negative(), 5), UnitRange(0, 5)), + (UnitRange(20, Infinity.POSITIVE), UnitRange(10, 15), UnitRange(0, 0)), + (UnitRange(Infinity.NEGATIVE, 0), UnitRange(5, 10), UnitRange(0, 0)), + (UnitRange(Infinity.NEGATIVE, 0), UnitRange(-10, 0), UnitRange(-10, 0)), + (UnitRange(0, Infinity.POSITIVE), UnitRange(Infinity.NEGATIVE, 5), UnitRange(0, 5)), ( - UnitRange(Infinity.negative(), 0), - UnitRange(Infinity.negative(), 5), - UnitRange(Infinity.negative(), 0), + UnitRange(Infinity.NEGATIVE, 0), + UnitRange(Infinity.NEGATIVE, 5), + UnitRange(Infinity.NEGATIVE, 0), ), ( - UnitRange(Infinity.negative(), Infinity.positive()), - UnitRange(Infinity.negative(), Infinity.positive()), - UnitRange(Infinity.negative(), Infinity.positive()), + UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE), + UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE), + UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE), ), ], ) -def test_unit_range_infinite_intersection(rng1, rng2, expected): +def test_unit_range_unbounded_intersection(rng1, rng2, expected): result = rng1 & rng2 assert result == expected -def test_positive_infinity_range(): - pos_inf_range = UnitRange(Infinity.positive(), Infinity.positive()) - assert len(pos_inf_range) == 0 +@pytest.mark.parametrize( + "rng", + [ + UnitRange(Infinity.NEGATIVE, 0), + UnitRange(0, Infinity.POSITIVE), + UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE), + ], +) +def test_positive_infinite_range_len(rng): + with pytest.raises(ValueError, match=r".*open.*"): + len(rng) -def test_mixed_infinity_range(): - mixed_inf_range = UnitRange(Infinity.negative(), Infinity.positive()) - assert len(mixed_inf_range) == Infinity.positive() +def test_range_contains(): + assert 1 in UnitRange(0, 2) + assert 1 not in UnitRange(0, 1) + assert 1 in UnitRange(0, Infinity.POSITIVE) + assert 1 in UnitRange(Infinity.NEGATIVE, 2) + assert 1 in UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE) + assert "s" not in UnitRange(Infinity.NEGATIVE, Infinity.POSITIVE) @pytest.mark.parametrize( "op, rng1, rng2, expected", [ (operator.le, UnitRange(-1, 2), UnitRange(-2, 3), True), - (operator.le, UnitRange(-1, 2), {-1, 0, 1}, True), - (operator.le, UnitRange(-1, 2), {-1, 0}, False), - (operator.le, UnitRange(-1, 2), {-2, -1, 0, 1, 2}, True), - (operator.le, UnitRange(Infinity.negative(), 2), UnitRange(Infinity.negative(), 3), True), - (operator.le, UnitRange(Infinity.negative(), 2), {1, 2, 3}, False), + (operator.le, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 3), True), + (operator.ge, UnitRange(-2, 3), UnitRange(-1, 2), True), + (operator.ge, UnitRange(Infinity.NEGATIVE, 3), UnitRange(Infinity.NEGATIVE, 2), True), + (operator.lt, UnitRange(-1, 2), UnitRange(-2, 2), True), + (operator.lt, UnitRange(-2, 1), UnitRange(-2, 2), True), + (operator.lt, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 3), True), + (operator.gt, UnitRange(-2, 2), UnitRange(-1, 2), True), + (operator.gt, UnitRange(-2, 2), UnitRange(-2, 1), True), + (operator.gt, UnitRange(Infinity.NEGATIVE, 3), UnitRange(Infinity.NEGATIVE, 2), True), + (operator.eq, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 2), True), + (operator.ne, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 3), True), + (operator.ne, UnitRange(Infinity.NEGATIVE, 2), UnitRange(Infinity.NEGATIVE, 2), False), ], ) def test_range_comparison(op, rng1, rng2, expected):