Skip to content

Commit

Permalink
Add invariant
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals committed Sep 6, 2023
1 parent b29c355 commit a6c4e14
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 45 deletions.
44 changes: 25 additions & 19 deletions src/gt4py/next/embedded/function_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@
class FunctionField(common.Field[common.DimsT, core_defs.ScalarT], common.FieldBuiltinFuncRegistry):
func: Callable
domain: common.Domain = common.Domain()
_skip_invariant: bool = False

def __post_init__(self):
if not self._skip_invariant:
num_params = len(self.domain)
try:
func_params = self.func.__code__.co_argcount
if func_params != num_params:
raise ValueError(
f"Invariant violation: len(self.domain) ({num_params}) does not match the number of parameters of the self.func ({func_params}).")
except AttributeError:
raise ValueError(f"Must pass a function as an argument to self.func.")

def restrict(self, index: common.AnyIndexSpec) -> FunctionField:
new_domain = embedded_common.sub_domain(self.domain, index)
Expand All @@ -43,9 +55,6 @@ def restrict(self, index: common.AnyIndexSpec) -> FunctionField:

@property
def ndarray(self) -> core_defs.NDArrayObject:
if _has_empty_domain(self):
raise embedded_exceptions.InvalidDomainForNdarrayError(self.__class__.__name__)

if not self.domain.is_finite():
embedded_exceptions.InfiniteRangeNdarrayError(self.__class__.__name__, self.domain)

Expand All @@ -60,12 +69,12 @@ def _handle_function_field_op(
broadcasted_self = _broadcast(self, domain_intersection.dims)
broadcasted_other = _broadcast(other, domain_intersection.dims)
return self.__class__(
_compose(op, broadcasted_self, broadcasted_other), domain_intersection
_compose(op, broadcasted_self, broadcasted_other), domain_intersection, _skip_invariant=True
)

def _handle_scalar_op(self, other: FunctionField, op: Callable) -> FunctionField:
new_func = lambda *args: op(self.func(*args), other)
return self.__class__(new_func, self.domain)
return self.__class__(new_func, self.domain, _skip_invariant=True) # skip invariant as we cannot deduce number of args

@overload
def _binary_operation(self, op: Callable, other: core_defs.ScalarT) -> common.Field:
Expand All @@ -83,6 +92,9 @@ def _binary_operation(self, op, other):
else:
return op(other, self)

def _unary_op(self, op: Callable) -> FunctionField:
return self.__class__(_compose(op, self), self.domain, _skip_invariant=True)

def __add__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
return self._binary_operation(operator.add, other)

Expand Down Expand Up @@ -117,16 +129,16 @@ def __ge__(self, other: common.Field | core_defs.ScalarT) -> common.Field:
return self._binary_operation(operator.ge, other)

def __pos__(self) -> common.Field:
return self.__class__(_compose(operator.pos, self), self.domain)
return self._unary_op(operator.pos)

def __neg__(self) -> common.Field:
return self.__class__(_compose(operator.neg, self), self.domain)
return self._unary_op(operator.neg)

def __invert__(self) -> common.Field:
return self.__class__(_compose(operator.invert, self), self.domain)
return self._unary_op(operator.invert)

def __abs__(self) -> common.Field:
return self.__class__(_compose(abs, self), self.domain)
return self._unary_op(abs)

def __and__(self, other) -> common.Field:
raise NotImplementedError("Method __and__ not implemented")
Expand Down Expand Up @@ -165,27 +177,21 @@ def _compose(operation: Callable, *fields: FunctionField) -> Callable:

def _broadcast(field: FunctionField, dims: tuple[common.Dimension, ...]) -> FunctionField:
def broadcasted_func(*args: int):
if not _has_empty_domain(field):
selected_args = [args[i] for i, dim in enumerate(dims) if dim in field.domain.dims]
return field.func(*selected_args)
return field.func(*args)
selected_args = [args[i] for i, dim in enumerate(dims) if dim in field.domain.dims]
return field.func(*selected_args)

named_ranges = embedded_common._compute_named_ranges(field, dims)
return FunctionField(broadcasted_func, common.Domain(*named_ranges))
return FunctionField(broadcasted_func, common.Domain(*named_ranges), _skip_invariant=True)


def _is_nd_array(other: Any) -> TypeGuard[nd._BaseNdArrayField]:
return isinstance(other, nd._BaseNdArrayField)


def _has_empty_domain(field: common.Field) -> bool:
return len(field.domain) < 1


def constant_field(
value: core_defs.ScalarT, domain: common.Domain = common.Domain()
) -> common.Field:
return FunctionField(lambda *args: value, domain)
return FunctionField(lambda *args: value, domain, _skip_invariant=True)


FunctionField.register_builtin_func(fbuiltins.broadcast, _broadcast)
47 changes: 21 additions & 26 deletions tests/next_tests/unit_tests/embedded_tests/test_function_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def test_constant_field_no_domain(op_func, expected_result):
cf2 = funcf.constant_field(20)
result = op_func(cf1, cf2)
assert result.func() == expected_result
with pytest.raises(embedded_exceptions.InvalidDomainForNdarrayError):
result.ndarray


@pytest.mark.parametrize(
Expand All @@ -76,14 +74,15 @@ def test_function_field_no_domain(op_func):
func1 = lambda x, y: x + y
func2 = lambda x, y: 2 * x + y

ff1 = funcf.FunctionField(func1)
ff2 = funcf.FunctionField(func2)
domain = common.Domain(*((I, UnitRange(5, 10)), (J, UnitRange(10, 15))))

ff1 = funcf.FunctionField(func1, domain)
ff2 = funcf.FunctionField(func2, domain)

result = op_func(ff1, ff2)

assert result.func(1, 2) == op_func(func1(1, 2), func2(1, 2))
with pytest.raises(embedded_exceptions.InvalidDomainForNdarrayError):
result.ndarray
assert isinstance(result.ndarray, np.ndarray)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -151,8 +150,8 @@ def test_constant_field_empty_domain_op():
field = common.field(np.ones((10, 10)), domain=domain)
cf = funcf.constant_field(10)

with pytest.raises(embedded_exceptions.InvalidDomainForNdarrayError):
cf + field
result = cf + field
assert np.allclose(result.ndarray, 11)


binary_op_field_intersection_cases = [
Expand Down Expand Up @@ -190,21 +189,14 @@ def test_constant_field_non_empty_domain_op(
assert np.all(result.ndarray == expected_value)


def adder(i, j, k=None):
def adder(i, j):
return i + j


@pytest.mark.parametrize(
"domain,expected_shape",
[
(common.Domain(dims=(I, J), ranges=(UnitRange(3, 13), UnitRange(-5, 5))), (10, 10)),
(
common.Domain(
dims=(I, J, K),
ranges=(UnitRange(-6, -3), UnitRange(-5, 10), UnitRange(1, 2)),
),
(3, 15, 1),
),
],
)
def test_function_field_ndarray(domain, expected_shape):
Expand Down Expand Up @@ -233,6 +225,11 @@ def test_function_field_with_field(domain):
assert result.ndarray.shape == (10, 10)
assert np.allclose(result.ndarray, expected_values)

@pytest.fixture
def function_field():
return funcf.FunctionField(adder, domain=common.Domain(
dims=(I, J), ranges=(common.UnitRange(1, 10), common.UnitRange(5, 10))
))

def test_function_field_addition():
res = funcf.FunctionField(
Expand All @@ -247,23 +244,21 @@ def test_function_field_addition():
assert res.func(1, 2) == 89


def test_function_field_unary():
ff = funcf.FunctionField(adder)
def test_function_field_unary(function_field):

pos_result = +ff
pos_result = +function_field
assert pos_result.func(1, 2) == 3

neg_result = -ff
neg_result = -function_field
assert neg_result.func(1, 2) == -3

invert_result = ~ff
invert_result = ~function_field
assert invert_result.func(1, 2) == -4

abs_result = abs(ff)
abs_result = abs(function_field)
assert abs_result.func(1, 2) == 3


def test_function_field_scalar_op():
ff = funcf.FunctionField(adder)
new = ff * 5.0
assert new.func(1, 2) == 15
def test_function_field_scalar_op(function_field):
new = function_field * 5.0
assert new.func(1, 2) == 15

0 comments on commit a6c4e14

Please sign in to comment.