From 0650d77b594e89ea092784593d3eaa7559e4fa51 Mon Sep 17 00:00:00 2001 From: edopao Date: Mon, 30 Oct 2023 14:44:50 +0100 Subject: [PATCH] feat[next]: Extend DaCe support for offset providers (#1353) Extend support in DaCe backend for offset providers, in order to generate the tasklet code in case of shift expressions with both direct and indirect addressing. Visitors for different types of addressing are merged in one unified visit_shift method. --- .../runners/dace_iterator/itir_to_sdfg.py | 1 - .../runners/dace_iterator/itir_to_tasklet.py | 114 ++++++++---------- 2 files changed, 47 insertions(+), 68 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 580486aa4a..1f9692356e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -196,7 +196,6 @@ def visit_StencilClosure( self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array] ) -> tuple[dace.SDFG, list[str], list[str]]: assert ItirToSDFG._check_no_lifts(node) - assert ItirToSDFG._check_shift_offsets_are_literals(node) # Create the closure's nested SDFG and single state. closure_sdfg = dace.SDFG(name="closure") diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index b28703feef..1634596afa 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -478,17 +478,7 @@ def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: return self._visit_deref(node) if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): if node.fun.fun.id == "shift": - offset = node.fun.args[0] - assert isinstance(offset, itir.OffsetLiteral) - offset_name = offset.value - assert isinstance(offset_name, str) - if offset_name not in self.offset_provider: - raise ValueError(f"offset provider for `{offset_name}` is missing") - offset_provider = self.offset_provider[offset_name] - if isinstance(offset_provider, Dimension): - return self._visit_direct_addressing(node) - else: - return self._visit_indirect_addressing(node) + return self._visit_shift(node) elif node.fun.fun.id == "reduce": return self._visit_reduce(node) @@ -653,39 +643,7 @@ def _make_shift_for_rest(self, rest, iterator): fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), args=[iterator] ) - def _visit_direct_addressing(self, node: itir.FunCall) -> IteratorExpr: - assert isinstance(node.fun, itir.FunCall) - shift = node.fun - assert isinstance(shift, itir.FunCall) - - tail, rest = self._split_shift_args(shift.args) - if rest: - iterator = self.visit(self._make_shift_for_rest(rest, node.args[0])) - else: - iterator = self.visit(node.args[0]) - - assert isinstance(tail[0], itir.OffsetLiteral) - offset = tail[0].value - assert isinstance(offset, str) - shifted_dim = self.offset_provider[offset].value - - assert isinstance(tail[1], itir.OffsetLiteral) - shift_amount = tail[1].value - assert isinstance(shift_amount, int) - - args = [ValueExpr(iterator.indices[shifted_dim], dace.int64)] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]} + {shift_amount}" - shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, dace.dtypes.int64, "dir_addr" - )[0].value - - shifted_index = {dim: value for dim, value in iterator.indices.items()} - shifted_index[shifted_dim] = shifted_value - - return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) - - def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: + def _visit_shift(self, node: itir.FunCall) -> IteratorExpr: shift = node.fun assert isinstance(shift, itir.FunCall) tail, rest = self._split_shift_args(shift.args) @@ -695,40 +653,48 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: iterator = self.visit(node.args[0]) assert isinstance(tail[0], itir.OffsetLiteral) - offset = tail[0].value - assert isinstance(offset, str) + offset_dim = tail[0].value + assert isinstance(offset_dim, str) + offset_node = self.visit(tail[1])[0] - assert isinstance(tail[1], itir.OffsetLiteral) - element = tail[1].value - assert isinstance(element, int) - - if isinstance(self.offset_provider[offset], NeighborTableOffsetProvider): - table = self.offset_provider[offset] - shifted_dim = table.origin_axis.value - target_dim = table.neighbor_axis.value - - conn = self.context.state.add_access(connectivity_identifier(offset)) + if isinstance(self.offset_provider[offset_dim], NeighborTableOffsetProvider): + offset_provider = self.offset_provider[offset_dim] + connectivity = self.context.state.add_access(connectivity_identifier(offset_dim)) + shifted_dim = offset_provider.origin_axis.value + target_dim = offset_provider.neighbor_axis.value args = [ - ValueExpr(conn, table.table.dtype), + ValueExpr(connectivity, offset_provider.table.dtype), ValueExpr(iterator.indices[shifted_dim], dace.int64), + offset_node, ] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{internals[1]}, {element}]" - else: - offset_provider = self.offset_provider[offset] - assert isinstance(offset_provider, StridedNeighborOffsetProvider) + expr = f"{internals[0]}[{internals[1]}, {internals[2]}]" + elif isinstance(self.offset_provider[offset_dim], StridedNeighborOffsetProvider): + offset_provider = self.offset_provider[offset_dim] shifted_dim = offset_provider.origin_axis.value target_dim = offset_provider.neighbor_axis.value - offset_value = iterator.indices[shifted_dim] - args = [ValueExpr(offset_value, dace.int64)] - internals = [f"{offset_value.data}_v"] - expr = f"{internals[0]} * {offset_provider.max_neighbors} + {element}" + args = [ + ValueExpr(iterator.indices[shifted_dim], dace.int64), + offset_node, + ] + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[0]} * {offset_provider.max_neighbors} + {internals[1]}" + else: + assert isinstance(self.offset_provider[offset_dim], Dimension) + + shifted_dim = self.offset_provider[offset_dim].value + target_dim = shifted_dim + args = [ + ValueExpr(iterator.indices[shifted_dim], dace.int64), + offset_node, + ] + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[0]} + {internals[1]}" shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, dace.dtypes.int64, "ind_addr" + list(zip(args, internals)), expr, dace.dtypes.int64, "shift" )[0].value shifted_index = {dim: value for dim, value in iterator.indices.items()} @@ -737,6 +703,20 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) + def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: + offset = node.value + assert isinstance(offset, int) + offset_var = unique_var_name() + self.context.body.add_scalar(offset_var, dace.dtypes.int64, transient=True) + offset_node = self.context.state.add_access(offset_var) + tasklet_node = self.context.state.add_tasklet( + "get_offset", {}, {"__out"}, f"__out = {offset}" + ) + self.context.state.add_edge( + tasklet_node, "__out", offset_node, None, dace.Memlet.simple(offset_var, "0") + ) + return [ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)] + def _visit_reduce(self, node: itir.FunCall): result_name = unique_var_name() result_access = self.context.state.add_access(result_name)