diff --git a/dace/transformation/dataflow/map_expansion.py b/dace/transformation/dataflow/map_expansion.py index 275b99c7e8..ce3d7600d9 100644 --- a/dace/transformation/dataflow/map_expansion.py +++ b/dace/transformation/dataflow/map_expansion.py @@ -3,12 +3,14 @@ from dace.sdfg.utils import consolidate_edges from typing import Dict, List +import copy import dace from dace import dtypes, subsets, symbolic from dace.sdfg import nodes from dace.sdfg import utils as sdutil from dace.sdfg.graph import OrderedMultiDiConnectorGraph from dace.transformation import transformation as pm +from dace.sdfg.propagation import propagate_memlets_scope class MapExpansion(pm.SingleStateTransformation): @@ -61,14 +63,28 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): # 1. If there are no edges coming from the outside, use empty memlets # 2. Edges with IN_* connectors replicate along the maps # 3. Edges for dynamic map ranges replicate until reaching range(s) - for edge in graph.out_edges(map_entry): + for edge in list(graph.out_edges(map_entry)): + if edge.src_conn is not None and edge.src_conn not in entries[-1].out_connectors: + entries[-1].add_out_connector(edge.src_conn) + + graph.add_edge(entries[-1], edge.src_conn, edge.dst, edge.dst_conn, memlet=copy.deepcopy(edge.data)) graph.remove_edge(edge) - graph.add_memlet_path(map_entry, - *entries, - edge.dst, - src_conn=edge.src_conn, - memlet=edge.data, - dst_conn=edge.dst_conn) + + if graph.in_degree(map_entry) == 0: + graph.add_memlet_path(map_entry, *entries, memlet=dace.Memlet()) + else: + for edge in graph.in_edges(map_entry): + if not edge.dst_conn.startswith("IN_"): + continue + + in_conn = edge.dst_conn + out_conn = "OUT_" + in_conn[3:] + if in_conn not in entries[-1].in_connectors: + graph.add_memlet_path(map_entry, + *entries, + memlet=copy.deepcopy(edge.data), + src_conn=out_conn, + dst_conn=in_conn) # Modify dynamic map ranges dynamic_edges = dace.sdfg.dynamic_map_inputs(graph, map_entry) @@ -111,6 +127,7 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): else: raise ValueError('Cannot find scope in state') + propagate_memlets_scope(sdfg, state=graph, scopes=scope) consolidate_edges(sdfg, scope) return [map_entry] + entries diff --git a/dace/transformation/dataflow/tasklet_fusion.py b/dace/transformation/dataflow/tasklet_fusion.py index 8179ead457..22a9dc9f6a 100644 --- a/dace/transformation/dataflow/tasklet_fusion.py +++ b/dace/transformation/dataflow/tasklet_fusion.py @@ -272,5 +272,5 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): graph.remove_node(t1) if data is not None: graph.remove_node(data) - sdfg.remove_data(data.data, True) + graph.remove_node(t2) diff --git a/dace/transformation/dataflow/wcr_conversion.py b/dace/transformation/dataflow/wcr_conversion.py index e95674adc1..d8bdc0b4b3 100644 --- a/dace/transformation/dataflow/wcr_conversion.py +++ b/dace/transformation/dataflow/wcr_conversion.py @@ -74,6 +74,12 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): outedge = graph.edges_between(tasklet, mx)[0] + # If in map, only match if the subset is independent of any + # map indices (otherwise no conflict) + if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len( + me.map.params): + return False + # Get relevant output connector outconn = outedge.src_conn @@ -115,16 +121,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if edge.data.subset != outedge.data.subset: continue - # If in map, only match if the subset is independent of any - # map indices (otherwise no conflict) - if (expr_index == 1 and len(outedge.data.subset.free_symbols - & set(me.map.params)) == len(me.map.params)): - continue - return True - else: - # Only Python/C++ tasklets supported - return False return False @@ -192,11 +189,13 @@ def apply(self, state: SDFGState, sdfg: SDFG): rhs: ast.BinOp = ast_node.value op = AugAssignToWCR._PYOP_MAP[type(rhs.op)] inconns = list(edge.dst_conn for edge in inedges) - for n in (rhs.left, rhs.right): - if isinstance(n, ast.Name) and n.id in inconns: - inedge = inedges[inconns.index(n.id)] - else: - new_rhs = n + if isinstance(rhs.left, ast.Name) and rhs.left.id in inconns: + inedge = inedges[inconns.index(rhs.left.id)] + new_rhs = rhs.right + else: + inedge = inedges[inconns.index(rhs.right.id)] + new_rhs = rhs.left + new_node = ast.copy_location(ast.Assign(targets=[lhs], value=new_rhs), ast_node) tasklet.code.code = [new_node] diff --git a/tests/expansion_dynamic_range_test.py b/tests/expansion_dynamic_range_test.py deleted file mode 100644 index 2cafe5b6f1..0000000000 --- a/tests/expansion_dynamic_range_test.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -import dace -from dace.transformation.dataflow import MapExpansion -import numpy as np - - -@dace.program -def expansion(A: dace.float32[20, 30, 5], rng: dace.int32[2]): - @dace.map - def mymap(i: _[0:20], j: _[rng[0]:rng[1]], k: _[0:5]): - a << A[i, j, k] - b >> A[i, j, k] - b = a * 2 - - -def test(): - A = np.random.rand(20, 30, 5).astype(np.float32) - b = np.array([5, 10], dtype=np.int32) - expected = A.copy() - expected[:, 5:10, :] *= 2 - - sdfg = expansion.to_sdfg() - sdfg(A=A, rng=b) - diff = np.linalg.norm(A - expected) - print('Difference (before transformation):', diff) - - sdfg.apply_transformations(MapExpansion) - - sdfg(A=A, rng=b) - expected[:, 5:10, :] *= 2 - diff2 = np.linalg.norm(A - expected) - print('Difference:', diff2) - assert (diff <= 1e-5) and (diff2 <= 1e-5) - - -if __name__ == "__main__": - test() diff --git a/tests/transformations/map_expansion_test.py b/tests/transformations/map_expansion_test.py new file mode 100644 index 0000000000..1f9a97f810 --- /dev/null +++ b/tests/transformations/map_expansion_test.py @@ -0,0 +1,119 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import dace +import numpy as np +from dace.transformation.dataflow import MapExpansion + +def test_expand_with_inputs(): + @dace.program + def toexpand(A: dace.float64[4, 2], B: dace.float64[2, 2]): + for i, j in dace.map[1:3, 0:2]: + with dace.tasklet: + a1 << A[i, j] + a2 << A[i + 1, j] + a3 << A[i - 1, j] + b >> B[i-1, j] + b = a1 + a2 + a3 + + sdfg = toexpand.to_sdfg() + sdfg.simplify() + + # Init conditions + sdfg.validate() + assert len([node for node in sdfg.start_state.nodes() if isinstance(node, dace.nodes.MapEntry)]) == 1 + assert len([node for node in sdfg.start_state.nodes() if isinstance(node, dace.nodes.MapExit)]) == 1 + + # Expansion + assert sdfg.apply_transformations_repeated(MapExpansion) == 1 + sdfg.validate() + + map_entries = set() + state = sdfg.start_state + for node in state.nodes(): + if not isinstance(node, dace.nodes.MapEntry): + continue + + # (Fast) MapExpansion should not add memlet paths for each memlet to a tasklet + if sdfg.start_state.entry_node(node) is None: + assert state.in_degree(node) == 1 + assert state.out_degree(node) == 1 + assert len(node.out_connectors) == 1 + else: + assert state.in_degree(node) == 1 + assert state.out_degree(node) == 3 + assert len(node.out_connectors) == 1 + + map_entries.add(node) + + assert len(map_entries) == 2 + +def test_expand_without_inputs(): + @dace.program + def toexpand(B: dace.float64[4, 4]): + for i, j in dace.map[0:4, 0:4]: + with dace.tasklet: + b >> B[i, j] + b = 0 + + sdfg = toexpand.to_sdfg() + sdfg.simplify() + + # Init conditions + sdfg.validate() + assert len([node for node in sdfg.start_state.nodes() if isinstance(node, dace.nodes.MapEntry)]) == 1 + assert len([node for node in sdfg.start_state.nodes() if isinstance(node, dace.nodes.MapExit)]) == 1 + + # Expansion + assert sdfg.apply_transformations_repeated(MapExpansion) == 1 + sdfg.validate() + + map_entries = set() + state = sdfg.start_state + for node in state.nodes(): + if not isinstance(node, dace.nodes.MapEntry): + continue + + # (Fast) MapExpansion should not add memlet paths for each memlet to a tasklet + if sdfg.start_state.entry_node(node) is None: + assert state.in_degree(node) == 0 + assert state.out_degree(node) == 1 + assert len(node.out_connectors) == 0 + else: + assert state.in_degree(node) == 1 + assert state.out_degree(node) == 1 + assert len(node.out_connectors) == 0 + + map_entries.add(node) + + assert len(map_entries) == 2 + +def test_expand_without_dynamic_inputs(): + @dace.program + def expansion(A: dace.float32[20, 30, 5], rng: dace.int32[2]): + @dace.map + def mymap(i: _[0:20], j: _[rng[0]:rng[1]], k: _[0:5]): + a << A[i, j, k] + b >> A[i, j, k] + b = a * 2 + + A = np.random.rand(20, 30, 5).astype(np.float32) + b = np.array([5, 10], dtype=np.int32) + expected = A.copy() + expected[:, 5:10, :] *= 2 + + sdfg = expansion.to_sdfg() + sdfg(A=A, rng=b) + diff = np.linalg.norm(A - expected) + print('Difference (before transformation):', diff) + + sdfg.apply_transformations(MapExpansion) + + sdfg(A=A, rng=b) + expected[:, 5:10, :] *= 2 + diff2 = np.linalg.norm(A - expected) + print('Difference:', diff2) + assert (diff <= 1e-5) and (diff2 <= 1e-5) + +if __name__ == '__main__': + test_expand_with_inputs() + test_expand_without_inputs() + test_expand_without_dynamic_inputs() diff --git a/tests/transformations/tasklet_fusion_test.py b/tests/transformations/tasklet_fusion_test.py index a65d218d98..b314b0feb9 100644 --- a/tests/transformations/tasklet_fusion_test.py +++ b/tests/transformations/tasklet_fusion_test.py @@ -3,6 +3,7 @@ import dace from dace import dtypes from dace.transformation.dataflow import TaskletFusion, MapFusion +from dace.transformation.optimizer import Optimizer import pytest datatype = dace.float32 @@ -195,6 +196,33 @@ def test_map_with_tasklets(language: str, with_data: bool): assert (np.allclose(C, ref)) + +def test_intermediate_transients(): + @dace.program + def sdfg_intermediate_transients(A: dace.float32[10], B: dace.float32[10]): + tmp = dace.define_local_scalar(dace.float32) + + # Use tmp twice to test removal of data + tmp = A[0] + 1 + tmp = tmp * 2 + B[0] = tmp + + + sdfg = sdfg_intermediate_transients.to_sdfg(simplify=True) + assert len([node for node in sdfg.start_state.data_nodes() if node.data == "tmp"]) == 2 + + xforms = Optimizer(sdfg=sdfg).get_pattern_matches(patterns=(TaskletFusion,)) + applied = False + for xform in xforms: + if xform.data.data == "tmp": + xform.apply(sdfg.start_state, sdfg) + applied = True + break + + assert applied + assert len([node for node in sdfg.start_state.data_nodes() if node.data == "tmp"]) == 1 + assert "tmp" in sdfg.arrays + if __name__ == '__main__': test_basic() test_same_name() @@ -204,3 +232,4 @@ def test_map_with_tasklets(language: str, with_data: bool): test_map_with_tasklets(language='Python', with_data=True) test_map_with_tasklets(language='CPP', with_data=False) test_map_with_tasklets(language='CPP', with_data=True) + test_intermediate_transients() diff --git a/tests/transformations/wcr_conversion_test.py b/tests/transformations/wcr_conversion_test.py new file mode 100644 index 0000000000..25913e8db1 --- /dev/null +++ b/tests/transformations/wcr_conversion_test.py @@ -0,0 +1,265 @@ +import dace + +from dace.transformation.dataflow import AugAssignToWCR + + +def test_aug_assign_tasklet_lhs(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = a + k + + sdfg = sdfg_aug_assign_tasklet_lhs.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_lhs_brackets(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs_brackets(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = a + (k + 1) + + sdfg = sdfg_aug_assign_tasklet_lhs_brackets.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_rhs(): + + @dace.program + def sdfg_aug_assign_tasklet_rhs(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = k + a + + sdfg = sdfg_aug_assign_tasklet_rhs.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_rhs_brackets(): + + @dace.program + def sdfg_aug_assign_tasklet_rhs_brackets(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet: + a << A[i] + k << B[i] + b >> A[i] + b = (k + 1) + a + + sdfg = sdfg_aug_assign_tasklet_rhs_brackets.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_lhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = a + k; + """ + + sdfg = sdfg_aug_assign_tasklet_lhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_lhs_brackets_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_lhs_brackets_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = a + (k + 1); + """ + + sdfg = sdfg_aug_assign_tasklet_lhs_brackets_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_rhs_brackets_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_rhs_brackets_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = (k + 1) + a; + """ + + sdfg = sdfg_aug_assign_tasklet_rhs_brackets_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_func_lhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_func_lhs_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + c << B[i] + b >> A[i] + """ + b = min(a, c); + """ + + sdfg = sdfg_aug_assign_tasklet_func_lhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_tasklet_func_rhs_cpp(): + + @dace.program + def sdfg_aug_assign_tasklet_func_rhs_cpp(A: dace.float64[32], B: dace.float64[32]): + for i in range(32): + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + c << B[i] + b >> A[i] + """ + b = min(c, a); + """ + + sdfg = sdfg_aug_assign_tasklet_func_rhs_cpp.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_free_map(): + + @dace.program + def sdfg_aug_assign_free_map(A: dace.float64[32], B: dace.float64[32]): + for i in dace.map[0:32]: + with dace.tasklet(language=dace.Language.CPP): + a << A[0] + k << B[i] + b >> A[0] + """ + b = k * a; + """ + + sdfg = sdfg_aug_assign_free_map.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 1 + + +def test_aug_assign_state_fission_map(): + + @dace.program + def sdfg_aug_assign_state_fission(A: dace.float64[32], B: dace.float64[32]): + for i in dace.map[0:32]: + with dace.tasklet: + a << B[i] + b >> A[i] + b = a + + for i in dace.map[0:32]: + with dace.tasklet: + a << A[0] + b >> A[0] + b = a * 2 + + for i in dace.map[0:32]: + with dace.tasklet: + a << A[0] + b >> A[0] + b = a * 2 + + sdfg = sdfg_aug_assign_state_fission.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR) + assert applied == 2 + + +def test_free_map_permissive(): + + @dace.program + def sdfg_free_map_permissive(A: dace.float64[32], B: dace.float64[32]): + for i in dace.map[0:32]: + with dace.tasklet(language=dace.Language.CPP): + a << A[i] + k << B[i] + b >> A[i] + """ + b = k * a; + """ + + sdfg = sdfg_free_map_permissive.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=False) + assert applied == 0 + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=True) + assert applied == 1 + +def test_aug_assign_same_inconns(): + + @dace.program + def sdfg_aug_assign_same_inconns(A: dace.float64[32]): + for i in dace.map[0:31]: + with dace.tasklet(language=dace.Language.Python): + a << A[i] + b << A[i+1] + c >> A[i] + + c = a * b + + sdfg = sdfg_aug_assign_same_inconns.to_sdfg() + sdfg.simplify() + + applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=True) + assert applied == 1