diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/init_exner_pr.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/init_exner_pr.py new file mode 100644 index 0000000000..5ec1cb398c --- /dev/null +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/init_exner_pr.py @@ -0,0 +1,50 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +from gt4py.next.common import GridType +from gt4py.next.ffront.decorator import field_operator, program +from gt4py.next.ffront.fbuiltins import Field, int32 + +from icon4py.model.common.dimension import CellDim, KDim +from icon4py.model.common.settings import backend +from icon4py.model.common.type_alias import vpfloat + + +@field_operator +def _init_exner_pr( + exner: Field[[CellDim, KDim], vpfloat], + exner_ref: Field[[CellDim, KDim], vpfloat], +) -> Field[[CellDim, KDim], vpfloat]: + exner_pr = exner - exner_ref + return exner_pr + + +@program(grid_type=GridType.UNSTRUCTURED, backend=backend) +def init_exner_pr( + exner: Field[[CellDim, KDim], vpfloat], + exner_ref: Field[[CellDim, KDim], vpfloat], + exner_pr: Field[[CellDim, KDim], vpfloat], + horizontal_start: int32, + horizontal_end: int32, + vertical_start: int32, + vertical_end: int32, +): + _init_exner_pr( + exner, + exner_ref, + out=exner_pr, + domain={ + CellDim: (horizontal_start, horizontal_end), + KDim: (vertical_start, vertical_end), + }, + ) diff --git a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/nh_solve/solve_nonhydro.py b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/nh_solve/solve_nonhydro.py index c7ce4a55c2..ae6df8f9ac 100644 --- a/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/nh_solve/solve_nonhydro.py +++ b/model/atmosphere/dycore/src/icon4py/model/atmosphere/dycore/nh_solve/solve_nonhydro.py @@ -1647,20 +1647,20 @@ def run_corrector_step( offset_provider={}, ) - # verified for e-9 - log.debug(f"corrector: start stencile 41") - compute_divergence_of_fluxes_of_rho_and_theta( - geofac_div=self.interpolation_state.geofac_div, - mass_fl_e=diagnostic_state_nh.mass_fl_e, - z_theta_v_fl_e=self.z_theta_v_fl_e, - z_flxdiv_mass=self.z_flxdiv_mass, - z_flxdiv_theta=self.z_flxdiv_theta, - horizontal_start=start_cell_nudging, - horizontal_end=end_cell_local, - vertical_start=0, - vertical_end=self.grid.num_levels, - offset_provider=self.grid.offset_providers, - ) + # verified for e-9 + log.debug(f"corrector: start stencil 41") + compute_divergence_of_fluxes_of_rho_and_theta( + geofac_div=self.interpolation_state.geofac_div, + mass_fl_e=diagnostic_state_nh.mass_fl_e, + z_theta_v_fl_e=self.z_theta_v_fl_e, + z_flxdiv_mass=self.z_flxdiv_mass, + z_flxdiv_theta=self.z_flxdiv_theta, + horizontal_start=start_cell_nudging, + horizontal_end=end_cell_local, + vertical_start=0, + vertical_end=self.grid.num_levels, + offset_provider=self.grid.offset_providers, + ) if self.config.itime_scheme == 4: log.debug(f"corrector start stencil 42 44 45 45b") diff --git a/model/atmosphere/dycore/tests/dycore_stencil_tests/test_init_exner_pr.py b/model/atmosphere/dycore/tests/dycore_stencil_tests/test_init_exner_pr.py new file mode 100644 index 0000000000..a40dd7e5bd --- /dev/null +++ b/model/atmosphere/dycore/tests/dycore_stencil_tests/test_init_exner_pr.py @@ -0,0 +1,51 @@ +# ICON4Py - ICON inspired code in Python and GT4Py +# +# Copyright (c) 2022, ETH Zurich and MeteoSwiss +# All rights reserved. +# +# This file is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np +import pytest +from gt4py.next.ffront.fbuiltins import int32 + +from icon4py.model.atmosphere.dycore.init_exner_pr import ( + init_exner_pr, +) +from icon4py.model.common.dimension import CellDim, KDim +from icon4py.model.common.test_utils.helpers import StencilTest, random_field, zero_field +from icon4py.model.common.type_alias import vpfloat + + +class TestInitExnerPr(StencilTest): + PROGRAM = init_exner_pr + OUTPUTS = ("exner_pr",) + + @staticmethod + def reference(grid, exner: np.array, exner_ref: np.array, **kwargs) -> dict: + exner_pr = exner - exner_ref + return dict( + exner_pr=exner_pr, + ) + + @pytest.fixture + def input_data(self, grid): + exner = random_field(grid, CellDim, KDim, dtype=vpfloat) + exner_ref = random_field(grid, CellDim, KDim, dtype=vpfloat) + exner_pr = zero_field(grid, CellDim, KDim, dtype=vpfloat) + + return dict( + exner=exner, + exner_ref=exner_ref, + exner_pr=exner_pr, + horizontal_start=int32(0), + horizontal_end=int32(grid.num_cells), + vertical_start=int32(0), + vertical_end=int32(grid.num_levels), + ) diff --git a/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py b/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py index c049f12349..b68422d276 100644 --- a/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py +++ b/model/common/src/icon4py/model/common/test_utils/serialbox_utils.py @@ -536,12 +536,14 @@ def rbf_vec_coeff_e(self): ).transpose() return as_field((EdgeDim, E2C2EDim), buffer) + @IconSavepoint.optionally_registered() def rbf_vec_coeff_c1(self): buffer = np.squeeze( self.serializer.read("rbf_vec_coeff_c1", self.savepoint).astype(float) ).transpose() return as_field((CellDim, C2E2C2EDim), buffer) + @IconSavepoint.optionally_registered() def rbf_vec_coeff_c2(self): buffer = np.squeeze( self.serializer.read("rbf_vec_coeff_c2", self.savepoint).astype(float) @@ -580,6 +582,7 @@ def hmask_dd3d(self): def inv_ddqz_z_full(self): return self._get_field("inv_ddqz_z_full", CellDim, KDim) + @IconSavepoint.optionally_registered(CellDim, KDim) def ddqz_z_full(self): return self._get_field("ddqz_z_full", CellDim, KDim) diff --git a/model/driver/src/icon4py/model/driver/dycore_driver.py b/model/driver/src/icon4py/model/driver/dycore_driver.py index 25d75230e7..6c183c9817 100644 --- a/model/driver/src/icon4py/model/driver/dycore_driver.py +++ b/model/driver/src/icon4py/model/driver/dycore_driver.py @@ -186,10 +186,10 @@ def time_integration( for time_step in range(self._n_time_steps): log.info(f"simulation date : {self._simulation_date} run timestep : {time_step}") log.info( - f" MAX VN: {prognostic_state_list[self._now].vn.asnumpy().max():.5e} , MAX W: {prognostic_state_list[self._now].w.asnumpy().max():.5e}" + f" MAX VN: {prognostic_state_list[self._now].vn.asnumpy().max():.15e} , MAX W: {prognostic_state_list[self._now].w.asnumpy().max():.15e}" ) log.info( - f" MAX RHO: {prognostic_state_list[self._now].rho.asnumpy().max():.5e} , MAX THETA_V: {prognostic_state_list[self._now].theta_v.asnumpy().max():.5e}" + f" MAX RHO: {prognostic_state_list[self._now].rho.asnumpy().max():.15e} , MAX THETA_V: {prognostic_state_list[self._now].theta_v.asnumpy().max():.15e}" ) # TODO (Chia Rui): check with Anurag about printing of max and min of variables. @@ -294,6 +294,8 @@ def initialize( props: ProcessProperties, serialization_type: SerializationType, experiment_type: ExperimentType, + grid_root, + grid_level, ): """ Inititalize the driver run. @@ -323,16 +325,24 @@ def initialize( log.info(f"reading configuration: experiment {experiment_type}") config = read_config(experiment_type) - decomp_info = read_decomp_info(file_path, props, serialization_type) + decomp_info = read_decomp_info(file_path, props, serialization_type, grid_root, grid_level) log.info(f"initializing the grid from '{file_path}'") - icon_grid = read_icon_grid(file_path, rank=props.rank, ser_type=serialization_type) + icon_grid = read_icon_grid( + file_path, + rank=props.rank, + ser_type=serialization_type, + grid_root=grid_root, + grid_level=grid_level, + ) log.info(f"reading input fields from '{file_path}'") (edge_geometry, cell_geometry, vertical_geometry, c_owner_mask) = read_geometry_fields( file_path, damping_height=config.run_config.damping_height, rank=props.rank, ser_type=serialization_type, + grid_root=grid_root, + grid_level=grid_level, ) ( diffusion_metric_state, @@ -341,7 +351,11 @@ def initialize( solve_nonhydro_interpolation_state, diagnostic_metric_state, ) = read_static_fields( - file_path, rank=props.rank, ser_type=serialization_type, experiment_type=experiment_type + file_path, + rank=props.rank, + ser_type=serialization_type, + grid_root=grid_root, + grid_level=grid_level, ) log.info("initializing diffusion") @@ -418,7 +432,9 @@ def initialize( help="serialization type for grid info and static fields", ) @click.option("--experiment_type", default="any", help="experiment selection") -def main(input_path, run_path, mpi, serialization_type, experiment_type): +@click.option("--grid_root", default=2, help="experiment selection") +@click.option("--grid_level", default=4, help="experiment selection") +def main(input_path, run_path, mpi, serialization_type, experiment_type, grid_root, grid_level): """ Run the driver. @@ -449,7 +465,9 @@ def main(input_path, run_path, mpi, serialization_type, experiment_type): diagnostic_state, prep_adv, inital_divdamp_fac_o2, - ) = initialize(Path(input_path), parallel_props, serialization_type, experiment_type) + ) = initialize( + Path(input_path), parallel_props, serialization_type, experiment_type, grid_root, grid_level + ) log.info(f"Starting ICON dycore run: {timeloop.simulation_date.isoformat()}") log.info( f"input args: input_path={input_path}, n_time_steps={timeloop.n_time_steps}, ending date={timeloop.run_config.end_date}" diff --git a/model/driver/src/icon4py/model/driver/initialization_utils.py b/model/driver/src/icon4py/model/driver/initialization_utils.py index c6a8711c08..968faf973e 100644 --- a/model/driver/src/icon4py/model/driver/initialization_utils.py +++ b/model/driver/src/icon4py/model/driver/initialization_utils.py @@ -25,6 +25,7 @@ DiffusionInterpolationState, DiffusionMetricState, ) +from icon4py.model.atmosphere.dycore.init_exner_pr import init_exner_pr from icon4py.model.atmosphere.dycore.state_utils.states import ( DiagnosticStateNonHydro, InterpolationState, @@ -45,7 +46,6 @@ from icon4py.model.common.decomposition.definitions import DecompositionInfo, ProcessProperties from icon4py.model.common.decomposition.mpi_decomposition import ParallelLogger from icon4py.model.common.dimension import ( - C2E2C2EDim, CEDim, CellDim, EdgeDim, @@ -88,29 +88,35 @@ class SerializationType(str, Enum): class ExperimentType(str, Enum): - """Jablonowski-Williamson test""" - JABW = "jabw" - """any test with initial conditions read from serialized data""" + """initial condition of Jablonowski-Williamson test""" ANY = "any" + """any test with initial conditions read from serialized data (remember to set correct SIMULATION_START_DATE)""" def read_icon_grid( - path: Path, rank=0, ser_type: SerializationType = SerializationType.SB + path: Path, + rank=0, + ser_type: SerializationType = SerializationType.SB, + grid_root=2, + grid_level=4, ) -> IconGrid: """ Read icon grid. Args: path: path where to find the input data + rank: mpi rank of the current compute node ser_type: type of input data. Currently only 'sb (serialbox)' is supported. It reads from ppser serialized test data + grid_root: global grid root division number + grid_level: global grid refinement number Returns: IconGrid parsed from a given input type. """ if ser_type == SerializationType.SB: return ( sb.IconSerialDataProvider("icon_pydycore", str(path.absolute()), False, mpi_rank=rank) - .from_savepoint_grid(2, 4) + .from_savepoint_grid(grid_root, grid_level) .construct_icon_grid(on_gpu=False) ) else: @@ -123,7 +129,29 @@ def model_initialization_jabw( edge_param: EdgeParams, path: Path, rank=0, -): +) -> tuple[ + DiffusionDiagnosticState, + DiagnosticStateNonHydro, + PrepAdvection, + float, + DiagnosticState, + PrognosticState, + PrognosticState, +]: + """ + Initial condition of Jablonowski-Williamson test. Set jw_up to values larger than 0.01 if + you want to run baroclinic case. + + Args: + icon_grid: IconGrid + cell_param: cell properties + edge_param: edge properties + path: path where to find the input data + rank: mpi rank of the current compute node + Returns: A tuple containing Diagnostic variables for diffusion and solve_nonhydro granules, + PrepAdvection, second order divdamp factor, diagnostic variables, and two prognostic + variables (now and next). + """ data_provider = sb.IconSerialDataProvider( "icon_pydycore", str(path.absolute()), False, mpi_rank=rank ) @@ -152,6 +180,9 @@ def model_initialization_jabw( EdgeDim, HorizontalMarkerIndex.lateral_boundary(EdgeDim) + 1 ) grid_idx_edge_end = icon_grid.get_end_index(EdgeDim, HorizontalMarkerIndex.end(EdgeDim)) + grid_idx_cell_interior_start = icon_grid.get_start_index( + CellDim, HorizontalMarkerIndex.interior(CellDim) + ) grid_idx_cell_start_plus1 = icon_grid.get_end_index( CellDim, HorizontalMarkerIndex.lateral_boundary(CellDim) + 1 ) @@ -324,6 +355,19 @@ def model_initialization_jabw( log.info("U, V computation completed.") + exner_pr = _allocate(CellDim, KDim, grid=icon_grid) + init_exner_pr( + exner, + data_provider.from_metrics_savepoint().exner_ref_mc(), + exner_pr, + grid_idx_cell_interior_start, + grid_idx_cell_end, + 0, + icon_grid.num_levels, + offset_provider={}, + ) + log.info("exner_pr initialization completed.") + diagnostic_state = DiagnosticState( pressure=pressure, pressure_ifc=pressure_ifc, @@ -355,7 +399,7 @@ def model_initialization_jabw( ) solve_nonhydro_diagnostic_state = DiagnosticStateNonHydro( theta_v_ic=_allocate(CellDim, KDim, grid=icon_grid, is_halfdim=True), - exner_pr=_allocate(CellDim, KDim, grid=icon_grid), + exner_pr=exner_pr, rho_ic=_allocate(CellDim, KDim, grid=icon_grid, is_halfdim=True), ddt_exner_phy=_allocate(CellDim, KDim, grid=icon_grid), grf_tend_rho=_allocate(CellDim, KDim, grid=icon_grid), @@ -396,19 +440,28 @@ def model_initialization_jabw( ) -def model_initialization_serialbox(icon_grid: IconGrid, path: Path, rank=0): +def model_initialization_serialbox( + icon_grid: IconGrid, path: Path, rank=0 +) -> tuple[ + DiffusionDiagnosticState, + DiagnosticStateNonHydro, + PrepAdvection, + float, + DiagnosticState, + PrognosticState, + PrognosticState, +]: """ - Read prognostic and diagnostic state from serialized data. + Initial condition read from serialized data. Diagnostic variables are allocated as zero + fields. Args: - icon_grid: icon grid - path: path to the serialized input data + icon_grid: IconGrid + path: path where to find the input data rank: mpi rank of the current compute node - - Returns: a tuple containing the data_provider, the initial diagnostic and prognostic state. - The data_provider is returned such that further timesteps of diagnostics and prognostics - can be read from within the dummy timeloop - + Returns: A tuple containing Diagnostic variables for diffusion and solve_nonhydro granules, + PrepAdvection, second order divdamp factor, diagnostic variables, and two prognostic + variables (now and next). """ data_provider = sb.IconSerialDataProvider( @@ -453,9 +506,8 @@ def model_initialization_serialbox(icon_grid: IconGrid, path: Path, rank=0): diagnostic_state = DiagnosticState( pressure=_allocate(CellDim, KDim, grid=icon_grid), - pressure_ifc=_allocate(CellDim, KDim, grid=icon_grid), + pressure_ifc=_allocate(CellDim, KDim, grid=icon_grid, is_halfdim=True), temperature=_allocate(CellDim, KDim, grid=icon_grid), - pressure_sfc=_allocate(CellDim, grid=icon_grid), u=_allocate(CellDim, KDim, grid=icon_grid), v=_allocate(CellDim, KDim, grid=icon_grid), ) @@ -501,6 +553,21 @@ def read_initial_state( PrognosticState, PrognosticState, ]: + """ + Read initial prognostic and diagnostic fields. + + Args: + icon_grid: IconGrid + cell_param: cell properties + edge_param: edge properties + path: path to the serialized input data + rank: mpi rank of the current compute node + experiment_type: (optional) defaults to ANY=any, type of initial condition to be read + + Returns: A tuple containing Diagnostic variables for diffusion and solve_nonhydro granules, + PrepAdvection, second order divdamp factor, diagnostic variables, and two prognostic + variables (now and next). + """ if experiment_type == ExperimentType.JABW: ( diffusion_diagnostic_state, @@ -536,16 +603,23 @@ def read_initial_state( def read_geometry_fields( - path: Path, damping_height, rank=0, ser_type: SerializationType = SerializationType.SB + path: Path, + damping_height, + rank=0, + ser_type: SerializationType = SerializationType.SB, + grid_root=2, + grid_level=4, ) -> tuple[EdgeParams, CellParams, VerticalModelParams, Field[[CellDim], bool]]: """ Read fields containing grid properties. Args: path: path to the serialized input data - damping_height: damping height for Rayleigh and divergence damping TODO (CHia Rui): Check - rank: + damping_height: damping height for Rayleigh and divergence damping + rank: mpi rank of the current compute node ser_type: (optional) defaults to SB=serialbox, type of input data to be read + grid_root: global grid root division number + grid_level: global grid refinement number Returns: a tuple containing fields describing edges, cells, vertical properties of the model the data is originally obtained from the grid file (horizontal fields) or some special input files. @@ -553,7 +627,7 @@ def read_geometry_fields( if ser_type == SerializationType.SB: sp = sb.IconSerialDataProvider( "icon_pydycore", str(path.absolute()), False, mpi_rank=rank - ).from_savepoint_grid(2, 4) + ).from_savepoint_grid(grid_root, grid_level) edge_geometry = sp.construct_edge_geometry() cell_geometry = sp.construct_cell_geometry() vertical_geometry = VerticalModelParams( @@ -571,12 +645,14 @@ def read_decomp_info( path: Path, procs_props: ProcessProperties, ser_type=SerializationType.SB, + grid_root=2, + grid_level=4, ) -> DecompositionInfo: if ser_type == SerializationType.SB: sp = sb.IconSerialDataProvider( "icon_pydycore", str(path.absolute()), True, procs_props.rank ) - return sp.from_savepoint_grid(2, 4).construct_decomposition_info() + return sp.from_savepoint_grid(grid_root, grid_level).construct_decomposition_info() else: raise NotImplementedError(SB_ONLY_MSG) @@ -585,7 +661,8 @@ def read_static_fields( path: Path, rank=0, ser_type: SerializationType = SerializationType.SB, - experiment_type: ExperimentType = ExperimentType.JABW, + grid_root=2, + grid_level=4, ) -> tuple[ DiffusionMetricState, DiffusionInterpolationState, @@ -600,7 +677,8 @@ def read_static_fields( path: path to the serialized input data rank: mpi rank, defaults to 0 for serial run ser_type: (optional) defaults to SB=serialbox, type of input data to be read - experiment_type: TODO (CHia RUi): Add description + grid_root: global grid root division number + grid_level: global grid refinement number Returns: a tuple containing the metric_state and interpolation state, @@ -613,7 +691,7 @@ def read_static_fields( ) icon_grid = ( sb.IconSerialDataProvider("icon_pydycore", str(path.absolute()), False, mpi_rank=rank) - .from_savepoint_grid(2, 4) + .from_savepoint_grid(grid_root, grid_level) .construct_icon_grid(on_gpu=False) ) diffusion_interpolation_state = construct_interpolation_state_for_diffusion( @@ -679,19 +757,11 @@ def read_static_fields( coeff_gradekin=metrics_savepoint.coeff_gradekin(), ) - if experiment_type == ExperimentType.JABW: - diagnostic_metric_state = DiagnosticMetricState( - ddqz_z_full=metrics_savepoint.ddqz_z_full(), - rbf_vec_coeff_c1=data_provider.from_interpolation_savepoint().rbf_vec_coeff_c1(), - rbf_vec_coeff_c2=data_provider.from_interpolation_savepoint().rbf_vec_coeff_c2(), - ) - else: - # ddqz_z_full is not serialized for mch_ch_r04b09_dsl and exclaim_ape_R02B04 - diagnostic_metric_state = DiagnosticMetricState( - ddqz_z_full=_allocate(CellDim, KDim, grid=icon_grid, dtype=float), - rbf_vec_coeff_c1=_allocate(CellDim, C2E2C2EDim, grid=icon_grid, dtype=float), - rbf_vec_coeff_c2=_allocate(CellDim, C2E2C2EDim, grid=icon_grid, dtype=float), - ) + diagnostic_metric_state = DiagnosticMetricState( + ddqz_z_full=metrics_savepoint.ddqz_z_full(), + rbf_vec_coeff_c1=interpolation_savepoint.rbf_vec_coeff_c1(), + rbf_vec_coeff_c2=interpolation_savepoint.rbf_vec_coeff_c2(), + ) return ( diffusion_metric_state, diff --git a/model/driver/src/icon4py/model/driver/jablonowski_willamson_testcase.py b/model/driver/src/icon4py/model/driver/jablonowski_willamson_testcase.py index ddec9fee7a..599f076d88 100644 --- a/model/driver/src/icon4py/model/driver/jablonowski_willamson_testcase.py +++ b/model/driver/src/icon4py/model/driver/jablonowski_willamson_testcase.py @@ -29,7 +29,21 @@ def zonalwind_2_normalwind_jabw_numpy( primal_normal_x: np.array, eta_v_e: np.array, ): - """mask = np.repeat(np.expand_dims(mask, axis=-1), eta_v_e.shape[1], axis=1)""" + """ + Compute normal wind at edge center from vertical eta coordinate (eta_v_e). + + Args: + icon_grid: IconGrid + jw_u0: base zonal wind speed factor + jw_up: perturbation amplitude + lat_perturbation_center: perturbation center in latitude + lon_perturbation_center: perturbation center in longitude + edge_lat: edge center latitude + edge_lon: edge center longitude + primal_normal_x: zonal component of primal normal vector at edge center + eta_v_e: vertical eta coordinate at edge center + Returns: normal wind + """ mask = np.ones((icon_grid.num_edges, icon_grid.num_levels), dtype=bool) mask[ 0 : icon_grid.get_end_index(EdgeDim, HorizontalMarkerIndex.lateral_boundary(EdgeDim) + 1), : diff --git a/model/driver/tests/test_testcase.py b/model/driver/tests/test_testcase.py index ad1211c282..341251fe17 100644 --- a/model/driver/tests/test_testcase.py +++ b/model/driver/tests/test_testcase.py @@ -89,3 +89,9 @@ def test_jabw_initial_condition( diagnostic_state.pressure_sfc.asnumpy(), data_provider.from_savepoint_jabw_init().pressure_sfc().asnumpy(), ) + + assert dallclose( + solve_nonhydro_diagnostic_state.exner_pr.asnumpy(), + data_provider.from_savepoint_jabw_diagnostic().exner_pr().asnumpy(), + atol=1.0e-14, + )