Skip to content

Commit

Permalink
Refactor stencil with high compile time
Browse files Browse the repository at this point in the history
  • Loading branch information
dastrm committed Nov 25, 2024
1 parent 381d378 commit 61f6b59
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 221 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -771,53 +771,23 @@ def _compute_numerical_flux(
)
log.debug("running stencil init_constant_cell_kdim_field - end")

# TODO (dastrm): write missing stencil here
if 0:
slevp1_ti = 1
nlev = self._grid.num_levels - 1
for jc in range(horizontal_start, horizontal_end):
for jk in range(1, self._grid.num_levels):
z_mass = dtime * prep_adv.mass_flx_ic.ndarray[jc, jk]
if z_mass > 0.0:
jks = jk
while (z_mass >= rhodz_now.ndarray[jc, jks]) and (jks <= nlev - 1):
z_mass -= rhodz_now.ndarray[jc, jks]
jks += 1
self._z_cfl.ndarray[jc, jk] += 1.0
z_cflfrac = z_mass / rhodz_now.ndarray[jc, jks]
if z_cflfrac < 1.0:
self._z_cfl.ndarray[jc, jk] += z_cflfrac
else:
self._z_cfl.ndarray[jc, jk] += 1.0 - constants.DBL_EPS
else:
jks = jk - 1
while (abs(z_mass) >= rhodz_now.ndarray[jc, jks]) and (jks >= slevp1_ti):
z_mass += rhodz_now.ndarray[jc, jks]
jks -= 1
self._z_cfl.ndarray[jc, jk] -= 1.0
z_cflfrac = z_mass / rhodz_now.ndarray[jc, jks]
if abs(z_cflfrac) < 1.0:
self._z_cfl.ndarray[jc, jk] += z_cflfrac
else:
self._z_cfl.ndarray[jc, jk] += constants.DBL_EPS - 1.0
else:
log.debug("running stencil compute_ppm4gpu_courant_number - start")
self._compute_ppm4gpu_courant_number(
p_mflx_contra_v=prep_adv.mass_flx_ic,
p_cellmass_now=rhodz_now,
z_cfl=self._z_cfl,
k=self._k_field,
slevp1_ti=1,
nlev=self._grid.num_levels - 1,
dbl_eps=constants.DBL_EPS,
p_dtime=dtime,
horizontal_start=horizontal_start,
horizontal_end=horizontal_end,
vertical_start=1,
vertical_end=self._grid.num_levels,
offset_provider=self._grid.offset_providers,
)
log.debug("running stencil compute_ppm4gpu_courant_number - end")
log.debug("running stencil compute_ppm4gpu_courant_number - start")
self._compute_ppm4gpu_courant_number(
p_mflx_contra_v=prep_adv.mass_flx_ic,
p_cellmass_now=rhodz_now,
z_cfl=self._z_cfl,
k=self._k_field,
slevp1_ti=1,
nlev=self._grid.num_levels - 1,
dbl_eps=constants.DBL_EPS,
p_dtime=dtime,
horizontal_start=horizontal_start,
horizontal_end=horizontal_end,
vertical_start=1,
vertical_end=self._grid.num_levels,
offset_provider=self._grid.offset_providers,
)
log.debug("running stencil compute_ppm4gpu_courant_number - end")

## reconstruct face values

Expand Down Expand Up @@ -942,94 +912,43 @@ def _compute_numerical_flux(
)
log.debug("running stencil compute_ppm4gpu_parabola_coefficients - end")

# TODO (dastrm): write missing stencil here
if 0:
slev = 0
for jc in range(horizontal_start, horizontal_end):
for jk in range(1, self._grid.num_levels):
js = int(abs(self._z_cfl.ndarray[jc, jk]))
z_cflfrac = abs(self._z_cfl.ndarray[jc, jk]) - js
if self._z_cfl.ndarray[jc, jk] > 0.0:
jks = min(jk, self._grid.num_levels - 1) + js
wsign = 1.0
else:
jks = (jk - 1) - js
wsign = -1.0
if jks < slev:
p_mflx_tracer_v.ndarray[jc, jk] = 0.0
continue
z_q_int = (
p_tracer_now.ndarray[jc, jks]
+ wsign * (self._z_delta_q.ndarray[jc, jks] * (1.0 - z_cflfrac))
- self._z_a1.ndarray[jc, jks]
* (1.0 - 3.0 * z_cflfrac + 2.0 * z_cflfrac * z_cflfrac)
)
p_mflx_tracer_v.ndarray[jc, jk] = (
wsign * rhodz_now.ndarray[jc, jks] * z_cflfrac * z_q_int / dtime
)
else:
log.debug("running stencil compute_ppm4gpu_fractional_flux - start")
self._compute_ppm4gpu_fractional_flux(
p_cc=p_tracer_now,
p_cellmass_now=rhodz_now,
z_cfl=self._z_cfl,
z_delta_q=self._z_delta_q,
z_a1=self._z_a1,
p_upflux=p_mflx_tracer_v,
k=self._k_field,
slev=0,
p_dtime=dtime,
horizontal_start=horizontal_start,
horizontal_end=horizontal_end,
vertical_start=1,
vertical_end=self._grid.num_levels,
offset_provider=self._grid.offset_providers,
)
log.debug("running stencil compute_ppm4gpu_fractional_flux - end")
log.debug("running stencil compute_ppm4gpu_fractional_flux - start")
self._compute_ppm4gpu_fractional_flux(
p_cc=p_tracer_now,
p_cellmass_now=rhodz_now,
z_cfl=self._z_cfl,
z_delta_q=self._z_delta_q,
z_a1=self._z_a1,
p_upflux=p_mflx_tracer_v,
k=self._k_field,
slev=0,
p_dtime=dtime,
horizontal_start=horizontal_start,
horizontal_end=horizontal_end,
vertical_start=1,
vertical_end=self._grid.num_levels,
offset_provider=self._grid.offset_providers,
)
log.debug("running stencil compute_ppm4gpu_fractional_flux - end")

## compute integer numerical flux

# TODO (dastrm): write missing stencil here
if 0:
slev = 0
for jc in range(horizontal_start, horizontal_end):
for jk in range(1, self._grid.num_levels):
js = int(abs(self._z_cfl.ndarray[jc, jk]))
if js == 0:
continue
z_iflx = 0.0
if self._z_cfl.ndarray[jc, jk] > 0.0:
for n in range(1, js + 1):
jk_shift = jk - 1 + n
z_iflx += (
p_tracer_now.ndarray[jc, jk_shift] * rhodz_now.ndarray[jc, jk_shift]
)
else:
for n in range(1, js + 1):
jk_shift = jk - n
if jk_shift < slev:
continue
z_iflx -= (
p_tracer_now.ndarray[jc, jk_shift] * rhodz_now.ndarray[jc, jk_shift]
)
p_mflx_tracer_v.ndarray[jc, jk] += z_iflx / dtime
else:
log.debug("running stencil compute_ppm4gpu_integer_flux - start")
self._compute_ppm4gpu_integer_flux(
p_cc=p_tracer_now,
p_cellmass_now=rhodz_now,
z_cfl=self._z_cfl,
p_upflux=p_mflx_tracer_v,
k=self._k_field,
slev=0,
p_dtime=dtime,
horizontal_start=horizontal_start,
horizontal_end=horizontal_end,
vertical_start=1,
vertical_end=self._grid.num_levels,
offset_provider=self._grid.offset_providers,
)
log.debug("running stencil compute_ppm4gpu_integer_flux - end")
log.debug("running stencil compute_ppm4gpu_integer_flux - start")
self._compute_ppm4gpu_integer_flux(
p_cc=p_tracer_now,
p_cellmass_now=rhodz_now,
z_cfl=self._z_cfl,
p_upflux=p_mflx_tracer_v,
k=self._k_field,
slev=0,
p_dtime=dtime,
horizontal_start=horizontal_start,
horizontal_end=horizontal_end,
vertical_start=1,
vertical_end=self._grid.num_levels,
offset_provider=self._grid.offset_providers,
)
log.debug("running stencil compute_ppm4gpu_integer_flux - end")

## set boundary conditions

Expand Down
Loading

0 comments on commit 61f6b59

Please sign in to comment.