Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Add missing UnitRange comparison functions #1363

Merged
merged 37 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
faca090
feat[next] Enable embedded field view in ffront_tests
havogt Nov 16, 2023
7931cfd
broadcast for scalars
havogt Nov 16, 2023
4734c84
implement astype
havogt Nov 16, 2023
e1463d0
support binary builtins for scalars
havogt Nov 16, 2023
f1047dc
support domain
havogt Nov 16, 2023
9ac0ddd
add __ne__, __eq__
havogt Nov 16, 2023
f8682ed
fix typo
havogt Nov 16, 2023
42805f7
this is the typo, the other was improve alloc
havogt Nov 16, 2023
ec0a0d5
cleanup import in fbuiltin
havogt Nov 16, 2023
ac28ea0
fix test case
havogt Nov 16, 2023
89e05ea
fix/ignore typing
havogt Nov 16, 2023
d639ff0
improve default backend selection
havogt Nov 17, 2023
5ad6be5
add comment
havogt Nov 17, 2023
11872f3
address review comments
havogt Nov 17, 2023
e9893be
clarify comment
havogt Nov 17, 2023
7bc0689
implement le for UnitRange
havogt Nov 17, 2023
d5b15c4
Merge remote-tracking branch 'origin/main' into enable_embedded_in_ff…
havogt Nov 17, 2023
3a9bfd6
address last comment
havogt Nov 17, 2023
4b6633f
fix test: convert to ndarray
havogt Nov 17, 2023
b1ddd31
feat[next] add missing UnitRange comparison functions
havogt Nov 17, 2023
d1a5045
Merge remote-tracking branch 'origin/main' into unit_range_comparison
havogt Nov 17, 2023
f1f1fae
Merge remote-tracking branch 'upstream/main' into unit_range_comparison
havogt Dec 1, 2023
6ac320f
cleaner unbound UnitRange
havogt Dec 4, 2023
98f92c4
test for bound_start, bound_stop
havogt Dec 4, 2023
eb6f38b
use enum and rename
havogt Dec 5, 2023
97abbac
last comments
havogt Dec 5, 2023
470caf3
remove set from unitrange
havogt Dec 5, 2023
dffcf0e
address review comments
havogt Dec 6, 2023
7fb65b3
add constructor test
havogt Dec 6, 2023
7b574e9
refactor with finite UnitRange and Domain
havogt Dec 7, 2023
e6d92e0
cleanup unit range constructor
havogt Dec 7, 2023
534ba45
Merge remote-tracking branch 'upstream/main' into unit_range_comparison
havogt Dec 7, 2023
a7a7a76
parametrize unitrange in left, right inf or int
havogt Dec 7, 2023
539c9a2
Merge remote-tracking branch 'upstream/main' into unit_range_comparison
havogt Dec 12, 2023
0ace18e
Merge remote-tracking branch 'upstream/main' into unit_range_comparison
havogt Dec 12, 2023
10f9d3c
Merge remote-tracking branch 'upstream/main' into unit_range_comparison
havogt Dec 14, 2023
5b6a02c
address review comments
havogt Dec 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 148 additions & 78 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@
import enum
import functools
import numbers
import sys
import types
from collections.abc import Mapping, Sequence, Set
from collections.abc import Mapping, Sequence

import numpy as np
import numpy.typing as npt
Expand All @@ -33,10 +32,12 @@
Any,
Callable,
ClassVar,
Generic,
Never,
Optional,
ParamSpec,
Protocol,
Self,
TypeAlias,
TypeGuard,
TypeVar,
Expand All @@ -52,16 +53,6 @@
DimsT = TypeVar("DimsT", bound=Sequence["Dimension"], covariant=True)


class Infinity(int):
@classmethod
def positive(cls) -> Infinity:
return cls(sys.maxsize)

@classmethod
def negative(cls) -> Infinity:
return cls(-sys.maxsize)


Tag: TypeAlias = str


Expand All @@ -84,31 +75,87 @@ def __str__(self):
return f"{self.value}[{self.kind}]"


class Infinity(enum.Enum):
"""Describes an unbounded `UnitRange`."""

NEGATIVE = enum.auto()
POSITIVE = enum.auto()

def __add__(self, _: int) -> Self:
return self

__radd__ = __add__

def __sub__(self, _: int) -> Self:
return self

__rsub__ = __sub__

def __le__(self, other: int | Infinity) -> bool:
return self is self.NEGATIVE or other is self.POSITIVE

def __lt__(self, other: int | Infinity) -> bool:
return self is self.NEGATIVE and other is not self

def __ge__(self, other: int | Infinity) -> bool:
return self is self.POSITIVE or other is self.NEGATIVE

def __gt__(self, other: int | Infinity) -> bool:
return self is self.POSITIVE and other is not self


def _as_int(v: core_defs.IntegralScalar | Infinity) -> int | Infinity:
return v if isinstance(v, Infinity) else int(v)


_Left = TypeVar("_Left", int, Infinity)
_Right = TypeVar("_Right", int, Infinity)


@dataclasses.dataclass(frozen=True, init=False)
class UnitRange(Sequence[int], Set[int]):
"""Range from `start` to `stop` with step size one."""
class UnitRange(Sequence[int], Generic[_Left, _Right]):
"""
Range from `start` to `stop` with step size one.

start: int
stop: int
An open range is constructed by passing `None` for `start` and/or `stop`.
havogt marked this conversation as resolved.
Show resolved Hide resolved
"""

start: _Left
stop: _Right

def __init__(self, start: core_defs.IntegralScalar, stop: core_defs.IntegralScalar) -> None:
def __init__(
self, start: core_defs.IntegralScalar | Infinity, stop: core_defs.IntegralScalar | Infinity
) -> None:
if start < stop:
object.__setattr__(self, "start", int(start))
object.__setattr__(self, "stop", int(stop))
object.__setattr__(self, "start", _as_int(start))
object.__setattr__(self, "stop", _as_int(stop))
else:
# make UnitRange(0,0) the single empty UnitRange
object.__setattr__(self, "start", 0)
object.__setattr__(self, "stop", 0)

# TODO: the whole infinity idea and implementation is broken and should be replaced
@classmethod
def infinity(cls) -> UnitRange:
return cls(Infinity.negative(), Infinity.positive())
def infinite(
cls,
) -> UnitRange:
return cls(Infinity.NEGATIVE, Infinity.POSITIVE)

def __len__(self) -> int:
if Infinity.positive() in (abs(self.start), abs(self.stop)):
return Infinity.positive()
return max(0, self.stop - self.start)
if UnitRange.is_finite(self):
return max(0, self.stop - self.start)
raise ValueError("Cannot compute length of open 'UnitRange'.")

@classmethod
havogt marked this conversation as resolved.
Show resolved Hide resolved
def is_finite(cls, obj: UnitRange) -> TypeGuard[FiniteUnitRange]:
return obj.start is not Infinity.NEGATIVE and obj.stop is not Infinity.POSITIVE

@classmethod
def is_right_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[_Left, int]]:
return obj.stop is not Infinity.POSITIVE

@classmethod
def is_left_finite(cls, obj: UnitRange) -> TypeGuard[UnitRange[int, _Right]]:
return obj.start is not Infinity.NEGATIVE

def __repr__(self) -> str:
return f"UnitRange({self.start}, {self.stop})"
Expand All @@ -122,6 +169,7 @@ def __getitem__(self, index: slice) -> UnitRange: # noqa: F811 # redefine unuse
...

def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # redefine unused
assert UnitRange.is_finite(self)
if isinstance(index, slice):
start, stop, step = index.indices(len(self))
if step != 1:
Expand All @@ -138,61 +186,58 @@ def __getitem__(self, index: int | slice) -> int | UnitRange: # noqa: F811 # re
else:
raise IndexError("'UnitRange' index out of range")

def __and__(self, other: Set[int]) -> UnitRange:
if isinstance(other, UnitRange):
start = max(self.start, other.start)
stop = min(self.stop, other.stop)
return UnitRange(start, stop)
else:
raise NotImplementedError(
"Can only find the intersection between 'UnitRange' instances."
)
def __and__(self, other: UnitRange) -> UnitRange:
return UnitRange(max(self.start, other.start), min(self.stop, other.stop))

def __le__(self, other: Set[int]):
if isinstance(other, UnitRange):
return self.start >= other.start and self.stop <= other.stop
elif len(self) == Infinity.positive():
def __contains__(self, value: Any) -> bool:
if not isinstance(value, core_defs.INTEGRAL_TYPES):
return False
return value >= self.start and value < self.stop
havogt marked this conversation as resolved.
Show resolved Hide resolved

def __le__(self, other: UnitRange) -> bool:
return self.start >= other.start and self.stop <= other.stop

def __lt__(self, other: UnitRange) -> bool:
return (self.start > other.start and self.stop <= other.stop) or (
self.start >= other.start and self.stop < other.stop
)
Comment on lines +202 to +204
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return (self.start > other.start and self.stop <= other.stop) or (
self.start >= other.start and self.stop < other.stop
)
return self < other or self == other

The version I proposed appears much easier to comprehend and now that I see that 4 comparisons are needed performance appears to be the same anyway.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with your proposal we have 2 function calls, 6 comparison and an isinstance check (at least), but I agree that it's simpler to understand

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. @havogt's proposal is better for performance reasons, @tehrengruber is better for readability. In this case I think @havogt version is actually better, but both are fine for me.


def __ge__(self, other: UnitRange) -> bool:
return self.start <= other.start and self.stop >= other.stop

def __gt__(self, other: UnitRange) -> bool:
return (self.start < other.start and self.stop >= other.stop) or (
self.start <= other.start and self.stop > other.stop
)

def __eq__(self, other: Any) -> bool:
if isinstance(other, UnitRange):
return self.start == other.start and self.stop == other.stop
else:
return Set.__le__(self, other)

def __add__(self, other: int | Set[int]) -> UnitRange:
if isinstance(other, int):
if other == Infinity.positive():
return UnitRange.infinity()
elif other == Infinity.negative():
return UnitRange(0, 0)
return UnitRange(
*(
s if s in [Infinity.negative(), Infinity.positive()] else s + other
for s in (self.start, self.stop)
)
)
else:
raise NotImplementedError("Can only compute union with 'int' instances.")

def __sub__(self, other: int | Set[int]) -> UnitRange:
if isinstance(other, int):
if other == Infinity.negative():
return self + Infinity.positive()
elif other == Infinity.positive():
return self + Infinity.negative()
else:
return self + (-other)
else:
raise NotImplementedError("Can only compute substraction with 'int' instances.")
return False

def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)

__ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented
def __add__(self, other: int) -> UnitRange:
return UnitRange(self.start + other, self.stop + other)

def __sub__(self, other: int) -> UnitRange:
return UnitRange(self.start - other, self.stop - other)

def __str__(self) -> str:
return f"({self.start}:{self.stop})"


FiniteUnitRange: TypeAlias = UnitRange[int, int]


RangeLike: TypeAlias = (
UnitRange
| range
| tuple[core_defs.IntegralScalar, core_defs.IntegralScalar]
| core_defs.IntegralScalar
| None
)


Expand All @@ -207,18 +252,27 @@ def unit_range(r: RangeLike) -> UnitRange:
# once the related mypy bug (#16358) gets fixed
if (
isinstance(r, tuple)
and isinstance(r[0], core_defs.INTEGRAL_TYPES)
and isinstance(r[1], core_defs.INTEGRAL_TYPES)
and (
isinstance(r[0], core_defs.INTEGRAL_TYPES) or r[0] is None or r[0] is Infinity.NEGATIVE
havogt marked this conversation as resolved.
Show resolved Hide resolved
)
and (
isinstance(r[1], core_defs.INTEGRAL_TYPES) or r[1] is None or r[1] is Infinity.POSITIVE
havogt marked this conversation as resolved.
Show resolved Hide resolved
)
):
return UnitRange(r[0], r[1])
start = r[0] if r[0] is not None else Infinity.NEGATIVE
stop = r[1] if r[1] is not None else Infinity.POSITIVE
return UnitRange(start, stop)
if isinstance(r, core_defs.INTEGRAL_TYPES):
return UnitRange(0, cast(core_defs.IntegralScalar, r))
if r is None:
return UnitRange.infinite()
raise ValueError(f"'{r!r}' cannot be interpreted as 'UnitRange'.")


IntIndex: TypeAlias = int | core_defs.IntegralScalar
NamedIndex: TypeAlias = tuple[Dimension, IntIndex] # TODO: convert to NamedTuple
NamedRange: TypeAlias = tuple[Dimension, UnitRange] # TODO: convert to NamedTuple
FiniteNamedRange: TypeAlias = tuple[Dimension, FiniteUnitRange] # TODO: convert to NamedTuple
RelativeIndexElement: TypeAlias = IntIndex | slice | types.EllipsisType
AbsoluteIndexElement: TypeAlias = NamedIndex | NamedRange
AnyIndexElement: TypeAlias = RelativeIndexElement | AbsoluteIndexElement
Expand All @@ -245,6 +299,10 @@ def is_named_range(v: AnyIndexSpec) -> TypeGuard[NamedRange]:
)


def is_finite_named_range(v: NamedRange) -> TypeGuard[FiniteNamedRange]:
return UnitRange.is_finite(v[1])


def is_named_index(v: AnyIndexSpec) -> TypeGuard[NamedRange]:
return (
isinstance(v, tuple) and len(v) == 2 and isinstance(v[0], Dimension) and is_int_index(v[1])
Expand Down Expand Up @@ -283,18 +341,21 @@ def named_range(v: tuple[Dimension, RangeLike]) -> NamedRange:
return (v[0], unit_range(v[1]))


_Rng = TypeVar("_Rng", bound=UnitRange)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question: is this actually correct? Shouldn't it be _Rng = TypeVar("_Rng", UnitRange, FiniteUnitRange) instead since FiniteUnitRange doesn't inherit from UnitRange?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe

_Rng = TypeVar(
    "_Rng",
    UnitRange[int, int],
    UnitRange[Infinity, int],
    UnitRange[int, Infinity],
    UnitRange[Infinity, Infinity],
)



@dataclasses.dataclass(frozen=True, init=False)
class Domain(Sequence[NamedRange]):
class Domain(Sequence[tuple[Dimension, _Rng]], Generic[_Rng]):
"""Describes the `Domain` of a `Field` as a `Sequence` of `NamedRange` s."""

dims: tuple[Dimension, ...]
ranges: tuple[UnitRange, ...]
ranges: tuple[_Rng, ...]

def __init__(
self,
*args: NamedRange,
*args: tuple[Dimension, _Rng],
dims: Optional[Sequence[Dimension]] = None,
ranges: Optional[Sequence[UnitRange]] = None,
ranges: Optional[Sequence[_Rng]] = None,
) -> None:
if dims is not None or ranges is not None:
if dims is None and ranges is None:
Expand Down Expand Up @@ -343,16 +404,22 @@ def ndim(self) -> int:
def shape(self) -> tuple[int, ...]:
return tuple(len(r) for r in self.ranges)

@classmethod
def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]:
return all(UnitRange.is_finite(rng) for rng in obj.ranges)

@overload
def __getitem__(self, index: int) -> NamedRange:
def __getitem__(self, index: int) -> tuple[Dimension, _Rng]:
...

@overload
def __getitem__(self, index: slice) -> Domain: # noqa: F811 # redefine unused
def __getitem__(self, index: slice) -> Self: # noqa: F811 # redefine unused
...

@overload
def __getitem__(self, index: Dimension) -> NamedRange: # noqa: F811 # redefine unused
def __getitem__( # noqa: F811 # redefine unused
self, index: Dimension
) -> tuple[Dimension, _Rng]:
...

def __getitem__( # noqa: F811 # redefine unused
Expand Down Expand Up @@ -434,6 +501,9 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain:
return Domain(dims=dims, ranges=ranges)


FiniteDomain: TypeAlias = Domain[FiniteUnitRange]


DomainLike: TypeAlias = (
Sequence[tuple[Dimension, RangeLike]] | Mapping[Dimension, RangeLike]
) # `Domain` is `Sequence[NamedRange]` and therefore a subset
Expand Down Expand Up @@ -484,7 +554,7 @@ def _broadcast_ranges(
broadcast_dims: Sequence[Dimension], dims: Sequence[Dimension], ranges: Sequence[UnitRange]
) -> tuple[UnitRange, ...]:
return tuple(
ranges[dims.index(d)] if d in dims else UnitRange.infinity() for d in broadcast_dims
ranges[dims.index(d)] if d in dims else UnitRange.infinite() for d in broadcast_dims
)


Expand Down Expand Up @@ -847,7 +917,7 @@ def asnumpy(self) -> Never:

@functools.cached_property
def domain(self) -> Domain:
return Domain(dims=(self.dimension,), ranges=(UnitRange.infinity(),))
return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),))

@property
def __gt_dims__(self) -> tuple[Dimension, ...]:
Expand Down
1 change: 1 addition & 0 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def _relative_sub_domain(
else:
# not in new domain
assert common.is_int_index(idx)
assert common.UnitRange.is_finite(rng)
new_index = (rng.start if idx >= 0 else rng.stop) + idx
if new_index < rng.start or new_index >= rng.stop:
raise embedded_exceptions.IndexOutOfBounds(
Expand Down
Loading