Skip to content

Commit

Permalink
[dace] Fix for ITIR temporary pass
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Oct 19, 2023
1 parent 2380efb commit 5f10d3d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,47 +157,28 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None:

arg_types = [type_translation.from_value(arg) for arg in args]
neighbor_tables = filter_neighbor_tables(offset_provider)
program_itir_not_working = [
"calculate_horizontal_gradients_for_turbulence",
"calculate_nabla2_and_smag_coefficients_for_vn",
"calculate_nabla4",
"mo_advection_traj_btraj_compute_o1_dsl",
"mo_math_gradients_grad_green_gauss_cell_dsl",
"mo_solve_nonhydro_stencil_16_fused_btraj_traj_o1",
"mo_solve_nonhydro_stencil_20",
"mo_solve_nonhydro_stencil_21",
"mo_solve_nonhydro_stencil_30",
"mo_solve_nonhydro_stencil_41",
"mo_velocity_advection_stencil_19",
"temporary_fields_for_turbulence_diagnostics",
"truly_horizontal_diffusion_nabla_of_theta_over_steep_points",
]

with_temporaries = False
cache_id = get_cache_id(program, arg_types, column_axis, offset_provider)
if build_cache is not None and cache_id in build_cache:
# retrieve SDFG program from build cache
sdfg_program = build_cache[cache_id]
sdfg = sdfg_program.sdfg
else:
# visit ITIR and generate SDFG
if (
any([ItirToSDFG._check_no_lifts(node) for node in program.closures])
and program.id not in program_itir_not_working
):
program = preprocess_program(program, offset_provider, LiftMode.FORCE_INLINE)
if all([ItirToSDFG._check_no_lifts(node) for node in program.closures]):
tmps = []
else:
program_with_tmps: global_tmps.FencilWithTemporaries = preprocess_program(
program, offset_provider, LiftMode.FORCE_TEMPORARIES
)
program = program_with_tmps.fencil
tmps = program_with_tmps.tmps
with_temporaries = True
else:
program = preprocess_program(program, offset_provider, LiftMode.FORCE_INLINE)
tmps = []

# visit ITIR and generate SDFG
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, tmps, column_axis)
sdfg = sdfg_genenerator.visit(program)
if with_temporaries:
if tmps:
# This pass is needed to avoid transformation errors in SDFG inlining, because temporaries are using offsets
sdfg.apply_transformations_repeated(RefineNestedAccess)
sdfg.simplify()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition, **kargs):
last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet)

if self.tmps:
# on the first interstate edge define symbols for shape/stride/offsets of temporary arrays
inter_state_edge = program_sdfg.out_edges(entry_state)[0]
inter_state_edge.data.assignments.update(tmp_symbols)

Expand Down

0 comments on commit 5f10d3d

Please sign in to comment.