Skip to content

Commit

Permalink
WCRToAugAssign: Use state fission instead of nested SDFG to prevent r…
Browse files Browse the repository at this point in the history
…ace conditions
  • Loading branch information
lukastruemper committed Oct 30, 2023
1 parent 736caca commit f66ef87
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 120 deletions.
127 changes: 16 additions & 111 deletions dace/transformation/dataflow/wcr_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from dace.sdfg import graph as gr, utils as sdutil
from dace import SDFG, SDFGState
from dace.sdfg.state import StateSubgraphView

from dace.transformation import helpers
from dace import propagate_memlet

class AugAssignToWCR(transformation.SingleStateTransformation):
"""
Expand Down Expand Up @@ -42,7 +43,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):

# Free tasklet
if expr_index == 0:
# Only free tasklets supported for now
if graph.entry_node(tasklet) is not None:
return False

Expand All @@ -53,8 +53,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
# Make sure augmented assignment can be fissioned as necessary
if any(not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(tasklet)):
return False
if graph.in_degree(inarr) > 0 and graph.out_degree(outarr) > 0:
return False

outedge = graph.edges_between(tasklet, outarr)[0]
else: # Free map
Expand All @@ -69,7 +67,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
if len(graph.edges_between(tasklet, mx)) > 1:
return False

# Currently no fission is supported
# Make sure augmented assignment can be fissioned as necessary
if any(e.src is not me and not isinstance(e.src, nodes.AccessNode)
for e in graph.in_edges(me) + graph.in_edges(tasklet)):
return False
Expand Down Expand Up @@ -150,50 +148,24 @@ def apply(self, state: SDFGState, sdfg: SDFG):
input: nodes.AccessNode = self.input
tasklet: nodes.Tasklet = self.tasklet
output: nodes.AccessNode = self.output
if self.expr_index == 1:
me = self.map_entry
mx = self.map_exit

# If state fission is necessary to keep semantics, do it first
if (self.expr_index == 0 and state.in_degree(input) > 0 and state.out_degree(output) == 0):
newstate = sdfg.add_state_after(state)
newstate.add_node(tasklet)
new_input, new_output = None, None

# Keep old edges for after we remove tasklet from the original state
in_edges = list(state.in_edges(tasklet))
out_edges = list(state.out_edges(tasklet))

for e in in_edges:
r = newstate.add_read(e.src.data)
newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data)
if e.src is input:
new_input = r
for e in out_edges:
w = newstate.add_write(e.dst.data)
newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data)
if e.dst is output:
new_output = w

# Remove tasklet and resulting isolated nodes
state.remove_node(tasklet)
for e in in_edges:
if state.degree(e.src) == 0:
state.remove_node(e.src)
for e in out_edges:
if state.degree(e.dst) == 0:
state.remove_node(e.dst)

# Reset state and nodes for rest of transformation
input = new_input
output = new_output
state = newstate
# End of state fission
if state.in_degree(input) > 0:
subgraph_nodes = set(
[e.src for e in state.bfs_edges(input, reverse=True)]
)
subgraph_nodes.add(input)

subgraph = StateSubgraphView(state, subgraph_nodes)
helpers.state_fission(sdfg, subgraph)

if self.expr_index == 0:
inedges = state.edges_between(input, tasklet)
outedge = state.edges_between(tasklet, output)[0]
else:
me = self.map_entry
mx = self.map_exit

inedges = state.edges_between(me, tasklet)
outedge = state.edges_between(tasklet, mx)[0]

Expand Down Expand Up @@ -283,75 +255,8 @@ def apply(self, state: SDFGState, sdfg: SDFG):
else:
outedge.data.wcr = f'lambda a,b: a {op} b'

if self.expr_index == 0:
# Remove input node and connector
state.remove_edge_and_connectors(inedge)
if state.degree(input) == 0:
state.remove_node(input)
else:
# Put into NestedSDFG to retain input dependencies
map_entry = self.map_entry
map_exit = self.map_exit
subgraph_nodes = set(state.all_nodes_between(map_entry, map_exit))
subgraph_nodes.add(map_entry)
subgraph_nodes.add(map_exit)

in_access_nodes = set()
out_access_nodes = set()
for edge in state.in_edges(map_entry):
subgraph_nodes.add(edge.src)
in_access_nodes.add(edge.src)
for edge in state.out_edges(map_exit):
subgraph_nodes.add(edge.dst)
out_access_nodes.add(edge.dst)

subgraph = StateSubgraphView(state, subgraph_nodes)

# Add subgraph as nested SDFG
nested_sdfg = SDFG("nested_" + map_entry.label)
inputs = set()
for data in in_access_nodes:
inputs.add(data.data)
nested_sdfg.arrays[data.data] = copy.deepcopy(sdfg.arrays[data.data])
outputs = set()
for data in out_access_nodes:
outputs.add(data.data)
nested_sdfg.arrays[data.data] = copy.deepcopy(sdfg.arrays[data.data])

nested_state = nested_sdfg.add_state("nested_" + map_entry.label + "_state", is_start_state=True)
node_map = {}
new_inedge = None
for node in subgraph.nodes():
new_node = copy.deepcopy(node)
nested_state.add_node(new_node)
node_map[node] = new_node
for edge in subgraph.edges():
new_edge = nested_state.add_edge(node_map[edge.src], edge.src_conn, node_map[edge.dst], edge.dst_conn,
edge.data)
if edge == inedge:
new_inedge = new_edge

nested_sdfg_node = state.add_nested_sdfg(nested_sdfg, sdfg, inputs=inputs, outputs=outputs)
for in_access in in_access_nodes:
nested_sdfg_node.add_in_connector(in_access.data)
state.add_edge(in_access, None, nested_sdfg_node, in_access.data,
Memlet.from_array(in_access.data, sdfg.arrays[in_access.data]))

for out_access in out_access_nodes:
nested_sdfg_node.add_in_connector(out_access.data)
state.add_edge(nested_sdfg_node, out_access.data, out_access, None,
Memlet.from_array(out_access.data, sdfg.arrays[out_access.data]))

# Remove subgraph from state
for edge in subgraph.edges():
state.remove_edge(edge)
for node in subgraph.nodes():
if node in in_access_nodes or node in out_access_nodes:
continue

state.remove_node(node)

nested_state.remove_memlet_path(new_inedge)
# Remove input node and connector
state.remove_memlet_path(inedge)

# If outedge leads to non-transient, and this is a nested SDFG,
# propagate outwards
Expand Down
18 changes: 9 additions & 9 deletions tests/transformations/wcr_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,10 @@ def sdfg_aug_assign_free_map(A: dace.float64[32], B: dace.float64[32]):
assert applied == 1


def test_aug_assign_dependent_map():
def test_aug_assign_state_fission_map():

@dace.program
def sdfg_aug_assign_dependent_map(A: dace.float64[32], B: dace.float64[32]):
def sdfg_aug_assign_state_fission(A: dace.float64[32], B: dace.float64[32]):
for i in dace.map[0:32]:
with dace.tasklet:
a << B[i]
Expand All @@ -207,21 +207,21 @@ def sdfg_aug_assign_dependent_map(A: dace.float64[32], B: dace.float64[32]):

for i in dace.map[0:32]:
with dace.tasklet:
a << A[i]
b >> A[i]
a << A[0]
b >> A[0]
b = a * 2

for i in dace.map[0:32]:
with dace.tasklet:
a << A[i]
b >> B[i]
b = a
a << A[0]
b >> A[0]
b = a * 2

sdfg = sdfg_aug_assign_dependent_map.to_sdfg()
sdfg = sdfg_aug_assign_state_fission.to_sdfg()
sdfg.simplify()

applied = sdfg.apply_transformations_repeated(AugAssignToWCR)
assert applied == 1
assert applied == 2


def test_free_map_permissive():
Expand Down

0 comments on commit f66ef87

Please sign in to comment.