From f66ef87478492109b06a217f8b2aa3644abfc810 Mon Sep 17 00:00:00 2001 From: Lukas Truemper Date: Mon, 30 Oct 2023 21:37:23 +0100 Subject: [PATCH] WCRToAugAssign: Use state fission instead of nested SDFG to prevent race conditions --- .../transformation/dataflow/wcr_conversion.py | 127 +++--------------- tests/transformations/wcr_conversion_test.py | 18 +-- 2 files changed, 25 insertions(+), 120 deletions(-) diff --git a/dace/transformation/dataflow/wcr_conversion.py b/dace/transformation/dataflow/wcr_conversion.py index f4151191cb..0dcb6e2fc5 100644 --- a/dace/transformation/dataflow/wcr_conversion.py +++ b/dace/transformation/dataflow/wcr_conversion.py @@ -8,7 +8,8 @@ from dace.sdfg import graph as gr, utils as sdutil from dace import SDFG, SDFGState from dace.sdfg.state import StateSubgraphView - +from dace.transformation import helpers +from dace import propagate_memlet class AugAssignToWCR(transformation.SingleStateTransformation): """ @@ -42,7 +43,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Free tasklet if expr_index == 0: - # Only free tasklets supported for now if graph.entry_node(tasklet) is not None: return False @@ -53,8 +53,6 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Make sure augmented assignment can be fissioned as necessary if any(not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(tasklet)): return False - if graph.in_degree(inarr) > 0 and graph.out_degree(outarr) > 0: - return False outedge = graph.edges_between(tasklet, outarr)[0] else: # Free map @@ -69,7 +67,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if len(graph.edges_between(tasklet, mx)) > 1: return False - # Currently no fission is supported + # Make sure augmented assignment can be fissioned as necessary if any(e.src is not me and not isinstance(e.src, nodes.AccessNode) for e in graph.in_edges(me) + graph.in_edges(tasklet)): return False @@ -150,50 +148,24 @@ def apply(self, state: SDFGState, sdfg: SDFG): input: nodes.AccessNode = self.input tasklet: nodes.Tasklet = self.tasklet output: nodes.AccessNode = self.output + if self.expr_index == 1: + me = self.map_entry + mx = self.map_exit # If state fission is necessary to keep semantics, do it first - if (self.expr_index == 0 and state.in_degree(input) > 0 and state.out_degree(output) == 0): - newstate = sdfg.add_state_after(state) - newstate.add_node(tasklet) - new_input, new_output = None, None - - # Keep old edges for after we remove tasklet from the original state - in_edges = list(state.in_edges(tasklet)) - out_edges = list(state.out_edges(tasklet)) - - for e in in_edges: - r = newstate.add_read(e.src.data) - newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data) - if e.src is input: - new_input = r - for e in out_edges: - w = newstate.add_write(e.dst.data) - newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data) - if e.dst is output: - new_output = w - - # Remove tasklet and resulting isolated nodes - state.remove_node(tasklet) - for e in in_edges: - if state.degree(e.src) == 0: - state.remove_node(e.src) - for e in out_edges: - if state.degree(e.dst) == 0: - state.remove_node(e.dst) - - # Reset state and nodes for rest of transformation - input = new_input - output = new_output - state = newstate - # End of state fission + if state.in_degree(input) > 0: + subgraph_nodes = set( + [e.src for e in state.bfs_edges(input, reverse=True)] + ) + subgraph_nodes.add(input) + + subgraph = StateSubgraphView(state, subgraph_nodes) + helpers.state_fission(sdfg, subgraph) if self.expr_index == 0: inedges = state.edges_between(input, tasklet) outedge = state.edges_between(tasklet, output)[0] else: - me = self.map_entry - mx = self.map_exit - inedges = state.edges_between(me, tasklet) outedge = state.edges_between(tasklet, mx)[0] @@ -283,75 +255,8 @@ def apply(self, state: SDFGState, sdfg: SDFG): else: outedge.data.wcr = f'lambda a,b: a {op} b' - if self.expr_index == 0: - # Remove input node and connector - state.remove_edge_and_connectors(inedge) - if state.degree(input) == 0: - state.remove_node(input) - else: - # Put into NestedSDFG to retain input dependencies - map_entry = self.map_entry - map_exit = self.map_exit - subgraph_nodes = set(state.all_nodes_between(map_entry, map_exit)) - subgraph_nodes.add(map_entry) - subgraph_nodes.add(map_exit) - - in_access_nodes = set() - out_access_nodes = set() - for edge in state.in_edges(map_entry): - subgraph_nodes.add(edge.src) - in_access_nodes.add(edge.src) - for edge in state.out_edges(map_exit): - subgraph_nodes.add(edge.dst) - out_access_nodes.add(edge.dst) - - subgraph = StateSubgraphView(state, subgraph_nodes) - - # Add subgraph as nested SDFG - nested_sdfg = SDFG("nested_" + map_entry.label) - inputs = set() - for data in in_access_nodes: - inputs.add(data.data) - nested_sdfg.arrays[data.data] = copy.deepcopy(sdfg.arrays[data.data]) - outputs = set() - for data in out_access_nodes: - outputs.add(data.data) - nested_sdfg.arrays[data.data] = copy.deepcopy(sdfg.arrays[data.data]) - - nested_state = nested_sdfg.add_state("nested_" + map_entry.label + "_state", is_start_state=True) - node_map = {} - new_inedge = None - for node in subgraph.nodes(): - new_node = copy.deepcopy(node) - nested_state.add_node(new_node) - node_map[node] = new_node - for edge in subgraph.edges(): - new_edge = nested_state.add_edge(node_map[edge.src], edge.src_conn, node_map[edge.dst], edge.dst_conn, - edge.data) - if edge == inedge: - new_inedge = new_edge - - nested_sdfg_node = state.add_nested_sdfg(nested_sdfg, sdfg, inputs=inputs, outputs=outputs) - for in_access in in_access_nodes: - nested_sdfg_node.add_in_connector(in_access.data) - state.add_edge(in_access, None, nested_sdfg_node, in_access.data, - Memlet.from_array(in_access.data, sdfg.arrays[in_access.data])) - - for out_access in out_access_nodes: - nested_sdfg_node.add_in_connector(out_access.data) - state.add_edge(nested_sdfg_node, out_access.data, out_access, None, - Memlet.from_array(out_access.data, sdfg.arrays[out_access.data])) - - # Remove subgraph from state - for edge in subgraph.edges(): - state.remove_edge(edge) - for node in subgraph.nodes(): - if node in in_access_nodes or node in out_access_nodes: - continue - - state.remove_node(node) - - nested_state.remove_memlet_path(new_inedge) + # Remove input node and connector + state.remove_memlet_path(inedge) # If outedge leads to non-transient, and this is a nested SDFG, # propagate outwards diff --git a/tests/transformations/wcr_conversion_test.py b/tests/transformations/wcr_conversion_test.py index b75ee86717..091b2a9db8 100644 --- a/tests/transformations/wcr_conversion_test.py +++ b/tests/transformations/wcr_conversion_test.py @@ -195,10 +195,10 @@ def sdfg_aug_assign_free_map(A: dace.float64[32], B: dace.float64[32]): assert applied == 1 -def test_aug_assign_dependent_map(): +def test_aug_assign_state_fission_map(): @dace.program - def sdfg_aug_assign_dependent_map(A: dace.float64[32], B: dace.float64[32]): + def sdfg_aug_assign_state_fission(A: dace.float64[32], B: dace.float64[32]): for i in dace.map[0:32]: with dace.tasklet: a << B[i] @@ -207,21 +207,21 @@ def sdfg_aug_assign_dependent_map(A: dace.float64[32], B: dace.float64[32]): for i in dace.map[0:32]: with dace.tasklet: - a << A[i] - b >> A[i] + a << A[0] + b >> A[0] b = a * 2 for i in dace.map[0:32]: with dace.tasklet: - a << A[i] - b >> B[i] - b = a + a << A[0] + b >> A[0] + b = a * 2 - sdfg = sdfg_aug_assign_dependent_map.to_sdfg() + sdfg = sdfg_aug_assign_state_fission.to_sdfg() sdfg.simplify() applied = sdfg.apply_transformations_repeated(AugAssignToWCR) - assert applied == 1 + assert applied == 2 def test_free_map_permissive():