Skip to content

Commit

Permalink
feat[next][dace]: use new LoopRegion construct for scan operator (#1424)
Browse files Browse the repository at this point in the history
The lowering of scan operator to SDFG uses a state machine to represent a loop. This PR replaces the state machine with a LoopRegion construct introduced in dace v0.15. The LoopRegion construct is not yet supported by dace transformation, but it will in the future and it could open new optimization opportunities (e.g. K-caching).
  • Loading branch information
edopao authored Jan 25, 2024
1 parent ac0478a commit 70f0f88
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import dace
import numpy as np
from dace.codegen.compiled_sdfg import CompiledSDFG
from dace.sdfg import utils as sdutils
from dace.transformation.auto import auto_optimize as autoopt

import gt4py.next.allocators as next_allocators
Expand Down Expand Up @@ -293,6 +294,9 @@ def build_sdfg_from_itir(
filename=frameinfo.filename,
)

# TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct
sdutils.inline_loop_blocks(sdfg)

# run DaCe transformations to simplify the SDFG
sdfg.simplify()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Mapping, Optional, Sequence, cast

import dace
from dace.sdfg.state import LoopRegion

import gt4py.eve as eve
from gt4py.next import Dimension, DimensionKind, type_inference as next_typing
Expand Down Expand Up @@ -477,15 +478,38 @@ def _visit_scan_stencil_closure(
scan_sdfg = dace.SDFG(name="scan")
scan_sdfg.debuginfo = dace_debuginfo(node)

# create a state machine for lambda call over the scan dimension
start_state = scan_sdfg.add_state("start", True)
lambda_state = scan_sdfg.add_state("lambda_compute")
end_state = scan_sdfg.add_state("end")

# the carry value of the scan operator exists only in the scope of the scan sdfg
scan_carry_name = unique_var_name()
scan_sdfg.add_scalar(scan_carry_name, dtype=as_dace_type(scan_dtype), transient=True)

# create a loop region for lambda call over the scan dimension
scan_loop_var = f"i_{scan_dim}"
if is_forward:
scan_loop = LoopRegion(
label="scan",
condition_expr=f"{scan_loop_var} < {scan_ub_str}",
loop_var=scan_loop_var,
initialize_expr=f"{scan_loop_var} = {scan_lb_str}",
update_expr=f"{scan_loop_var} = {scan_loop_var} + 1",
inverted=False,
)
else:
scan_loop = LoopRegion(
label="scan",
condition_expr=f"{scan_loop_var} >= {scan_lb_str}",
loop_var=scan_loop_var,
initialize_expr=f"{scan_loop_var} = {scan_ub_str} - 1",
update_expr=f"{scan_loop_var} = {scan_loop_var} - 1",
inverted=False,
)
scan_sdfg.add_node(scan_loop)
compute_state = scan_loop.add_state("lambda_compute", is_start_block=True)
update_state = scan_loop.add_state("lambda_update")
scan_loop.add_edge(compute_state, update_state, dace.InterstateEdge())

start_state = scan_sdfg.add_state("start", is_start_block=True)
scan_sdfg.add_edge(start_state, scan_loop, dace.InterstateEdge())

# tasklet for initialization of carry
carry_init_tasklet = start_state.add_tasklet(
"get_carry_init_value",
Expand All @@ -502,19 +526,6 @@ def _visit_scan_stencil_closure(
dace.Memlet.simple(scan_carry_name, "0"),
)

# TODO(edopao): replace state machine with dace loop construct
scan_sdfg.add_loop(
start_state,
lambda_state,
end_state,
loop_var=f"i_{scan_dim}",
initialize_expr=f"{scan_lb_str}" if is_forward else f"{scan_ub_str} - 1",
condition_expr=f"i_{scan_dim} < {scan_ub_str}"
if is_forward
else f"i_{scan_dim} >= {scan_lb_str}",
increment_expr=f"i_{scan_dim} + 1" if is_forward else f"i_{scan_dim} - 1",
)

# add storage to scan SDFG for inputs
for name in [*input_names, *connectivity_names]:
assert name not in scan_sdfg.arrays
Expand Down Expand Up @@ -569,7 +580,7 @@ def _visit_scan_stencil_closure(
array_mapping = {**input_mapping, **connectivity_mapping}
symbol_mapping = map_nested_sdfg_symbols(scan_sdfg, lambda_context.body, array_mapping)

scan_inner_node = lambda_state.add_nested_sdfg(
scan_inner_node = compute_state.add_nested_sdfg(
lambda_context.body,
parent=scan_sdfg,
inputs=set(lambda_input_names) | set(connectivity_names),
Expand All @@ -580,29 +591,25 @@ def _visit_scan_stencil_closure(

# connect scan SDFG to lambda inputs
for name, memlet in array_mapping.items():
access_node = lambda_state.add_access(name, debuginfo=lambda_context.body.debuginfo)
lambda_state.add_edge(access_node, None, scan_inner_node, name, memlet)
access_node = compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo)
compute_state.add_edge(access_node, None, scan_inner_node, name, memlet)

output_names = [output_name]
assert len(lambda_output_names) == 1
# connect lambda output to scan SDFG
for name, connector in zip(output_names, lambda_output_names):
lambda_state.add_edge(
compute_state.add_edge(
scan_inner_node,
connector,
lambda_state.add_access(name, debuginfo=lambda_context.body.debuginfo),
compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo),
None,
dace.Memlet.simple(name, f"i_{scan_dim}"),
dace.Memlet.simple(name, scan_loop_var),
)

# add state to scan SDFG to update the carry value at each loop iteration
lambda_update_state = scan_sdfg.add_state_after(lambda_state, "lambda_update")
lambda_update_state.add_memlet_path(
lambda_update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo),
lambda_update_state.add_access(
scan_carry_name, debuginfo=lambda_context.body.debuginfo
),
memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"),
update_state.add_nedge(
update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo),
update_state.add_access(scan_carry_name, debuginfo=lambda_context.body.debuginfo),
dace.Memlet.simple(output_names[0], scan_loop_var, other_subset_str="0"),
)

return scan_sdfg, map_ranges, scan_dim_index
Expand Down

0 comments on commit 70f0f88

Please sign in to comment.