Skip to content

Commit

Permalink
[dace] Removed transient array, only use views for partial array subset
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Jan 10, 2024
1 parent 75c7b8c commit 378110b
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 24 deletions.
52 changes: 51 additions & 1 deletion src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def visit_FieldAccess(
node: oir.FieldAccess,
*,
is_target: bool,
symbol_collector: "DaCeIRBuilder.SymbolCollector",
targets: Set[eve.SymbolRef],
var_offset_fields: Set[eve.SymbolRef],
**kwargs: Any,
Expand Down Expand Up @@ -347,6 +348,9 @@ def visit_FieldAccess(
res = dcir.ScalarAccess(name=name, dtype=node.dtype)
if is_target:
targets.add(node.name)
for index in node.data_index:
if isinstance(index, oir.ScalarAccess):
symbol_collector.add_symbol(index.name)
return res

def visit_ScalarAccess(
Expand Down Expand Up @@ -451,13 +455,59 @@ def visit_HorizontalExecution(
k_interval=k_interval,
)

dcir_node = dcir.Tasklet(
dcir_node: dcir.ComputationNode = dcir.Tasklet(
decls=decls,
stmts=stmts,
read_memlets=read_memlets,
write_memlets=write_memlets,
)

if next(dcir_node.walk_values().if_isinstance(dcir.IndexAccess).iterator, None) is not None:
"""
In case of tasklet inside a map scope for the vertical dimension, the tasklet is not mapped directly
rather it is instanciated inside a nested SDFG. The reason is to avoid the tasklet to be connected
to map nodes, instead ensuring that it is connected to access nodes. This in order to be able to use
views of array containers, in case the tasklet contains array access with partial subset index.
"""
field_decls = global_ctx.get_dcir_decls(
{
field: global_ctx.library_node.access_infos[field]
for field in set(memlet.field for memlet in [*read_memlets, *write_memlets])
},
symbol_collector=symbol_collector,
)
for memlet in [*read_memlets, *write_memlets]:
for sym in memlet.access_info.grid_subset.free_symbols:
symbol_collector.add_symbol(sym, common.DataType.INT32)
nested_read_memlets = [
dcir.Memlet(
field=field,
connector=field,
access_info=global_ctx.library_node.access_infos[field],
is_read=True,
is_write=False,
)
for field in set(memlet.field for memlet in read_memlets)
]
nested_write_memlets = [
dcir.Memlet(
field=field,
connector=field,
access_info=global_ctx.library_node.access_infos[field],
is_read=False,
is_write=True,
)
for field in set(memlet.field for memlet in write_memlets)
]
dcir_node = dcir.NestedSDFG(
label=f"{global_ctx.library_node.label}_tasklet",
field_decls=field_decls,
read_memlets=nested_read_memlets,
write_memlets=nested_write_memlets,
states=self.to_state(dcir_node, grid_subset=iteration_ctx.grid_subset),
symbol_decls=list(symbol_collector.symbol_decls.values()),
)

for item in reversed(expansion_items):
iteration_ctx = iteration_ctx.pop()
dcir_node = self._process_iteration_item(
Expand Down
26 changes: 5 additions & 21 deletions src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,32 +110,17 @@ def visit_Memlet(
"""
In case of memlet to/from a tasklet where the data passed to the tasklet is an array (not a scalar),
the tasklet should use explicit indexes for the full array shape. In case the memlet is limited to
a subset of the full shape, the resulting slice needs to be presented as a view or a transient array.
a subset of the full shape, the resulting slice needs to be presented as a view of the original array.
"""
if isinstance(scope_node, dace.nodes.Tasklet) and field_decl.data_dims:
dtype = data_type_to_dace_typeclass(field_decl.dtype)
endpoint, endpoint_connector = (
node_ctx.input_node_and_conns[memlet.data]
if node.is_read
else node_ctx.output_node_and_conns[memlet.data]
slice_data, slice_desc = sdfg_ctx.sdfg.add_view(
f"{node.connector}_v", field_decl.data_dims, dtype, find_new_name=True
)
if isinstance(endpoint, dace.nodes.AccessNode):
slice_data, slice_desc = sdfg_ctx.sdfg.add_view(
f"{node.connector}_v", field_decl.data_dims, dtype, find_new_name=True
)
else:
slice_data, slice_desc = sdfg_ctx.sdfg.add_array(
f"{node.connector}_t",
field_decl.data_dims,
dtype,
find_new_name=True,
transient=True,
)
slice_node = sdfg_ctx.state.add_access(slice_data)
if node.is_read:
sdfg_ctx.state.add_edge(
endpoint,
endpoint_connector,
*node_ctx.input_node_and_conns[memlet.data],
slice_node,
None,
memlet,
Expand All @@ -158,8 +143,7 @@ def visit_Memlet(
sdfg_ctx.state.add_edge(
slice_node,
None,
endpoint,
endpoint_connector,
*node_ctx.output_node_and_conns[memlet.data],
memlet,
)
else:
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/cartesian/gtc/daceir.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def overapproximated_shape(self):
def apply_iteration(self, grid_subset: GridSubset):
res_intervals = dict(self.grid_subset.intervals)
for axis, field_interval in self.grid_subset.intervals.items():
if axis in grid_subset.intervals:
if axis in grid_subset.intervals and not isinstance(field_interval, DomainInterval):
grid_interval = grid_subset.intervals[axis]
assert isinstance(field_interval, IndexWithExtent)
extent = field_interval.extent
Expand Down Expand Up @@ -880,7 +880,7 @@ class DomainMap(ComputationNode, IterationNode):


class ComputationState(IterationNode):
computations: List[Union[Tasklet, DomainMap]]
computations: List[Union[Tasklet, DomainMap, NestedSDFG]]


class DomainLoop(IterationNode, ComputationNode):
Expand Down

0 comments on commit 378110b

Please sign in to comment.