Skip to content

Commit

Permalink
[dace] WCR-based reduction now works in icon4py
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Sep 7, 2023
1 parent 8cac9da commit 2cebbde
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,21 @@ def itir_type_as_dace_type(type_: next_typing.Type):
raise NotImplementedError()


def reduction_init_value(op_name_: str, type_: Any):
if op_name_ == "plus":
init_value = type_(0)
elif op_name_ == "multiplies":
init_value = type_(1)
elif op_name_ == "minimum":
init_value = type_("inf")
elif op_name_ == "maximum":
init_value = type_("-inf")
else:
raise NotImplementedError()

return init_value


_MATH_BUILTINS_MAPPING = {
"abs": "abs({})",
"sin": "math.sin({})",
Expand Down Expand Up @@ -139,6 +154,7 @@ class Context:
state: dace.SDFGState
symbol_map: dict[str, IteratorExpr | ValueExpr | SymbolExpr]
reduce_limit: int
reduce_wcr: Optional[str]

def __init__(
self,
Expand All @@ -150,6 +166,7 @@ def __init__(
self.state = state
self.symbol_map = symbol_map
self.reduce_limit = 0
self.reduce_wcr = None


def builtin_neighbors(
Expand Down Expand Up @@ -365,6 +382,8 @@ def visit_Lambda(
value = IteratorExpr(field, indices, arg.dtype, arg.dimensions)
symbol_map[param] = value
context = Context(context_sdfg, context_state, symbol_map)
context.reduce_limit = prev_context.reduce_limit
context.reduce_wcr = prev_context.reduce_wcr
self.context = context

# Add input parameters as arrays
Expand Down Expand Up @@ -411,7 +430,11 @@ def visit_Lambda(
self.context.body.add_scalar(result_name, result.dtype, transient=True)
result_access = self.context.state.add_access(result_name)
self.context.state.add_edge(
result.value, None, result_access, None, dace.Memlet(f"{result.value.data}[0]")
result.value,
None,
result_access,
None,
dace.Memlet(f"{result_access.data}[0]", wcr=context.reduce_wcr),
)
result = ValueExpr(value=result_access, dtype=result.dtype)
else:
Expand Down Expand Up @@ -770,21 +793,43 @@ def _visit_reduce(self, node: itir.FunCall):
nreduce = self.context.body.arrays[neighbor_args[0].value.data].shape[0]
nreduce_domain = {"__idx": f"0:{nreduce}"}

# set variable in context to enable dereference of neighbors in input fields
result_dtype = neighbor_args[0].dtype
self.context.body.add_scalar(result_name, result_dtype, transient=True)

assert isinstance(fun_node.expr, itir.FunCall)
op_name = fun_node.expr.fun
assert isinstance(op_name, itir.SymRef)

init_value = reduction_init_value(op_name.id, result_dtype)
init_state = self.context.body.add_state_before(self.context.state, "init")
init_tasklet = init_state.add_tasklet(
"init_reduce", {}, {"__out"}, f"__out = {init_value}"
)
init_state.add_edge(
init_tasklet,
"__out",
init_state.add_access(result_name),
None,
dace.Memlet.simple(result_name, "0"),
)

# set variable in context to enable dereference of neighbors in input fields and WCR on reduce tasklet
self.context.reduce_limit = nreduce
self.context.reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format(
"x", "y"
)

for i, node_arg in enumerate(node.args):
if not args[i]:
args[i] = self.visit(node_arg)[0]
# clear context
self.context.reduce_limit = 0

result_dtype = neighbor_args[0].dtype
self.context.body.add_scalar(result_name, result_dtype, transient=True)

assert isinstance(fun_node.expr, itir.FunCall)
lambda_node = itir.Lambda(expr=fun_node.expr.args[1], params=fun_node.params[1:])
lambda_context, inner_inputs, inner_outputs = self.visit(lambda_node, args=args)

# clear context
self.context.reduce_limit = 0
self.context.reduce_wcr = None

# the connectivity arrays (neighbor tables) are not needed inside the lambda SDFG
neighbor_tables = filter_neighbor_tables(self.offset_provider)
for conn, _ in neighbor_tables:
Expand All @@ -797,11 +842,7 @@ def _visit_reduce(self, node: itir.FunCall):
input_memlets = [
create_memlet_at(expr.value.data, ("__idx",)) for arg, expr in zip(node.args, args)
]

op_name = fun_node.expr.fun
assert isinstance(op_name, itir.SymRef)
wcr_str = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y")
output_memlet = dace.Memlet(data=result_name, subset="0", wcr=wcr_str)
output_memlet = dace.Memlet.simple(result_name, "0")

input_mapping = {param: arg for (param, _), arg in zip(inner_inputs, input_memlets)}
output_mapping = {inner_outputs[0].value.data: output_memlet}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def fencil(edge_f: cases.EField, out: cases.VField):

def test_reduction_expression_in_call(unstructured_case):
if unstructured_case.backend == dace_iterator.run_dace_iterator:
# -edge_f(V2E) * tmp_nbh * 2 gets inlined with the neighbor_sum operation in the reduction in itir,
# so in addition to the skipped reason, currently itir is a lambda instead of the 'plus' operation
# -edge_f(V2E) * tmp_nbh * 2 gets inlined with the neighbor_sum operation in the reduction in itir
pytest.skip("Not supported in DaCe backend: Reductions not directly on a field.")

@gtx.field_operator
Expand Down

0 comments on commit 2cebbde

Please sign in to comment.