From 07e4a251cab2e3be845f838e2eaa2610393c98d1 Mon Sep 17 00:00:00 2001 From: Lukas Truemper Date: Sat, 2 Sep 2023 07:33:15 +0200 Subject: [PATCH] PruneConnectors: Generalize to scoped NestedSDFGs --- .../dataflow/prune_connectors.py | 9 +- .../transformations/prune_connectors_test.py | 84 ++++++++++++++++++- 2 files changed, 90 insertions(+), 3 deletions(-) diff --git a/dace/transformation/dataflow/prune_connectors.py b/dace/transformation/dataflow/prune_connectors.py index ecc89bc753..8e889b81b7 100644 --- a/dace/transformation/dataflow/prune_connectors.py +++ b/dace/transformation/dataflow/prune_connectors.py @@ -46,7 +46,10 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi # Add WCR outputs to "do not prune" input list for e in graph.out_edges(nsdfg): if e.data.wcr is not None and e.src_conn in prune_in: - if (graph.in_degree(next(iter(graph.in_edges_by_connector(nsdfg, e.src_conn))).src) > 0): + # Find access node + in_edge = next(iter(graph.in_edges_by_connector(nsdfg, e.src_conn))) + access_node = graph.memlet_path(in_edge)[0].src + if graph.in_degree(access_node) > 0: prune_in.remove(e.src_conn) has_before = all( graph.in_degree(graph.memlet_path(e)[0].src) > 0 for e in graph.in_edges(nsdfg) if e.dst_conn in prune_in) @@ -73,7 +76,9 @@ def apply(self, state: SDFGState, sdfg: SDFG): # Add WCR outputs to "do not prune" input list for e in state.out_edges(nsdfg): if e.data.wcr is not None and e.src_conn in prune_in: - if (state.in_degree(next(iter(state.in_edges_by_connector(nsdfg, e.src_conn))).src) > 0): + in_edge = next(iter(state.in_edges_by_connector(nsdfg, e.src_conn))) + access_node = state.memlet_path(in_edge)[0].src + if state.in_degree(access_node) > 0: prune_in.remove(e.src_conn) do_not_prune = set() for conn in prune_in: diff --git a/tests/transformations/prune_connectors_test.py b/tests/transformations/prune_connectors_test.py index 1b9ee4369d..b34306a318 100644 --- a/tests/transformations/prune_connectors_test.py +++ b/tests/transformations/prune_connectors_test.py @@ -4,7 +4,10 @@ import os import pytest import dace -from dace.transformation.dataflow import PruneConnectors + +from dace.sdfg.state import StateSubgraphView +from dace.transformation.dataflow import PruneConnectors, AugAssignToWCR +from dace.transformation import helpers def make_sdfg(): @@ -237,6 +240,83 @@ def test_unused_retval_2(): assert np.allclose(a, 1) +def test_prune_connectors_in_scope(): + + @dace.program + def sdfg_prune_connectors_in_scope(A: dace.float64[32], B: dace.float64[32]): + for i in dace.map[0:32]: + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + c << B[i] + b >> A[i] + """ + b = min(c, a); + """ + + sdfg = sdfg_prune_connectors_in_scope.to_sdfg() + sdfg.simplify() + + state = sdfg.start_state + map_entry = None + for node in state.nodes(): + if isinstance(node, dace.nodes.MapEntry): + map_entry = node + break + + map_exit = state.exit_node(map_entry) + subgraph = StateSubgraphView(state, set(state.all_nodes_between(map_entry, map_exit))) + helpers.nest_state_subgraph(sdfg, state, subgraph) + + # WCR Conversion + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + applied = sdfg.apply_transformations_repeated(PruneConnectors) + assert applied == 1 + + +def test_prune_connectors_in_scope_dependency(): + + @dace.program + def sdfg_prune_connectors_in_scope_dependency(A: dace.float64[32], B: dace.float64[32]): + for j in dace.map[0:32]: + with dace.tasklet(language=dace.Language.CPP): + a >> A[j] + """ + a = 0; + """ + for i in dace.map[0:32]: + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + c << B[i] + b >> A[i] + """ + b = min(c, a); + """ + + sdfg = sdfg_prune_connectors_in_scope_dependency.to_sdfg() + sdfg.simplify() + + state = sdfg.start_state + map_entry = None + for node in state.nodes(): + if isinstance(node, dace.nodes.MapEntry) and str(node.map.params[0]) == "i": + map_entry = node + break + + # Create NestedSDFG inside map + map_exit = state.exit_node(map_entry) + subgraph = StateSubgraphView(state, set(state.all_nodes_between(map_entry, map_exit))) + helpers.nest_state_subgraph(sdfg, state, subgraph) + + # WCR Conversion: A is no input of NestedSDFG + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + applied = sdfg.apply_transformations_repeated(PruneConnectors) + assert applied == 0 + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--N", default=64) @@ -248,3 +328,5 @@ def test_unused_retval_2(): test_prune_connectors(True, n=n) test_unused_retval() test_unused_retval_2() + test_prune_connectors_in_scope() + test_prune_connectors_in_scope_dependency()