From e573336b65f080df9451342a411b6cf548110781 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 18 Dec 2023 09:29:35 +0100 Subject: [PATCH 1/2] Added a function to get the arguments to call an SDFG. This commit adds a function that allows to generate the arguments needed to call an SDFG, before this was part of `run_dace_iterator()`. This made it very complex to run an SDFG outside this function. One should consider this as an amend to [PR #1379](https://github.com/GridTools/gt4py/pull/1379). --- .../runners/dace_iterator/__init__.py | 87 ++++++++++++------- 1 file changed, 56 insertions(+), 31 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 037c4f3e4d..e1952e075d 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import hashlib import warnings -from typing import Any, Mapping, Optional, Sequence +from typing import Any, Mapping, Optional, Sequence, Union import dace import numpy as np @@ -94,8 +94,12 @@ def preprocess_program( return fencil_definition -def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]: - return {name.id: convert_arg(arg) for name, arg in zip(params, args)} +def get_args(params: Union[Sequence[itir.Sym], dace.SDFG], args: Sequence[Any]) -> dict[str, Any]: + if isinstance(params, dace.SDFG): + params = params.arg_names + else: + params = [name.id for name in params] + return {name: convert_arg(arg) for name, arg in zip(params, args)} def _ensure_is_on_device( @@ -131,13 +135,19 @@ def get_shape_args( def get_offset_args( - arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any] + arrays: Mapping[str, dace.data.Array], + params: Union[Sequence[itir.Sym], dace.SDFG], + args: Sequence[Any], ) -> Mapping[str, int]: + if isinstance(params, dace.SDFG): + params = params.arg_names + else: + params = [name.id for name in params] return { str(sym): -drange.start - for param, arg in zip(params, args) + for name, arg in zip(params, args) if common.is_field(arg) - for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain)) + for sym, drange in zip(arrays[name].offset, get_sorted_dim_ranges(arg.domain)) } @@ -193,6 +203,45 @@ def get_cache_id( return m.hexdigest() +def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: + """Extracts the arguments needed to call the SDFG. + + This function can handle the same arguments that are passed to `run_dace_iterator()`. + + Args: + sdfg: The SDFG for which we want to get the arguments. + """ # noqa: D401 + offset_provider = kwargs["offset_provider"] + on_gpu = kwargs.get("on_gpu", False) + + neighbor_tables = filter_neighbor_tables(offset_provider) + device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU + + dace_args = get_args(sdfg, args) + dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} + dace_conn_args = get_connectivity_args(neighbor_tables, device) + dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) + dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) + dace_strides = get_stride_args(sdfg.arrays, dace_field_args) + dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) + dace_offsets = get_offset_args(sdfg.arrays, sdfg, args) + all_args = { + **dace_args, + **dace_conn_args, + **dace_shapes, + **dace_conn_shapes, + **dace_strides, + **dace_conn_strides, + **dace_offsets, + } + expected_args = { + key: value + for key, value in all_args.items() + if key in sdfg.signature_arglist(with_types=False) + } + return expected_args + + def build_sdfg_from_itir( program: itir.FencilDefinition, *args, @@ -251,8 +300,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): offset_provider = kwargs["offset_provider"] arg_types = [type_translation.from_value(arg) for arg in args] - device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU - neighbor_tables = filter_neighbor_tables(offset_provider) cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) if build_cache is not None and cache_id in build_cache: @@ -281,29 +328,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): if build_cache is not None: build_cache[cache_id] = sdfg_program - dace_args = get_args(program.params, args) - dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)} - dace_conn_args = get_connectivity_args(neighbor_tables, device) - dace_shapes = get_shape_args(sdfg.arrays, dace_field_args) - dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) - dace_strides = get_stride_args(sdfg.arrays, dace_field_args) - dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) - dace_offsets = get_offset_args(sdfg.arrays, program.params, args) - - all_args = { - **dace_args, - **dace_conn_args, - **dace_shapes, - **dace_conn_shapes, - **dace_strides, - **dace_conn_strides, - **dace_offsets, - } - expected_args = { - key: value - for key, value in all_args.items() - if key in sdfg.signature_arglist(with_types=False) - } + expected_args = get_sdfg_args(sdfg, *args, **kwargs) with dace.config.temporary_config(): dace.config.Config.set("compiler", "allow_view_arguments", value=True) From 9f1d99b5a27bcd88dabd4f0678ff74616c2cf39f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 18 Dec 2023 10:52:31 +0100 Subject: [PATCH 2/2] Removed some old backwards compatibility. --- .../runners/dace_iterator/__init__.py | 26 +++++++------------ 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index e1952e075d..6c70d0087e 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -13,7 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import hashlib import warnings -from typing import Any, Mapping, Optional, Sequence, Union +from typing import Any, Mapping, Optional, Sequence import dace import numpy as np @@ -94,12 +94,9 @@ def preprocess_program( return fencil_definition -def get_args(params: Union[Sequence[itir.Sym], dace.SDFG], args: Sequence[Any]) -> dict[str, Any]: - if isinstance(params, dace.SDFG): - params = params.arg_names - else: - params = [name.id for name in params] - return {name: convert_arg(arg) for name, arg in zip(params, args)} +def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]: + sdfg_params: Sequence[str] = sdfg.arg_names + return {sdfg_param: convert_arg(arg) for sdfg_param, arg in zip(sdfg_params, args)} def _ensure_is_on_device( @@ -135,19 +132,16 @@ def get_shape_args( def get_offset_args( - arrays: Mapping[str, dace.data.Array], - params: Union[Sequence[itir.Sym], dace.SDFG], + sdfg: dace.SDFG, args: Sequence[Any], ) -> Mapping[str, int]: - if isinstance(params, dace.SDFG): - params = params.arg_names - else: - params = [name.id for name in params] + sdfg_arrays: Mapping[str, dace.data.Array] = sdfg.arrays + sdfg_params: Sequence[str] = sdfg.arg_names return { str(sym): -drange.start - for name, arg in zip(params, args) + for sdfg_param, arg in zip(sdfg_params, args) if common.is_field(arg) - for sym, drange in zip(arrays[name].offset, get_sorted_dim_ranges(arg.domain)) + for sym, drange in zip(sdfg_arrays[sdfg_param].offset, get_sorted_dim_ranges(arg.domain)) } @@ -224,7 +218,7 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]: dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args) dace_strides = get_stride_args(sdfg.arrays, dace_field_args) dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args) - dace_offsets = get_offset_args(sdfg.arrays, sdfg, args) + dace_offsets = get_offset_args(sdfg, args) all_args = { **dace_args, **dace_conn_args,