Skip to content

Commit

Permalink
Implement remaining methods
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals committed Sep 6, 2023
1 parent 37abfeb commit 08b40d0
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 30 deletions.
51 changes: 26 additions & 25 deletions src/gt4py/next/embedded/function_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,30 @@ def __gt__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
def __ge__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
return self._binary_operation(operator.ge, other)

def __and__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
return self._binary_operation(operator.and_, other)

def __or__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
return self._binary_operation(operator.or_, other)

def __xor__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
return self._binary_operation(operator.xor, other)

def __radd__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
return self._binary_operation(lambda x, y: y + x, other)

def __rfloordiv__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
return self._binary_operation(lambda x, y: y // x, other)

def __rmul__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
return self._binary_operation(lambda x, y: y * x, other)

def __rsub__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
return self._binary_operation(lambda x, y: y - x, other)

def __rtruediv__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
return self._binary_operation(lambda x, y: y / x, other)

def __pos__(self) -> common.Field:
return self._unary_op(operator.pos)

Expand All @@ -177,37 +201,14 @@ def __invert__(self) -> common.Field:
def __abs__(self) -> common.Field:
return self._unary_op(abs)

def __and__(self, other) -> common.Field:
raise NotImplementedError("Method __and__ not implemented")

def __call__(self, *args, **kwargs) -> common.Field:
raise NotImplementedError("Method __call__ not implemented")

def __or__(self, other) -> common.Field:
raise NotImplementedError("Method __or__ not implemented")

def __radd__(self, other) -> common.Field:
raise NotImplementedError("Method __radd__ not implemented")

def __rfloordiv__(self, other) -> common.Field:
raise NotImplementedError("Method __rfloordiv__ not implemented")

def __rmul__(self, other) -> common.Field:
raise NotImplementedError("Method __rmul__ not implemented")

def __rsub__(self, other) -> common.Field:
raise NotImplementedError("Method __rsub__ not implemented")

def __rtruediv__(self, other) -> common.Field:
raise NotImplementedError("Method __rtruediv__ not implemented")

def __xor__(self, other) -> common.Field:
raise NotImplementedError("Method __xor__ not implemented")
return self.func(*args, **kwargs)

def remap(self, *args, **kwargs) -> common.Field:
raise NotImplementedError("Method remap not implemented")



def _compose(operation: Callable, *fields: FunctionField) -> Callable:
return lambda *args: operation(*[f.func(*args) for f in fields])

Expand Down
36 changes: 31 additions & 5 deletions tests/next_tests/unit_tests/embedded_tests/test_function_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,19 @@ def rfloordiv(x, y):
operator.mul,
operator.truediv,
operator.floordiv,
rfloordiv,
operator.pow,
lambda x, y: operator.truediv(y, x),
operator.add,
operator.mul,
lambda x, y: operator.sub(y, x),
operator.pow,
lambda x, y: operator.truediv(y, x), # Reverse true division
lambda x, y: operator.add(y, x), # Reverse addition
lambda x, y: operator.mul(y, x), # Reverse multiplication
lambda x, y: operator.sub(y, x), # Reverse subtraction
lambda x, y: operator.floordiv(y, x), # Reverse floor division
]

logical_operators = [
operator.xor,
operator.and_,
operator.or_,
]


Expand Down Expand Up @@ -163,6 +170,25 @@ def test_function_field_broadcast(op_func):
assert isinstance(result.ndarray, np.ndarray)


@pytest.mark.parametrize("op_func", logical_operators)
def test_function_field_logical_operators(op_func):
func1 = lambda x, y: x > 5
func2 = lambda y: y < 10

domain1 = common.Domain(
dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10))
)
domain2 = common.Domain(dims=(J,), ranges=(common.UnitRange(7, 15),))

ff1 = funcf.FunctionField(func1, domain1)
ff2 = funcf.FunctionField(func2, domain2)

result = op_func(ff1, ff2)

assert result.func(5, 10) == op_func(func1(5, 10), func2(10))
assert isinstance(result.ndarray, np.ndarray)


@pytest.mark.parametrize(
"domain,expected_shape",
[
Expand Down

0 comments on commit 08b40d0

Please sign in to comment.