Skip to content

Commit

Permalink
Remove view node, use full index inside tasklet
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Jan 12, 2024
1 parent 378110b commit c36b8a4
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 100 deletions.
79 changes: 34 additions & 45 deletions src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
compute_dcir_access_infos,
flatten_list,
get_tasklet_symbol,
make_dace_subset,
union_inout_memlets,
union_node_grid_subsets,
untile_memlets,
Expand Down Expand Up @@ -316,7 +317,6 @@ def visit_FieldAccess(
node: oir.FieldAccess,
*,
is_target: bool,
symbol_collector: "DaCeIRBuilder.SymbolCollector",
targets: Set[eve.SymbolRef],
var_offset_fields: Set[eve.SymbolRef],
**kwargs: Any,
Expand Down Expand Up @@ -348,9 +348,6 @@ def visit_FieldAccess(
res = dcir.ScalarAccess(name=name, dtype=node.dtype)
if is_target:
targets.add(node.name)
for index in node.data_index:
if isinstance(index, oir.ScalarAccess):
symbol_collector.add_symbol(index.name)
return res

def visit_ScalarAccess(
Expand Down Expand Up @@ -455,7 +452,7 @@ def visit_HorizontalExecution(
k_interval=k_interval,
)

dcir_node: dcir.ComputationNode = dcir.Tasklet(
dcir_node = dcir.Tasklet(
decls=decls,
stmts=stmts,
read_memlets=read_memlets,
Expand All @@ -464,49 +461,41 @@ def visit_HorizontalExecution(

if next(dcir_node.walk_values().if_isinstance(dcir.IndexAccess).iterator, None) is not None:
"""
In case of tasklet inside a map scope for the vertical dimension, the tasklet is not mapped directly
rather it is instanciated inside a nested SDFG. The reason is to avoid the tasklet to be connected
to map nodes, instead ensuring that it is connected to access nodes. This in order to be able to use
views of array containers, in case the tasklet contains array access with partial subset index.
Special case of tasklet performing array access. The memlet should pass the full array shape
(no slicing) and the tasklet code should use all explicit indexes for array access.
"""
field_decls = global_ctx.get_dcir_decls(
{
field: global_ctx.library_node.access_infos[field]
for field in set(memlet.field for memlet in [*read_memlets, *write_memlets])
},
symbol_collector=symbol_collector,
)
for memlet in [*read_memlets, *write_memlets]:
for sym in memlet.access_info.grid_subset.free_symbols:
symbol_collector.add_symbol(sym, common.DataType.INT32)
nested_read_memlets = [
dcir.Memlet(
field=field,
connector=field,
access_info=global_ctx.library_node.access_infos[field],
is_read=True,
is_write=False,
)
for field in set(memlet.field for memlet in read_memlets)
]
nested_write_memlets = [
dcir.Memlet(
field=field,
connector=field,
access_info=global_ctx.library_node.access_infos[field],
is_read=False,
is_write=True,
ndims = len(iteration_ctx.grid_subset.intervals)
field_decl = global_ctx.library_node.field_decls[memlet.field]
# calculate array subset from original memlet
memlet_subset = make_dace_subset(
global_ctx.library_node.access_infos[memlet.field],
memlet.access_info,
field_decl.data_dims,
)
for field in set(memlet.field for memlet in write_memlets)
]
dcir_node = dcir.NestedSDFG(
label=f"{global_ctx.library_node.label}_tasklet",
field_decls=field_decls,
read_memlets=nested_read_memlets,
write_memlets=nested_write_memlets,
states=self.to_state(dcir_node, grid_subset=iteration_ctx.grid_subset),
symbol_decls=list(symbol_collector.symbol_decls.values()),
)
# ensure grid access on single point
assert memlet_subset.size()[:ndims] == [1] * ndims
memlet_data_index = [
dcir.Literal(value=str(r[0]), dtype=common.DataType.INT32)
for r in memlet_subset[:ndims]
]
# loop through assignment statements in the tasklet body
tasklet_subset_size = 0
for access_node in dcir_node.walk_values().if_isinstance(dcir.IndexAccess):
if access_node.data_index and access_node.name == memlet.connector:
if tasklet_subset_size != 0:
assert len(access_node.data_index) == tasklet_subset_size
else:
tasklet_subset_size = len(access_node.data_index)
for idx in reversed(memlet_data_index):
access_node.data_index.insert(0, idx)
# reshape memlet if tasklet accessed the endpoint array with partial index
if tasklet_subset_size != 0:
# ensure that memlet symbols used for array subset are defined in context
for sym in memlet.access_info.grid_subset.free_symbols:
symbol_collector.add_symbol(sym)
# set full shape on memlet
memlet.access_info = global_ctx.library_node.access_infos[memlet.field]

for item in reversed(expansion_items):
iteration_ctx = iteration_ctx.pop()
Expand Down
66 changes: 13 additions & 53 deletions src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,60 +107,20 @@ def visit_Memlet(
subset=make_dace_subset(field_decl.access_info, node.access_info, field_decl.data_dims),
dynamic=field_decl.is_dynamic,
)
"""
In case of memlet to/from a tasklet where the data passed to the tasklet is an array (not a scalar),
the tasklet should use explicit indexes for the full array shape. In case the memlet is limited to
a subset of the full shape, the resulting slice needs to be presented as a view of the original array.
"""
if isinstance(scope_node, dace.nodes.Tasklet) and field_decl.data_dims:
dtype = data_type_to_dace_typeclass(field_decl.dtype)
slice_data, slice_desc = sdfg_ctx.sdfg.add_view(
f"{node.connector}_v", field_decl.data_dims, dtype, find_new_name=True
if node.is_read:
sdfg_ctx.state.add_edge(
*node_ctx.input_node_and_conns[memlet.data],
scope_node,
connector_prefix + node.connector,
memlet,
)
if node.is_write:
sdfg_ctx.state.add_edge(
scope_node,
connector_prefix + node.connector,
*node_ctx.output_node_and_conns[memlet.data],
memlet,
)
slice_node = sdfg_ctx.state.add_access(slice_data)
if node.is_read:
sdfg_ctx.state.add_edge(
*node_ctx.input_node_and_conns[memlet.data],
slice_node,
None,
memlet,
)
sdfg_ctx.state.add_edge(
slice_node,
None,
scope_node,
connector_prefix + node.connector,
dace.Memlet.from_array(slice_data, slice_desc),
)
if node.is_write:
sdfg_ctx.state.add_edge(
scope_node,
connector_prefix + node.connector,
slice_node,
None,
dace.Memlet.from_array(slice_data, slice_desc),
)
sdfg_ctx.state.add_edge(
slice_node,
None,
*node_ctx.output_node_and_conns[memlet.data],
memlet,
)
else:
if node.is_read:
sdfg_ctx.state.add_edge(
*node_ctx.input_node_and_conns[memlet.data],
scope_node,
connector_prefix + node.connector,
memlet,
)
if node.is_write:
sdfg_ctx.state.add_edge(
scope_node,
connector_prefix + node.connector,
*node_ctx.output_node_and_conns[memlet.data],
memlet,
)

@classmethod
def _add_empty_edges(
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/cartesian/gtc/daceir.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def union(self, other):
else:
assert (
isinstance(interval2, (TileInterval, DomainInterval))
and isinstance(interval1, IndexWithExtent)
and isinstance(interval1, (IndexWithExtent, DomainInterval))
) or (
isinstance(interval1, (TileInterval, DomainInterval))
and isinstance(interval2, IndexWithExtent)
Expand Down Expand Up @@ -880,7 +880,7 @@ class DomainMap(ComputationNode, IterationNode):


class ComputationState(IterationNode):
computations: List[Union[Tasklet, DomainMap, NestedSDFG]]
computations: List[Union[Tasklet, DomainMap]]


class DomainLoop(IterationNode, ComputationNode):
Expand Down

0 comments on commit c36b8a4

Please sign in to comment.