Skip to content

Commit

Permalink
PruneConnectors: Generalize to scoped NestedSDFGs
Browse files Browse the repository at this point in the history
  • Loading branch information
lukastruemper committed Sep 2, 2023
1 parent 64df91f commit 07e4a25
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 3 deletions.
9 changes: 7 additions & 2 deletions dace/transformation/dataflow/prune_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi
# Add WCR outputs to "do not prune" input list
for e in graph.out_edges(nsdfg):
if e.data.wcr is not None and e.src_conn in prune_in:
if (graph.in_degree(next(iter(graph.in_edges_by_connector(nsdfg, e.src_conn))).src) > 0):
# Find access node
in_edge = next(iter(graph.in_edges_by_connector(nsdfg, e.src_conn)))
access_node = graph.memlet_path(in_edge)[0].src
if graph.in_degree(access_node) > 0:
prune_in.remove(e.src_conn)
has_before = all(
graph.in_degree(graph.memlet_path(e)[0].src) > 0 for e in graph.in_edges(nsdfg) if e.dst_conn in prune_in)
Expand All @@ -73,7 +76,9 @@ def apply(self, state: SDFGState, sdfg: SDFG):
# Add WCR outputs to "do not prune" input list
for e in state.out_edges(nsdfg):
if e.data.wcr is not None and e.src_conn in prune_in:
if (state.in_degree(next(iter(state.in_edges_by_connector(nsdfg, e.src_conn))).src) > 0):
in_edge = next(iter(state.in_edges_by_connector(nsdfg, e.src_conn)))
access_node = state.memlet_path(in_edge)[0].src
if state.in_degree(access_node) > 0:
prune_in.remove(e.src_conn)
do_not_prune = set()
for conn in prune_in:
Expand Down
84 changes: 83 additions & 1 deletion tests/transformations/prune_connectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import os
import pytest
import dace
from dace.transformation.dataflow import PruneConnectors

from dace.sdfg.state import StateSubgraphView
from dace.transformation.dataflow import PruneConnectors, AugAssignToWCR
from dace.transformation import helpers


def make_sdfg():
Expand Down Expand Up @@ -237,6 +240,83 @@ def test_unused_retval_2():
assert np.allclose(a, 1)


def test_prune_connectors_in_scope():

@dace.program
def sdfg_prune_connectors_in_scope(A: dace.float64[32], B: dace.float64[32]):
for i in dace.map[0:32]:
with dace.tasklet(language=dace.Language.CPP):
a << A[i]
c << B[i]
b >> A[i]
"""
b = min(c, a);
"""

sdfg = sdfg_prune_connectors_in_scope.to_sdfg()
sdfg.simplify()

state = sdfg.start_state
map_entry = None
for node in state.nodes():
if isinstance(node, dace.nodes.MapEntry):
map_entry = node
break

map_exit = state.exit_node(map_entry)
subgraph = StateSubgraphView(state, set(state.all_nodes_between(map_entry, map_exit)))
helpers.nest_state_subgraph(sdfg, state, subgraph)

# WCR Conversion
applied = sdfg.apply_transformations_repeated(AugAssignToWCR)
assert applied == 1

applied = sdfg.apply_transformations_repeated(PruneConnectors)
assert applied == 1


def test_prune_connectors_in_scope_dependency():

@dace.program
def sdfg_prune_connectors_in_scope_dependency(A: dace.float64[32], B: dace.float64[32]):
for j in dace.map[0:32]:
with dace.tasklet(language=dace.Language.CPP):
a >> A[j]
"""
a = 0;
"""
for i in dace.map[0:32]:
with dace.tasklet(language=dace.Language.CPP):
a << A[i]
c << B[i]
b >> A[i]
"""
b = min(c, a);
"""

sdfg = sdfg_prune_connectors_in_scope_dependency.to_sdfg()
sdfg.simplify()

state = sdfg.start_state
map_entry = None
for node in state.nodes():
if isinstance(node, dace.nodes.MapEntry) and str(node.map.params[0]) == "i":
map_entry = node
break

# Create NestedSDFG inside map
map_exit = state.exit_node(map_entry)
subgraph = StateSubgraphView(state, set(state.all_nodes_between(map_entry, map_exit)))
helpers.nest_state_subgraph(sdfg, state, subgraph)

# WCR Conversion: A is no input of NestedSDFG
applied = sdfg.apply_transformations_repeated(AugAssignToWCR)
assert applied == 1

applied = sdfg.apply_transformations_repeated(PruneConnectors)
assert applied == 0


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--N", default=64)
Expand All @@ -248,3 +328,5 @@ def test_unused_retval_2():
test_prune_connectors(True, n=n)
test_unused_retval()
test_unused_retval_2()
test_prune_connectors_in_scope()
test_prune_connectors_in_scope_dependency()

0 comments on commit 07e4a25

Please sign in to comment.