Skip to content

Commit

Permalink
some edits
Browse files Browse the repository at this point in the history
  • Loading branch information
nfarabullini committed Nov 26, 2024
1 parent 2a43848 commit 14ccc0e
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
import gt4py.next as gtx
import numpy as np
import pytest

from icon4py.model.atmosphere.diffusion.stencils.apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence import (
apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence,
)
from icon4py.model.common import dimension as dims
from icon4py.model.common.settings import xp
from icon4py.model.common.test_utils.helpers import StencilTest, random_field, zero_field

from .test_apply_nabla2_to_w import apply_nabla2_to_w_numpy
Expand Down Expand Up @@ -49,10 +49,10 @@ def reference(
halo_idx,
**kwargs,
):
reshaped_k = k[xp.newaxis, :]
reshaped_cell = cell[:, xp.newaxis]
reshaped_k = k[np.newaxis, :]
reshaped_cell = cell[:, np.newaxis]
if type_shear == 2:
dwdx, dwdy = xp.where(
dwdx, dwdy = np.where(
0 < reshaped_k,
calculate_horizontal_gradients_for_turbulence_numpy(
grid, w_old, geofac_grg_x, geofac_grg_y
Expand All @@ -62,13 +62,13 @@ def reference(

z_nabla2_c = calculate_nabla2_for_w_numpy(grid, w_old, geofac_n2s)

w = xp.where(
w = np.where(
(interior_idx <= reshaped_cell) & (reshaped_cell < halo_idx),
apply_nabla2_to_w_numpy(grid, area, z_nabla2_c, geofac_n2s, w_old, diff_multfac_w),
w_old,
)

w = xp.where(
w = np.where(
(0 < reshaped_k)
& (reshaped_k < nrdmax)
& (interior_idx <= reshaped_cell)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,24 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import numpy as np
import pytest

from icon4py.model.atmosphere.diffusion.stencils.calculate_diagnostics_for_turbulence import (
calculate_diagnostics_for_turbulence,
)
from icon4py.model.common import dimension as dims
from icon4py.model.common.settings import xp
from icon4py.model.common.test_utils.helpers import StencilTest, random_field, zero_field
from icon4py.model.common.type_alias import vpfloat


def calculate_diagnostics_for_turbulence_numpy(
wgtfac_c: xp.array, div: xp.array, kh_c: xp.array, div_ic, hdef_ic
) -> tuple[xp.array, xp.array]:
div = xp.asarray(div)
wgtfac_c = xp.asarray(wgtfac_c)
kc_offset_1 = xp.roll(xp.asarray(kh_c), shift=1, axis=1)
div_offset_1 = xp.roll(div, shift=1, axis=1)
wgtfac_c: np.array, div: np.array, kh_c: np.array, div_ic, hdef_ic
) -> tuple[np.array, np.array]:
div = np.asarray(div)
wgtfac_c = np.asarray(wgtfac_c)
kc_offset_1 = np.roll(np.asarray(kh_c), shift=1, axis=1)
div_offset_1 = np.roll(div, shift=1, axis=1)
div_ic[:, 1:] = (wgtfac_c * div + (1.0 - wgtfac_c) * div_offset_1)[:, 1:]
hdef_ic[:, 1:] = ((wgtfac_c * kh_c + (1.0 - wgtfac_c) * kc_offset_1) ** 2)[:, 1:]
return div_ic, hdef_ic
Expand All @@ -34,7 +34,7 @@ class TestCalculateDiagnosticsForTurbulence(StencilTest):
OUTPUTS = ("div_ic", "hdef_ic")

@staticmethod
def reference(grid, wgtfac_c: xp.array, div: xp.array, kh_c: xp.array, div_ic, hdef_ic) -> dict:
def reference(grid, wgtfac_c: np.array, div: np.array, kh_c: np.array, div_ic, hdef_ic) -> dict:
div_ic, hdef_ic = calculate_diagnostics_for_turbulence_numpy(
wgtfac_c, div, kh_c, div_ic, hdef_ic
)
Expand Down
3 changes: 1 addition & 2 deletions model/common/src/icon4py/model/common/test_utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
from dataclasses import dataclass, field
from typing import ClassVar, Optional

import numpy as np
import numpy.typing as npt
import pytest
from gt4py._core.definitions import is_scalar_type
from gt4py.next import as_field, common as gt_common, constructors
from gt4py.next.ffront.decorator import Program

import numpy as np

from ..grid.base import BaseGrid
from ..type_alias import wpfloat

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from typing import Optional

import gt4py.next as gtx
import numpy as np
from gt4py.next import backend

from icon4py.model.common import type_alias as ta
import numpy as np


def allocate_zero_field(
*dims: gtx.Dimension,
Expand Down

0 comments on commit 14ccc0e

Please sign in to comment.