Skip to content

Commit

Permalink
Merge branch 'settings_elements_migration' of https://github.com/C2SM…
Browse files Browse the repository at this point in the history
…/icon4py into settings_elements_migration
  • Loading branch information
nfarabullini committed Nov 26, 2024
2 parents 9827453 + a06d303 commit 8acb135
Showing 1 changed file with 19 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
compute_antidiffusive_cell_fluxes_and_min_max,
)
from icon4py.model.common import dimension as dims
from icon4py.model.common.settings import xp
import numpy as np


class TestComputeAntidiffusiveCellFluxesAndMinMax(helpers.StencilTest):
Expand All @@ -30,44 +30,44 @@ class TestComputeAntidiffusiveCellFluxesAndMinMax(helpers.StencilTest):
@staticmethod
def reference(
grid,
geofac_div: xp.ndarray,
p_rhodz_now: xp.ndarray,
p_rhodz_new: xp.ndarray,
z_mflx_low: xp.ndarray,
z_anti: xp.ndarray,
p_cc: xp.ndarray,
geofac_div: np.ndarray,
p_rhodz_now: np.ndarray,
p_rhodz_new: np.ndarray,
z_mflx_low: np.ndarray,
z_anti: np.ndarray,
p_cc: np.ndarray,
p_dtime: float,
**kwargs,
) -> dict:
c2e = xp.asarray(grid.connectivities[dims.C2EDim])
c2e = grid.connectivities[dims.C2EDim]
z_anti_c2e = z_anti[c2e]

geofac_div = helpers.reshape(geofac_div, c2e.shape)
geofac_div = xp.expand_dims(xp.asarray(geofac_div), axis=-1)
geofac_div = np.expand_dims(geofac_div, axis=-1)

zero_array = xp.zeros(p_rhodz_now.shape)
zero_array = np.zeros(p_rhodz_now.shape)

z_mflx_anti_1 = p_dtime * geofac_div[:, 0] / p_rhodz_new * z_anti_c2e[:, 0]
z_mflx_anti_2 = p_dtime * geofac_div[:, 1] / p_rhodz_new * z_anti_c2e[:, 1]
z_mflx_anti_3 = p_dtime * geofac_div[:, 2] / p_rhodz_new * z_anti_c2e[:, 2]

z_mflx_anti_in = -1.0 * (
xp.minimum(zero_array, z_mflx_anti_1)
+ xp.minimum(zero_array, z_mflx_anti_2)
+ xp.minimum(zero_array, z_mflx_anti_3)
np.minimum(zero_array, z_mflx_anti_1)
+ np.minimum(zero_array, z_mflx_anti_2)
+ np.minimum(zero_array, z_mflx_anti_3)
)

z_mflx_anti_out = (
xp.maximum(zero_array, z_mflx_anti_1)
+ xp.maximum(zero_array, z_mflx_anti_2)
+ xp.maximum(zero_array, z_mflx_anti_3)
np.maximum(zero_array, z_mflx_anti_1)
+ np.maximum(zero_array, z_mflx_anti_2)
+ np.maximum(zero_array, z_mflx_anti_3)
)

z_fluxdiv_c = xp.sum(z_mflx_low[c2e] * geofac_div, axis=1)
z_fluxdiv_c = np.sum(z_mflx_low[c2e] * geofac_div, axis=1)

z_tracer_new_low = (p_cc * p_rhodz_now - p_dtime * z_fluxdiv_c) / p_rhodz_new
z_tracer_max = xp.maximum(p_cc, z_tracer_new_low)
z_tracer_min = xp.minimum(p_cc, z_tracer_new_low)
z_tracer_max = np.maximum(p_cc, z_tracer_new_low)
z_tracer_min = np.minimum(p_cc, z_tracer_new_low)

return dict(
z_mflx_anti_in=z_mflx_anti_in,
Expand Down

0 comments on commit 8acb135

Please sign in to comment.