diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/init_exner_pr.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/init_exner_pr.py
new file mode 100644
index 0000000000..5ec1cb398c
--- /dev/null
+++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/init_exner_pr.py
@@ -0,0 +1,50 @@
+# ICON4Py - ICON inspired code in Python and GT4Py
+#
+# Copyright (c) 2022, ETH Zurich and MeteoSwiss
+# All rights reserved.
+#
+# This file is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+from gt4py.next.common import GridType
+from gt4py.next.ffront.decorator import field_operator, program
+from gt4py.next.ffront.fbuiltins import Field, int32
+
+from icon4py.model.common.dimension import CellDim, KDim
+from icon4py.model.common.settings import backend
+from icon4py.model.common.type_alias import vpfloat
+
+
+@field_operator
+def _init_exner_pr(
+ exner: Field[[CellDim, KDim], vpfloat],
+ exner_ref: Field[[CellDim, KDim], vpfloat],
+) -> Field[[CellDim, KDim], vpfloat]:
+ exner_pr = exner - exner_ref
+ return exner_pr
+
+
+@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
+def init_exner_pr(
+ exner: Field[[CellDim, KDim], vpfloat],
+ exner_ref: Field[[CellDim, KDim], vpfloat],
+ exner_pr: Field[[CellDim, KDim], vpfloat],
+ horizontal_start: int32,
+ horizontal_end: int32,
+ vertical_start: int32,
+ vertical_end: int32,
+):
+ _init_exner_pr(
+ exner,
+ exner_ref,
+ out=exner_pr,
+ domain={
+ CellDim: (horizontal_start, horizontal_end),
+ KDim: (vertical_start, vertical_end),
+ },
+ )
diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/nh_solve/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/nh_solve/solve_nonhydro.py
index c7ce4a55c2..ae6df8f9ac 100644
--- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/nh_solve/solve_nonhydro.py
+++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/nh_solve/solve_nonhydro.py
@@ -1647,20 +1647,20 @@ def run_corrector_step(
offset_provider={},
)
- # verified for e-9
- log.debug(f"corrector: start stencile 41")
- compute_divergence_of_fluxes_of_rho_and_theta(
- geofac_div=self.interpolation_state.geofac_div,
- mass_fl_e=diagnostic_state_nh.mass_fl_e,
- z_theta_v_fl_e=self.z_theta_v_fl_e,
- z_flxdiv_mass=self.z_flxdiv_mass,
- z_flxdiv_theta=self.z_flxdiv_theta,
- horizontal_start=start_cell_nudging,
- horizontal_end=end_cell_local,
- vertical_start=0,
- vertical_end=self.grid.num_levels,
- offset_provider=self.grid.offset_providers,
- )
+ # verified for e-9
+ log.debug(f"corrector: start stencil 41")
+ compute_divergence_of_fluxes_of_rho_and_theta(
+ geofac_div=self.interpolation_state.geofac_div,
+ mass_fl_e=diagnostic_state_nh.mass_fl_e,
+ z_theta_v_fl_e=self.z_theta_v_fl_e,
+ z_flxdiv_mass=self.z_flxdiv_mass,
+ z_flxdiv_theta=self.z_flxdiv_theta,
+ horizontal_start=start_cell_nudging,
+ horizontal_end=end_cell_local,
+ vertical_start=0,
+ vertical_end=self.grid.num_levels,
+ offset_provider=self.grid.offset_providers,
+ )
if self.config.itime_scheme == 4:
log.debug(f"corrector start stencil 42 44 45 45b")
diff --git a/model/atmosphere/dycore/tests/dycore_stencil_tests/test_init_exner_pr.py b/model/atmosphere/dycore/tests/dycore_stencil_tests/test_init_exner_pr.py
new file mode 100644
index 0000000000..a40dd7e5bd
--- /dev/null
+++ b/model/atmosphere/dycore/tests/dycore_stencil_tests/test_init_exner_pr.py
@@ -0,0 +1,51 @@
+# ICON4Py - ICON inspired code in Python and GT4Py
+#
+# Copyright (c) 2022, ETH Zurich and MeteoSwiss
+# All rights reserved.
+#
+# This file is free software: you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the
+# Free Software Foundation, either version 3 of the License, or any later
+# version. See the LICENSE.txt file at the top-level directory of this
+# distribution for a copy of the license or check .
+#
+# SPDX-License-Identifier: GPL-3.0-or-later
+
+import numpy as np
+import pytest
+from gt4py.next.ffront.fbuiltins import int32
+
+from icon4py.model.atmosphere.dycore.init_exner_pr import (
+ init_exner_pr,
+)
+from icon4py.model.common.dimension import CellDim, KDim
+from icon4py.model.common.test_utils.helpers import StencilTest, random_field, zero_field
+from icon4py.model.common.type_alias import vpfloat
+
+
+class TestInitExnerPr(StencilTest):
+ PROGRAM = init_exner_pr
+ OUTPUTS = ("exner_pr",)
+
+ @staticmethod
+ def reference(grid, exner: np.array, exner_ref: np.array, **kwargs) -> dict:
+ exner_pr = exner - exner_ref
+ return dict(
+ exner_pr=exner_pr,
+ )
+
+ @pytest.fixture
+ def input_data(self, grid):
+ exner = random_field(grid, CellDim, KDim, dtype=vpfloat)
+ exner_ref = random_field(grid, CellDim, KDim, dtype=vpfloat)
+ exner_pr = zero_field(grid, CellDim, KDim, dtype=vpfloat)
+
+ return dict(
+ exner=exner,
+ exner_ref=exner_ref,
+ exner_pr=exner_pr,
+ horizontal_start=int32(0),
+ horizontal_end=int32(grid.num_cells),
+ vertical_start=int32(0),
+ vertical_end=int32(grid.num_levels),
+ )
diff --git a/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py b/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py
index c049f12349..b68422d276 100644
--- a/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py
+++ b/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py
@@ -536,12 +536,14 @@ def rbf_vec_coeff_e(self):
).transpose()
return as_field((EdgeDim, E2C2EDim), buffer)
+ @IconSavepoint.optionally_registered()
def rbf_vec_coeff_c1(self):
buffer = np.squeeze(
self.serializer.read("rbf_vec_coeff_c1", self.savepoint).astype(float)
).transpose()
return as_field((CellDim, C2E2C2EDim), buffer)
+ @IconSavepoint.optionally_registered()
def rbf_vec_coeff_c2(self):
buffer = np.squeeze(
self.serializer.read("rbf_vec_coeff_c2", self.savepoint).astype(float)
@@ -580,6 +582,7 @@ def hmask_dd3d(self):
def inv_ddqz_z_full(self):
return self._get_field("inv_ddqz_z_full", CellDim, KDim)
+ @IconSavepoint.optionally_registered(CellDim, KDim)
def ddqz_z_full(self):
return self._get_field("ddqz_z_full", CellDim, KDim)
diff --git a/model/driver/src/icon4py/model/driver/dycore_driver.py b/model/driver/src/icon4py/model/driver/dycore_driver.py
index 25d75230e7..6c183c9817 100644
--- a/model/driver/src/icon4py/model/driver/dycore_driver.py
+++ b/model/driver/src/icon4py/model/driver/dycore_driver.py
@@ -186,10 +186,10 @@ def time_integration(
for time_step in range(self._n_time_steps):
log.info(f"simulation date : {self._simulation_date} run timestep : {time_step}")
log.info(
- f" MAX VN: {prognostic_state_list[self._now].vn.asnumpy().max():.5e} , MAX W: {prognostic_state_list[self._now].w.asnumpy().max():.5e}"
+ f" MAX VN: {prognostic_state_list[self._now].vn.asnumpy().max():.15e} , MAX W: {prognostic_state_list[self._now].w.asnumpy().max():.15e}"
)
log.info(
- f" MAX RHO: {prognostic_state_list[self._now].rho.asnumpy().max():.5e} , MAX THETA_V: {prognostic_state_list[self._now].theta_v.asnumpy().max():.5e}"
+ f" MAX RHO: {prognostic_state_list[self._now].rho.asnumpy().max():.15e} , MAX THETA_V: {prognostic_state_list[self._now].theta_v.asnumpy().max():.15e}"
)
# TODO (Chia Rui): check with Anurag about printing of max and min of variables.
@@ -294,6 +294,8 @@ def initialize(
props: ProcessProperties,
serialization_type: SerializationType,
experiment_type: ExperimentType,
+ grid_root,
+ grid_level,
):
"""
Inititalize the driver run.
@@ -323,16 +325,24 @@ def initialize(
log.info(f"reading configuration: experiment {experiment_type}")
config = read_config(experiment_type)
- decomp_info = read_decomp_info(file_path, props, serialization_type)
+ decomp_info = read_decomp_info(file_path, props, serialization_type, grid_root, grid_level)
log.info(f"initializing the grid from '{file_path}'")
- icon_grid = read_icon_grid(file_path, rank=props.rank, ser_type=serialization_type)
+ icon_grid = read_icon_grid(
+ file_path,
+ rank=props.rank,
+ ser_type=serialization_type,
+ grid_root=grid_root,
+ grid_level=grid_level,
+ )
log.info(f"reading input fields from '{file_path}'")
(edge_geometry, cell_geometry, vertical_geometry, c_owner_mask) = read_geometry_fields(
file_path,
damping_height=config.run_config.damping_height,
rank=props.rank,
ser_type=serialization_type,
+ grid_root=grid_root,
+ grid_level=grid_level,
)
(
diffusion_metric_state,
@@ -341,7 +351,11 @@ def initialize(
solve_nonhydro_interpolation_state,
diagnostic_metric_state,
) = read_static_fields(
- file_path, rank=props.rank, ser_type=serialization_type, experiment_type=experiment_type
+ file_path,
+ rank=props.rank,
+ ser_type=serialization_type,
+ grid_root=grid_root,
+ grid_level=grid_level,
)
log.info("initializing diffusion")
@@ -418,7 +432,9 @@ def initialize(
help="serialization type for grid info and static fields",
)
@click.option("--experiment_type", default="any", help="experiment selection")
-def main(input_path, run_path, mpi, serialization_type, experiment_type):
+@click.option("--grid_root", default=2, help="experiment selection")
+@click.option("--grid_level", default=4, help="experiment selection")
+def main(input_path, run_path, mpi, serialization_type, experiment_type, grid_root, grid_level):
"""
Run the driver.
@@ -449,7 +465,9 @@ def main(input_path, run_path, mpi, serialization_type, experiment_type):
diagnostic_state,
prep_adv,
inital_divdamp_fac_o2,
- ) = initialize(Path(input_path), parallel_props, serialization_type, experiment_type)
+ ) = initialize(
+ Path(input_path), parallel_props, serialization_type, experiment_type, grid_root, grid_level
+ )
log.info(f"Starting ICON dycore run: {timeloop.simulation_date.isoformat()}")
log.info(
f"input args: input_path={input_path}, n_time_steps={timeloop.n_time_steps}, ending date={timeloop.run_config.end_date}"
diff --git a/model/driver/src/icon4py/model/driver/initialization_utils.py b/model/driver/src/icon4py/model/driver/initialization_utils.py
index c6a8711c08..968faf973e 100644
--- a/model/driver/src/icon4py/model/driver/initialization_utils.py
+++ b/model/driver/src/icon4py/model/driver/initialization_utils.py
@@ -25,6 +25,7 @@
DiffusionInterpolationState,
DiffusionMetricState,
)
+from icon4py.model.atmosphere.dycore.init_exner_pr import init_exner_pr
from icon4py.model.atmosphere.dycore.state_utils.states import (
DiagnosticStateNonHydro,
InterpolationState,
@@ -45,7 +46,6 @@
from icon4py.model.common.decomposition.definitions import DecompositionInfo, ProcessProperties
from icon4py.model.common.decomposition.mpi_decomposition import ParallelLogger
from icon4py.model.common.dimension import (
- C2E2C2EDim,
CEDim,
CellDim,
EdgeDim,
@@ -88,29 +88,35 @@ class SerializationType(str, Enum):
class ExperimentType(str, Enum):
- """Jablonowski-Williamson test"""
-
JABW = "jabw"
- """any test with initial conditions read from serialized data"""
+ """initial condition of Jablonowski-Williamson test"""
ANY = "any"
+ """any test with initial conditions read from serialized data (remember to set correct SIMULATION_START_DATE)"""
def read_icon_grid(
- path: Path, rank=0, ser_type: SerializationType = SerializationType.SB
+ path: Path,
+ rank=0,
+ ser_type: SerializationType = SerializationType.SB,
+ grid_root=2,
+ grid_level=4,
) -> IconGrid:
"""
Read icon grid.
Args:
path: path where to find the input data
+ rank: mpi rank of the current compute node
ser_type: type of input data. Currently only 'sb (serialbox)' is supported. It reads
from ppser serialized test data
+ grid_root: global grid root division number
+ grid_level: global grid refinement number
Returns: IconGrid parsed from a given input type.
"""
if ser_type == SerializationType.SB:
return (
sb.IconSerialDataProvider("icon_pydycore", str(path.absolute()), False, mpi_rank=rank)
- .from_savepoint_grid(2, 4)
+ .from_savepoint_grid(grid_root, grid_level)
.construct_icon_grid(on_gpu=False)
)
else:
@@ -123,7 +129,29 @@ def model_initialization_jabw(
edge_param: EdgeParams,
path: Path,
rank=0,
-):
+) -> tuple[
+ DiffusionDiagnosticState,
+ DiagnosticStateNonHydro,
+ PrepAdvection,
+ float,
+ DiagnosticState,
+ PrognosticState,
+ PrognosticState,
+]:
+ """
+ Initial condition of Jablonowski-Williamson test. Set jw_up to values larger than 0.01 if
+ you want to run baroclinic case.
+
+ Args:
+ icon_grid: IconGrid
+ cell_param: cell properties
+ edge_param: edge properties
+ path: path where to find the input data
+ rank: mpi rank of the current compute node
+ Returns: A tuple containing Diagnostic variables for diffusion and solve_nonhydro granules,
+ PrepAdvection, second order divdamp factor, diagnostic variables, and two prognostic
+ variables (now and next).
+ """
data_provider = sb.IconSerialDataProvider(
"icon_pydycore", str(path.absolute()), False, mpi_rank=rank
)
@@ -152,6 +180,9 @@ def model_initialization_jabw(
EdgeDim, HorizontalMarkerIndex.lateral_boundary(EdgeDim) + 1
)
grid_idx_edge_end = icon_grid.get_end_index(EdgeDim, HorizontalMarkerIndex.end(EdgeDim))
+ grid_idx_cell_interior_start = icon_grid.get_start_index(
+ CellDim, HorizontalMarkerIndex.interior(CellDim)
+ )
grid_idx_cell_start_plus1 = icon_grid.get_end_index(
CellDim, HorizontalMarkerIndex.lateral_boundary(CellDim) + 1
)
@@ -324,6 +355,19 @@ def model_initialization_jabw(
log.info("U, V computation completed.")
+ exner_pr = _allocate(CellDim, KDim, grid=icon_grid)
+ init_exner_pr(
+ exner,
+ data_provider.from_metrics_savepoint().exner_ref_mc(),
+ exner_pr,
+ grid_idx_cell_interior_start,
+ grid_idx_cell_end,
+ 0,
+ icon_grid.num_levels,
+ offset_provider={},
+ )
+ log.info("exner_pr initialization completed.")
+
diagnostic_state = DiagnosticState(
pressure=pressure,
pressure_ifc=pressure_ifc,
@@ -355,7 +399,7 @@ def model_initialization_jabw(
)
solve_nonhydro_diagnostic_state = DiagnosticStateNonHydro(
theta_v_ic=_allocate(CellDim, KDim, grid=icon_grid, is_halfdim=True),
- exner_pr=_allocate(CellDim, KDim, grid=icon_grid),
+ exner_pr=exner_pr,
rho_ic=_allocate(CellDim, KDim, grid=icon_grid, is_halfdim=True),
ddt_exner_phy=_allocate(CellDim, KDim, grid=icon_grid),
grf_tend_rho=_allocate(CellDim, KDim, grid=icon_grid),
@@ -396,19 +440,28 @@ def model_initialization_jabw(
)
-def model_initialization_serialbox(icon_grid: IconGrid, path: Path, rank=0):
+def model_initialization_serialbox(
+ icon_grid: IconGrid, path: Path, rank=0
+) -> tuple[
+ DiffusionDiagnosticState,
+ DiagnosticStateNonHydro,
+ PrepAdvection,
+ float,
+ DiagnosticState,
+ PrognosticState,
+ PrognosticState,
+]:
"""
- Read prognostic and diagnostic state from serialized data.
+ Initial condition read from serialized data. Diagnostic variables are allocated as zero
+ fields.
Args:
- icon_grid: icon grid
- path: path to the serialized input data
+ icon_grid: IconGrid
+ path: path where to find the input data
rank: mpi rank of the current compute node
-
- Returns: a tuple containing the data_provider, the initial diagnostic and prognostic state.
- The data_provider is returned such that further timesteps of diagnostics and prognostics
- can be read from within the dummy timeloop
-
+ Returns: A tuple containing Diagnostic variables for diffusion and solve_nonhydro granules,
+ PrepAdvection, second order divdamp factor, diagnostic variables, and two prognostic
+ variables (now and next).
"""
data_provider = sb.IconSerialDataProvider(
@@ -453,9 +506,8 @@ def model_initialization_serialbox(icon_grid: IconGrid, path: Path, rank=0):
diagnostic_state = DiagnosticState(
pressure=_allocate(CellDim, KDim, grid=icon_grid),
- pressure_ifc=_allocate(CellDim, KDim, grid=icon_grid),
+ pressure_ifc=_allocate(CellDim, KDim, grid=icon_grid, is_halfdim=True),
temperature=_allocate(CellDim, KDim, grid=icon_grid),
- pressure_sfc=_allocate(CellDim, grid=icon_grid),
u=_allocate(CellDim, KDim, grid=icon_grid),
v=_allocate(CellDim, KDim, grid=icon_grid),
)
@@ -501,6 +553,21 @@ def read_initial_state(
PrognosticState,
PrognosticState,
]:
+ """
+ Read initial prognostic and diagnostic fields.
+
+ Args:
+ icon_grid: IconGrid
+ cell_param: cell properties
+ edge_param: edge properties
+ path: path to the serialized input data
+ rank: mpi rank of the current compute node
+ experiment_type: (optional) defaults to ANY=any, type of initial condition to be read
+
+ Returns: A tuple containing Diagnostic variables for diffusion and solve_nonhydro granules,
+ PrepAdvection, second order divdamp factor, diagnostic variables, and two prognostic
+ variables (now and next).
+ """
if experiment_type == ExperimentType.JABW:
(
diffusion_diagnostic_state,
@@ -536,16 +603,23 @@ def read_initial_state(
def read_geometry_fields(
- path: Path, damping_height, rank=0, ser_type: SerializationType = SerializationType.SB
+ path: Path,
+ damping_height,
+ rank=0,
+ ser_type: SerializationType = SerializationType.SB,
+ grid_root=2,
+ grid_level=4,
) -> tuple[EdgeParams, CellParams, VerticalModelParams, Field[[CellDim], bool]]:
"""
Read fields containing grid properties.
Args:
path: path to the serialized input data
- damping_height: damping height for Rayleigh and divergence damping TODO (CHia Rui): Check
- rank:
+ damping_height: damping height for Rayleigh and divergence damping
+ rank: mpi rank of the current compute node
ser_type: (optional) defaults to SB=serialbox, type of input data to be read
+ grid_root: global grid root division number
+ grid_level: global grid refinement number
Returns: a tuple containing fields describing edges, cells, vertical properties of the model
the data is originally obtained from the grid file (horizontal fields) or some special input files.
@@ -553,7 +627,7 @@ def read_geometry_fields(
if ser_type == SerializationType.SB:
sp = sb.IconSerialDataProvider(
"icon_pydycore", str(path.absolute()), False, mpi_rank=rank
- ).from_savepoint_grid(2, 4)
+ ).from_savepoint_grid(grid_root, grid_level)
edge_geometry = sp.construct_edge_geometry()
cell_geometry = sp.construct_cell_geometry()
vertical_geometry = VerticalModelParams(
@@ -571,12 +645,14 @@ def read_decomp_info(
path: Path,
procs_props: ProcessProperties,
ser_type=SerializationType.SB,
+ grid_root=2,
+ grid_level=4,
) -> DecompositionInfo:
if ser_type == SerializationType.SB:
sp = sb.IconSerialDataProvider(
"icon_pydycore", str(path.absolute()), True, procs_props.rank
)
- return sp.from_savepoint_grid(2, 4).construct_decomposition_info()
+ return sp.from_savepoint_grid(grid_root, grid_level).construct_decomposition_info()
else:
raise NotImplementedError(SB_ONLY_MSG)
@@ -585,7 +661,8 @@ def read_static_fields(
path: Path,
rank=0,
ser_type: SerializationType = SerializationType.SB,
- experiment_type: ExperimentType = ExperimentType.JABW,
+ grid_root=2,
+ grid_level=4,
) -> tuple[
DiffusionMetricState,
DiffusionInterpolationState,
@@ -600,7 +677,8 @@ def read_static_fields(
path: path to the serialized input data
rank: mpi rank, defaults to 0 for serial run
ser_type: (optional) defaults to SB=serialbox, type of input data to be read
- experiment_type: TODO (CHia RUi): Add description
+ grid_root: global grid root division number
+ grid_level: global grid refinement number
Returns:
a tuple containing the metric_state and interpolation state,
@@ -613,7 +691,7 @@ def read_static_fields(
)
icon_grid = (
sb.IconSerialDataProvider("icon_pydycore", str(path.absolute()), False, mpi_rank=rank)
- .from_savepoint_grid(2, 4)
+ .from_savepoint_grid(grid_root, grid_level)
.construct_icon_grid(on_gpu=False)
)
diffusion_interpolation_state = construct_interpolation_state_for_diffusion(
@@ -679,19 +757,11 @@ def read_static_fields(
coeff_gradekin=metrics_savepoint.coeff_gradekin(),
)
- if experiment_type == ExperimentType.JABW:
- diagnostic_metric_state = DiagnosticMetricState(
- ddqz_z_full=metrics_savepoint.ddqz_z_full(),
- rbf_vec_coeff_c1=data_provider.from_interpolation_savepoint().rbf_vec_coeff_c1(),
- rbf_vec_coeff_c2=data_provider.from_interpolation_savepoint().rbf_vec_coeff_c2(),
- )
- else:
- # ddqz_z_full is not serialized for mch_ch_r04b09_dsl and exclaim_ape_R02B04
- diagnostic_metric_state = DiagnosticMetricState(
- ddqz_z_full=_allocate(CellDim, KDim, grid=icon_grid, dtype=float),
- rbf_vec_coeff_c1=_allocate(CellDim, C2E2C2EDim, grid=icon_grid, dtype=float),
- rbf_vec_coeff_c2=_allocate(CellDim, C2E2C2EDim, grid=icon_grid, dtype=float),
- )
+ diagnostic_metric_state = DiagnosticMetricState(
+ ddqz_z_full=metrics_savepoint.ddqz_z_full(),
+ rbf_vec_coeff_c1=interpolation_savepoint.rbf_vec_coeff_c1(),
+ rbf_vec_coeff_c2=interpolation_savepoint.rbf_vec_coeff_c2(),
+ )
return (
diffusion_metric_state,
diff --git a/model/driver/src/icon4py/model/driver/jablonowski_willamson_testcase.py b/model/driver/src/icon4py/model/driver/jablonowski_willamson_testcase.py
index ddec9fee7a..599f076d88 100644
--- a/model/driver/src/icon4py/model/driver/jablonowski_willamson_testcase.py
+++ b/model/driver/src/icon4py/model/driver/jablonowski_willamson_testcase.py
@@ -29,7 +29,21 @@ def zonalwind_2_normalwind_jabw_numpy(
primal_normal_x: np.array,
eta_v_e: np.array,
):
- """mask = np.repeat(np.expand_dims(mask, axis=-1), eta_v_e.shape[1], axis=1)"""
+ """
+ Compute normal wind at edge center from vertical eta coordinate (eta_v_e).
+
+ Args:
+ icon_grid: IconGrid
+ jw_u0: base zonal wind speed factor
+ jw_up: perturbation amplitude
+ lat_perturbation_center: perturbation center in latitude
+ lon_perturbation_center: perturbation center in longitude
+ edge_lat: edge center latitude
+ edge_lon: edge center longitude
+ primal_normal_x: zonal component of primal normal vector at edge center
+ eta_v_e: vertical eta coordinate at edge center
+ Returns: normal wind
+ """
mask = np.ones((icon_grid.num_edges, icon_grid.num_levels), dtype=bool)
mask[
0 : icon_grid.get_end_index(EdgeDim, HorizontalMarkerIndex.lateral_boundary(EdgeDim) + 1), :
diff --git a/model/driver/tests/test_testcase.py b/model/driver/tests/test_testcase.py
index ad1211c282..341251fe17 100644
--- a/model/driver/tests/test_testcase.py
+++ b/model/driver/tests/test_testcase.py
@@ -89,3 +89,9 @@ def test_jabw_initial_condition(
diagnostic_state.pressure_sfc.asnumpy(),
data_provider.from_savepoint_jabw_init().pressure_sfc().asnumpy(),
)
+
+ assert dallclose(
+ solve_nonhydro_diagnostic_state.exner_pr.asnumpy(),
+ data_provider.from_savepoint_jabw_diagnostic().exner_pr().asnumpy(),
+ atol=1.0e-14,
+ )