diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index ba55c517b6..ecdcc84fc1 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -15,8 +15,6 @@ from typing import Any, Optional, Sequence, cast -import numpy as np - from gt4py.next import common from gt4py.next.embedded import exceptions as embedded_exceptions @@ -135,19 +133,7 @@ def _find_index_of_dim( return None -def _compute_domain_slice( - field: common.Field, new_dimensions: tuple[common.Dimension, ...] -) -> Sequence[slice | None]: - domain_slice: list[slice | None] = [] - for dim in new_dimensions: - if _find_index_of_dim(dim, field.domain) is not None: - domain_slice.append(slice(None)) - else: - domain_slice.append(np.newaxis) - return domain_slice - - -def _compute_named_ranges( +def _broadcast_domain( field: common.Field, new_dimensions: tuple[common.Dimension, ...] ) -> Sequence[common.NamedRange]: named_ranges = [] diff --git a/src/gt4py/next/embedded/exceptions.py b/src/gt4py/next/embedded/exceptions.py index 5153ddc965..c86912fc87 100644 --- a/src/gt4py/next/embedded/exceptions.py +++ b/src/gt4py/next/embedded/exceptions.py @@ -49,12 +49,11 @@ def __init__(self, cls_name: str): self.cls_name = cls_name -class InvalidDomainForNdarrayError(gt4py_exceptions.GT4PyError): - def __init__(self, cls_name: str): - super().__init__( - f"Error in `{cls_name}`: Cannot construct an ndarray with an empty domain." - ) +class FunctionFieldError(gt4py_exceptions.GT4PyError): + def __init__(self, cls_name: str, msg: str): + super().__init__(f"Error in `{cls_name}`: {msg}.") self.cls_name = cls_name + self.msg = msg class InfiniteRangeNdarrayError(gt4py_exceptions.GT4PyError): diff --git a/src/gt4py/next/embedded/function_field.py b/src/gt4py/next/embedded/function_field.py index cb095f5f72..a15e3d39f8 100644 --- a/src/gt4py/next/embedded/function_field.py +++ b/src/gt4py/next/embedded/function_field.py @@ -42,11 +42,15 @@ def __post_init__(self): try: func_params = self.func.__code__.co_argcount if func_params != num_params: - raise ValueError( - f"Invariant violation: len(self.domain) ({num_params}) does not match the number of parameters of the self.func ({func_params})." + raise embedded_exceptions.FunctionFieldError( + self.__class__.__name__, + f"Invariant violation: len(self.domain) ({num_params}) does not match the number of parameters of the provided function ({func_params})", ) except AttributeError: - raise ValueError(f"Must pass a function as an argument to self.func.") + raise embedded_exceptions.FunctionFieldError( + self.__class__.__name__, + f"Invalid first argument type: Expected a function but got {self.func}", + ) def restrict(self, index: common.AnyIndexSpec) -> FunctionField: new_domain = embedded_common.sub_domain(self.domain, index) @@ -57,10 +61,11 @@ def restrict(self, index: common.AnyIndexSpec) -> FunctionField: @property def ndarray(self) -> core_defs.NDArrayObject: if not self.domain.is_finite(): - embedded_exceptions.InfiniteRangeNdarrayError(self.__class__.__name__, self.domain) + raise embedded_exceptions.InfiniteRangeNdarrayError( + self.__class__.__name__, self.domain + ) shape = [len(rng) for rng in self.domain.ranges] - return np.fromfunction(self.func, shape) def _handle_function_field_op(self, other: FunctionField, op: Callable) -> FunctionField: @@ -90,7 +95,7 @@ def _binary_operation(self, op: Callable, other: common.Field) -> common.Field: def _binary_operation(self, op, other): if isinstance(other, self.__class__): return self._handle_function_field_op(other, op) - elif isinstance(other, (int, float)): # Handle scalar values + elif isinstance(other, (int, float)): return self._handle_scalar_op(other, op) else: return op(other, self) @@ -183,7 +188,7 @@ def broadcasted_func(*args: int): selected_args = [args[i] for i, dim in enumerate(dims) if dim in field.domain.dims] return field.func(*selected_args) - named_ranges = embedded_common._compute_named_ranges(field, dims) + named_ranges = embedded_common._broadcast_domain(field, dims) return FunctionField(broadcasted_func, common.Domain(*named_ranges), _skip_invariant=True) diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 9ba84ab697..2656519ea9 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -17,7 +17,7 @@ import dataclasses from collections.abc import Callable, Sequence from types import ModuleType -from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar +from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar, Sequence import numpy as np from numpy import typing as npt @@ -25,7 +25,7 @@ from gt4py._core import definitions as core_defs from gt4py.next import common from gt4py.next.embedded import common as embedded_common -from gt4py.next.embedded.common import _compute_domain_slice, _compute_named_ranges +from gt4py.next.embedded.common import _broadcast_domain, _find_index_of_dim from gt4py.next.ffront import fbuiltins @@ -335,7 +335,7 @@ def __setitem__( def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field: domain_slice = _compute_domain_slice(field, new_dimensions) - named_ranges = _compute_named_ranges(field, new_dimensions) + named_ranges = _broadcast_domain(field, new_dimensions) # handle case where we have a constant FunctionField where ndarray is a scalar if isinstance(value := field.ndarray, (int, float)): @@ -415,3 +415,15 @@ def _compute_slice( return rng - domain.ranges[pos].start else: raise ValueError(f"Can only use integer or UnitRange ranges, provided type: {type(rng)}") + + +def _compute_domain_slice( + field: common.Field, new_dimensions: tuple[common.Dimension, ...] +) -> Sequence[slice | None]: + domain_slice: list[slice | None] = [] + for dim in new_dimensions: + if _find_index_of_dim(dim, field.domain) is not None: + domain_slice.append(slice(None)) + else: + domain_slice.append(np.newaxis) + return domain_slice diff --git a/tests/next_tests/unit_tests/embedded_tests/__init__.py b/tests/next_tests/unit_tests/embedded_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/next_tests/unit_tests/embedded_tests/test_common.py b/tests/next_tests/unit_tests/embedded_tests/test_common.py index 7c35be6447..3a4f5fd66b 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_common.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_common.py @@ -65,64 +65,64 @@ def test_slice_range(rng, slce, expected): ([(I, (-2, 3))], (I, 3), IndexError), ([(I, (-2, 3))], (I, UnitRange(3, 4)), IndexError), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - 2, - [(J, (3, 6)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + 2, + [(J, (3, 6)), (K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - slice(2, 3), - [(I, (4, 5)), (J, (3, 6)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + slice(2, 3), + [(I, (4, 5)), (J, (3, 6)), (K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (I, 2), - [(J, (3, 6)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (I, 2), + [(J, (3, 6)), (K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (I, UnitRange(2, 3)), - [(I, (2, 3)), (J, (3, 6)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (I, UnitRange(2, 3)), + [(I, (2, 3)), (J, (3, 6)), (K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (J, 3), - [(I, (2, 5)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (J, 3), + [(I, (2, 5)), (K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (J, UnitRange(4, 5)), - [(I, (2, 5)), (J, (4, 5)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (J, UnitRange(4, 5)), + [(I, (2, 5)), (J, (4, 5)), (K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - ((J, 3), (I, 2)), - [(K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + ((J, 3), (I, 2)), + [(K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - ((J, UnitRange(4, 5)), (I, 2)), - [(J, (4, 5)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + ((J, UnitRange(4, 5)), (I, 2)), + [(J, (4, 5)), (K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (slice(1, 2), slice(2, 3)), - [(I, (3, 4)), (J, (5, 6)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), slice(2, 3)), + [(I, (3, 4)), (J, (5, 6)), (K, (4, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (Ellipsis, slice(2, 3)), - [(I, (2, 5)), (J, (3, 6)), (K, (6, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (Ellipsis, slice(2, 3)), + [(I, (2, 5)), (J, (3, 6)), (K, (6, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (slice(1, 2), Ellipsis, slice(2, 3)), - [(I, (3, 4)), (J, (3, 6)), (K, (6, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), Ellipsis, slice(2, 3)), + [(I, (3, 4)), (J, (3, 6)), (K, (6, 7))], ), ( - [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], - (slice(1, 2), slice(1, 2), Ellipsis), - [(I, (3, 4)), (J, (4, 5)), (K, (4, 7))], + [(I, (2, 5)), (J, (3, 6)), (K, (4, 7))], + (slice(1, 2), slice(1, 2), Ellipsis), + [(I, (3, 4)), (J, (4, 5)), (K, (4, 7))], ), ([], Ellipsis, []), ([], slice(None), IndexError), @@ -144,21 +144,16 @@ def test_sub_domain(domain, index, expected): @pytest.fixture def finite_domain(): - I = common.Dimension("I") - J = common.Dimension("J") return common.Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4))) @pytest.fixture def infinite_domain(): - I = common.Dimension("I") - return common.Domain((I, UnitRange.infinity())) + return common.Domain((I, UnitRange.infinity()), (J, UnitRange.infinity())) @pytest.fixture def mixed_domain(): - I = common.Dimension("I") - J = common.Dimension("J") return common.Domain((I, UnitRange(-1, 3)), (J, UnitRange.infinity())) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_function_field.py b/tests/next_tests/unit_tests/embedded_tests/test_function_field.py index 32403269ef..6486fc7bfa 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_function_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_function_field.py @@ -21,6 +21,7 @@ from gt4py.next.common import Dimension, UnitRange from gt4py.next.embedded import exceptions as embedded_exceptions, function_field as funcf +from .test_common import mixed_domain, infinite_domain I = Dimension("I") J = Dimension("J") @@ -31,6 +32,21 @@ def rfloordiv(x, y): return operator.floordiv(y, x) +operators = [ + operator.add, + operator.sub, + operator.mul, + operator.truediv, + operator.floordiv, + rfloordiv, + operator.pow, + lambda x, y: operator.truediv(y, x), + operator.add, + operator.mul, + lambda x, y: operator.sub(y, x), +] + + @pytest.mark.parametrize( "op_func, expected_result", [ @@ -54,69 +70,6 @@ def test_constant_field_no_domain(op_func, expected_result): assert result.func() == expected_result -@pytest.mark.parametrize( - "op_func", - [ - operator.add, - operator.sub, - operator.mul, - operator.truediv, - operator.floordiv, - rfloordiv, - operator.pow, - lambda x, y: operator.truediv(y, x), - operator.add, - operator.mul, - lambda x, y: operator.sub(y, x), - ], -) -def test_function_field_no_domain(op_func): - func1 = lambda x, y: x + y - func2 = lambda x, y: 2 * x + y - - domain = common.Domain(*((I, UnitRange(5, 10)), (J, UnitRange(10, 15)))) - - ff1 = funcf.FunctionField(func1, domain) - ff2 = funcf.FunctionField(func2, domain) - - result = op_func(ff1, ff2) - - assert result.func(1, 2) == op_func(func1(1, 2), func2(1, 2)) - assert isinstance(result.ndarray, np.ndarray) - - -@pytest.mark.parametrize( - "op_func", - [ - operator.add, - operator.sub, - operator.mul, - operator.truediv, - operator.floordiv, - rfloordiv, - operator.pow, - lambda x, y: operator.truediv(y, x), - operator.add, - operator.mul, - lambda x, y: operator.sub(y, x), - ], -) -def test_function_field_broadcast(op_func): - func1 = lambda x, y: x + y - func2 = lambda y: 2 * y - - domain1 = common.Domain(dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10))) - domain2 = common.Domain(dims=(J,), ranges=(common.UnitRange(7, 15),)) - - ff1 = funcf.FunctionField(func1, domain1) - ff2 = funcf.FunctionField(func2, domain2) - - result = op_func(ff1, ff2) - - assert result.func(5, 10) == op_func(func1(5, 10), func2(10)) - assert isinstance(result.ndarray, np.ndarray) - - @pytest.mark.parametrize( "index", [((I, UnitRange(0, 10)),), common.Domain(dims=(I,), ranges=(UnitRange(0, 10),))] ) @@ -127,31 +80,28 @@ def test_constant_field_getitem_missing_domain(index): @pytest.mark.parametrize( - "domain,expected_shape", + "domain", [ - (common.Domain(dims=(I, J), ranges=(UnitRange(3, 13), UnitRange(-5, 5))), (10, 10)), - ( - common.Domain( - dims=(I, J, K), - ranges=(UnitRange(-6, -3), UnitRange(-5, 10), UnitRange(1, 2)), - ), - (3, 15, 1), + common.Domain(dims=(I, J), ranges=(UnitRange(3, 13), UnitRange(-5, 5))), + common.Domain( + dims=(I, J, K), ranges=(UnitRange(-6, -3), UnitRange(-5, 10), UnitRange(1, 2)) ), ], ) -def test_constant_field_ndarray(domain, expected_shape): +def test_constant_field_ndarray(domain): cf = funcf.constant_field(10, domain) assert isinstance(cf.ndarray, int) assert cf.ndarray == 10 -def test_constant_field_empty_domain_op(): +def test_constant_field_and_field_op(): domain = common.Domain(dims=(I, J), ranges=(UnitRange(3, 13), UnitRange(-5, 5))) field = common.field(np.ones((10, 10)), domain=domain) cf = funcf.constant_field(10) result = cf + field assert np.allclose(result.ndarray, 11) + assert result.domain == domain binary_op_field_intersection_cases = [ @@ -193,6 +143,26 @@ def adder(i, j): return i + j +@pytest.mark.parametrize( + "op_func", + operators, +) +def test_function_field_broadcast(op_func): + func1 = lambda x, y: x + y + func2 = lambda y: 2 * y + + domain1 = common.Domain(dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10))) + domain2 = common.Domain(dims=(J,), ranges=(common.UnitRange(7, 15),)) + + ff1 = funcf.FunctionField(func1, domain1) + ff2 = funcf.FunctionField(func2, domain2) + + result = op_func(ff1, ff2) + + assert result.func(5, 10) == op_func(func1(5, 10), func2(10)) + assert isinstance(result.ndarray, np.ndarray) + + @pytest.mark.parametrize( "domain,expected_shape", [ @@ -225,13 +195,8 @@ def test_function_field_with_field(domain): assert result.ndarray.shape == (10, 10) assert np.allclose(result.ndarray, expected_values) -@pytest.fixture -def function_field(): - return funcf.FunctionField(adder, domain=common.Domain( - dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10)) - )) -def test_function_field_addition(): +def test_function_field_function_field_op(): res = funcf.FunctionField( lambda x, y: x + 42 * y, domain=common.Domain( @@ -244,8 +209,17 @@ def test_function_field_addition(): assert res.func(1, 2) == 89 -def test_function_field_unary(function_field): +@pytest.fixture +def function_field(): + return funcf.FunctionField( + adder, + domain=common.Domain( + dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10)) + ), + ) + +def test_function_field_unary(function_field): pos_result = +function_field assert pos_result.func(1, 2) == 3 @@ -262,3 +236,29 @@ def test_function_field_unary(function_field): def test_function_field_scalar_op(function_field): new = function_field * 5.0 assert new.func(1, 2) == 15 + + +@pytest.mark.parametrize("func", ["foo", 1.0, 1]) +def test_function_field_invalid_func(func): + with pytest.raises(embedded_exceptions.FunctionFieldError, match="Invalid first argument type"): + funcf.FunctionField(func) + + +@pytest.mark.parametrize( + "domain", + [ + common.Domain(), + common.Domain(*((I, UnitRange(1, 10)), (J, UnitRange(5, 10)))), + ], +) +def test_function_field_invalid_invariant(domain): + with pytest.raises(embedded_exceptions.FunctionFieldError, match="Invariant violation"): + funcf.FunctionField(lambda x: x, domain) + + +def test_function_field_infinite_range(infinite_domain, mixed_domain): + domains = [infinite_domain, mixed_domain] + for d in domains: + with pytest.raises(embedded_exceptions.InfiniteRangeNdarrayError): + ff = funcf.FunctionField(adder, d) + ff.ndarray