From c36b8a49da493a9c9b91fcfaf7bf9e5646808124 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Fri, 12 Jan 2024 17:48:14 +0100 Subject: [PATCH] Remove view node, use full index inside tasklet --- .../gtc/dace/expansion/daceir_builder.py | 79 ++++++++----------- .../gtc/dace/expansion/sdfg_builder.py | 66 +++------------- src/gt4py/cartesian/gtc/daceir.py | 4 +- 3 files changed, 49 insertions(+), 100 deletions(-) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index 1187537d18..bee10f6e0c 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -30,6 +30,7 @@ compute_dcir_access_infos, flatten_list, get_tasklet_symbol, + make_dace_subset, union_inout_memlets, union_node_grid_subsets, untile_memlets, @@ -316,7 +317,6 @@ def visit_FieldAccess( node: oir.FieldAccess, *, is_target: bool, - symbol_collector: "DaCeIRBuilder.SymbolCollector", targets: Set[eve.SymbolRef], var_offset_fields: Set[eve.SymbolRef], **kwargs: Any, @@ -348,9 +348,6 @@ def visit_FieldAccess( res = dcir.ScalarAccess(name=name, dtype=node.dtype) if is_target: targets.add(node.name) - for index in node.data_index: - if isinstance(index, oir.ScalarAccess): - symbol_collector.add_symbol(index.name) return res def visit_ScalarAccess( @@ -455,7 +452,7 @@ def visit_HorizontalExecution( k_interval=k_interval, ) - dcir_node: dcir.ComputationNode = dcir.Tasklet( + dcir_node = dcir.Tasklet( decls=decls, stmts=stmts, read_memlets=read_memlets, @@ -464,49 +461,41 @@ def visit_HorizontalExecution( if next(dcir_node.walk_values().if_isinstance(dcir.IndexAccess).iterator, None) is not None: """ - In case of tasklet inside a map scope for the vertical dimension, the tasklet is not mapped directly - rather it is instanciated inside a nested SDFG. The reason is to avoid the tasklet to be connected - to map nodes, instead ensuring that it is connected to access nodes. This in order to be able to use - views of array containers, in case the tasklet contains array access with partial subset index. + Special case of tasklet performing array access. The memlet should pass the full array shape + (no slicing) and the tasklet code should use all explicit indexes for array access. """ - field_decls = global_ctx.get_dcir_decls( - { - field: global_ctx.library_node.access_infos[field] - for field in set(memlet.field for memlet in [*read_memlets, *write_memlets]) - }, - symbol_collector=symbol_collector, - ) for memlet in [*read_memlets, *write_memlets]: - for sym in memlet.access_info.grid_subset.free_symbols: - symbol_collector.add_symbol(sym, common.DataType.INT32) - nested_read_memlets = [ - dcir.Memlet( - field=field, - connector=field, - access_info=global_ctx.library_node.access_infos[field], - is_read=True, - is_write=False, - ) - for field in set(memlet.field for memlet in read_memlets) - ] - nested_write_memlets = [ - dcir.Memlet( - field=field, - connector=field, - access_info=global_ctx.library_node.access_infos[field], - is_read=False, - is_write=True, + ndims = len(iteration_ctx.grid_subset.intervals) + field_decl = global_ctx.library_node.field_decls[memlet.field] + # calculate array subset from original memlet + memlet_subset = make_dace_subset( + global_ctx.library_node.access_infos[memlet.field], + memlet.access_info, + field_decl.data_dims, ) - for field in set(memlet.field for memlet in write_memlets) - ] - dcir_node = dcir.NestedSDFG( - label=f"{global_ctx.library_node.label}_tasklet", - field_decls=field_decls, - read_memlets=nested_read_memlets, - write_memlets=nested_write_memlets, - states=self.to_state(dcir_node, grid_subset=iteration_ctx.grid_subset), - symbol_decls=list(symbol_collector.symbol_decls.values()), - ) + # ensure grid access on single point + assert memlet_subset.size()[:ndims] == [1] * ndims + memlet_data_index = [ + dcir.Literal(value=str(r[0]), dtype=common.DataType.INT32) + for r in memlet_subset[:ndims] + ] + # loop through assignment statements in the tasklet body + tasklet_subset_size = 0 + for access_node in dcir_node.walk_values().if_isinstance(dcir.IndexAccess): + if access_node.data_index and access_node.name == memlet.connector: + if tasklet_subset_size != 0: + assert len(access_node.data_index) == tasklet_subset_size + else: + tasklet_subset_size = len(access_node.data_index) + for idx in reversed(memlet_data_index): + access_node.data_index.insert(0, idx) + # reshape memlet if tasklet accessed the endpoint array with partial index + if tasklet_subset_size != 0: + # ensure that memlet symbols used for array subset are defined in context + for sym in memlet.access_info.grid_subset.free_symbols: + symbol_collector.add_symbol(sym) + # set full shape on memlet + memlet.access_info = global_ctx.library_node.access_infos[memlet.field] for item in reversed(expansion_items): iteration_ctx = iteration_ctx.pop() diff --git a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py index 0157e2ce08..dcffd9e410 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py @@ -107,60 +107,20 @@ def visit_Memlet( subset=make_dace_subset(field_decl.access_info, node.access_info, field_decl.data_dims), dynamic=field_decl.is_dynamic, ) - """ - In case of memlet to/from a tasklet where the data passed to the tasklet is an array (not a scalar), - the tasklet should use explicit indexes for the full array shape. In case the memlet is limited to - a subset of the full shape, the resulting slice needs to be presented as a view of the original array. - """ - if isinstance(scope_node, dace.nodes.Tasklet) and field_decl.data_dims: - dtype = data_type_to_dace_typeclass(field_decl.dtype) - slice_data, slice_desc = sdfg_ctx.sdfg.add_view( - f"{node.connector}_v", field_decl.data_dims, dtype, find_new_name=True + if node.is_read: + sdfg_ctx.state.add_edge( + *node_ctx.input_node_and_conns[memlet.data], + scope_node, + connector_prefix + node.connector, + memlet, + ) + if node.is_write: + sdfg_ctx.state.add_edge( + scope_node, + connector_prefix + node.connector, + *node_ctx.output_node_and_conns[memlet.data], + memlet, ) - slice_node = sdfg_ctx.state.add_access(slice_data) - if node.is_read: - sdfg_ctx.state.add_edge( - *node_ctx.input_node_and_conns[memlet.data], - slice_node, - None, - memlet, - ) - sdfg_ctx.state.add_edge( - slice_node, - None, - scope_node, - connector_prefix + node.connector, - dace.Memlet.from_array(slice_data, slice_desc), - ) - if node.is_write: - sdfg_ctx.state.add_edge( - scope_node, - connector_prefix + node.connector, - slice_node, - None, - dace.Memlet.from_array(slice_data, slice_desc), - ) - sdfg_ctx.state.add_edge( - slice_node, - None, - *node_ctx.output_node_and_conns[memlet.data], - memlet, - ) - else: - if node.is_read: - sdfg_ctx.state.add_edge( - *node_ctx.input_node_and_conns[memlet.data], - scope_node, - connector_prefix + node.connector, - memlet, - ) - if node.is_write: - sdfg_ctx.state.add_edge( - scope_node, - connector_prefix + node.connector, - *node_ctx.output_node_and_conns[memlet.data], - memlet, - ) @classmethod def _add_empty_edges( diff --git a/src/gt4py/cartesian/gtc/daceir.py b/src/gt4py/cartesian/gtc/daceir.py index 5ad89dd37a..0366317360 100644 --- a/src/gt4py/cartesian/gtc/daceir.py +++ b/src/gt4py/cartesian/gtc/daceir.py @@ -536,7 +536,7 @@ def union(self, other): else: assert ( isinstance(interval2, (TileInterval, DomainInterval)) - and isinstance(interval1, IndexWithExtent) + and isinstance(interval1, (IndexWithExtent, DomainInterval)) ) or ( isinstance(interval1, (TileInterval, DomainInterval)) and isinstance(interval2, IndexWithExtent) @@ -880,7 +880,7 @@ class DomainMap(ComputationNode, IterationNode): class ComputationState(IterationNode): - computations: List[Union[Tasklet, DomainMap, NestedSDFG]] + computations: List[Union[Tasklet, DomainMap]] class DomainLoop(IterationNode, ComputationNode):