Skip to content

Commit

Permalink
AugAssignToWCR: Support for min/max functions
Browse files Browse the repository at this point in the history
  • Loading branch information
lukastruemper committed Sep 1, 2023
1 parent 881c7a6 commit 07644f8
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 10 deletions.
40 changes: 34 additions & 6 deletions dace/transformation/dataflow/wcr_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class AugAssignToWCR(transformation.SingleStateTransformation):
map_exit = transformation.PatternNode(nodes.MapExit)

_EXPRESSIONS = ['+', '-', '*', '^', '%'] #, '/']
_FUNCTIONS = ['min', 'max']
_EXPR_MAP = {'-': ('+', '-({expr})'), '/': ('*', '((decltype({expr}))1)/({expr})')}
_PYOP_MAP = {ast.Add: '+', ast.Sub: '-', ast.Mult: '*', ast.BitXor: '^', ast.Mod: '%', ast.Div: '/'}

Expand Down Expand Up @@ -78,6 +79,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
outconn = outedge.src_conn

ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS)
funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS)

if tasklet.language is dtypes.Language.Python:
# Match a single assignment with a binary operation as RHS
Expand Down Expand Up @@ -109,8 +111,12 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
inconn = edge.dst_conn
lhs = r'^\s*%s\s*=\s*%s\s*%s.*;$' % (re.escape(outconn), re.escape(inconn), ops)
rhs = r'^\s*%s\s*=\s*\(.*\)\s*%s\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn))
func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,.*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn))
func_rhs = r'^\s*%s\s*=\s*(%s)\(.*,\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs, re.escape(inconn))
if re.match(lhs, cstr) is None and re.match(rhs, cstr) is None:
continue
if re.match(func_lhs, cstr) is None and re.match(func_rhs, cstr) is None:
continue

# Same memlet
if edge.data.subset != outedge.data.subset:
continue
Expand Down Expand Up @@ -183,6 +189,7 @@ def apply(self, state: SDFGState, sdfg: SDFG):
outconn = outedge.src_conn

ops = '[%s]' % ''.join(re.escape(o) for o in AugAssignToWCR._EXPRESSIONS)
funcs = '|'.join(re.escape(o) for o in AugAssignToWCR._FUNCTIONS)

# Change tasklet code
if tasklet.language is dtypes.Language.Python:
Expand All @@ -209,9 +216,24 @@ def apply(self, state: SDFGState, sdfg: SDFG):
match = re.match(
r'^\s*%s\s*=\s*\((.*)\)\s*(%s)\s*%s;$' % (re.escape(outconn), ops, re.escape(inconn)), cstr)
if match is None:
continue
op = match.group(2)
expr = match.group(1)
func_rhs = r'^\s*%s\s*=\s*(%s)\((.*),\s*%s\s*\)\s*;$' % (re.escape(outconn), funcs,
re.escape(inconn))
match = re.match(func_rhs, cstr)
if match is None:
func_lhs = r'^\s*%s\s*=\s*(%s)\(\s*%s\s*,(.*)\)\s*;$' % (re.escape(outconn), funcs,
re.escape(inconn))
match = re.match(func_lhs, cstr)
if match is None:
continue
else:
op = match.group(1)
expr = match.group(2)
else:
op = match.group(1)
expr = match.group(2)
else:
op = match.group(2)
expr = match.group(1)
else:
op = match.group(1)
expr = match.group(2)
Expand All @@ -231,7 +253,10 @@ def apply(self, state: SDFGState, sdfg: SDFG):
raise NotImplementedError

# Change output edge
outedge.data.wcr = f'lambda a,b: a {op} b'
if op in AugAssignToWCR._FUNCTIONS:
outedge.data.wcr = f'lambda a,b: {op}(a, b)'
else:
outedge.data.wcr = f'lambda a,b: a {op} b'

if self.expr_index == 0:
# Remove input node and connector
Expand All @@ -251,6 +276,9 @@ def apply(self, state: SDFGState, sdfg: SDFG):
sd = sd.parent_sdfg
outedge = next(iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data)))
for outedge in nstate.memlet_path(outedge):
outedge.data.wcr = f'lambda a,b: a {op} b'
if op in AugAssignToWCR._FUNCTIONS:
outedge.data.wcr = f'lambda a,b: {op}(a, b)'
else:
outedge.data.wcr = f'lambda a,b: a {op} b'
# At this point we are leading to an access node again and can
# traverse further up
40 changes: 36 additions & 4 deletions tests/transformations/wcr_conversion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,39 @@ def sdfg_aug_assign_tasklet_rhs_brackets_cpp(A: dace.float64[32]):
assert applied == 1


if __name__ == "__main__":
test_aug_assign_tasklet_lhs_cpp()
test_aug_assign_tasklet_lhs_brackets_cpp()
test_aug_assign_tasklet_rhs_brackets_cpp()
def test_aug_assign_tasklet_func_lhs_cpp():

@dace.program
def sdfg_aug_assign_tasklet_func_lhs_cpp(A: dace.float64[32]):
for i in range(32):
with dace.tasklet(language=dace.Language.CPP):
a << A[i]
b >> A[i]
"""
b = min(a, 0);
"""

sdfg = sdfg_aug_assign_tasklet_func_lhs_cpp.to_sdfg()
sdfg.simplify()

applied = sdfg.apply_transformations_repeated(AugAssignToWCR)
assert applied == 1


def test_aug_assign_tasklet_func_rhs_cpp():

@dace.program
def sdfg_aug_assign_tasklet_func_rhs_cpp(A: dace.float64[32]):
for i in range(32):
with dace.tasklet(language=dace.Language.CPP):
a << A[i]
b >> A[i]
"""
b = min(0, a);
"""

sdfg = sdfg_aug_assign_tasklet_func_rhs_cpp.to_sdfg()
sdfg.simplify()

applied = sdfg.apply_transformations_repeated(AugAssignToWCR)
assert applied == 1

0 comments on commit 07644f8

Please sign in to comment.