Skip to content

Commit

Permalink
AugAssignToWCR: Minor bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
lukastruemper committed Nov 16, 2023
1 parent 43ca982 commit d345c59
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 15 deletions.
28 changes: 13 additions & 15 deletions dace/transformation/dataflow/wcr_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):

outedge = graph.edges_between(tasklet, mx)[0]

# If in map, only match if the subset is independent of any
# map indices (otherwise no conflict)
if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len(
me.map.params):
return False

# Get relevant output connector
outconn = outedge.src_conn

Expand Down Expand Up @@ -131,17 +137,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
if edge.data.subset != outedge.data.subset:
continue

# If in map, only match if the subset is independent of any
# map indices (otherwise no conflict)
if expr_index == 1:
if not permissive and len(outedge.data.subset.free_symbols & set(me.map.params)) == len(
me.map.params):
continue

return True
else:
# Only Python/C++ tasklets supported
return False

return False

Expand Down Expand Up @@ -182,11 +178,13 @@ def apply(self, state: SDFGState, sdfg: SDFG):
rhs: ast.BinOp = ast_node.value
op = AugAssignToWCR._PYOP_MAP[type(rhs.op)]
inconns = list(edge.dst_conn for edge in inedges)
for n in (rhs.left, rhs.right):
if isinstance(n, ast.Name) and n.id in inconns:
inedge = inedges[inconns.index(n.id)]
else:
new_rhs = n
if isinstance(rhs.left, ast.Name) and rhs.left.id in inconns:
inedge = inedges[inconns.index(rhs.left.id)]
new_rhs = rhs.right
else:
inedge = inedges[inconns.index(rhs.right.id)]
new_rhs = rhs.left

new_node = ast.copy_location(ast.Assign(targets=[lhs], value=new_rhs), ast_node)
tasklet.code.code = [new_node]

Expand Down
18 changes: 18 additions & 0 deletions tests/transformations/wcr_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,21 @@ def sdfg_free_map_permissive(A: dace.float64[32], B: dace.float64[32]):

applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=True)
assert applied == 1

def test_aug_assign_same_inconns():

@dace.program
def sdfg_aug_assign_same_inconns(A: dace.float64[32]):
for i in dace.map[0:31]:
with dace.tasklet(language=dace.Language.Python):
a << A[i]
b << A[i+1]
c >> A[i]

c = a * b

sdfg = sdfg_aug_assign_same_inconns.to_sdfg()
sdfg.simplify()

applied = sdfg.apply_transformations_repeated(AugAssignToWCR, permissive=True)
assert applied == 1

0 comments on commit d345c59

Please sign in to comment.