Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next]: Add missing UnitRange comparison functions #1363

Merged
merged 37 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
faca090
feat[next] Enable embedded field view in ffront_tests
havogt Nov 16, 2023
7931cfd
broadcast for scalars
havogt Nov 16, 2023
4734c84
implement astype
havogt Nov 16, 2023
e1463d0
support binary builtins for scalars
havogt Nov 16, 2023
f1047dc
support domain
havogt Nov 16, 2023
9ac0ddd
add __ne__, __eq__
havogt Nov 16, 2023
f8682ed
fix typo
havogt Nov 16, 2023
42805f7
this is the typo, the other was improve alloc
havogt Nov 16, 2023
ec0a0d5
cleanup import in fbuiltin
havogt Nov 16, 2023
ac28ea0
fix test case
havogt Nov 16, 2023
89e05ea
fix/ignore typing
havogt Nov 16, 2023
d639ff0
improve default backend selection
havogt Nov 17, 2023
5ad6be5
add comment
havogt Nov 17, 2023
11872f3
address review comments
havogt Nov 17, 2023
e9893be
clarify comment
havogt Nov 17, 2023
7bc0689
implement le for UnitRange
havogt Nov 17, 2023
d5b15c4
Merge remote-tracking branch 'origin/main' into enable_embedded_in_ff…
havogt Nov 17, 2023
3a9bfd6
address last comment
havogt Nov 17, 2023
4b6633f
fix test: convert to ndarray
havogt Nov 17, 2023
b1ddd31
feat[next] add missing UnitRange comparison functions
havogt Nov 17, 2023
d1a5045
Merge remote-tracking branch 'origin/main' into unit_range_comparison
havogt Nov 17, 2023
f1f1fae
Merge remote-tracking branch 'upstream/main' into unit_range_comparison
havogt Dec 1, 2023
6ac320f
cleaner unbound UnitRange
havogt Dec 4, 2023
98f92c4
test for bound_start, bound_stop
havogt Dec 4, 2023
eb6f38b
use enum and rename
havogt Dec 5, 2023
97abbac
last comments
havogt Dec 5, 2023
470caf3
remove set from unitrange
havogt Dec 5, 2023
dffcf0e
address review comments
havogt Dec 6, 2023
7fb65b3
add constructor test
havogt Dec 6, 2023
7b574e9
refactor with finite UnitRange and Domain
havogt Dec 7, 2023
e6d92e0
cleanup unit range constructor
havogt Dec 7, 2023
534ba45
Merge remote-tracking branch 'upstream/main' into unit_range_comparison
havogt Dec 7, 2023
a7a7a76
parametrize unitrange in left, right inf or int
havogt Dec 7, 2023
539c9a2
Merge remote-tracking branch 'upstream/main' into unit_range_comparison
havogt Dec 12, 2023
0ace18e
Merge remote-tracking branch 'upstream/main' into unit_range_comparison
havogt Dec 12, 2023
10f9d3c
Merge remote-tracking branch 'upstream/main' into unit_range_comparison
havogt Dec 14, 2023
5b6a02c
address review comments
havogt Dec 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,68 @@ def __and__(self, other: Set[int]) -> UnitRange:
else:
raise NotImplementedError("Can only find the intersection between UnitRange instances.")

def __le__(self, other: Set[int]):
def __contains__(self, value: Any) -> bool:
if not isinstance(value, core_defs.INTEGRAL_TYPES):
return False
if value == Infinity.positive() or value == Infinity.negative():
# raising error was an adhoc decision, feel free to improve
raise ValueError("Cannot check if Infinity is in UnitRange.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if value == Infinity.positive() or value == Infinity.negative():
# raising error was an adhoc decision, feel free to improve
raise ValueError("Cannot check if Infinity is in UnitRange.")

I would suggest to skip this for performance reasons. Let's talk if you want to know why I care about performance here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let's talk

return value >= self.start and value < self.stop

def __le__(self, other: Set[int]) -> bool:
if isinstance(other, UnitRange):
return self.start >= other.start and self.stop <= other.stop
elif len(self) == Infinity.positive():
return False
else:
return Set.__le__(self, other)

__ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented
def __lt__(self, other: Set[int]) -> bool:
if isinstance(other, UnitRange):
return (self.start > other.start and self.stop <= other.stop) or (
self.start >= other.start and self.stop < other.stop
)
elif len(self) == Infinity.positive():
return False
else:
return Set.__lt__(self, other)

def __ge__(self, other: Set[int]) -> bool:
if isinstance(other, UnitRange):
return self.start <= other.start and self.stop >= other.stop
elif len(self) == Infinity.positive():
for v in other:
if v not in self:
return False
return True
else:
return Set.__ge__(self, other)

def __gt__(self, other: Set[int]) -> bool:
if isinstance(other, UnitRange):
return (self.start < other.start and self.stop >= other.stop) or (
self.start <= other.start and self.stop > other.stop
)
elif len(self) == Infinity.positive():
for v in other:
if v not in self:
return False
return True
else:
return Set.__gt__(self, other)

def __eq__(self, other: Any) -> bool:
if isinstance(other, UnitRange):
return self.start == other.start and self.stop == other.stop
elif len(self) == Infinity.positive():
return False
elif isinstance(other, Set):
return Set.__eq__(self, other)
else:
return False

def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)

def __str__(self) -> str:
return f"({self.start}:{self.stop})"
Expand Down
60 changes: 60 additions & 0 deletions tests/next_tests/unit_tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
KDim = Dimension("KDim", kind=DimensionKind.VERTICAL)


@pytest.fixture(params=[Infinity.positive(), Infinity.negative()])
def inf(request):
yield request.param


@pytest.fixture
def a_domain():
return Domain((IDim, UnitRange(0, 10)), (JDim, UnitRange(5, 15)), (KDim, UnitRange(20, 30)))
Expand Down Expand Up @@ -151,6 +156,61 @@ def test_mixed_infinity_range():
assert len(mixed_inf_range) == Infinity.positive()


def test_range_contains():
assert 1 in UnitRange(0, 2)
assert 1 not in UnitRange(0, 1)
assert 1 in UnitRange(0, Infinity.positive())
assert 1 in UnitRange(Infinity.negative(), 2)
assert 1 in UnitRange(Infinity.negative(), Infinity.positive())
assert "s" not in UnitRange(Infinity.negative(), Infinity.positive())


def test_range_contains_infinity(inf):
with pytest.raises(ValueError):
inf in UnitRange(Infinity.negative(), Infinity.positive())


@pytest.mark.parametrize(
"op, rng1, rng2, expected",
[
(operator.le, UnitRange(-1, 2), UnitRange(-2, 3), True),
(operator.le, UnitRange(-1, 2), {-1, 0, 1}, True),
(operator.le, UnitRange(-1, 2), {-1, 0}, False),
(operator.le, UnitRange(-1, 2), {-2, -1, 0, 1, 2}, True),
(operator.le, UnitRange(Infinity.negative(), 2), UnitRange(Infinity.negative(), 3), True),
(operator.le, UnitRange(Infinity.negative(), 2), {1, 2, 3}, False),
(operator.ge, UnitRange(-2, 3), UnitRange(-1, 2), True),
(operator.ge, UnitRange(-2, 3), {-2, -1, 0, 1, 2}, True),
(operator.ge, UnitRange(-2, 3), {-2, -1, 0, 1, 2, 3}, False),
(operator.ge, UnitRange(Infinity.negative(), 3), UnitRange(Infinity.negative(), 2), True),
(operator.ge, UnitRange(Infinity.negative(), 3), {1, 2}, True),
(operator.lt, UnitRange(-1, 2), UnitRange(-2, 2), True),
(operator.lt, UnitRange(-2, 1), UnitRange(-2, 2), True),
(operator.lt, UnitRange(-2, 2), {-1, 0, 1, 2}, False),
(operator.lt, UnitRange(-2, 2), {-2, -1, 0, 1, 2}, True),
(operator.lt, UnitRange(-2, 2), {-3, -2, -1, 0, 1, 2}, True),
(operator.lt, UnitRange(Infinity.negative(), 2), UnitRange(Infinity.negative(), 3), True),
(operator.lt, UnitRange(Infinity.negative(), 2), {1, 2, 3}, False),
(operator.gt, UnitRange(-2, 2), UnitRange(-1, 2), True),
(operator.gt, UnitRange(-2, 2), UnitRange(-2, 1), True),
(operator.gt, UnitRange(-2, 2), {-1, 0, 1}, True),
(operator.gt, UnitRange(-2, 2), {-2, -1, 0}, True),
(operator.gt, UnitRange(-2, 2), {-2, -1, 0, 1}, False),
(operator.gt, UnitRange(Infinity.negative(), 3), UnitRange(Infinity.negative(), 2), True),
(operator.gt, UnitRange(Infinity.negative(), 2), {0, 1}, True),
(operator.gt, UnitRange(Infinity.negative(), 2), {1, 2}, False),
(operator.eq, UnitRange(Infinity.negative(), 2), UnitRange(Infinity.negative(), 2), True),
(operator.eq, UnitRange(-2, 2), {-2, -1, 0, 1}, True),
(operator.eq, UnitRange(-2, 2), {-2, 1}, False),
(operator.ne, UnitRange(Infinity.negative(), 2), UnitRange(Infinity.negative(), 3), True),
(operator.ne, UnitRange(Infinity.negative(), 2), UnitRange(Infinity.negative(), 2), False),
(operator.ne, UnitRange(-2, 2), {-2, -1, 0}, True),
],
)
def test_range_comparison(op, rng1, rng2, expected):
assert op(rng1, rng2) == expected


@pytest.mark.parametrize(
"op, rng1, rng2, expected",
[
Expand Down