Skip to content

Commit

Permalink
test is_finite
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals committed Sep 6, 2023
1 parent a6c4e14 commit 7cefa7e
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 42 deletions.
1 change: 1 addition & 0 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from gt4py.next import common
from gt4py.next.embedded import exceptions as embedded_exceptions


# TODO: handle 0-D empty domain case. If Ellipsis should give back domain and in all other cases error.
def sub_domain(domain: common.Domain, index: common.AnyIndexSpec) -> common.Domain:
index_sequence = common.as_any_index_sequence(index)
Expand Down
15 changes: 9 additions & 6 deletions src/gt4py/next/embedded/function_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ def __post_init__(self):
func_params = self.func.__code__.co_argcount
if func_params != num_params:
raise ValueError(
f"Invariant violation: len(self.domain) ({num_params}) does not match the number of parameters of the self.func ({func_params}).")
f"Invariant violation: len(self.domain) ({num_params}) does not match the number of parameters of the self.func ({func_params})."
)
except AttributeError:
raise ValueError(f"Must pass a function as an argument to self.func.")

Expand All @@ -62,19 +63,21 @@ def ndarray(self) -> core_defs.NDArrayObject:

return np.fromfunction(self.func, shape)

def _handle_function_field_op(
self, other: FunctionField, op: Callable
) -> FunctionField:
def _handle_function_field_op(self, other: FunctionField, op: Callable) -> FunctionField:
domain_intersection = self.domain & other.domain
broadcasted_self = _broadcast(self, domain_intersection.dims)
broadcasted_other = _broadcast(other, domain_intersection.dims)
return self.__class__(
_compose(op, broadcasted_self, broadcasted_other), domain_intersection, _skip_invariant=True
_compose(op, broadcasted_self, broadcasted_other),
domain_intersection,
_skip_invariant=True,
)

def _handle_scalar_op(self, other: FunctionField, op: Callable) -> FunctionField:
new_func = lambda *args: op(self.func(*args), other)
return self.__class__(new_func, self.domain, _skip_invariant=True) # skip invariant as we cannot deduce number of args
return self.__class__(
new_func, self.domain, _skip_invariant=True
) # skip invariant as we cannot deduce number of args

@overload
def _binary_operation(self, op: Callable, other: core_defs.ScalarT) -> common.Field:
Expand Down
104 changes: 68 additions & 36 deletions tests/next_tests/unit_tests/embedded_tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,64 +65,64 @@ def test_slice_range(rng, slce, expected):
([(I, (-2, 3))], (I, 3), IndexError),
([(I, (-2, 3))], (I, UnitRange(3, 4)), IndexError),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
2,
[(J, (3, 6)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
2,
[(J, (3, 6)), (K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
slice(2, 3),
[(I, (4, 5)), (J, (3, 6)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
slice(2, 3),
[(I, (4, 5)), (J, (3, 6)), (K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(I, 2),
[(J, (3, 6)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(I, 2),
[(J, (3, 6)), (K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(I, UnitRange(2, 3)),
[(I, (2, 3)), (J, (3, 6)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(I, UnitRange(2, 3)),
[(I, (2, 3)), (J, (3, 6)), (K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(J, 3),
[(I, (2, 5)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(J, 3),
[(I, (2, 5)), (K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(J, UnitRange(4, 5)),
[(I, (2, 5)), (J, (4, 5)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(J, UnitRange(4, 5)),
[(I, (2, 5)), (J, (4, 5)), (K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
((J, 3), (I, 2)),
[(K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
((J, 3), (I, 2)),
[(K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
((J, UnitRange(4, 5)), (I, 2)),
[(J, (4, 5)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
((J, UnitRange(4, 5)), (I, 2)),
[(J, (4, 5)), (K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(slice(1, 2), slice(2, 3)),
[(I, (3, 4)), (J, (5, 6)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(slice(1, 2), slice(2, 3)),
[(I, (3, 4)), (J, (5, 6)), (K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(Ellipsis, slice(2, 3)),
[(I, (2, 5)), (J, (3, 6)), (K, (6, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(Ellipsis, slice(2, 3)),
[(I, (2, 5)), (J, (3, 6)), (K, (6, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(slice(1, 2), Ellipsis, slice(2, 3)),
[(I, (3, 4)), (J, (3, 6)), (K, (6, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(slice(1, 2), Ellipsis, slice(2, 3)),
[(I, (3, 4)), (J, (3, 6)), (K, (6, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(slice(1, 2), slice(1, 2), Ellipsis),
[(I, (3, 4)), (J, (4, 5)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(slice(1, 2), slice(1, 2), Ellipsis),
[(I, (3, 4)), (J, (4, 5)), (K, (4, 7))],
),
([], Ellipsis, []),
([], slice(None), IndexError),
Expand All @@ -140,3 +140,35 @@ def test_sub_domain(domain, index, expected):
expected = common.domain(expected)
result = sub_domain(domain, index)
assert result == expected


@pytest.fixture
def finite_domain():
I = common.Dimension("I")
J = common.Dimension("J")
return common.Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4)))


@pytest.fixture
def infinite_domain():
I = common.Dimension("I")
return common.Domain((I, UnitRange.infinity()))


@pytest.fixture
def mixed_domain():
I = common.Dimension("I")
J = common.Dimension("J")
return common.Domain((I, UnitRange(-1, 3)), (J, UnitRange.infinity()))


def test_finite_domain_is_finite(finite_domain):
assert finite_domain.is_finite() == True


def test_infinite_domain_is_finite(infinite_domain):
assert infinite_domain.is_finite() == False


def test_mixed_domain_is_finite(mixed_domain):
assert mixed_domain.is_finite() == False

0 comments on commit 7cefa7e

Please sign in to comment.