Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run icon4py on gpu #579

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -835,81 +835,82 @@ def _do_diffusion_step(
"running stencils 07 08 09 10 (apply_diffusion_to_w_and_compute_horizontal_gradients_for_turbulence): end"
)

log.debug(
"running fused stencils 11 12 (calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools): start"
)

self.calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools.with_connectivities(
self.compile_time_connectivities
)(
theta_v=prognostic_state.theta_v,
theta_ref_mc=self._metric_state.theta_ref_mc,
thresh_tdiff=self.thresh_tdiff,
smallest_vpfloat=constants.DBL_EPS,
kh_smag_e=self.kh_smag_e,
horizontal_start=self._edge_start_nudging,
horizontal_end=self._edge_end_halo,
vertical_start=(self._grid.num_levels - 2),
vertical_end=self._grid.num_levels,
offset_provider=self._grid.offset_providers,
)
log.debug(
"running stencils 11 12 (calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools): end"
)
if self.config.apply_to_temperature:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that stencil should be inside the if and it was not before? We never catch these kind of bugs because we allways run the same configuration. So we only ever test on branch of ifs... :-( From this point of view it is very bad that the configuration of EXCLAIM.APE is similar to MCH_CH_R04B09_DSL. We had a discussion on this yesterday. We really should come up with a list of configurations that we want to support and then test all of them and delete what we don't want and test. @anuragdipankar @lxavier @muellch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, those stencils are inside the if statement. The temperature diffusion should be turned off in the standard Jablonowski Williamson test.

log.debug(
"running fused stencils 11 12 (calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools): start"
)

log.debug("running stencils 13 14 (calculate_nabla2_for_theta): start")
self.calculate_nabla2_for_theta.with_connectivities(self.compile_time_connectivities)(
kh_smag_e=self.kh_smag_e,
inv_dual_edge_length=self._edge_params.inverse_dual_edge_lengths,
theta_v=prognostic_state.theta_v,
geofac_div=self._interpolation_state.geofac_div,
z_temp=self.z_temp,
horizontal_start=self._cell_start_nudging,
horizontal_end=self._cell_end_local,
vertical_start=0,
vertical_end=self._grid.num_levels,
offset_provider=self._grid.offset_providers,
)
log.debug("running stencils 13_14 (calculate_nabla2_for_theta): end")
log.debug(
"running stencil 15 (truly_horizontal_diffusion_nabla_of_theta_over_steep_points): start"
)
if self.config.apply_zdiffusion_t:
self.truly_horizontal_diffusion_nabla_of_theta_over_steep_points.with_connectivities(
self.calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools.with_connectivities(
self.compile_time_connectivities
)(
mask=self._metric_state.mask_hdiff,
zd_vertoffset=self._metric_state.zd_vertoffset,
zd_diffcoef=self._metric_state.zd_diffcoef,
geofac_n2s_c=self._interpolation_state.geofac_n2s_c,
geofac_n2s_nbh=self._interpolation_state.geofac_n2s_nbh,
vcoef=self._metric_state.zd_intcoef,
theta_v=prognostic_state.theta_v,
theta_ref_mc=self._metric_state.theta_ref_mc,
thresh_tdiff=self.thresh_tdiff,
smallest_vpfloat=constants.DBL_EPS,
kh_smag_e=self.kh_smag_e,
horizontal_start=self._edge_start_nudging,
horizontal_end=self._edge_end_halo,
vertical_start=(self._grid.num_levels - 2),
vertical_end=self._grid.num_levels,
offset_provider=self._grid.offset_providers,
)
log.debug(
"running stencils 11 12 (calculate_enhanced_diffusion_coefficients_for_grid_point_cold_pools): end"
)

log.debug("running stencils 13 14 (calculate_nabla2_for_theta): start")
self.calculate_nabla2_for_theta.with_connectivities(self.compile_time_connectivities)(
kh_smag_e=self.kh_smag_e,
inv_dual_edge_length=self._edge_params.inverse_dual_edge_lengths,
theta_v=prognostic_state.theta_v,
geofac_div=self._interpolation_state.geofac_div,
z_temp=self.z_temp,
horizontal_start=self._cell_start_nudging,
horizontal_end=self._cell_end_local,
vertical_start=0,
vertical_end=self._grid.num_levels,
offset_provider=self._grid.offset_providers,
)

log.debug("running stencils 13_14 (calculate_nabla2_for_theta): end")
log.debug(
"running fused stencil 15 (truly_horizontal_diffusion_nabla_of_theta_over_steep_points): end"
"running stencil 15 (truly_horizontal_diffusion_nabla_of_theta_over_steep_points): start"
)
log.debug("running stencil 16 (update_theta_and_exner): start")
self.update_theta_and_exner.with_connectivities(self.compile_time_connectivities)(
z_temp=self.z_temp,
area=self._cell_params.area,
theta_v=prognostic_state.theta_v,
exner=prognostic_state.exner,
rd_o_cvd=self.rd_o_cvd,
horizontal_start=self._cell_start_nudging,
horizontal_end=self._cell_end_local,
vertical_start=0,
vertical_end=self._grid.num_levels,
offset_provider={},
)
log.debug("running stencil 16 (update_theta_and_exner): end")
if self.config.apply_zdiffusion_t:
self.truly_horizontal_diffusion_nabla_of_theta_over_steep_points.with_connectivities(
self.compile_time_connectivities
)(
mask=self._metric_state.mask_hdiff,
zd_vertoffset=self._metric_state.zd_vertoffset,
zd_diffcoef=self._metric_state.zd_diffcoef,
geofac_n2s_c=self._interpolation_state.geofac_n2s_c,
geofac_n2s_nbh=self._interpolation_state.geofac_n2s_nbh,
vcoef=self._metric_state.zd_intcoef,
theta_v=prognostic_state.theta_v,
z_temp=self.z_temp,
horizontal_start=self._cell_start_nudging,
horizontal_end=self._cell_end_local,
vertical_start=0,
vertical_end=self._grid.num_levels,
offset_provider=self._grid.offset_providers,
)

log.debug(
"running fused stencil 15 (truly_horizontal_diffusion_nabla_of_theta_over_steep_points): end"
)
log.debug("running stencil 16 (update_theta_and_exner): start")
self.update_theta_and_exner.with_connectivities(self.compile_time_connectivities)(
z_temp=self.z_temp,
area=self._cell_params.area,
theta_v=prognostic_state.theta_v,
exner=prognostic_state.exner,
rd_o_cvd=self.rd_o_cvd,
horizontal_start=self._cell_start_nudging,
horizontal_end=self._cell_end_local,
vertical_start=0,
vertical_end=self._grid.num_levels,
offset_provider={},
)
log.debug("running stencil 16 (update_theta_and_exner): end")

self.halo_exchange_wait(
handle_edge_comm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from icon4py.model.atmosphere.dycore.interpolate_to_half_levels_vp import (
_interpolate_to_half_levels_vp,
)
from icon4py.model.atmosphere.dycore.interpolate_to_half_levels_wp import (
_interpolate_to_half_levels_wp,
)
from icon4py.model.common import dimension as dims, field_type_aliases as fa
from icon4py.model.common.dimension import Koff
from icon4py.model.common.settings import backend
Expand All @@ -37,7 +40,7 @@ def _compute_virtual_potential_temperatures_and_pressure_gradient(
wgtfac_c_wp, ddqz_z_half_wp = astype((wgtfac_c, ddqz_z_half), wpfloat)

z_theta_v_pr_ic_vp = _interpolate_to_half_levels_vp(wgtfac_c=wgtfac_c, interpolant=z_rth_pr_2)
theta_v_ic_wp = wgtfac_c_wp * theta_v + (wpfloat("1.0") - wgtfac_c_wp) * theta_v(Koff[-1])
theta_v_ic_wp = _interpolate_to_half_levels_wp(wgtfac_c=wgtfac_c_wp, interpolant=theta_v)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what again was the problem with the inline version? (I have nothing against the stencil, I think it is even better in terms of readability, but I am still wondering why it changes anything.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I always faced cuda illegal memory access error with the inline version when running with resolutions greater than R2B6. I still do not understand why. I am not sure whether it is due to some deep reasons or bugs else where. At least so far for R2B7, everything is okay and results are verified with the new functional call. (As I also indeed like the functional call for interpolation) I tend to keep it this way.

z_th_ddz_exner_c_wp = vwind_expl_wgt * theta_v_ic_wp * (
exner_pr(Koff[-1]) - exner_pr
) / ddqz_z_half_wp + astype(z_theta_v_pr_ic_vp * d_exner_dz_ref_ic, wpfloat)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# ICON4Py - ICON inspired code in Python and GT4Py
#
# Copyright (c) 2022-2024, ETH Zurich and MeteoSwiss
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
import gt4py.next as gtx
from gt4py.next.common import GridType
from gt4py.next.ffront.decorator import field_operator, program

from icon4py.model.common import dimension as dims, field_type_aliases as fa
from icon4py.model.common.dimension import Koff
from icon4py.model.common.settings import backend
from icon4py.model.common.type_alias import wpfloat


@field_operator
def _interpolate_to_half_levels_wp(
wgtfac_c: fa.CellKField[wpfloat],
interpolant: fa.CellKField[wpfloat],
) -> fa.CellKField[wpfloat]:
"""Formerly known mo_velocity_advection_stencil_10 and as _mo_solve_nonhydro_stencil_05."""
OngChia marked this conversation as resolved.
Show resolved Hide resolved
interpolation_to_half_levels_wp = wgtfac_c * interpolant + (
wpfloat("1.0") - wgtfac_c
) * interpolant(Koff[-1])
return interpolation_to_half_levels_wp


@program(grid_type=GridType.UNSTRUCTURED, backend=backend)
def interpolate_to_half_levels_wp(
wgtfac_c: fa.CellKField[wpfloat],
interpolant: fa.CellKField[wpfloat],
interpolation_to_half_levels_wp: fa.CellKField[wpfloat],
horizontal_start: gtx.int32,
horizontal_end: gtx.int32,
vertical_start: gtx.int32,
vertical_end: gtx.int32,
):
_interpolate_to_half_levels_wp(
wgtfac_c,
interpolant,
out=interpolation_to_half_levels_wp,
domain={
dims.CellDim: (horizontal_start, horizontal_end),
dims.KDim: (vertical_start, vertical_end),
},
)
Loading
Loading