diff --git a/model/common/src/icon4py/model/common/utils/gt4py_field_allocation.py b/model/common/src/icon4py/model/common/utils/gt4py_field_allocation.py index 77740070b..137ead06b 100644 --- a/model/common/src/icon4py/model/common/utils/gt4py_field_allocation.py +++ b/model/common/src/icon4py/model/common/utils/gt4py_field_allocation.py @@ -15,14 +15,20 @@ from icon4py.model.common import dimension, type_alias as ta +""" Enum values from Enum values taken from DLPack reference implementation at: + https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h + via GT4Py +""" +CUDA_DEVICE_TYPES = ( + gt_core_defs.DeviceType.CUDA, + gt_core_defs.DeviceType.CUDA_MANAGED, + gt_core_defs.DeviceType.ROCM, +) + + def is_cupy_device(backend: backend.Backend) -> bool: - cuda_device_types = ( - gt_core_defs.DeviceType.CUDA, - gt_core_defs.DeviceType.CUDA_MANAGED, - gt_core_defs.DeviceType.ROCM, - ) if backend is not None: - return backend.allocator.__gt_device_type__ in cuda_device_types + return backend.allocator.__gt_device_type__ in CUDA_DEVICE_TYPES else: return False @@ -50,6 +56,12 @@ def as_field(field: gtx.Field, backend: backend.Backend) -> gtx.Field: return gtx.as_field(field.domain, field.ndarray, allocator=backend) +def _size(grid, dim: gtx.Dimension, is_half_dim: bool) -> int: + if dim == dimension.KDim and is_half_dim: + return grid.size[dim] + 1 + return grid.size[dim] + + def allocate_zero_field( *dims: gtx.Dimension, grid, @@ -57,13 +69,7 @@ def allocate_zero_field( dtype=ta.wpfloat, backend: Optional[backend.Backend] = None, ) -> gtx.Field: - def size(dim: gtx.Dimension, is_half_dim: bool) -> int: - if dim == dimension.KDim and is_half_dim: - return grid.size[dim] + 1 - else: - return grid.size[dim] - - dimensions = {d: range(size(d, is_halfdim)) for d in dims} + dimensions = {d: range(_size(grid, d, is_halfdim)) for d in dims} return gtx.zeros(dimensions, dtype=dtype, allocator=backend) @@ -75,5 +81,5 @@ def allocate_indices( backend: Optional[backend.Backend] = None, ) -> gtx.Field: xp = import_array_ns(backend) - shapex = grid.size[dim] + 1 if is_halfdim else grid.size[dim] + shapex = _size(grid, dim, is_halfdim) return gtx.as_field((dim,), xp.arange(shapex, dtype=dtype), allocator=backend)