Skip to content

Commit

Permalink
test ellipsis function field getitem
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals committed Sep 6, 2023
1 parent 08b40d0 commit fc8ec3b
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 38 deletions.
1 change: 0 additions & 1 deletion src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from gt4py.next.embedded import exceptions as embedded_exceptions


# TODO: handle 0-D empty domain case. If Ellipsis should give back domain and in all other cases error.
def sub_domain(domain: common.Domain, index: common.AnyIndexSpec) -> common.Domain:
index_sequence = common.as_any_index_sequence(index)

Expand Down
59 changes: 29 additions & 30 deletions src/gt4py/next/embedded/function_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,34 +33,35 @@
@dataclasses.dataclass(frozen=True)
class FunctionField(common.Field[common.DimsT, core_defs.ScalarT], common.FieldBuiltinFuncRegistry):
"""
A `FunctionField` represents a field of values generated by a callable function
over a specified domain. The function supplied to the `func` parameter will be
used to create the ndarray when accessing the `ndarray` property. The result of
calling `ndarray` will be the same as using `np.fromfunction` with the provided
function.
Args:
func (Callable): The callable function that generates field values.
domain (common.Domain, optional): The domain over which the function is defined.
Defaults to an empty domain.
_skip_invariant (bool, optional): Internal flag to skip invariant checks.
Defaults to False.
Examples:
Create a FunctionField and compute its ndarray:
>>> import numpy as np
>>> from gt4py.next import common
>>> from gt4py.next.embedded.function_field import FunctionField
>>> I = common.Dimension("I")
>>> domain = common.Domain((I, common.UnitRange(0, 5)))
>>> func = lambda i: i ** 2
>>> field = FunctionField(func, domain)
>>> ndarray = field.ndarray
>>> expected_ndarray = np.fromfunction(func, (5,))
>>> np.array_equal(ndarray, expected_ndarray)
True
"""
A `FunctionField` represents a field of values generated by a callable function
over a specified domain. The function supplied to the `func` parameter will be
used to create the ndarray when accessing the `ndarray` property. The result of
calling `ndarray` will be the same as using `np.fromfunction` with the provided
function.
Args:
func (Callable): The callable function that generates field values.
domain (common.Domain, optional): The domain over which the function is defined.
Defaults to an empty domain.
_skip_invariant (bool, optional): Internal flag to skip invariant checks.
Defaults to False.
Examples:
Create a FunctionField and compute its ndarray:
>>> import numpy as np
>>> from gt4py.next import common
>>> from gt4py.next.embedded.function_field import FunctionField
>>> I = common.Dimension("I")
>>> domain = common.Domain((I, common.UnitRange(0, 5)))
>>> func = lambda i: i ** 2
>>> field = FunctionField(func, domain)
>>> ndarray = field.ndarray
>>> expected_ndarray = np.fromfunction(func, (5,))
>>> np.array_equal(ndarray, expected_ndarray)
True
"""

func: Callable
domain: common.Domain = common.Domain()
_skip_invariant: bool = False
Expand Down Expand Up @@ -93,7 +94,6 @@ def ndarray(self) -> core_defs.NDArrayObject:
raise embedded_exceptions.InfiniteRangeNdarrayError(
self.__class__.__name__, self.domain
)

shape = [len(rng) for rng in self.domain.ranges]
return np.fromfunction(self.func, shape)

Expand Down Expand Up @@ -208,7 +208,6 @@ def remap(self, *args, **kwargs) -> common.Field:
raise NotImplementedError("Method remap not implemented")



def _compose(operation: Callable, *fields: FunctionField) -> Callable:
return lambda *args: operation(*[f.func(*args) for f in fields])

Expand Down
21 changes: 14 additions & 7 deletions tests/next_tests/unit_tests/embedded_tests/test_function_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,22 @@ def test_constant_field_no_domain(op_func, expected_result):
assert result.func() == expected_result


@pytest.mark.parametrize(
"index", [((I, UnitRange(0, 10)),), common.Domain(dims=(I,), ranges=(UnitRange(0, 10),))]
@pytest.fixture(
params=[((I, UnitRange(0, 10)),), common.Domain(dims=(I,), ranges=(UnitRange(0, 10),))]
)
def test_constant_field_getitem_missing_domain(index):
def test_index(request):
return request.param


def test_constant_field_getitem_missing_domain(test_index):
cf = funcf.constant_field(10)
with pytest.raises(embedded_exceptions.IndexOutOfBounds):
cf[index]
cf[test_index]


def test_constant_field_getitem_missing_domain_ellipsis(test_index):
cf = funcf.constant_field(10)
cf[...].domain == cf.domain


@pytest.mark.parametrize(
Expand Down Expand Up @@ -175,9 +184,7 @@ def test_function_field_logical_operators(op_func):
func1 = lambda x, y: x > 5
func2 = lambda y: y < 10

domain1 = common.Domain(
dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10))
)
domain1 = common.Domain(dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10)))
domain2 = common.Domain(dims=(J,), ranges=(common.UnitRange(7, 15),))

ff1 = funcf.FunctionField(func1, domain1)
Expand Down

0 comments on commit fc8ec3b

Please sign in to comment.