From 898ff7c256ae14e40e316af04db8ce720b0d2513 Mon Sep 17 00:00:00 2001 From: David Strassmann Date: Mon, 25 Nov 2024 14:25:11 +0100 Subject: [PATCH] Address requested changes --- .../advection/advection_vertical.py | 30 +++++++++++++------ .../compute_vertical_tracer_flux_upwind.py | 2 +- .../advection/tests/advection_tests/utils.py | 12 ++++---- 3 files changed, 28 insertions(+), 16 deletions(-) diff --git a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_vertical.py b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_vertical.py index cd6c9dd0f..374db8794 100644 --- a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_vertical.py +++ b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/advection_vertical.py @@ -556,6 +556,10 @@ def __init__( ) self._integrate_tracer_vertically = integrate_tracer_vertically.with_backend(self._backend) + # misc + self._ivadv_tracer = 1 + self._iadv_slev_jt = 0 + log.debug("vertical advection class init - end") def _get_horizontal_start_end(self, even_timestep: bool): @@ -633,8 +637,8 @@ def _update_unknowns( tracer_new=p_tracer_new, k=self._k_field, p_dtime=dtime, - ivadv_tracer=1, - iadv_slev_jt=0, + ivadv_tracer=self._ivadv_tracer, + iadv_slev_jt=self._iadv_slev_jt, horizontal_start=horizontal_start, horizontal_end=horizontal_end, vertical_start=0, @@ -730,6 +734,14 @@ def __init__( ) self._integrate_tracer_vertically = integrate_tracer_vertically.with_backend(self._backend) + # misc + self._slev = 0 + self._slevp1_ti = 1 + self._elev = self._grid.num_levels - 1 + self._nlev = self._grid.num_levels - 1 + self._ivadv_tracer = 1 + self._iadv_slev_jt = 0 + log.debug("vertical advection class init - end") def _get_horizontal_start_end(self, even_timestep: bool): @@ -777,8 +789,8 @@ def _compute_numerical_flux( p_cellmass_now=rhodz_now, z_cfl=self._z_cfl, k=self._k_field, - slevp1_ti=1, - nlev=self._grid.num_levels - 1, + slevp1_ti=self._slevp1_ti, + nlev=self._nlev, dbl_eps=constants.DBL_EPS, p_dtime=dtime, horizontal_start=horizontal_start, @@ -798,7 +810,7 @@ def _compute_numerical_flux( p_cellhgt_mc_now=self._metric_state.ddqz_z_full, k=self._k_field, z_slope=self._z_slope, - elev=self._grid.num_levels - 1, + elev=self._elev, horizontal_start=horizontal_start, horizontal_end=horizontal_end, vertical_start=1, @@ -921,7 +933,7 @@ def _compute_numerical_flux( z_a1=self._z_a1, p_upflux=p_mflx_tracer_v, k=self._k_field, - slev=0, + slev=self._slev, p_dtime=dtime, horizontal_start=horizontal_start, horizontal_end=horizontal_end, @@ -940,7 +952,7 @@ def _compute_numerical_flux( z_cfl=self._z_cfl, p_upflux=p_mflx_tracer_v, k=self._k_field, - slev=0, + slev=self._slev, p_dtime=dtime, horizontal_start=horizontal_start, horizontal_end=horizontal_end, @@ -994,8 +1006,8 @@ def _update_unknowns( tracer_new=p_tracer_new, k=self._k_field, p_dtime=dtime, - ivadv_tracer=1, - iadv_slev_jt=0, + ivadv_tracer=self._ivadv_tracer, + iadv_slev_jt=self._iadv_slev_jt, horizontal_start=horizontal_start, horizontal_end=horizontal_end, vertical_start=0, diff --git a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/stencils/compute_vertical_tracer_flux_upwind.py b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/stencils/compute_vertical_tracer_flux_upwind.py index 556c15c73..7e4a55d7a 100644 --- a/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/stencils/compute_vertical_tracer_flux_upwind.py +++ b/model/atmosphere/advection/src/icon4py/model/atmosphere/advection/stencils/compute_vertical_tracer_flux_upwind.py @@ -35,7 +35,7 @@ def compute_vertical_tracer_flux_upwind( _compute_vertical_tracer_flux_upwind( p_cc, p_mflx_contra_v, - out=(p_upflux), + out=p_upflux, domain={ dims.CellDim: (horizontal_start, horizontal_end), dims.KDim: (vertical_start, vertical_end), diff --git a/model/atmosphere/advection/tests/advection_tests/utils.py b/model/atmosphere/advection/tests/advection_tests/utils.py index 6aecba2e6..95af10235 100644 --- a/model/atmosphere/advection/tests/advection_tests/utils.py +++ b/model/atmosphere/advection/tests/advection_tests/utils.py @@ -171,17 +171,17 @@ def verify_advection_fields( # verify advection output fields assert helpers.dallclose( - diagnostic_state.hfl_tracer.asnumpy()[hfl_tracer_range, :], - diagnostic_state_ref.hfl_tracer.asnumpy()[hfl_tracer_range, :], + diagnostic_state.hfl_tracer.ndarray[hfl_tracer_range, :], + diagnostic_state_ref.hfl_tracer.ndarray[hfl_tracer_range, :], rtol=1e-10, ) assert helpers.dallclose( - diagnostic_state.vfl_tracer.asnumpy()[vfl_tracer_range, :], - diagnostic_state_ref.vfl_tracer.asnumpy()[vfl_tracer_range, :], + diagnostic_state.vfl_tracer.ndarray[vfl_tracer_range, :], + diagnostic_state_ref.vfl_tracer.ndarray[vfl_tracer_range, :], rtol=1e-10, ) assert helpers.dallclose( - p_tracer_new.asnumpy()[p_tracer_new_range, :], - p_tracer_new_ref.asnumpy()[p_tracer_new_range, :], + p_tracer_new.ndarray[p_tracer_new_range, :], + p_tracer_new_ref.ndarray[p_tracer_new_range, :], atol=1e-16, )