diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 54ca08fe6e..a039d311ca 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -14,6 +14,7 @@ import hashlib import warnings from inspect import currentframe, getframeinfo +from pathlib import Path from typing import Any, Mapping, Optional, Sequence import dace @@ -26,7 +27,7 @@ import gt4py.next.program_processors.otf_compile_executor as otf_exec import gt4py.next.program_processors.processor_interface as ppi from gt4py.next import common -from gt4py.next.iterator import embedded as itir_embedded, transforms as itir_transforms +from gt4py.next.iterator import transforms as itir_transforms from gt4py.next.otf.compilation import cache as compilation_cache from gt4py.next.type_system import type_specifications as ts, type_translation @@ -109,23 +110,29 @@ def _ensure_is_on_device( def get_connectivity_args( - neighbor_tables: Sequence[tuple[str, itir_embedded.NeighborTableOffsetProvider]], + neighbor_tables: Mapping[str, common.NeighborTable], device: dace.dtypes.DeviceType, ) -> dict[str, Any]: return { - connectivity_identifier(offset): _ensure_is_on_device(table.table, device) - for offset, table in neighbor_tables + connectivity_identifier(offset): _ensure_is_on_device(offset_provider.table, device) + for offset, offset_provider in neighbor_tables.items() } def get_shape_args( arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any] ) -> Mapping[str, int]: - return { - str(sym): size - for name, value in args.items() - for sym, size in zip(arrays[name].shape, value.shape) - } + shape_args: dict[str, int] = {} + for name, value in args.items(): + for sym, size in zip(arrays[name].shape, value.shape): + if isinstance(sym, dace.symbol): + assert sym.name not in shape_args + shape_args[sym.name] = size + elif sym != size: + raise RuntimeError( + f"Expected shape {arrays[name].shape} for arg {name}, got {value.shape}." + ) + return shape_args def get_offset_args( @@ -158,34 +165,41 @@ def get_stride_args( return stride_args -_build_cache_cpu: dict[str, CompiledSDFG] = {} -_build_cache_gpu: dict[str, CompiledSDFG] = {} +_build_cache: dict[str, CompiledSDFG] = {} def get_cache_id( + build_type: str, + build_for_gpu: bool, program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], column_axis: Optional[common.Dimension], offset_provider: Mapping[str, Any], ) -> str: - max_neighbors = [ - (k, v.max_neighbors) - for k, v in offset_provider.items() - if isinstance( - v, - ( - itir_embedded.NeighborTableOffsetProvider, - itir_embedded.StridedNeighborOffsetProvider, - ), - ) + def offset_invariants(offset): + if isinstance(offset, common.Connectivity): + return ( + offset.origin_axis, + offset.neighbor_axis, + offset.has_skip_values, + offset.max_neighbors, + ) + if isinstance(offset, common.Dimension): + return (offset,) + return tuple() + + offset_cache_keys = [ + (name, *offset_invariants(offset)) for name, offset in offset_provider.items() ] cache_id_args = [ str(arg) for arg in ( + build_type, + build_for_gpu, program, *arg_types, column_axis, - *max_neighbors, + *offset_cache_keys, ) ] m = hashlib.sha256() @@ -262,7 +276,7 @@ def build_sdfg_from_itir( # visit ITIR and generate SDFG program = preprocess_program(program, offset_provider, lift_mode) sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) - sdfg = sdfg_genenerator.visit(program) + sdfg: dace.SDFG = sdfg_genenerator.visit(program) if sdfg is None: raise RuntimeError(f"Visit failed for program {program.id}.") @@ -284,8 +298,8 @@ def build_sdfg_from_itir( # run DaCe auto-optimization heuristics if auto_optimize: - # TODO: Investigate how symbol definitions improve autoopt transformations, - # in which case the cache table should take the symbols map into account. + # TODO: Investigate performance improvement from SDFG specialization with constant symbols, + # for array shape and strides, although this would imply JIT compilation. symbols: dict[str, int] = {} device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) @@ -307,25 +321,31 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): # ITIR parameters column_axis = kwargs.get("column_axis", None) offset_provider = kwargs["offset_provider"] + # debug option to store SDFGs on filesystem and skip lowering ITIR to SDFG at each run + skip_itir_lowering_to_sdfg = kwargs.get("skip_itir_lowering_to_sdfg", False) arg_types = [type_translation.from_value(arg) for arg in args] - cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) + cache_id = get_cache_id(build_type, on_gpu, program, arg_types, column_axis, offset_provider) if build_cache is not None and cache_id in build_cache: # retrieve SDFG program from build cache sdfg_program = build_cache[cache_id] sdfg = sdfg_program.sdfg - else: - sdfg = build_sdfg_from_itir( - program, - *args, - offset_provider=offset_provider, - auto_optimize=auto_optimize, - on_gpu=on_gpu, - column_axis=column_axis, - lift_mode=lift_mode, - ) + sdfg_filename = f"_dacegraphs/gt4py/{cache_id}/{program.id}.sdfg" + if not (skip_itir_lowering_to_sdfg and Path(sdfg_filename).exists()): + sdfg = build_sdfg_from_itir( + program, + *args, + offset_provider=offset_provider, + auto_optimize=auto_optimize, + on_gpu=on_gpu, + column_axis=column_axis, + lift_mode=lift_mode, + ) + sdfg.save(sdfg_filename) + else: + sdfg = dace.SDFG.from_file(sdfg_filename) sdfg.build_folder = compilation_cache._session_cache_dir_path / ".dacecache" with dace.config.temporary_config(): @@ -361,7 +381,7 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: program, *args, **kwargs, - build_cache=_build_cache_cpu, + build_cache=_build_cache, build_type=_build_type, compiler_args=compiler_args, on_gpu=False, @@ -380,7 +400,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: program, *args, **kwargs, - build_cache=_build_cache_gpu, + build_cache=_build_cache, build_type=_build_type, on_gpu=True, ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index dc194c0436..ce1ac6073a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -11,14 +11,14 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, Optional, cast +from typing import Any, Mapping, Optional, Sequence, cast import dace import gt4py.eve as eve from gt4py.next import Dimension, DimensionKind, type_inference as next_typing +from gt4py.next.common import NeighborTable from gt4py.next.iterator import ir as itir, type_inference as itir_typing -from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.iterator.ir import Expr, FunCall, Literal, SymRef from gt4py.next.type_system import type_specifications as ts, type_translation @@ -43,13 +43,12 @@ flatten_list, get_sorted_dims, map_nested_sdfg_symbols, - new_array_symbols, unique_name, unique_var_name, ) -def get_scan_args(stencil: Expr) -> tuple[bool, Literal]: +def _get_scan_args(stencil: Expr) -> tuple[bool, Literal]: """ Parse stencil expression to extract the scan arguments. @@ -68,7 +67,7 @@ def get_scan_args(stencil: Expr) -> tuple[bool, Literal]: return is_forward.value == "True", init_carry -def get_scan_dim( +def _get_scan_dim( column_axis: Dimension, storage_types: dict[str, ts.TypeSpec], output: SymRef, @@ -93,6 +92,35 @@ def get_scan_dim( ) +def _make_array_shape_and_strides( + name: str, + dims: Sequence[Dimension], + neighbor_tables: Mapping[str, NeighborTable], + sort_dims: bool, +) -> tuple[list[dace.symbol], list[dace.symbol]]: + """ + Parse field dimensions and allocate symbols for array shape and strides. + + For local dimensions, the size is known at compile-time and therefore + the corresponding array shape dimension is set to an integer literal value. + + Returns + ------- + tuple(shape, strides) + The output tuple fields are arrays of dace symbolic expressions. + """ + dtype = dace.int64 + sorted_dims = [dim for _, dim in get_sorted_dims(dims)] if sort_dims else dims + shape = [ + neighbor_tables[dim.value].max_neighbors + if dim.kind == DimensionKind.LOCAL + else dace.symbol(unique_name(f"{name}_shape{i}"), dtype) + for i, dim in enumerate(sorted_dims) + ] + strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i, _ in enumerate(shape)] + return shape, strides + + class ItirToSDFG(eve.NodeVisitor): param_types: list[ts.TypeSpec] storage_types: dict[str, ts.TypeSpec] @@ -104,7 +132,7 @@ class ItirToSDFG(eve.NodeVisitor): def __init__( self, param_types: list[ts.TypeSpec], - offset_provider: dict[str, NeighborTableOffsetProvider], + offset_provider: dict[str, NeighborTable], column_axis: Optional[Dimension] = None, ): self.param_types = param_types @@ -112,9 +140,19 @@ def __init__( self.offset_provider = offset_provider self.storage_types = {} - def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True): + def add_storage( + self, + sdfg: dace.SDFG, + name: str, + type_: ts.TypeSpec, + neighbor_tables: Mapping[str, NeighborTable], + has_offset: bool = True, + sort_dimensions: bool = True, + ): if isinstance(type_, ts.FieldType): - shape, strides = new_array_symbols(name, len(type_.dims)) + shape, strides = _make_array_shape_and_strides( + name, type_.dims, neighbor_tables, sort_dimensions + ) offset = ( [dace.symbol(unique_name(f"{name}_offset{i}_")) for i in range(len(type_.dims))] if has_offset @@ -153,14 +191,23 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): # Add program parameters as SDFG storages. for param, type_ in zip(node.params, self.param_types): - self.add_storage(program_sdfg, str(param.id), type_) + self.add_storage(program_sdfg, str(param.id), type_, neighbor_tables) # Add connectivities as SDFG storages. - for offset, table in neighbor_tables: - scalar_kind = type_translation.get_scalar_kind(table.table.dtype) - local_dim = Dimension("ElementDim", kind=DimensionKind.LOCAL) - type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind)) - self.add_storage(program_sdfg, connectivity_identifier(offset), type_, has_offset=False) + for offset, offset_provider in neighbor_tables.items(): + scalar_kind = type_translation.get_scalar_kind(offset_provider.table.dtype) + local_dim = Dimension(offset, kind=DimensionKind.LOCAL) + type_ = ts.FieldType( + [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) + ) + self.add_storage( + program_sdfg, + connectivity_identifier(offset), + type_, + neighbor_tables, + has_offset=False, + sort_dimensions=False, + ) # Create a nested SDFG for all stencil closures. for closure in node.closures: @@ -222,7 +269,7 @@ def visit_StencilClosure( input_names = [str(inp.id) for inp in node.inputs] neighbor_tables = filter_neighbor_tables(self.offset_provider) - connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] output_nodes = self.get_output_nodes(node, closure_sdfg, closure_state) output_names = [k for k, _ in output_nodes.items()] @@ -400,11 +447,11 @@ def _visit_scan_stencil_closure( output_name: str, ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], int]: # extract scan arguments - is_forward, init_carry_value = get_scan_args(node.stencil) + is_forward, init_carry_value = _get_scan_args(node.stencil) # select the scan dimension based on program argument for column axis assert self.column_axis assert isinstance(node.output, SymRef) - scan_dim, scan_dim_index, scan_dtype = get_scan_dim( + scan_dim, scan_dim_index, scan_dtype = _get_scan_dim( self.column_axis, self.storage_types, node.output, @@ -570,7 +617,7 @@ def _visit_parallel_stencil_closure( ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]: neighbor_tables = filter_neighbor_tables(self.offset_provider) input_names = [str(inp.id) for inp in node.inputs] - conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] # find the scan dimension, same as output dimension, and exclude it from the map domain map_ranges = {} @@ -583,7 +630,7 @@ def _visit_parallel_stencil_closure( index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} input_arrays = [(name, self.storage_types[name]) for name in input_names] - connectivity_arrays = [(array_table[name], name) for name in conn_names] + connectivity_arrays = [(array_table[name], name) for name in connectivity_names] context, results = closure_to_tasklet_sdfg( node, diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 0ace6948b0..322a147382 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -237,7 +237,12 @@ def builtin_neighbors( ) data_access_tasklet = state.add_tasklet( "data_access", - code=f"__result = __field[{field_index}] if {neighbor_check} else {transformer.context.reduce_identity.value}", + code=f"__result = __field[{field_index}]" + + ( + f" if {neighbor_check} else {transformer.context.reduce_identity.value}" + if offset_provider.has_skip_values + else "" + ), inputs={"__field", field_index}, outputs={"__result"}, debuginfo=di, @@ -372,20 +377,25 @@ def builtin_list_get( args = list(itertools.chain(*transformer.visit(node_args))) assert len(args) == 2 # index node - assert isinstance(args[0], (SymbolExpr, ValueExpr)) - # 1D-array node - assert isinstance(args[1], ValueExpr) - # source node should be a 1D array - assert len(transformer.context.body.arrays[args[1].value.data].shape) == 1 - - expr_args = [(arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr)] - internals = [ - arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args - ] - expr = f"{internals[1]}[{internals[0]}]" - return transformer.add_expr_tasklet( - expr_args, expr, args[1].dtype, "list_get", dace_debuginfo=di - ) + if isinstance(args[0], SymbolExpr): + index_value = args[0].value + result_name = unique_var_name() + transformer.context.body.add_scalar(result_name, args[1].dtype, transient=True) + result_node = transformer.context.state.add_access(result_name) + transformer.context.state.add_nedge( + args[1].value, + result_node, + dace.Memlet.simple(args[1].value.data, index_value), + ) + return [ValueExpr(result_node, args[1].dtype)] + + else: + expr_args = [(arg, f"{arg.value.data}_v") for arg in args] + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[1]}[{internals[0]}]" + return transformer.add_expr_tasklet( + expr_args, expr, args[1].dtype, "list_get", dace_debuginfo=di + ) def builtin_cast( @@ -562,9 +572,9 @@ def visit_Lambda( ]: func_name = f"lambda_{abs(hash(node)):x}" neighbor_tables = ( - filter_neighbor_tables(self.offset_provider) if use_neighbor_tables else [] + filter_neighbor_tables(self.offset_provider) if use_neighbor_tables else {} ) - connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] # Create the SDFG for the lambda's body lambda_sdfg = dace.SDFG(func_name) @@ -700,8 +710,8 @@ def _visit_call(self, node: itir.FunCall): nsdfg_inputs[var] = create_memlet_full(store, self.context.body.arrays[store]) neighbor_tables = filter_neighbor_tables(self.offset_provider) - for conn, _ in neighbor_tables: - var = connectivity_identifier(conn) + for offset in neighbor_tables.keys(): + var = connectivity_identifier(offset) nsdfg_inputs[var] = create_memlet_full(var, self.context.body.arrays[var]) symbol_mapping = map_nested_sdfg_symbols(self.context.body, func_context.body, nsdfg_inputs) @@ -729,8 +739,8 @@ def _visit_call(self, node: itir.FunCall): store = value.indices[dim] idx_memlet = nsdfg_inputs[var] self.context.state.add_edge(store, None, nsdfg_node, var, idx_memlet) - for conn, _ in neighbor_tables: - var = connectivity_identifier(conn) + for offset in neighbor_tables.keys(): + var = connectivity_identifier(offset) memlet = nsdfg_inputs[var] access = self.context.state.add_access(var, debuginfo=nsdfg_node.debuginfo) self.context.state.add_edge(access, None, nsdfg_node, var, memlet) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 971c1bbdf2..a66fc36b1b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -17,7 +17,7 @@ import dace from gt4py.next import Dimension -from gt4py.next.iterator.embedded import NeighborTableOffsetProvider +from gt4py.next.common import NeighborTable from gt4py.next.iterator.ir import Node from gt4py.next.type_system import type_specifications as ts @@ -52,11 +52,11 @@ def as_dace_type(type_: ts.ScalarType): def filter_neighbor_tables(offset_provider: dict[str, Any]): - return [ - (offset, table) + return { + offset: table for offset, table in offset_provider.items() - if isinstance(table, NeighborTableOffsetProvider) - ] + if isinstance(table, NeighborTable) + } def connectivity_identifier(name: str):