Skip to content

Commit

Permalink
make gtx.Fields compatible with gt4py.cartesian
Browse files Browse the repository at this point in the history
  • Loading branch information
havogt committed Dec 20, 2023
1 parent 2393a29 commit ee54718
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 75 deletions.
117 changes: 80 additions & 37 deletions examples/lap_cartesian_vs_next.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -49,6 +49,49 @@
"dtype = np.float64"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Storages\n",
"--"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<link href=\"https://fonts.googleapis.com/icon?family=Material+Icons\" rel=\"stylesheet\"><script src=\"https://spcl.github.io/dace/webclient2/dist/sdfv.js\"></script>\n",
"<link href=\"https://spcl.github.io/dace/webclient2/sdfv.css\" rel=\"stylesheet\">\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import gt4py.next as gtx\n",
"\n",
"allocator = gtx.itir_embedded # should match the executor\n",
"\n",
"I = gtx.Dimension(\"I\")\n",
"J = gtx.Dimension(\"J\")\n",
"K = gtx.Dimension(\"K\")\n",
"\n",
"domain = gtx.domain({I: nx, J: ny, K: nz})\n",
"\n",
"inp = gtx.as_field(domain, np.fromfunction(lambda x, y, z: x**2+y**2, shape=(nx, ny, nz)), dtype, allocator=allocator)\n",
"out_cartesian = gtx.zeros(domain, dtype, allocator=allocator)\n",
"out_next = gtx.zeros(domain, dtype, allocator=allocator)"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand All @@ -60,74 +103,74 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import gt4py.storage as storage\n",
"import gt4py.cartesian.gtscript as gtscript\n",
"\n",
"backend = \"numpy\"\n",
"# backend = \"gt:cpu_ifirst\"\n",
"cartesian_backend = \"numpy\"\n",
"# cartesian_backend = \"gt:cpu_ifirst\"\n",
"\n",
"@gtscript.stencil(backend=backend)\n",
"@gtscript.stencil(backend=cartesian_backend)\n",
"def lap_cartesian(\n",
" inp: gtscript.Field[dtype],\n",
" out: gtscript.Field[dtype],\n",
"):\n",
" with computation(PARALLEL), interval(...):\n",
" out = -4.0 * inp[0, 0, 0] + inp[-1, 0, 0] + inp[1, 0, 0] + inp[0, -1, 0] + inp[0, 1, 0]\n",
"\n",
"inp = storage.from_array(np.fromfunction(lambda x, y, z: x**2+y**2, shape=(nx, ny, nz)), dtype, backend=backend)\n",
"out_cartesian = storage.zeros((nx, ny, nz), dtype, backend=backend)\n",
"\n",
"lap_cartesian(inp=inp, out=out_cartesian, origin=(1, 1, 0), domain=(nx-2, ny-2, nz))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from gt4py.next.common import Dimension, Field\n",
"from gt4py.next.ffront.decorator import field_operator, program\n",
"from gt4py.next.ffront.fbuiltins import FieldOffset\n",
"from gt4py.next.iterator.embedded import np_as_located_field as with_dimensions\n",
"from gt4py.next.program_processors.runners.roundtrip import executor as itir_embedded\n",
"from gt4py.next.program_processors.runners.gtfn_cpu import run_gtfn as cpu_backend\n",
"\n",
"backend = itir_embedded\n",
"# backend = cpu_backend\n",
"from gt4py.next import Field\n",
"\n",
"IDim = Dimension(\"IDim\")\n",
"JDim = Dimension(\"JDim\")\n",
"KDim = Dimension(\"KDim\")\n",
"backend = gtx.itir_embedded\n",
"# backend = cpu_backend\n",
"\n",
"I = FieldOffset(\"I\", source=IDim, target=(IDim,))\n",
"J = FieldOffset(\"J\", source=JDim, target=(JDim,))\n",
"Ioff = gtx.FieldOffset(\"I\", source=I, target=(I,))\n",
"Joff = gtx.FieldOffset(\"J\", source=J, target=(J,))\n",
"\n",
"@field_operator\n",
"def lap_next(inp: Field[[IDim, JDim, KDim], dtype]) -> Field[[IDim, JDim, KDim], dtype]:\n",
" return -4.0 * inp + inp(I[-1]) + inp(I[1]) + inp(J[-1]) + inp(J[1])\n",
"@gtx.field_operator\n",
"def lap_next(inp: Field[[I, J, K], dtype]) -> Field[[I, J, K], dtype]:\n",
" return -4.0 * inp + inp(Ioff[-1]) + inp(Ioff[1]) + inp(Joff[-1]) + inp(Joff[1])\n",
"\n",
"@program(backend=backend)\n",
"def lap_next_program(inp: Field[[IDim, JDim, KDim], dtype], out: Field[[IDim, JDim, KDim], dtype]):\n",
"@gtx.program(backend=backend)\n",
"def lap_next_program(inp: Field[[I, J, K], dtype], out: Field[[I, J, K], dtype]):\n",
" lap_next(inp, out=out[1:-1, 1:-1, :])\n",
"\n",
"inp_next = with_dimensions(IDim,JDim,KDim)(inp)\n",
"out_next = with_dimensions(IDim, JDim, KDim)(np.zeros(shape=(nx, ny, nz)))\n",
"\n",
"lap_next_program(inp_next, out_next, offset_provider={\"I\": IDim, \"J\": JDim})"
"lap_next_program(inp, out_next, offset_provider={\"Ioff\": I, \"Joff\": J})"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "TypeError",
"evalue": "ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mallclose\u001b[49m\u001b[43m(\u001b[49m\u001b[43mout_cartesian\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mout_next\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m<__array_function__ internals>:200\u001b[0m, in \u001b[0;36mallclose\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
"File \u001b[0;32m~/git/gt4py/.venv/lib64/python3.10/site-packages/numpy/core/numeric.py:2270\u001b[0m, in \u001b[0;36mallclose\u001b[0;34m(a, b, rtol, atol, equal_nan)\u001b[0m\n\u001b[1;32m 2199\u001b[0m \u001b[38;5;129m@array_function_dispatch\u001b[39m(_allclose_dispatcher)\n\u001b[1;32m 2200\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mallclose\u001b[39m(a, b, rtol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1.e-5\u001b[39m, atol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1.e-8\u001b[39m, equal_nan\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m 2201\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 2202\u001b[0m \u001b[38;5;124;03m Returns True if two arrays are element-wise equal within a tolerance.\u001b[39;00m\n\u001b[1;32m 2203\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 2268\u001b[0m \n\u001b[1;32m 2269\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 2270\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mall\u001b[39m(\u001b[43misclose\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrtol\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrtol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43matol\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43matol\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mequal_nan\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mequal_nan\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 2271\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mbool\u001b[39m(res)\n",
"File \u001b[0;32m<__array_function__ internals>:200\u001b[0m, in \u001b[0;36misclose\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
"File \u001b[0;32m~/git/gt4py/.venv/lib64/python3.10/site-packages/numpy/core/numeric.py:2377\u001b[0m, in \u001b[0;36misclose\u001b[0;34m(a, b, rtol, atol, equal_nan)\u001b[0m\n\u001b[1;32m 2374\u001b[0m dt \u001b[38;5;241m=\u001b[39m multiarray\u001b[38;5;241m.\u001b[39mresult_type(y, \u001b[38;5;241m1.\u001b[39m)\n\u001b[1;32m 2375\u001b[0m y \u001b[38;5;241m=\u001b[39m asanyarray(y, dtype\u001b[38;5;241m=\u001b[39mdt)\n\u001b[0;32m-> 2377\u001b[0m xfin \u001b[38;5;241m=\u001b[39m \u001b[43misfinite\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2378\u001b[0m yfin \u001b[38;5;241m=\u001b[39m isfinite(y)\n\u001b[1;32m 2379\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mall\u001b[39m(xfin) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mall\u001b[39m(yfin):\n",
"\u001b[0;31mTypeError\u001b[0m: ufunc 'isfinite' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''"
]
}
],
"source": [
"assert np.allclose(out_cartesian, out_next)"
"assert np.allclose(out_cartesian.asnumpy(), out_next.asnumpy())"
]
}
],
Expand All @@ -147,7 +190,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.2"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/next/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@
index_field,
np_as_located_field,
)
from .program_processors.runners.gtfn import (
run_gtfn_cached as gtfn_cpu,
run_gtfn_gpu_cached as gtfn_gpu,
)
from .program_processors.runners.roundtrip import backend as itir_embedded


__all__ = [
Expand Down Expand Up @@ -74,5 +79,9 @@
"field_operator",
"program",
"scan_operator",
# from program_processor
"gtfn_cpu",
"gtfn_gpu",
"itir_embedded",
*fbuiltins.__all__,
]
33 changes: 12 additions & 21 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,22 +574,7 @@ def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _
...


# TODO(havogt): replace this protocol with the new `GTFieldInterface` protocol
class NextGTDimsInterface(Protocol):
"""
Protocol for objects providing the `__gt_dims__` property, naming :class:`Field` dimensions.
The dimension names are objects of type :class:`Dimension`, in contrast to
:mod:`gt4py.cartesian`, where the labels are `str` s with implied semantics,
see :class:`~gt4py._core.definitions.GTDimsInterface` .
"""

@property
def __gt_dims__(self) -> tuple[Dimension, ...]:
...


# TODO(egparedes): add support for this new protocol in the cartesian module
# TODO(havogt): we need to specify when we should use this interface vs the `Field` protocol.
class GTFieldInterface(Protocol):
"""Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`."""

Expand All @@ -599,13 +584,23 @@ def __gt_domain__(self) -> Domain:


@extended_runtime_checkable
class Field(NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, core_defs.ScalarT]):
class Field(
core_defs.GTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, core_defs.ScalarT]
):
__gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher]

@property
def domain(self) -> Domain:
...

@property
def __gt_domain__(self) -> Domain:
return self.domain

@property
def __gt_dims__(self) -> tuple[str, ...]:
return tuple(d.value for d in self.domain.dims)

@property
def codomain(self) -> type[core_defs.ScalarT] | Dimension:
...
Expand Down Expand Up @@ -923,10 +918,6 @@ def asnumpy(self) -> Never:
def domain(self) -> Domain:
return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),))

@property
def __gt_dims__(self) -> tuple[Dimension, ...]:
return self.domain.dims

@property
def __gt_origin__(self) -> Never:
raise TypeError("'CartesianConnectivity' does not support this operation.")
Expand Down
4 changes: 0 additions & 4 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,6 @@ def domain(self) -> common.Domain:
def shape(self) -> tuple[int, ...]:
return self._ndarray.shape

@property
def __gt_dims__(self) -> tuple[common.Dimension, ...]:
return self._domain.dims

@property
def __gt_origin__(self) -> tuple[int, ...]:
assert common.Domain.is_finite(self._domain)
Expand Down
33 changes: 22 additions & 11 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class LocatedField(Protocol):

@property
@abc.abstractmethod
def __gt_dims__(self) -> tuple[common.Dimension, ...]:
def __gt_domain__(self) -> common.Domain:
...

# TODO(havogt): define generic Protocol to provide a concrete return type
Expand All @@ -182,7 +182,7 @@ def field_getitem(self, indices: NamedFieldIndices) -> Any:

@property
def __gt_origin__(self) -> tuple[int, ...]:
return tuple([0] * len(self.__gt_dims__))
return tuple([0] * len(self.__gt_domain__.dims))


@runtime_checkable
Expand Down Expand Up @@ -680,7 +680,18 @@ def _get_axes(
assert all(first == _get_axes(f) for f in field_or_tuple)
return first
else:
return field_or_tuple.__gt_dims__
return field_or_tuple.__gt_domain__.dims


def _get_domain(
field_or_tuple: LocatedField | tuple,
) -> common.Domain: # arbitrary nesting of tuples of LocatedField
if isinstance(field_or_tuple, tuple):
first = _get_domain(field_or_tuple[0])
assert all(first == _get_domain(f) for f in field_or_tuple)
return first
else:
return field_or_tuple.__gt_domain__


def _single_vertical_idx(
Expand Down Expand Up @@ -894,14 +905,14 @@ class NDArrayLocatedFieldWrapper(MutableLocatedField):
_ndarrayfield: common.Field

@property
def __gt_dims__(self) -> tuple[common.Dimension, ...]:
return self._ndarrayfield.__gt_dims__
def __gt_domain__(self) -> common.Domain:
return self._ndarrayfield.__gt_domain__

def _translate_named_indices(
self, _named_indices: NamedFieldIndices
) -> common.AbsoluteIndexSequence:
named_indices: Mapping[common.Dimension, FieldIndex | SparsePositionEntry] = {
d: _named_indices[d.value] for d in self._ndarrayfield.__gt_dims__
d: _named_indices[d.value] for d in self._ndarrayfield.__gt_domain__.dims
}
domain_slice: list[common.NamedRange | common.NamedIndex] = []
for d, v in named_indices.items():
Expand Down Expand Up @@ -1046,8 +1057,8 @@ class IndexField(common.Field):
_dimension: common.Dimension

@property
def __gt_dims__(self) -> tuple[common.Dimension, ...]:
return (self._dimension,)
def __gt_domain__(self) -> common.Domain:
return self.domain

@property
def __gt_origin__(self) -> tuple[int, ...]:
Expand Down Expand Up @@ -1165,8 +1176,8 @@ class ConstantField(common.Field[Any, core_defs.ScalarT]):
_value: core_defs.ScalarT

@property
def __gt_dims__(self) -> tuple[common.Dimension, ...]:
return tuple()
def __gt_domain__(self) -> common.Domain:
return self.domain

@property
def __gt_origin__(self) -> tuple[int, ...]:
Expand Down Expand Up @@ -1452,7 +1463,7 @@ def _tuple_assign(field: tuple | MutableLocatedField, value: Any, named_indices:
class TupleOfFields(TupleField):
def __init__(self, data):
self.data = data
self.__gt_dims__ = _get_axes(data)
self.__gt_domain__ = _get_domain(data)

def field_getitem(self, named_indices: NamedFieldIndices) -> Any:
return _build_tuple_result(self.data, named_indices)
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/iterator/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _contains_tuple_dtype_field(arg):
# other `np.int32`). We just ignore the error here and postpone fixing this to when
# the new storages land (The implementation here works for LocatedFieldImpl).

return common.is_field(arg) and any(dim is None for dim in arg.__gt_dims__)
return common.is_field(arg) and any(dim is None for dim in arg.domain.dims)


def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]:
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/type_system/type_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def from_value(value: Any) -> ts.TypeSpec:
elif isinstance(value, common.Dimension):
symbol_type = ts.DimensionType(dim=value)
elif common.is_field(value):
dims = list(value.__gt_dims__)
dims = list(value.domain.dims)
dtype = from_type_hint(value.dtype.scalar_type)
symbol_type = ts.FieldType(dims=dims, dtype=dtype)
elif isinstance(value, tuple):
Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/storage/cartesian/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ def cpu_copy(array: Union[np.ndarray, "cp.ndarray"]) -> np.ndarray:
def asarray(
array: FieldLike, *, device: Literal["cpu", "gpu", None] = None
) -> np.ndarray | cp.ndarray:
if hasattr(array, "ndarray"):
array = array.ndarray
if device == "gpu" or (not device and hasattr(array, "__cuda_array_interface__")):
return cp.asarray(array)
if device == "cpu" or (
Expand Down

0 comments on commit ee54718

Please sign in to comment.