Skip to content

Commit

Permalink
[dace][cartesian] Fix array access in tasklet
Browse files Browse the repository at this point in the history
Found some incompatible tasklet represention while upgrading
to dace v0.15.1. Array access inside tasklet with partial index
subset worked in v0.14.1, although not valid. Added view or
transient array in case of memlet slice, to ensure that tasklet
uses explicit indexes for the full array shape.
  • Loading branch information
edopao committed Jan 9, 2024
1 parent 6b269bd commit 92a2761
Showing 1 changed file with 71 additions and 13 deletions.
84 changes: 71 additions & 13 deletions src/gt4py/cartesian/gtc/dace/expansion/sdfg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,78 @@ def visit_Memlet(
subset=make_dace_subset(field_decl.access_info, node.access_info, field_decl.data_dims),
dynamic=field_decl.is_dynamic,
)
if node.is_read:
sdfg_ctx.state.add_edge(
*node_ctx.input_node_and_conns[memlet.data],
scope_node,
connector_prefix + node.connector,
memlet,
)
if node.is_write:
sdfg_ctx.state.add_edge(
scope_node,
connector_prefix + node.connector,
*node_ctx.output_node_and_conns[memlet.data],
memlet,
"""
In case of memlet to/from a tasklet where the data passed to the tasklet is an array (not a scalar),
the tasklet should use explicit indexes for the full array shape. In case the memlet is limited to
a subset of the full shape, the resulting slice needs to be presented as a view or a transient array.
"""
if isinstance(scope_node, dace.nodes.Tasklet) and field_decl.data_dims:
dtype = data_type_to_dace_typeclass(field_decl.dtype)
endpoint, endpoint_connector = (
node_ctx.input_node_and_conns[memlet.data]
if node.is_read
else node_ctx.output_node_and_conns[memlet.data]
)
if isinstance(endpoint, dace.nodes.AccessNode):
slice_data, slice_desc = sdfg_ctx.sdfg.add_view(
f"{node.connector}_v", field_decl.data_dims, dtype, find_new_name=True
)
slice_node = sdfg_ctx.state.add_access(slice_data)
else:
slice_data, slice_desc = sdfg_ctx.sdfg.add_array(
f"{node.connector}_t",
field_decl.data_dims,
dtype,
find_new_name=True,
transient=True,
)
slice_node = sdfg_ctx.state.add_access(slice_data)

if node.is_read:
sdfg_ctx.state.add_edge(
endpoint,
endpoint_connector,
slice_node,
None,
memlet,
)
sdfg_ctx.state.add_edge(
slice_node,
None,
scope_node,
connector_prefix + node.connector,
dace.Memlet.from_array(slice_data, slice_desc),
)
if node.is_write:
sdfg_ctx.state.add_edge(
scope_node,
connector_prefix + node.connector,
slice_node,
None,
dace.Memlet.from_array(slice_data, slice_desc),
)
sdfg_ctx.state.add_edge(
slice_node,
None,
endpoint,
endpoint_connector,
memlet,
)
else:
if node.is_read:
sdfg_ctx.state.add_edge(
*node_ctx.input_node_and_conns[memlet.data],
scope_node,
connector_prefix + node.connector,
memlet,
)
if node.is_write:
sdfg_ctx.state.add_edge(
scope_node,
connector_prefix + node.connector,
*node_ctx.output_node_and_conns[memlet.data],
memlet,
)

@classmethod
def _add_empty_edges(
Expand Down

0 comments on commit 92a2761

Please sign in to comment.