Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[dace]: Computing SDFG call arguments #1398

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 49 additions & 30 deletions src/gt4py/next/program_processors/runners/dace_iterator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ def preprocess_program(
return fencil_definition


def get_args(params: Sequence[itir.Sym], args: Sequence[Any]) -> dict[str, Any]:
return {name.id: convert_arg(arg) for name, arg in zip(params, args)}
def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]:
sdfg_params: Sequence[str] = sdfg.arg_names
return {sdfg_param: convert_arg(arg) for sdfg_param, arg in zip(sdfg_params, args)}


def _ensure_is_on_device(
Expand Down Expand Up @@ -131,13 +132,16 @@ def get_shape_args(


def get_offset_args(
arrays: Mapping[str, dace.data.Array], params: Sequence[itir.Sym], args: Sequence[Any]
sdfg: dace.SDFG,
args: Sequence[Any],
) -> Mapping[str, int]:
sdfg_arrays: Mapping[str, dace.data.Array] = sdfg.arrays
sdfg_params: Sequence[str] = sdfg.arg_names
return {
str(sym): -drange.start
for param, arg in zip(params, args)
for sdfg_param, arg in zip(sdfg_params, args)
if common.is_field(arg)
for sym, drange in zip(arrays[param.id].offset, get_sorted_dim_ranges(arg.domain))
for sym, drange in zip(sdfg_arrays[sdfg_param].offset, get_sorted_dim_ranges(arg.domain))
}


Expand Down Expand Up @@ -193,6 +197,45 @@ def get_cache_id(
return m.hexdigest()


def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]:
"""Extracts the arguments needed to call the SDFG.

This function can handle the same arguments that are passed to `run_dace_iterator()`.

Args:
sdfg: The SDFG for which we want to get the arguments.
""" # noqa: D401
offset_provider = kwargs["offset_provider"]
on_gpu = kwargs.get("on_gpu", False)

neighbor_tables = filter_neighbor_tables(offset_provider)
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU

dace_args = get_args(sdfg, args)
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
dace_conn_args = get_connectivity_args(neighbor_tables, device)
dace_shapes = get_shape_args(sdfg.arrays, dace_field_args)
dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args)
dace_strides = get_stride_args(sdfg.arrays, dace_field_args)
dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args)
dace_offsets = get_offset_args(sdfg, args)
all_args = {
**dace_args,
**dace_conn_args,
**dace_shapes,
**dace_conn_shapes,
**dace_strides,
**dace_conn_strides,
**dace_offsets,
}
expected_args = {
key: value
for key, value in all_args.items()
if key in sdfg.signature_arglist(with_types=False)
}
return expected_args


def build_sdfg_from_itir(
program: itir.FencilDefinition,
*args,
Expand Down Expand Up @@ -251,8 +294,6 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
offset_provider = kwargs["offset_provider"]

arg_types = [type_translation.from_value(arg) for arg in args]
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU
neighbor_tables = filter_neighbor_tables(offset_provider)

cache_id = get_cache_id(program, arg_types, column_axis, offset_provider)
if build_cache is not None and cache_id in build_cache:
Expand Down Expand Up @@ -281,29 +322,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
if build_cache is not None:
build_cache[cache_id] = sdfg_program

dace_args = get_args(program.params, args)
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
dace_conn_args = get_connectivity_args(neighbor_tables, device)
dace_shapes = get_shape_args(sdfg.arrays, dace_field_args)
dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args)
dace_strides = get_stride_args(sdfg.arrays, dace_field_args)
dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args)
dace_offsets = get_offset_args(sdfg.arrays, program.params, args)

all_args = {
**dace_args,
**dace_conn_args,
**dace_shapes,
**dace_conn_shapes,
**dace_strides,
**dace_conn_strides,
**dace_offsets,
}
expected_args = {
key: value
for key, value in all_args.items()
if key in sdfg.signature_arglist(with_types=False)
}
expected_args = get_sdfg_args(sdfg, *args, **kwargs)

with dace.config.temporary_config():
dace.config.Config.set("compiler", "allow_view_arguments", value=True)
Expand Down