From 7cefa7ea5324cb7237882598bd00907570e6970b Mon Sep 17 00:00:00 2001 From: samkellerhals Date: Wed, 6 Sep 2023 14:42:53 +0200 Subject: [PATCH] test is_finite --- src/gt4py/next/embedded/common.py | 1 + src/gt4py/next/embedded/function_field.py | 15 ++- .../unit_tests/embedded_tests/test_common.py | 104 ++++++++++++------ 3 files changed, 78 insertions(+), 42 deletions(-) diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 66fc9b427e..ba55c517b6 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -20,6 +20,7 @@ from gt4py.next import common from gt4py.next.embedded import exceptions as embedded_exceptions + # TODO: handle 0-D empty domain case. If Ellipsis should give back domain and in all other cases error. def sub_domain(domain: common.Domain, index: common.AnyIndexSpec) -> common.Domain: index_sequence = common.as_any_index_sequence(index) diff --git a/src/gt4py/next/embedded/function_field.py b/src/gt4py/next/embedded/function_field.py index b9236c14ed..cb095f5f72 100644 --- a/src/gt4py/next/embedded/function_field.py +++ b/src/gt4py/next/embedded/function_field.py @@ -43,7 +43,8 @@ def __post_init__(self): func_params = self.func.__code__.co_argcount if func_params != num_params: raise ValueError( - f"Invariant violation: len(self.domain) ({num_params}) does not match the number of parameters of the self.func ({func_params}).") + f"Invariant violation: len(self.domain) ({num_params}) does not match the number of parameters of the self.func ({func_params})." + ) except AttributeError: raise ValueError(f"Must pass a function as an argument to self.func.") @@ -62,19 +63,21 @@ def ndarray(self) -> core_defs.NDArrayObject: return np.fromfunction(self.func, shape) - def _handle_function_field_op( - self, other: FunctionField, op: Callable - ) -> FunctionField: + def _handle_function_field_op(self, other: FunctionField, op: Callable) -> FunctionField: domain_intersection = self.domain & other.domain broadcasted_self = _broadcast(self, domain_intersection.dims) broadcasted_other = _broadcast(other, domain_intersection.dims) return self.__class__( - _compose(op, broadcasted_self, broadcasted_other), domain_intersection, _skip_invariant=True + _compose(op, broadcasted_self, broadcasted_other), + domain_intersection, + _skip_invariant=True, ) def _handle_scalar_op(self, other: FunctionField, op: Callable) -> FunctionField: new_func = lambda *args: op(self.func(*args), other) - return self.__class__(new_func, self.domain, _skip_invariant=True) # skip invariant as we cannot deduce number of args + return self.__class__( + new_func, self.domain, _skip_invariant=True + ) # skip invariant as we cannot deduce number of args @overload def _binary_operation(self, op: Callable, other: core_defs.ScalarT) -> common.Field: diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index bcff9a414b..7c35be6447 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -65,64 +65,64 @@ def test_slice_range(rng, slce, expected): ([(I, (-2, 3))], (I, 3), IndexError), ([(I, (-2, 3))], (I, UnitRange(3, 4)), IndexError), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - 2, - [(J, (3, 6)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + 2, + [(J, (3, 6)), (K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - slice(2, 3), - [(I, (4, 5)), (J, (3, 6)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + slice(2, 3), + [(I, (4, 5)), (J, (3, 6)), (K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (I, 2), - [(J, (3, 6)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (I, 2), + [(J, (3, 6)), (K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (I, UnitRange(2, 3)), - [(I, (2, 3)), (J, (3, 6)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (I, UnitRange(2, 3)), + [(I, (2, 3)), (J, (3, 6)), (K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (J, 3), - [(I, (2, 5)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (J, 3), + [(I, (2, 5)), (K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (J, UnitRange(4, 5)), - [(I, (2, 5)), (J, (4, 5)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (J, UnitRange(4, 5)), + [(I, (2, 5)), (J, (4, 5)), (K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - ((J, 3), (I, 2)), - [(K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + ((J, 3), (I, 2)), + [(K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - ((J, UnitRange(4, 5)), (I, 2)), - [(J, (4, 5)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + ((J, UnitRange(4, 5)), (I, 2)), + [(J, (4, 5)), (K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (slice(1, 2), slice(2, 3)), - [(I, (3, 4)), (J, (5, 6)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), slice(2, 3)), + [(I, (3, 4)), (J, (5, 6)), (K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (Ellipsis, slice(2, 3)), - [(I, (2, 5)), (J, (3, 6)), (K, (6, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (Ellipsis, slice(2, 3)), + [(I, (2, 5)), (J, (3, 6)), (K, (6, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (slice(1, 2), Ellipsis, slice(2, 3)), - [(I, (3, 4)), (J, (3, 6)), (K, (6, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), Ellipsis, slice(2, 3)), + [(I, (3, 4)), (J, (3, 6)), (K, (6, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (slice(1, 2), slice(1, 2), Ellipsis), - [(I, (3, 4)), (J, (4, 5)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), slice(1, 2), Ellipsis), + [(I, (3, 4)), (J, (4, 5)), (K, (4, 7))], ), ([], Ellipsis, []), ([], slice(None), IndexError), @@ -140,3 +140,35 @@ def test_sub_domain(domain, index, expected): expected = common.domain(expected) result = sub_domain(domain, index) assert result == expected + + +@pytest.fixture +def finite_domain(): + I = common.Dimension("I") + J = common.Dimension("J") + return common.Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4))) + + +@pytest.fixture +def infinite_domain(): + I = common.Dimension("I") + return common.Domain((I, UnitRange.infinity())) + + +@pytest.fixture +def mixed_domain(): + I = common.Dimension("I") + J = common.Dimension("J") + return common.Domain((I, UnitRange(-1, 3)), (J, UnitRange.infinity())) + + +def test_finite_domain_is_finite(finite_domain): + assert finite_domain.is_finite() == True + + +def test_infinite_domain_is_finite(infinite_domain): + assert infinite_domain.is_finite() == False + + +def test_mixed_domain_is_finite(mixed_domain): + assert mixed_domain.is_finite() == False