Skip to content

Commit

Permalink
Use custom exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
samkellerhals committed Sep 6, 2023
1 parent 7cefa7e commit c5a026d
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 152 deletions.
16 changes: 1 addition & 15 deletions src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

from typing import Any, Optional, Sequence, cast

import numpy as np

from gt4py.next import common
from gt4py.next.embedded import exceptions as embedded_exceptions

Expand Down Expand Up @@ -135,19 +133,7 @@ def _find_index_of_dim(
return None


def _compute_domain_slice(
field: common.Field, new_dimensions: tuple[common.Dimension, ...]
) -> Sequence[slice | None]:
domain_slice: list[slice | None] = []
for dim in new_dimensions:
if _find_index_of_dim(dim, field.domain) is not None:
domain_slice.append(slice(None))
else:
domain_slice.append(np.newaxis)
return domain_slice


def _compute_named_ranges(
def _broadcast_domain(
field: common.Field, new_dimensions: tuple[common.Dimension, ...]
) -> Sequence[common.NamedRange]:
named_ranges = []
Expand Down
9 changes: 4 additions & 5 deletions src/gt4py/next/embedded/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,11 @@ def __init__(self, cls_name: str):
self.cls_name = cls_name


class InvalidDomainForNdarrayError(gt4py_exceptions.GT4PyError):
def __init__(self, cls_name: str):
super().__init__(
f"Error in `{cls_name}`: Cannot construct an ndarray with an empty domain."
)
class FunctionFieldError(gt4py_exceptions.GT4PyError):
def __init__(self, cls_name: str, msg: str):
super().__init__(f"Error in `{cls_name}`: {msg}.")
self.cls_name = cls_name
self.msg = msg


class InfiniteRangeNdarrayError(gt4py_exceptions.GT4PyError):
Expand Down
19 changes: 12 additions & 7 deletions src/gt4py/next/embedded/function_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,15 @@ def __post_init__(self):
try:
func_params = self.func.__code__.co_argcount
if func_params != num_params:
raise ValueError(
f"Invariant violation: len(self.domain) ({num_params}) does not match the number of parameters of the self.func ({func_params})."
raise embedded_exceptions.FunctionFieldError(
self.__class__.__name__,
f"Invariant violation: len(self.domain) ({num_params}) does not match the number of parameters of the provided function ({func_params})",
)
except AttributeError:
raise ValueError(f"Must pass a function as an argument to self.func.")
raise embedded_exceptions.FunctionFieldError(
self.__class__.__name__,
f"Invalid first argument type: Expected a function but got {self.func}",
)

def restrict(self, index: common.AnyIndexSpec) -> FunctionField:
new_domain = embedded_common.sub_domain(self.domain, index)
Expand All @@ -57,10 +61,11 @@ def restrict(self, index: common.AnyIndexSpec) -> FunctionField:
@property
def ndarray(self) -> core_defs.NDArrayObject:
if not self.domain.is_finite():
embedded_exceptions.InfiniteRangeNdarrayError(self.__class__.__name__, self.domain)
raise embedded_exceptions.InfiniteRangeNdarrayError(
self.__class__.__name__, self.domain
)

shape = [len(rng) for rng in self.domain.ranges]

return np.fromfunction(self.func, shape)

def _handle_function_field_op(self, other: FunctionField, op: Callable) -> FunctionField:
Expand Down Expand Up @@ -90,7 +95,7 @@ def _binary_operation(self, op: Callable, other: common.Field) -> common.Field:
def _binary_operation(self, op, other):
if isinstance(other, self.__class__):
return self._handle_function_field_op(other, op)
elif isinstance(other, (int, float)): # Handle scalar values
elif isinstance(other, (int, float)):
return self._handle_scalar_op(other, op)
else:
return op(other, self)
Expand Down Expand Up @@ -183,7 +188,7 @@ def broadcasted_func(*args: int):
selected_args = [args[i] for i, dim in enumerate(dims) if dim in field.domain.dims]
return field.func(*selected_args)

named_ranges = embedded_common._compute_named_ranges(field, dims)
named_ranges = embedded_common._broadcast_domain(field, dims)
return FunctionField(broadcasted_func, common.Domain(*named_ranges), _skip_invariant=True)


Expand Down
18 changes: 15 additions & 3 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
import dataclasses
from collections.abc import Callable, Sequence
from types import ModuleType
from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar
from typing import Any, ClassVar, Optional, ParamSpec, TypeAlias, TypeVar, Sequence

import numpy as np
from numpy import typing as npt

from gt4py._core import definitions as core_defs
from gt4py.next import common
from gt4py.next.embedded import common as embedded_common
from gt4py.next.embedded.common import _compute_domain_slice, _compute_named_ranges
from gt4py.next.embedded.common import _broadcast_domain, _find_index_of_dim
from gt4py.next.ffront import fbuiltins


Expand Down Expand Up @@ -335,7 +335,7 @@ def __setitem__(

def _broadcast(field: common.Field, new_dimensions: tuple[common.Dimension, ...]) -> common.Field:
domain_slice = _compute_domain_slice(field, new_dimensions)
named_ranges = _compute_named_ranges(field, new_dimensions)
named_ranges = _broadcast_domain(field, new_dimensions)

# handle case where we have a constant FunctionField where ndarray is a scalar
if isinstance(value := field.ndarray, (int, float)):
Expand Down Expand Up @@ -415,3 +415,15 @@ def _compute_slice(
return rng - domain.ranges[pos].start
else:
raise ValueError(f"Can only use integer or UnitRange ranges, provided type: {type(rng)}")


def _compute_domain_slice(
field: common.Field, new_dimensions: tuple[common.Dimension, ...]
) -> Sequence[slice | None]:
domain_slice: list[slice | None] = []
for dim in new_dimensions:
if _find_index_of_dim(dim, field.domain) is not None:
domain_slice.append(slice(None))
else:
domain_slice.append(np.newaxis)
return domain_slice
Empty file.
79 changes: 37 additions & 42 deletions tests/next_tests/unit_tests/embedded_tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,64 +65,64 @@ def test_slice_range(rng, slce, expected):
([(I, (-2, 3))], (I, 3), IndexError),
([(I, (-2, 3))], (I, UnitRange(3, 4)), IndexError),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
2,
[(J, (3, 6)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
2,
[(J, (3, 6)), (K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
slice(2, 3),
[(I, (4, 5)), (J, (3, 6)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
slice(2, 3),
[(I, (4, 5)), (J, (3, 6)), (K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(I, 2),
[(J, (3, 6)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(I, 2),
[(J, (3, 6)), (K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(I, UnitRange(2, 3)),
[(I, (2, 3)), (J, (3, 6)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(I, UnitRange(2, 3)),
[(I, (2, 3)), (J, (3, 6)), (K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(J, 3),
[(I, (2, 5)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(J, 3),
[(I, (2, 5)), (K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(J, UnitRange(4, 5)),
[(I, (2, 5)), (J, (4, 5)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(J, UnitRange(4, 5)),
[(I, (2, 5)), (J, (4, 5)), (K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
((J, 3), (I, 2)),
[(K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
((J, 3), (I, 2)),
[(K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
((J, UnitRange(4, 5)), (I, 2)),
[(J, (4, 5)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
((J, UnitRange(4, 5)), (I, 2)),
[(J, (4, 5)), (K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(slice(1, 2), slice(2, 3)),
[(I, (3, 4)), (J, (5, 6)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(slice(1, 2), slice(2, 3)),
[(I, (3, 4)), (J, (5, 6)), (K, (4, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(Ellipsis, slice(2, 3)),
[(I, (2, 5)), (J, (3, 6)), (K, (6, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(Ellipsis, slice(2, 3)),
[(I, (2, 5)), (J, (3, 6)), (K, (6, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(slice(1, 2), Ellipsis, slice(2, 3)),
[(I, (3, 4)), (J, (3, 6)), (K, (6, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(slice(1, 2), Ellipsis, slice(2, 3)),
[(I, (3, 4)), (J, (3, 6)), (K, (6, 7))],
),
(
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(slice(1, 2), slice(1, 2), Ellipsis),
[(I, (3, 4)), (J, (4, 5)), (K, (4, 7))],
[(I, (2, 5)), (J, (3, 6)), (K, (4, 7))],
(slice(1, 2), slice(1, 2), Ellipsis),
[(I, (3, 4)), (J, (4, 5)), (K, (4, 7))],
),
([], Ellipsis, []),
([], slice(None), IndexError),
Expand All @@ -144,21 +144,16 @@ def test_sub_domain(domain, index, expected):

@pytest.fixture
def finite_domain():
I = common.Dimension("I")
J = common.Dimension("J")
return common.Domain((I, UnitRange(-1, 3)), (J, UnitRange(2, 4)))


@pytest.fixture
def infinite_domain():
I = common.Dimension("I")
return common.Domain((I, UnitRange.infinity()))
return common.Domain((I, UnitRange.infinity()), (J, UnitRange.infinity()))


@pytest.fixture
def mixed_domain():
I = common.Dimension("I")
J = common.Dimension("J")
return common.Domain((I, UnitRange(-1, 3)), (J, UnitRange.infinity()))


Expand Down
Loading

0 comments on commit c5a026d

Please sign in to comment.