diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index f78d90095c..1c1bed9c5e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -11,29 +11,44 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - -from typing import Any, Mapping, Sequence +import hashlib +from typing import Any, Mapping, Optional, Sequence import dace import numpy as np +from dace.codegen.compiled_sdfg import CompiledSDFG +from dace.transformation.auto import auto_optimize as autoopt import gt4py.next.iterator.ir as itir -from gt4py.next import common -from gt4py.next.iterator.embedded import NeighborTableOffsetProvider +from gt4py.next.common import Dimension, Domain, UnitRange, is_field +from gt4py.next.iterator.embedded import NeighborTableOffsetProvider, StridedNeighborOffsetProvider from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms from gt4py.next.otf.compilation import cache from gt4py.next.program_processors.processor_interface import program_executor -from gt4py.next.type_system import type_translation +from gt4py.next.type_system import type_specifications as ts, type_translation from .itir_to_sdfg import ItirToSDFG -from .utility import connectivity_identifier, filter_neighbor_tables +from .utility import connectivity_identifier, filter_neighbor_tables, get_sorted_dims + + +def get_sorted_dim_ranges(domain: Domain) -> Sequence[UnitRange]: + sorted_dims = get_sorted_dims(domain.dims) + return [domain.ranges[dim_index] for dim_index, _ in sorted_dims] + + +""" Default build configuration in DaCe backend """ +_build_type = "Release" +# removing -ffast-math from DaCe default compiler args in order to support isfinite/isinf/isnan built-ins +_cpu_args = ( + "-std=c++14 -fPIC -Wall -Wextra -O3 -march=native -Wno-unused-parameter -Wno-unused-label" +) def convert_arg(arg: Any): - if common.is_field(arg): - sorted_dims = sorted(enumerate(arg.__gt_dims__), key=lambda v: v[1].value) + if is_field(arg): + sorted_dims = get_sorted_dims(arg.domain.dims) ndim = len(sorted_dims) - dim_indices = [dim[0] for dim in sorted_dims] + dim_indices = [dim_index for dim_index, _ in sorted_dims] assert isinstance(arg.ndarray, np.ndarray) return np.moveaxis(arg.ndarray, range(ndim), dim_indices) return arg @@ -69,6 +84,17 @@ def get_shape_args( } +def get_offset_args( + arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any] +) -> Mapping[str, int]: + return { + str(sym): -drange.start + for param, arg in zip(params, args) + if is_field(arg) + for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain)) + } + + def get_stride_args( arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any] ) -> Mapping[str, int]: @@ -85,17 +111,89 @@ def get_stride_args( return stride_args +_build_cache_cpu: dict[str, CompiledSDFG] = {} +_build_cache_gpu: dict[str, CompiledSDFG] = {} + + +def get_cache_id( + program: itir.FencilDefinition, + arg_types: Sequence[ts.TypeSpec], + column_axis: Optional[Dimension], + offset_provider: Mapping[str, Any], +) -> str: + max_neighbors = [ + (k, v.max_neighbors) + for k, v in offset_provider.items() + if isinstance(v, (NeighborTableOffsetProvider, StridedNeighborOffsetProvider)) + ] + cache_id_args = [ + str(arg) + for arg in ( + program, + *arg_types, + column_axis, + *max_neighbors, + ) + ] + m = hashlib.sha256() + for s in cache_id_args: + m.update(s.encode()) + return m.hexdigest() + + @program_executor def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: + # build parameters + auto_optimize = kwargs.get("auto_optimize", False) + build_type = kwargs.get("build_type", "RelWithDebInfo") + run_on_gpu = kwargs.get("run_on_gpu", False) + build_cache = kwargs.get("build_cache", None) + # ITIR parameters column_axis = kwargs.get("column_axis", None) offset_provider = kwargs["offset_provider"] - neighbor_tables = filter_neighbor_tables(offset_provider) - program = preprocess_program(program, offset_provider) arg_types = [type_translation.from_value(arg) for arg in args] - sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) - sdfg: dace.SDFG = sdfg_genenerator.visit(program) - sdfg.simplify() + neighbor_tables = filter_neighbor_tables(offset_provider) + + cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) + if build_cache is not None and cache_id in build_cache: + # retrieve SDFG program from build cache + sdfg_program = build_cache[cache_id] + sdfg = sdfg_program.sdfg + else: + # visit ITIR and generate SDFG + program = preprocess_program(program, offset_provider) + sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) + sdfg = sdfg_genenerator.visit(program) + sdfg.simplify() + + # set array storage for GPU execution + if run_on_gpu: + device = dace.DeviceType.GPU + sdfg._name = f"{sdfg.name}_gpu" + for _, _, array in sdfg.arrays_recursive(): + if not array.transient: + array.storage = dace.dtypes.StorageType.GPU_Global + else: + device = dace.DeviceType.CPU + + # run DaCe auto-optimization heuristics + if auto_optimize: + # TODO Investigate how symbol definitions improve autoopt transformations, + # in which case the cache table should take the symbols map into account. + symbols: dict[str, int] = {} + sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols) + + # compile SDFG and retrieve SDFG program + sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" + with dace.config.temporary_config(): + dace.config.Config.set("compiler", "build_type", value=build_type) + dace.config.Config.set("compiler", "cpu", "args", value=_cpu_args) + sdfg_program = sdfg.compile(validate=False) + + # store SDFG program in build cache + if build_cache is not None: + build_cache[cache_id] = sdfg_program dace_args = get_args(program.params, args) dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} @@ -103,9 +201,8 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) dace_strides = get_stride_args(sdfg.arrays, dace_field_args) - dace_conn_stirdes = get_stride_args(sdfg.arrays, dace_conn_args) - - sdfg.build_folder = cache._session_cache_dir_path / ".dacecache" + dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) + dace_offsets = get_offset_args(sdfg.arrays, program.params, args) all_args = { **dace_args, @@ -113,16 +210,40 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs) -> None: **dace_shapes, **dace_conn_shapes, **dace_strides, - **dace_conn_stirdes, + **dace_conn_strides, + **dace_offsets, } expected_args = { key: value for key, value in all_args.items() if key in sdfg.signature_arglist(with_types=False) } + with dace.config.temporary_config(): dace.config.Config.set("compiler", "allow_view_arguments", value=True) - dace.config.Config.set("compiler", "build_type", value="Debug") - dace.config.Config.set("compiler", "cpu", "args", value="-O0") dace.config.Config.set("frontend", "check_args", value=True) - sdfg(**expected_args) + sdfg_program(**expected_args) + + +@program_executor +def run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: + run_dace_iterator( + program, + *args, + **kwargs, + build_cache=_build_cache_cpu, + build_type=_build_type, + run_on_gpu=False, + ) + + +@program_executor +def run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: + run_dace_iterator( + program, + *args, + **kwargs, + build_cache=_build_cache_gpu, + build_type=_build_type, + run_on_gpu=True, + ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 56031d8555..7017815688 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -38,6 +38,7 @@ create_memlet_at, create_memlet_full, filter_neighbor_tables, + get_sorted_dims, map_nested_sdfg_symbols, unique_var_name, ) @@ -79,9 +80,10 @@ def get_scan_dim( - scan_dim_dtype: data type along the scan dimension """ output_type = cast(ts.FieldType, storage_types[output.id]) + sorted_dims = [dim for _, dim in get_sorted_dims(output_type.dims)] return ( column_axis.value, - output_type.dims.index(column_axis), + sorted_dims.index(column_axis), output_type.dtype, ) @@ -105,12 +107,17 @@ def __init__( self.offset_provider = offset_provider self.storage_types = {} - def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec): + def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True): if isinstance(type_, ts.FieldType): shape = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] strides = [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + offset = ( + [dace.symbol(unique_var_name()) for _ in range(len(type_.dims))] + if has_offset + else None + ) dtype = as_dace_type(type_.dtype) - sdfg.add_array(name, shape=shape, strides=strides, dtype=dtype) + sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype) elif isinstance(type_, ts.ScalarType): sdfg.add_symbol(name, as_dace_type(type_)) else: @@ -134,7 +141,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): scalar_kind = type_translation.get_scalar_kind(table.table.dtype) local_dim = Dimension("ElementDim", kind=DimensionKind.LOCAL) type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind)) - self.add_storage(program_sdfg, connectivity_identifier(offset), type_) + self.add_storage(program_sdfg, connectivity_identifier(offset), type_, has_offset=False) # Create a nested SDFG for all stencil closures. for closure in node.closures: @@ -246,7 +253,7 @@ def visit_StencilClosure( ) access = closure_init_state.add_access(out_name) value = ValueExpr(access, dtype) - memlet = create_memlet_at(out_name, ("0",)) + memlet = dace.Memlet.simple(out_name, "0") closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) program_arg_syms[name] = value else: @@ -274,7 +281,7 @@ def visit_StencilClosure( transient_to_arg_name_mapping[nsdfg_output_name] = output_name # scan operator should always be the first function call in a closure if is_scan(node.stencil): - nsdfg, map_domain, scan_dim_index = self._visit_scan_stencil_closure( + nsdfg, map_ranges, scan_dim_index = self._visit_scan_stencil_closure( node, closure_sdfg.arrays, closure_domain, nsdfg_output_name ) results = [nsdfg_output_name] @@ -285,8 +292,8 @@ def visit_StencilClosure( closure_sdfg.add_array( nsdfg_output_name, dtype=output_descriptor.dtype, - shape=(array_table[output_name].shape[scan_dim_index],), - strides=(array_table[output_name].strides[scan_dim_index],), + shape=(output_descriptor.shape[scan_dim_index],), + strides=(output_descriptor.strides[scan_dim_index],), transient=True, ) @@ -294,13 +301,13 @@ def visit_StencilClosure( output_name, tuple( f"i_{dim}" - if f"i_{dim}" in map_domain + if f"i_{dim}" in map_ranges else f"0:{output_descriptor.shape[scan_dim_index]}" for dim, _ in closure_domain ), ) else: - nsdfg, map_domain, results = self._visit_parallel_stencil_closure( + nsdfg, map_ranges, results = self._visit_parallel_stencil_closure( node, closure_sdfg.arrays, closure_domain ) assert len(results) == 1 @@ -313,7 +320,7 @@ def visit_StencilClosure( transient=True, ) - output_memlet = create_memlet_at(output_name, tuple(idx for idx in map_domain.keys())) + output_memlet = create_memlet_at(output_name, tuple(idx for idx in map_ranges.keys())) input_mapping = {param: arg for param, arg in zip(input_names, input_memlets)} output_mapping = {param: arg_memlet for param, arg_memlet in zip(results, [output_memlet])} @@ -325,7 +332,7 @@ def visit_StencilClosure( nsdfg_node, map_entry, map_exit = add_mapped_nested_sdfg( closure_state, sdfg=nsdfg, - map_ranges=map_domain or {"__dummy": "0"}, + map_ranges=map_ranges or {"__dummy": "0"}, inputs=array_mapping, outputs=output_mapping, symbol_mapping=symbol_mapping, @@ -341,10 +348,10 @@ def visit_StencilClosure( edge.src_conn, transient_access, None, - dace.Memlet(data=memlet.data, subset=output_subset), + dace.Memlet.simple(memlet.data, output_subset), ) - inner_memlet = dace.Memlet( - data=memlet.data, subset=output_subset, other_subset=memlet.subset + inner_memlet = dace.Memlet.simple( + memlet.data, output_subset, other_subset_str=memlet.subset ) closure_state.add_edge(transient_access, None, map_exit, edge.dst_conn, inner_memlet) closure_state.remove_edge(edge) @@ -360,7 +367,7 @@ def visit_StencilClosure( None, map_entry, b.value.data, - create_memlet_at(b.value.data, ("0",)), + dace.Memlet.simple(b.value.data, "0"), ) return closure_sdfg @@ -390,12 +397,12 @@ def _visit_scan_stencil_closure( connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] # find the scan dimension, same as output dimension, and exclude it from the map domain - map_domain = {} + map_ranges = {} for dim, (lb, ub) in closure_domain: lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value if not dim == scan_dim: - map_domain[f"i_{dim}"] = f"{lb_str}:{ub_str}" + map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" else: scan_lb_str = lb_str scan_ub_str = ub_str @@ -481,29 +488,28 @@ def _visit_scan_stencil_closure( "__result", carry_node1, None, - dace.Memlet(data=f"{scan_carry_name}", subset="0"), + dace.Memlet.simple(scan_carry_name, "0"), ) carry_node2 = lambda_state.add_access(scan_carry_name) lambda_state.add_memlet_path( carry_node2, scan_inner_node, - memlet=dace.Memlet(data=f"{scan_carry_name}", subset="0"), + memlet=dace.Memlet.simple(scan_carry_name, "0"), src_conn=None, dst_conn=lambda_carry_name, ) # connect access nodes to lambda inputs for (inner_name, _), data_name in zip(lambda_inputs[1:], input_names): - data_subset = ( - ", ".join([f"i_{dim}" for dim, _ in closure_domain]) - if isinstance(self.storage_types[data_name], ts.FieldType) - else "0" - ) + if isinstance(self.storage_types[data_name], ts.FieldType): + memlet = create_memlet_at(data_name, tuple(f"i_{dim}" for dim, _ in closure_domain)) + else: + memlet = dace.Memlet.simple(data_name, "0") lambda_state.add_memlet_path( lambda_state.add_access(data_name), scan_inner_node, - memlet=dace.Memlet(data=f"{data_name}", subset=data_subset), + memlet=memlet, src_conn=None, dst_conn=inner_name, ) @@ -527,12 +533,13 @@ def _visit_scan_stencil_closure( data_name, shape=(array_table[node.output.id].shape[scan_dim_index],), strides=(array_table[node.output.id].strides[scan_dim_index],), + offset=(array_table[node.output.id].offset[scan_dim_index],), dtype=array_table[node.output.id].dtype, ) lambda_state.add_memlet_path( scan_inner_node, lambda_state.add_access(data_name), - memlet=dace.Memlet(data=data_name, subset=f"i_{scan_dim}"), + memlet=dace.Memlet.simple(data_name, f"i_{scan_dim}"), src_conn=lambda_connector.value.label, dst_conn=None, ) @@ -544,10 +551,10 @@ def _visit_scan_stencil_closure( lambda_update_state.add_memlet_path( result_node, carry_node3, - memlet=dace.Memlet(data=f"{output_names[0]}", subset=f"i_{scan_dim}", other_subset="0"), + memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"), ) - return scan_sdfg, map_domain, scan_dim_index + return scan_sdfg, map_ranges, scan_dim_index def _visit_parallel_stencil_closure( self, @@ -562,11 +569,11 @@ def _visit_parallel_stencil_closure( conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] # find the scan dimension, same as output dimension, and exclude it from the map domain - map_domain = {} + map_ranges = {} for dim, (lb, ub) in closure_domain: lb_str = lb.value.data if isinstance(lb, ValueExpr) else lb.value ub_str = ub.value.data if isinstance(ub, ValueExpr) else ub.value - map_domain[f"i_{dim}"] = f"{lb_str}:{ub_str}" + map_ranges[f"i_{dim}"] = f"{lb_str}:{ub_str}" # Create an SDFG for the tasklet that computes a single item of the output domain. index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} @@ -583,7 +590,7 @@ def _visit_parallel_stencil_closure( self.node_types, ) - return context.body, map_domain, [r.value.data for r in results] + return context.body, map_ranges, [r.value.data for r in results] def _visit_domain( self, node: itir.FunCall, context: Context diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 2e7a598d9a..610698646a 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -23,7 +23,7 @@ from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols import gt4py.eve.codegen -from gt4py.next import Dimension, type_inference as next_typing +from gt4py.next import Dimension, StridedNeighborOffsetProvider, type_inference as next_typing from gt4py.next.iterator import ir as itir, type_inference as itir_typing from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.iterator.ir import FunCall, Lambda @@ -34,7 +34,6 @@ add_mapped_nested_sdfg, as_dace_type, connectivity_identifier, - create_memlet_at, create_memlet_full, filter_neighbor_tables, map_nested_sdfg_symbols, @@ -244,10 +243,8 @@ def builtin_neighbors( ) # select full shape only in the neighbor-axis dimension field_subset = [ - f"0:{sdfg.arrays[iterator.field.data].shape[idx]}" - if dim == table.neighbor_axis.value - else f"i_{dim}" - for idx, dim in enumerate(iterator.dimensions) + f"0:{shape}" if dim == table.neighbor_axis.value else f"i_{dim}" + for dim, shape in zip(sorted(iterator.dimensions), sdfg.arrays[iterator.field.data].shape) ] state.add_memlet_path( iterator.field, @@ -576,6 +573,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: return iterator args: list[ValueExpr] + sorted_dims = sorted(iterator.dimensions) if self.context.reduce_limit: # we are visiting a child node of reduction, so the neighbor index can be used for indirect addressing result_name = unique_var_name() @@ -595,9 +593,9 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: ) # if dim is not found in iterator indices, we take the neighbor index over the reduction domain - array_index = [ + flat_index = [ f"{iterator.indices[dim].data}_v" if dim in iterator.indices else index_name - for dim in sorted(iterator.dimensions) + for dim in sorted_dims ] args = [ValueExpr(iterator.field, iterator.dtype)] + [ ValueExpr(iterator.indices[dim], iterator.dtype) for dim in iterator.indices @@ -608,7 +606,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: name="deref", inputs=set(internals), outputs={"__result"}, - code=f"__result = {args[0].value.data}_v[{', '.join(array_index)}]", + code=f"__result = {args[0].value.data}_v[{', '.join(flat_index)}]", ) for arg, internal in zip(args, internals): @@ -630,12 +628,9 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: return [ValueExpr(value=result_access, dtype=iterator.dtype)] else: - sorted_index = sorted(iterator.indices.items(), key=lambda x: x[0]) - flat_index = [ - ValueExpr(x[1], iterator.dtype) for x in sorted_index if x[0] in iterator.dimensions + args = [ValueExpr(iterator.field, iterator.dtype)] + [ + ValueExpr(iterator.indices[dim], iterator.dtype) for dim in sorted_dims ] - - args = [ValueExpr(iterator.field, int), *flat_index] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") @@ -702,18 +697,31 @@ def _visit_indirect_addressing(self, node: itir.FunCall) -> IteratorExpr: element = tail[1].value assert isinstance(element, int) - table: NeighborTableOffsetProvider = self.offset_provider[offset] - shifted_dim = table.origin_axis.value - target_dim = table.neighbor_axis.value + if isinstance(self.offset_provider[offset], NeighborTableOffsetProvider): + table = self.offset_provider[offset] + shifted_dim = table.origin_axis.value + target_dim = table.neighbor_axis.value - conn = self.context.state.add_access(connectivity_identifier(offset)) + conn = self.context.state.add_access(connectivity_identifier(offset)) + + args = [ + ValueExpr(conn, table.table.dtype), + ValueExpr(iterator.indices[shifted_dim], dace.int64), + ] + + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[0]}[{internals[1]}, {element}]" + else: + offset_provider = self.offset_provider[offset] + assert isinstance(offset_provider, StridedNeighborOffsetProvider) + + shifted_dim = offset_provider.origin_axis.value + target_dim = offset_provider.neighbor_axis.value + offset_value = iterator.indices[shifted_dim] + args = [ValueExpr(offset_value, dace.int64)] + internals = [f"{offset_value.data}_v"] + expr = f"{internals[0]} * {offset_provider.max_neighbors} + {element}" - args = [ - ValueExpr(conn, table.table.dtype), - ValueExpr(iterator.indices[shifted_dim], dace.int64), - ] - internals = [f"{arg.value.data}_v" for arg in args] - expr = f"{internals[0]}[{internals[1]}, {element}]" shifted_value = self.add_expr_tasklet( list(zip(args, internals)), expr, dace.dtypes.int64, "ind_addr" )[0].value @@ -849,7 +857,7 @@ def _visit_reduce(self, node: itir.FunCall): p.apply_pass(lambda_context.body, {}) input_memlets = [ - create_memlet_at(expr.value.data, ("__idx",)) for arg, expr in zip(node.args, args) + dace.Memlet.simple(expr.value.data, "__idx") for arg, expr in zip(node.args, args) ] output_memlet = dace.Memlet.simple(result_name, "0") @@ -928,7 +936,7 @@ def add_expr_tasklet( ) self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - memlet = create_memlet_at(result_access.data, ("0",)) + memlet = dace.Memlet.simple(result_access.data, "0") self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) return [ValueExpr(result_access, result_type)] diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 889a1ab150..7e6fe13ac7 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -12,10 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any +from typing import Any, Sequence import dace +from gt4py.next import Dimension from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.type_system import type_specifications as ts @@ -49,7 +50,7 @@ def connectivity_identifier(name: str): def create_memlet_full(source_identifier: str, source_array: dace.data.Array): bounds = [(0, size) for size in source_array.shape] subset = ", ".join(f"{lb}:{ub}" for lb, ub in bounds) - return dace.Memlet(data=source_identifier, subset=subset) + return dace.Memlet.simple(source_identifier, subset) def create_memlet_at(source_identifier: str, index: tuple[str, ...]): @@ -57,6 +58,10 @@ def create_memlet_at(source_identifier: str, index: tuple[str, ...]): return dace.Memlet(data=source_identifier, subset=subset) +def get_sorted_dims(dims: Sequence[Dimension]) -> Sequence[tuple[int, Dimension]]: + return sorted(enumerate(dims), key=lambda v: v[1].value) + + def map_nested_sdfg_symbols( parent_sdfg: dace.SDFG, nested_sdfg: dace.SDFG, array_mapping: dict[str, dace.Memlet] ) -> dict[str, str]: diff --git a/tests/next_tests/exclusion_matrices.py b/tests/next_tests/exclusion_matrices.py index 27ccb29095..98ac9352c3 100644 --- a/tests/next_tests/exclusion_matrices.py +++ b/tests/next_tests/exclusion_matrices.py @@ -61,7 +61,6 @@ (USES_NEGATIVE_MODULO, XFAIL, UNSUPPORTED_MESSAGE), (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE), (USES_SCAN_IN_FIELD_OPERATOR, XFAIL, UNSUPPORTED_MESSAGE), - (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ] #: Skip matrix, contains for each backend processor a list of tuples with following fields: @@ -81,11 +80,18 @@ (USES_TUPLE_RETURNS, XFAIL, UNSUPPORTED_MESSAGE), (USES_ZERO_DIMENSIONAL_FIELDS, XFAIL, UNSUPPORTED_MESSAGE), ], - GTFN_CPU: GTFN_SKIP_TEST_LIST, - GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST, + GTFN_CPU: GTFN_SKIP_TEST_LIST + + [ + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + ], + GTFN_CPU_IMPERATIVE: GTFN_SKIP_TEST_LIST + + [ + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), + ], GTFN_CPU_WITH_TEMPORARIES: GTFN_SKIP_TEST_LIST + [ (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, BINDINGS_UNSUPPORTED_MESSAGE), ], GTFN_FORMAT_SOURCECODE: [ (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE),