diff --git a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py index dcffd9e410..a29ecc71c6 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py @@ -107,20 +107,78 @@ def visit_Memlet( subset=make_dace_subset(field_decl.access_info, node.access_info, field_decl.data_dims), dynamic=field_decl.is_dynamic, ) - if node.is_read: - sdfg_ctx.state.add_edge( - *node_ctx.input_node_and_conns[memlet.data], - scope_node, - connector_prefix + node.connector, - memlet, - ) - if node.is_write: - sdfg_ctx.state.add_edge( - scope_node, - connector_prefix + node.connector, - *node_ctx.output_node_and_conns[memlet.data], - memlet, + """ + In case of memlet to/from a tasklet where the data passed to the tasklet is an array (not a scalar), + the tasklet should use explicit indexes for the full array shape. In case the memlet is limited to + a subset of the full shape, the resulting slice needs to be presented as a view or a transient array. + """ + if isinstance(scope_node, dace.nodes.Tasklet) and field_decl.data_dims: + dtype = data_type_to_dace_typeclass(field_decl.dtype) + endpoint, endpoint_connector = ( + node_ctx.input_node_and_conns[memlet.data] + if node.is_read + else node_ctx.output_node_and_conns[memlet.data] ) + if isinstance(endpoint, dace.nodes.AccessNode): + slice_data, slice_desc = sdfg_ctx.sdfg.add_view( + f"{node.connector}_v", field_decl.data_dims, dtype, find_new_name=True + ) + slice_node = sdfg_ctx.state.add_access(slice_data) + else: + slice_data, slice_desc = sdfg_ctx.sdfg.add_array( + f"{node.connector}_t", + field_decl.data_dims, + dtype, + find_new_name=True, + transient=True, + ) + slice_node = sdfg_ctx.state.add_access(slice_data) + + if node.is_read: + sdfg_ctx.state.add_edge( + endpoint, + endpoint_connector, + slice_node, + None, + memlet, + ) + sdfg_ctx.state.add_edge( + slice_node, + None, + scope_node, + connector_prefix + node.connector, + dace.Memlet.from_array(slice_data, slice_desc), + ) + if node.is_write: + sdfg_ctx.state.add_edge( + scope_node, + connector_prefix + node.connector, + slice_node, + None, + dace.Memlet.from_array(slice_data, slice_desc), + ) + sdfg_ctx.state.add_edge( + slice_node, + None, + endpoint, + endpoint_connector, + memlet, + ) + else: + if node.is_read: + sdfg_ctx.state.add_edge( + *node_ctx.input_node_and_conns[memlet.data], + scope_node, + connector_prefix + node.connector, + memlet, + ) + if node.is_write: + sdfg_ctx.state.add_edge( + scope_node, + connector_prefix + node.connector, + *node_ctx.output_node_and_conns[memlet.data], + memlet, + ) @classmethod def _add_empty_edges(