Skip to content

Commit

Permalink
example: cartesian with next compatibility (#1202)
Browse files Browse the repository at this point in the history
Add an example illustrating using gt4py.cartesian and gt4py.next computations next to each other using gt4py.next storages.

Refactor GTFieldInterface and cleanup GTDimsInterface for next.
  • Loading branch information
havogt authored Jan 18, 2024
1 parent 3edf21e commit ba36856
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 38 deletions.
189 changes: 189 additions & 0 deletions examples/lap_cartesian_vs_next.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"GT4Py - GridTools for Python\n",
"\n",
"Copyright (c) 2014-2023, ETH Zurich\n",
"All rights reserved.\n",
"\n",
"This file is part the GT4Py project and the GridTools framework.\n",
"GT4Py is free software: you can redistribute it and/or modify it under\n",
"the terms of the GNU General Public License as published by the\n",
"Free Software Foundation, either version 3 of the License, or any later\n",
"version. See the LICENSE.txt file at the top-level directory of this\n",
"distribution for a copy of the license or check <https://www.gnu.org/licenses/>.\n",
"\n",
"SPDX-License-Identifier: GPL-3.0-or-later"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Demonstrates gt4py.cartesian with gt4py.next compatibility"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"nx = 32\n",
"ny = 32\n",
"nz = 1\n",
"dtype = np.float64"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Storages\n",
"--\n",
"\n",
"We create fields using the gt4py.next constructors. These fields are compatible with gt4py.cartesian when we use \"I\", \"J\", \"K\" as the dimension names."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<link href=\"https://fonts.googleapis.com/icon?family=Material+Icons\" rel=\"stylesheet\"><script src=\"https://spcl.github.io/dace/webclient2/dist/sdfv.js\"></script>\n",
"<link href=\"https://spcl.github.io/dace/webclient2/sdfv.css\" rel=\"stylesheet\">\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import gt4py.next as gtx\n",
"\n",
"allocator = gtx.itir_embedded # should match the executor\n",
"# allocator = gtx.gtfn_cpu\n",
"# allocator = gtx.gtfn_gpu\n",
"\n",
"# Note: for gt4py.next, names don't matter, for gt4py.cartesian they have to be \"I\", \"J\", \"K\"\n",
"I = gtx.Dimension(\"I\")\n",
"J = gtx.Dimension(\"J\")\n",
"K = gtx.Dimension(\"K\", kind=gtx.DimensionKind.VERTICAL)\n",
"\n",
"domain = gtx.domain({I: nx, J: ny, K: nz})\n",
"\n",
"inp = gtx.as_field(domain, np.fromfunction(lambda x, y, z: x**2+y**2, shape=(nx, ny, nz)), dtype, allocator=allocator)\n",
"out_cartesian = gtx.zeros(domain, dtype, allocator=allocator)\n",
"out_next = gtx.zeros(domain, dtype, allocator=allocator)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"gt4py.cartesian\n",
"--"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import gt4py.cartesian.gtscript as gtscript\n",
"\n",
"cartesian_backend = \"numpy\"\n",
"# cartesian_backend = \"gt:cpu_ifirst\"\n",
"# cartesian_backend = \"gt:gpu\"\n",
"\n",
"@gtscript.stencil(backend=cartesian_backend)\n",
"def lap_cartesian(\n",
" inp: gtscript.Field[dtype],\n",
" out: gtscript.Field[dtype],\n",
"):\n",
" with computation(PARALLEL), interval(...):\n",
" out = -4.0 * inp[0, 0, 0] + inp[-1, 0, 0] + inp[1, 0, 0] + inp[0, -1, 0] + inp[0, 1, 0]\n",
"\n",
"lap_cartesian(inp=inp, out=out_cartesian, origin=(1, 1, 0), domain=(nx-2, ny-2, nz))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from gt4py.next import Field\n",
"\n",
"next_backend = gtx.itir_embedded\n",
"# next_backend = gtx.gtfn_cpu\n",
"# next_backend = gtx.gtfn_gpu\n",
"\n",
"Ioff = gtx.FieldOffset(\"I\", source=I, target=(I,))\n",
"Joff = gtx.FieldOffset(\"J\", source=J, target=(J,))\n",
"\n",
"@gtx.field_operator\n",
"def lap_next(inp: Field[[I, J, K], dtype]) -> Field[[I, J, K], dtype]:\n",
" return -4.0 * inp + inp(Ioff[-1]) + inp(Ioff[1]) + inp(Joff[-1]) + inp(Joff[1])\n",
"\n",
"@gtx.program(backend=next_backend)\n",
"def lap_next_program(inp: Field[[I, J, K], dtype], out: Field[[I, J, K], dtype]):\n",
" lap_next(inp, out=out[1:-1, 1:-1, :])\n",
"\n",
"lap_next_program(inp, out_next, offset_provider={\"Ioff\": I, \"Joff\": J})"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"assert np.allclose(out_cartesian.asnumpy(), out_next.asnumpy())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
9 changes: 9 additions & 0 deletions src/gt4py/next/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@
index_field,
np_as_located_field,
)
from .program_processors.runners.gtfn import (
run_gtfn_cached as gtfn_cpu,
run_gtfn_gpu_cached as gtfn_gpu,
)
from .program_processors.runners.roundtrip import backend as itir_python


__all__ = [
Expand Down Expand Up @@ -74,5 +79,9 @@
"field_operator",
"program",
"scan_operator",
# from program_processor
"gtfn_cpu",
"gtfn_gpu",
"itir_python",
*fbuiltins.__all__,
]
35 changes: 16 additions & 19 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,38 +574,39 @@ def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _
...


# TODO(havogt): replace this protocol with the new `GTFieldInterface` protocol
class NextGTDimsInterface(Protocol):
# TODO(havogt): we need to describe when this interface should be used instead of the `Field` protocol.
class GTFieldInterface(core_defs.GTDimsInterface, core_defs.GTOriginInterface, Protocol):
"""
Protocol for objects providing the `__gt_dims__` property, naming :class:`Field` dimensions.
Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`.
The dimension names are objects of type :class:`Dimension`, in contrast to
:mod:`gt4py.cartesian`, where the labels are `str` s with implied semantics,
see :class:`~gt4py._core.definitions.GTDimsInterface` .
Note:
- A default implementation of the `__gt_dims__` interface from `gt4py.cartesian` is provided.
- No implementation of `__gt_origin__` is provided because of infinite fields.
"""

@property
def __gt_dims__(self) -> tuple[Dimension, ...]:
def __gt_domain__(self) -> Domain:
# TODO probably should be changed to `DomainLike` (with a new concept `DimensionLike`)
# to allow implementations without having to import gtx.Domain.
...


# TODO(egparedes): add support for this new protocol in the cartesian module
class GTFieldInterface(Protocol):
"""Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`."""

@property
def __gt_domain__(self) -> Domain:
...
def __gt_dims__(self) -> tuple[str, ...]:
return tuple(d.value for d in self.__gt_domain__.dims)


@extended_runtime_checkable
class Field(NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, core_defs.ScalarT]):
class Field(GTFieldInterface, Protocol[DimsT, core_defs.ScalarT]):
__gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher]

@property
def domain(self) -> Domain:
...

@property
def __gt_domain__(self) -> Domain:
return self.domain

@property
def codomain(self) -> type[core_defs.ScalarT] | Dimension:
...
Expand Down Expand Up @@ -923,10 +924,6 @@ def asnumpy(self) -> Never:
def domain(self) -> Domain:
return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),))

@property
def __gt_dims__(self) -> tuple[Dimension, ...]:
return self.domain.dims

@property
def __gt_origin__(self) -> Never:
raise TypeError("'CartesianConnectivity' does not support this operation.")
Expand Down
4 changes: 0 additions & 4 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,6 @@ def domain(self) -> common.Domain:
def shape(self) -> tuple[int, ...]:
return self._ndarray.shape

@property
def __gt_dims__(self) -> tuple[common.Dimension, ...]:
return self._domain.dims

@property
def __gt_origin__(self) -> tuple[int, ...]:
assert common.Domain.is_finite(self._domain)
Expand Down
32 changes: 19 additions & 13 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class LocatedField(Protocol):

@property
@abc.abstractmethod
def __gt_dims__(self) -> tuple[common.Dimension, ...]:
def __gt_domain__(self) -> common.Domain:
...

# TODO(havogt): define generic Protocol to provide a concrete return type
Expand All @@ -182,7 +182,7 @@ def field_getitem(self, indices: NamedFieldIndices) -> Any:

@property
def __gt_origin__(self) -> tuple[int, ...]:
return tuple([0] * len(self.__gt_dims__))
return tuple([0] * len(self.__gt_domain__.dims))


@runtime_checkable
Expand Down Expand Up @@ -675,12 +675,18 @@ def _is_concrete_position(pos: Position) -> TypeGuard[ConcretePosition]:
def _get_axes(
field_or_tuple: LocatedField | tuple,
) -> Sequence[common.Dimension]: # arbitrary nesting of tuples of LocatedField
return _get_domain(field_or_tuple).dims


def _get_domain(
field_or_tuple: LocatedField | tuple,
) -> common.Domain: # arbitrary nesting of tuples of LocatedField
if isinstance(field_or_tuple, tuple):
first = _get_axes(field_or_tuple[0])
assert all(first == _get_axes(f) for f in field_or_tuple)
first = _get_domain(field_or_tuple[0])
assert all(first == _get_domain(f) for f in field_or_tuple)
return first
else:
return field_or_tuple.__gt_dims__
return field_or_tuple.__gt_domain__


def _single_vertical_idx(
Expand Down Expand Up @@ -894,14 +900,14 @@ class NDArrayLocatedFieldWrapper(MutableLocatedField):
_ndarrayfield: common.Field

@property
def __gt_dims__(self) -> tuple[common.Dimension, ...]:
return self._ndarrayfield.__gt_dims__
def __gt_domain__(self) -> common.Domain:
return self._ndarrayfield.__gt_domain__

def _translate_named_indices(
self, _named_indices: NamedFieldIndices
) -> common.AbsoluteIndexSequence:
named_indices: Mapping[common.Dimension, FieldIndex | SparsePositionEntry] = {
d: _named_indices[d.value] for d in self._ndarrayfield.__gt_dims__
d: _named_indices[d.value] for d in self._ndarrayfield.__gt_domain__.dims
}
domain_slice: list[common.NamedRange | common.NamedIndex] = []
for d, v in named_indices.items():
Expand Down Expand Up @@ -1046,8 +1052,8 @@ class IndexField(common.Field):
_dimension: common.Dimension

@property
def __gt_dims__(self) -> tuple[common.Dimension, ...]:
return (self._dimension,)
def __gt_domain__(self) -> common.Domain:
return self.domain

@property
def __gt_origin__(self) -> tuple[int, ...]:
Expand Down Expand Up @@ -1165,8 +1171,8 @@ class ConstantField(common.Field[Any, core_defs.ScalarT]):
_value: core_defs.ScalarT

@property
def __gt_dims__(self) -> tuple[common.Dimension, ...]:
return tuple()
def __gt_domain__(self) -> common.Domain:
return self.domain

@property
def __gt_origin__(self) -> tuple[int, ...]:
Expand Down Expand Up @@ -1452,7 +1458,7 @@ def _tuple_assign(field: tuple | MutableLocatedField, value: Any, named_indices:
class TupleOfFields(TupleField):
def __init__(self, data):
self.data = data
self.__gt_dims__ = _get_axes(data)
self.__gt_domain__ = _get_domain(data)

def field_getitem(self, named_indices: NamedFieldIndices) -> Any:
return _build_tuple_result(self.data, named_indices)
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _contains_tuple_dtype_field(arg):
# other `np.int32`). We just ignore the error here and postpone fixing this to when
# the new storages land (The implementation here works for LocatedFieldImpl).

return common.is_field(arg) and any(dim is None for dim in arg.__gt_dims__)
return common.is_field(arg) and any(dim is None for dim in arg.domain.dims)


def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]:
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/type_system/type_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def from_value(value: Any) -> ts.TypeSpec:
elif isinstance(value, common.Dimension):
symbol_type = ts.DimensionType(dim=value)
elif common.is_field(value):
dims = list(value.__gt_dims__)
dims = list(value.domain.dims)
dtype = from_type_hint(value.dtype.scalar_type)
symbol_type = ts.FieldType(dims=dims, dtype=dtype)
elif isinstance(value, tuple):
Expand Down
Loading

0 comments on commit ba36856

Please sign in to comment.