From de4de8e59607bb126bae29e26f5b773a46e3dfef Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 25 Jan 2024 10:55:59 +0100 Subject: [PATCH] Review comments --- .../runners/dace_iterator/__init__.py | 9 +-- .../runners/dace_iterator/itir_to_sdfg.py | 58 ++++++++++++++----- .../runners/dace_iterator/itir_to_tasklet.py | 16 ++--- .../runners/dace_iterator/utility.py | 22 +------ 4 files changed, 55 insertions(+), 50 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index c7ba883ade..bb2ac04933 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -14,6 +14,7 @@ import hashlib import warnings from inspect import currentframe, getframeinfo +from pathlib import Path from typing import Any, Mapping, Optional, Sequence import dace @@ -329,12 +330,12 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): sdfg = sdfg_program.sdfg else: - # debug: test large icon4py stencils without regenerating the SDFG at each run - generate_sdfg = True + # useful for debug: run gt4py program without regenerating the SDFG at each run + skip_itir_lowering_to_sdfg = False target = "gpu" if on_gpu else "cpu" - sdfg_filename = f"_dacegraphs/{target}/{program.id}.sdfg" + sdfg_filename = f"_dacegraphs/gt4py/{target}/{program.id}.sdfg" - if generate_sdfg: + if not (skip_itir_lowering_to_sdfg and Path(sdfg_filename).exists()): sdfg = build_sdfg_from_itir( program, *args, diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index e82ac37cdf..ce1ac6073a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -11,7 +11,7 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, Mapping, Optional, cast +from typing import Any, Mapping, Optional, Sequence, cast import dace @@ -42,14 +42,13 @@ filter_neighbor_tables, flatten_list, get_sorted_dims, - map_field_dimensions_to_sdfg_symbols, map_nested_sdfg_symbols, unique_name, unique_var_name, ) -def get_scan_args(stencil: Expr) -> tuple[bool, Literal]: +def _get_scan_args(stencil: Expr) -> tuple[bool, Literal]: """ Parse stencil expression to extract the scan arguments. @@ -68,7 +67,7 @@ def get_scan_args(stencil: Expr) -> tuple[bool, Literal]: return is_forward.value == "True", init_carry -def get_scan_dim( +def _get_scan_dim( column_axis: Dimension, storage_types: dict[str, ts.TypeSpec], output: SymRef, @@ -93,6 +92,35 @@ def get_scan_dim( ) +def _make_array_shape_and_strides( + name: str, + dims: Sequence[Dimension], + neighbor_tables: Mapping[str, NeighborTable], + sort_dims: bool, +) -> tuple[list[dace.symbol], list[dace.symbol]]: + """ + Parse field dimensions and allocate symbols for array shape and strides. + + For local dimensions, the size is known at compile-time and therefore + the corresponding array shape dimension is set to an integer literal value. + + Returns + ------- + tuple(shape, strides) + The output tuple fields are arrays of dace symbolic expressions. + """ + dtype = dace.int64 + sorted_dims = [dim for _, dim in get_sorted_dims(dims)] if sort_dims else dims + shape = [ + neighbor_tables[dim.value].max_neighbors + if dim.kind == DimensionKind.LOCAL + else dace.symbol(unique_name(f"{name}_shape{i}"), dtype) + for i, dim in enumerate(sorted_dims) + ] + strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i, _ in enumerate(shape)] + return shape, strides + + class ItirToSDFG(eve.NodeVisitor): param_types: list[ts.TypeSpec] storage_types: dict[str, ts.TypeSpec] @@ -122,7 +150,7 @@ def add_storage( sort_dimensions: bool = True, ): if isinstance(type_, ts.FieldType): - shape, strides = map_field_dimensions_to_sdfg_symbols( + shape, strides = _make_array_shape_and_strides( name, type_.dims, neighbor_tables, sort_dimensions ) offset = ( @@ -166,10 +194,12 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): self.add_storage(program_sdfg, str(param.id), type_, neighbor_tables) # Add connectivities as SDFG storages. - for offset, table in neighbor_tables.items(): - scalar_kind = type_translation.get_scalar_kind(table.table.dtype) + for offset, offset_provider in neighbor_tables.items(): + scalar_kind = type_translation.get_scalar_kind(offset_provider.table.dtype) local_dim = Dimension(offset, kind=DimensionKind.LOCAL) - type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind)) + type_ = ts.FieldType( + [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) + ) self.add_storage( program_sdfg, connectivity_identifier(offset), @@ -239,9 +269,7 @@ def visit_StencilClosure( input_names = [str(inp.id) for inp in node.inputs] neighbor_tables = filter_neighbor_tables(self.offset_provider) - connectivity_names = [ - connectivity_identifier(offset) for offset, _ in neighbor_tables.items() - ] + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] output_nodes = self.get_output_nodes(node, closure_sdfg, closure_state) output_names = [k for k, _ in output_nodes.items()] @@ -419,11 +447,11 @@ def _visit_scan_stencil_closure( output_name: str, ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], int]: # extract scan arguments - is_forward, init_carry_value = get_scan_args(node.stencil) + is_forward, init_carry_value = _get_scan_args(node.stencil) # select the scan dimension based on program argument for column axis assert self.column_axis assert isinstance(node.output, SymRef) - scan_dim, scan_dim_index, scan_dtype = get_scan_dim( + scan_dim, scan_dim_index, scan_dtype = _get_scan_dim( self.column_axis, self.storage_types, node.output, @@ -589,9 +617,7 @@ def _visit_parallel_stencil_closure( ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]: neighbor_tables = filter_neighbor_tables(self.offset_provider) input_names = [str(inp.id) for inp in node.inputs] - connectivity_names = [ - connectivity_identifier(offset) for offset, _ in neighbor_tables.items() - ] + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] # find the scan dimension, same as output dimension, and exclude it from the map domain map_ranges = {} diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 13e864e0ec..322a147382 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -382,11 +382,9 @@ def builtin_list_get( result_name = unique_var_name() transformer.context.body.add_scalar(result_name, args[1].dtype, transient=True) result_node = transformer.context.state.add_access(result_name) - transformer.context.state.add_edge( + transformer.context.state.add_nedge( args[1].value, - None, result_node, - None, dace.Memlet.simple(args[1].value.data, index_value), ) return [ValueExpr(result_node, args[1].dtype)] @@ -576,9 +574,7 @@ def visit_Lambda( neighbor_tables = ( filter_neighbor_tables(self.offset_provider) if use_neighbor_tables else {} ) - connectivity_names = [ - connectivity_identifier(offset) for offset, _ in neighbor_tables.items() - ] + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] # Create the SDFG for the lambda's body lambda_sdfg = dace.SDFG(func_name) @@ -714,8 +710,8 @@ def _visit_call(self, node: itir.FunCall): nsdfg_inputs[var] = create_memlet_full(store, self.context.body.arrays[store]) neighbor_tables = filter_neighbor_tables(self.offset_provider) - for conn, _ in neighbor_tables.items(): - var = connectivity_identifier(conn) + for offset in neighbor_tables.keys(): + var = connectivity_identifier(offset) nsdfg_inputs[var] = create_memlet_full(var, self.context.body.arrays[var]) symbol_mapping = map_nested_sdfg_symbols(self.context.body, func_context.body, nsdfg_inputs) @@ -743,8 +739,8 @@ def _visit_call(self, node: itir.FunCall): store = value.indices[dim] idx_memlet = nsdfg_inputs[var] self.context.state.add_edge(store, None, nsdfg_node, var, idx_memlet) - for conn, _ in neighbor_tables.items(): - var = connectivity_identifier(conn) + for offset in neighbor_tables.keys(): + var = connectivity_identifier(offset) memlet = nsdfg_inputs[var] access = self.context.state.add_access(var, debuginfo=nsdfg_node.debuginfo) self.context.state.add_edge(access, None, nsdfg_node, var, memlet) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index f6435cfb14..a66fc36b1b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -12,11 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later import itertools -from typing import Any, Mapping, Optional, Sequence +from typing import Any, Optional, Sequence import dace -from gt4py.next import Dimension, DimensionKind +from gt4py.next import Dimension from gt4py.next.common import NeighborTable from gt4py.next.iterator.ir import Node from gt4py.next.type_system import type_specifications as ts @@ -181,24 +181,6 @@ def unique_var_name(): return unique_name("_var") -def map_field_dimensions_to_sdfg_symbols( - name: str, - dims: Sequence[Dimension], - neighbor_tables: Mapping[str, NeighborTable], - sort_dims: bool, -) -> tuple[list[dace.symbol], list[dace.symbol]]: - dtype = dace.int64 - sorted_dims = [dim for _, dim in get_sorted_dims(dims)] if sort_dims else dims - shape = [ - neighbor_tables[dim.value].max_neighbors - if dim.kind == DimensionKind.LOCAL - else dace.symbol(unique_name(f"{name}_shape{i}"), dtype) - for i, dim in enumerate(sorted_dims) - ] - strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i, _ in enumerate(shape)] - return shape, strides - - def new_array_symbols(name: str, ndim: int) -> tuple[list[dace.symbol], list[dace.symbol]]: dtype = dace.int64 shape = [dace.symbol(unique_name(f"{name}_shape{i}"), dtype) for i in range(ndim)]