diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 54ca08fe6e..aa963d101d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -19,6 +19,7 @@ import dace import numpy as np from dace.codegen.compiled_sdfg import CompiledSDFG +from dace.sdfg import utils as sdutils from dace.transformation.auto import auto_optimize as autoopt import gt4py.next.allocators as next_allocators @@ -262,7 +263,7 @@ def build_sdfg_from_itir( # visit ITIR and generate SDFG program = preprocess_program(program, offset_provider, lift_mode) sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) - sdfg = sdfg_genenerator.visit(program) + sdfg: dace.SDFG = sdfg_genenerator.visit(program) if sdfg is None: raise RuntimeError(f"Visit failed for program {program.id}.") @@ -279,6 +280,9 @@ def build_sdfg_from_itir( filename=frameinfo.filename, ) + # TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct + sdutils.inline_loop_blocks(sdfg) + # run DaCe transformations to simplify the SDFG sdfg.simplify() diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index dc194c0436..f8010ffd83 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -14,6 +14,7 @@ from typing import Any, Optional, cast import dace +from dace.sdfg.state import LoopRegion import gt4py.eve as eve from gt4py.next import Dimension, DimensionKind, type_inference as next_typing @@ -430,15 +431,38 @@ def _visit_scan_stencil_closure( scan_sdfg = dace.SDFG(name="scan") scan_sdfg.debuginfo = dace_debuginfo(node) - # create a state machine for lambda call over the scan dimension - start_state = scan_sdfg.add_state("start", True) - lambda_state = scan_sdfg.add_state("lambda_compute") - end_state = scan_sdfg.add_state("end") - # the carry value of the scan operator exists only in the scope of the scan sdfg scan_carry_name = unique_var_name() scan_sdfg.add_scalar(scan_carry_name, dtype=as_dace_type(scan_dtype), transient=True) + # create a loop region for lambda call over the scan dimension + scan_loop_var = f"i_{scan_dim}" + if is_forward: + scan_loop = LoopRegion( + label="scan", + condition_expr=f"{scan_loop_var} < {scan_ub_str}", + loop_var=scan_loop_var, + initialize_expr=f"{scan_loop_var} = {scan_lb_str}", + update_expr=f"{scan_loop_var} = {scan_loop_var} + 1", + inverted=False, + ) + else: + scan_loop = LoopRegion( + label="scan", + condition_expr=f"{scan_loop_var} >= {scan_lb_str}", + loop_var=scan_loop_var, + initialize_expr=f"{scan_loop_var} = {scan_ub_str} - 1", + update_expr=f"{scan_loop_var} = {scan_loop_var} - 1", + inverted=False, + ) + scan_sdfg.add_node(scan_loop) + compute_state = scan_loop.add_state("lambda_compute", is_start_block=True) + update_state = scan_loop.add_state("lambda_update") + scan_loop.add_edge(compute_state, update_state, dace.InterstateEdge()) + + start_state = scan_sdfg.add_state("start", is_start_block=True) + scan_sdfg.add_edge(start_state, scan_loop, dace.InterstateEdge()) + # tasklet for initialization of carry carry_init_tasklet = start_state.add_tasklet( "get_carry_init_value", @@ -455,19 +479,6 @@ def _visit_scan_stencil_closure( dace.Memlet.simple(scan_carry_name, "0"), ) - # TODO(edopao): replace state machine with dace loop construct - scan_sdfg.add_loop( - start_state, - lambda_state, - end_state, - loop_var=f"i_{scan_dim}", - initialize_expr=f"{scan_lb_str}" if is_forward else f"{scan_ub_str} - 1", - condition_expr=f"i_{scan_dim} < {scan_ub_str}" - if is_forward - else f"i_{scan_dim} >= {scan_lb_str}", - increment_expr=f"i_{scan_dim} + 1" if is_forward else f"i_{scan_dim} - 1", - ) - # add storage to scan SDFG for inputs for name in [*input_names, *connectivity_names]: assert name not in scan_sdfg.arrays @@ -522,7 +533,7 @@ def _visit_scan_stencil_closure( array_mapping = {**input_mapping, **connectivity_mapping} symbol_mapping = map_nested_sdfg_symbols(scan_sdfg, lambda_context.body, array_mapping) - scan_inner_node = lambda_state.add_nested_sdfg( + scan_inner_node = compute_state.add_nested_sdfg( lambda_context.body, parent=scan_sdfg, inputs=set(lambda_input_names) | set(connectivity_names), @@ -533,29 +544,25 @@ def _visit_scan_stencil_closure( # connect scan SDFG to lambda inputs for name, memlet in array_mapping.items(): - access_node = lambda_state.add_access(name, debuginfo=lambda_context.body.debuginfo) - lambda_state.add_edge(access_node, None, scan_inner_node, name, memlet) + access_node = compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo) + compute_state.add_edge(access_node, None, scan_inner_node, name, memlet) output_names = [output_name] assert len(lambda_output_names) == 1 # connect lambda output to scan SDFG for name, connector in zip(output_names, lambda_output_names): - lambda_state.add_edge( + compute_state.add_edge( scan_inner_node, connector, - lambda_state.add_access(name, debuginfo=lambda_context.body.debuginfo), + compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo), None, - dace.Memlet.simple(name, f"i_{scan_dim}"), + dace.Memlet.simple(name, scan_loop_var), ) - # add state to scan SDFG to update the carry value at each loop iteration - lambda_update_state = scan_sdfg.add_state_after(lambda_state, "lambda_update") - lambda_update_state.add_memlet_path( - lambda_update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo), - lambda_update_state.add_access( - scan_carry_name, debuginfo=lambda_context.body.debuginfo - ), - memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"), + update_state.add_nedge( + update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo), + update_state.add_access(scan_carry_name, debuginfo=lambda_context.body.debuginfo), + dace.Memlet.simple(output_names[0], scan_loop_var, other_subset_str="0"), ) return scan_sdfg, map_ranges, scan_dim_index