diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py index c4e9be3835..3debb6a5eb 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/simplify.py @@ -971,36 +971,55 @@ def apply( tmp_out_subset = dace_subsets.Range.from_array(tmp_desc) assert glob_in_subset is not None - # Find the source of the edge entering the map exit node - map_exit_in_conn = map_to_tmp_edge.src_conn.replace("OUT_", "IN_") - src_to_map_exit_edge = next( - edge for edge in graph.in_edges(map_exit) if edge.dst_conn == map_exit_in_conn - ) - if isinstance(src_to_map_exit_edge.src, dace.nodes.NestedSDFG): - nsdfg_node = src_to_map_exit_edge.src + # Recursively visit the nested SDFGs for mapping from inner to outer strides on the vertical dimension + def map_strides(edge: dace.sdfg.graph.Edge, outer_node: dace.nodes.AccessNode) -> None: + if isinstance(edge.src, dace.nodes.MapExit): + # Find the source of the edge entering the map exit node + map_exit_in_conn = edge.src_conn.replace("OUT_", "IN_") + for edge_to_map_exit_edge in graph.in_edges_by_connector( + edge.src, map_exit_in_conn + ): + map_strides(edge_to_map_exit_edge, outer_node) + return + + if not isinstance(edge.src, dace.nodes.NestedSDFG): + return + # We need to propagate the strides inside the nested SDFG on the global arrays - # TODO: the stride should be propagated recursively to nested SDFGs, if directly connected + nsdfg_node = edge.src new_strides = tuple( stride for stride, to_map_size in zip( - glob_ac.desc(sdfg).strides, - src_to_map_exit_edge.data.subset.size(), + outer_node.desc(sdfg).strides, + edge.data.subset.size(), strict=True, ) if to_map_size != 1 ) - inner_data = src_to_map_exit_edge.src_conn + inner_data = edge.src_conn inner_desc = nsdfg_node.sdfg.arrays[inner_data] - if isinstance(inner_desc, dace.data.Array): - inner_desc.set_shape(inner_desc.shape, new_strides) - else: - assert isinstance(inner_desc, dace.data.Scalar) + assert not inner_desc.transient + + if isinstance(inner_desc, dace.data.Scalar): assert len(new_strides) == 0 + return + + assert isinstance(inner_desc, dace.data.Array) + inner_desc.set_shape(inner_desc.shape, new_strides) + for stride in new_strides: for sym in stride.free_symbols: nsdfg_node.sdfg.add_symbol(str(sym), sym.dtype) nsdfg_node.symbol_mapping |= {str(sym): sym} + for inner_state in nsdfg_node.sdfg.states(): + for inner_node in inner_state.data_nodes(): + if inner_node.data == inner_data: + for inner_edge in inner_state.in_edges(inner_node): + map_strides(inner_edge, inner_node) + + map_strides(map_to_tmp_edge, glob_ac) + # We now remove the `tmp` node, and create a new connection between # the global node and the map exit. new_map_to_glob_edge = graph.add_edge(