Skip to content

Commit

Permalink
make map_strides recursive
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Dec 12, 2024
1 parent eb17345 commit 8b163da
Showing 1 changed file with 34 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -971,36 +971,55 @@ def apply(
tmp_out_subset = dace_subsets.Range.from_array(tmp_desc)
assert glob_in_subset is not None

# Find the source of the edge entering the map exit node
map_exit_in_conn = map_to_tmp_edge.src_conn.replace("OUT_", "IN_")
src_to_map_exit_edge = next(
edge for edge in graph.in_edges(map_exit) if edge.dst_conn == map_exit_in_conn
)
if isinstance(src_to_map_exit_edge.src, dace.nodes.NestedSDFG):
nsdfg_node = src_to_map_exit_edge.src
# Recursively visit the nested SDFGs for mapping from inner to outer strides on the vertical dimension
def map_strides(edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode) -> None:
if isinstance(edge.src, dace.nodes.MapExit):
# Find the source of the edge entering the map exit node
map_exit_in_conn = edge.src_conn.replace("OUT_", "IN_")
for edge_to_map_exit_edge in graph.in_edges_by_connector(
edge.src, map_exit_in_conn
):
map_strides(edge_to_map_exit_edge, outer_node)
return

if not isinstance(edge.src, dace.nodes.NestedSDFG):
return

# We need to propagate the strides inside the nested SDFG on the global arrays
# TODO: the stride should be propagated recursively to nested SDFGs, if directly connected
nsdfg_node = edge.src
new_strides = tuple(
stride
for stride, to_map_size in zip(
glob_ac.desc(sdfg).strides,
src_to_map_exit_edge.data.subset.size(),
outer_node.desc(sdfg).strides,
edge.data.subset.size(),
strict=True,
)
if to_map_size != 1
)
inner_data = src_to_map_exit_edge.src_conn
inner_data = edge.src_conn
inner_desc = nsdfg_node.sdfg.arrays[inner_data]
if isinstance(inner_desc, dace.data.Array):
inner_desc.set_shape(inner_desc.shape, new_strides)
else:
assert isinstance(inner_desc, dace.data.Scalar)
assert not inner_desc.transient

if isinstance(inner_desc, dace.data.Scalar):
assert len(new_strides) == 0
return

assert isinstance(inner_desc, dace.data.Array)
inner_desc.set_shape(inner_desc.shape, new_strides)

for stride in new_strides:
for sym in stride.free_symbols:
nsdfg_node.sdfg.add_symbol(str(sym), sym.dtype)
nsdfg_node.symbol_mapping |= {str(sym): sym}

for inner_state in nsdfg_node.sdfg.states():
for inner_node in inner_state.data_nodes():
if inner_node.data == inner_data:
for inner_edge in inner_state.in_edges(inner_node):
map_strides(inner_edge, inner_node)

map_strides(map_to_tmp_edge, glob_ac)

# We now remove the `tmp` node, and create a new connection between
# the global node and the map exit.
new_map_to_glob_edge = graph.add_edge(
Expand Down

0 comments on commit 8b163da

Please sign in to comment.