Skip to content

Commit

Permalink
fix[dace]: Fixed SDFG args (#1400)
Browse files Browse the repository at this point in the history
Modified how the SDFG arguments are computed.

It was noticed that some transformations, especially the `SDFG.apply_gpu_transformation()`, to the SDFG, added new arguments to the SDFG.
But, since a lot of functions build on the `SDFG.arg_names` member and this member was populated before the transformation, an error occurred.
Thus it was changed such that `SDFG.arg_names` was only populated with the arguments also known to the Fencil.
  • Loading branch information
philip-paul-mueller authored Dec 19, 2023
1 parent 15a7bd6 commit af33e21
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]:
neighbor_tables = filter_neighbor_tables(offset_provider)
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU

sdfg_sig = sdfg.signature_arglist(with_types=False)
dace_args = get_args(sdfg, args)
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
dace_conn_args = get_connectivity_args(neighbor_tables, device)
Expand All @@ -224,11 +225,8 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, **kwargs) -> dict[str, Any]:
**dace_conn_strides,
**dace_offsets,
}
expected_args = {
key: value
for key, value in all_args.items()
if key in sdfg.signature_arglist(with_types=False)
}
expected_args = {key: all_args[key] for key in sdfg_sig}

return expected_args


Expand Down Expand Up @@ -258,21 +256,22 @@ def build_sdfg_from_itir(
# TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force
# `lift_more` to `FORCE_INLINE` mode.
lift_mode = itir_transforms.LiftMode.FORCE_INLINE

arg_types = [type_translation.from_value(arg) for arg in args]
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU

# visit ITIR and generate SDFG
program = preprocess_program(program, offset_provider, lift_mode)
# TODO: According to Lex one should build the SDFG first in a general mannor.
# Generalisation to a particular device should happen only at the end.
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu)
sdfg = sdfg_genenerator.visit(program)
sdfg.simplify()

# run DaCe auto-optimization heuristics
if auto_optimize:
# TODO Investigate how symbol definitions improve autoopt transformations,
# in which case the cache table should take the symbols map into account.
# TODO: Investigate how symbol definitions improve autoopt transformations,
# in which case the cache table should take the symbols map into account.
symbols: dict[str, int] = {}
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU
sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu)

return sdfg
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,9 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet)

# Create the call signature for the SDFG.
# All arguments required by the SDFG, regardless if explicit and implicit, are added
# as positional arguments. In the front are all arguments to the Fencil, in that
# order, they are followed by the arguments created by the translation process,
arg_list = [str(a) for a in node.params]
sig_list = program_sdfg.signature_arglist(with_types=False)
implicit_args = set(sig_list) - set(arg_list)
call_params = arg_list + [ia for ia in sig_list if ia in implicit_args]
program_sdfg.arg_names = call_params
# Only the arguments requiered by the Fencil, i.e. `node.params` are added as poitional arguments.
# The implicit arguments, such as the offset providers or the arguments created by the translation process, must be passed as keywords only arguments.
program_sdfg.arg_names = [str(a) for a in node.params]

program_sdfg.validate()
return program_sdfg
Expand Down

0 comments on commit af33e21

Please sign in to comment.