Skip to content

Commit

Permalink
Add unit test for next.allocators and many other fixes for reviewer's…
Browse files Browse the repository at this point in the history
… comments
  • Loading branch information
egparedes committed Nov 13, 2023
1 parent f01fd06 commit 99e8856
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 31 deletions.
7 changes: 2 additions & 5 deletions src/gt4py/eve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,7 @@ def itemgetter_(key: Any, default: Any = NOTHING) -> Callable[[Any], Any]:


class fluid_partial(functools.partial):
"""
A `functools.partial` subclass supporting multiple applications by calling `.partial()`.
"""
"""Create a `functools.partial` with support for multiple applications calling `.partial()`."""

def partial(self, *args, **kwargs) -> fluid_partial:
return fluid_partial(self, *args, **kwargs)
Expand All @@ -257,8 +255,7 @@ def with_fluid_partial( # noqa: F811 # redefinition of unused function
def with_fluid_partial( # noqa: F811 # redefinition of unused function
func: Optional[Callable[..., Any]] = None, *args: Any, **kwargs: Any
) -> Union[Callable[..., Any], Callable[[Callable[..., Any]], Callable[..., Any]]]:
"""
A decorator that adds a `partial` attribute to the decorated function.
"""Add a `partial` attribute to the decorated function.
The `partial` attribute is a function that behaves like `functools.partial`,
but also supports partial application of the decorated function. It can be
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
"NeighborTableOffsetProvider",
"StridedNeighborOffsetProvider",
"index_field",
"np_as_located_field"
"np_as_located_field",
# from ffront
"FieldOffset",
"field_operator",
Expand Down
61 changes: 47 additions & 14 deletions src/gt4py/next/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np

import gt4py._core.definitions as core_defs
import gt4py.eve as eve
import gt4py.next.common as common
import gt4py.storage.allocators as core_allocators
from gt4py.eve.extended_typing import (
Expand Down Expand Up @@ -53,6 +54,8 @@


class FieldBufferAllocatorProtocol(Protocol[core_defs.DeviceTypeT]):
"""Protocol for buffer allocators used to allocate memory for fields with a given domain."""

@property
@abc.abstractmethod
def __gt_device_type__(self) -> core_defs.DeviceTypeT:
Expand Down Expand Up @@ -80,7 +83,10 @@ def is_field_allocator_for(


class FieldBufferAllocatorFactoryProtocol(Protocol[core_defs.DeviceTypeT]):
"""Protocol for device-specific buffer allocator factories for fields."""

@property
@abc.abstractmethod
def __gt_allocator__(self) -> FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]:
...

Expand All @@ -92,35 +98,50 @@ def is_field_allocator_factory(obj: Any) -> TypeGuard[FieldBufferAllocatorFactor
def is_field_allocator_factory_for(
obj: Any, device: core_defs.DeviceTypeT
) -> TypeGuard[FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]]:
return is_field_allocator_factory(obj) and obj.__gt_allocator__().__gt_device_type__ is device
return is_field_allocator_factory(obj) and obj.__gt_allocator__.__gt_device_type__ is device


FieldBufferAllocationTool = (
FieldBufferAllocationUtil = (
FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]
| FieldBufferAllocatorFactoryProtocol[core_defs.DeviceTypeT]
)


def is_field_allocation_tool(obj: Any) -> TypeGuard[FieldBufferAllocationTool]:
def is_field_allocation_tool(obj: Any) -> TypeGuard[FieldBufferAllocationUtil]:
return is_field_allocator(obj) or is_field_allocator_factory(obj)


def get_allocator(
obj: FieldBufferAllocationTool, default: Optional[FieldBufferAllocatorProtocol] = None
obj: Any, *, default: FieldBufferAllocatorProtocol | eve.NothingType = eve.NOTHING
) -> FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]:
"""
Return a field-buffer-allocator from an object assumed to be an allocator or an allocator factory.
A default allocator can be provided as fallback in case `obj` is neither an allocator nor a factory.
Arguments:
obj: The allocator or allocator factory.
default: Fallback allocator.
Returns:
A field buffer allocator.
Raises:
TypeError: If `obj` is neither a field allocator nor a field allocator factory and no default is provided.
"""
if is_field_allocator(obj):
return obj
elif is_field_allocator_factory(obj):
return obj.__gt_allocator__
elif default is not None:
elif default is not eve.NOTHING:
return default
else:
raise TypeError(f"Object {obj} is neither a field allocator nor a field allocator factory")


@dataclasses.dataclass(frozen=True)
class BaseFieldBufferAllocator(FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]):
"""Parametrizable field allocator base class."""
"""Parametrizable field buffer allocator base class."""

device_type: core_defs.DeviceTypeT
array_ns: core_allocators.ValidNumPyLikeAllocationNS
Expand Down Expand Up @@ -156,6 +177,8 @@ def __gt_allocate__(
def horizontal_first_layout_mapper(
dims: Sequence[common.Dimension],
) -> core_allocators.BufferLayoutMap:
"""Map dimensions to a buffer layout making horizonal dims change the slowest (i.e. larger strides)."""

def pos_of_kind(kind: common.DimensionKind) -> list[int]:
return [i for i, dim in enumerate(dims) if dim.kind == kind]

Expand All @@ -166,8 +189,10 @@ def pos_of_kind(kind: common.DimensionKind) -> list[int]:
layout_map = [0] * len(dims)
for i, pos in enumerate(horizontals + verticals + locals_):
layout_map[pos] = len(dims) - 1 - i

valid_layout_map = tuple(layout_map)
assert core_allocators.is_valid_layout_map(valid_layout_map)

return valid_layout_map


Expand All @@ -184,6 +209,8 @@ def pos_of_kind(kind: common.DimensionKind) -> list[int]:


class StandardCPUFieldBufferAllocator(BaseFieldBufferAllocator[core_defs.CPUDeviceTyping]):
"""A field buffer allocator for CPU devices that uses a horizontal-first layout mapper and 64-byte alignment."""

def __init__(self) -> None:
super().__init__(
device_type=core_defs.DeviceType.CPU,
Expand All @@ -200,6 +227,8 @@ def __init__(self) -> None:

@dataclasses.dataclass(frozen=True)
class InvalidFieldBufferAllocator(FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]):
"""A field buffer allocator that always raises an exception."""

device_type: core_defs.DeviceTypeT
exception: Exception

Expand Down Expand Up @@ -267,26 +296,30 @@ def allocate(
dtype: core_defs.DType[core_defs.ScalarT],
*,
aligned_index: Optional[Sequence[common.NamedIndex]] = None,
allocator: Optional[FieldBufferAllocationTool] = None,
allocator: Optional[FieldBufferAllocationUtil] = None,
device: Optional[core_defs.Device] = None,
) -> core_allocators.TensorBuffer:
"""
Allocate a TensorBuffer with the given settings on the given device.
Allocate a TensorBuffer for the given domain and device or allocator.
The arguments `device` and `allocator` are mutually exclusive.
If `device` is specified, the corresponding default allocator
(defined in :data:`device_allocators`) is used.
Args: TODO
domain:
dtype: Data type descriptor as defined in :meth:`BufferAllocator.allocate`.
Args:
domain: The domain which should be backed by the allocated tensor buffer.
dtype: Data type.
aligned_index: N-dimensional index of the first aligned element
allocator: The allocator to use for the allocation.
device: The device to allocate the tensor buffer on (using the default
allocator for this kind of device from :data:`device_allocators`).
aligned_index: N-dimensional index of the first aligned element as defined
in :meth:`BufferAllocator.allocate`.
Returns:
The allocated tensor buffer.
"""
if device is None and allocator is None:
raise ValueError("No 'device' or 'allocator' specified")
actual_allocator = get_allocator(allocator, None) if allocator is not None else None
actual_allocator = get_allocator(allocator, default=None)
if device is None:
assert actual_allocator is not None # for mypy
device = core_defs.Device(actual_allocator.__gt_device_type__, 0)
Expand Down
20 changes: 14 additions & 6 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,12 @@ def __str__(self) -> str:
return f"({self.start}:{self.stop})"


RangeLike: TypeAlias = UnitRange | range | tuple[core_defs.INTEGRAL_TYPES, core_defs.INTEGRAL_TYPES] | core_defs.INTEGRAL_TYPES
RangeLike: TypeAlias = (
UnitRange
| range
| tuple[core_defs.IntegralScalar, core_defs.IntegralScalar]
| core_defs.IntegralScalar
)


def unit_range(r: RangeLike) -> UnitRange:
Expand Down Expand Up @@ -404,22 +409,25 @@ def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _
...


# TODO(havogt): replace this protocol with the new `GTFieldInterface` protocol
class NextGTDimsInterface(Protocol):
"""
A `GTDimsInterface` is an object providing the `__gt_dims__` property, naming :class:`Field` dimensions.
Protocol for objects providing the `__gt_dims__` property, naming :class:`Field` dimensions.
The dimension names are objects of type :class:`Dimension`, in contrast to :mod:`gt4py.cartesian`,
where the labels are `str` s with implied semantics, see :class:`~gt4py._core.definitions.GTDimsInterface` .
The dimension names are objects of type :class:`Dimension`, in contrast to
:mod:`gt4py.cartesian`, where the labels are `str` s with implied semantics,
see :class:`~gt4py._core.definitions.GTDimsInterface` .
"""

# TODO(havogt): unify with GTDimsInterface, ideally in backward compatible way
@property
def __gt_dims__(self) -> tuple[Dimension, ...]:
...


####### TODO-> Update cartsian gtdims
# TODO(egparedes): add support for this new protocol in the cartesian module
class GTFieldInterface(Protocol):
"""Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`."""

@property
def __gt_domain__(self) -> Domain:
...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def visit_Sym(self, node: itir.Sym) -> itir.Sym:


@ppi.program_formatter
def format_itir(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str:
def format_itir_and_check(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str:
# remove types from ITIR as they are not supported for the roundtrip
root = _RemoveITIRSymTypes().visit(program)
pretty = pretty_printer.pformat(root)
Expand Down
8 changes: 5 additions & 3 deletions src/gt4py/storage/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@
xtyping.DLPackBuffer,
]

#: Tuple of positive integers encoding a permutation of the dimensions.
#: Tuple of positive integers encoding a permutation of the dimensions, such that
#: layout_map[i] = j means that the i-th dimension of the tensor corresponds
#: to the j-th dimension in the (C-layout) buffer.
BufferLayoutMap = NewType("BufferLayoutMap", Sequence[core_defs.PositiveIntegral])


Expand Down Expand Up @@ -85,9 +87,9 @@ class TensorBuffer(Generic[core_defs.DeviceTypeT, core_defs.ScalarT]):
dtype: Data type descriptor.
shape: Tuple with lengths of the corresponding tensor dimensions.
strides: Tuple with sizes (in bytes) of the steps in each dimension.
layout_map: Tuple with the order of the dimensions in the buffer.
layout_map: Tuple with the order of the dimensions in the buffer
layout_map[i] = j means that the i-th dimension of the tensor
corresponds to the j-th dimension in the buffer.
corresponds to the j-th dimension in the (C-layout) buffer.
byte_offset: Offset (in bytes) from the beginning of the buffer to
the first valid element.
byte_alignment: Alignment (in bytes) of the first valid element.
Expand Down
4 changes: 3 additions & 1 deletion tests/next_tests/exclusion_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ class OptionalProgramExecutorId(_PythonObjectIdMixin, str, enum.Enum):

class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum):
GTFN_CPP_FORMATTER = "gt4py.next.program_processors.formatters.gtfn.format_cpp"
ITIR_PRETTY_PRINTER = "gt4py.next.program_processors.formatters.pretty_print.format_itir"
ITIR_PRETTY_PRINTER = (
"gt4py.next.program_processors.formatters.pretty_print.format_itir_and_check"
)
ITIR_TYPE_CHECKER = "gt4py.next.program_processors.formatters.type_check.check_type_inference"
LISP_FORMATTER = "gt4py.next.program_processors.formatters.lisp.format_lisp"

Expand Down
Loading

0 comments on commit 99e8856

Please sign in to comment.