Skip to content

Commit

Permalink
from xp to np edits
Browse files Browse the repository at this point in the history
  • Loading branch information
nfarabullini committed Nov 26, 2024
1 parent bf95797 commit 3b6b3be
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions model/common/src/icon4py/model/common/test_utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from gt4py.next import as_field, common as gt_common, constructors
from gt4py.next.ffront.decorator import Program

from icon4py.model.common.settings import xp
import numpy as np

from ..grid.base import BaseGrid
from ..type_alias import wpfloat
Expand Down Expand Up @@ -70,13 +70,13 @@ def random_mask(
dtype: Optional[npt.DTypeLike] = None,
extend: Optional[dict[gt_common.Dimension, int]] = None,
) -> gt_common.Field:
rng = xp.random.default_rng()
rng = np.random.default_rng()
shape = _shape(grid, *dims, extend=extend)
arr = xp.full(shape, False).flatten()
arr = np.full(shape, False).flatten()
num_true = int(arr.size * 0.5)
arr[:num_true] = True
rng.shuffle(arr)
arr = xp.reshape(arr, newshape=shape)
arr = np.reshape(arr, newshape=shape)
if dtype:
arr = arr.astype(dtype)
return as_field(dims, arr)
Expand All @@ -90,7 +90,7 @@ def random_field(
extend: Optional[dict[gt_common.Dimension, int]] = None,
dtype: Optional[npt.DTypeLike] = None,
) -> gt_common.Field:
arr = xp.random.default_rng().uniform(
arr = np.random.default_rng().uniform(
low=low, high=high, size=_shape(grid, *dims, extend=extend)
)
if dtype:
Expand All @@ -104,15 +104,15 @@ def zero_field(
dtype=wpfloat,
extend: Optional[dict[gt_common.Dimension, int]] = None,
) -> gt_common.Field:
return as_field(dims, xp.zeros(shape=_shape(grid, *dims, extend=extend), dtype=dtype))
return as_field(dims, np.zeros(shape=_shape(grid, *dims, extend=extend), dtype=dtype))


def constant_field(
grid: BaseGrid, value: float, *dims: gt_common.Dimension, dtype=wpfloat
) -> gt_common.Field:
return as_field(
dims,
value * xp.ones(shape=tuple(map(lambda x: grid.size[x], dims)), dtype=dtype),
value * np.ones(shape=tuple(map(lambda x: grid.size[x], dims)), dtype=dtype),
)


Expand All @@ -122,7 +122,7 @@ def as_1D_sparse_field(field: gt_common.Field, target_dim: gt_common.Dimension)
return numpy_to_1D_sparse_field(buffer, target_dim)


def numpy_to_1D_sparse_field(field: xp.ndarray, dim: gt_common.Dimension) -> gt_common.Field:
def numpy_to_1D_sparse_field(field: np.ndarray, dim: gt_common.Dimension) -> gt_common.Field:
"""Convert a 2D sparse field to a 1D flattened (Felix-style) sparse field."""
old_shape = field.shape
assert len(old_shape) == 2
Expand All @@ -142,15 +142,15 @@ def flatten_first_two_dims(*dims: gt_common.Dimension, field: gt_common.Field) -
return as_field(dims, newarray)


def unflatten_first_two_dims(field: gt_common.Field) -> xp.array:
def unflatten_first_two_dims(field: gt_common.Field) -> np.array:
"""Convert a (n-1)-D flattened (Felix-style) sparse field back to a n-D sparse field."""
old_shape = xp.asarray(field).shape
old_shape = np.asarray(field).shape
new_shape = (old_shape[0] // 3, 3) + old_shape[1:]
return xp.asarray(field).reshape(new_shape)
return np.asarray(field).reshape(new_shape)


def dallclose(a, b, rtol=1.0e-12, atol=0.0, equal_nan=False):
return xp.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
return np.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)


def allocate_data(backend, input_data):
Expand All @@ -173,7 +173,7 @@ def _test_validation(self, grid, backend, input_data):
reference_outputs = self.reference(
grid,
**{
k: v.asnumpy() if isinstance(v, gt_common.Field) else xp.array(v)
k: v.asnumpy() if isinstance(v, gt_common.Field) else np.array(v)
for k, v in input_data.items()
},
)
Expand All @@ -191,7 +191,7 @@ def _test_validation(self, grid, backend, input_data):
else (out, (slice(None),), (slice(None),))
)

assert xp.allclose(
assert np.allclose(
input_data[name].asnumpy()[gtslice],
reference_outputs[name][refslice],
equal_nan=True,
Expand Down Expand Up @@ -235,7 +235,7 @@ class StencilTest:
...
... @staticmethod
... def reference(some_input, **kwargs):
... return dict(some_output=xp.asarray(some_input) * 2)
... return dict(some_output=np.asarray(some_input) * 2)
"""

PROGRAM: ClassVar[Program]
Expand All @@ -250,5 +250,5 @@ def __init_subclass__(cls, **kwargs):
setattr(cls, f"test_{cls.__name__}_benchmark", _test_execution_benchmark)


def reshape(arr: xp.array, shape: tuple[int, ...]):
return xp.reshape(arr, shape)
def reshape(arr: np.array, shape: tuple[int, ...]):
return np.reshape(arr, shape)

0 comments on commit 3b6b3be

Please sign in to comment.