Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next][dace]: use new LoopRegion construct for scan operator #1424

Merged
merged 5 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import dace
import numpy as np
from dace.codegen.compiled_sdfg import CompiledSDFG
from dace.sdfg import utils as sdutils
from dace.transformation.auto import auto_optimize as autoopt

import gt4py.next.allocators as next_allocators
Expand Down Expand Up @@ -262,7 +263,7 @@ def build_sdfg_from_itir(
# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider, lift_mode)
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
sdfg = sdfg_genenerator.visit(program)
sdfg: dace.SDFG = sdfg_genenerator.visit(program)
if sdfg is None:
raise RuntimeError(f"Visit failed for program {program.id}.")

Expand All @@ -279,6 +280,9 @@ def build_sdfg_from_itir(
filename=frameinfo.filename,
)

# TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct
sdutils.inline_loop_blocks(sdfg)

# run DaCe transformations to simplify the SDFG
sdfg.simplify()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Any, Optional, cast

import dace
from dace.sdfg.state import LoopRegion

import gt4py.eve as eve
from gt4py.next import Dimension, DimensionKind, type_inference as next_typing
Expand Down Expand Up @@ -430,15 +431,38 @@ def _visit_scan_stencil_closure(
scan_sdfg = dace.SDFG(name="scan")
scan_sdfg.debuginfo = dace_debuginfo(node)

# create a state machine for lambda call over the scan dimension
start_state = scan_sdfg.add_state("start", True)
lambda_state = scan_sdfg.add_state("lambda_compute")
end_state = scan_sdfg.add_state("end")

# the carry value of the scan operator exists only in the scope of the scan sdfg
scan_carry_name = unique_var_name()
scan_sdfg.add_scalar(scan_carry_name, dtype=as_dace_type(scan_dtype), transient=True)

# create a loop region for lambda call over the scan dimension
scan_loop_var = f"i_{scan_dim}"
if is_forward:
scan_loop = LoopRegion(
label="scan",
condition_expr=f"{scan_loop_var} < {scan_ub_str}",
loop_var=scan_loop_var,
initialize_expr=f"{scan_loop_var} = {scan_lb_str}",
update_expr=f"{scan_loop_var} = {scan_loop_var} + 1",
inverted=False,
)
else:
scan_loop = LoopRegion(
label="scan",
condition_expr=f"{scan_loop_var} >= {scan_lb_str}",
loop_var=scan_loop_var,
initialize_expr=f"{scan_loop_var} = {scan_ub_str} - 1",
update_expr=f"{scan_loop_var} = {scan_loop_var} - 1",
inverted=False,
)
scan_sdfg.add_node(scan_loop)
compute_state = scan_loop.add_state("lambda_compute", is_start_block=True)
update_state = scan_loop.add_state("lambda_update")
scan_loop.add_edge(compute_state, update_state, dace.InterstateEdge())

start_state = scan_sdfg.add_state("start", is_start_block=True)
scan_sdfg.add_edge(start_state, scan_loop, dace.InterstateEdge())

# tasklet for initialization of carry
carry_init_tasklet = start_state.add_tasklet(
"get_carry_init_value",
Expand All @@ -455,19 +479,6 @@ def _visit_scan_stencil_closure(
dace.Memlet.simple(scan_carry_name, "0"),
)

# TODO(edopao): replace state machine with dace loop construct
scan_sdfg.add_loop(
start_state,
lambda_state,
end_state,
loop_var=f"i_{scan_dim}",
initialize_expr=f"{scan_lb_str}" if is_forward else f"{scan_ub_str} - 1",
condition_expr=f"i_{scan_dim} < {scan_ub_str}"
if is_forward
else f"i_{scan_dim} >= {scan_lb_str}",
increment_expr=f"i_{scan_dim} + 1" if is_forward else f"i_{scan_dim} - 1",
)

# add storage to scan SDFG for inputs
for name in [*input_names, *connectivity_names]:
assert name not in scan_sdfg.arrays
Expand Down Expand Up @@ -522,7 +533,7 @@ def _visit_scan_stencil_closure(
array_mapping = {**input_mapping, **connectivity_mapping}
symbol_mapping = map_nested_sdfg_symbols(scan_sdfg, lambda_context.body, array_mapping)

scan_inner_node = lambda_state.add_nested_sdfg(
scan_inner_node = compute_state.add_nested_sdfg(
lambda_context.body,
parent=scan_sdfg,
inputs=set(lambda_input_names) | set(connectivity_names),
Expand All @@ -533,29 +544,25 @@ def _visit_scan_stencil_closure(

# connect scan SDFG to lambda inputs
for name, memlet in array_mapping.items():
access_node = lambda_state.add_access(name, debuginfo=lambda_context.body.debuginfo)
lambda_state.add_edge(access_node, None, scan_inner_node, name, memlet)
access_node = compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo)
compute_state.add_edge(access_node, None, scan_inner_node, name, memlet)

output_names = [output_name]
assert len(lambda_output_names) == 1
# connect lambda output to scan SDFG
for name, connector in zip(output_names, lambda_output_names):
lambda_state.add_edge(
compute_state.add_edge(
scan_inner_node,
connector,
lambda_state.add_access(name, debuginfo=lambda_context.body.debuginfo),
compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo),
None,
dace.Memlet.simple(name, f"i_{scan_dim}"),
dace.Memlet.simple(name, scan_loop_var),
)

# add state to scan SDFG to update the carry value at each loop iteration
lambda_update_state = scan_sdfg.add_state_after(lambda_state, "lambda_update")
lambda_update_state.add_memlet_path(
lambda_update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo),
lambda_update_state.add_access(
scan_carry_name, debuginfo=lambda_context.body.debuginfo
),
memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"),
update_state.add_nedge(
update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo),
update_state.add_access(scan_carry_name, debuginfo=lambda_context.body.debuginfo),
dace.Memlet.simple(output_names[0], scan_loop_var, other_subset_str="0"),
)

return scan_sdfg, map_ranges, scan_dim_index
Expand Down