diff --git a/src/gt4py/next/embedded/function_field.py b/src/gt4py/next/embedded/function_field.py index 8078144a23..938f9b4c63 100644 --- a/src/gt4py/next/embedded/function_field.py +++ b/src/gt4py/next/embedded/function_field.py @@ -165,6 +165,30 @@ def __gt__(self, other: common.Field | core_defs.ScalarT) -> common.Field: def __ge__(self, other: common.Field | core_defs.ScalarT) -> common.Field: return self._binary_operation(operator.ge, other) + def __and__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.and_, other) + + def __or__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.or_, other) + + def __xor__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(operator.xor, other) + + def __radd__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(lambda x, y: y + x, other) + + def __rfloordiv__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(lambda x, y: y // x, other) + + def __rmul__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(lambda x, y: y * x, other) + + def __rsub__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(lambda x, y: y - x, other) + + def __rtruediv__(self, other: common.Field | core_defs.ScalarT) -> common.Field: + return self._binary_operation(lambda x, y: y / x, other) + def __pos__(self) -> common.Field: return self._unary_op(operator.pos) @@ -177,37 +201,14 @@ def __invert__(self) -> common.Field: def __abs__(self) -> common.Field: return self._unary_op(abs) - def __and__(self, other) -> common.Field: - raise NotImplementedError("Method __and__ not implemented") - def __call__(self, *args, **kwargs) -> common.Field: - raise NotImplementedError("Method __call__ not implemented") - - def __or__(self, other) -> common.Field: - raise NotImplementedError("Method __or__ not implemented") - - def __radd__(self, other) -> common.Field: - raise NotImplementedError("Method __radd__ not implemented") - - def __rfloordiv__(self, other) -> common.Field: - raise NotImplementedError("Method __rfloordiv__ not implemented") - - def __rmul__(self, other) -> common.Field: - raise NotImplementedError("Method __rmul__ not implemented") - - def __rsub__(self, other) -> common.Field: - raise NotImplementedError("Method __rsub__ not implemented") - - def __rtruediv__(self, other) -> common.Field: - raise NotImplementedError("Method __rtruediv__ not implemented") - - def __xor__(self, other) -> common.Field: - raise NotImplementedError("Method __xor__ not implemented") + return self.func(*args, **kwargs) def remap(self, *args, **kwargs) -> common.Field: raise NotImplementedError("Method remap not implemented") + def _compose(operation: Callable, *fields: FunctionField) -> Callable: return lambda *args: operation(*[f.func(*args) for f in fields]) diff --git a/tests/next_tests/unit_tests/embedded_tests/test_function_field.py b/tests/next_tests/unit_tests/embedded_tests/test_function_field.py index 6486fc7bfa..922a87dcf6 100644 --- a/tests/next_tests/unit_tests/embedded_tests/test_function_field.py +++ b/tests/next_tests/unit_tests/embedded_tests/test_function_field.py @@ -38,12 +38,19 @@ def rfloordiv(x, y): operator.mul, operator.truediv, operator.floordiv, - rfloordiv, - operator.pow, lambda x, y: operator.truediv(y, x), - operator.add, - operator.mul, - lambda x, y: operator.sub(y, x), + operator.pow, + lambda x, y: operator.truediv(y, x), # Reverse true division + lambda x, y: operator.add(y, x), # Reverse addition + lambda x, y: operator.mul(y, x), # Reverse multiplication + lambda x, y: operator.sub(y, x), # Reverse subtraction + lambda x, y: operator.floordiv(y, x), # Reverse floor division +] + +logical_operators = [ + operator.xor, + operator.and_, + operator.or_, ] @@ -163,6 +170,25 @@ def test_function_field_broadcast(op_func): assert isinstance(result.ndarray, np.ndarray) +@pytest.mark.parametrize("op_func", logical_operators) +def test_function_field_logical_operators(op_func): + func1 = lambda x, y: x > 5 + func2 = lambda y: y < 10 + + domain1 = common.Domain( + dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10)) + ) + domain2 = common.Domain(dims=(J,), ranges=(common.UnitRange(7, 15),)) + + ff1 = funcf.FunctionField(func1, domain1) + ff2 = funcf.FunctionField(func2, domain2) + + result = op_func(ff1, ff2) + + assert result.func(5, 10) == op_func(func1(5, 10), func2(10)) + assert isinstance(result.ndarray, np.ndarray) + + @pytest.mark.parametrize( "domain,expected_shape", [