diff --git a/src/gt4py/next/embedded/function_field.py b/src/gt4py/next/embedded/function_field.py index ee9bb6f40a..b9236c14ed 100644 --- a/src/gt4py/next/embedded/function_field.py +++ b/src/gt4py/next/embedded/function_field.py @@ -34,6 +34,18 @@ class FunctionField(common.Field[common.DimsT, core_defs.ScalarT], common.FieldBuiltinFuncRegistry): func: Callable domain: common.Domain = common.Domain() + _skip_invariant: bool = False + + def __post_init__(self): + if not self._skip_invariant: + num_params = len(self.domain) + try: + func_params = self.func.__code__.co_argcount + if func_params != num_params: + raise ValueError( + f"Invariant violation: len(self.domain) ({num_params}) does not match the number of parameters of the self.func ({func_params}).") + except AttributeError: + raise ValueError(f"Must pass a function as an argument to self.func.") def restrict(self, index: common.AnyIndexSpec) -> FunctionField: new_domain = embedded_common.sub_domain(self.domain, index) @@ -43,9 +55,6 @@ def restrict(self, index: common.AnyIndexSpec) -> FunctionField: @property def ndarray(self) -> core_defs.NDArrayObject: - if _has_empty_domain(self): - raise embedded_exceptions.InvalidDomainForNdarrayError(self.__class__.__name__) - if not self.domain.is_finite(): embedded_exceptions.InfiniteRangeNdarrayError(self.__class__.__name__, self.domain) @@ -60,12 +69,12 @@ def _handle_function_field_op( broadcasted_self = _broadcast(self, domain_intersection.dims) broadcasted_other = _broadcast(other, domain_intersection.dims) return self.__class__( - _compose(op, broadcasted_self, broadcasted_other), domain_intersection + _compose(op, broadcasted_self, broadcasted_other), domain_intersection, _skip_invariant=True ) def _handle_scalar_op(self, other: FunctionField, op: Callable) -> FunctionField: new_func = lambda *args: op(self.func(*args), other) - return self.__class__(new_func, self.domain) + return self.__class__(new_func, self.domain, _skip_invariant=True) # skip invariant as we cannot deduce number of args @overload def _binary_operation(self, op: Callable, other: core_defs.ScalarT) -> common.Field: @@ -83,6 +92,9 @@ def _binary_operation(self, op, other): else: return op(other, self) + def _unary_op(self, op: Callable) -> FunctionField: + return self.__class__(_compose(op, self), self.domain, _skip_invariant=True) + def __add__(self, other: common.Field | core_defs.ScalarT) -> common.Field: return self._binary_operation(operator.add, other) @@ -117,16 +129,16 @@ def __ge__(self, other: common.Field | core_defs.ScalarT) -> common.Field: return self._binary_operation(operator.ge, other) def __pos__(self) -> common.Field: - return self.__class__(_compose(operator.pos, self), self.domain) + return self._unary_op(operator.pos) def __neg__(self) -> common.Field: - return self.__class__(_compose(operator.neg, self), self.domain) + return self._unary_op(operator.neg) def __invert__(self) -> common.Field: - return self.__class__(_compose(operator.invert, self), self.domain) + return self._unary_op(operator.invert) def __abs__(self) -> common.Field: - return self.__class__(_compose(abs, self), self.domain) + return self._unary_op(abs) def __and__(self, other) -> common.Field: raise NotImplementedError("Method __and__ not implemented") @@ -165,27 +177,21 @@ def _compose(operation: Callable, *fields: FunctionField) -> Callable: def _broadcast(field: FunctionField, dims: tuple[common.Dimension, ...]) -> FunctionField: def broadcasted_func(*args: int): - if not _has_empty_domain(field): - selected_args = [args[i] for i, dim in enumerate(dims) if dim in field.domain.dims] - return field.func(*selected_args) - return field.func(*args) + selected_args = [args[i] for i, dim in enumerate(dims) if dim in field.domain.dims] + return field.func(*selected_args) named_ranges = embedded_common._compute_named_ranges(field, dims) - return FunctionField(broadcasted_func, common.Domain(*named_ranges)) + return FunctionField(broadcasted_func, common.Domain(*named_ranges), _skip_invariant=True) def _is_nd_array(other: Any) -> TypeGuard[nd._BaseNdArrayField]: return isinstance(other, nd._BaseNdArrayField) -def _has_empty_domain(field: common.Field) -> bool: - return len(field.domain) < 1 - - def constant_field( value: core_defs.ScalarT, domain: common.Domain = common.Domain() ) -> common.Field: - return FunctionField(lambda *args: value, domain) + return FunctionField(lambda *args: value, domain, _skip_invariant=True) FunctionField.register_builtin_func(fbuiltins.broadcast, _broadcast) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_function_field.py b/tests/next_tests/unit_tests/embedded_tests/test_function_field.py index f16088a31c..32403269ef 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_function_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_function_field.py @@ -52,8 +52,6 @@ def test_constant_field_no_domain(op_func, expected_result): cf2 = funcf.constant_field(20) result = op_func(cf1, cf2) assert result.func() == expected_result - with pytest.raises(embedded_exceptions.InvalidDomainForNdarrayError): - result.ndarray @pytest.mark.parametrize( @@ -76,14 +74,15 @@ def test_function_field_no_domain(op_func): func1 = lambda x, y: x + y func2 = lambda x, y: 2 * x + y - ff1 = funcf.FunctionField(func1) - ff2 = funcf.FunctionField(func2) + domain = common.Domain(*((I, UnitRange(5, 10)), (J, UnitRange(10, 15)))) + + ff1 = funcf.FunctionField(func1, domain) + ff2 = funcf.FunctionField(func2, domain) result = op_func(ff1, ff2) assert result.func(1, 2) == op_func(func1(1, 2), func2(1, 2)) - with pytest.raises(embedded_exceptions.InvalidDomainForNdarrayError): - result.ndarray + assert isinstance(result.ndarray, np.ndarray) @pytest.mark.parametrize( @@ -151,8 +150,8 @@ def test_constant_field_empty_domain_op(): field = common.field(np.ones((10, 10)), domain=domain) cf = funcf.constant_field(10) - with pytest.raises(embedded_exceptions.InvalidDomainForNdarrayError): - cf + field + result = cf + field + assert np.allclose(result.ndarray, 11) binary_op_field_intersection_cases = [ @@ -190,7 +189,7 @@ def test_constant_field_non_empty_domain_op( assert np.all(result.ndarray == expected_value) -def adder(i, j, k=None): +def adder(i, j): return i + j @@ -198,13 +197,6 @@ def adder(i, j, k=None): "domain,expected_shape", [ (common.Domain(dims=(I, J), ranges=(UnitRange(3, 13), UnitRange(-5, 5))), (10, 10)), - ( - common.Domain( - dims=(I, J, K), - ranges=(UnitRange(-6, -3), UnitRange(-5, 10), UnitRange(1, 2)), - ), - (3, 15, 1), - ), ], ) def test_function_field_ndarray(domain, expected_shape): @@ -233,6 +225,11 @@ def test_function_field_with_field(domain): assert result.ndarray.shape == (10, 10) assert np.allclose(result.ndarray, expected_values) +@pytest.fixture +def function_field(): + return funcf.FunctionField(adder, domain=common.Domain( + dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10)) + )) def test_function_field_addition(): res = funcf.FunctionField( @@ -247,23 +244,21 @@ def test_function_field_addition(): assert res.func(1, 2) == 89 -def test_function_field_unary(): - ff = funcf.FunctionField(adder) +def test_function_field_unary(function_field): - pos_result = +ff + pos_result = +function_field assert pos_result.func(1, 2) == 3 - neg_result = -ff + neg_result = -function_field assert neg_result.func(1, 2) == -3 - invert_result = ~ff + invert_result = ~function_field assert invert_result.func(1, 2) == -4 - abs_result = abs(ff) + abs_result = abs(function_field) assert abs_result.func(1, 2) == 3 -def test_function_field_scalar_op(): - ff = funcf.FunctionField(adder) - new = ff * 5.0 - assert new.func(1, 2) == 15 \ No newline at end of file +def test_function_field_scalar_op(function_field): + new = function_field * 5.0 + assert new.func(1, 2) == 15