Skip to content

Commit

Permalink
fix[next]: Fix usage of DaCe fast-call to SDFG (#1656)
Browse files Browse the repository at this point in the history
This PR addresses some flaky test failures observed in GT4Py CI. The
root cause was that the dace backend did not check the connectivity
arrays, which are passed as keyword-arguments to the SDFG. It did only
check the positional arguments. The connectivity arrays do not have to
be allocated on the device memory: for gpu execution, the backend
ensures that the connectivity arrays are copied to device memory just
before passing them to the SDFG call. Previous implementation worked
sometimes, when by chance cupy was reusing the same array on gpu memory,
hence the flaky behavior of the tests.
New test is added for the connectivity case. The previous test case is
cleaned up and improved, by invalidating all scalar positional arguments
at each SDFG call: this allows to test that they are overridden before
fast_call.

Additionally, this PR reduces the overhead of regular SDFG call:
previous implementation was copying all the connectivity arrays to gpu
memory, with this PR we only allocate cupy arrays for the connectivities
used in the SDFG.
  • Loading branch information
edopao authored Sep 27, 2024
1 parent 48b13cc commit 6d011ea
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 133 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,9 @@ def _ensure_is_on_device(
return connectivity_arg


def _get_connectivity_args(
neighbor_tables: Mapping[str, gtx_common.NeighborTable], device: dace.dtypes.DeviceType
) -> dict[str, Any]:
return {
dace_util.connectivity_identifier(offset): _ensure_is_on_device(
offset_provider.table, device
)
for offset, offset_provider in neighbor_tables.items()
}


def _get_shape_args(
arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any]
) -> Mapping[str, int]:
arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray]
) -> dict[str, int]:
shape_args: dict[str, int] = {}
for name, value in args.items():
for sym, size in zip(arrays[name].shape, value.shape, strict=True):
Expand All @@ -101,8 +90,8 @@ def _get_shape_args(


def _get_stride_args(
arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any]
) -> Mapping[str, int]:
arrays: Mapping[str, dace.data.Array], args: Mapping[str, np.typing.NDArray]
) -> dict[str, int]:
stride_args = {}
for name, value in args.items():
for sym, stride_size in zip(arrays[name].strides, value.strides, strict=True):
Expand All @@ -121,6 +110,27 @@ def _get_stride_args(
return stride_args


def get_sdfg_conn_args(
sdfg: dace.SDFG,
offset_provider: dict[str, Any],
on_gpu: bool,
) -> dict[str, np.typing.NDArray]:
"""
Extracts the connectivity tables that are used in the sdfg and ensures
that the memory buffers are allocated for the target device.
"""
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU

connectivity_args = {}
for offset, connectivity in dace_util.filter_connectivities(offset_provider).items():
assert isinstance(connectivity, gtx_common.NeighborTable)
param = dace_util.connectivity_identifier(offset)
if param in sdfg.arrays:
connectivity_args[param] = _ensure_is_on_device(connectivity.table, device)

return connectivity_args


def get_sdfg_args(
sdfg: dace.SDFG,
*args: Any,
Expand All @@ -138,17 +148,9 @@ def get_sdfg_args(
"""
offset_provider = kwargs["offset_provider"]

neighbor_tables: dict[str, gtx_common.NeighborTable] = {}
for offset, connectivity in dace_util.filter_connectivities(offset_provider).items():
assert isinstance(connectivity, gtx_common.NeighborTable)
neighbor_tables[offset] = connectivity
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU

dace_args = _get_args(sdfg, args, use_field_canonical_representation)
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
dace_conn_args = _get_connectivity_args(neighbor_tables, device)
# keep only connectivity tables that are used in the sdfg
dace_conn_args = {n: v for n, v in dace_conn_args.items() if n in sdfg.arrays}
dace_conn_args = get_sdfg_conn_args(sdfg, offset_provider, on_gpu)
dace_shapes = _get_shape_args(sdfg.arrays, dace_field_args)
dace_conn_shapes = _get_shape_args(sdfg.arrays, dace_conn_args)
dace_strides = _get_stride_args(sdfg.arrays, dace_field_args)
Expand Down
11 changes: 10 additions & 1 deletion src/gt4py/next/program_processors/runners/dace_common/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from __future__ import annotations

from typing import Any, Mapping, Optional, Sequence
import re
from typing import Any, Final, Mapping, Optional, Sequence

import dace

Expand All @@ -17,6 +18,10 @@
from gt4py.next.type_system import type_specifications as ts


# regex to match the symbols for field shape and strides
FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile("__.+_(size|stride)_\d+")


def as_scalar_type(typestr: str) -> ts.ScalarType:
"""Obtain GT4Py scalar type from generic numpy string representation."""
try:
Expand All @@ -38,6 +43,10 @@ def field_stride_symbol_name(field_name: str, axis: int) -> str:
return f"__{field_name}_stride_{axis}"


def is_field_symbol(name: str) -> bool:
return FIELD_SYMBOL_RE.match(name) is not None


def debug_info(
node: gtir.Node, *, default: Optional[dace.dtypes.DebugInfo] = None
) -> Optional[dace.dtypes.DebugInfo]:
Expand Down
76 changes: 37 additions & 39 deletions src/gt4py/next/program_processors/runners/dace_common/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import ctypes
import dataclasses
from typing import Any, Optional
from typing import Any

import dace
import factory
Expand All @@ -20,26 +20,23 @@
from gt4py.next import common, config
from gt4py.next.otf import arguments, languages, stages, step_types, workflow
from gt4py.next.otf.compilation import cache
from gt4py.next.program_processors.runners.dace_common import dace_backend
from gt4py.next.program_processors.runners.dace_common import dace_backend, utility as dace_utils


class CompiledDaceProgram(stages.CompiledProgram):
sdfg_program: dace.CompiledSDFG
# Map SDFG argument to its position in program ABI; scalar arguments that are not used in the SDFG will not be present.
sdfg_arg_position: list[Optional[int]]

def __init__(self, program: dace.CompiledSDFG):
# extract position of arguments in program ABI
sdfg_arglist = program.sdfg.signature_arglist(with_types=False)
sdfg_arg_pos_mapping = {param: pos for pos, param in enumerate(sdfg_arglist)}
sdfg_used_symbols = program.sdfg.used_symbols(all_symbols=False)
# Sorted list of SDFG arguments as they appear in program ABI and corresponding data type;
# scalar arguments that are not used in the SDFG will not be present.
sdfg_arglist: list[tuple[str, dace.dtypes.Data]]

def __init__(self, program: dace.CompiledSDFG):
self.sdfg_program = program
self.sdfg_arg_position = [
sdfg_arg_pos_mapping[param]
if param in program.sdfg.arrays or param in sdfg_used_symbols
else None
for param in program.sdfg.arg_names
# `dace.CompiledSDFG.arglist()` returns an ordered dictionary that maps the argument
# name to its data type, in the same order as arguments appear in the program ABI.
# This is also the same order of arguments in `dace.CompiledSDFG._lastargs[0]`.
self.sdfg_arglist = [
(arg_name, arg_type) for arg_name, arg_type in program.sdfg.arglist().items()
]

def __call__(self, *args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -94,13 +91,6 @@ class Meta:
model = DaCeCompiler


def _get_ctype_value(arg: Any, dtype: dace.dtypes.dataclass) -> Any:
if not isinstance(arg, (ctypes._SimpleCData, ctypes._Pointer)):
actype = dtype.as_ctypes()
return actype(arg)
return arg


def convert_args(
inp: CompiledDaceProgram,
device: core_defs.DeviceType = core_defs.DeviceType.CPU,
Expand All @@ -119,28 +109,36 @@ def decorated_program(
args = (*args, out)
if len(sdfg.arg_names) > len(args):
args = (*args, *arguments.iter_size_args(args))

if sdfg_program._lastargs:
# The scalar arguments should be replaced with the actual value; for field arguments,
# the data pointer should remain the same otherwise fast-call cannot be used and
# the args list needs to be reconstructed.
kwargs = dict(zip(sdfg.arg_names, args, strict=True))
kwargs.update(dace_backend.get_sdfg_conn_args(sdfg, offset_provider, on_gpu))

use_fast_call = True
for arg, param, pos in zip(args, sdfg.arg_names, inp.sdfg_arg_position, strict=True):
if isinstance(arg, common.Field):
desc = sdfg.arrays[param]
assert isinstance(desc, dace.data.Array)
assert isinstance(sdfg_program._lastargs[0][pos], ctypes.c_void_p)
if sdfg_program._lastargs[0][pos].value != get_array_interface_ptr(
arg.ndarray, desc.storage
):
last_call_args = sdfg_program._lastargs[0]
# The scalar arguments should be overridden with the new value; for field arguments,
# the data pointer should remain the same otherwise fast_call cannot be used and
# the arguments list has to be reconstructed.
for i, (arg_name, arg_type) in enumerate(inp.sdfg_arglist):
if isinstance(arg_type, dace.data.Array):
assert arg_name in kwargs, f"Argument '{arg_name}' not found."
data_ptr = get_array_interface_ptr(kwargs[arg_name], arg_type.storage)
assert isinstance(last_call_args[i], ctypes.c_void_p)
if last_call_args[i].value != data_ptr:
use_fast_call = False
break
elif param in sdfg.arrays:
desc = sdfg.arrays[param]
assert isinstance(desc, dace.data.Scalar)
sdfg_program._lastargs[0][pos] = _get_ctype_value(arg, desc.dtype)
elif pos:
sym_dtype = sdfg.symbols[param]
sdfg_program._lastargs[0][pos] = _get_ctype_value(arg, sym_dtype)
else:
assert isinstance(arg_type, dace.data.Scalar)
assert isinstance(last_call_args[i], ctypes._SimpleCData)
if arg_name in kwargs:
# override the scalar value used in previous program call
actype = arg_type.dtype.as_ctypes()
last_call_args[i] = actype(kwargs[arg_name])
else:
# shape and strides of arrays are supposed not to change, and can therefore be omitted
assert dace_utils.is_field_symbol(
arg_name
), f"Argument '{arg_name}' not found."

if use_fast_call:
return sdfg_program.fast_call(*sdfg_program._lastargs)
Expand Down
Loading

0 comments on commit 6d011ea

Please sign in to comment.