Skip to content

Commit

Permalink
add convenience functions for
Browse files Browse the repository at this point in the history
- import of array_ns depending on backend
- transfer field to a given backend
  • Loading branch information
halungge committed Nov 26, 2024
1 parent a1244cd commit 042e8e5
Showing 1 changed file with 49 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,49 @@
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
import logging as log
from typing import Optional

import gt4py._core.definitions as gt_core_defs
import gt4py.next as gtx
from gt4py.next import backend

from icon4py.model.common import type_alias as ta
from icon4py.model.common.settings import xp
from icon4py.model.common import dimension, type_alias as ta


def is_cupy_device(backend: backend.Backend) -> bool:
cuda_device_types = (
gt_core_defs.DeviceType.CUDA,
gt_core_defs.DeviceType.CUDA_MANAGED,
gt_core_defs.DeviceType.ROCM,
)
if backend is not None:
return backend.allocator.__gt_device_type__ in cuda_device_types
else:
return False


def array_ns(try_cupy: bool):
if try_cupy:
try:
import cupy as cp

return cp
except ImportError:
log.warn("No cupy installed, falling back to numpy for array_ns")
import numpy as np

return np


def import_array_ns(backend: backend.Backend):
"""Import cupy or numpy depending on a chosen GT4Py backend DevicType."""
return array_ns(is_cupy_device(backend))


def as_field(field: gtx.Field, backend: backend.Backend) -> gtx.Field:
"""Convenience function to transfer an existing Field to a given backend."""
return gtx.as_field(field.domain, field.ndarray, allocator=backend)


def allocate_zero_field(
Expand All @@ -20,12 +56,15 @@ def allocate_zero_field(
is_halfdim=False,
dtype=ta.wpfloat,
backend: Optional[backend.Backend] = None,
):
shapex = tuple(map(lambda x: grid.size[x], dims))
if is_halfdim:
assert len(shapex) == 2
shapex = (shapex[0], shapex[1] + 1)
return gtx.as_field(dims, xp.zeros(shapex, dtype=dtype), allocator=backend)
) -> gtx.Field:
def size(dim: gtx.Dimension, is_half_dim: bool) -> int:
if dim == dimension.KDim and is_half_dim:
return grid.size[dim] + 1
else:
return grid.size[dim]

dimensions = {d: range(size(d, is_halfdim)) for d in dims}
return gtx.zeros(dimensions, dtype=dtype, allocator=backend)


def allocate_indices(
Expand All @@ -34,6 +73,7 @@ def allocate_indices(
is_halfdim=False,
dtype=gtx.int32,
backend: Optional[backend.Backend] = None,
):
) -> gtx.Field:
xp = import_array_ns(backend)
shapex = grid.size[dim] + 1 if is_halfdim else grid.size[dim]
return gtx.as_field((dim,), xp.arange(shapex, dtype=dtype), allocator=backend)

0 comments on commit 042e8e5

Please sign in to comment.