diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index ecdcc84fc1..8cd2d8ce15 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -19,7 +19,6 @@ from gt4py.next.embedded import exceptions as embedded_exceptions -# TODO: handle 0-D empty domain case. If Ellipsis should give back domain and in all other cases error. def sub_domain(domain: common.Domain, index: common.AnyIndexSpec) -> common.Domain: index_sequence = common.as_any_index_sequence(index) diff --git a/src/gt4py/next/embedded/function_field.py b/src/gt4py/next/embedded/function_field.py index 938f9b4c63..b3d215da11 100644 --- a/src/gt4py/next/embedded/function_field.py +++ b/src/gt4py/next/embedded/function_field.py @@ -33,34 +33,35 @@ @dataclasses.dataclass(frozen=True) class FunctionField(common.Field[common.DimsT, core_defs.ScalarT], common.FieldBuiltinFuncRegistry): """ - A `FunctionField` represents a field of values generated by a callable function - over a specified domain. The function supplied to the `func` parameter will be - used to create the ndarray when accessing the `ndarray` property. The result of - calling `ndarray` will be the same as using `np.fromfunction` with the provided - function. - - Args: - func (Callable): The callable function that generates field values. - domain (common.Domain, optional): The domain over which the function is defined. - Defaults to an empty domain. - _skip_invariant (bool, optional): Internal flag to skip invariant checks. - Defaults to False. - - Examples: - Create a FunctionField and compute its ndarray: - - >>> import numpy as np - >>> from gt4py.next import common - >>> from gt4py.next.embedded.function_field import FunctionField - >>> I = common.Dimension("I") - >>> domain = common.Domain((I, common.UnitRange(0, 5))) - >>> func = lambda i: i ** 2 - >>> field = FunctionField(func, domain) - >>> ndarray = field.ndarray - >>> expected_ndarray = np.fromfunction(func, (5,)) - >>> np.array_equal(ndarray, expected_ndarray) - True - """ + A `FunctionField` represents a field of values generated by a callable function + over a specified domain. The function supplied to the `func` parameter will be + used to create the ndarray when accessing the `ndarray` property. The result of + calling `ndarray` will be the same as using `np.fromfunction` with the provided + function. + + Args: + func (Callable): The callable function that generates field values. + domain (common.Domain, optional): The domain over which the function is defined. + Defaults to an empty domain. + _skip_invariant (bool, optional): Internal flag to skip invariant checks. + Defaults to False. + + Examples: + Create a FunctionField and compute its ndarray: + + >>> import numpy as np + >>> from gt4py.next import common + >>> from gt4py.next.embedded.function_field import FunctionField + >>> I = common.Dimension("I") + >>> domain = common.Domain((I, common.UnitRange(0, 5))) + >>> func = lambda i: i ** 2 + >>> field = FunctionField(func, domain) + >>> ndarray = field.ndarray + >>> expected_ndarray = np.fromfunction(func, (5,)) + >>> np.array_equal(ndarray, expected_ndarray) + True + """ + func: Callable domain: common.Domain = common.Domain() _skip_invariant: bool = False @@ -93,7 +94,6 @@ def ndarray(self) -> core_defs.NDArrayObject: raise embedded_exceptions.InfiniteRangeNdarrayError( self.__class__.__name__, self.domain ) - shape = [len(rng) for rng in self.domain.ranges] return np.fromfunction(self.func, shape) @@ -208,7 +208,6 @@ def remap(self, *args, **kwargs) -> common.Field: raise NotImplementedError("Method remap not implemented") - def _compose(operation: Callable, *fields: FunctionField) -> Callable: return lambda *args: operation(*[f.func(*args) for f in fields]) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_function_field.py b/tests/next_tests/unit_tests/embedded_tests/test_function_field.py index 922a87dcf6..b20bee93e1 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_function_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_function_field.py @@ -77,13 +77,22 @@ def test_constant_field_no_domain(op_func, expected_result): assert result.func() == expected_result -@pytest.mark.parametrize( - "index", [((I, UnitRange(0, 10)),), common.Domain(dims=(I,), ranges=(UnitRange(0, 10),))] +@pytest.fixture( + params=[((I, UnitRange(0, 10)),), common.Domain(dims=(I,), ranges=(UnitRange(0, 10),))] ) -def test_constant_field_getitem_missing_domain(index): +def test_index(request): + return request.param + + +def test_constant_field_getitem_missing_domain(test_index): cf = funcf.constant_field(10) with pytest.raises(embedded_exceptions.IndexOutOfBounds): - cf[index] + cf[test_index] + + +def test_constant_field_getitem_missing_domain_ellipsis(test_index): + cf = funcf.constant_field(10) + cf[...].domain == cf.domain @pytest.mark.parametrize( @@ -175,9 +184,7 @@ def test_function_field_logical_operators(op_func): func1 = lambda x, y: x > 5 func2 = lambda y: y < 10 - domain1 = common.Domain( - dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10)) - ) + domain1 = common.Domain(dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10))) domain2 = common.Domain(dims=(J,), ranges=(common.UnitRange(7, 15),)) ff1 = funcf.FunctionField(func1, domain1)