Skip to content

Commit

Permalink
feat[next]: extend DaCe support of reduction operator (#1332)
Browse files Browse the repository at this point in the history
Adding generic implementation of neighbor-reduction to DaCe backend based on map with Write-Conflict Resolution (WCR) on output memlet. This PR enables use of lambdas as reduction function.
  • Loading branch information
edopao authored and Nina Burgdorfer committed Oct 19, 2023
1 parent 598d96b commit 3b3fc55
Show file tree
Hide file tree
Showing 6 changed files with 351 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
is_scan,
)
from .utility import (
add_mapped_nested_sdfg,
as_dace_type,
connectivity_identifier,
create_memlet_at,
Expand Down Expand Up @@ -321,7 +322,7 @@ def visit_StencilClosure(
array_mapping = {**input_mapping, **conn_mapping}
symbol_mapping = map_nested_sdfg_symbols(closure_sdfg, nsdfg, array_mapping)

nsdfg_node, map_entry, map_exit = self._add_mapped_nested_sdfg(
nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg(
closure_state,
sdfg=nsdfg,
map_ranges=map_domain or {"__dummy": "0"},
Expand Down Expand Up @@ -584,76 +585,6 @@ def _visit_parallel_stencil_closure(

return context.body, map_domain, [r.value.data for r in results]

def _add_mapped_nested_sdfg(
self,
state: dace.SDFGState,
map_ranges: dict[str, str | dace.subsets.Subset]
| list[tuple[str, str | dace.subsets.Subset]],
inputs: dict[str, dace.Memlet],
outputs: dict[str, dace.Memlet],
sdfg: dace.SDFG,
symbol_mapping: dict[str, Any] | None = None,
schedule: Any = dace.dtypes.ScheduleType.Default,
unroll_map: bool = False,
location: Any = None,
debuginfo: Any = None,
input_nodes: dict[str, dace.nodes.AccessNode] | None = None,
output_nodes: dict[str, dace.nodes.AccessNode] | None = None,
) -> tuple[dace.nodes.NestedSDFG, dace.nodes.MapEntry, dace.nodes.MapExit]:
if not symbol_mapping:
symbol_mapping = {sym: sym for sym in sdfg.free_symbols}

nsdfg_node = state.add_nested_sdfg(
sdfg,
None,
set(inputs.keys()),
set(outputs.keys()),
symbol_mapping,
name=sdfg.name,
schedule=schedule,
location=location,
debuginfo=debuginfo,
)

map_entry, map_exit = state.add_map(
f"{sdfg.name}_map", map_ranges, schedule, unroll_map, debuginfo
)

if input_nodes is None:
input_nodes = {
memlet.data: state.add_access(memlet.data) for name, memlet in inputs.items()
}
if output_nodes is None:
output_nodes = {
memlet.data: state.add_access(memlet.data) for name, memlet in outputs.items()
}
if not inputs:
state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet())
for name, memlet in inputs.items():
state.add_memlet_path(
input_nodes[memlet.data],
map_entry,
nsdfg_node,
memlet=memlet,
src_conn=None,
dst_conn=name,
propagate=True,
)
if not outputs:
state.add_edge(nsdfg_node, None, map_exit, None, dace.Memlet())
for name, memlet in outputs.items():
state.add_memlet_path(
nsdfg_node,
map_exit,
output_nodes[memlet.data],
memlet=memlet,
src_conn=name,
dst_conn=None,
propagate=True,
)

return nsdfg_node, map_entry, map_exit

def _visit_domain(
self, node: itir.FunCall, context: Context
) -> tuple[tuple[str, tuple[ValueExpr, ValueExpr]], ...]:
Expand Down
Loading

0 comments on commit 3b3fc55

Please sign in to comment.