From 378110bc9ef544fe110e285ef10bcee9b0eb5214 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 10 Jan 2024 16:01:59 +0100 Subject: [PATCH] [dace] Removed transient array, only use views for partial array subset --- .../gtc/dace/expansion/daceir_builder.py | 52 ++++++++++++++++++- .../gtc/dace/expansion/sdfg_builder.py | 26 ++-------- src/gt4py/cartesian/gtc/daceir.py | 4 +- 3 files changed, 58 insertions(+), 24 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index db276a48b9..1187537d18 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -316,6 +316,7 @@ def visit_FieldAccess( node: oir.FieldAccess, *, is_target: bool, + symbol_collector: "DaCeIRBuilder.SymbolCollector", targets: Set[eve.SymbolRef], var_offset_fields: Set[eve.SymbolRef], **kwargs: Any, @@ -347,6 +348,9 @@ def visit_FieldAccess( res = dcir.ScalarAccess(name=name, dtype=node.dtype) if is_target: targets.add(node.name) + for index in node.data_index: + if isinstance(index, oir.ScalarAccess): + symbol_collector.add_symbol(index.name) return res def visit_ScalarAccess( @@ -451,13 +455,59 @@ def visit_HorizontalExecution( k_interval=k_interval, ) - dcir_node = dcir.Tasklet( + dcir_node: dcir.ComputationNode = dcir.Tasklet( decls=decls, stmts=stmts, read_memlets=read_memlets, write_memlets=write_memlets, ) + if next(dcir_node.walk_values().if_isinstance(dcir.IndexAccess).iterator, None) is not None: + """ + In case of tasklet inside a map scope for the vertical dimension, the tasklet is not mapped directly + rather it is instanciated inside a nested SDFG. The reason is to avoid the tasklet to be connected + to map nodes, instead ensuring that it is connected to access nodes. This in order to be able to use + views of array containers, in case the tasklet contains array access with partial subset index. + """ + field_decls = global_ctx.get_dcir_decls( + { + field: global_ctx.library_node.access_infos[field] + for field in set(memlet.field for memlet in [*read_memlets, *write_memlets]) + }, + symbol_collector=symbol_collector, + ) + for memlet in [*read_memlets, *write_memlets]: + for sym in memlet.access_info.grid_subset.free_symbols: + symbol_collector.add_symbol(sym, common.DataType.INT32) + nested_read_memlets = [ + dcir.Memlet( + field=field, + connector=field, + access_info=global_ctx.library_node.access_infos[field], + is_read=True, + is_write=False, + ) + for field in set(memlet.field for memlet in read_memlets) + ] + nested_write_memlets = [ + dcir.Memlet( + field=field, + connector=field, + access_info=global_ctx.library_node.access_infos[field], + is_read=False, + is_write=True, + ) + for field in set(memlet.field for memlet in write_memlets) + ] + dcir_node = dcir.NestedSDFG( + label=f"{global_ctx.library_node.label}_tasklet", + field_decls=field_decls, + read_memlets=nested_read_memlets, + write_memlets=nested_write_memlets, + states=self.to_state(dcir_node, grid_subset=iteration_ctx.grid_subset), + symbol_decls=list(symbol_collector.symbol_decls.values()), + ) + for item in reversed(expansion_items): iteration_ctx = iteration_ctx.pop() dcir_node = self._process_iteration_item( diff --git a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py index bda99f26c1..0157e2ce08 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py @@ -110,32 +110,17 @@ def visit_Memlet( """ In case of memlet to/from a tasklet where the data passed to the tasklet is an array (not a scalar), the tasklet should use explicit indexes for the full array shape. In case the memlet is limited to - a subset of the full shape, the resulting slice needs to be presented as a view or a transient array. + a subset of the full shape, the resulting slice needs to be presented as a view of the original array. """ if isinstance(scope_node, dace.nodes.Tasklet) and field_decl.data_dims: dtype = data_type_to_dace_typeclass(field_decl.dtype) - endpoint, endpoint_connector = ( - node_ctx.input_node_and_conns[memlet.data] - if node.is_read - else node_ctx.output_node_and_conns[memlet.data] + slice_data, slice_desc = sdfg_ctx.sdfg.add_view( + f"{node.connector}_v", field_decl.data_dims, dtype, find_new_name=True ) - if isinstance(endpoint, dace.nodes.AccessNode): - slice_data, slice_desc = sdfg_ctx.sdfg.add_view( - f"{node.connector}_v", field_decl.data_dims, dtype, find_new_name=True - ) - else: - slice_data, slice_desc = sdfg_ctx.sdfg.add_array( - f"{node.connector}_t", - field_decl.data_dims, - dtype, - find_new_name=True, - transient=True, - ) slice_node = sdfg_ctx.state.add_access(slice_data) if node.is_read: sdfg_ctx.state.add_edge( - endpoint, - endpoint_connector, + *node_ctx.input_node_and_conns[memlet.data], slice_node, None, memlet, @@ -158,8 +143,7 @@ def visit_Memlet( sdfg_ctx.state.add_edge( slice_node, None, - endpoint, - endpoint_connector, + *node_ctx.output_node_and_conns[memlet.data], memlet, ) else: diff --git a/src/gt4py/cartesian/gtc/daceir.py b/src/gt4py/cartesian/gtc/daceir.py index 28ebc8cd8e..5ad89dd37a 100644 --- a/src/gt4py/cartesian/gtc/daceir.py +++ b/src/gt4py/cartesian/gtc/daceir.py @@ -573,7 +573,7 @@ def overapproximated_shape(self): def apply_iteration(self, grid_subset: GridSubset): res_intervals = dict(self.grid_subset.intervals) for axis, field_interval in self.grid_subset.intervals.items(): - if axis in grid_subset.intervals: + if axis in grid_subset.intervals and not isinstance(field_interval, DomainInterval): grid_interval = grid_subset.intervals[axis] assert isinstance(field_interval, IndexWithExtent) extent = field_interval.extent @@ -880,7 +880,7 @@ class DomainMap(ComputationNode, IterationNode): class ComputationState(IterationNode): - computations: List[Union[Tasklet, DomainMap]] + computations: List[Union[Tasklet, DomainMap, NestedSDFG]] class DomainLoop(IterationNode, ComputationNode):