Skip to content

Commit

Permalink
fix bug in test multi substeps
Browse files Browse the repository at this point in the history
  • Loading branch information
OngChia committed Nov 27, 2024
1 parent 982774f commit d3b0ab8
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ def test_run_solve_nonhydro_multi_step(
recompute = sp.get_metadata("recompute").get("recompute")
linit = sp.get_metadata("linit").get("linit")

diagnostic_state_nh = utils.construct_diagnostics(sp, swap_ddt_w_adv_pc=True)
diagnostic_state_nh = utils.construct_diagnostics(sp, swap_ddt_w_adv_pc=not linit)
prognostic_states = utils.create_prognostic_states(sp)

interpolation_state = utils.construct_interpolation_state(interpolation_savepoint)
Expand Down
13 changes: 10 additions & 3 deletions tools/src/icon4pytools/py2fgen/wrappers/dycore_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def solve_nh_run(
divdamp_fac_o2: gtx.float64,
ndyn_substeps: gtx.float64,
idyn_timestep: gtx.int32,
):
) -> tuple[bool, bool]:
logger.info(f"Using Device = {settings.device}")

prep_adv = dycore_states.PrepAdvection(
Expand All @@ -394,6 +394,8 @@ def solve_nh_run(
vol_flx_ic=zero_field(dycore_wrapper_state["grid"], CellDim, KDim, dtype=gtx.float64),
)

ddt_vn_apc = common_utils.TimeStepPair(ddt_vn_apc_ntl1, ddt_vn_apc_ntl2)
ddt_w_adv = common_utils.TimeStepPair(ddt_w_adv_ntl1, ddt_w_adv_ntl2)
diagnostic_state_nh = dycore_states.DiagnosticStateNonHydro(
theta_v_ic=theta_v_ic,
exner_pr=exner_pr,
Expand All @@ -405,8 +407,8 @@ def solve_nh_run(
mass_fl_e=mass_fl_e,
ddt_vn_phy=ddt_vn_phy,
grf_tend_vn=grf_tend_vn,
ddt_vn_apc_pc=common_utils.TimeStepPair(ddt_vn_apc_ntl1, ddt_vn_apc_ntl2),
ddt_w_adv_pc=common_utils.TimeStepPair(ddt_w_adv_ntl1, ddt_w_adv_ntl2),
ddt_vn_apc_pc=ddt_vn_apc,
ddt_w_adv_pc=ddt_w_adv,
vt=vt,
vn_ie=vn_ie,
w_concorr_c=w_concorr_c,
Expand Down Expand Up @@ -449,6 +451,11 @@ def solve_nh_run(
at_last_substep=idyn_timestep == (ndyn_substeps - 1),
)

is_ddt_vn_apc_swapped = False if ddt_vn_apc_ntl1 == ddt_vn_apc.current else True
is_ddt_w_adv_swapped = False if ddt_w_adv_ntl1 == ddt_w_adv.current else True

return is_ddt_vn_apc_swapped, is_ddt_w_adv_swapped


def grid_init(
cell_starts: gt4py_common.Field[[CellIndexDim], gtx.int32],
Expand Down
87 changes: 48 additions & 39 deletions tools/tests/py2fgen/test_dycore_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from icon4py.model.common.grid import horizontal as h_grid, vertical as v_grid
from icon4py.model.common.grid.vertical import VerticalGridConfig
from icon4py.model.common.states import prognostic_state as prognostics
from icon4py.model.common.states.prognostic_state import PrognosticState
from icon4py.model.common.test_utils import (
datatest_utils as dt_utils,
helpers,
Expand Down Expand Up @@ -292,8 +293,6 @@ def test_dycore_wrapper_granule_inputs(
exner_new = sp.exner_new()

# using fortran indices
nnow = 1
nnew = 2
jstep_init_fortran = jstep_init + 1

# --- Expected objects that form inputs into init function ---
Expand Down Expand Up @@ -382,9 +381,7 @@ def test_dycore_wrapper_granule_inputs(
ddt_vn_phy=sp.ddt_vn_phy(),
grf_tend_vn=sp.grf_tend_vn(),
ddt_vn_apc_pc=common_utils.TimeStepPair(sp.ddt_vn_apc_pc(1), sp.ddt_vn_apc_pc(2)),
ddt_w_adv_pc=common_utils.TimeStepPair(
sp.ddt_w_adv_pc(1), sp.ddt_w_adv_pc(2)
),
ddt_w_adv_pc=common_utils.TimeStepPair(sp.ddt_w_adv_pc(1), sp.ddt_w_adv_pc(2)),
vt=sp.vt(),
vn_ie=sp.vn_ie(),
w_concorr_c=sp.w_concorr_c(),
Expand Down Expand Up @@ -423,8 +420,6 @@ def test_dycore_wrapper_granule_inputs(
expected_linit = sp.get_metadata("linit").get("linit")
expected_clean_mflx = sp.get_metadata("clean_mflx").get("clean_mflx")
expected_lprep_adv = sp.get_metadata("prep_adv").get("prep_adv")
expected_nnow = 0
expected_nnew = 1
expected_at_first_substep = jstep_init == 0
expected_at_last_substep = jstep_init == (ndyn_substeps - 1)

Expand Down Expand Up @@ -663,8 +658,6 @@ def test_dycore_wrapper_granule_inputs(
divdamp_fac_o2=initial_divdamp_fac,
ndyn_substeps=ndyn_substeps,
idyn_timestep=jstep_init_fortran,
nnow=nnow,
nnew=nnew,
)

# Check input arguments to SolveNonhydro.time_step
Expand Down Expand Up @@ -711,12 +704,6 @@ def test_dycore_wrapper_granule_inputs(
)
assert result, f"Prep Advection flag comparison failed: {error_message}"

result, error_message = utils.compare_objects(captured_kwargs["nnew"], expected_nnew)
assert result, f"nnew comparison failed: {error_message}"

result, error_message = utils.compare_objects(captured_kwargs["nnow"], expected_nnow)
assert result, f"nnow comparison failed: {error_message}"

result, error_message = utils.compare_objects(
captured_kwargs["at_first_substep"], expected_at_first_substep
)
Expand Down Expand Up @@ -1107,11 +1094,9 @@ def test_granule_solve_nonhydro_single_step_regional(
exner_new = sp.exner_new()

# using fortran indices
nnow = 1
nnew = 2
jstep_init_fortran = jstep_init + 1

dycore_wrapper.solve_nh_run(
is_ddt_vn_apc_swapped, is_ddt_w_adv_swapped = dycore_wrapper.solve_nh_run(
rho_now=rho_now,
rho_new=rho_new,
exner_now=exner_now,
Expand Down Expand Up @@ -1151,8 +1136,6 @@ def test_granule_solve_nonhydro_single_step_regional(
divdamp_fac_o2=initial_divdamp_fac,
ndyn_substeps=ndyn_substeps,
idyn_timestep=jstep_init_fortran,
nnow=nnow,
nnew=nnew,
)

assert helpers.dallclose(
Expand Down Expand Up @@ -1537,8 +1520,12 @@ def test_granule_solve_nonhydro_multi_step_regional(
grf_tend_vn = sp.grf_tend_vn()
ddt_vn_apc_ntl1 = sp.ddt_vn_apc_pc(1)
ddt_vn_apc_ntl2 = sp.ddt_vn_apc_pc(2)
ddt_w_adv_ntl1 = sp.ddt_w_adv_pc(1)
ddt_w_adv_ntl2 = sp.ddt_w_adv_pc(2)
if linit:
ddt_w_adv_ntl1 = sp.ddt_w_adv_pc(1)
ddt_w_adv_ntl2 = sp.ddt_w_adv_pc(2)
else:
ddt_w_adv_ntl1 = sp.ddt_w_adv_pc(2)
ddt_w_adv_ntl2 = sp.ddt_w_adv_pc(1)
vt = sp.vt()
vn_ie = sp.vn_ie()
w_concorr_c = sp.w_concorr_c()
Expand All @@ -1557,26 +1544,42 @@ def test_granule_solve_nonhydro_multi_step_regional(
rho_new = sp.rho_new()
exner_new = sp.exner_new()

prognostic_state_nnow = PrognosticState(
w=w_now,
vn=vn_now,
theta_v=theta_v_now,
rho=rho_now,
exner=exner_now,
)
prognostic_state_nnew = PrognosticState(
w=w_new,
vn=vn_new,
theta_v=theta_v_new,
rho=rho_new,
exner=exner_new,
)
prognostic_states = common_utils.TimeStepPair(prognostic_state_nnow, prognostic_state_nnew)
ddt_vn_apc = common_utils.TimeStepPair(ddt_vn_apc_ntl1, ddt_vn_apc_ntl2)
ddt_w_adv = common_utils.TimeStepPair(ddt_w_adv_ntl1, ddt_w_adv_ntl2)

# use fortran indices in the driving loop to compute i_substep
for i_substep in range(1, ndyn_substeps + 1):
is_last_substep = i_substep == (ndyn_substeps)

dycore_wrapper.solve_nh_run(
rho_now=rho_now,
rho_new=rho_new,
exner_now=exner_now,
exner_new=exner_new,
w_now=w_now,
w_new=w_new,
theta_v_now=theta_v_now,
theta_v_new=theta_v_new,
vn_now=vn_now,
vn_new=vn_new,
is_ddt_vn_apc_swapped, is_ddt_w_adv_swapped = dycore_wrapper.solve_nh_run(
rho_now=prognostic_states.current.rho,
rho_new=prognostic_states.next.rho,
exner_now=prognostic_states.current.exner,
exner_new=prognostic_states.next.exner,
w_now=prognostic_states.current.w,
w_new=prognostic_states.next.w,
theta_v_now=prognostic_states.current.theta_v,
theta_v_new=prognostic_states.next.theta_v,
vn_now=prognostic_states.current.vn,
vn_new=prognostic_states.next.vn,
w_concorr_c=w_concorr_c,
ddt_vn_apc_ntl1=ddt_vn_apc_ntl1,
ddt_vn_apc_ntl2=ddt_vn_apc_ntl2,
ddt_w_adv_ntl1=ddt_w_adv_ntl1,
ddt_w_adv_ntl2=ddt_w_adv_ntl2,
ddt_vn_apc_ntl1=ddt_vn_apc.current,
ddt_vn_apc_ntl2=ddt_vn_apc.next,
ddt_w_adv_ntl1=ddt_w_adv.current,
ddt_w_adv_ntl2=ddt_w_adv.next,
theta_v_ic=theta_v_ic,
rho_ic=rho_ic,
exner_pr=exner_pr,
Expand Down Expand Up @@ -1606,6 +1609,12 @@ def test_granule_solve_nonhydro_multi_step_regional(
recompute = False
clean_mflx = False

prognostic_states.swap()
if is_ddt_vn_apc_swapped:
ddt_vn_apc.swap()
if is_ddt_w_adv_swapped:
ddt_w_adv.swap()

cell_start_lb_plus2 = icon_grid.start_index(
h_grid.domain(dims.CellDim)(h_grid.Zone.LATERAL_BOUNDARY_LEVEL_3)
)
Expand Down

0 comments on commit d3b0ab8

Please sign in to comment.