diff --git a/dace/cli/sdfv.py b/dace/cli/sdfv.py index c0ff3da36d..f503775814 100644 --- a/dace/cli/sdfv.py +++ b/dace/cli/sdfv.py @@ -41,9 +41,10 @@ def view(sdfg: dace.SDFG, filename: Optional[Union[str, int]] = None): or 'VSCODE_IPC_HOOK_CLI' in os.environ or 'VSCODE_GIT_IPC_HANDLE' in os.environ ): - filename = tempfile.mktemp(suffix='.sdfg') + fd, filename = tempfile.mkstemp(suffix='.sdfg') sdfg.save(filename) os.system(f'code {filename}') + os.close(fd) return if type(sdfg) is dace.SDFG: diff --git a/dace/transformation/change_strides.py b/dace/transformation/change_strides.py new file mode 100644 index 0000000000..001cd4aa63 --- /dev/null +++ b/dace/transformation/change_strides.py @@ -0,0 +1,210 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" This module provides a function to change the stride in a given SDFG """ +from typing import List, Union, Tuple +import sympy + +import dace +from dace.dtypes import ScheduleType +from dace.sdfg import SDFG, nodes, SDFGState +from dace.data import Array, Scalar +from dace.memlet import Memlet + + +def list_access_nodes( + sdfg: dace.SDFG, + array_name: str) -> List[Tuple[nodes.AccessNode, Union[SDFGState, dace.SDFG]]]: + """ + Find all access nodes in the SDFG of the given array name. Does not recourse into nested SDFGs. + + :param sdfg: The SDFG to search through + :type sdfg: dace.SDFG + :param array_name: The name of the wanted array + :type array_name: str + :return: List of the found access nodes together with their state + :rtype: List[Tuple[nodes.AccessNode, Union[dace.SDFGState, dace.SDFG]]] + """ + found_nodes = [] + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, nodes.AccessNode) and node.data == array_name: + found_nodes.append((node, state)) + return found_nodes + + +def change_strides( + sdfg: dace.SDFG, + stride_one_values: List[str], + schedule: ScheduleType) -> SDFG: + """ + Change the strides of the arrays on the given SDFG such that the given dimension has stride 1. Returns a new SDFG. + + :param sdfg: The input SDFG + :type sdfg: dace.SDFG + :param stride_one_values: Length of the dimension whose stride should be set to one. Expects that each array has + only one dimension whose length is in this list. Expects that list contains name of symbols + :type stride_one_values: List[str] + :param schedule: Schedule to use to copy the arrays + :type schedule: ScheduleType + :return: SDFG with changed strides + :rtype: SDFG + """ + # Create new SDFG and copy constants and symbols + original_name = sdfg.name + sdfg.name = "changed_strides" + new_sdfg = SDFG(original_name) + for dname, value in sdfg.constants.items(): + new_sdfg.add_constant(dname, value) + for dname, stype in sdfg.symbols.items(): + new_sdfg.add_symbol(dname, stype) + + changed_stride_state = new_sdfg.add_state("with_changed_strides", is_start_state=True) + inputs, outputs = sdfg.read_and_write_sets() + # Get all arrays which are persistent == not transient + persistent_arrays = {name: desc for name, desc in sdfg.arrays.items() if not desc.transient} + + # Get the persistent arrays of all the transient arrays which get copied to GPU + for dname in persistent_arrays: + for access, state in list_access_nodes(sdfg, dname): + if len(state.out_edges(access)) == 1: + edge = state.out_edges(access)[0] + if isinstance(edge.dst, nodes.AccessNode): + if edge.dst.data in inputs: + inputs.remove(edge.dst.data) + inputs.add(dname) + if len(state.in_edges(access)) == 1: + edge = state.in_edges(access)[0] + if isinstance(edge.src, nodes.AccessNode): + if edge.src.data in inputs: + outputs.remove(edge.src.data) + outputs.add(dname) + + # Only keep inputs and outputs which are persistent + inputs.intersection_update(persistent_arrays.keys()) + outputs.intersection_update(persistent_arrays.keys()) + nsdfg = changed_stride_state.add_nested_sdfg(sdfg, new_sdfg, inputs=inputs, outputs=outputs) + transform_state = new_sdfg.add_state_before(changed_stride_state, label="transform_data", is_start_state=True) + transform_state_back = new_sdfg.add_state_after(changed_stride_state, "transform_data_back", is_start_state=False) + + # copy arrays + for dname, desc in sdfg.arrays.items(): + if not desc.transient: + if isinstance(desc, Array): + new_sdfg.add_array(dname, desc.shape, desc.dtype, desc.storage, + desc.location, desc.transient, desc.strides, + desc.offset) + elif isinstance(desc, Scalar): + new_sdfg.add_scalar(dname, desc.dtype, desc.storage, desc.transient, desc.lifetime, desc.debuginfo) + + new_order = {} + new_strides_map = {} + + # Map of array names in the nested sdfg: key: array name in parent sdfg (this sdfg), value: name in the nsdfg + # Assumes that name changes only appear in the first level of nsdfg nesting + array_names_map = {} + for graph in sdfg.sdfg_list: + if graph.parent_nsdfg_node is not None: + if graph.parent_sdfg == sdfg: + for connector in graph.parent_nsdfg_node.in_connectors: + for in_edge in graph.parent.in_edges_by_connector(graph.parent_nsdfg_node, connector): + array_names_map[str(connector)] = in_edge.data.data + + for containing_sdfg, dname, desc in sdfg.arrays_recursive(): + shape_str = [str(s) for s in desc.shape] + # Get index of the dimension we want to have stride 1 + stride_one_idx = None + this_stride_one_value = None + for dim in stride_one_values: + if str(dim) in shape_str: + stride_one_idx = shape_str.index(str(dim)) + this_stride_one_value = dim + break + + if stride_one_idx is not None: + new_order[dname] = [stride_one_idx] + + new_strides = list(desc.strides) + new_strides[stride_one_idx] = sympy.S.One + + previous_size = dace.symbolic.symbol(this_stride_one_value) + previous_stride = sympy.S.One + for i in range(len(new_strides)): + if i != stride_one_idx: + new_order[dname].append(i) + new_strides[i] = previous_size * previous_stride + previous_size = desc.shape[i] + previous_stride = new_strides[i] + + new_strides_map[dname] = {} + # Create a map entry for this data linking old strides to new strides. This assumes that each entry in + # strides is unique which is given as otherwise there would be two dimension i, j where a[i, j] would point + # to the same address as a[j, i] + for new_stride, old_stride in zip(new_strides, desc.strides): + new_strides_map[dname][old_stride] = new_stride + desc.strides = tuple(new_strides) + else: + parent_name = array_names_map[dname] if dname in array_names_map else dname + if parent_name in new_strides_map: + new_strides = [] + for stride in desc.strides: + new_strides.append(new_strides_map[parent_name][stride]) + desc.strides = new_strides + + # Add new flipped arrays for every non-transient array + flipped_names_map = {} + for dname, desc in sdfg.arrays.items(): + if not desc.transient: + flipped_name = f"{dname}_flipped" + flipped_names_map[dname] = flipped_name + new_sdfg.add_array(flipped_name, desc.shape, desc.dtype, + desc.storage, desc.location, True, + desc.strides, desc.offset) + + # Deal with the inputs: Create tasklet to flip them and connect via memlets + # for input in inputs: + for input in set([*inputs, *outputs]): + if input in new_order: + flipped_data = flipped_names_map[input] + if input in inputs: + changed_stride_state.add_memlet_path(changed_stride_state.add_access(flipped_data), nsdfg, + dst_conn=input, memlet=Memlet(data=flipped_data)) + # Simply need to copy the data, the different strides take care of the transposing + arr = sdfg.arrays[input] + tasklet, map_entry, map_exit = transform_state.add_mapped_tasklet( + name=f"transpose_{input}", + map_ranges={f"_i{i}": f"0:{s}" for i, s in enumerate(arr.shape)}, + inputs={'_in': Memlet(data=input, subset=", ".join(f"_i{i}" for i, _ in enumerate(arr.shape)))}, + code='_out = _in', + outputs={'_out': Memlet(data=flipped_data, + subset=", ".join(f"_i{i}" for i, _ in enumerate(arr.shape)))}, + external_edges=True, + schedule=schedule, + ) + # Do the same for the outputs + for output in outputs: + if output in new_order: + flipped_data = flipped_names_map[output] + changed_stride_state.add_memlet_path(nsdfg, changed_stride_state.add_access(flipped_data), + src_conn=output, memlet=Memlet(data=flipped_data)) + # Simply need to copy the data, the different strides take care of the transposing + arr = sdfg.arrays[output] + tasklet, map_entry, map_exit = transform_state_back.add_mapped_tasklet( + name=f"transpose_{output}", + map_ranges={f"_i{i}": f"0:{s}" for i, s in enumerate(arr.shape)}, + inputs={'_in': Memlet(data=flipped_data, + subset=", ".join(f"_i{i}" for i, _ in enumerate(arr.shape)))}, + code='_out = _in', + outputs={'_out': Memlet(data=output, subset=", ".join(f"_i{i}" for i, _ in enumerate(arr.shape)))}, + external_edges=True, + schedule=schedule, + ) + # Deal with any arrays which have not been flipped (should only be scalars). Connect them directly + for dname, desc in sdfg.arrays.items(): + if not desc.transient and dname not in new_order: + if dname in inputs: + changed_stride_state.add_memlet_path(changed_stride_state.add_access(dname), nsdfg, dst_conn=dname, + memlet=Memlet(data=dname)) + if dname in outputs: + changed_stride_state.add_memlet_path(nsdfg, changed_stride_state.add_access(dname), src_conn=dname, + memlet=Memlet(data=dname)) + + return new_sdfg diff --git a/dace/transformation/dataflow/map_expansion.py b/dace/transformation/dataflow/map_expansion.py index 275b99c7e8..60f1f13f32 100644 --- a/dace/transformation/dataflow/map_expansion.py +++ b/dace/transformation/dataflow/map_expansion.py @@ -1,16 +1,18 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Contains classes that implement the map-expansion transformation. """ from dace.sdfg.utils import consolidate_edges from typing import Dict, List import dace from dace import dtypes, subsets, symbolic +from dace.properties import EnumProperty, make_properties from dace.sdfg import nodes from dace.sdfg import utils as sdutil from dace.sdfg.graph import OrderedMultiDiConnectorGraph from dace.transformation import transformation as pm +@make_properties class MapExpansion(pm.SingleStateTransformation): """ Implements the map-expansion pattern. @@ -25,14 +27,16 @@ class MapExpansion(pm.SingleStateTransformation): map_entry = pm.PatternNode(nodes.MapEntry) + inner_schedule = EnumProperty(desc="Schedule for inner maps", + dtype=dtypes.ScheduleType, + default=dtypes.ScheduleType.Sequential, + allow_none=True) + @classmethod def expressions(cls): return [sdutil.node_path_graph(cls.map_entry)] - def can_be_applied(self, graph: dace.SDFGState, - expr_index: int, - sdfg: dace.SDFG, - permissive: bool = False): + def can_be_applied(self, graph: dace.SDFGState, expr_index: int, sdfg: dace.SDFG, permissive: bool = False): # A candidate subgraph matches the map-expansion pattern when it # includes an N-dimensional map, with N greater than one. return self.map_entry.map.get_param_num() > 1 @@ -44,10 +48,11 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): current_map = map_entry.map # Create new maps + inner_schedule = self.inner_schedule or current_map.schedule new_maps = [ nodes.Map(current_map.label + '_' + str(param), [param], subsets.Range([param_range]), - schedule=dtypes.ScheduleType.Sequential) + schedule=inner_schedule) for param, param_range in zip(current_map.params[1:], current_map.range[1:]) ] current_map.params = [current_map.params[0]] diff --git a/dace/transformation/dataflow/trivial_map_elimination.py b/dace/transformation/dataflow/trivial_map_elimination.py index 327d5d8c9a..9387cfce23 100644 --- a/dace/transformation/dataflow/trivial_map_elimination.py +++ b/dace/transformation/dataflow/trivial_map_elimination.py @@ -5,6 +5,7 @@ from dace.sdfg import utils as sdutil from dace.transformation import transformation from dace.properties import make_properties +from dace.memlet import Memlet @make_properties @@ -48,12 +49,15 @@ def apply(self, graph, sdfg): if len(remaining_ranges) == 0: # Redirect map entry's out edges + write_only_map = True for edge in graph.out_edges(map_entry): path = graph.memlet_path(edge) index = path.index(edge) - # Add an edge directly from the previous source connector to the destination - graph.add_edge(path[index - 1].src, path[index - 1].src_conn, edge.dst, edge.dst_conn, edge.data) + if not edge.data.is_empty(): + # Add an edge directly from the previous source connector to the destination + graph.add_edge(path[index - 1].src, path[index - 1].src_conn, edge.dst, edge.dst_conn, edge.data) + write_only_map = False # Redirect map exit's in edges. for edge in graph.in_edges(map_exit): @@ -63,6 +67,11 @@ def apply(self, graph, sdfg): # Add an edge directly from the source to the next destination connector if len(path) > index + 1: graph.add_edge(edge.src, edge.src_conn, path[index + 1].dst, path[index + 1].dst_conn, edge.data) + if write_only_map: + outer_exit = path[index+1].dst + outer_entry = graph.entry_node(outer_exit) + if outer_entry is not None: + graph.add_edge(outer_entry, None, edge.src, None, Memlet()) # Remove map graph.remove_nodes_from([map_entry, map_exit]) diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 8986c4e37f..9c41e4dec4 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -1137,7 +1137,8 @@ def traverse(state: SDFGState, treenode: ScopeTree): ntree.state = nstate treenode.children.append(ntree) for child in treenode.children: - traverse(getattr(child, 'state', state), child) + if hasattr(child, 'state') and child.state != state: + traverse(getattr(child, 'state', state), child) traverse(state, stree) return stree diff --git a/dace/transformation/interstate/__init__.py b/dace/transformation/interstate/__init__.py index 0bd168751c..b8bcc716e6 100644 --- a/dace/transformation/interstate/__init__.py +++ b/dace/transformation/interstate/__init__.py @@ -15,3 +15,4 @@ from .move_loop_into_map import MoveLoopIntoMap from .trivial_loop_elimination import TrivialLoopElimination from .multistate_inline import InlineMultistateSDFG +from .move_assignment_outside_if import MoveAssignmentOutsideIf diff --git a/dace/transformation/interstate/move_assignment_outside_if.py b/dace/transformation/interstate/move_assignment_outside_if.py new file mode 100644 index 0000000000..3d4db9ae25 --- /dev/null +++ b/dace/transformation/interstate/move_assignment_outside_if.py @@ -0,0 +1,113 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" +Transformation to move assignments outside if statements to potentially avoid warp divergence. Speedup gained is +questionable. +""" + +import ast +import sympy as sp + +from dace import sdfg as sd +from dace.sdfg import graph as gr +from dace.sdfg.nodes import Tasklet, AccessNode +from dace.transformation import transformation + + +class MoveAssignmentOutsideIf(transformation.MultiStateTransformation): + + if_guard = transformation.PatternNode(sd.SDFGState) + if_stmt = transformation.PatternNode(sd.SDFGState) + else_stmt = transformation.PatternNode(sd.SDFGState) + + @classmethod + def expressions(cls): + sdfg = gr.OrderedDiGraph() + sdfg.add_nodes_from([cls.if_guard, cls.if_stmt, cls.else_stmt]) + sdfg.add_edge(cls.if_guard, cls.if_stmt, sd.InterstateEdge()) + sdfg.add_edge(cls.if_guard, cls.else_stmt, sd.InterstateEdge()) + return [sdfg] + + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + # The if-guard can only have two outgoing edges: to the if and to the else part + guard_outedges = graph.out_edges(self.if_guard) + if len(guard_outedges) != 2: + return False + + # Outgoing edges must be a negation of each other + if guard_outedges[0].data.condition_sympy() != (sp.Not(guard_outedges[1].data.condition_sympy())): + return False + + # The if guard should either have zero or one incoming edge + if len(sdfg.in_edges(self.if_guard)) > 1: + return False + + # set of the variables which get a const value assigned + assigned_const = set() + # Dict which collects all AccessNodes for each variable together with its state + access_nodes = {} + # set of the variables which are only written to + self.write_only_values = set() + # Dictionary which stores additional information for the variables which are written only + self.assign_context = {} + for state in [self.if_stmt, self.else_stmt]: + for node in state.nodes(): + if isinstance(node, Tasklet): + # If node is a tasklet, check if assigns a constant value + assigns_const = True + for code_stmt in node.code.code: + if not (isinstance(code_stmt, ast.Assign) and isinstance(code_stmt.value, ast.Constant)): + assigns_const = False + if assigns_const: + for edge in state.out_edges(node): + if isinstance(edge.dst, AccessNode): + assigned_const.add(edge.dst.data) + self.assign_context[edge.dst.data] = {"state": state, "tasklet": node} + elif isinstance(node, AccessNode): + if node.data not in access_nodes: + access_nodes[node.data] = [] + access_nodes[node.data].append((node, state)) + + # check that the found access nodes only get written to + for data, nodes in access_nodes.items(): + write_only = True + for node, state in nodes: + if node.has_reads(state): + # The read is only a problem if it is not written before -> the access node has no incoming edge + if state.in_degree(node) == 0: + write_only = False + else: + # There is also a problem if any edge is an update instead of write + for edge in [*state.out_edges(node), *state.out_edges(node)]: + if edge.data.wcr is not None: + write_only = False + + if write_only: + self.write_only_values.add(data) + + # Want only the values which are only written to and one option uses a constant value + self.write_only_values = assigned_const.intersection(self.write_only_values) + + if len(self.write_only_values) == 0: + return False + return True + + def apply(self, _, sdfg: sd.SDFG): + # create a new state before the guard state where the zero assignment happens + new_assign_state = sdfg.add_state_before(self.if_guard, label="const_assignment_state") + + # Move all the Tasklets together with the AccessNode + for value in self.write_only_values: + state = self.assign_context[value]["state"] + tasklet = self.assign_context[value]["tasklet"] + new_assign_state.add_node(tasklet) + for edge in state.out_edges(tasklet): + state.remove_edge(edge) + state.remove_node(edge.dst) + new_assign_state.add_node(edge.dst) + new_assign_state.add_edge(tasklet, edge.src_conn, edge.dst, edge.dst_conn, edge.data) + + state.remove_node(tasklet) + # Remove the state if it was emptied + if state.is_empty(): + sdfg.remove_node(state) + return sdfg diff --git a/tests/transformations/change_strides_test.py b/tests/transformations/change_strides_test.py new file mode 100644 index 0000000000..3975761fd5 --- /dev/null +++ b/tests/transformations/change_strides_test.py @@ -0,0 +1,48 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace import nodes +from dace.dtypes import ScheduleType +from dace.memlet import Memlet +from dace.transformation.change_strides import change_strides + + +def change_strides_test(): + sdfg = dace.SDFG('change_strides_test') + N = dace.symbol('N') + M = dace.symbol('M') + sdfg.add_array('A', [N, M], dace.float64) + sdfg.add_array('B', [N, M, 3], dace.float64) + state = sdfg.add_state() + + task1, mentry1, mexit1 = state.add_mapped_tasklet( + name="map1", + map_ranges={'i': '0:N', 'j': '0:M'}, + inputs={'a': Memlet(data='A', subset='i, j')}, + outputs={'b': Memlet(data='B', subset='i, j, 0')}, + code='b = a + 1', + external_edges=True, + propagate=True) + + # Check that states are as expected + changed_sdfg = change_strides(sdfg, ['N'], ScheduleType.Sequential) + assert len(changed_sdfg.states()) == 3 + assert len(changed_sdfg.out_edges(changed_sdfg.start_state)) == 1 + work_state = changed_sdfg.out_edges(changed_sdfg.start_state)[0].dst + nsdfg = None + for node in work_state.nodes(): + if isinstance(node, nodes.NestedSDFG): + nsdfg = node + # Check shape and strides of data inside nested SDFG + assert nsdfg is not None + assert nsdfg.sdfg.data('A').shape == (N, M) + assert nsdfg.sdfg.data('B').shape == (N, M, 3) + assert nsdfg.sdfg.data('A').strides == (1, N) + assert nsdfg.sdfg.data('B').strides == (1, N, M*N) + + +def main(): + change_strides_test() + + +if __name__ == '__main__': + main() diff --git a/tests/transformations/move_assignment_outside_if_test.py b/tests/transformations/move_assignment_outside_if_test.py new file mode 100644 index 0000000000..323e83cf61 --- /dev/null +++ b/tests/transformations/move_assignment_outside_if_test.py @@ -0,0 +1,161 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from dace.transformation.interstate import MoveAssignmentOutsideIf +from dace.sdfg import InterstateEdge +from dace.memlet import Memlet +from dace.sdfg.nodes import Tasklet + + +def one_variable_simple_test(const_value: int = 0): + """ Test with one variable which has formula and const branch. Uses the given const value """ + sdfg = dace.SDFG('one_variable_simple_test') + # Create guard state and one state where A is set to 0 and another where it is set using B and some formula + guard = sdfg.add_state('guard', is_start_state=True) + formula_state = sdfg.add_state('formula', is_start_state=False) + const_state = sdfg.add_state('const', is_start_state=False) + sdfg.add_array('A', [1], dace.float64) + sdfg.add_array('B', [1], dace.float64) + + # Add tasklet inside states + formula_tasklet = formula_state.add_tasklet('formula_assign', {'b'}, {'a'}, 'a = 2*b') + formula_state.add_memlet_path(formula_state.add_read('B'), formula_tasklet, memlet=Memlet(data='B', subset='0'), + dst_conn='b') + formula_state.add_memlet_path(formula_tasklet, formula_state.add_write('A'), memlet=Memlet(data='A', subset='0'), + src_conn='a') + const_tasklet = const_state.add_tasklet('const_assign', {}, {'a'}, f"a = {const_value}") + const_state.add_memlet_path(const_tasklet, const_state.add_write('A'), memlet=Memlet(data='A', subset='0'), + src_conn='a') + + # Create if-else condition such that either the formula state or the const state is executed + sdfg.add_edge(guard, formula_state, InterstateEdge(condition='B[0] < 0.5')) + sdfg.add_edge(guard, const_state, InterstateEdge(condition='B[0] >= 0.5')) + sdfg.validate() + + # Assure transformation is applied + assert sdfg.apply_transformations_repeated([MoveAssignmentOutsideIf]) == 1 + # SDFG now starts with a state containing the const_tasklet + assert const_tasklet in sdfg.start_state.nodes() + # The formula state has only one in_edge with the condition + assert len(sdfg.in_edges(formula_state)) == 1 + assert sdfg.in_edges(formula_state)[0].data.condition.as_string == '(B[0] < 0.5)' + # All state have at most one out_edge -> there is no if-else branching anymore + for state in sdfg.states(): + assert len(sdfg.out_edges(state)) <= 1 + + +def multiple_variable_test(): + """ Test with multiple variables where not all appear in the const branch """ + sdfg = dace.SDFG('one_variable_simple_test') + # Create guard state and one state where A is set to 0 and another where it is set using B and some formula + guard = sdfg.add_state('guard', is_start_state=True) + formula_state = sdfg.add_state('formula', is_start_state=False) + const_state = sdfg.add_state('const', is_start_state=False) + sdfg.add_array('A', [1], dace.float64) + sdfg.add_array('B', [1], dace.float64) + sdfg.add_array('C', [1], dace.float64) + sdfg.add_array('D', [1], dace.float64) + + A = formula_state.add_access('A') + B = formula_state.add_access('B') + C = formula_state.add_access('C') + D = formula_state.add_access('D') + formula_tasklet_a = formula_state.add_tasklet('formula_assign', {'b'}, {'a'}, 'a = 2*b') + formula_state.add_memlet_path(B, formula_tasklet_a, memlet=Memlet(data='B', subset='0'), dst_conn='b') + formula_state.add_memlet_path(formula_tasklet_a, A, memlet=Memlet(data='A', subset='0'), src_conn='a') + formula_tasklet_b = formula_state.add_tasklet('formula_assign', {'c'}, {'b'}, 'a = 2*c') + formula_state.add_memlet_path(C, formula_tasklet_b, memlet=Memlet(data='C', subset='0'), dst_conn='c') + formula_state.add_memlet_path(formula_tasklet_b, B, memlet=Memlet(data='B', subset='0'), src_conn='b') + formula_tasklet_c = formula_state.add_tasklet('formula_assign', {'d'}, {'c'}, 'a = 2*d') + formula_state.add_memlet_path(D, formula_tasklet_c, memlet=Memlet(data='D', subset='0'), dst_conn='d') + formula_state.add_memlet_path(formula_tasklet_c, C, memlet=Memlet(data='C', subset='0'), src_conn='c') + + const_tasklet_a = const_state.add_tasklet('const_assign', {}, {'a'}, 'a = 0') + const_state.add_memlet_path(const_tasklet_a, const_state.add_write('A'), memlet=Memlet(data='A', subset='0'), + src_conn='a') + const_tasklet_b = const_state.add_tasklet('const_assign', {}, {'b'}, 'b = 0') + const_state.add_memlet_path(const_tasklet_b, const_state.add_write('B'), memlet=Memlet(data='B', subset='0'), + src_conn='b') + + # Create if-else condition such that either the formula state or the const state is executed + sdfg.add_edge(guard, formula_state, InterstateEdge(condition='D[0] < 0.5')) + sdfg.add_edge(guard, const_state, InterstateEdge(condition='D[0] >= 0.5')) + sdfg.validate() + + # Assure transformation is applied + assert sdfg.apply_transformations_repeated([MoveAssignmentOutsideIf]) == 1 + # There are no other tasklets in the start state beside the const assignment tasklet as there are no other const + # assignments + for node in sdfg.start_state.nodes(): + if isinstance(node, Tasklet): + assert node == const_tasklet_a or node == const_tasklet_b + # The formula state has only one in_edge with the condition + assert len(sdfg.in_edges(formula_state)) == 1 + assert sdfg.in_edges(formula_state)[0].data.condition.as_string == '(D[0] < 0.5)' + # All state have at most one out_edge -> there is no if-else branching anymore + for state in sdfg.states(): + assert len(sdfg.out_edges(state)) <= 1 + + +def multiple_variable_not_all_const_test(): + """ Test with multiple variables where not all get const-assigned in const branch """ + sdfg = dace.SDFG('one_variable_simple_test') + # Create guard state and one state where A is set to 0 and another where it is set using B and some formula + guard = sdfg.add_state('guard', is_start_state=True) + formula_state = sdfg.add_state('formula', is_start_state=False) + const_state = sdfg.add_state('const', is_start_state=False) + sdfg.add_array('A', [1], dace.float64) + sdfg.add_array('B', [1], dace.float64) + sdfg.add_array('C', [1], dace.float64) + + A = formula_state.add_access('A') + B = formula_state.add_access('B') + C = formula_state.add_access('C') + formula_tasklet_a = formula_state.add_tasklet('formula_assign', {'b'}, {'a'}, 'a = 2*b') + formula_state.add_memlet_path(B, formula_tasklet_a, memlet=Memlet(data='B', subset='0'), dst_conn='b') + formula_state.add_memlet_path(formula_tasklet_a, A, memlet=Memlet(data='A', subset='0'), src_conn='a') + formula_tasklet_b = formula_state.add_tasklet('formula_assign', {'c'}, {'b'}, 'a = 2*c') + formula_state.add_memlet_path(C, formula_tasklet_b, memlet=Memlet(data='C', subset='0'), dst_conn='c') + formula_state.add_memlet_path(formula_tasklet_b, B, memlet=Memlet(data='B', subset='0'), src_conn='b') + + const_tasklet_a = const_state.add_tasklet('const_assign', {}, {'a'}, 'a = 0') + const_state.add_memlet_path(const_tasklet_a, const_state.add_write('A'), memlet=Memlet(data='A', subset='0'), + src_conn='a') + const_tasklet_b = const_state.add_tasklet('const_assign', {'c'}, {'b'}, 'b = 1.5 * c') + const_state.add_memlet_path(const_state.add_read('C'), const_tasklet_b, memlet=Memlet(data='C', subset='0'), + dst_conn='c') + const_state.add_memlet_path(const_tasklet_b, const_state.add_write('B'), memlet=Memlet(data='B', subset='0'), + src_conn='b') + + # Create if-else condition such that either the formula state or the const state is executed + sdfg.add_edge(guard, formula_state, InterstateEdge(condition='C[0] < 0.5')) + sdfg.add_edge(guard, const_state, InterstateEdge(condition='C[0] >= 0.5')) + sdfg.validate() + + # Assure transformation is applied + assert sdfg.apply_transformations_repeated([MoveAssignmentOutsideIf]) == 1 + # There are no other tasklets in the start state beside the const assignment tasklet as there are no other const + # assignments + for node in sdfg.start_state.nodes(): + if isinstance(node, Tasklet): + assert node == const_tasklet_a or node == const_tasklet_b + # The formula state has only one in_edge with the condition + assert len(sdfg.in_edges(formula_state)) == 1 + assert sdfg.in_edges(formula_state)[0].data.condition.as_string == '(C[0] < 0.5)' + # Guard still has two outgoing edges as if-else pattern still exists + assert len(sdfg.out_edges(guard)) == 2 + # const state now has only const_tasklet_b left plus two access nodes + assert len(const_state.nodes()) == 3 + for node in const_state.nodes(): + if isinstance(node, Tasklet): + assert node == const_tasklet_b + + +def main(): + one_variable_simple_test(0) + one_variable_simple_test(2) + multiple_variable_test() + multiple_variable_not_all_const_test() + + +if __name__ == '__main__': + main() diff --git a/tests/trivial_map_elimination_test.py b/tests/trivial_map_elimination_test.py index 44b1f77652..9600dad640 100644 --- a/tests/trivial_map_elimination_test.py +++ b/tests/trivial_map_elimination_test.py @@ -25,7 +25,69 @@ def trivial_map_sdfg(): return sdfg +def trivial_map_init_sdfg(): + sdfg = dace.SDFG('trivial_map_range_expanded') + sdfg.add_array('B', [5, 1], dace.float64) + state = sdfg.add_state() + + # Nodes + map_entry_outer, map_exit_outer = state.add_map('map_outer', dict(j='0:5')) + map_entry_inner, map_exit_inner = state.add_map('map_inner', dict(i='0:1')) + + tasklet = state.add_tasklet('tasklet', {}, {'b'}, 'b = 1') + write = state.add_write('B') + + # Edges + state.add_memlet_path(map_entry_outer, map_entry_inner, memlet=dace.Memlet()) + state.add_memlet_path(map_entry_inner, tasklet, memlet=dace.Memlet()) + + state.add_memlet_path(tasklet, map_exit_inner, memlet=dace.Memlet.simple('B', 'j, i'), src_conn='b', + dst_conn='IN_B') + state.add_memlet_path(map_exit_inner, map_exit_outer, memlet=dace.Memlet.simple('B', 'j, 0'), src_conn='OUT_B', + dst_conn='IN_B') + state.add_memlet_path(map_exit_outer, write, memlet=dace.Memlet.simple('B', '0:5, 0'), + src_conn='OUT_B') + + sdfg.validate() + return sdfg + + +def trivial_map_pseudo_init_sdfg(): + sdfg = dace.SDFG('trivial_map_range_expanded') + sdfg.add_array('A', [5, 1], dace.float64) + sdfg.add_array('B', [5, 1], dace.float64) + state = sdfg.add_state() + + # Nodes + map_entry_outer, map_exit_outer = state.add_map('map_outer', dict(j='0:5')) + map_entry_inner, map_exit_inner = state.add_map('map_inner', dict(i='0:1')) + + read = state.add_read('A') + tasklet = state.add_tasklet('tasklet', {'a'}, {'b'}, 'b = a') + write = state.add_write('B') + + # Edges + state.add_memlet_path(map_entry_outer, map_entry_inner, memlet=dace.Memlet()) + state.add_memlet_path(read, map_entry_outer, map_entry_inner, memlet=dace.Memlet.simple('A', '0:5, 0'), + dst_conn='IN_A') + state.add_memlet_path(map_entry_inner, tasklet, memlet=dace.Memlet()) + state.add_memlet_path(map_entry_inner, tasklet, memlet=dace.Memlet.simple('A', 'j, 0'), src_conn='OUT_A', dst_conn='a') + + state.add_memlet_path(tasklet, map_exit_inner, memlet=dace.Memlet.simple('B', 'j, i'), src_conn='b', + dst_conn='IN_B') + state.add_memlet_path(map_exit_inner, map_exit_outer, memlet=dace.Memlet.simple('B', 'j, 0'), src_conn='OUT_B', + dst_conn='IN_B') + state.add_memlet_path(map_exit_outer, write, memlet=dace.Memlet.simple('B', '0:5, 0'), + src_conn='OUT_B') + + sdfg.validate() + return sdfg + + class TrivialMapEliminationTest(unittest.TestCase): + """ + Tests the case where the map has an empty input edge + """ def test_can_be_applied(self): graph = trivial_map_sdfg() @@ -56,5 +118,75 @@ def test_raplaces_map_params_in_scope(self): self.assertEqual(out_memlet.data.subset, dace.subsets.Range([(0, 0, 1)])) +class TrivialMapInitEliminationTest(unittest.TestCase): + def test_can_be_applied(self): + graph = trivial_map_init_sdfg() + + count = graph.apply_transformations(TrivialMapElimination, validate=False, validate_all=False) + graph.validate() + + self.assertGreater(count, 0) + + def test_removes_map(self): + graph = trivial_map_init_sdfg() + + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 2) + + graph.apply_transformations(TrivialMapElimination) + + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 1) + + def test_reconnects_edges(self): + graph = trivial_map_init_sdfg() + + graph.apply_transformations(TrivialMapElimination) + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 1) + # Check that there is an outgoing edge from the map entry + self.assertEqual(len(state.out_edges(map_entries[0])), 1) + + +class TrivialMapPseudoInitEliminationTest(unittest.TestCase): + """ + Test cases where the map has an empty input and a non empty input + """ + def test_can_be_applied(self): + graph = trivial_map_pseudo_init_sdfg() + + count = graph.apply_transformations(TrivialMapElimination, validate=False, validate_all=False) + graph.validate() + graph.view() + + self.assertGreater(count, 0) + + def test_removes_map(self): + graph = trivial_map_pseudo_init_sdfg() + + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 2) + + graph.apply_transformations(TrivialMapElimination) + + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 1) + + def test_reconnects_edges(self): + graph = trivial_map_pseudo_init_sdfg() + + graph.apply_transformations(TrivialMapElimination) + state = graph.nodes()[0] + map_entries = [n for n in state.nodes() if isinstance(n, dace.sdfg.nodes.MapEntry)] + self.assertEqual(len(map_entries), 1) + # Check that there is an outgoing edge from the map entry + self.assertEqual(len(state.out_edges(map_entries[0])), 1) + + if __name__ == '__main__': unittest.main()