diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index fbd143b7e8..2c5f33d4eb 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -157,47 +157,28 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: arg_types = [type_translation.from_value(arg) for arg in args] neighbor_tables = filter_neighbor_tables(offset_provider) - program_itir_not_working = [ - "calculate_horizontal_gradients_for_turbulence", - "calculate_nabla2_and_smag_coefficients_for_vn", - "calculate_nabla4", - "mo_advection_traj_btraj_compute_o1_dsl", - "mo_math_gradients_grad_green_gauss_cell_dsl", - "mo_solve_nonhydro_stencil_16_fused_btraj_traj_o1", - "mo_solve_nonhydro_stencil_20", - "mo_solve_nonhydro_stencil_21", - "mo_solve_nonhydro_stencil_30", - "mo_solve_nonhydro_stencil_41", - "mo_velocity_advection_stencil_19", - "temporary_fields_for_turbulence_diagnostics", - "truly_horizontal_diffusion_nabla_of_theta_over_steep_points", - ] - with_temporaries = False cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) if build_cache is not None and cache_id in build_cache: # retrieve SDFG program from build cache sdfg_program = build_cache[cache_id] sdfg = sdfg_program.sdfg else: - # visit ITIR and generate SDFG - if ( - any([ItirToSDFG._check_no_lifts(node) for node in program.closures]) - and program.id not in program_itir_not_working - ): + program = preprocess_program(program, offset_provider, LiftMode.FORCE_INLINE) + if all([ItirToSDFG._check_no_lifts(node) for node in program.closures]): + tmps = [] + else: program_with_tmps: global_tmps.FencilWithTemporaries = preprocess_program( program, offset_provider, LiftMode.FORCE_TEMPORARIES ) program = program_with_tmps.fencil tmps = program_with_tmps.tmps - with_temporaries = True - else: - program = preprocess_program(program, offset_provider, LiftMode.FORCE_INLINE) - tmps = [] + # visit ITIR and generate SDFG sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, tmps, column_axis) sdfg = sdfg_genenerator.visit(program) - if with_temporaries: + if tmps: + # This pass is needed to avoid transformation errors in SDFG inlining, because temporaries are using offsets sdfg.apply_transformations_repeated(RefineNestedAccess) sdfg.simplify() diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 065fb89508..e0570f517e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -318,6 +318,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition, **kargs): last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) if self.tmps: + # on the first interstate edge define symbols for shape/stride/offsets of temporary arrays inter_state_edge = program_sdfg.out_edges(entry_state)[0] inter_state_edge.data.assignments.update(tmp_symbols)