Skip to content

Commit

Permalink
extract CUPY devices
Browse files Browse the repository at this point in the history
extract _size function and use in all allocation funtions
  • Loading branch information
halungge committed Nov 27, 2024
1 parent 042e8e5 commit db15ea9
Showing 1 changed file with 20 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@
from icon4py.model.common import dimension, type_alias as ta


""" Enum values from Enum values taken from DLPack reference implementation at:
https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h
via GT4Py
"""
CUDA_DEVICE_TYPES = (
gt_core_defs.DeviceType.CUDA,
gt_core_defs.DeviceType.CUDA_MANAGED,
gt_core_defs.DeviceType.ROCM,
)


def is_cupy_device(backend: backend.Backend) -> bool:
cuda_device_types = (
gt_core_defs.DeviceType.CUDA,
gt_core_defs.DeviceType.CUDA_MANAGED,
gt_core_defs.DeviceType.ROCM,
)
if backend is not None:
return backend.allocator.__gt_device_type__ in cuda_device_types
return backend.allocator.__gt_device_type__ in CUDA_DEVICE_TYPES
else:
return False

Expand Down Expand Up @@ -50,20 +56,20 @@ def as_field(field: gtx.Field, backend: backend.Backend) -> gtx.Field:
return gtx.as_field(field.domain, field.ndarray, allocator=backend)


def _size(grid, dim: gtx.Dimension, is_half_dim: bool) -> int:
if dim == dimension.KDim and is_half_dim:
return grid.size[dim] + 1
return grid.size[dim]


def allocate_zero_field(
*dims: gtx.Dimension,
grid,
is_halfdim=False,
dtype=ta.wpfloat,
backend: Optional[backend.Backend] = None,
) -> gtx.Field:
def size(dim: gtx.Dimension, is_half_dim: bool) -> int:
if dim == dimension.KDim and is_half_dim:
return grid.size[dim] + 1
else:
return grid.size[dim]

dimensions = {d: range(size(d, is_halfdim)) for d in dims}
dimensions = {d: range(_size(grid, d, is_halfdim)) for d in dims}
return gtx.zeros(dimensions, dtype=dtype, allocator=backend)


Expand All @@ -75,5 +81,5 @@ def allocate_indices(
backend: Optional[backend.Backend] = None,
) -> gtx.Field:
xp = import_array_ns(backend)
shapex = grid.size[dim] + 1 if is_halfdim else grid.size[dim]
shapex = _size(grid, dim, is_halfdim)
return gtx.as_field((dim,), xp.arange(shapex, dtype=dtype), allocator=backend)

0 comments on commit db15ea9

Please sign in to comment.