Skip to content

Commit

Permalink
feat[next]: Add missing UnitRange comparison functions (#1363)
Browse files Browse the repository at this point in the history
- Introduce a better Infinity
- Make UnitRange Generic to express finite, infinite, left-finite, right-finite properly.
- Remove `Set` from UnitRange
  • Loading branch information
havogt authored Dec 18, 2023
1 parent 0d66829 commit cdcd653
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 144 deletions.
228 changes: 151 additions & 77 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
import enum
import functools
import numbers
import sys
import types
from collections.abc import Mapping, Sequence, Set
from collections.abc import Mapping, Sequence

import numpy as np
import numpy.typing as npt
Expand All @@ -33,10 +32,12 @@
Any,
Callable,
ClassVar,
Generic,
Never,
Optional,
ParamSpec,
Protocol,
Self,
TypeAlias,
TypeGuard,
TypeVar,
Expand All @@ -52,16 +53,6 @@
DimsT = TypeVar("DimsT", bound=Sequence["Dimension"], covariant=True)


class Infinity(int):
@classmethod
def positive(cls) -> Infinity:
return cls(sys.maxsize)

@classmethod
def negative(cls) -> Infinity:
return cls(-sys.maxsize)


Tag: TypeAlias = str


Expand All @@ -84,31 +75,86 @@ def __str__(self):
return f"{self.value}[{self.kind}]"


class Infinity(enum.Enum):
"""Describes an unbounded `UnitRange`."""

NEGATIVE = enum.auto()
POSITIVE = enum.auto()

def __add__(self, _: int) -> Self:
return self

__radd__ = __add__

def __sub__(self, _: int) -> Self:
return self

__rsub__ = __sub__

def __le__(self, other: int | Infinity) -> bool:
return self is self.NEGATIVE or other is self.POSITIVE

def __lt__(self, other: int | Infinity) -> bool:
return self is self.NEGATIVE and other is not self

def __ge__(self, other: int | Infinity) -> bool:
return self is self.POSITIVE or other is self.NEGATIVE

def __gt__(self, other: int | Infinity) -> bool:
return self is self.POSITIVE and other is not self


def _as_int(v: core_defs.IntegralScalar | Infinity) -> int | Infinity:
return v if isinstance(v, Infinity) else int(v)


_Left = TypeVar("_Left", int, Infinity)
_Right = TypeVar("_Right", int, Infinity)


@dataclasses.dataclass(frozen=True, init=False)
class UnitRange(Sequence[int], Set[int]):
class UnitRange(Sequence[int], Generic[_Left, _Right]):
"""Range from `start` to `stop` with step size one."""

start: int
stop: int
start: _Left
stop: _Right

def __init__(self, start: core_defs.IntegralScalar, stop: core_defs.IntegralScalar) -> None:
def __init__(
self, start: core_defs.IntegralScalar | Infinity, stop: core_defs.IntegralScalar | Infinity
) -> None:
if start < stop:
object.__setattr__(self, "start", int(start))
object.__setattr__(self, "stop", int(stop))
object.__setattr__(self, "start", _as_int(start))
object.__setattr__(self, "stop", _as_int(stop))
else:
# make UnitRange(0,0) the single empty UnitRange
object.__setattr__(self, "start", 0)
object.__setattr__(self, "stop", 0)

# TODO: the whole infinity idea and implementation is broken and should be replaced
@classmethod
def infinity(cls) -> UnitRange:
return cls(Infinity.negative(), Infinity.positive())
def infinite(
cls,
) -> UnitRange:
return cls(Infinity.NEGATIVE, Infinity.POSITIVE)

def __len__(self) -> int:
if Infinity.positive() in (abs(self.start), abs(self.stop)):
return Infinity.positive()
return max(0, self.stop - self.start)
if UnitRange.is_finite(self):
return max(0, self.stop - self.start)
raise ValueError("Cannot compute length of open 'UnitRange'.")

@classmethod
def is_finite(cls, obj: UnitRange) -> TypeGuard[FiniteUnitRange]:
# classmethod since TypeGuards requires the guarded obj as separate argument
return obj.start is not Infinity.NEGATIVE and obj.stop is not Infinity.POSITIVE

@classmethod
def is_right_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[_Left, int]]:
# classmethod since TypeGuards requires the guarded obj as separate argument
return obj.stop is not Infinity.POSITIVE

@classmethod
def is_left_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[int, _Right]]:
# classmethod since TypeGuards requires the guarded obj as separate argument
return obj.start is not Infinity.NEGATIVE

def __repr__(self) -> str:
return f"UnitRange({self.start}, {self.stop})"
Expand All @@ -122,6 +168,7 @@ def __getitem__(self, index: slice) -> UnitRange: # noqa: F811 # redefine unuse
...

def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # redefine unused
assert UnitRange.is_finite(self)
if isinstance(index, slice):
start, stop, step = index.indices(len(self))
if step != 1:
Expand All @@ -138,61 +185,60 @@ def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # re
else:
raise IndexError("'UnitRange' index out of range")

def __and__(self, other: Set[int]) -> UnitRange:
if isinstance(other, UnitRange):
start = max(self.start, other.start)
stop = min(self.stop, other.stop)
return UnitRange(start, stop)
else:
raise NotImplementedError(
"Can only find the intersection between 'UnitRange' instances."
)
def __and__(self, other: UnitRange) -> UnitRange:
return UnitRange(max(self.start, other.start), min(self.stop, other.stop))

def __contains__(self, value: Any) -> bool:
return (
isinstance(value, core_defs.INTEGRAL_TYPES)
and value >= self.start
and value < self.stop
)

def __le__(self, other: UnitRange) -> bool:
return self.start >= other.start and self.stop <= other.stop

def __lt__(self, other: UnitRange) -> bool:
return (self.start > other.start and self.stop <= other.stop) or (
self.start >= other.start and self.stop < other.stop
)

def __ge__(self, other: UnitRange) -> bool:
return self.start <= other.start and self.stop >= other.stop

def __le__(self, other: Set[int]):
def __gt__(self, other: UnitRange) -> bool:
return (self.start < other.start and self.stop >= other.stop) or (
self.start <= other.start and self.stop > other.stop
)

def __eq__(self, other: Any) -> bool:
if isinstance(other, UnitRange):
return self.start >= other.start and self.stop <= other.stop
elif len(self) == Infinity.positive():
return False
else:
return Set.__le__(self, other)

def __add__(self, other: int | Set[int]) -> UnitRange:
if isinstance(other, int):
if other == Infinity.positive():
return UnitRange.infinity()
elif other == Infinity.negative():
return UnitRange(0, 0)
return UnitRange(
*(
s if s in [Infinity.negative(), Infinity.positive()] else s + other
for s in (self.start, self.stop)
)
)
else:
raise NotImplementedError("Can only compute union with 'int' instances.")

def __sub__(self, other: int | Set[int]) -> UnitRange:
if isinstance(other, int):
if other == Infinity.negative():
return self + Infinity.positive()
elif other == Infinity.positive():
return self + Infinity.negative()
else:
return self + (-other)
return self.start == other.start and self.stop == other.stop
else:
raise NotImplementedError("Can only compute substraction with 'int' instances.")
return False

def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)

__ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented
def __add__(self, other: int) -> UnitRange:
return UnitRange(self.start + other, self.stop + other)

def __sub__(self, other: int) -> UnitRange:
return UnitRange(self.start - other, self.stop - other)

def __str__(self) -> str:
return f"({self.start}:{self.stop})"


FiniteUnitRange: TypeAlias = UnitRange[int, int]


RangeLike: TypeAlias = (
UnitRange
| range
| tuple[core_defs.IntegralScalar, core_defs.IntegralScalar]
| core_defs.IntegralScalar
| None
)


Expand All @@ -207,18 +253,23 @@ def unit_range(r: RangeLike) -> UnitRange:
# once the related mypy bug (#16358) gets fixed
if (
isinstance(r, tuple)
and isinstance(r[0], core_defs.INTEGRAL_TYPES)
and isinstance(r[1], core_defs.INTEGRAL_TYPES)
and (isinstance(r[0], core_defs.INTEGRAL_TYPES) or r[0] in (None, Infinity.NEGATIVE))
and (isinstance(r[1], core_defs.INTEGRAL_TYPES) or r[1] in (None, Infinity.POSITIVE))
):
return UnitRange(r[0], r[1])
start = r[0] if r[0] is not None else Infinity.NEGATIVE
stop = r[1] if r[1] is not None else Infinity.POSITIVE
return UnitRange(start, stop)
if isinstance(r, core_defs.INTEGRAL_TYPES):
return UnitRange(0, cast(core_defs.IntegralScalar, r))
if r is None:
return UnitRange.infinite()
raise ValueError(f"'{r!r}' cannot be interpreted as 'UnitRange'.")


IntIndex: TypeAlias = int | core_defs.IntegralScalar
NamedIndex: TypeAlias = tuple[Dimension, IntIndex] # TODO: convert to NamedTuple
NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple
FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple
RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType
AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange
AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement
Expand All @@ -245,6 +296,10 @@ def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]:
)


def is_finite_named_range(v: NamedRange) -> TypeGuard[FiniteNamedRange]:
return UnitRange.is_finite(v[1])


def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]:
return (
isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1])
Expand Down Expand Up @@ -283,18 +338,27 @@ def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange:
return (v[0], unit_range(v[1]))


_Rng = TypeVar(
"_Rng",
UnitRange[int, int],
UnitRange[Infinity, int],
UnitRange[int, Infinity],
UnitRange[Infinity, Infinity],
)


@dataclasses.dataclass(frozen=True, init=False)
class Domain(Sequence[NamedRange]):
class Domain(Sequence[tuple[Dimension, _Rng]], Generic[_Rng]):
"""Describes the `Domain` of a `Field` as a `Sequence` of `NamedRange` s."""

dims: tuple[Dimension, ...]
ranges: tuple[UnitRange, ...]
ranges: tuple[_Rng, ...]

def __init__(
self,
*args: NamedRange,
*args: tuple[Dimension, _Rng],
dims: Optional[Sequence[Dimension]] = None,
ranges: Optional[Sequence[UnitRange]] = None,
ranges: Optional[Sequence[_Rng]] = None,
) -> None:
if dims is not None or ranges is not None:
if dims is None and ranges is None:
Expand Down Expand Up @@ -343,16 +407,23 @@ def ndim(self) -> int:
def shape(self) -> tuple[int, ...]:
return tuple(len(r) for r in self.ranges)

@classmethod
def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]:
# classmethod since TypeGuards requires the guarded obj as separate argument
return all(UnitRange.is_finite(rng) for rng in obj.ranges)

@overload
def __getitem__(self, index: int) -> NamedRange:
def __getitem__(self, index: int) -> tuple[Dimension, _Rng]:
...

@overload
def __getitem__(self, index: slice) -> Domain: # noqa: F811 # redefine unused
def __getitem__(self, index: slice) -> Self: # noqa: F811 # redefine unused
...

@overload
def __getitem__(self, index: Dimension) -> NamedRange: # noqa: F811 # redefine unused
def __getitem__( # noqa: F811 # redefine unused
self, index: Dimension
) -> tuple[Dimension, _Rng]:
...

def __getitem__( # noqa: F811 # redefine unused
Expand Down Expand Up @@ -434,6 +505,9 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain:
return Domain(dims=dims, ranges=ranges)


FiniteDomain: TypeAlias = Domain[FiniteUnitRange]


DomainLike: TypeAlias = (
Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike]
) # `Domain` is `Sequence[NamedRange]` and therefore a subset
Expand Down Expand Up @@ -484,7 +558,7 @@ def _broadcast_ranges(
broadcast_dims: Sequence[Dimension], dims: Sequence[Dimension], ranges: Sequence[UnitRange]
) -> tuple[UnitRange, ...]:
return tuple(
ranges[dims.index(d)] if d in dims else UnitRange.infinity() for d in broadcast_dims
ranges[dims.index(d)] if d in dims else UnitRange.infinite() for d in broadcast_dims
)


Expand Down Expand Up @@ -847,7 +921,7 @@ def asnumpy(self) -> Never:

@functools.cached_property
def domain(self) -> Domain:
return Domain(dims=(self.dimension,), ranges=(UnitRange.infinity(),))
return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),))

@property
def __gt_dims__(self) -> tuple[Dimension, ...]:
Expand Down
1 change: 1 addition & 0 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _relative_sub_domain(
else:
# not in new domain
assert common.is_int_index(idx)
assert common.UnitRange.is_finite(rng)
new_index = (rng.start if idx >= 0 else rng.stop) + idx
if new_index < rng.start or new_index >= rng.stop:
raise embedded_exceptions.IndexOutOfBounds(
Expand Down
Loading

0 comments on commit cdcd653

Please sign in to comment.