From 70f0f88df76d10f29f28a63bcc8802460da2269c Mon Sep 17 00:00:00 2001 From: edopao Date: Thu, 25 Jan 2024 13:02:13 +0100 Subject: [PATCH] feat[next][dace]: use new LoopRegion construct for scan operator (#1424) The lowering of scan operator to SDFG uses a state machine to represent a loop. This PR replaces the state machine with a LoopRegion construct introduced in dace v0.15. The LoopRegion construct is not yet supported by dace transformation, but it will in the future and it could open new optimization opportunities (e.g. K-caching). --- .../runners/dace_iterator/__init__.py | 4 ++ .../runners/dace_iterator/itir_to_sdfg.py | 71 ++++++++++--------- 2 files changed, 43 insertions(+), 32 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index a039d311ca..6a8b9bc9c6 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -20,6 +20,7 @@ import dace import numpy as np from dace.codegen.compiled_sdfg import CompiledSDFG +from dace.sdfg import utils as sdutils from dace.transformation.auto import auto_optimize as autoopt import gt4py.next.allocators as next_allocators @@ -293,6 +294,9 @@ def build_sdfg_from_itir( filename=frameinfo.filename, ) + # TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct + sdutils.inline_loop_blocks(sdfg) + # run DaCe transformations to simplify the SDFG sdfg.simplify() diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index ce1ac6073a..8a7826dae4 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -14,6 +14,7 @@ from typing import Any, Mapping, Optional, Sequence, cast import dace +from dace.sdfg.state import LoopRegion import gt4py.eve as eve from gt4py.next import Dimension, DimensionKind, type_inference as next_typing @@ -477,15 +478,38 @@ def _visit_scan_stencil_closure( scan_sdfg = dace.SDFG(name="scan") scan_sdfg.debuginfo = dace_debuginfo(node) - # create a state machine for lambda call over the scan dimension - start_state = scan_sdfg.add_state("start", True) - lambda_state = scan_sdfg.add_state("lambda_compute") - end_state = scan_sdfg.add_state("end") - # the carry value of the scan operator exists only in the scope of the scan sdfg scan_carry_name = unique_var_name() scan_sdfg.add_scalar(scan_carry_name, dtype=as_dace_type(scan_dtype), transient=True) + # create a loop region for lambda call over the scan dimension + scan_loop_var = f"i_{scan_dim}" + if is_forward: + scan_loop = LoopRegion( + label="scan", + condition_expr=f"{scan_loop_var} < {scan_ub_str}", + loop_var=scan_loop_var, + initialize_expr=f"{scan_loop_var} = {scan_lb_str}", + update_expr=f"{scan_loop_var} = {scan_loop_var} + 1", + inverted=False, + ) + else: + scan_loop = LoopRegion( + label="scan", + condition_expr=f"{scan_loop_var} >= {scan_lb_str}", + loop_var=scan_loop_var, + initialize_expr=f"{scan_loop_var} = {scan_ub_str} - 1", + update_expr=f"{scan_loop_var} = {scan_loop_var} - 1", + inverted=False, + ) + scan_sdfg.add_node(scan_loop) + compute_state = scan_loop.add_state("lambda_compute", is_start_block=True) + update_state = scan_loop.add_state("lambda_update") + scan_loop.add_edge(compute_state, update_state, dace.InterstateEdge()) + + start_state = scan_sdfg.add_state("start", is_start_block=True) + scan_sdfg.add_edge(start_state, scan_loop, dace.InterstateEdge()) + # tasklet for initialization of carry carry_init_tasklet = start_state.add_tasklet( "get_carry_init_value", @@ -502,19 +526,6 @@ def _visit_scan_stencil_closure( dace.Memlet.simple(scan_carry_name, "0"), ) - # TODO(edopao): replace state machine with dace loop construct - scan_sdfg.add_loop( - start_state, - lambda_state, - end_state, - loop_var=f"i_{scan_dim}", - initialize_expr=f"{scan_lb_str}" if is_forward else f"{scan_ub_str} - 1", - condition_expr=f"i_{scan_dim} < {scan_ub_str}" - if is_forward - else f"i_{scan_dim} >= {scan_lb_str}", - increment_expr=f"i_{scan_dim} + 1" if is_forward else f"i_{scan_dim} - 1", - ) - # add storage to scan SDFG for inputs for name in [*input_names, *connectivity_names]: assert name not in scan_sdfg.arrays @@ -569,7 +580,7 @@ def _visit_scan_stencil_closure( array_mapping = {**input_mapping, **connectivity_mapping} symbol_mapping = map_nested_sdfg_symbols(scan_sdfg, lambda_context.body, array_mapping) - scan_inner_node = lambda_state.add_nested_sdfg( + scan_inner_node = compute_state.add_nested_sdfg( lambda_context.body, parent=scan_sdfg, inputs=set(lambda_input_names) | set(connectivity_names), @@ -580,29 +591,25 @@ def _visit_scan_stencil_closure( # connect scan SDFG to lambda inputs for name, memlet in array_mapping.items(): - access_node = lambda_state.add_access(name, debuginfo=lambda_context.body.debuginfo) - lambda_state.add_edge(access_node, None, scan_inner_node, name, memlet) + access_node = compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo) + compute_state.add_edge(access_node, None, scan_inner_node, name, memlet) output_names = [output_name] assert len(lambda_output_names) == 1 # connect lambda output to scan SDFG for name, connector in zip(output_names, lambda_output_names): - lambda_state.add_edge( + compute_state.add_edge( scan_inner_node, connector, - lambda_state.add_access(name, debuginfo=lambda_context.body.debuginfo), + compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo), None, - dace.Memlet.simple(name, f"i_{scan_dim}"), + dace.Memlet.simple(name, scan_loop_var), ) - # add state to scan SDFG to update the carry value at each loop iteration - lambda_update_state = scan_sdfg.add_state_after(lambda_state, "lambda_update") - lambda_update_state.add_memlet_path( - lambda_update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo), - lambda_update_state.add_access( - scan_carry_name, debuginfo=lambda_context.body.debuginfo - ), - memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"), + update_state.add_nedge( + update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo), + update_state.add_access(scan_carry_name, debuginfo=lambda_context.body.debuginfo), + dace.Memlet.simple(output_names[0], scan_loop_var, other_subset_str="0"), ) return scan_sdfg, map_ranges, scan_dim_index