Skip to content

Commit

Permalink
Fixes for TaskletFusion, AugAssignToWCR and MapExpansion (#1432)
Browse files Browse the repository at this point in the history
- The PR fixes two minor bugs for corner cases of the AugAssignToWCR and
TaskletFusion which are reflected in additional test cases:
- TaskletFusion: Should not remove array from SDFG, since it could be
used elsewhere
- AugAssignToWCR: Handle tasklets where all inputs come from same array
- The PR re-writes MapExpansion to create only one memlet path per out
connector to be more efficient. I experienced MapExpansion running for
literally hours because it uses add_memlet_path for each edge to a
tasklet. This is too expensive for >4 dimensional stencils with >50
edges
  • Loading branch information
lukastruemper committed Dec 2, 2023
1 parent 5e1bdfa commit 8e6331d
Show file tree
Hide file tree
Showing 7 changed files with 451 additions and 59 deletions.
31 changes: 24 additions & 7 deletions dace/transformation/dataflow/map_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

from dace.sdfg.utils import consolidate_edges
from typing import Dict, List
import copy
import dace
from dace import dtypes, subsets, symbolic
from dace.sdfg import nodes
from dace.sdfg import utils as sdutil
from dace.sdfg.graph import OrderedMultiDiConnectorGraph
from dace.transformation import transformation as pm
from dace.sdfg.propagation import propagate_memlets_scope


class MapExpansion(pm.SingleStateTransformation):
Expand Down Expand Up @@ -61,14 +63,28 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
# 1. If there are no edges coming from the outside, use empty memlets
# 2. Edges with IN_* connectors replicate along the maps
# 3. Edges for dynamic map ranges replicate until reaching range(s)
for edge in graph.out_edges(map_entry):
for edge in list(graph.out_edges(map_entry)):
if edge.src_conn is not None and edge.src_conn not in entries[-1].out_connectors:
entries[-1].add_out_connector(edge.src_conn)

graph.add_edge(entries[-1], edge.src_conn, edge.dst, edge.dst_conn, memlet=copy.deepcopy(edge.data))
graph.remove_edge(edge)
graph.add_memlet_path(map_entry,
*entries,
edge.dst,
src_conn=edge.src_conn,
memlet=edge.data,
dst_conn=edge.dst_conn)

if graph.in_degree(map_entry) == 0:
graph.add_memlet_path(map_entry, *entries, memlet=dace.Memlet())
else:
for edge in graph.in_edges(map_entry):
if not edge.dst_conn.startswith("IN_"):
continue

in_conn = edge.dst_conn
out_conn = "OUT_" + in_conn[3:]
if in_conn not in entries[-1].in_connectors:
graph.add_memlet_path(map_entry,
*entries,
memlet=copy.deepcopy(edge.data),
src_conn=out_conn,
dst_conn=in_conn)

# Modify dynamic map ranges
dynamic_edges = dace.sdfg.dynamic_map_inputs(graph, map_entry)
Expand Down Expand Up @@ -111,6 +127,7 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
else:
raise ValueError('Cannot find scope in state')

propagate_memlets_scope(sdfg, state=graph, scopes=scope)
consolidate_edges(sdfg, scope)

return [map_entry] + entries
2 changes: 1 addition & 1 deletion dace/transformation/dataflow/tasklet_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,5 +272,5 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG):
graph.remove_node(t1)
if data is not None:
graph.remove_node(data)
sdfg.remove_data(data.data, True)

graph.remove_node(t2)
27 changes: 13 additions & 14 deletions dace/transformation/dataflow/wcr_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):

outedge = graph.edges_between(tasklet, mx)[0]

# If in map, only match if the subset is independent of any
# map indices (otherwise no conflict)
if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len(
me.map.params):
return False

# Get relevant output connector
outconn = outedge.src_conn

Expand Down Expand Up @@ -115,16 +121,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
if edge.data.subset != outedge.data.subset:
continue

# If in map, only match if the subset is independent of any
# map indices (otherwise no conflict)
if (expr_index == 1 and len(outedge.data.subset.free_symbols
& set(me.map.params)) == len(me.map.params)):
continue

return True
else:
# Only Python/C++ tasklets supported
return False

return False

Expand Down Expand Up @@ -192,11 +189,13 @@ def apply(self, state: SDFGState, sdfg: SDFG):
rhs: ast.BinOp = ast_node.value
op = AugAssignToWCR._PYOP_MAP[type(rhs.op)]
inconns = list(edge.dst_conn for edge in inedges)
for n in (rhs.left, rhs.right):
if isinstance(n, ast.Name) and n.id in inconns:
inedge = inedges[inconns.index(n.id)]
else:
new_rhs = n
if isinstance(rhs.left, ast.Name) and rhs.left.id in inconns:
inedge = inedges[inconns.index(rhs.left.id)]
new_rhs = rhs.right
else:
inedge = inedges[inconns.index(rhs.right.id)]
new_rhs = rhs.left

new_node = ast.copy_location(ast.Assign(targets=[lhs], value=new_rhs), ast_node)
tasklet.code.code = [new_node]

Expand Down
37 changes: 0 additions & 37 deletions tests/expansion_dynamic_range_test.py

This file was deleted.

119 changes: 119 additions & 0 deletions tests/transformations/map_expansion_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import dace
import numpy as np
from dace.transformation.dataflow import MapExpansion

def test_expand_with_inputs():
@dace.program
def toexpand(A: dace.float64[4, 2], B: dace.float64[2, 2]):
for i, j in dace.map[1:3, 0:2]:
with dace.tasklet:
a1 << A[i, j]
a2 << A[i + 1, j]
a3 << A[i - 1, j]
b >> B[i-1, j]
b = a1 + a2 + a3

sdfg = toexpand.to_sdfg()
sdfg.simplify()

# Init conditions
sdfg.validate()
assert len([node for node in sdfg.start_state.nodes() if isinstance(node, dace.nodes.MapEntry)]) == 1
assert len([node for node in sdfg.start_state.nodes() if isinstance(node, dace.nodes.MapExit)]) == 1

# Expansion
assert sdfg.apply_transformations_repeated(MapExpansion) == 1
sdfg.validate()

map_entries = set()
state = sdfg.start_state
for node in state.nodes():
if not isinstance(node, dace.nodes.MapEntry):
continue

# (Fast) MapExpansion should not add memlet paths for each memlet to a tasklet
if sdfg.start_state.entry_node(node) is None:
assert state.in_degree(node) == 1
assert state.out_degree(node) == 1
assert len(node.out_connectors) == 1
else:
assert state.in_degree(node) == 1
assert state.out_degree(node) == 3
assert len(node.out_connectors) == 1

map_entries.add(node)

assert len(map_entries) == 2

def test_expand_without_inputs():
@dace.program
def toexpand(B: dace.float64[4, 4]):
for i, j in dace.map[0:4, 0:4]:
with dace.tasklet:
b >> B[i, j]
b = 0

sdfg = toexpand.to_sdfg()
sdfg.simplify()

# Init conditions
sdfg.validate()
assert len([node for node in sdfg.start_state.nodes() if isinstance(node, dace.nodes.MapEntry)]) == 1
assert len([node for node in sdfg.start_state.nodes() if isinstance(node, dace.nodes.MapExit)]) == 1

# Expansion
assert sdfg.apply_transformations_repeated(MapExpansion) == 1
sdfg.validate()

map_entries = set()
state = sdfg.start_state
for node in state.nodes():
if not isinstance(node, dace.nodes.MapEntry):
continue

# (Fast) MapExpansion should not add memlet paths for each memlet to a tasklet
if sdfg.start_state.entry_node(node) is None:
assert state.in_degree(node) == 0
assert state.out_degree(node) == 1
assert len(node.out_connectors) == 0
else:
assert state.in_degree(node) == 1
assert state.out_degree(node) == 1
assert len(node.out_connectors) == 0

map_entries.add(node)

assert len(map_entries) == 2

def test_expand_without_dynamic_inputs():
@dace.program
def expansion(A: dace.float32[20, 30, 5], rng: dace.int32[2]):
@dace.map
def mymap(i: _[0:20], j: _[rng[0]:rng[1]], k: _[0:5]):
a << A[i, j, k]
b >> A[i, j, k]
b = a * 2

A = np.random.rand(20, 30, 5).astype(np.float32)
b = np.array([5, 10], dtype=np.int32)
expected = A.copy()
expected[:, 5:10, :] *= 2

sdfg = expansion.to_sdfg()
sdfg(A=A, rng=b)
diff = np.linalg.norm(A - expected)
print('Difference (before transformation):', diff)

sdfg.apply_transformations(MapExpansion)

sdfg(A=A, rng=b)
expected[:, 5:10, :] *= 2
diff2 = np.linalg.norm(A - expected)
print('Difference:', diff2)
assert (diff <= 1e-5) and (diff2 <= 1e-5)

if __name__ == '__main__':
test_expand_with_inputs()
test_expand_without_inputs()
test_expand_without_dynamic_inputs()
29 changes: 29 additions & 0 deletions tests/transformations/tasklet_fusion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import dace
from dace import dtypes
from dace.transformation.dataflow import TaskletFusion, MapFusion
from dace.transformation.optimizer import Optimizer
import pytest

datatype = dace.float32
Expand Down Expand Up @@ -195,6 +196,33 @@ def test_map_with_tasklets(language: str, with_data: bool):
assert (np.allclose(C, ref))



def test_intermediate_transients():
@dace.program
def sdfg_intermediate_transients(A: dace.float32[10], B: dace.float32[10]):
tmp = dace.define_local_scalar(dace.float32)

# Use tmp twice to test removal of data
tmp = A[0] + 1
tmp = tmp * 2
B[0] = tmp


sdfg = sdfg_intermediate_transients.to_sdfg(simplify=True)
assert len([node for node in sdfg.start_state.data_nodes() if node.data == "tmp"]) == 2

xforms = Optimizer(sdfg=sdfg).get_pattern_matches(patterns=(TaskletFusion,))
applied = False
for xform in xforms:
if xform.data.data == "tmp":
xform.apply(sdfg.start_state, sdfg)
applied = True
break

assert applied
assert len([node for node in sdfg.start_state.data_nodes() if node.data == "tmp"]) == 1
assert "tmp" in sdfg.arrays

if __name__ == '__main__':
test_basic()
test_same_name()
Expand All @@ -204,3 +232,4 @@ def test_map_with_tasklets(language: str, with_data: bool):
test_map_with_tasklets(language='Python', with_data=True)
test_map_with_tasklets(language='CPP', with_data=False)
test_map_with_tasklets(language='CPP', with_data=True)
test_intermediate_transients()
Loading

0 comments on commit 8e6331d

Please sign in to comment.